Distributed Ranges
Loading...
Searching...
No Matches
exclusive_scan.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <sycl/sycl.hpp>
8
9#include <oneapi/dpl/execution>
10#include <oneapi/dpl/numeric>
11
12#include <oneapi/dpl/async>
13
14#include <dr/concepts/concepts.hpp>
15#include <dr/detail/onedpl_direct_iterator.hpp>
16#include <dr/sp/algorithms/execution_policy.hpp>
17#include <dr/sp/allocators.hpp>
18#include <dr/sp/detail.hpp>
19#include <dr/sp/init.hpp>
20#include <dr/sp/vector.hpp>
21#include <dr/sp/views/views.hpp>
22
23namespace dr::sp {
24
25template <typename ExecutionPolicy, dr::distributed_contiguous_range R,
26 dr::distributed_contiguous_range O, typename U, typename BinaryOp>
27void exclusive_scan_impl_(ExecutionPolicy &&policy, R &&r, O &&o, U init,
28 BinaryOp &&binary_op) {
29 using T = rng::range_value_t<O>;
30
31 static_assert(
32 std::is_same_v<std::remove_cvref_t<ExecutionPolicy>, device_policy>);
33
34 auto zipped_view = dr::sp::views::zip(r, o);
35 auto zipped_segments = zipped_view.zipped_segments();
36
37 if constexpr (std::is_same_v<std::remove_cvref_t<ExecutionPolicy>,
38 device_policy>) {
39
40 U *d_inits = sycl::malloc_device<U>(rng::size(zipped_segments),
41 sp::devices()[0], sp::context());
42
43 std::vector<sycl::event> events;
44
45 std::size_t segment_id = 0;
46 for (auto &&segs : zipped_segments) {
47 auto &&[in_segment, out_segment] = segs;
48
49 auto last_element = rng::prev(rng::end(__detail::local(in_segment)));
50 auto dest = d_inits + segment_id;
51
52 auto &&q = __detail::queue(dr::ranges::rank(in_segment));
53
54 auto e = q.single_task([=] { *dest = *last_element; });
55 events.push_back(e);
56 segment_id++;
57 }
58
59 __detail::wait(events);
60 events.clear();
61
62 std::vector<U> inits(rng::size(zipped_segments));
63
64 sp::copy(d_inits, d_inits + inits.size(), inits.data() + 1);
65
66 sycl::free(d_inits, sp::context());
67
68 inits[0] = init;
69
70 auto root = dr::sp::devices()[0];
71 dr::sp::device_allocator<T> allocator(dr::sp::context(), root);
73 std::size_t(zipped_segments.size()), allocator);
74
75 segment_id = 0;
76 for (auto &&segs : zipped_segments) {
77 auto &&[in_segment, out_segment] = segs;
78
79 auto &&q = __detail::queue(dr::ranges::rank(in_segment));
80 auto &&local_policy = __detail::dpl_policy(dr::ranges::rank(in_segment));
81
82 auto dist = rng::distance(in_segment);
83 assert(dist > 0);
84
85 auto first = rng::begin(in_segment);
86 auto last = rng::end(in_segment);
87 auto d_first = rng::begin(out_segment);
88
89 auto init = inits[segment_id];
90
91 auto event = oneapi::dpl::experimental::exclusive_scan_async(
92 local_policy, dr::__detail::direct_iterator(first),
94 dr::__detail::direct_iterator(d_first), init, binary_op);
95
96 auto dst_iter = dr::ranges::local(partial_sums).data() + segment_id;
97
98 auto src_iter = dr::ranges::local(out_segment).data();
99 rng::advance(src_iter, dist - 1);
100
101 auto e = q.submit([&](auto &&h) {
102 h.depends_on(event);
103 h.single_task([=]() {
104 rng::range_value_t<O> value = *src_iter;
105 *dst_iter = value;
106 });
107 });
108
109 events.push_back(e);
110
111 segment_id++;
112 }
113
114 __detail::wait(events);
115 events.clear();
116
117 auto &&local_policy = __detail::dpl_policy(0);
118
119 auto first = dr::ranges::local(partial_sums).data();
120 auto last = first + partial_sums.size();
121
122 oneapi::dpl::experimental::inclusive_scan_async(local_policy, first, last,
123 first, binary_op)
124 .wait();
125
126 std::size_t idx = 0;
127 for (auto &&segs : zipped_segments) {
128 auto &&[in_segment, out_segment] = segs;
129
130 if (idx > 0) {
131 auto &&q = __detail::queue(dr::ranges::rank(out_segment));
132
133 auto first = rng::begin(out_segment);
134 dr::__detail::direct_iterator d_first(first);
135
136 auto d_sum =
137 dr::ranges::__detail::local(partial_sums).begin() + idx - 1;
138
139 sycl::event e = dr::__detail::parallel_for(
140 q, sycl::range<>(rng::distance(out_segment)),
141 [=](auto idx) { d_first[idx] = binary_op(d_first[idx], *d_sum); });
142
143 events.push_back(e);
144 }
145 idx++;
146 }
147
148 __detail::wait(events);
149
150 } else {
151 assert(false);
152 }
153}
154
155// Ranges versions
156
157template <typename ExecutionPolicy, dr::distributed_contiguous_range R,
158 dr::distributed_contiguous_range O, typename T, typename BinaryOp>
159void exclusive_scan(ExecutionPolicy &&policy, R &&r, O &&o, T init,
160 BinaryOp &&binary_op) {
161 exclusive_scan_impl_(std::forward<ExecutionPolicy>(policy),
162 std::forward<R>(r), std::forward<O>(o), init,
163 std::forward<BinaryOp>(binary_op));
164}
165
166template <typename ExecutionPolicy, dr::distributed_contiguous_range R,
168void exclusive_scan(ExecutionPolicy &&policy, R &&r, O &&o, T init) {
169 exclusive_scan_impl_(std::forward<ExecutionPolicy>(policy),
170 std::forward<R>(r), std::forward<O>(o), init,
171 std::plus<>{});
172}
173
175 dr::distributed_contiguous_range O, typename T, typename BinaryOp>
176void exclusive_scan(R &&r, O &&o, T init, BinaryOp &&binary_op) {
177 exclusive_scan_impl_(dr::sp::par_unseq, std::forward<R>(r),
178 std::forward<O>(o), init,
179 std::forward<BinaryOp>(binary_op));
180}
181
184void exclusive_scan(R &&r, O &&o, T init) {
185 exclusive_scan_impl_(dr::sp::par_unseq, std::forward<R>(r),
186 std::forward<O>(o), init, std::plus<>{});
187}
188
189// Iterator versions
190
191template <typename ExecutionPolicy, dr::distributed_iterator Iter,
192 dr::distributed_iterator OutputIter, typename T, typename BinaryOp>
193void exclusive_scan(ExecutionPolicy &&policy, Iter first, Iter last,
194 OutputIter d_first, T init, BinaryOp &&binary_op) {
195 auto dist = rng::distance(first, last);
196 auto d_last = d_first;
197 rng::advance(d_last, dist);
198 exclusive_scan_impl_(
199 std::forward<ExecutionPolicy>(policy), rng::subrange(first, last),
200 rng::subrange(d_first, d_last), init, std::forward<BinaryOp>(binary_op));
201}
202
203template <typename ExecutionPolicy, dr::distributed_iterator Iter,
204 dr::distributed_iterator OutputIter, typename T>
205void exclusive_scan(ExecutionPolicy &&policy, Iter first, Iter last,
206 OutputIter d_first, T init) {
207 exclusive_scan(std::forward<ExecutionPolicy>(policy), first, last, d_first,
208 init, std::plus<>{});
209}
210
211template <dr::distributed_iterator Iter, dr::distributed_iterator OutputIter,
212 typename T, typename BinaryOp>
213void exclusive_scan(Iter first, Iter last, OutputIter d_first, T init,
214 BinaryOp &&binary_op) {
215 exclusive_scan(dr::sp::par_unseq, first, last, d_first, init,
216 std::forward<BinaryOp>(binary_op));
217}
218
219template <dr::distributed_iterator Iter, dr::distributed_iterator OutputIter,
220 typename T>
221void exclusive_scan(Iter first, Iter last, OutputIter d_first, T init) {
222 exclusive_scan(dr::sp::par_unseq, first, last, d_first, init);
223}
224
225} // namespace dr::sp
Definition: onedpl_direct_iterator.hpp:15
Definition: allocators.hpp:20
Definition: vector.hpp:14
Definition: concepts.hpp:42
Definition: concepts.hpp:31