12#include <dr/concepts/concepts.hpp>
13#include <dr/detail/logger.hpp>
14#include <dr/detail/onedpl_direct_iterator.hpp>
15#include <dr/detail/ranges_shim.hpp>
16#include <dr/detail/tuple_utils.hpp>
17#include <dr/mp/global.hpp>
19namespace dr::mp::__detail {
22 template <
typename T>
operator T()
const noexcept {
23 return std::declval<T>();
27template <
typename F,
typename Arg1>
32template <
typename F,
typename Arg1,
typename Arg2>
34 { f(Arg1{}, Arg2{}) };
41namespace detail = dr::__detail;
44template <
typename... Ts>
46 auto ranges = std::tie(drs...);
47 auto &&dr0 = std::get<0>(ranges);
48 if (rng::empty(dr0)) {
52 auto all_segments = rng::views::zip(dr::ranges::segments(drs)...);
53 for (
auto segs : all_segments) {
54 auto seg0 = std::get<0>(segs);
55 auto mdspan0 = seg0.mdspan();
58 if (dr::ranges::rank(seg0) == default_comm().rank()) {
61 auto operand_infos = detail::tuple_transform(segs, [](
auto &&seg) {
64 return std::make_pair(seg.mdspan(), seg.root_mdspan().extents());
68#ifdef SYCL_LANGUAGE_VERSION
69 auto do_point = [=](
auto index) {
72 detail::tuple_transform(operand_infos, [=](
auto info) {
74 std::to_address(&info.first(index[0], index[1])),
81 dr::__detail::parallel_for(
82 mp::sycl_queue(), sycl::range(mdspan0.extent(0), mdspan0.extent(1)),
90 auto invoke_index = [=](
auto index) {
93 detail::tuple_transform(operand_infos, [=](
auto info) {
94 return md::mdspan(std::to_address(&info.first(index)),
101 detail::mdspan_foreach<mdspan0.rank(),
decltype(invoke_index)>(
102 mdspan0.extents(), invoke_index);
104 for (std::size_t i = 0; i < mdspan0.extents().extent(0); i++) {
105 for (std::size_t j = 0; j < mdspan0.extents().extent(1); j++) {
106 invoke_index(std::array<std::size_t, 2>{i, j});
118template <
typename F,
typename... Ts>
119void for_each(F op, is_mdspan_view
auto &&...drs) {
120 auto ranges = std::tie(drs...);
121 auto &&dr0 = std::get<0>(ranges);
122 if (rng::empty(dr0)) {
126 auto all_segments = rng::views::zip(dr::ranges::segments(drs)...);
127 for (
auto segs : all_segments) {
128 auto seg0 = std::get<0>(segs);
129 auto mdspan0 = seg0.mdspan();
132 if (dr::ranges::rank(seg0) == default_comm().rank()) {
133 auto origin = seg0.origin();
136 auto operand_mdspans = detail::tuple_transform(
137 segs, [](
auto &&seg) {
return seg.mdspan(); });
139 if (mp::use_sycl()) {
140#ifdef SYCL_LANGUAGE_VERSION
142 auto invoke_index = [=](
auto index) {
144 auto references = detail::tie_transform(
145 operand_mdspans, [mdspan0, index](
auto mdspan) ->
decltype(
auto) {
146 static_assert(1 <= mdspan0.rank() && mdspan0.rank() <= 3);
147 if constexpr (mdspan0.rank() == 1) {
148 return mdspan(index[0]);
149 }
else if constexpr (mdspan0.rank() == 2) {
150 return mdspan(index[0], index[1]);
151 }
else if constexpr (mdspan0.rank() == 3) {
152 return mdspan(index[0], index[1], index[2]);
156 std::invocable<F,
decltype(references)> ||
157 std::invocable<F,
decltype(index),
decltype(references)>);
158 if constexpr (std::invocable<F,
decltype(references)>) {
161 auto global_index = index;
162 for (std::size_t i = 0; i < rng::size(global_index); i++) {
163 global_index[i] += origin[i];
166 op(global_index, references);
170 if constexpr (mdspan0.rank() == 1) {
171 auto range = sycl::range(mdspan0.extent(0));
172 dr::__detail::parallel_for(mp::sycl_queue(), range, invoke_index)
174 }
else if constexpr (mdspan0.rank() == 2) {
175 auto range = sycl::range(mdspan0.extent(0), mdspan0.extent(1));
176 dr::__detail::parallel_for(mp::sycl_queue(), range, invoke_index)
178 }
else if constexpr (mdspan0.rank() == 3) {
179 auto range = sycl::range(mdspan0.extent(0), mdspan0.extent(1),
181 dr::__detail::parallel_for(mp::sycl_queue(), range, invoke_index)
190 auto invoke_index = [=](
auto index) {
192 auto references = detail::tie_transform(
194 [index](
auto mdspan) ->
decltype(
auto) {
return mdspan(index); });
196 std::invocable<F,
decltype(references)> ||
197 std::invocable<F,
decltype(index),
decltype(references)>);
198 if constexpr (std::invocable<F,
decltype(references)>) {
200 }
else if constexpr (std::invocable<F,
decltype(index),
201 decltype(references)>) {
202 auto global_index = index;
203 for (std::size_t i = 0; i < rng::size(global_index); i++) {
204 global_index[i] += origin[i];
207 op(global_index, references);
212 detail::mdspan_foreach<mdspan0.rank(),
decltype(invoke_index)>(
213 mdspan0.extents(), invoke_index);
Definition: md_for_each.hpp:28
Definition: md_for_each.hpp:33
Definition: mdspan_view.hpp:206
Definition: md_for_each.hpp:21