Distributed Ranges
Loading...
Searching...
No Matches
inclusive_scan.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <optional>
8
9#include <sycl/sycl.hpp>
10
11#include <oneapi/dpl/execution>
12#include <oneapi/dpl/numeric>
13
14#include <oneapi/dpl/async>
15
16#include <dr/concepts/concepts.hpp>
17#include <dr/detail/onedpl_direct_iterator.hpp>
18#include <dr/sp/algorithms/execution_policy.hpp>
19#include <dr/sp/allocators.hpp>
20#include <dr/sp/detail.hpp>
21#include <dr/sp/init.hpp>
22#include <dr/sp/vector.hpp>
23#include <dr/sp/views/views.hpp>
24
25namespace dr::sp {
26
27template <typename ExecutionPolicy, dr::distributed_contiguous_range R,
28 dr::distributed_contiguous_range O, typename BinaryOp,
29 typename U = rng::range_value_t<R>>
30void inclusive_scan_impl_(ExecutionPolicy &&policy, R &&r, O &&o,
31 BinaryOp &&binary_op, std::optional<U> init = {}) {
32 using T = rng::range_value_t<O>;
33
34 static_assert(
35 std::is_same_v<std::remove_cvref_t<ExecutionPolicy>, device_policy>);
36
37 auto zipped_view = dr::sp::views::zip(r, o);
38 auto zipped_segments = zipped_view.zipped_segments();
39
40 if constexpr (std::is_same_v<std::remove_cvref_t<ExecutionPolicy>,
41 device_policy>) {
42
43 std::vector<sycl::event> events;
44
45 auto root = dr::sp::devices()[0];
46 dr::sp::device_allocator<T> allocator(dr::sp::context(), root);
48 std::size_t(zipped_segments.size()), allocator);
49
50 std::size_t segment_id = 0;
51 for (auto &&segs : zipped_segments) {
52 auto &&[in_segment, out_segment] = segs;
53
54 auto &&q = __detail::queue(dr::ranges::rank(in_segment));
55 auto &&local_policy = __detail::dpl_policy(dr::ranges::rank(in_segment));
56
57 auto dist = rng::distance(in_segment);
58 assert(dist > 0);
59
60 auto first = rng::begin(in_segment);
61 auto last = rng::end(in_segment);
62 auto d_first = rng::begin(out_segment);
63
64 sycl::event event;
65
66 if (segment_id == 0 && init.has_value()) {
67 event = oneapi::dpl::experimental::inclusive_scan_async(
68 local_policy, dr::__detail::direct_iterator(first),
70 dr::__detail::direct_iterator(d_first), binary_op, init.value());
71 } else {
72 event = oneapi::dpl::experimental::inclusive_scan_async(
73 local_policy, dr::__detail::direct_iterator(first),
75 dr::__detail::direct_iterator(d_first), binary_op);
76 }
77
78 auto dst_iter = dr::ranges::local(partial_sums).data() + segment_id;
79
80 auto src_iter = dr::ranges::local(out_segment).data();
81 rng::advance(src_iter, dist - 1);
82
83 auto e = q.submit([&](auto &&h) {
84 h.depends_on(event);
85 h.single_task([=]() {
86 rng::range_value_t<O> value = *src_iter;
87 *dst_iter = value;
88 });
89 });
90
91 events.push_back(e);
92
93 segment_id++;
94 }
95
96 __detail::wait(events);
97 events.clear();
98
99 auto &&local_policy = __detail::dpl_policy(0);
100
101 auto first = dr::ranges::local(partial_sums).data();
102 auto last = first + partial_sums.size();
103
104 oneapi::dpl::experimental::inclusive_scan_async(local_policy, first, last,
105 first, binary_op)
106 .wait();
107
108 std::size_t idx = 0;
109 for (auto &&segs : zipped_segments) {
110 auto &&[in_segment, out_segment] = segs;
111
112 if (idx > 0) {
113 auto &&q = __detail::queue(dr::ranges::rank(out_segment));
114
115 auto first = rng::begin(out_segment);
116 dr::__detail::direct_iterator d_first(first);
117
118 auto d_sum =
119 dr::ranges::__detail::local(partial_sums).begin() + idx - 1;
120
121 sycl::event e = dr::__detail::parallel_for(
122 q, sycl::range<>(rng::distance(out_segment)),
123 [=](auto idx) { d_first[idx] = binary_op(d_first[idx], *d_sum); });
124
125 events.push_back(e);
126 }
127 idx++;
128 }
129
130 __detail::wait(events);
131
132 } else {
133 assert(false);
134 }
135}
136
137template <typename ExecutionPolicy, dr::distributed_contiguous_range R,
138 dr::distributed_contiguous_range O, typename BinaryOp, typename T>
139void inclusive_scan(ExecutionPolicy &&policy, R &&r, O &&o,
140 BinaryOp &&binary_op, T init) {
141 inclusive_scan_impl_(std::forward<ExecutionPolicy>(policy),
142 std::forward<R>(r), std::forward<O>(o),
143 std::forward<BinaryOp>(binary_op), std::optional(init));
144}
145
146template <typename ExecutionPolicy, dr::distributed_contiguous_range R,
147 dr::distributed_contiguous_range O, typename BinaryOp>
148void inclusive_scan(ExecutionPolicy &&policy, R &&r, O &&o,
149 BinaryOp &&binary_op) {
150 inclusive_scan_impl_(std::forward<ExecutionPolicy>(policy),
151 std::forward<R>(r), std::forward<O>(o),
152 std::forward<BinaryOp>(binary_op));
153}
154
155template <typename ExecutionPolicy, dr::distributed_contiguous_range R,
157void inclusive_scan(ExecutionPolicy &&policy, R &&r, O &&o) {
158 inclusive_scan(std::forward<ExecutionPolicy>(policy), std::forward<R>(r),
159 std::forward<O>(o), std::plus<rng::range_value_t<R>>());
160}
161
162// Distributed iterator versions
163
164template <typename ExecutionPolicy, dr::distributed_iterator Iter,
165 dr::distributed_iterator OutputIter, typename BinaryOp, typename T>
166OutputIter inclusive_scan(ExecutionPolicy &&policy, Iter first, Iter last,
167 OutputIter d_first, BinaryOp &&binary_op, T init) {
168
169 auto dist = rng::distance(first, last);
170 auto d_last = d_first;
171 rng::advance(d_last, dist);
172 inclusive_scan(std::forward<ExecutionPolicy>(policy),
173 rng::subrange(first, last), rng::subrange(d_first, d_last),
174 std::forward<BinaryOp>(binary_op), init);
175
176 return d_last;
177}
178
179template <typename ExecutionPolicy, dr::distributed_iterator Iter,
180 dr::distributed_iterator OutputIter, typename BinaryOp>
181OutputIter inclusive_scan(ExecutionPolicy &&policy, Iter first, Iter last,
182 OutputIter d_first, BinaryOp &&binary_op) {
183
184 auto dist = rng::distance(first, last);
185 auto d_last = d_first;
186 rng::advance(d_last, dist);
187 inclusive_scan(std::forward<ExecutionPolicy>(policy),
188 rng::subrange(first, last), rng::subrange(d_first, d_last),
189 std::forward<BinaryOp>(binary_op));
190
191 return d_last;
192}
193
194template <typename ExecutionPolicy, dr::distributed_iterator Iter,
195 dr::distributed_iterator OutputIter>
196OutputIter inclusive_scan(ExecutionPolicy &&policy, Iter first, Iter last,
197 OutputIter d_first) {
198 auto dist = rng::distance(first, last);
199 auto d_last = d_first;
200 rng::advance(d_last, dist);
201 inclusive_scan(std::forward<ExecutionPolicy>(policy),
202 rng::subrange(first, last), rng::subrange(d_first, d_last));
203
204 return d_last;
205}
206
207// Execution policy-less versions
208
211void inclusive_scan(R &&r, O &&o) {
212 inclusive_scan(dr::sp::par_unseq, std::forward<R>(r), std::forward<O>(o));
213}
214
216 dr::distributed_contiguous_range O, typename BinaryOp>
217void inclusive_scan(R &&r, O &&o, BinaryOp &&binary_op) {
218 inclusive_scan(dr::sp::par_unseq, std::forward<R>(r), std::forward<O>(o),
219 std::forward<BinaryOp>(binary_op));
220}
221
223 dr::distributed_contiguous_range O, typename BinaryOp, typename T>
224void inclusive_scan(R &&r, O &&o, BinaryOp &&binary_op, T init) {
225 inclusive_scan(dr::sp::par_unseq, std::forward<R>(r), std::forward<O>(o),
226 std::forward<BinaryOp>(binary_op), init);
227}
228
229// Distributed iterator versions
230
231template <dr::distributed_iterator Iter, dr::distributed_iterator OutputIter>
232OutputIter inclusive_scan(Iter first, Iter last, OutputIter d_first) {
233 return inclusive_scan(dr::sp::par_unseq, first, last, d_first);
234}
235
236template <dr::distributed_iterator Iter, dr::distributed_iterator OutputIter,
237 typename BinaryOp>
238OutputIter inclusive_scan(Iter first, Iter last, OutputIter d_first,
239 BinaryOp &&binary_op) {
240 return inclusive_scan(dr::sp::par_unseq, first, last, d_first,
241 std::forward<BinaryOp>(binary_op));
242}
243
244template <dr::distributed_iterator Iter, dr::distributed_iterator OutputIter,
245 typename BinaryOp, typename T>
246OutputIter inclusive_scan(Iter first, Iter last, OutputIter d_first,
247 BinaryOp &&binary_op, T init) {
248 return inclusive_scan(dr::sp::par_unseq, first, last, d_first,
249 std::forward<BinaryOp>(binary_op), init);
250}
251
252} // namespace dr::sp
Definition: onedpl_direct_iterator.hpp:15
Definition: allocators.hpp:20
Definition: vector.hpp:14
Definition: concepts.hpp:42
Definition: concepts.hpp:31