Distributed Ranges
Loading...
Searching...
No Matches
transform.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4#pragma once
5
6#include <dr/sp/detail.hpp>
7#include <dr/sp/init.hpp>
8#include <dr/sp/util.hpp>
9
10namespace dr::sp {
11
25template <class ExecutionPolicy>
26auto transform(ExecutionPolicy &&policy, dr::distributed_range auto &&in,
27 dr::distributed_iterator auto out, auto &&fn) {
28
29 static_assert( // currently only one policy supported
30 std::is_same_v<std::remove_cvref_t<ExecutionPolicy>, device_policy>);
31
32 std::vector<sycl::event> events;
33 using OutT = typename decltype(out)::value_type;
34 std::vector<void *> buffers;
35 const auto out_end = out + rng::size(in);
36
37 for (auto &&[in_seg, out_seg] :
38 views::zip(in, rng::subrange(out, out_end)).zipped_segments()) {
39 auto in_device = policy.get_devices()[in_seg.rank()];
40 auto &&q = __detail::queue(dr::ranges::rank(in_seg));
41 const std::size_t seg_size = rng::size(in_seg);
42 assert(seg_size == rng::size(out_seg));
43 auto local_in_seg = __detail::local(in_seg);
44
45 if (in_seg.rank() == out_seg.rank()) {
46 auto local_out_seg = __detail::local(out_seg);
47 events.emplace_back(q.parallel_for(seg_size, [=](auto idx) {
48 local_out_seg[idx] = fn(local_in_seg[idx]);
49 }));
50 } else {
51 OutT *buffer =
52 sycl::malloc_device<OutT>(seg_size, in_device, dr::sp::context());
53 buffers.push_back(buffer);
54
55 sycl::event compute_event = q.parallel_for(
56 seg_size, [=](auto idx) { buffer[idx] = fn(local_in_seg[idx]); });
57 events.emplace_back(q.copy(buffer, __detail::local(out_seg.begin()),
58 seg_size, compute_event));
59 }
60 }
61 __detail::wait(events);
62
63 for (auto *b : buffers)
64 sycl::free(b, dr::sp::context());
65
66 return rng::unary_transform_result<decltype(rng::end(in)), decltype(out_end)>{
67 rng::end(in), out_end};
68}
69
70template <dr::distributed_range R, dr::distributed_iterator Iter, typename Fn>
71auto transform(R &&in, Iter out, Fn &&fn) {
72 return transform(dr::sp::par_unseq, std::forward<R>(in),
73 std::forward<Iter>(out), std::forward<Fn>(fn));
74}
75
76template <typename ExecutionPolicy, dr::distributed_iterator Iter1,
77 dr::distributed_iterator Iter2, typename Fn>
78auto transform(ExecutionPolicy &&policy, Iter1 in_begin, Iter1 in_end,
79 Iter2 out_end, Fn &&fn) {
80 return transform(
81 std::forward<ExecutionPolicy>(policy),
82 rng::subrange(std::forward<Iter1>(in_begin), std::forward<Iter1>(in_end)),
83 std::forward<Iter2>(out_end), std::forward<Fn>(fn));
84}
85
87 typename Fn>
88auto transform(Iter1 in_begin, Iter1 in_end, Iter2 out_end, Fn &&fn) {
89 return transform(dr::sp::par_unseq, std::forward<Iter1>(in_begin),
90 std::forward<Iter1>(in_end), std::forward<Iter2>(out_end),
91 std::forward<Fn>(fn));
92}
93
94} // namespace dr::sp
Definition: concepts.hpp:31
Definition: concepts.hpp:20