Distributed Ranges
Loading...
Searching...
No Matches
sort.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#ifdef SYCL_LANGUAGE_VERSION
8#include <dr/sp/init.hpp>
9#include <oneapi/dpl/algorithm>
10#include <oneapi/dpl/execution>
11#include <oneapi/dpl/iterator>
12#endif
13
14#include <mpi.h>
15
16#include <algorithm>
17#include <utility>
18
19#include <dr/concepts/concepts.hpp>
20#include <dr/detail/logger.hpp>
21#include <dr/detail/onedpl_direct_iterator.hpp>
22#include <dr/detail/ranges_shim.hpp>
23#include <dr/mp/global.hpp>
24
25namespace dr::mp {
26
27namespace __detail {
28
29template <typename T> class buffer {
30
31public:
32 using value_type = T;
33 std::size_t size() { return size_; }
34
35 buffer(std::size_t cnt) : size_(cnt) {
36 if (cnt > 0) {
37 data_ = alloc_.allocate(cnt);
38 assert(data_ != nullptr);
39 }
40 }
41
42 ~buffer() {
43 if (data_ != nullptr)
44 alloc_.deallocate(data_, size_);
45 data_ = nullptr;
46 size_ = 0;
47 }
48
49 T *resize(std::size_t cnt) {
50 if (cnt == size_)
51 return data_;
52
53 if (cnt == 0) {
54 alloc_.deallocate(data_, size_);
55 data_ = nullptr;
56 } else {
57 T *newdata = alloc_.allocate(cnt);
58 copy(data_, newdata, std::min(size_, cnt));
59 alloc_.deallocate(data_, size_);
60 data_ = newdata;
61 }
62 size_ = cnt;
63 return data_;
64 }
65
66 void replace(buffer &other) {
67 if (data_ != nullptr)
68 alloc_.deallocate(data_, size_);
69
70 data_ = rng::data(other);
71 size_ = rng::size(other);
72 other.data_ = nullptr;
73 other.size_ = 0;
74 }
75
76 T *data() { return data_; }
77 T *begin() { return data_; }
78 T *end() { return data_ + size_; }
79
80private:
81 allocator<T> alloc_;
82 T *data_ = nullptr;
83 std::size_t size_ = 0;
84}; // class buffer
85
86template <typename R, typename Compare> void local_sort(R &r, Compare &&comp) {
87 if (rng::size(r) >= 2) {
88 if (mp::use_sycl()) {
89#ifdef SYCL_LANGUAGE_VERSION
90 auto policy = dpl_policy();
91 auto &&local_segment = dr::ranges::__detail::local(r);
92 DRLOG("GPU dpl::sort(), size {}", rng::size(r));
93 oneapi::dpl::sort(
94 policy, dr::__detail::direct_iterator(rng::begin(local_segment)),
95 dr::__detail::direct_iterator(rng::end(local_segment)), comp);
96#else
97 assert(false);
98#endif
99 } else {
100 DRLOG("cpu rng::sort, size {}", rng::size(r));
101 rng::sort(rng::begin(r), rng::end(r), comp);
102 }
103 }
104}
105
106template <typename T, typename Compare>
107void local_merge(buffer<T> &v, std::vector<std::size_t> chunks,
108 Compare &&comp) {
109
110 std::exclusive_scan(chunks.begin(), chunks.end(), chunks.begin(), 0);
111
112 while (chunks.size() > 1) {
113 std::size_t segno = chunks.size();
114 std::vector<std::size_t> next_chunks;
115 for (std::size_t i = 0; i < segno / 2; i++) {
116 auto first = v.begin() + chunks[2 * i];
117 auto middle = v.begin() + chunks[2 * i + 1];
118 auto last = (2 * i + 2 < segno) ? v.begin() + chunks[2 * i + 2] : v.end();
119 if (mp::use_sycl()) {
120#ifdef SYCL_LANGUAGE_VERSION
121 auto dfirst = dr::__detail::direct_iterator(first);
122 auto dmiddle = dr::__detail::direct_iterator(middle);
123 auto dlast = dr::__detail::direct_iterator(last);
124 oneapi::dpl::inplace_merge(dpl_policy(), dfirst, dmiddle, dlast, comp);
125#else
126 assert(false);
127#endif
128 } else {
129 std::inplace_merge(first, middle, last, comp);
130 }
131 next_chunks.push_back(chunks[2 * i]);
132 }
133 if (segno % 2 == 1) {
134 next_chunks.push_back(chunks[segno - 1]);
135 }
136 std::swap(chunks, next_chunks);
137 }
138}
139
140/* elements of dist_sort */
141template <typename valT, typename Compare, typename Seg>
142void splitters(Seg &lsegment, Compare &&comp,
143 std::vector<std::size_t> &vec_split_i,
144 std::vector<std::size_t> &vec_split_s) {
145 const std::size_t _comm_size = default_comm().size(); // dr-style ignore
146
147 assert(rng::size(vec_split_i) == _comm_size);
148 assert(rng::size(vec_split_s) == _comm_size);
149
150 std::vector<valT> vec_lmedians(_comm_size + 1);
151 std::vector<valT> vec_gmedians((_comm_size + 1) * _comm_size);
152
153 const double _step_m = static_cast<double>(rng::size(lsegment)) /
154 static_cast<double>(_comm_size);
155
156 /* calculate splitting values and indices - find n-1 dividers splitting
157 * each segment into equal parts */
158 if (mp::use_sycl()) {
159#ifdef SYCL_LANGUAGE_VERSION
160 std::vector<sycl::event> events;
161
162 for (std::size_t i = 0; i < rng::size(vec_lmedians) - 1; i++) {
163 assert(i * _step_m < rng::size(lsegment));
164 sycl::event ev = sycl_queue().memcpy(
165 &vec_lmedians[i], &lsegment[i * _step_m], sizeof(valT));
166 events.emplace_back(ev);
167 }
168 sycl::event ev =
169 sycl_queue().memcpy(&vec_lmedians[rng::size(vec_lmedians) - 1],
170 &lsegment[rng::size(lsegment) - 1], sizeof(valT));
171 events.emplace_back(ev);
172 sycl::event::wait(events);
173#else
174 assert(false);
175#endif
176 } else {
177 for (std::size_t i = 0; i < rng::size(vec_lmedians) - 1; i++) {
178 assert(i * _step_m < rng::size(lsegment));
179 vec_lmedians[i] = lsegment[i * _step_m];
180 }
181 vec_lmedians.back() = lsegment.back();
182 }
183
184 default_comm().all_gather(vec_lmedians, vec_gmedians);
185 rng::sort(rng::begin(vec_gmedians), rng::end(vec_gmedians), comp);
186
187 std::vector<valT> vec_split_v(_comm_size - 1);
188
189 for (std::size_t i = 0; i < _comm_size - 1; i++) {
190 auto global_median_idx = (i + 1) * (_comm_size + 1) - 1;
191 assert(global_median_idx < rng::size(vec_gmedians));
192 vec_split_v[i] = vec_gmedians[global_median_idx];
193 }
194
195 /* The while loop is executed in host memory, and together with
196 * sycl_copy takes most of the execution time of the sort procedure */
197 if (mp::use_sycl()) {
198#ifdef SYCL_LANGUAGE_VERSION
199 auto &&local_policy = dpl_policy();
200 sycl::queue q = sycl_queue();
201
202 auto lsb = dr::__detail::direct_iterator(rng::begin(lsegment));
203 auto lse = dr::__detail::direct_iterator(rng::end(lsegment));
204
205 oneapi::dpl::lower_bound(local_policy, lsb, lse, vec_split_v.begin(),
206 vec_split_v.end(), vec_split_i.begin() + 1, comp);
207
208#else
209 assert(false);
210#endif
211 } else {
212 for (std::size_t i = 1; i <= rng::size(vec_split_v); i++) {
213 auto idx = vec_split_v[i - 1];
214 auto lower =
215 std::lower_bound(lsegment.begin(), lsegment.end(), idx, comp);
216 vec_split_i[i] = rng::distance(lsegment.begin(), lower);
217 }
218 }
219 for (std::size_t i = 1; i < vec_split_i.size(); i++) {
220 vec_split_s[i - 1] = vec_split_i[i] - vec_split_i[i - 1];
221 }
222 vec_split_s.back() = rng::size(lsegment) - vec_split_i.back();
223}
224
225template <typename valT>
226void shift_data(const int64_t shift_left, const int64_t shift_right,
227 buffer<valT> &vec_recvdata, buffer<valT> &vec_left,
228 buffer<valT> &vec_right) {
229 const std::size_t _comm_rank = default_comm().rank();
230
231 MPI_Request req_l, req_r;
232 MPI_Status stat_l, stat_r;
233
234 assert(static_cast<int64_t>(rng::size(vec_left)) == std::max(0L, shift_left));
235 assert(static_cast<int64_t>(rng::size(vec_right)) ==
236 std::max(0L, shift_right));
237
238 if (static_cast<int64_t>(rng::size(vec_recvdata)) < -shift_left) {
239 // Too little data in recv buffer to shift left - first get from right,
240 // then send left
241 DRLOG("Get from right first, recvdata size {} shift left {}",
242 rng::size(vec_recvdata), shift_left);
243
244 assert(shift_right > 0);
245
246 default_comm().irecv(rng::data(vec_right), rng::size(vec_right),
247 _comm_rank + 1, &req_r);
248 MPI_Wait(&req_r, &stat_r);
249
250 std::size_t old_size = rng::size(vec_recvdata);
251 vec_recvdata.resize(rng::size(vec_recvdata) + shift_right);
252
253 assert(rng::size(vec_right) <= rng::size(vec_recvdata) - old_size);
254
255 __detail::copy(rng::data(vec_right), rng::data(vec_recvdata) + old_size,
256 rng::size(vec_right));
257
258 vec_right.resize(0);
259
260 default_comm().isend(rng::data(vec_recvdata), -shift_left, _comm_rank - 1,
261 &req_l);
262 MPI_Wait(&req_l, &stat_l);
263
264 } else if (static_cast<int64_t>(rng::size(vec_recvdata)) < -shift_right) {
265 // Too little data in buffer to shift right - first get from left, then
266 // send right
267 // ** This will never happen, because values eq to split go right
268 DRLOG(
269 "Too little data in buffer to shift right - this should never happen");
270 assert(false);
271
272 } else {
273 // enough data in recv buffer
274 if (shift_left < 0) {
275 default_comm().isend(rng::data(vec_recvdata), -shift_left, _comm_rank - 1,
276 &req_l);
277 } else if (shift_left > 0) {
278 assert(shift_left == static_cast<int64_t>(rng::size(vec_left)));
279 default_comm().irecv(rng::data(vec_left), rng::size(vec_left),
280 _comm_rank - 1, &req_l);
281 }
282 if (shift_right > 0) {
283 assert(shift_right == static_cast<int64_t>(rng::size(vec_right)));
284 default_comm().irecv(rng::data(vec_right), rng::size(vec_right),
285 _comm_rank + 1, &req_r);
286 } else if (shift_right < 0) {
287 default_comm().isend(rng::data(vec_recvdata) + rng::size(vec_recvdata) +
288 shift_right,
289 -shift_right, _comm_rank + 1, &req_r);
290 }
291 if (shift_left != 0)
292 MPI_Wait(&req_l, &stat_l);
293 if (shift_right != 0)
294 MPI_Wait(&req_r, &stat_r);
295 }
296}
297
298template <typename valT>
299void copy_results(auto &lsegment, const int64_t shift_left,
300 const int64_t shift_right, buffer<valT> &vec_recvdata,
301 buffer<valT> &vec_left, buffer<valT> &vec_right) {
302 const std::size_t invalidate_left = std::max(-shift_left, 0L);
303 const std::size_t invalidate_right = std::max(-shift_right, 0L);
304
305 const std::size_t size_l = rng::size(vec_left);
306 const std::size_t size_r = rng::size(vec_right);
307 const std::size_t size_d =
308 rng::size(vec_recvdata) - (invalidate_left + invalidate_right);
309
310 if (mp::use_sycl()) {
311#ifdef SYCL_LANGUAGE_VERSION
312 sycl::event e_l, e_d, e_r;
313
314 if (size_l > 0) {
315 assert(size_l <= rng::size(lsegment));
316 e_l = sycl_queue().copy(rng::data(vec_left), rng::data(lsegment), size_l);
317 }
318 if (size_r > 0) {
319 assert(size_l + size_d + size_r <= rng::size(lsegment));
320 e_r = sycl_queue().copy(rng::data(vec_right),
321 rng::data(lsegment) + size_l + size_d, size_r);
322 }
323 if (size_d > 0) {
324 assert(size_l + size_d <= rng::size(lsegment));
325 assert(invalidate_left + size_d <= rng::size(vec_recvdata));
326 e_d = sycl_queue().copy(rng::data(vec_recvdata) + invalidate_left,
327 rng::data(lsegment) + size_l, size_d);
328 }
329 if (size_l > 0)
330 e_l.wait();
331 if (size_r > 0)
332 e_r.wait();
333 if (size_d > 0)
334 e_d.wait();
335
336#else
337 assert(false);
338#endif
339 } else {
340 if (size_l > 0) {
341 assert(size_l <= rng::size(lsegment));
342 std::copy(rng::begin(vec_left), rng::end(vec_left), rng::begin(lsegment));
343 }
344 if (size_r > 0) {
345 assert(size_l + size_d + size_r <= rng::size(lsegment));
346 std::copy(rng::begin(vec_right), rng::end(vec_right),
347 rng::begin(lsegment) + size_l + size_d);
348 }
349 if (size_d > 0) {
350 assert(size_l + size_d <= rng::size(lsegment));
351 assert(invalidate_left + size_d <= rng::size(vec_recvdata));
352 std::copy(rng::begin(vec_recvdata) + invalidate_left,
353 rng::begin(vec_recvdata) + invalidate_left + size_d,
354 rng::begin(lsegment) + size_l);
355 }
356 }
357}
358
359template <dr::distributed_range R, typename Compare>
360void dist_sort(R &r, Compare &&comp) {
361 using valT = typename R::value_type;
362
363 const std::size_t _comm_rank = default_comm().rank();
364 const std::size_t _comm_size = default_comm().size(); // dr-style ignore
365
366 auto &&lsegment = local_segment(r);
367
368 std::vector<std::size_t> vec_split_i(_comm_size, 0);
369 std::vector<std::size_t> vec_split_s(_comm_size, 0);
370 std::vector<std::size_t> vec_rsizes(_comm_size, 0);
371 std::vector<std::size_t> vec_rindices(_comm_size, 0);
372 std::vector<std::size_t> vec_recv_elems(_comm_size, 0);
373 std::size_t _total_elems = 0;
374
375 DRLOG("Rank {}: Dist sort, local segment size {}", default_comm().rank(),
376 rng::size(lsegment));
377 __detail::local_sort(lsegment, comp);
378
379 /* find splitting values - limits of areas to send to other processes */
380 __detail::splitters<valT>(lsegment, comp, vec_split_i, vec_split_s);
381 default_comm().alltoall(vec_split_s, vec_rsizes, 1);
382
383 /* prepare data to send and receive */
384 std::exclusive_scan(vec_rsizes.begin(), vec_rsizes.end(),
385 vec_rindices.begin(), 0);
386 const std::size_t _recv_elems = vec_rindices.back() + vec_rsizes.back();
387
388 /* send and receive data belonging to each node, then redistribute
389 * data to achieve size of data equal to size of local segment */
390 /* async i_all_gather causes problems on some systems */
391 // MPI_Request req_recvelems;
392 default_comm().all_gather(_recv_elems, vec_recv_elems);
393
394 /* buffer for received data */
395 buffer<valT> vec_recvdata(_recv_elems);
396
397 /* send data not belonging and receive data belonging to local processes
398 */
399 default_comm().alltoallv(lsegment, vec_split_s, vec_split_i, vec_recvdata,
400 vec_rsizes, vec_rindices);
401
402 __detail::local_merge(vec_recvdata, vec_rsizes, comp);
403
404 // MPI_Wait(&req_recvelems, MPI_STATUS_IGNORE);
405
406 _total_elems = std::reduce(vec_recv_elems.begin(), vec_recv_elems.end());
407
408 /* prepare data for shift to neighboring processes */
409 std::vector<int64_t> vec_shift(_comm_size - 1);
410
411 const auto desired_elems_num = (_total_elems + _comm_size - 1) / _comm_size;
412
413 vec_shift[0] = desired_elems_num - vec_recv_elems[0];
414 for (std::size_t i = 1; i < _comm_size - 1; i++) {
415 vec_shift[i] = vec_shift[i - 1] + desired_elems_num - vec_recv_elems[i];
416 }
417
418 const int64_t shift_left = _comm_rank == 0 ? 0 : -vec_shift[_comm_rank - 1];
419 const int64_t shift_right =
420 _comm_rank == _comm_size - 1 ? 0 : vec_shift[_comm_rank];
421
422 buffer<valT> vec_left(std::max(shift_left, 0L));
423 buffer<valT> vec_right(std::max(shift_right, 0L));
424
425 /* shift data if necessary, to have exactly the number of elements equal to
426 * lsegment size */
427 __detail::shift_data<valT>(shift_left, shift_right, vec_recvdata, vec_left,
428 vec_right);
429
430 /* copy results to distributed vector's local segment */
431 __detail::copy_results<valT>(lsegment, shift_left, shift_right, vec_recvdata,
432 vec_left, vec_right);
433} // __detail::dist_sort
434
435} // namespace __detail
436
437template <dr::distributed_range R, typename Compare = std::less<>>
438void sort(R &r, Compare &&comp = Compare()) {
439
440 using valT = typename R::value_type;
441
442 std::size_t _comm_rank = default_comm().rank();
443 std::size_t _comm_size = default_comm().size(); // dr-style ignore
444
445 if (_comm_size == 1) {
446 DRLOG("mp::sort() - one node only");
447 auto &&lsegment = local_segment(r);
448 __detail::local_sort(lsegment, comp);
449
450 } else if (rng::size(r) <= (_comm_size - 1) * (_comm_size - 1)) {
451 /* Distributed vector of size <= (comm_size-1) * (comm_size-1) may have
452 * 0-size local segments. It is also small enough to prefer sequential sort
453 */
454 DRLOG("mp::sort() - local sort on node 0");
455
456 std::vector<valT> vec_recvdata(rng::size(r));
457 dr::mp::copy(0, r, rng::begin(vec_recvdata));
458
459 if (_comm_rank == 0) {
460 rng::sort(vec_recvdata, comp);
461 }
462 dr::mp::barrier();
463 dr::mp::copy(0, vec_recvdata, rng::begin(r));
464
465 } else {
466 DRLOG("mp::sort() - distributed sort");
467 __detail::dist_sort(r, comp);
468 dr::mp::barrier();
469 }
470}
471
472template <dr::distributed_iterator RandomIt, typename Compare = std::less<>>
473void sort(RandomIt first, RandomIt last, Compare comp = Compare()) {
474 sort(rng::subrange(first, last), comp);
475}
476
477} // namespace dr::mp
Definition: onedpl_direct_iterator.hpp:15
Definition: allocator.hpp:11
Definition: sort.hpp:29