9#include <sycl/sycl.hpp>
11#include <oneapi/dpl/execution>
12#include <oneapi/dpl/numeric>
14#include <oneapi/dpl/async>
16#include <dr/concepts/concepts.hpp>
17#include <dr/detail/onedpl_direct_iterator.hpp>
18#include <dr/sp/algorithms/execution_policy.hpp>
19#include <dr/sp/allocators.hpp>
20#include <dr/sp/detail.hpp>
21#include <dr/sp/init.hpp>
22#include <dr/sp/vector.hpp>
23#include <dr/sp/views/views.hpp>
29 typename U = rng::range_value_t<R>>
30void inclusive_scan_impl_(ExecutionPolicy &&policy, R &&r, O &&o,
31 BinaryOp &&binary_op, std::optional<U> init = {}) {
32 using T = rng::range_value_t<O>;
35 std::is_same_v<std::remove_cvref_t<ExecutionPolicy>, device_policy>);
37 auto zipped_view = dr::sp::views::zip(r, o);
38 auto zipped_segments = zipped_view.zipped_segments();
40 if constexpr (std::is_same_v<std::remove_cvref_t<ExecutionPolicy>,
43 std::vector<sycl::event> events;
45 auto root = dr::sp::devices()[0];
48 std::size_t(zipped_segments.size()), allocator);
50 std::size_t segment_id = 0;
51 for (
auto &&segs : zipped_segments) {
52 auto &&[in_segment, out_segment] = segs;
54 auto &&q = __detail::queue(dr::ranges::rank(in_segment));
55 auto &&local_policy = __detail::dpl_policy(dr::ranges::rank(in_segment));
57 auto dist = rng::distance(in_segment);
60 auto first = rng::begin(in_segment);
61 auto last = rng::end(in_segment);
62 auto d_first = rng::begin(out_segment);
66 if (segment_id == 0 && init.has_value()) {
67 event = oneapi::dpl::experimental::inclusive_scan_async(
72 event = oneapi::dpl::experimental::inclusive_scan_async(
78 auto dst_iter = dr::ranges::local(partial_sums).data() + segment_id;
80 auto src_iter = dr::ranges::local(out_segment).data();
81 rng::advance(src_iter, dist - 1);
83 auto e = q.submit([&](
auto &&h) {
86 rng::range_value_t<O> value = *src_iter;
96 __detail::wait(events);
99 auto &&local_policy = __detail::dpl_policy(0);
101 auto first = dr::ranges::local(partial_sums).data();
102 auto last = first + partial_sums.size();
104 oneapi::dpl::experimental::inclusive_scan_async(local_policy, first, last,
109 for (
auto &&segs : zipped_segments) {
110 auto &&[in_segment, out_segment] = segs;
113 auto &&q = __detail::queue(dr::ranges::rank(out_segment));
115 auto first = rng::begin(out_segment);
119 dr::ranges::__detail::local(partial_sums).begin() + idx - 1;
121 sycl::event e = dr::__detail::parallel_for(
122 q, sycl::range<>(rng::distance(out_segment)),
123 [=](
auto idx) { d_first[idx] = binary_op(d_first[idx], *d_sum); });
130 __detail::wait(events);
139void inclusive_scan(ExecutionPolicy &&policy, R &&r, O &&o,
140 BinaryOp &&binary_op, T init) {
141 inclusive_scan_impl_(std::forward<ExecutionPolicy>(policy),
142 std::forward<R>(r), std::forward<O>(o),
143 std::forward<BinaryOp>(binary_op), std::optional(init));
148void inclusive_scan(ExecutionPolicy &&policy, R &&r, O &&o,
149 BinaryOp &&binary_op) {
150 inclusive_scan_impl_(std::forward<ExecutionPolicy>(policy),
151 std::forward<R>(r), std::forward<O>(o),
152 std::forward<BinaryOp>(binary_op));
157void inclusive_scan(ExecutionPolicy &&policy, R &&r, O &&o) {
158 inclusive_scan(std::forward<ExecutionPolicy>(policy), std::forward<R>(r),
159 std::forward<O>(o), std::plus<rng::range_value_t<R>>());
166OutputIter inclusive_scan(ExecutionPolicy &&policy, Iter first, Iter last,
167 OutputIter d_first, BinaryOp &&binary_op, T init) {
169 auto dist = rng::distance(first, last);
170 auto d_last = d_first;
171 rng::advance(d_last, dist);
172 inclusive_scan(std::forward<ExecutionPolicy>(policy),
173 rng::subrange(first, last), rng::subrange(d_first, d_last),
174 std::forward<BinaryOp>(binary_op), init);
181OutputIter inclusive_scan(ExecutionPolicy &&policy, Iter first, Iter last,
182 OutputIter d_first, BinaryOp &&binary_op) {
184 auto dist = rng::distance(first, last);
185 auto d_last = d_first;
186 rng::advance(d_last, dist);
187 inclusive_scan(std::forward<ExecutionPolicy>(policy),
188 rng::subrange(first, last), rng::subrange(d_first, d_last),
189 std::forward<BinaryOp>(binary_op));
196OutputIter inclusive_scan(ExecutionPolicy &&policy, Iter first, Iter last,
197 OutputIter d_first) {
198 auto dist = rng::distance(first, last);
199 auto d_last = d_first;
200 rng::advance(d_last, dist);
201 inclusive_scan(std::forward<ExecutionPolicy>(policy),
202 rng::subrange(first, last), rng::subrange(d_first, d_last));
211void inclusive_scan(R &&r, O &&o) {
212 inclusive_scan(dr::sp::par_unseq, std::forward<R>(r), std::forward<O>(o));
217void inclusive_scan(R &&r, O &&o, BinaryOp &&binary_op) {
218 inclusive_scan(dr::sp::par_unseq, std::forward<R>(r), std::forward<O>(o),
219 std::forward<BinaryOp>(binary_op));
224void inclusive_scan(R &&r, O &&o, BinaryOp &&binary_op, T init) {
225 inclusive_scan(dr::sp::par_unseq, std::forward<R>(r), std::forward<O>(o),
226 std::forward<BinaryOp>(binary_op), init);
231template <dr::distributed_iterator Iter, dr::distributed_iterator OutputIter>
232OutputIter inclusive_scan(Iter first, Iter last, OutputIter d_first) {
233 return inclusive_scan(dr::sp::par_unseq, first, last, d_first);
238OutputIter inclusive_scan(Iter first, Iter last, OutputIter d_first,
239 BinaryOp &&binary_op) {
240 return inclusive_scan(dr::sp::par_unseq, first, last, d_first,
241 std::forward<BinaryOp>(binary_op));
245 typename BinaryOp,
typename T>
246OutputIter inclusive_scan(Iter first, Iter last, OutputIter d_first,
247 BinaryOp &&binary_op, T init) {
248 return inclusive_scan(dr::sp::par_unseq, first, last, d_first,
249 std::forward<BinaryOp>(binary_op), init);
Definition: onedpl_direct_iterator.hpp:15
Definition: allocators.hpp:20
Definition: vector.hpp:14
Definition: concepts.hpp:42
Definition: concepts.hpp:31