Distributed Ranges
Loading...
Searching...
No Matches
inclusive_exclusive_scan_impl.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#ifdef SYCL_LANGUAGE_VERSION
8#include <oneapi/dpl/async>
9#endif
10
11#include <dr/detail/sycl_utils.hpp>
12
13namespace dr::mp::__detail {
14
15namespace detail = dr::__detail;
16
17}
18
19namespace dr::mp::__detail {
20
21void local_inclusive_scan(auto policy, auto in, auto out, auto binary_op,
22 auto init, std::size_t seg_index) {
23 auto in_begin_direct = detail::direct_iterator(in.begin());
24 auto in_end_direct = detail::direct_iterator(in.end());
25 auto out_begin_direct = detail::direct_iterator(out.begin());
26 if (init && seg_index == 0) {
27 std::inclusive_scan(policy, in_begin_direct, in_end_direct,
28 out_begin_direct, binary_op, init.value());
29 } else {
30 std::inclusive_scan(policy, in_begin_direct, in_end_direct,
31 out_begin_direct, binary_op);
32 }
33}
34
35void local_exclusive_scan(auto policy, auto in, auto out, auto binary_op,
36 auto init, std::size_t seg_index) {
37 auto in_begin_direct = detail::direct_iterator(in.begin());
38 auto in_end_direct = detail::direct_iterator(in.end());
39 auto out_begin_direct = detail::direct_iterator(out.begin());
40
41 if (seg_index != 0) {
42 assert(rng::size(in) > 1);
43 assert(rng::size(out) > 1);
44 --in_end_direct;
45 ++out_begin_direct;
46 std::inclusive_scan(policy, in_begin_direct, in_end_direct,
47 out_begin_direct, binary_op);
48 } else {
49 assert(init.has_value());
50 std::exclusive_scan(policy, in_begin_direct, in_end_direct,
51 out_begin_direct, init.value(), binary_op);
52 }
53}
54
55template <bool is_exclusive, dr::distributed_contiguous_range R,
56 dr::distributed_iterator O, typename BinaryOp,
57 typename U = rng::range_value_t<R>>
58auto inclusive_exclusive_scan_impl_(R &&r, O &&d_first, BinaryOp &&binary_op,
59 std::optional<U> init = {}) {
60 using value_type = U;
61 assert(aligned(r, d_first));
62
63 bool use_sycl = mp::use_sycl();
64 auto comm = default_comm();
65
66 // for input vector, which may have segment of size 1, do sequential scan
67 if (rng::size(r) <= comm.size() * (comm.size() - 1) + 1) {
68 std::vector<value_type> vec_in(rng::size(r));
69 std::vector<value_type> vec_out(rng::size(r));
70 mp::copy(0, r, vec_in.begin());
71
72 if (comm.rank() == 0) {
73 if constexpr (is_exclusive) {
74 assert(init.has_value());
75 std::exclusive_scan(detail::direct_iterator(vec_in.begin()),
76 detail::direct_iterator(vec_in.end()),
77 detail::direct_iterator(vec_out.begin()),
78 init.value(), binary_op);
79 } else {
80 if (init.has_value()) {
81 std::inclusive_scan(detail::direct_iterator(vec_in.begin()),
82 detail::direct_iterator(vec_in.end()),
83 detail::direct_iterator(vec_out.begin()),
84 binary_op, init.value());
85 } else {
86 std::inclusive_scan(detail::direct_iterator(vec_in.begin()),
87 detail::direct_iterator(vec_in.end()),
88 detail::direct_iterator(vec_out.begin()),
89 binary_op);
90 }
91 }
92 }
93 mp::copy(0, vec_out, d_first);
94 return d_first + rng::size(r);
95 }
96
97 auto rank = comm.rank();
98 auto local_segs = rng::views::zip(local_segments(r), local_segments(d_first));
99 auto global_segs =
100 rng::views::zip(dr::ranges::segments(r), dr::ranges::segments(d_first));
101 std::size_t num_segs = std::size_t(rng::size(dr::ranges::segments(r)));
102
103 // Pass 1 local inclusive scan
104 std::size_t seg_index = 0;
105 for (auto global_seg : global_segs) {
106 auto [global_in, global_out] = global_seg;
107
108 if (dr::ranges::rank(global_in) == rank) {
109 auto local_in = dr::ranges::__detail::local(global_in);
110 auto local_out = dr::ranges::__detail::local(global_out);
111 if (use_sycl) {
112#ifdef SYCL_LANGUAGE_VERSION
113 if constexpr (is_exclusive) {
114 local_exclusive_scan(dpl_policy(), local_in, local_out, binary_op,
115 init, seg_index);
116 } else {
117 local_inclusive_scan(dpl_policy(), local_in, local_out, binary_op,
118 init, seg_index);
119 }
120#else
121 assert(false);
122#endif
123 } else {
124 if constexpr (is_exclusive) {
125 local_exclusive_scan(std::execution::par_unseq, local_in, local_out,
126 binary_op, init, seg_index);
127 } else {
128 local_inclusive_scan(std::execution::par_unseq, local_in, local_out,
129 binary_op, init, seg_index);
130 }
131 }
132 }
133
134 seg_index++;
135 }
136 // Pass 2 put partial sums on root
137 seg_index = 0;
138 auto win = root_win();
139 for (auto global_seg : global_segs) {
140 // Do not need last segment
141 if (seg_index == num_segs - 1) {
142 break;
143 }
144
145 auto [global_in, global_out] = global_seg;
146 if (dr::ranges::rank(global_in) == rank) {
147 auto local_out = dr::ranges::__detail::local(global_out);
148 auto local_in = dr::ranges::__detail::local(global_in);
149 rng::range_value_t<R> back;
150 if constexpr (is_exclusive) {
151 if (use_sycl) {
152 auto ret = sycl_get(local_out.back(), local_in.back());
153 back = binary_op(ret.first, ret.second);
154 } else {
155 back = binary_op(local_out.back(), local_in.back());
156 }
157 } else {
158 back = use_sycl ? sycl_get(local_out.back()) : local_out.back();
159 }
160
161 win.put(back, 0, seg_index);
162 }
163
164 seg_index++;
165 }
166 win.fence();
167
168 // Pass 3: scan of partial sums on root
169 if (rank == 0) {
170 value_type *partials = win.local_data<value_type>();
171 std::inclusive_scan(partials, partials + num_segs, partials, binary_op);
172 }
173 barrier();
174
175 // Pass 4: rebase
176 seg_index = 0;
177 for (auto global_seg : global_segs) {
178 if (seg_index > 0) {
179 auto [global_in, global_out] = global_seg;
180
181 auto offset = win.get<value_type>(0, seg_index - 1);
182 auto rebase = [offset, binary_op](auto &v) { v = binary_op(v, offset); };
183 if (dr::ranges::rank(global_in) == rank) {
184 auto local_in = dr::ranges::__detail::local(global_in);
185 auto local_out = rng::views::take(
186 dr::ranges::__detail::local(global_out), rng::size(local_in));
187 auto local_out_adj = [use_sycl](auto local_out, auto offset) {
188 bool _use_sycl = use_sycl;
189 if constexpr (is_exclusive) {
190 auto local_out_begin_direct =
191 detail::direct_iterator(local_out.begin());
192 if (_use_sycl) {
193 sycl_copy(&offset, &(*local_out_begin_direct));
194 } else {
195 *local_out_begin_direct = offset;
196 }
197 return local_out | rng::views::drop(1);
198 } else {
199 return local_out;
200 }
201 }(local_out, offset);
202 if (use_sycl) {
203#ifdef SYCL_LANGUAGE_VERSION
204 auto wrap_rebase = [rebase, base = rng::begin(local_out_adj)](
205 auto idx) { rebase(base[idx]); };
206 detail::parallel_for(dr::mp::sycl_queue(),
207 sycl::range<>(rng::distance(local_out_adj)),
208 wrap_rebase)
209 .wait();
210#else
211 assert(false);
212#endif
213 } else {
214 std::for_each(std::execution::par_unseq, local_out_adj.begin(),
215 local_out_adj.end(), rebase);
216 }
217 // dr::drlog.debug("rebase after: {}\n", local_out_adj);
218 }
219 }
220 seg_index++;
221 }
222
223 barrier();
224 return d_first + rng::size(r);
225}
226} // namespace dr::mp::__detail
Definition: onedpl_direct_iterator.hpp:15
Definition: concepts.hpp:42
Definition: concepts.hpp:31