Distributed Ranges
Loading...
Searching...
No Matches
md_for_each.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <algorithm>
8#include <execution>
9#include <type_traits>
10#include <utility>
11
12#include <dr/concepts/concepts.hpp>
13#include <dr/detail/logger.hpp>
14#include <dr/detail/onedpl_direct_iterator.hpp>
15#include <dr/detail/ranges_shim.hpp>
16#include <dr/detail/tuple_utils.hpp>
17#include <dr/mp/global.hpp>
18
19namespace dr::mp::__detail {
20
21struct any {
22 template <typename T> operator T() const noexcept {
23 return std::declval<T>();
24 }
25};
26
27template <typename F, typename Arg1>
28concept one_argument = requires(F &f) {
29 { f(Arg1{}) };
30};
31
32template <typename F, typename Arg1, typename Arg2>
33concept two_arguments = requires(F &f) {
34 { f(Arg1{}, Arg2{}) };
35};
36
37}; // namespace dr::mp::__detail
38
39namespace dr::mp {
40
41namespace detail = dr::__detail;
42
44template <typename... Ts>
45void stencil_for_each(auto op, is_mdspan_view auto &&...drs) {
46 auto ranges = std::tie(drs...);
47 auto &&dr0 = std::get<0>(ranges);
48 if (rng::empty(dr0)) {
49 return;
50 }
51
52 auto all_segments = rng::views::zip(dr::ranges::segments(drs)...);
53 for (auto segs : all_segments) {
54 auto seg0 = std::get<0>(segs);
55 auto mdspan0 = seg0.mdspan();
56
57 // If local
58 if (dr::ranges::rank(seg0) == default_comm().rank()) {
59 // Calculate loop invariant info about the operands. Use a tuple
60 // to hold the info for all operands.
61 auto operand_infos = detail::tuple_transform(segs, [](auto &&seg) {
62 // mdspan for tile. This could be a submdspan, so we need the
63 // extents of the root to get the memory strides
64 return std::make_pair(seg.mdspan(), seg.root_mdspan().extents());
65 });
66
67 if (mp::use_sycl()) {
68#ifdef SYCL_LANGUAGE_VERSION
69 auto do_point = [=](auto index) {
70 // Transform operand_infos into stencils
71 auto stencils =
72 detail::tuple_transform(operand_infos, [=](auto info) {
73 return md::mdspan(
74 std::to_address(&info.first(index[0], index[1])),
75 info.second);
76 });
77 op(stencils);
78 };
79 // TODO: Extend sycl_utils.hpp to handle ranges > 1D. It uses
80 // ndrange and handles > 32 bits.
81 dr::__detail::parallel_for(
82 mp::sycl_queue(), sycl::range(mdspan0.extent(0), mdspan0.extent(1)),
83 do_point)
84 .wait();
85#else
86 assert(false);
87#endif
88 } else {
89 // Given an index, invoke op on a tuple of stencils
90 auto invoke_index = [=](auto index) {
91 // Transform operand_infos into stencils
92 auto stencils =
93 detail::tuple_transform(operand_infos, [=](auto info) {
94 return md::mdspan(std::to_address(&info.first(index)),
95 info.second);
96 });
97 op(stencils);
98 };
99#if 0
100 // Does not vectorize. Something about loop index being forced into memory
101 detail::mdspan_foreach<mdspan0.rank(), decltype(invoke_index)>(
102 mdspan0.extents(), invoke_index);
103#else
104 for (std::size_t i = 0; i < mdspan0.extents().extent(0); i++) {
105 for (std::size_t j = 0; j < mdspan0.extents().extent(1); j++) {
106 invoke_index(std::array<std::size_t, 2>{i, j});
107 }
108 }
109#endif
110 }
111 }
112 }
113
114 barrier();
115}
116
118template <typename F, typename... Ts>
119void for_each(F op, is_mdspan_view auto &&...drs) {
120 auto ranges = std::tie(drs...);
121 auto &&dr0 = std::get<0>(ranges);
122 if (rng::empty(dr0)) {
123 return;
124 }
125
126 auto all_segments = rng::views::zip(dr::ranges::segments(drs)...);
127 for (auto segs : all_segments) {
128 auto seg0 = std::get<0>(segs);
129 auto mdspan0 = seg0.mdspan();
130
131 // If local
132 if (dr::ranges::rank(seg0) == default_comm().rank()) {
133 auto origin = seg0.origin();
134
135 // make a tuple of mdspans
136 auto operand_mdspans = detail::tuple_transform(
137 segs, [](auto &&seg) { return seg.mdspan(); });
138
139 if (mp::use_sycl()) {
140#ifdef SYCL_LANGUAGE_VERSION
141 //
142 auto invoke_index = [=](auto index) {
143 // Transform mdspans into references
144 auto references = detail::tie_transform(
145 operand_mdspans, [mdspan0, index](auto mdspan) -> decltype(auto) {
146 static_assert(1 <= mdspan0.rank() && mdspan0.rank() <= 3);
147 if constexpr (mdspan0.rank() == 1) {
148 return mdspan(index[0]);
149 } else if constexpr (mdspan0.rank() == 2) {
150 return mdspan(index[0], index[1]);
151 } else if constexpr (mdspan0.rank() == 3) {
152 return mdspan(index[0], index[1], index[2]);
153 }
154 });
155 static_assert(
156 std::invocable<F, decltype(references)> ||
157 std::invocable<F, decltype(index), decltype(references)>);
158 if constexpr (std::invocable<F, decltype(references)>) {
159 op(references);
160 } else {
161 auto global_index = index;
162 for (std::size_t i = 0; i < rng::size(global_index); i++) {
163 global_index[i] += origin[i];
164 }
165
166 op(global_index, references);
167 }
168 };
169
170 if constexpr (mdspan0.rank() == 1) {
171 auto range = sycl::range(mdspan0.extent(0));
172 dr::__detail::parallel_for(mp::sycl_queue(), range, invoke_index)
173 .wait();
174 } else if constexpr (mdspan0.rank() == 2) {
175 auto range = sycl::range(mdspan0.extent(0), mdspan0.extent(1));
176 dr::__detail::parallel_for(mp::sycl_queue(), range, invoke_index)
177 .wait();
178 } else if constexpr (mdspan0.rank() == 3) {
179 auto range = sycl::range(mdspan0.extent(0), mdspan0.extent(1),
180 mdspan0.extent(2));
181 dr::__detail::parallel_for(mp::sycl_queue(), range, invoke_index)
182 .wait();
183 }
184#else
185 assert(false);
186#endif
187 } else {
188 // invoke op on a tuple of references created by using the mdspan's and
189 // index
190 auto invoke_index = [=](auto index) {
191 // Transform operand_infos into references
192 auto references = detail::tie_transform(
193 operand_mdspans,
194 [index](auto mdspan) -> decltype(auto) { return mdspan(index); });
195 static_assert(
196 std::invocable<F, decltype(references)> ||
197 std::invocable<F, decltype(index), decltype(references)>);
198 if constexpr (std::invocable<F, decltype(references)>) {
199 op(references);
200 } else if constexpr (std::invocable<F, decltype(index),
201 decltype(references)>) {
202 auto global_index = index;
203 for (std::size_t i = 0; i < rng::size(global_index); i++) {
204 global_index[i] += origin[i];
205 }
206
207 op(global_index, references);
208 } else {
209 assert(false);
210 }
211 };
212 detail::mdspan_foreach<mdspan0.rank(), decltype(invoke_index)>(
213 mdspan0.extents(), invoke_index);
214 }
215 }
216 }
217
218 barrier();
219}
220
221} // namespace dr::mp
Definition: md_for_each.hpp:28
Definition: md_for_each.hpp:33
Definition: mdspan_view.hpp:206
Definition: md_for_each.hpp:21