7#include <dr/detail/mdspan_shim.hpp>
8#include <dr/detail/mdspan_utils.hpp>
9#include <dr/detail/ranges_shim.hpp>
10#include <dr/detail/ranges_utils.hpp>
11#include <dr/mp/containers/distributed_vector.hpp>
13namespace dr::mp::decomp {
15inline constexpr std::size_t div = std::numeric_limits<std::size_t>::max();
16inline constexpr std::size_t all = div - 1;
20namespace dr::mp::__detail {
25template <
typename BaseSegment, std::
size_t Rank>
26class md_segment :
public rng::view_interface<md_segment<BaseSegment, Rank>> {
29 using index_type = dr::__detail::dr_extents<Rank>;
32 md_segment(index_type origin, BaseSegment segment, index_type tile_shape)
33 : base_(segment), origin_(origin),
34 mdspan_(local_tile(segment, tile_shape)) {
35 dr::drlog.debug(dr::logger::mdspan_view,
36 "md_segment\n origin: {}\n tile shape: {}\n", origin,
41 auto begin()
const {
return base_.begin(); }
42 auto end()
const {
return base_.end(); }
44 auto reserved()
const {
return base_.reserved(); }
46 auto halo()
const {
return dr::mp::halo(base_); }
49 auto mdspan()
const {
return mdspan_; }
50 auto origin()
const {
return origin_; }
52 auto root_mdspan()
const {
return mdspan(); }
55 using T = rng::range_value_t<BaseSegment>;
57 static auto local_tile(BaseSegment segment,
const index_type &tile_shape) {
59 T *ptr = dr::ranges::rank(segment) == default_comm().rank()
60 ? std::to_address(dr::ranges::local(rng::begin(segment)))
62 return md::mdspan(ptr, tile_shape);
67 md::mdspan<T, dr::__detail::md_extents<Rank>, md::layout_stride> mdspan_;
79 typename Layout = md::layout_right>
80struct mdspan_view :
public rng::view_interface<mdspan_view<R, Rank>> {
82 using base_type = rng::views::all_t<R>;
83 using iterator_type = rng::iterator_t<base_type>;
84 using extents_type = md::dextents<std::size_t, Rank>;
86 md::mdspan<iterator_type, extents_type, Layout,
88 using difference_type = rng::iter_difference_t<iterator_type>;
89 using index_type = dr::__detail::dr_extents<Rank>;
92 index_type full_shape_;
93 index_type tile_shape_;
95 static auto segment_index_to_global_origin(std::size_t linear,
96 index_type full_shape,
97 index_type tile_shape) {
98 index_type grid_shape;
99 for (std::size_t i = 0; i < Rank; i++) {
100 grid_shape[i] = dr::__detail::partition_up(full_shape[i], tile_shape[i]);
102 auto origin = dr::__detail::linear_to_index(linear, grid_shape);
103 for (std::size_t i = 0; i < Rank; i++) {
104 origin[i] *= tile_shape[i];
110 static auto make_segments(
auto base,
auto full_shape,
auto tile_shape) {
111 auto make_md = [=](
auto v) {
112 auto clipped = tile_shape;
113 std::size_t segment_index = std::get<0>(v);
114 std::size_t end = (segment_index + 1) * tile_shape[0];
115 if (end > full_shape[0]) {
116 clipped[0] -= std::min(end - full_shape[0], clipped[0]);
119 segment_index_to_global_origin(segment_index, full_shape, tile_shape),
120 std::get<1>(v), clipped);
123 dr::drlog.debug(dr::logger::mdspan_view,
124 "mdspan_view\n full shape: {}\n tile shape: {}\n",
125 full_shape, tile_shape);
127 return dr::__detail::bounded_enumerate(dr::ranges::segments(base)) |
128 rng::views::transform(make_md);
130 using segments_type =
decltype(make_segments(std::declval<base_type>(),
131 full_shape_, tile_shape_));
134 mdspan_view(R r, dr::__detail::dr_extents<Rank> full_shape)
135 : base_(rng::views::all(std::forward<R>(r))) {
136 full_shape_ = full_shape;
139 tile_shape_ = full_shape;
140 tile_shape_[0] = decomp::div;
143 segments_ = make_segments(base_, full_shape_, tile_shape_);
146 mdspan_view(R r, dr::__detail::dr_extents<Rank> full_shape,
147 dr::__detail::dr_extents<Rank> tile_shape)
148 : base_(rng::views::all(std::forward<R>(r))), full_shape_(full_shape),
149 tile_shape_(tile_shape) {
151 segments_ = make_segments(base_, full_shape_, tile_shape_);
155 auto begin()
const {
return base_.begin(); }
156 auto end()
const {
return base_.end(); }
157 auto operator[](difference_type n) {
return base_[n]; }
161 auto mdspan()
const {
return mdspan_type(rng::begin(base_), full_shape_); }
162 static constexpr auto rank() {
return Rank; }
164 auto segments()
const {
return segments_; }
168 dr::__detail::dr_extents<Rank> grid_shape;
169 for (std::size_t i : rng::views::iota(0u, Rank)) {
171 dr::__detail::partition_up(full_shape_[i], tile_shape_[i]);
173 using grid_iterator_type = rng::iterator_t<segments_type>;
175 md::mdspan<grid_iterator_type, extents_type, Layout,
177 return grid_type(rng::begin(segments_), grid_shape);
182 void replace_decomp() {
183 auto n = std::size_t(rng::size(dr::ranges::segments(base_)));
184 for (std::size_t i = 0; i < Rank; i++) {
185 if (tile_shape_[i] == decomp::div) {
186 tile_shape_[i] = dr::__detail::partition_up(full_shape_[i], n);
187 }
else if (tile_shape_[i] == decomp::all) {
188 tile_shape_[i] = full_shape_[i];
193 segments_type segments_;
196template <
typename R, std::
size_t Rank>
197mdspan_view(R &&r, dr::__detail::dr_extents<Rank> extents)
200template <
typename R, std::
size_t Rank>
201mdspan_view(R &&r, dr::__detail::dr_extents<Rank> full_shape,
202 dr::__detail::dr_extents<Rank> tile_shape)
211namespace dr::mp::views {
216 dr::__detail::dr_extents<Rank> tile_shape)
217 : full_shape_(full_shape), tile_shape_(tile_shape), tile_valid_(
true) {}
220 : full_shape_(full_shape) {}
222 template <rng::viewable_range R>
auto operator()(R &&r)
const {
224 return mdspan_view(std::forward<R>(r), full_shape_, tile_shape_);
226 return mdspan_view(std::forward<R>(r), full_shape_);
230 template <rng::viewable_range R>
232 return closure(std::forward<R>(r));
236 dr::__detail::dr_extents<Rank> full_shape_;
237 dr::__detail::dr_extents<Rank> tile_shape_;
238 bool tile_valid_ =
false;
243 template <rng::viewable_range R,
typename Shape>
244 auto operator()(R &&r, Shape &&full_shape, Shape &&tile_shape)
const {
246 std::forward<Shape>(tile_shape))(
250 template <rng::viewable_range R,
typename Shape>
251 auto operator()(R &&r, Shape &&full_shape)
const {
256 template <
typename Shape>
257 auto operator()(Shape &&full_shape, Shape &&tile_shape)
const {
259 std::forward<Shape>(tile_shape));
262 template <
typename Shape>
auto operator()(Shape &&full_shape)
const {
Definition: mdspan_utils.hpp:60
Definition: mdspan_view.hpp:26
Definition: mdspan_view.hpp:213
Definition: mdspan_view.hpp:241
Definition: concepts.hpp:42
Definition: concepts.hpp:20
Definition: mdspan_view.hpp:206
Definition: mdspan_view.hpp:80