6#include <dr/sp/detail.hpp>
7#include <dr/sp/init.hpp>
8#include <dr/sp/util.hpp>
25template <
class ExecutionPolicy>
30 std::is_same_v<std::remove_cvref_t<ExecutionPolicy>, device_policy>);
32 std::vector<sycl::event> events;
33 using OutT =
typename decltype(out)::value_type;
34 std::vector<void *> buffers;
35 const auto out_end = out + rng::size(in);
37 for (
auto &&[in_seg, out_seg] :
38 views::zip(in, rng::subrange(out, out_end)).zipped_segments()) {
39 auto in_device = policy.get_devices()[in_seg.rank()];
40 auto &&q = __detail::queue(dr::ranges::rank(in_seg));
41 const std::size_t seg_size = rng::size(in_seg);
42 assert(seg_size == rng::size(out_seg));
43 auto local_in_seg = __detail::local(in_seg);
45 if (in_seg.rank() == out_seg.rank()) {
46 auto local_out_seg = __detail::local(out_seg);
47 events.emplace_back(q.parallel_for(seg_size, [=](
auto idx) {
48 local_out_seg[idx] = fn(local_in_seg[idx]);
52 sycl::malloc_device<OutT>(seg_size, in_device, dr::sp::context());
53 buffers.push_back(buffer);
55 sycl::event compute_event = q.parallel_for(
56 seg_size, [=](
auto idx) { buffer[idx] = fn(local_in_seg[idx]); });
57 events.emplace_back(q.copy(buffer, __detail::local(out_seg.begin()),
58 seg_size, compute_event));
61 __detail::wait(events);
63 for (
auto *b : buffers)
64 sycl::free(b, dr::sp::context());
66 return rng::unary_transform_result<
decltype(rng::end(in)),
decltype(out_end)>{
67 rng::end(in), out_end};
70template <dr::distributed_range R, dr::distributed_iterator Iter,
typename Fn>
71auto transform(R &&in, Iter out, Fn &&fn) {
72 return transform(dr::sp::par_unseq, std::forward<R>(in),
73 std::forward<Iter>(out), std::forward<Fn>(fn));
78auto transform(ExecutionPolicy &&policy, Iter1 in_begin, Iter1 in_end,
79 Iter2 out_end, Fn &&fn) {
81 std::forward<ExecutionPolicy>(policy),
82 rng::subrange(std::forward<Iter1>(in_begin), std::forward<Iter1>(in_end)),
83 std::forward<Iter2>(out_end), std::forward<Fn>(fn));
88auto transform(Iter1 in_begin, Iter1 in_end, Iter2 out_end, Fn &&fn) {
89 return transform(dr::sp::par_unseq, std::forward<Iter1>(in_begin),
90 std::forward<Iter1>(in_end), std::forward<Iter2>(out_end),
91 std::forward<Fn>(fn));
Definition: concepts.hpp:31
Definition: concepts.hpp:20