Distributed Ranges
Loading...
Searching...
No Matches
mdspan_utils.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <dr/detail/mdspan_shim.hpp>
8
9namespace dr::__detail {
10
11template <std::size_t Rank> auto dims(md::dextents<std::size_t, Rank> extents) {
12 if constexpr (Rank == 1) {
13 return std::tuple(extents.extent(0));
14 } else if constexpr (Rank == 2) {
15 return std::tuple(extents.extent(0), extents.extent(1));
16 } else if constexpr (Rank == 3) {
17 return std::tuple(extents.extent(0), extents.extent(1), extents.extent(2));
18 } else {
19 assert(false);
20 }
21}
22
23template <typename Index> auto shape_to_strides(const Index &shape) {
24 const std::size_t rank = rng::size(shape);
25 Index strides;
26 strides[rank - 1] = 1;
27 for (std::size_t i = 1; i < rank; i++) {
28 strides[rank - i - 1] = strides[rank - i] * shape[rank - i];
29 }
30 return strides;
31}
32
33template <typename Index>
34auto linear_to_index(std::size_t linear, const Index &shape) {
35 Index index, strides(shape_to_strides(shape));
36
37 for (std::size_t i = 0; i < rng::size(shape); i++) {
38 index[i] = linear / strides[i];
39 linear = linear % strides[i];
40 }
41
42 return index;
43}
44
45template <typename Mdspan>
46concept mdspan_like = requires(Mdspan &mdspan) {
47 mdspan.rank();
48 mdspan.extents();
49};
50
51template <typename Mdarray>
52concept mdarray_like = requires(Mdarray &mdarray) { mdarray.to_mdspan(); };
53
54template <std::size_t Rank> using dr_extents = std::array<std::size_t, Rank>;
55template <std::size_t Rank> using md_extents = md::dextents<std::size_t, Rank>;
56
57//
58// Mdspan accessor using an iterator
59//
60template <std::random_access_iterator Iter> class mdspan_iter_accessor {
61public:
62 using data_handle_type = Iter;
63 using reference = std::iter_reference_t<Iter>;
65
66 constexpr mdspan_iter_accessor() noexcept = default;
67 constexpr auto access(Iter iter, std::size_t index) const {
68 return iter[index];
69 }
70
71 constexpr auto offset(Iter iter, std::size_t index) const noexcept {
72 return iter + index;
73 }
74};
75
76template <typename M, std::size_t Rank, std::size_t... indexes>
77auto make_submdspan_impl(M mdspan, const dr_extents<Rank> &starts,
78 const dr_extents<Rank> &ends,
79 std::index_sequence<indexes...>) {
80 return md::submdspan(mdspan, std::tuple(starts[indexes], ends[indexes])...);
81}
82
83// Mdspan accepts slices, but that is hard to work with because it
84// requires parameter packs. Work with starts/size vectors internally
85// and use slices at the interface
86template <std::size_t Rank>
87auto make_submdspan(auto mdspan, const std::array<std::size_t, Rank> &starts,
88 const std::array<std::size_t, Rank> &ends) {
89 return make_submdspan_impl(mdspan, starts, ends,
90 std::make_index_sequence<Rank>{});
91}
92
93template <std::size_t Rank, typename Op>
94void mdspan_foreach(md_extents<Rank> extents, Op op,
95 dr_extents<Rank> index = dr_extents<Rank>(),
96 std::size_t rank = 0) {
97 for (index[rank] = 0; index[rank] < extents.extent(rank); index[rank]++) {
98 if (rank == Rank - 1) {
99 op(index);
100 } else {
101 mdspan_foreach(extents, op, index, rank + 1);
102 }
103 }
104}
105
106// Pack mdspan into contiguous container
107template <mdspan_like Src>
108auto mdspan_copy(Src src, std::forward_iterator auto dst) {
109 __detail::event event;
110
111 constexpr std::size_t rank = std::remove_cvref_t<Src>::rank();
112 if (rank >= 2 && rank <= 3 && mp::use_sycl()) {
113#ifdef SYCL_LANGUAGE_VERSION
114 constexpr std::size_t rank = std::remove_cvref_t<Src>::rank();
115 if constexpr (rank == 2) {
116 event = dr::__detail::parallel_for(
117 dr::mp::sycl_queue(), sycl::range(src.extent(0), src.extent(1)),
118 [src, dst](auto idx) {
119 dst[idx[0] * src.extent(1) + idx[1]] = src(idx);
120 });
121 } else if constexpr (rank == 3) {
122 event = dr::__detail::parallel_for(
123 dr::mp::sycl_queue(),
124 sycl::range(src.extent(0), src.extent(1), src.extent(2)),
125 [src, dst](auto idx) {
126 dst[idx[0] * src.extent(1) * src.extent(2) +
127 idx[1] * src.extent(2) + idx[2]] = src(idx);
128 });
129 } else {
130 assert(false);
131 }
132#endif
133 } else {
134 auto pack = [src, &dst](auto index) { *dst++ = src(index); };
135 mdspan_foreach<src.rank(), decltype(pack)>(src.extents(), pack);
136 }
137
138 return event;
139}
140
141// unpack contiguous container into mdspan
142template <mdspan_like Dst>
143auto mdspan_copy(std::forward_iterator auto src, Dst dst) {
144 __detail::event event;
145
146 constexpr std::size_t rank = std::remove_cvref_t<Dst>::rank();
147 if (rank >= 2 && rank <= 3 && mp::use_sycl()) {
148#ifdef SYCL_LANGUAGE_VERSION
149 if constexpr (rank == 2) {
150 event = dr::__detail::parallel_for(
151 dr::mp::sycl_queue(), sycl::range(dst.extent(0), dst.extent(1)),
152 [src, dst](auto idx) {
153 dst(idx) = src[idx[0] * dst.extent(1) + idx[1]];
154 });
155 } else if constexpr (rank == 3) {
156 event = dr::__detail::parallel_for(
157 dr::mp::sycl_queue(),
158 sycl::range(dst.extent(0), dst.extent(1), dst.extent(2)),
159 [src, dst](auto idx) {
160 dst(idx) = src[idx[0] * dst.extent(1) * dst.extent(2) +
161 idx[1] * dst.extent(2) + idx[2]];
162 });
163 } else {
164 assert(false);
165 }
166#endif
167 } else {
168 auto unpack = [&src, dst](auto index) { dst(index) = *src++; };
169 mdspan_foreach<dst.rank(), decltype(unpack)>(dst.extents(), unpack);
170 }
171
172 return event;
173}
174
175// copy mdspan to mdspan
176auto mdspan_copy(mdspan_like auto src, mdspan_like auto dst) {
177 __detail::event event;
178
179 assert(src.extents() == dst.extents());
180
181 constexpr std::size_t rank = std::remove_cvref_t<decltype(src)>::rank();
182 if (rank >= 2 && rank <= 3 && mp::use_sycl()) {
183#ifdef SYCL_LANGUAGE_VERSION
184 dr::drlog.debug("mdspan_copy using sycl\n");
185 if constexpr (rank == 2) {
186 event = dr::__detail::parallel_for(
187 dr::mp::sycl_queue(), sycl::range(dst.extent(0), dst.extent(1)),
188 [src, dst](auto idx) { dst(idx) = src(idx); });
189 } else if constexpr (rank == 3) {
190 event = dr::__detail::parallel_for(
191 dr::mp::sycl_queue(),
192 sycl::range(dst.extent(0), dst.extent(1), dst.extent(2)),
193 [src, dst](auto idx) { dst(idx) = src(idx); });
194 } else {
195 assert(false);
196 }
197#endif
198 } else {
199
200 auto copy = [src, dst](auto index) { dst(index) = src(index); };
201 mdspan_foreach<src.rank(), decltype(copy)>(src.extents(), copy);
202 }
203
204 return event;
205}
206
207// For operator(), rearrange indices according to template arguments.
208//
209// For mdtranspose<mdspan3d, 2, 0, 1> a(b);
210//
211// a(1, 2, 3) references b(3, 1, 2)
212//
213template <typename Mdspan, std::size_t... Is>
214class mdtranspose : public Mdspan {
215private:
216 static constexpr std::size_t rank_ = Mdspan::rank();
217
218public:
219 // Inherit constructors from base class
220 mdtranspose(Mdspan &mdspan) : Mdspan(mdspan) {}
221
222 // rearrange indices according to template arguments
223 template <std::integral... Indexes>
224 auto &operator()(Indexes... indexes) const {
225 std::tuple index(indexes...);
226 return Mdspan::operator()(std::get<Is>(index)...);
227 }
228 auto &operator()(std::array<std::size_t, rank_> index) const {
229 return Mdspan::operator()(index[Is]...);
230 }
231
232 auto extents() const {
233 // To get the extents, we must invert the index mapping
234 std::array<std::size_t, rank_> from_transposed({Is...});
235 std::array<std::size_t, rank_> extents_t;
236 for (std::size_t i = 0; i < rank_; i++) {
237 extents_t[from_transposed[i]] = Mdspan::extent(i);
238 }
239
240 return md_extents<rank_>(extents_t);
241 }
242 auto extent(std::size_t d) const { return extents().extent(d); }
243};
244
245} // namespace dr::__detail
246
247template <dr::__detail::mdspan_like Mdspan>
248struct fmt::formatter<Mdspan, char> : public formatter<string_view> {
249 template <typename FmtContext>
250 auto format(Mdspan mdspan, FmtContext &ctx) const {
251 std::array<std::size_t, mdspan.rank()> index;
252 rng::fill(index, 0);
253 format_mdspan(ctx, mdspan, index, 0);
254 return ctx.out();
255 }
256
257 void format_mdspan(auto &ctx, auto mdspan, auto &index,
258 std::size_t dim) const {
259 for (std::size_t i = 0; i < mdspan.extent(dim); i++) {
260 index[dim] = i;
261 if (dim == mdspan.rank() - 1) {
262 if (i == 0) {
263 fmt::format_to(ctx.out(), "{}: ", index);
264 }
265 fmt::format_to(ctx.out(), "{:4} ", mdspan(index));
266 } else {
267 format_mdspan(ctx, mdspan, index, dim + 1);
268 }
269 }
270 fmt::format_to(ctx.out(), "\n");
271 }
272};
273
274namespace MDSPAN_NAMESPACE {
275
276template <dr::__detail::mdspan_like M1, dr::__detail::mdspan_like M2>
277bool operator==(const M1 &m1, const M2 &m2) {
278 constexpr std::size_t rank1 = M1::rank(), rank2 = M2::rank();
279 static_assert(rank1 == rank2);
280 if (dr::__detail::dims<rank1>(m1.extents()) !=
281 dr::__detail::dims<rank1>(m2.extents())) {
282 return false;
283 }
284
285 // See mdspan_foreach for a way to generalize this to all ranks
286 if constexpr (M1::rank() == 1) {
287 for (std::size_t i = 0; i < m1.extent(0); i++) {
288 if (m1(i) != m2(i)) {
289 return false;
290 }
291 }
292 } else if constexpr (M1::rank() == 2) {
293 for (std::size_t i = 0; i < m1.extent(0); i++) {
294 for (std::size_t j = 0; j < m1.extent(1); j++) {
295 if (m1(i, j) != m2(i, j)) {
296 return false;
297 }
298 }
299 }
300 } else if constexpr (M1::rank() == 3) {
301 for (std::size_t i = 0; i < m1.extent(0); i++) {
302 for (std::size_t j = 0; j < m1.extent(1); j++) {
303 for (std::size_t k = 0; k < m1.extent(2); k++) {
304 if (m1(i, j, k) != m2(i, j, k)) {
305 return false;
306 }
307 }
308 }
309 }
310 } else {
311 assert(false);
312 }
313
314 return true;
315}
316
317template <dr::__detail::mdspan_like M>
318inline std::ostream &operator<<(std::ostream &os, const M &m) {
319 if constexpr (dr::__detail::mdarray_like<M>) {
320 os << fmt::format("\n{}", m.to_mdspan());
321 } else {
322 os << fmt::format("\n{}", m);
323 }
324 return os;
325}
326
327} // namespace MDSPAN_NAMESPACE
328
329namespace dr {
330
331template <typename R>
333 distributed_range<R> && requires(R &r) { r.mdspan(); };
334
335} // namespace dr
Definition: mdspan_utils.hpp:60
Definition: mdspan_utils.hpp:214
Definition: index.hpp:34
Definition: mdspan_utils.hpp:52
Definition: mdspan_utils.hpp:46
Definition: mdspan_utils.hpp:332
Definition: concepts.hpp:20