Distributed Ranges
Loading...
Searching...
No Matches
submdspan_view.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <dr/detail/mdspan_shim.hpp>
8#include <dr/detail/ranges_shim.hpp>
9#include <dr/mp/views/mdspan_view.hpp>
10
11namespace dr::mp::__detail {
12
13//
14// Add a local mdspan to the underlying segment
15//
16template <typename BaseSegment, std::size_t Rank,
17 typename Layout = md::layout_stride>
18class mdsub_segment : public BaseSegment {
19private:
20public:
21 using index_type = dr::__detail::dr_extents<Rank>;
22
23 mdsub_segment(){};
24 mdsub_segment(BaseSegment segment, const index_type &slice_starts,
25 const index_type &slice_ends)
26 : BaseSegment(segment),
27 mdspan_(local_tile(segment, slice_starts, slice_ends)),
28 root_mdspan_(segment.mdspan()) {}
29
30 auto mdspan() const { return mdspan_; }
31 auto root_mdspan() const { return root_mdspan_; }
32
33private:
34 using T = rng::range_value_t<BaseSegment>;
35
36 static auto local_tile(BaseSegment segment, const index_type &slice_starts,
37 const index_type &slice_ends) {
38 index_type starts, ends;
39 index_type base_starts = segment.origin();
40 auto base_mdspan = segment.mdspan();
41
42 for (std::size_t i = 0; i < Rank; i++) {
43 // Clip base to area covered by requested span, and translate from global
44 // to local indexing
45 auto base_end = base_starts[i] + base_mdspan.extent(i);
46 starts[i] =
47 std::min(base_end, std::max(slice_starts[i], base_starts[i])) -
48 base_starts[i];
49 ends[i] = std::max(base_starts[i], std::min(slice_ends[i], base_end)) -
50 base_starts[i];
51 }
52
53 return dr::__detail::make_submdspan(base_mdspan, starts, ends);
54 }
55
56 md::mdspan<T, dr::__detail::md_extents<Rank>, md::layout_stride> mdspan_;
57 md::mdspan<T, dr::__detail::md_extents<Rank>, md::layout_stride> root_mdspan_;
58};
59
60} // namespace dr::mp::__detail
61
62namespace dr::mp {
63
64//
65// Wrap a mdspan view
66//
67template <is_mdspan_view Base>
68struct submdspan_view : public rng::view_interface<submdspan_view<Base>> {
69private:
70 static auto make_segments(auto base, auto slice_starts, auto slice_ends) {
71 auto make_md = [=](auto segment) {
72 return __detail::mdsub_segment(segment, slice_starts, slice_ends);
73 };
74 return dr::ranges::segments(base) | rng::views::transform(make_md);
75 }
76
77 using iterator_type = rng::iterator_t<Base>;
78 using extents_type = dr::__detail::dr_extents<Base::rank()>;
79 using difference_type = rng::iter_difference_t<iterator_type>;
80 using segments_type =
81 decltype(make_segments(std::declval<Base>(), std::declval<extents_type>(),
82 std::declval<extents_type>()));
83
84 Base base_;
85 extents_type slice_starts_;
86 extents_type slice_ends_;
87 segments_type segments_;
88
89public:
90 submdspan_view(is_mdspan_view auto base, extents_type slice_starts,
91 extents_type slice_ends)
92 : base_(base), slice_starts_(std::forward<extents_type>(slice_starts)),
93 slice_ends_(std::forward<extents_type>(slice_ends)) {
94 segments_ = make_segments(base_, slice_starts_, slice_ends_);
95 }
96
97 // Base implements random access range
98 auto begin() const { return base_.begin(); }
99 auto end() const { return base_.end(); }
100 auto operator[](difference_type n) { return base_[n]; }
101
102 auto mdspan() const {
103 return dr::__detail::make_submdspan(base_.mdspan(), slice_starts_,
104 slice_ends_);
105 }
106
107 auto segments() const { return segments_; }
108
109 // Mdspan access to grid
110 auto grid() {
111 using grid_iterator_type = rng::iterator_t<segments_type>;
112 using grid_type =
113 md::mdspan<grid_iterator_type, dr::__detail::md_extents<Base::rank()>,
114 md::layout_right,
116 return grid_type(rng::begin(segments_), base_.grid().extents());
117 }
118};
119
120template <typename R, typename Extents>
121submdspan_view(R r, Extents slice_starts, Extents slice_ends)
123
124} // namespace dr::mp
125
126namespace dr::mp::views {
127
128template <typename Extents> class submdspan_adapter_closure {
129public:
130 submdspan_adapter_closure(Extents slice_starts, Extents slice_ends)
131 : slice_starts_(slice_starts), slice_ends_(slice_ends) {}
132
133 template <rng::viewable_range R> auto operator()(R &&r) const {
134 return submdspan_view(std::forward<R>(r), slice_starts_, slice_ends_);
135 }
136
137 template <rng::viewable_range R>
138 friend auto operator|(R &&r, const submdspan_adapter_closure &closure) {
139 return closure(std::forward<R>(r));
140 }
141
142private:
143 Extents slice_starts_;
144 Extents slice_ends_;
145};
146
148public:
149 template <is_mdspan_view R, typename Extents>
150 auto operator()(R r, Extents &&slice_starts, Extents &&slice_ends) const {
151 return submdspan_adapter_closure(std::forward<Extents>(slice_starts),
152 std::forward<Extents>(slice_ends))(
153 std::forward<R>(r));
154 }
155
156 template <typename Extents>
157 auto operator()(Extents &&slice_starts, Extents &&slice_ends) const {
158 return submdspan_adapter_closure(std::forward<Extents>(slice_starts),
159 std::forward<Extents>(slice_ends));
160 }
161};
162
163inline constexpr auto submdspan = submdspan_fn_{};
164
165} // namespace dr::mp::views
Definition: mdspan_utils.hpp:60
Definition: submdspan_view.hpp:18
Definition: submdspan_view.hpp:128
Definition: submdspan_view.hpp:147
Definition: mdspan_view.hpp:206
Definition: submdspan_view.hpp:68