Distributed Ranges
Loading...
Searching...
No Matches
mdspan_view.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <dr/detail/mdspan_shim.hpp>
8#include <dr/detail/mdspan_utils.hpp>
9#include <dr/detail/ranges_shim.hpp>
10#include <dr/detail/ranges_utils.hpp>
11#include <dr/mp/containers/distributed_vector.hpp>
12
13namespace dr::mp::decomp {
14
15inline constexpr std::size_t div = std::numeric_limits<std::size_t>::max();
16inline constexpr std::size_t all = div - 1;
17
18} // namespace dr::mp::decomp
19
20namespace dr::mp::__detail {
21
22//
23// Add a local mdspan to the underlying segment
24//
25template <typename BaseSegment, std::size_t Rank>
26class md_segment : public rng::view_interface<md_segment<BaseSegment, Rank>> {
27private:
28public:
29 using index_type = dr::__detail::dr_extents<Rank>;
30
31 md_segment() {}
32 md_segment(index_type origin, BaseSegment segment, index_type tile_shape)
33 : base_(segment), origin_(origin),
34 mdspan_(local_tile(segment, tile_shape)) {
35 dr::drlog.debug(dr::logger::mdspan_view,
36 "md_segment\n origin: {}\n tile shape: {}\n", origin,
37 tile_shape);
38 }
39
40 // view_interface uses below to define everything else
41 auto begin() const { return base_.begin(); }
42 auto end() const { return base_.end(); }
43
44 auto reserved() const { return base_.reserved(); }
45
46 auto halo() const { return dr::mp::halo(base_); }
47
48 // mdspan-specific methods
49 auto mdspan() const { return mdspan_; }
50 auto origin() const { return origin_; }
51 // for slices, this would be the underlying mdspan
52 auto root_mdspan() const { return mdspan(); }
53
54private:
55 using T = rng::range_value_t<BaseSegment>;
56
57 static auto local_tile(BaseSegment segment, const index_type &tile_shape) {
58 // Undefined behavior if the segments is not local
59 T *ptr = dr::ranges::rank(segment) == default_comm().rank()
60 ? std::to_address(dr::ranges::local(rng::begin(segment)))
61 : nullptr;
62 return md::mdspan(ptr, tile_shape);
63 }
64
65 BaseSegment base_;
66 index_type origin_;
67 md::mdspan<T, dr::__detail::md_extents<Rank>, md::layout_stride> mdspan_;
68};
69
70} // namespace dr::mp::__detail
71
72namespace dr::mp {
73
74//
75// Wrap a distributed range, adding an mdspan and adapting the
76// segments to also be mdspans for local access
77//
78template <distributed_contiguous_range R, std::size_t Rank,
79 typename Layout = md::layout_right>
80struct mdspan_view : public rng::view_interface<mdspan_view<R, Rank>> {
81private:
82 using base_type = rng::views::all_t<R>;
83 using iterator_type = rng::iterator_t<base_type>;
84 using extents_type = md::dextents<std::size_t, Rank>;
85 using mdspan_type =
86 md::mdspan<iterator_type, extents_type, Layout,
88 using difference_type = rng::iter_difference_t<iterator_type>;
89 using index_type = dr::__detail::dr_extents<Rank>;
90
91 base_type base_;
92 index_type full_shape_;
93 index_type tile_shape_;
94
95 static auto segment_index_to_global_origin(std::size_t linear,
96 index_type full_shape,
97 index_type tile_shape) {
98 index_type grid_shape;
99 for (std::size_t i = 0; i < Rank; i++) {
100 grid_shape[i] = dr::__detail::partition_up(full_shape[i], tile_shape[i]);
101 }
102 auto origin = dr::__detail::linear_to_index(linear, grid_shape);
103 for (std::size_t i = 0; i < Rank; i++) {
104 origin[i] *= tile_shape[i];
105 }
106
107 return origin;
108 }
109
110 static auto make_segments(auto base, auto full_shape, auto tile_shape) {
111 auto make_md = [=](auto v) {
112 auto clipped = tile_shape;
113 std::size_t segment_index = std::get<0>(v);
114 std::size_t end = (segment_index + 1) * tile_shape[0];
115 if (end > full_shape[0]) {
116 clipped[0] -= std::min(end - full_shape[0], clipped[0]);
117 }
119 segment_index_to_global_origin(segment_index, full_shape, tile_shape),
120 std::get<1>(v), clipped);
121 };
122
123 dr::drlog.debug(dr::logger::mdspan_view,
124 "mdspan_view\n full shape: {}\n tile shape: {}\n",
125 full_shape, tile_shape);
126 // use bounded_enumerate so we get a std::ranges::common_range
127 return dr::__detail::bounded_enumerate(dr::ranges::segments(base)) |
128 rng::views::transform(make_md);
129 }
130 using segments_type = decltype(make_segments(std::declval<base_type>(),
131 full_shape_, tile_shape_));
132
133public:
134 mdspan_view(R r, dr::__detail::dr_extents<Rank> full_shape)
135 : base_(rng::views::all(std::forward<R>(r))) {
136 full_shape_ = full_shape;
137
138 // Default tile shape splits on leading dimension
139 tile_shape_ = full_shape;
140 tile_shape_[0] = decomp::div;
141
142 replace_decomp();
143 segments_ = make_segments(base_, full_shape_, tile_shape_);
144 }
145
146 mdspan_view(R r, dr::__detail::dr_extents<Rank> full_shape,
147 dr::__detail::dr_extents<Rank> tile_shape)
148 : base_(rng::views::all(std::forward<R>(r))), full_shape_(full_shape),
149 tile_shape_(tile_shape) {
150 replace_decomp();
151 segments_ = make_segments(base_, full_shape_, tile_shape_);
152 }
153
154 // Base implements random access range
155 auto begin() const { return base_.begin(); }
156 auto end() const { return base_.end(); }
157 auto operator[](difference_type n) { return base_[n]; }
158
159 // Add a local mdspan to the base segment
160 // Mdspan access to base
161 auto mdspan() const { return mdspan_type(rng::begin(base_), full_shape_); }
162 static constexpr auto rank() { return Rank; }
163
164 auto segments() const { return segments_; }
165
166 // Mdspan access to grid
167 auto grid() {
168 dr::__detail::dr_extents<Rank> grid_shape;
169 for (std::size_t i : rng::views::iota(0u, Rank)) {
170 grid_shape[i] =
171 dr::__detail::partition_up(full_shape_[i], tile_shape_[i]);
172 }
173 using grid_iterator_type = rng::iterator_t<segments_type>;
174 using grid_type =
175 md::mdspan<grid_iterator_type, extents_type, Layout,
177 return grid_type(rng::begin(segments_), grid_shape);
178 }
179
180private:
181 // Replace div with actual value
182 void replace_decomp() {
183 auto n = std::size_t(rng::size(dr::ranges::segments(base_)));
184 for (std::size_t i = 0; i < Rank; i++) {
185 if (tile_shape_[i] == decomp::div) {
186 tile_shape_[i] = dr::__detail::partition_up(full_shape_[i], n);
187 } else if (tile_shape_[i] == decomp::all) {
188 tile_shape_[i] = full_shape_[i];
189 }
190 }
191 }
192
193 segments_type segments_;
194};
195
196template <typename R, std::size_t Rank>
197mdspan_view(R &&r, dr::__detail::dr_extents<Rank> extents)
199
200template <typename R, std::size_t Rank>
201mdspan_view(R &&r, dr::__detail::dr_extents<Rank> full_shape,
202 dr::__detail::dr_extents<Rank> tile_shape)
204
205template <typename R>
207 dr::distributed_range<R> && requires(R &r) { r.mdspan(); };
208
209} // namespace dr::mp
210
211namespace dr::mp::views {
212
213template <std::size_t Rank> class mdspan_adapter_closure {
214public:
215 mdspan_adapter_closure(dr::__detail::dr_extents<Rank> full_shape,
216 dr::__detail::dr_extents<Rank> tile_shape)
217 : full_shape_(full_shape), tile_shape_(tile_shape), tile_valid_(true) {}
218
219 mdspan_adapter_closure(dr::__detail::dr_extents<Rank> full_shape)
220 : full_shape_(full_shape) {}
221
222 template <rng::viewable_range R> auto operator()(R &&r) const {
223 if (tile_valid_) {
224 return mdspan_view(std::forward<R>(r), full_shape_, tile_shape_);
225 } else {
226 return mdspan_view(std::forward<R>(r), full_shape_);
227 }
228 }
229
230 template <rng::viewable_range R>
231 friend auto operator|(R &&r, const mdspan_adapter_closure &closure) {
232 return closure(std::forward<R>(r));
233 }
234
235private:
236 dr::__detail::dr_extents<Rank> full_shape_;
237 dr::__detail::dr_extents<Rank> tile_shape_;
238 bool tile_valid_ = false;
239};
240
242public:
243 template <rng::viewable_range R, typename Shape>
244 auto operator()(R &&r, Shape &&full_shape, Shape &&tile_shape) const {
245 return mdspan_adapter_closure(std::forward<Shape>(full_shape),
246 std::forward<Shape>(tile_shape))(
247 std::forward<R>(r));
248 }
249
250 template <rng::viewable_range R, typename Shape>
251 auto operator()(R &&r, Shape &&full_shape) const {
252 return mdspan_adapter_closure(std::forward<Shape>(full_shape))(
253 std::forward<R>(r));
254 }
255
256 template <typename Shape>
257 auto operator()(Shape &&full_shape, Shape &&tile_shape) const {
258 return mdspan_adapter_closure(std::forward<Shape>(full_shape),
259 std::forward<Shape>(tile_shape));
260 }
261
262 template <typename Shape> auto operator()(Shape &&full_shape) const {
263 return mdspan_adapter_closure(std::forward<Shape>(full_shape));
264 }
265};
266
267inline constexpr auto mdspan = mdspan_fn_{};
268
269} // namespace dr::mp::views
Definition: mdspan_utils.hpp:60
Definition: mdspan_view.hpp:26
Definition: mdspan_view.hpp:213
Definition: mdspan_view.hpp:241
Definition: concepts.hpp:42
Definition: concepts.hpp:20
Definition: mdspan_view.hpp:206
Definition: mdspan_view.hpp:80