7#include <sycl/sycl.hpp>
9#include <oneapi/dpl/execution>
10#include <oneapi/dpl/numeric>
12#include <oneapi/dpl/async>
14#include <dr/concepts/concepts.hpp>
15#include <dr/detail/onedpl_direct_iterator.hpp>
16#include <dr/sp/algorithms/execution_policy.hpp>
17#include <dr/sp/allocators.hpp>
18#include <dr/sp/detail.hpp>
19#include <dr/sp/init.hpp>
20#include <dr/sp/vector.hpp>
21#include <dr/sp/views/views.hpp>
27void exclusive_scan_impl_(ExecutionPolicy &&policy, R &&r, O &&o, U init,
28 BinaryOp &&binary_op) {
29 using T = rng::range_value_t<O>;
32 std::is_same_v<std::remove_cvref_t<ExecutionPolicy>, device_policy>);
34 auto zipped_view = dr::sp::views::zip(r, o);
35 auto zipped_segments = zipped_view.zipped_segments();
37 if constexpr (std::is_same_v<std::remove_cvref_t<ExecutionPolicy>,
40 U *d_inits = sycl::malloc_device<U>(rng::size(zipped_segments),
41 sp::devices()[0], sp::context());
43 std::vector<sycl::event> events;
45 std::size_t segment_id = 0;
46 for (
auto &&segs : zipped_segments) {
47 auto &&[in_segment, out_segment] = segs;
49 auto last_element = rng::prev(rng::end(__detail::local(in_segment)));
50 auto dest = d_inits + segment_id;
52 auto &&q = __detail::queue(dr::ranges::rank(in_segment));
54 auto e = q.single_task([=] { *dest = *last_element; });
59 __detail::wait(events);
62 std::vector<U> inits(rng::size(zipped_segments));
64 sp::copy(d_inits, d_inits + inits.size(), inits.data() + 1);
66 sycl::free(d_inits, sp::context());
70 auto root = dr::sp::devices()[0];
73 std::size_t(zipped_segments.size()), allocator);
76 for (
auto &&segs : zipped_segments) {
77 auto &&[in_segment, out_segment] = segs;
79 auto &&q = __detail::queue(dr::ranges::rank(in_segment));
80 auto &&local_policy = __detail::dpl_policy(dr::ranges::rank(in_segment));
82 auto dist = rng::distance(in_segment);
85 auto first = rng::begin(in_segment);
86 auto last = rng::end(in_segment);
87 auto d_first = rng::begin(out_segment);
89 auto init = inits[segment_id];
91 auto event = oneapi::dpl::experimental::exclusive_scan_async(
96 auto dst_iter = dr::ranges::local(partial_sums).data() + segment_id;
98 auto src_iter = dr::ranges::local(out_segment).data();
99 rng::advance(src_iter, dist - 1);
101 auto e = q.submit([&](
auto &&h) {
103 h.single_task([=]() {
104 rng::range_value_t<O> value = *src_iter;
114 __detail::wait(events);
117 auto &&local_policy = __detail::dpl_policy(0);
119 auto first = dr::ranges::local(partial_sums).data();
120 auto last = first + partial_sums.size();
122 oneapi::dpl::experimental::inclusive_scan_async(local_policy, first, last,
127 for (
auto &&segs : zipped_segments) {
128 auto &&[in_segment, out_segment] = segs;
131 auto &&q = __detail::queue(dr::ranges::rank(out_segment));
133 auto first = rng::begin(out_segment);
137 dr::ranges::__detail::local(partial_sums).begin() + idx - 1;
139 sycl::event e = dr::__detail::parallel_for(
140 q, sycl::range<>(rng::distance(out_segment)),
141 [=](
auto idx) { d_first[idx] = binary_op(d_first[idx], *d_sum); });
148 __detail::wait(events);
159void exclusive_scan(ExecutionPolicy &&policy, R &&r, O &&o, T init,
160 BinaryOp &&binary_op) {
161 exclusive_scan_impl_(std::forward<ExecutionPolicy>(policy),
162 std::forward<R>(r), std::forward<O>(o), init,
163 std::forward<BinaryOp>(binary_op));
168void exclusive_scan(ExecutionPolicy &&policy, R &&r, O &&o, T init) {
169 exclusive_scan_impl_(std::forward<ExecutionPolicy>(policy),
170 std::forward<R>(r), std::forward<O>(o), init,
176void exclusive_scan(R &&r, O &&o, T init, BinaryOp &&binary_op) {
177 exclusive_scan_impl_(dr::sp::par_unseq, std::forward<R>(r),
178 std::forward<O>(o), init,
179 std::forward<BinaryOp>(binary_op));
184void exclusive_scan(R &&r, O &&o, T init) {
185 exclusive_scan_impl_(dr::sp::par_unseq, std::forward<R>(r),
186 std::forward<O>(o), init, std::plus<>{});
193void exclusive_scan(ExecutionPolicy &&policy, Iter first, Iter last,
194 OutputIter d_first, T init, BinaryOp &&binary_op) {
195 auto dist = rng::distance(first, last);
196 auto d_last = d_first;
197 rng::advance(d_last, dist);
198 exclusive_scan_impl_(
199 std::forward<ExecutionPolicy>(policy), rng::subrange(first, last),
200 rng::subrange(d_first, d_last), init, std::forward<BinaryOp>(binary_op));
205void exclusive_scan(ExecutionPolicy &&policy, Iter first, Iter last,
206 OutputIter d_first, T init) {
207 exclusive_scan(std::forward<ExecutionPolicy>(policy), first, last, d_first,
208 init, std::plus<>{});
212 typename T,
typename BinaryOp>
213void exclusive_scan(Iter first, Iter last, OutputIter d_first, T init,
214 BinaryOp &&binary_op) {
215 exclusive_scan(dr::sp::par_unseq, first, last, d_first, init,
216 std::forward<BinaryOp>(binary_op));
221void exclusive_scan(Iter first, Iter last, OutputIter d_first, T init) {
222 exclusive_scan(dr::sp::par_unseq, first, last, d_first, init);
Definition: onedpl_direct_iterator.hpp:15
Definition: allocators.hpp:20
Definition: vector.hpp:14
Definition: concepts.hpp:42
Definition: concepts.hpp:31