7#ifdef SYCL_LANGUAGE_VERSION
8#include <dr/sp/init.hpp>
9#include <oneapi/dpl/algorithm>
10#include <oneapi/dpl/execution>
11#include <oneapi/dpl/iterator>
19#include <dr/concepts/concepts.hpp>
20#include <dr/detail/logger.hpp>
21#include <dr/detail/onedpl_direct_iterator.hpp>
22#include <dr/detail/ranges_shim.hpp>
23#include <dr/mp/global.hpp>
33 std::size_t size() {
return size_; }
35 buffer(std::size_t cnt) : size_(cnt) {
37 data_ = alloc_.allocate(cnt);
38 assert(data_ !=
nullptr);
44 alloc_.deallocate(data_, size_);
49 T *resize(std::size_t cnt) {
54 alloc_.deallocate(data_, size_);
57 T *newdata = alloc_.allocate(cnt);
58 copy(data_, newdata, std::min(size_, cnt));
59 alloc_.deallocate(data_, size_);
66 void replace(
buffer &other) {
68 alloc_.deallocate(data_, size_);
70 data_ = rng::data(other);
71 size_ = rng::size(other);
72 other.data_ =
nullptr;
76 T *data() {
return data_; }
77 T *begin() {
return data_; }
78 T *end() {
return data_ + size_; }
83 std::size_t size_ = 0;
86template <
typename R,
typename Compare>
void local_sort(R &r, Compare &&comp) {
87 if (rng::size(r) >= 2) {
89#ifdef SYCL_LANGUAGE_VERSION
90 auto policy = dpl_policy();
91 auto &&local_segment = dr::ranges::__detail::local(r);
92 DRLOG(
"GPU dpl::sort(), size {}", rng::size(r));
100 DRLOG(
"cpu rng::sort, size {}", rng::size(r));
101 rng::sort(rng::begin(r), rng::end(r), comp);
106template <
typename T,
typename Compare>
107void local_merge(buffer<T> &v, std::vector<std::size_t> chunks,
110 std::exclusive_scan(chunks.begin(), chunks.end(), chunks.begin(), 0);
112 while (chunks.size() > 1) {
113 std::size_t segno = chunks.size();
114 std::vector<std::size_t> next_chunks;
115 for (std::size_t i = 0; i < segno / 2; i++) {
116 auto first = v.begin() + chunks[2 * i];
117 auto middle = v.begin() + chunks[2 * i + 1];
118 auto last = (2 * i + 2 < segno) ? v.begin() + chunks[2 * i + 2] : v.end();
119 if (mp::use_sycl()) {
120#ifdef SYCL_LANGUAGE_VERSION
124 oneapi::dpl::inplace_merge(dpl_policy(), dfirst, dmiddle, dlast, comp);
129 std::inplace_merge(first, middle, last, comp);
131 next_chunks.push_back(chunks[2 * i]);
133 if (segno % 2 == 1) {
134 next_chunks.push_back(chunks[segno - 1]);
136 std::swap(chunks, next_chunks);
141template <
typename valT,
typename Compare,
typename Seg>
142void splitters(Seg &lsegment, Compare &&comp,
143 std::vector<std::size_t> &vec_split_i,
144 std::vector<std::size_t> &vec_split_s) {
145 const std::size_t _comm_size = default_comm().size();
147 assert(rng::size(vec_split_i) == _comm_size);
148 assert(rng::size(vec_split_s) == _comm_size);
150 std::vector<valT> vec_lmedians(_comm_size + 1);
151 std::vector<valT> vec_gmedians((_comm_size + 1) * _comm_size);
153 const double _step_m =
static_cast<double>(rng::size(lsegment)) /
154 static_cast<double>(_comm_size);
158 if (mp::use_sycl()) {
159#ifdef SYCL_LANGUAGE_VERSION
160 std::vector<sycl::event> events;
162 for (std::size_t i = 0; i < rng::size(vec_lmedians) - 1; i++) {
163 assert(i * _step_m < rng::size(lsegment));
164 sycl::event ev = sycl_queue().memcpy(
165 &vec_lmedians[i], &lsegment[i * _step_m],
sizeof(valT));
166 events.emplace_back(ev);
169 sycl_queue().memcpy(&vec_lmedians[rng::size(vec_lmedians) - 1],
170 &lsegment[rng::size(lsegment) - 1],
sizeof(valT));
171 events.emplace_back(ev);
172 sycl::event::wait(events);
177 for (std::size_t i = 0; i < rng::size(vec_lmedians) - 1; i++) {
178 assert(i * _step_m < rng::size(lsegment));
179 vec_lmedians[i] = lsegment[i * _step_m];
181 vec_lmedians.back() = lsegment.back();
184 default_comm().all_gather(vec_lmedians, vec_gmedians);
185 rng::sort(rng::begin(vec_gmedians), rng::end(vec_gmedians), comp);
187 std::vector<valT> vec_split_v(_comm_size - 1);
189 for (std::size_t i = 0; i < _comm_size - 1; i++) {
190 auto global_median_idx = (i + 1) * (_comm_size + 1) - 1;
191 assert(global_median_idx < rng::size(vec_gmedians));
192 vec_split_v[i] = vec_gmedians[global_median_idx];
197 if (mp::use_sycl()) {
198#ifdef SYCL_LANGUAGE_VERSION
199 auto &&local_policy = dpl_policy();
200 sycl::queue q = sycl_queue();
205 oneapi::dpl::lower_bound(local_policy, lsb, lse, vec_split_v.begin(),
206 vec_split_v.end(), vec_split_i.begin() + 1, comp);
212 for (std::size_t i = 1; i <= rng::size(vec_split_v); i++) {
213 auto idx = vec_split_v[i - 1];
215 std::lower_bound(lsegment.begin(), lsegment.end(), idx, comp);
216 vec_split_i[i] = rng::distance(lsegment.begin(), lower);
219 for (std::size_t i = 1; i < vec_split_i.size(); i++) {
220 vec_split_s[i - 1] = vec_split_i[i] - vec_split_i[i - 1];
222 vec_split_s.back() = rng::size(lsegment) - vec_split_i.back();
225template <
typename valT>
226void shift_data(
const int64_t shift_left,
const int64_t shift_right,
227 buffer<valT> &vec_recvdata, buffer<valT> &vec_left,
228 buffer<valT> &vec_right) {
229 const std::size_t _comm_rank = default_comm().rank();
231 MPI_Request req_l, req_r;
232 MPI_Status stat_l, stat_r;
234 assert(
static_cast<int64_t
>(rng::size(vec_left)) == std::max(0L, shift_left));
235 assert(
static_cast<int64_t
>(rng::size(vec_right)) ==
236 std::max(0L, shift_right));
238 if (
static_cast<int64_t
>(rng::size(vec_recvdata)) < -shift_left) {
241 DRLOG(
"Get from right first, recvdata size {} shift left {}",
242 rng::size(vec_recvdata), shift_left);
244 assert(shift_right > 0);
246 default_comm().irecv(rng::data(vec_right), rng::size(vec_right),
247 _comm_rank + 1, &req_r);
248 MPI_Wait(&req_r, &stat_r);
250 std::size_t old_size = rng::size(vec_recvdata);
251 vec_recvdata.resize(rng::size(vec_recvdata) + shift_right);
253 assert(rng::size(vec_right) <= rng::size(vec_recvdata) - old_size);
255 __detail::copy(rng::data(vec_right), rng::data(vec_recvdata) + old_size,
256 rng::size(vec_right));
260 default_comm().isend(rng::data(vec_recvdata), -shift_left, _comm_rank - 1,
262 MPI_Wait(&req_l, &stat_l);
264 }
else if (
static_cast<int64_t
>(rng::size(vec_recvdata)) < -shift_right) {
269 "Too little data in buffer to shift right - this should never happen");
274 if (shift_left < 0) {
275 default_comm().isend(rng::data(vec_recvdata), -shift_left, _comm_rank - 1,
277 }
else if (shift_left > 0) {
278 assert(shift_left ==
static_cast<int64_t
>(rng::size(vec_left)));
279 default_comm().irecv(rng::data(vec_left), rng::size(vec_left),
280 _comm_rank - 1, &req_l);
282 if (shift_right > 0) {
283 assert(shift_right ==
static_cast<int64_t
>(rng::size(vec_right)));
284 default_comm().irecv(rng::data(vec_right), rng::size(vec_right),
285 _comm_rank + 1, &req_r);
286 }
else if (shift_right < 0) {
287 default_comm().isend(rng::data(vec_recvdata) + rng::size(vec_recvdata) +
289 -shift_right, _comm_rank + 1, &req_r);
292 MPI_Wait(&req_l, &stat_l);
293 if (shift_right != 0)
294 MPI_Wait(&req_r, &stat_r);
298template <
typename valT>
299void copy_results(
auto &lsegment,
const int64_t shift_left,
300 const int64_t shift_right, buffer<valT> &vec_recvdata,
301 buffer<valT> &vec_left, buffer<valT> &vec_right) {
302 const std::size_t invalidate_left = std::max(-shift_left, 0L);
303 const std::size_t invalidate_right = std::max(-shift_right, 0L);
305 const std::size_t size_l = rng::size(vec_left);
306 const std::size_t size_r = rng::size(vec_right);
307 const std::size_t size_d =
308 rng::size(vec_recvdata) - (invalidate_left + invalidate_right);
310 if (mp::use_sycl()) {
311#ifdef SYCL_LANGUAGE_VERSION
312 sycl::event e_l, e_d, e_r;
315 assert(size_l <= rng::size(lsegment));
316 e_l = sycl_queue().copy(rng::data(vec_left), rng::data(lsegment), size_l);
319 assert(size_l + size_d + size_r <= rng::size(lsegment));
320 e_r = sycl_queue().copy(rng::data(vec_right),
321 rng::data(lsegment) + size_l + size_d, size_r);
324 assert(size_l + size_d <= rng::size(lsegment));
325 assert(invalidate_left + size_d <= rng::size(vec_recvdata));
326 e_d = sycl_queue().copy(rng::data(vec_recvdata) + invalidate_left,
327 rng::data(lsegment) + size_l, size_d);
341 assert(size_l <= rng::size(lsegment));
342 std::copy(rng::begin(vec_left), rng::end(vec_left), rng::begin(lsegment));
345 assert(size_l + size_d + size_r <= rng::size(lsegment));
346 std::copy(rng::begin(vec_right), rng::end(vec_right),
347 rng::begin(lsegment) + size_l + size_d);
350 assert(size_l + size_d <= rng::size(lsegment));
351 assert(invalidate_left + size_d <= rng::size(vec_recvdata));
352 std::copy(rng::begin(vec_recvdata) + invalidate_left,
353 rng::begin(vec_recvdata) + invalidate_left + size_d,
354 rng::begin(lsegment) + size_l);
359template <dr::distributed_range R,
typename Compare>
360void dist_sort(R &r, Compare &&comp) {
361 using valT =
typename R::value_type;
363 const std::size_t _comm_rank = default_comm().rank();
364 const std::size_t _comm_size = default_comm().size();
366 auto &&lsegment = local_segment(r);
368 std::vector<std::size_t> vec_split_i(_comm_size, 0);
369 std::vector<std::size_t> vec_split_s(_comm_size, 0);
370 std::vector<std::size_t> vec_rsizes(_comm_size, 0);
371 std::vector<std::size_t> vec_rindices(_comm_size, 0);
372 std::vector<std::size_t> vec_recv_elems(_comm_size, 0);
373 std::size_t _total_elems = 0;
375 DRLOG(
"Rank {}: Dist sort, local segment size {}", default_comm().rank(),
376 rng::size(lsegment));
377 __detail::local_sort(lsegment, comp);
380 __detail::splitters<valT>(lsegment, comp, vec_split_i, vec_split_s);
381 default_comm().alltoall(vec_split_s, vec_rsizes, 1);
384 std::exclusive_scan(vec_rsizes.begin(), vec_rsizes.end(),
385 vec_rindices.begin(), 0);
386 const std::size_t _recv_elems = vec_rindices.back() + vec_rsizes.back();
392 default_comm().all_gather(_recv_elems, vec_recv_elems);
395 buffer<valT> vec_recvdata(_recv_elems);
399 default_comm().alltoallv(lsegment, vec_split_s, vec_split_i, vec_recvdata,
400 vec_rsizes, vec_rindices);
402 __detail::local_merge(vec_recvdata, vec_rsizes, comp);
406 _total_elems = std::reduce(vec_recv_elems.begin(), vec_recv_elems.end());
409 std::vector<int64_t> vec_shift(_comm_size - 1);
411 const auto desired_elems_num = (_total_elems + _comm_size - 1) / _comm_size;
413 vec_shift[0] = desired_elems_num - vec_recv_elems[0];
414 for (std::size_t i = 1; i < _comm_size - 1; i++) {
415 vec_shift[i] = vec_shift[i - 1] + desired_elems_num - vec_recv_elems[i];
418 const int64_t shift_left = _comm_rank == 0 ? 0 : -vec_shift[_comm_rank - 1];
419 const int64_t shift_right =
420 _comm_rank == _comm_size - 1 ? 0 : vec_shift[_comm_rank];
422 buffer<valT> vec_left(std::max(shift_left, 0L));
423 buffer<valT> vec_right(std::max(shift_right, 0L));
427 __detail::shift_data<valT>(shift_left, shift_right, vec_recvdata, vec_left,
431 __detail::copy_results<valT>(lsegment, shift_left, shift_right, vec_recvdata,
432 vec_left, vec_right);
437template <dr::distributed_range R,
typename Compare = std::less<>>
438void sort(R &r, Compare &&comp = Compare()) {
440 using valT =
typename R::value_type;
442 std::size_t _comm_rank = default_comm().rank();
443 std::size_t _comm_size = default_comm().size();
445 if (_comm_size == 1) {
446 DRLOG(
"mp::sort() - one node only");
447 auto &&lsegment = local_segment(r);
448 __detail::local_sort(lsegment, comp);
450 }
else if (rng::size(r) <= (_comm_size - 1) * (_comm_size - 1)) {
454 DRLOG(
"mp::sort() - local sort on node 0");
456 std::vector<valT> vec_recvdata(rng::size(r));
457 dr::mp::copy(0, r, rng::begin(vec_recvdata));
459 if (_comm_rank == 0) {
460 rng::sort(vec_recvdata, comp);
463 dr::mp::copy(0, vec_recvdata, rng::begin(r));
466 DRLOG(
"mp::sort() - distributed sort");
467 __detail::dist_sort(r, comp);
472template <dr::distributed_iterator RandomIt,
typename Compare = std::less<>>
473void sort(RandomIt first, RandomIt last, Compare comp = Compare()) {
474 sort(rng::subrange(first, last), comp);
Definition: onedpl_direct_iterator.hpp:15
Definition: allocator.hpp:11