7#include <dr/detail/mdspan_shim.hpp>
9namespace dr::__detail {
11template <std::
size_t Rank>
auto dims(md::dextents<std::size_t, Rank> extents) {
12 if constexpr (Rank == 1) {
13 return std::tuple(extents.extent(0));
14 }
else if constexpr (Rank == 2) {
15 return std::tuple(extents.extent(0), extents.extent(1));
16 }
else if constexpr (Rank == 3) {
17 return std::tuple(extents.extent(0), extents.extent(1), extents.extent(2));
23template <
typename Index>
auto shape_to_strides(
const Index &shape) {
24 const std::size_t rank = rng::size(shape);
26 strides[rank - 1] = 1;
27 for (std::size_t i = 1; i < rank; i++) {
28 strides[rank - i - 1] = strides[rank - i] * shape[rank - i];
33template <
typename Index>
34auto linear_to_index(std::size_t linear,
const Index &shape) {
35 Index index, strides(shape_to_strides(shape));
37 for (std::size_t i = 0; i < rng::size(shape); i++) {
38 index[i] = linear / strides[i];
39 linear = linear % strides[i];
45template <
typename Mdspan>
51template <
typename Mdarray>
52concept mdarray_like =
requires(Mdarray &mdarray) { mdarray.to_mdspan(); };
54template <std::
size_t Rank>
using dr_extents = std::array<std::size_t, Rank>;
55template <std::
size_t Rank>
using md_extents = md::dextents<std::size_t, Rank>;
62 using data_handle_type = Iter;
63 using reference = std::iter_reference_t<Iter>;
67 constexpr auto access(Iter iter, std::size_t
index)
const {
71 constexpr auto offset(Iter iter, std::size_t
index)
const noexcept {
76template <
typename M, std::size_t Rank, std::size_t... indexes>
77auto make_submdspan_impl(M mdspan,
const dr_extents<Rank> &starts,
78 const dr_extents<Rank> &ends,
79 std::index_sequence<indexes...>) {
80 return md::submdspan(mdspan, std::tuple(starts[indexes], ends[indexes])...);
86template <std::
size_t Rank>
87auto make_submdspan(
auto mdspan,
const std::array<std::size_t, Rank> &starts,
88 const std::array<std::size_t, Rank> &ends) {
89 return make_submdspan_impl(mdspan, starts, ends,
90 std::make_index_sequence<Rank>{});
93template <std::
size_t Rank,
typename Op>
94void mdspan_foreach(md_extents<Rank> extents, Op op,
95 dr_extents<Rank> index = dr_extents<Rank>(),
96 std::size_t rank = 0) {
97 for (index[rank] = 0; index[rank] < extents.extent(rank); index[rank]++) {
98 if (rank == Rank - 1) {
101 mdspan_foreach(extents, op, index, rank + 1);
107template <mdspan_like Src>
108auto mdspan_copy(Src src, std::forward_iterator
auto dst) {
109 __detail::event event;
111 constexpr std::size_t rank = std::remove_cvref_t<Src>::rank();
112 if (rank >= 2 && rank <= 3 && mp::use_sycl()) {
113#ifdef SYCL_LANGUAGE_VERSION
114 constexpr std::size_t rank = std::remove_cvref_t<Src>::rank();
115 if constexpr (rank == 2) {
116 event = dr::__detail::parallel_for(
117 dr::mp::sycl_queue(), sycl::range(src.extent(0), src.extent(1)),
118 [src, dst](
auto idx) {
119 dst[idx[0] * src.extent(1) + idx[1]] = src(idx);
121 }
else if constexpr (rank == 3) {
122 event = dr::__detail::parallel_for(
123 dr::mp::sycl_queue(),
124 sycl::range(src.extent(0), src.extent(1), src.extent(2)),
125 [src, dst](
auto idx) {
126 dst[idx[0] * src.extent(1) * src.extent(2) +
127 idx[1] * src.extent(2) + idx[2]] = src(idx);
134 auto pack = [src, &dst](
auto index) { *dst++ = src(index); };
135 mdspan_foreach<src.rank(),
decltype(pack)>(src.extents(), pack);
142template <mdspan_like Dst>
143auto mdspan_copy(std::forward_iterator
auto src, Dst dst) {
144 __detail::event event;
146 constexpr std::size_t rank = std::remove_cvref_t<Dst>::rank();
147 if (rank >= 2 && rank <= 3 && mp::use_sycl()) {
148#ifdef SYCL_LANGUAGE_VERSION
149 if constexpr (rank == 2) {
150 event = dr::__detail::parallel_for(
151 dr::mp::sycl_queue(), sycl::range(dst.extent(0), dst.extent(1)),
152 [src, dst](
auto idx) {
153 dst(idx) = src[idx[0] * dst.extent(1) + idx[1]];
155 }
else if constexpr (rank == 3) {
156 event = dr::__detail::parallel_for(
157 dr::mp::sycl_queue(),
158 sycl::range(dst.extent(0), dst.extent(1), dst.extent(2)),
159 [src, dst](
auto idx) {
160 dst(idx) = src[idx[0] * dst.extent(1) * dst.extent(2) +
161 idx[1] * dst.extent(2) + idx[2]];
168 auto unpack = [&src, dst](
auto index) { dst(index) = *src++; };
169 mdspan_foreach<dst.rank(),
decltype(unpack)>(dst.extents(), unpack);
176auto mdspan_copy(mdspan_like
auto src, mdspan_like
auto dst) {
177 __detail::event event;
179 assert(src.extents() == dst.extents());
181 constexpr std::size_t rank = std::remove_cvref_t<
decltype(src)>::rank();
182 if (rank >= 2 && rank <= 3 && mp::use_sycl()) {
183#ifdef SYCL_LANGUAGE_VERSION
184 dr::drlog.debug(
"mdspan_copy using sycl\n");
185 if constexpr (rank == 2) {
186 event = dr::__detail::parallel_for(
187 dr::mp::sycl_queue(), sycl::range(dst.extent(0), dst.extent(1)),
188 [src, dst](
auto idx) { dst(idx) = src(idx); });
189 }
else if constexpr (rank == 3) {
190 event = dr::__detail::parallel_for(
191 dr::mp::sycl_queue(),
192 sycl::range(dst.extent(0), dst.extent(1), dst.extent(2)),
193 [src, dst](
auto idx) { dst(idx) = src(idx); });
200 auto copy = [src, dst](
auto index) { dst(index) = src(index); };
201 mdspan_foreach<src.rank(),
decltype(copy)>(src.extents(), copy);
213template <
typename Mdspan, std::size_t... Is>
216 static constexpr std::size_t rank_ = Mdspan::rank();
223 template <std::integral... Indexes>
224 auto &operator()(Indexes... indexes)
const {
225 std::tuple
index(indexes...);
226 return Mdspan::operator()(std::get<Is>(
index)...);
228 auto &operator()(std::array<std::size_t, rank_>
index)
const {
229 return Mdspan::operator()(
index[Is]...);
232 auto extents()
const {
234 std::array<std::size_t, rank_> from_transposed({Is...});
235 std::array<std::size_t, rank_> extents_t;
236 for (std::size_t i = 0; i < rank_; i++) {
237 extents_t[from_transposed[i]] = Mdspan::extent(i);
240 return md_extents<rank_>(extents_t);
242 auto extent(std::size_t d)
const {
return extents().extent(d); }
247template <dr::__detail::mdspan_like Mdspan>
248struct fmt::formatter<Mdspan, char> :
public formatter<string_view> {
249 template <
typename FmtContext>
250 auto format(Mdspan mdspan, FmtContext &ctx)
const {
251 std::array<std::size_t, mdspan.rank()> index;
253 format_mdspan(ctx, mdspan, index, 0);
257 void format_mdspan(
auto &ctx,
auto mdspan,
auto &index,
258 std::size_t dim)
const {
259 for (std::size_t i = 0; i < mdspan.extent(dim); i++) {
261 if (dim == mdspan.rank() - 1) {
263 fmt::format_to(ctx.out(),
"{}: ", index);
265 fmt::format_to(ctx.out(),
"{:4} ", mdspan(index));
267 format_mdspan(ctx, mdspan, index, dim + 1);
270 fmt::format_to(ctx.out(),
"\n");
274namespace MDSPAN_NAMESPACE {
276template <dr::__detail::mdspan_like M1, dr::__detail::mdspan_like M2>
277bool operator==(
const M1 &m1,
const M2 &m2) {
278 constexpr std::size_t rank1 = M1::rank(), rank2 = M2::rank();
279 static_assert(rank1 == rank2);
280 if (dr::__detail::dims<rank1>(m1.extents()) !=
281 dr::__detail::dims<rank1>(m2.extents())) {
286 if constexpr (M1::rank() == 1) {
287 for (std::size_t i = 0; i < m1.extent(0); i++) {
288 if (m1(i) != m2(i)) {
292 }
else if constexpr (M1::rank() == 2) {
293 for (std::size_t i = 0; i < m1.extent(0); i++) {
294 for (std::size_t j = 0; j < m1.extent(1); j++) {
295 if (m1(i, j) != m2(i, j)) {
300 }
else if constexpr (M1::rank() == 3) {
301 for (std::size_t i = 0; i < m1.extent(0); i++) {
302 for (std::size_t j = 0; j < m1.extent(1); j++) {
303 for (std::size_t k = 0; k < m1.extent(2); k++) {
304 if (m1(i, j, k) != m2(i, j, k)) {
317template <dr::__detail::mdspan_like M>
318inline std::ostream &operator<<(std::ostream &os,
const M &m) {
320 os << fmt::format(
"\n{}", m.to_mdspan());
322 os << fmt::format(
"\n{}", m);
Definition: mdspan_utils.hpp:60
Definition: mdspan_utils.hpp:214
Definition: mdspan_utils.hpp:52
Definition: mdspan_utils.hpp:46
Definition: mdspan_utils.hpp:332
Definition: concepts.hpp:20