Distributed Ranges
Loading...
Searching...
No Matches
sort.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <oneapi/dpl/execution>
8
9#include <oneapi/dpl/algorithm>
10#include <oneapi/dpl/async>
11
12#include <dr/concepts/concepts.hpp>
13#include <dr/detail/onedpl_direct_iterator.hpp>
14#include <dr/sp/init.hpp>
15
16#include <omp.h>
17#include <sycl/sycl.hpp>
18
19namespace dr::sp {
20
21namespace __detail {
22
23template <typename LocalPolicy, typename InputIt, typename Compare>
24sycl::event sort_async(LocalPolicy &&policy, InputIt first, InputIt last,
25 Compare &&comp) {
26 if (rng::distance(first, last) >= 2) {
27 dr::__detail::direct_iterator d_first(first);
29 return oneapi::dpl::experimental::sort_async(
30 std::forward<LocalPolicy>(policy), d_first, d_last,
31 std::forward<Compare>(comp));
32 } else {
33 return sycl::event{};
34 }
35}
36
37template <typename LocalPolicy, typename InputIt1, typename InputIt2,
38 typename OutputIt, typename Comparator = std::less<>>
39OutputIt lower_bound(LocalPolicy &&policy, InputIt1 start, InputIt1 end,
40 InputIt2 value_first, InputIt2 value_last, OutputIt result,
41 Comparator comp = Comparator()) {
42 dr::__detail::direct_iterator d_start(start);
44
45 dr::__detail::direct_iterator d_value_first(value_first);
46 dr::__detail::direct_iterator d_value_last(value_last);
47
48 dr::__detail::direct_iterator d_result(result);
49
50 return oneapi::dpl::lower_bound(std::forward<LocalPolicy>(policy), d_start,
51 d_end, d_value_first, d_value_last, d_result,
52 comp)
53 .base();
54}
55
56} // namespace __detail
57
58template <dr::distributed_range R, typename Compare = std::less<>>
59void sort(R &&r, Compare comp = Compare()) {
60 auto &&segments = dr::ranges::segments(r);
61
62 if (rng::size(segments) == 0) {
63 return;
64 } else if (rng::size(segments) == 1) {
65 auto &&segment = *rng::begin(segments);
66 auto &&local_policy =
67 dr::sp::__detail::dpl_policy(dr::ranges::rank(segment));
68 auto &&local_segment = dr::sp::__detail::local(segment);
69
70 __detail::sort_async(local_policy, rng::begin(local_segment),
71 rng::end(local_segment), comp)
72 .wait();
73 return;
74 }
75
76 using T = rng::range_value_t<R>;
77 std::vector<sycl::event> events;
78
79 const std::size_t n_segments = std::size_t(rng::size(segments));
80 const std::size_t n_splitters = n_segments - 1;
81
82 // Sort each local segment, then compute medians.
83 // Each segment has `n_splitters` medians,
84 // so `n_segments * n_splitters` medians total.
85
86 T *medians = sycl::malloc_device<T>(n_segments * n_splitters,
87 sp::devices()[0], sp::context());
88
89 for (auto &&[segment_id_, segment] : rng::views::enumerate(segments)) {
90 auto const segment_id = static_cast<std::size_t>(segment_id_);
91 auto &&q = dr::sp::__detail::queue(dr::ranges::rank(segment));
92 auto &&local_policy =
93 dr::sp::__detail::dpl_policy(dr::ranges::rank(segment));
94
95 auto &&local_segment = dr::sp::__detail::local(segment);
96
97 auto s = __detail::sort_async(local_policy, rng::begin(local_segment),
98 rng::end(local_segment), comp);
99
100 double step_size = static_cast<double>(rng::size(segment)) / n_segments;
101
102 auto local_begin = rng::begin(local_segment);
103
104 auto e = q.submit([&](auto &&h) {
105 h.depends_on(s);
106
107 h.parallel_for(n_splitters, [=](auto i) {
108 medians[n_splitters * segment_id + i] =
109 local_begin[std::size_t(step_size * (i + 1) + 0.5)];
110 });
111 });
112
113 events.push_back(e);
114 }
115
116 dr::sp::__detail::wait(events);
117 events.clear();
118
119 // Compute global medians by sorting medians and
120 // computing `n_splitters` medians from the medians.
121 auto &&local_policy = dr::sp::__detail::dpl_policy(0);
122 __detail::sort_async(local_policy, medians,
123 medians + n_segments * n_splitters, comp)
124 .wait();
125
126 double step_size = static_cast<double>(n_segments * n_splitters) / n_segments;
127
128 // - Collect median of medians to get final splitters.
129 // - Write splitters to [0, n_splitters) in `medians`
130
131 auto &&q = dr::sp::__detail::queue(0);
132 q.single_task([=] {
133 for (std::size_t i = 0; i < n_splitters; i++) {
134 medians[i] = medians[std::size_t(step_size * (i + 1) + 0.5)];
135 }
136 }).wait();
137
138 std::vector<std::size_t *> splitter_indices;
139 // sorted_seg_sizes[i]: how many elements exists in all segments between
140 // medians[i-1] and medians[i]
141 std::vector<std::size_t> sorted_seg_sizes(n_segments, 0);
142 // push_positions[snd_idx][rcv_idx]: shift inside final segment of rcv_idx for
143 // data being sent from initial snd_idx segment
144 std::vector<std::vector<std::size_t>> push_positions(n_segments);
145
146 // Compute how many elements will be sent to each of the new "sorted
147 // segments". Simultaneously compute the offsets `push_positions` where each
148 // segments' corresponding elements will be pushed.
149
150 for (auto &&[segment_id, segment] : rng::views::enumerate(segments)) {
151 auto &&q = dr::sp::__detail::queue(dr::ranges::rank(segment));
152 auto &&local_policy =
153 dr::sp::__detail::dpl_policy(dr::ranges::rank(segment));
154
155 auto &&local_segment = dr::sp::__detail::local(segment);
156
157 // slitter_i = [ index in local_segment of first element greater or equal
158 // 1st global median, index ... 2nd global median, ..., size of
159 // local_segment]
160 std::size_t *splitter_i = sycl::malloc_shared<std::size_t>(
161 n_segments, q.get_device(), sp::context());
162 splitter_indices.push_back(splitter_i);
163
164 // Local copy `medians_l` necessary due to [GSD-3893]
165 T *medians_l =
166 sycl::malloc_device<T>(n_splitters, q.get_device(), sp::context());
167
168 q.memcpy(medians_l, medians, sizeof(T) * n_splitters).wait();
169
170 __detail::lower_bound(local_policy, rng::begin(local_segment),
171 rng::end(local_segment), medians_l,
172 medians_l + n_splitters, splitter_i, comp);
173
174 sycl::free(medians_l, sp::context());
175
176 splitter_i[n_splitters] = rng::size(local_segment);
177
178 for (std::size_t i = 0; i < n_segments; i++) {
179 const std::size_t n_elements =
180 splitter_i[i] - (i == 0 ? 0 : splitter_i[i - 1]);
181 const std::size_t pos =
182 std::atomic_ref(sorted_seg_sizes[i]).fetch_add(n_elements);
183 push_positions[static_cast<std::size_t>(segment_id)].push_back(pos);
184 }
185 }
186
187 // Allocate new "sorted segments"
188 std::vector<T *> sorted_segments;
189
190 for (auto &&[segment_id, segment] : rng::views::enumerate(segments)) {
191 auto &&q = dr::sp::__detail::queue(dr::ranges::rank(segment));
192
193 T *buffer = sycl::malloc_device<T>(
194 sorted_seg_sizes[static_cast<std::size_t>(segment_id)], q);
195 sorted_segments.push_back(buffer);
196 }
197
198 // Copy corresponding elements to each "sorted segment"
199 for (auto &&[segment_id_, segment] : rng::views::enumerate(segments)) {
200 auto &&local_segment = dr::sp::__detail::local(segment);
201 const auto segment_id = static_cast<std::size_t>(segment_id_);
202
203 std::size_t *splitter_i = splitter_indices[segment_id];
204
205 auto p_first = rng::begin(local_segment);
206 auto p_last = p_first;
207 for (std::size_t i = 0; i < n_segments; i++) {
208 p_last = rng::begin(local_segment) + splitter_i[i];
209
210 const std::size_t pos = push_positions[segment_id][i];
211
212 auto e = sp::copy_async(p_first, p_last, sorted_segments[i] + pos);
213 events.push_back(e);
214
215 p_first = p_last;
216 }
217 }
218
219 dr::sp::__detail::wait(events);
220 events.clear();
221
222 // merge sorted chunks within each of these new segments
223
224#pragma omp parallel num_threads(n_segments)
225 {
226 int t = omp_get_thread_num();
227
228 std::vector<std::size_t> chunks_ind;
229 for (std::size_t i = 0; i < n_segments; i++) {
230 chunks_ind.push_back(push_positions[i][t]);
231 }
232
233 auto _segments = n_segments;
234 while (_segments > 1) {
235 std::vector<std::size_t> new_chunks;
236 new_chunks.push_back(0);
237
238 for (int s = 0; s < _segments / 2; s++) {
239
240 const std::size_t l = (2 * s + 2 < _segments) ? chunks_ind[2 * s + 2]
241 : sorted_seg_sizes[t];
242
243 auto first = dr::__detail::direct_iterator(sorted_segments[t] +
244 chunks_ind[2 * s]);
245 auto middle = dr::__detail::direct_iterator(sorted_segments[t] +
246 chunks_ind[2 * s + 1]);
247 auto last = dr::__detail::direct_iterator(sorted_segments[t] + l);
248
249 new_chunks.push_back(l);
250
251 oneapi::dpl::inplace_merge(
252 __detail::dpl_policy(dr::ranges::rank(segments[t])), first, middle,
253 last, std::forward<Compare>(comp));
254 }
255
256 _segments = (_segments + 1) / 2;
257
258 std::swap(chunks_ind, new_chunks);
259 }
260 } // End of omp parallel region
261
262 // Copy the results into the output.
263
264 auto d_first = rng::begin(r);
265
266 for (std::size_t i = 0; i < sorted_segments.size(); i++) {
267 T *seg = sorted_segments[i];
268 std::size_t n_elements = sorted_seg_sizes[i];
269
270 auto e = sp::copy_async(seg, seg + n_elements, d_first);
271
272 events.push_back(e);
273
274 rng::advance(d_first, n_elements);
275 }
276
277 dr::sp::__detail::wait(events);
278
279 // Free temporary memory.
280
281 for (auto &&sorted_seg : sorted_segments) {
282 sycl::free(sorted_seg, sp::context());
283 }
284
285 for (auto &&splitter_i : splitter_indices) {
286 sycl::free(splitter_i, sp::context());
287 }
288
289 sycl::free(medians, sp::context());
290}
291
292template <dr::distributed_iterator RandomIt, typename Compare = std::less<>>
293void sort(RandomIt first, RandomIt last, Compare comp = Compare()) {
294 sort(rng::subrange(first, last), comp);
295}
296
297} // namespace dr::sp
Definition: onedpl_direct_iterator.hpp:15