7#ifdef SYCL_LANGUAGE_VERSION
8#include <oneapi/dpl/async>
11#include <dr/detail/sycl_utils.hpp>
13namespace dr::mp::__detail {
15namespace detail = dr::__detail;
19namespace dr::mp::__detail {
21void local_inclusive_scan(
auto policy,
auto in,
auto out,
auto binary_op,
22 auto init, std::size_t seg_index) {
26 if (init && seg_index == 0) {
27 std::inclusive_scan(policy, in_begin_direct, in_end_direct,
28 out_begin_direct, binary_op, init.value());
30 std::inclusive_scan(policy, in_begin_direct, in_end_direct,
31 out_begin_direct, binary_op);
35void local_exclusive_scan(
auto policy,
auto in,
auto out,
auto binary_op,
36 auto init, std::size_t seg_index) {
42 assert(rng::size(in) > 1);
43 assert(rng::size(out) > 1);
46 std::inclusive_scan(policy, in_begin_direct, in_end_direct,
47 out_begin_direct, binary_op);
49 assert(init.has_value());
50 std::exclusive_scan(policy, in_begin_direct, in_end_direct,
51 out_begin_direct, init.value(), binary_op);
57 typename U = rng::range_value_t<R>>
58auto inclusive_exclusive_scan_impl_(R &&r, O &&d_first, BinaryOp &&binary_op,
59 std::optional<U> init = {}) {
61 assert(aligned(r, d_first));
63 bool use_sycl = mp::use_sycl();
64 auto comm = default_comm();
67 if (rng::size(r) <= comm.size() * (comm.size() - 1) + 1) {
68 std::vector<value_type> vec_in(rng::size(r));
69 std::vector<value_type> vec_out(rng::size(r));
70 mp::copy(0, r, vec_in.begin());
72 if (comm.rank() == 0) {
73 if constexpr (is_exclusive) {
74 assert(init.has_value());
78 init.value(), binary_op);
80 if (init.has_value()) {
84 binary_op, init.value());
93 mp::copy(0, vec_out, d_first);
94 return d_first + rng::size(r);
97 auto rank = comm.rank();
98 auto local_segs = rng::views::zip(local_segments(r), local_segments(d_first));
100 rng::views::zip(dr::ranges::segments(r), dr::ranges::segments(d_first));
101 std::size_t num_segs = std::size_t(rng::size(dr::ranges::segments(r)));
104 std::size_t seg_index = 0;
105 for (
auto global_seg : global_segs) {
106 auto [global_in, global_out] = global_seg;
108 if (dr::ranges::rank(global_in) == rank) {
109 auto local_in = dr::ranges::__detail::local(global_in);
110 auto local_out = dr::ranges::__detail::local(global_out);
112#ifdef SYCL_LANGUAGE_VERSION
113 if constexpr (is_exclusive) {
114 local_exclusive_scan(dpl_policy(), local_in, local_out, binary_op,
117 local_inclusive_scan(dpl_policy(), local_in, local_out, binary_op,
124 if constexpr (is_exclusive) {
125 local_exclusive_scan(std::execution::par_unseq, local_in, local_out,
126 binary_op, init, seg_index);
128 local_inclusive_scan(std::execution::par_unseq, local_in, local_out,
129 binary_op, init, seg_index);
138 auto win = root_win();
139 for (
auto global_seg : global_segs) {
141 if (seg_index == num_segs - 1) {
145 auto [global_in, global_out] = global_seg;
146 if (dr::ranges::rank(global_in) == rank) {
147 auto local_out = dr::ranges::__detail::local(global_out);
148 auto local_in = dr::ranges::__detail::local(global_in);
149 rng::range_value_t<R> back;
150 if constexpr (is_exclusive) {
152 auto ret = sycl_get(local_out.back(), local_in.back());
153 back = binary_op(ret.first, ret.second);
155 back = binary_op(local_out.back(), local_in.back());
158 back = use_sycl ? sycl_get(local_out.back()) : local_out.back();
161 win.put(back, 0, seg_index);
170 value_type *partials = win.local_data<value_type>();
171 std::inclusive_scan(partials, partials + num_segs, partials, binary_op);
177 for (
auto global_seg : global_segs) {
179 auto [global_in, global_out] = global_seg;
181 auto offset = win.get<value_type>(0, seg_index - 1);
182 auto rebase = [offset, binary_op](
auto &v) { v = binary_op(v, offset); };
183 if (dr::ranges::rank(global_in) == rank) {
184 auto local_in = dr::ranges::__detail::local(global_in);
185 auto local_out = rng::views::take(
186 dr::ranges::__detail::local(global_out), rng::size(local_in));
187 auto local_out_adj = [use_sycl](
auto local_out,
auto offset) {
188 bool _use_sycl = use_sycl;
189 if constexpr (is_exclusive) {
190 auto local_out_begin_direct =
193 sycl_copy(&offset, &(*local_out_begin_direct));
195 *local_out_begin_direct = offset;
197 return local_out | rng::views::drop(1);
201 }(local_out, offset);
203#ifdef SYCL_LANGUAGE_VERSION
204 auto wrap_rebase = [rebase, base = rng::begin(local_out_adj)](
205 auto idx) { rebase(base[idx]); };
206 detail::parallel_for(dr::mp::sycl_queue(),
207 sycl::range<>(rng::distance(local_out_adj)),
214 std::for_each(std::execution::par_unseq, local_out_adj.begin(),
215 local_out_adj.end(), rebase);
224 return d_first + rng::size(r);
Definition: onedpl_direct_iterator.hpp:15
Definition: concepts.hpp:42
Definition: concepts.hpp:31