Distributed Ranges
Loading...
Searching...
No Matches
sycl_utils.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <limits>
8
9#include <dr/detail/utils.hpp>
10
11#ifdef SYCL_LANGUAGE_VERSION
12
13#include <sycl/sycl.hpp>
14
15namespace dr::__detail {
16
17// With the ND-range workaround, the maximum kernel size is
18// `std::numeric_limits<std::int32_t>::max()` rounded down to
19// the nearest multiple of the block size.
20inline std::size_t max_kernel_size_(std::size_t block_size = 128) {
21 std::size_t max_kernel_size = std::numeric_limits<std::int32_t>::max();
22 return (max_kernel_size / block_size) * block_size;
23}
24
25// This is a workaround to avoid performance degradation
26// in DPC++ for odd range sizes.
27template <typename Fn>
28sycl::event parallel_for_workaround(sycl::queue &q, sycl::range<1> numWorkItems,
29 Fn &&fn, std::size_t block_size = 128) {
30 std::size_t num_blocks = (numWorkItems.size() + block_size - 1) / block_size;
31
32 int32_t range_size = numWorkItems.size();
33
34 auto event = q.parallel_for(
35 sycl::nd_range<>(num_blocks * block_size, block_size), [=](auto nd_idx) {
36 auto idx = nd_idx.get_global_id(0);
37 if (idx < range_size) {
38 fn(idx);
39 }
40 });
41 return event;
42}
43
44template <typename Fn>
45sycl::event parallel_for_64bit(sycl::queue &q, sycl::range<1> numWorkItems,
46 Fn &&fn) {
47 std::size_t block_size = 128;
48 std::size_t max_kernel_size = max_kernel_size_(block_size);
49
50 std::vector<sycl::event> events;
51 for (std::size_t base_idx = 0; base_idx < numWorkItems.size();
52 base_idx += max_kernel_size) {
53 std::size_t launch_size =
54 std::min(numWorkItems.size() - base_idx, max_kernel_size);
55
56 auto e = parallel_for_workaround(
57 q, launch_size,
58 [=](sycl::id<1> idx_) {
59 sycl::id<1> idx(base_idx + idx_);
60 fn(idx);
61 },
62 block_size);
63
64 events.push_back(e);
65 }
66
67 auto e = q.submit([&](auto &&h) {
68 h.depends_on(events);
69 // Empty host task necessary due to [CMPLRLLVM-46542]
70 h.host_task([] {});
71 });
72
73 return e;
74}
75
76//
77// return true if the device can be partitioned by affinity domain
78//
79inline auto partitionable(sycl::device device) {
80 // Earlier commits used the query API, but they return true even
81 // though a partition will fail: Intel MPI mpirun with multiple
82 // processes.
83 try {
84 device.create_sub_devices<
85 sycl::info::partition_property::partition_by_affinity_domain>(
86 sycl::info::partition_affinity_domain::numa);
87 } catch (sycl::exception const &e) {
88 if (e.code() == sycl::errc::invalid ||
89 e.code() == sycl::errc::feature_not_supported) {
90 return false;
91 } else {
92 throw;
93 }
94 }
95
96 return true;
97}
98
99// Convert a global range to a nd_range using generic block size level
100// gpu requires uniform size workgroup, so round up to a multiple of a
101// workgroup.
102template <int Dim> auto nd_range(sycl::range<Dim> global) {
103 if constexpr (Dim == 1) {
104 sycl::range local(128);
105 return sycl::nd_range<Dim>(sycl::range(round_up(global[0], local[0])),
106 local);
107 } else if constexpr (Dim == 2) {
108 sycl::range local(16, 16);
109 return sycl::nd_range<Dim>(sycl::range(round_up(global[0], local[0]),
110 round_up(global[1], local[1])),
111 local);
112 } else if constexpr (Dim == 3) {
113 sycl::range local(8, 8, 8);
114 return sycl::nd_range<Dim>(sycl::range(round_up(global[0], local[0]),
115 round_up(global[1], local[1]),
116 round_up(global[2], local[2])),
117 local);
118 } else {
119 assert(false);
120 return sycl::range(0);
121 }
122}
123
124template <typename Fn>
125sycl::event parallel_for_nd(sycl::queue &q, sycl::range<1> global, Fn &&fn) {
126 return q.parallel_for(nd_range(global), [=](auto nd_idx) {
127 auto idx0 = nd_idx.get_global_id(0);
128 if (idx0 < global[0]) {
129 fn(idx0);
130 }
131 });
132}
133
134template <typename Fn>
135sycl::event parallel_for_nd(sycl::queue &q, sycl::range<2> global, Fn &&fn) {
136 return q.parallel_for(nd_range(global), [=](auto nd_idx) {
137 auto idx0 = nd_idx.get_global_id(0);
138 auto idx1 = nd_idx.get_global_id(1);
139 if (idx0 < global[0] && idx1 < global[1]) {
140 fn(std::array{idx0, idx1});
141 }
142 });
143}
144
145template <typename Fn>
146sycl::event parallel_for_nd(sycl::queue &q, sycl::range<3> global, Fn &&fn) {
147 return q.parallel_for(nd_range(global), [=](auto nd_idx) {
148 auto idx0 = nd_idx.get_global_id(0);
149 auto idx1 = nd_idx.get_global_id(1);
150 auto idx2 = nd_idx.get_global_id(2);
151 if (idx0 < global[0] && idx1 < global[1] && idx2 < global[2]) {
152 fn(std::array{idx0, idx1, idx2});
153 }
154 });
155}
156
157auto combine_events(sycl::queue &q, const auto &events) {
158 return q.submit([&](auto &&h) {
159 h.depends_on(events);
160 // Empty host task necessary due to [CMPLRLLVM-46542]
161 h.host_task([] {});
162 });
163}
164
165template <typename Fn>
166sycl::event parallel_for(sycl::queue &q, sycl::range<1> numWorkItems, Fn &&fn) {
167 std::size_t block_size = 128;
168 std::size_t max_kernel_size = max_kernel_size_();
169
170 if (numWorkItems.size() < max_kernel_size) {
171 return parallel_for_workaround(q, numWorkItems, std::forward<Fn>(fn),
172 block_size);
173 } else {
174 return parallel_for_64bit(q, numWorkItems, std::forward<Fn>(fn));
175 }
176}
177
178template <typename Fn>
179sycl::event parallel_for(sycl::queue &q, sycl::range<2> global, Fn &&fn) {
180 auto max = std::numeric_limits<std::int32_t>::max();
181 assert(global[0] < max && global[1] < max);
182 return parallel_for_nd(q, global, fn);
183}
184
185template <typename Fn>
186sycl::event parallel_for(sycl::queue &q, sycl::range<3> global, Fn &&fn) {
187 auto max = std::numeric_limits<std::int32_t>::max();
188 assert(global[0] < max && global[1] < max && global[2] < max);
189 return parallel_for_nd(q, global, fn);
190}
191
192using event = sycl::event;
193
194} // namespace dr::__detail
195
196#else
197
198namespace dr::__detail {
199
200class event {
201public:
202 void wait() {}
203};
204
205} // namespace dr::__detail
206
207#endif // SYCL_LANGUAGE_VERSION
Definition: sycl_utils.hpp:200