9#include <dr/detail/utils.hpp>
11#ifdef SYCL_LANGUAGE_VERSION
13#include <sycl/sycl.hpp>
15namespace dr::__detail {
20inline std::size_t max_kernel_size_(std::size_t block_size = 128) {
21 std::size_t max_kernel_size = std::numeric_limits<std::int32_t>::max();
22 return (max_kernel_size / block_size) * block_size;
28sycl::event parallel_for_workaround(sycl::queue &q, sycl::range<1> numWorkItems,
29 Fn &&fn, std::size_t block_size = 128) {
30 std::size_t num_blocks = (numWorkItems.size() + block_size - 1) / block_size;
32 int32_t range_size = numWorkItems.size();
34 auto event = q.parallel_for(
35 sycl::nd_range<>(num_blocks * block_size, block_size), [=](
auto nd_idx) {
36 auto idx = nd_idx.get_global_id(0);
37 if (idx < range_size) {
45sycl::event parallel_for_64bit(sycl::queue &q, sycl::range<1> numWorkItems,
47 std::size_t block_size = 128;
48 std::size_t max_kernel_size = max_kernel_size_(block_size);
50 std::vector<sycl::event> events;
51 for (std::size_t base_idx = 0; base_idx < numWorkItems.size();
52 base_idx += max_kernel_size) {
53 std::size_t launch_size =
54 std::min(numWorkItems.size() - base_idx, max_kernel_size);
56 auto e = parallel_for_workaround(
58 [=](sycl::id<1> idx_) {
59 sycl::id<1> idx(base_idx + idx_);
67 auto e = q.submit([&](
auto &&h) {
79inline auto partitionable(sycl::device device) {
84 device.create_sub_devices<
85 sycl::info::partition_property::partition_by_affinity_domain>(
86 sycl::info::partition_affinity_domain::numa);
87 }
catch (sycl::exception
const &e) {
88 if (e.code() == sycl::errc::invalid ||
89 e.code() == sycl::errc::feature_not_supported) {
102template <
int Dim>
auto nd_range(sycl::range<Dim> global) {
103 if constexpr (Dim == 1) {
104 sycl::range local(128);
105 return sycl::nd_range<Dim>(sycl::range(round_up(global[0], local[0])),
107 }
else if constexpr (Dim == 2) {
108 sycl::range local(16, 16);
109 return sycl::nd_range<Dim>(sycl::range(round_up(global[0], local[0]),
110 round_up(global[1], local[1])),
112 }
else if constexpr (Dim == 3) {
113 sycl::range local(8, 8, 8);
114 return sycl::nd_range<Dim>(sycl::range(round_up(global[0], local[0]),
115 round_up(global[1], local[1]),
116 round_up(global[2], local[2])),
120 return sycl::range(0);
124template <
typename Fn>
125sycl::event parallel_for_nd(sycl::queue &q, sycl::range<1> global, Fn &&fn) {
126 return q.parallel_for(nd_range(global), [=](
auto nd_idx) {
127 auto idx0 = nd_idx.get_global_id(0);
128 if (idx0 < global[0]) {
134template <
typename Fn>
135sycl::event parallel_for_nd(sycl::queue &q, sycl::range<2> global, Fn &&fn) {
136 return q.parallel_for(nd_range(global), [=](
auto nd_idx) {
137 auto idx0 = nd_idx.get_global_id(0);
138 auto idx1 = nd_idx.get_global_id(1);
139 if (idx0 < global[0] && idx1 < global[1]) {
140 fn(std::array{idx0, idx1});
145template <
typename Fn>
146sycl::event parallel_for_nd(sycl::queue &q, sycl::range<3> global, Fn &&fn) {
147 return q.parallel_for(nd_range(global), [=](
auto nd_idx) {
148 auto idx0 = nd_idx.get_global_id(0);
149 auto idx1 = nd_idx.get_global_id(1);
150 auto idx2 = nd_idx.get_global_id(2);
151 if (idx0 < global[0] && idx1 < global[1] && idx2 < global[2]) {
152 fn(std::array{idx0, idx1, idx2});
157auto combine_events(sycl::queue &q,
const auto &events) {
158 return q.submit([&](
auto &&h) {
159 h.depends_on(events);
165template <
typename Fn>
166sycl::event parallel_for(sycl::queue &q, sycl::range<1> numWorkItems, Fn &&fn) {
167 std::size_t block_size = 128;
168 std::size_t max_kernel_size = max_kernel_size_();
170 if (numWorkItems.size() < max_kernel_size) {
171 return parallel_for_workaround(q, numWorkItems, std::forward<Fn>(fn),
174 return parallel_for_64bit(q, numWorkItems, std::forward<Fn>(fn));
178template <
typename Fn>
179sycl::event parallel_for(sycl::queue &q, sycl::range<2> global, Fn &&fn) {
180 auto max = std::numeric_limits<std::int32_t>::max();
181 assert(global[0] < max && global[1] < max);
182 return parallel_for_nd(q, global, fn);
185template <
typename Fn>
186sycl::event parallel_for(sycl::queue &q, sycl::range<3> global, Fn &&fn) {
187 auto max = std::numeric_limits<std::int32_t>::max();
188 assert(global[0] < max && global[1] < max && global[2] < max);
189 return parallel_for_nd(q, global, fn);
192using event = sycl::event;
198namespace dr::__detail {
Definition: sycl_utils.hpp:200