7#include <oneapi/dpl/execution>
9#include <oneapi/dpl/algorithm>
10#include <oneapi/dpl/async>
12#include <dr/concepts/concepts.hpp>
13#include <dr/detail/onedpl_direct_iterator.hpp>
14#include <dr/sp/init.hpp>
17#include <sycl/sycl.hpp>
23template <
typename LocalPolicy,
typename InputIt,
typename Compare>
24sycl::event sort_async(LocalPolicy &&policy, InputIt first, InputIt last,
26 if (rng::distance(first, last) >= 2) {
29 return oneapi::dpl::experimental::sort_async(
30 std::forward<LocalPolicy>(policy), d_first, d_last,
31 std::forward<Compare>(comp));
37template <
typename LocalPolicy,
typename InputIt1,
typename InputIt2,
38 typename OutputIt,
typename Comparator = std::less<>>
39OutputIt lower_bound(LocalPolicy &&policy, InputIt1 start, InputIt1 end,
40 InputIt2 value_first, InputIt2 value_last, OutputIt result,
41 Comparator comp = Comparator()) {
50 return oneapi::dpl::lower_bound(std::forward<LocalPolicy>(policy), d_start,
51 d_end, d_value_first, d_value_last, d_result,
58template <dr::distributed_range R,
typename Compare = std::less<>>
59void sort(R &&r, Compare comp = Compare()) {
60 auto &&segments = dr::ranges::segments(r);
62 if (rng::size(segments) == 0) {
64 }
else if (rng::size(segments) == 1) {
65 auto &&segment = *rng::begin(segments);
67 dr::sp::__detail::dpl_policy(dr::ranges::rank(segment));
68 auto &&local_segment = dr::sp::__detail::local(segment);
70 __detail::sort_async(local_policy, rng::begin(local_segment),
71 rng::end(local_segment), comp)
76 using T = rng::range_value_t<R>;
77 std::vector<sycl::event> events;
79 const std::size_t n_segments = std::size_t(rng::size(segments));
80 const std::size_t n_splitters = n_segments - 1;
86 T *medians = sycl::malloc_device<T>(n_segments * n_splitters,
87 sp::devices()[0], sp::context());
89 for (
auto &&[segment_id_, segment] : rng::views::enumerate(segments)) {
90 auto const segment_id =
static_cast<std::size_t
>(segment_id_);
91 auto &&q = dr::sp::__detail::queue(dr::ranges::rank(segment));
93 dr::sp::__detail::dpl_policy(dr::ranges::rank(segment));
95 auto &&local_segment = dr::sp::__detail::local(segment);
97 auto s = __detail::sort_async(local_policy, rng::begin(local_segment),
98 rng::end(local_segment), comp);
100 double step_size =
static_cast<double>(rng::size(segment)) / n_segments;
102 auto local_begin = rng::begin(local_segment);
104 auto e = q.submit([&](
auto &&h) {
107 h.parallel_for(n_splitters, [=](
auto i) {
108 medians[n_splitters * segment_id + i] =
109 local_begin[std::size_t(step_size * (i + 1) + 0.5)];
116 dr::sp::__detail::wait(events);
121 auto &&local_policy = dr::sp::__detail::dpl_policy(0);
122 __detail::sort_async(local_policy, medians,
123 medians + n_segments * n_splitters, comp)
126 double step_size =
static_cast<double>(n_segments * n_splitters) / n_segments;
131 auto &&q = dr::sp::__detail::queue(0);
133 for (std::size_t i = 0; i < n_splitters; i++) {
134 medians[i] = medians[std::size_t(step_size * (i + 1) + 0.5)];
138 std::vector<std::size_t *> splitter_indices;
141 std::vector<std::size_t> sorted_seg_sizes(n_segments, 0);
144 std::vector<std::vector<std::size_t>> push_positions(n_segments);
150 for (
auto &&[segment_id, segment] : rng::views::enumerate(segments)) {
151 auto &&q = dr::sp::__detail::queue(dr::ranges::rank(segment));
152 auto &&local_policy =
153 dr::sp::__detail::dpl_policy(dr::ranges::rank(segment));
155 auto &&local_segment = dr::sp::__detail::local(segment);
160 std::size_t *splitter_i = sycl::malloc_shared<std::size_t>(
161 n_segments, q.get_device(), sp::context());
162 splitter_indices.push_back(splitter_i);
166 sycl::malloc_device<T>(n_splitters, q.get_device(), sp::context());
168 q.memcpy(medians_l, medians,
sizeof(T) * n_splitters).wait();
170 __detail::lower_bound(local_policy, rng::begin(local_segment),
171 rng::end(local_segment), medians_l,
172 medians_l + n_splitters, splitter_i, comp);
174 sycl::free(medians_l, sp::context());
176 splitter_i[n_splitters] = rng::size(local_segment);
178 for (std::size_t i = 0; i < n_segments; i++) {
179 const std::size_t n_elements =
180 splitter_i[i] - (i == 0 ? 0 : splitter_i[i - 1]);
181 const std::size_t pos =
182 std::atomic_ref(sorted_seg_sizes[i]).fetch_add(n_elements);
183 push_positions[
static_cast<std::size_t
>(segment_id)].push_back(pos);
188 std::vector<T *> sorted_segments;
190 for (
auto &&[segment_id, segment] : rng::views::enumerate(segments)) {
191 auto &&q = dr::sp::__detail::queue(dr::ranges::rank(segment));
193 T *buffer = sycl::malloc_device<T>(
194 sorted_seg_sizes[
static_cast<std::size_t
>(segment_id)], q);
195 sorted_segments.push_back(buffer);
199 for (
auto &&[segment_id_, segment] : rng::views::enumerate(segments)) {
200 auto &&local_segment = dr::sp::__detail::local(segment);
201 const auto segment_id =
static_cast<std::size_t
>(segment_id_);
203 std::size_t *splitter_i = splitter_indices[segment_id];
205 auto p_first = rng::begin(local_segment);
206 auto p_last = p_first;
207 for (std::size_t i = 0; i < n_segments; i++) {
208 p_last = rng::begin(local_segment) + splitter_i[i];
210 const std::size_t pos = push_positions[segment_id][i];
212 auto e = sp::copy_async(p_first, p_last, sorted_segments[i] + pos);
219 dr::sp::__detail::wait(events);
224#pragma omp parallel num_threads(n_segments)
226 int t = omp_get_thread_num();
228 std::vector<std::size_t> chunks_ind;
229 for (std::size_t i = 0; i < n_segments; i++) {
230 chunks_ind.push_back(push_positions[i][t]);
233 auto _segments = n_segments;
234 while (_segments > 1) {
235 std::vector<std::size_t> new_chunks;
236 new_chunks.push_back(0);
238 for (
int s = 0; s < _segments / 2; s++) {
240 const std::size_t l = (2 * s + 2 < _segments) ? chunks_ind[2 * s + 2]
241 : sorted_seg_sizes[t];
246 chunks_ind[2 * s + 1]);
249 new_chunks.push_back(l);
251 oneapi::dpl::inplace_merge(
252 __detail::dpl_policy(dr::ranges::rank(segments[t])), first, middle,
253 last, std::forward<Compare>(comp));
256 _segments = (_segments + 1) / 2;
258 std::swap(chunks_ind, new_chunks);
264 auto d_first = rng::begin(r);
266 for (std::size_t i = 0; i < sorted_segments.size(); i++) {
267 T *seg = sorted_segments[i];
268 std::size_t n_elements = sorted_seg_sizes[i];
270 auto e = sp::copy_async(seg, seg + n_elements, d_first);
274 rng::advance(d_first, n_elements);
277 dr::sp::__detail::wait(events);
281 for (
auto &&sorted_seg : sorted_segments) {
282 sycl::free(sorted_seg, sp::context());
285 for (
auto &&splitter_i : splitter_indices) {
286 sycl::free(splitter_i, sp::context());
289 sycl::free(medians, sp::context());
292template <dr::distributed_iterator RandomIt,
typename Compare = std::less<>>
293void sort(RandomIt first, RandomIt last, Compare comp = Compare()) {
294 sort(rng::subrange(first, last), comp);
Definition: onedpl_direct_iterator.hpp:15