7#include <dr/detail/mdspan_shim.hpp>
8#include <dr/detail/ranges_shim.hpp>
9#include <dr/mp/allocator.hpp>
10#include <dr/mp/containers/distributed_mdarray.hpp>
12namespace dr::mp::__detail {
16 tmp_buffer(std::size_t size,
auto &&candidate) {
18 data_ = candidate.mdspan().data_handle();
20 allocated_data_ =
nullptr;
23 if (size_ > candidate.reserved()) {
25 dr::logger::transpose,
26 "Allocating a temporary buffer requested size {} candidate size {}\n",
27 size, candidate.reserved());
29 data_ = allocated_data_;
31 assert(data_ !=
nullptr);
34 T *data() {
return data_; }
38 if (allocated_data_) {
40 allocated_data_ =
nullptr;
46 T *allocated_data_ =
nullptr;
50template <dr::distributed_mdspan_range MR1, dr::distributed_mdspan_range MR2>
51void transpose2D(MR1 &&src, MR2 &&dst,
auto sm,
auto dm) {
52 auto comm = default_comm();
54 using T = rng::range_value_t<MR1>;
56 using index_type = dr::__detail::dr_extents<2>;
59 assert(sm.extent(0) == dm.extent(1) && sm.extent(1) == dm.extent(0));
61 auto src_tile = src.grid()(comm.rank(), 0);
62 auto dst_tile = dst.grid()(comm.rank(), 0);
64 if (comm.size() == 1) {
65 dr::drlog.debug(dr::logger::transpose,
"direct transpose on single rank\n");
66 auto sm = src_tile.mdspan();
67 auto dm = dst_tile.mdspan();
69 dr::__detail::mdspan_copy(src_tile_t, dm).wait();
80 std::size_t sub_tile_size = src.grid()(0, 0).mdspan().extent(0) *
81 dst.grid()(0, 0).mdspan().extent(0);
82 std::size_t sub_tiles_size = sub_tile_size * comm.size();
83 dr::drlog.debug(dr::logger::transpose,
"sub_tile_size: {}x{} total: {}\n",
84 src.grid()(0, 0).mdspan().extent(0),
85 dst.grid()(0, 0).mdspan().extent(0), sub_tile_size);
89 T *buffer = send_buffer.data();
91 std::vector<dr::__detail::event> pack_events;
92 index_type start({0, 0}), end({src_tile.mdspan().extent(0), 0});
93 for (std::size_t i = 0; i < dst.grid().extent(0); i++) {
94 auto num_cols = dst.grid()(i, 0).mdspan().extent(0);
96 end[1] = start[1] + num_cols;
97 dr::drlog.debug(dr::logger::transpose,
"Packing start: {}, end: {}\n",
100 dr::__detail::make_submdspan(src_tile.mdspan(), start, end);
102 pack_events.push_back(dr::__detail::mdspan_copy(sub_tile_t, buffer));
103 buffer += sub_tile_size;
104 start[1] += num_cols;
106 rng::for_each(pack_events, [](
auto e) { e.wait(); });
110 __detail::tmp_buffer<T> receive_buffer(sub_tiles_size, src_tile);
111 buffer = receive_buffer.data();
112 comm.alltoall(send_buffer.data(), receive_buffer.data(), sub_tile_size);
114 std::vector<dr::__detail::event> unpack_events;
116 end = {dst_tile.mdspan().extent(0), 0};
117 for (std::size_t i = 0; i < src.grid().extent(0); i++) {
118 auto num_cols = src.grid()(i, 0).mdspan().extent(0);
120 end[1] = start[1] + num_cols;
121 dr::drlog.debug(dr::logger::transpose,
"Unpacking start: {}, end: {}\n",
124 dr::__detail::make_submdspan(dst_tile.mdspan(), start, end);
125 unpack_events.push_back(dr::__detail::mdspan_copy(buffer, sub_tile));
126 buffer += sub_tile_size;
127 start[1] += num_cols;
129 rng::for_each(unpack_events, [](
auto e) { e.wait(); });
135void transpose3D_slab(MR1 &&src, MR2 &&dst,
auto sm,
auto dm) {
136 auto comm = default_comm();
138 using T = rng::range_value_t<MR1>;
140 using index_type = dr::__detail::dr_extents<3>;
143 dr::drlog.debug(dr::logger::transpose,
144 "transpose src: [{}, {}, {}] dst: [{}, {}, {}]\n",
145 sm.extent(0), sm.extent(1), sm.extent(2), dm.extent(0),
146 dm.extent(1), dm.extent(2));
148 constexpr std::array<std::size_t, 3> from_transposed{Is...};
150 assert(sm.extent(0) == dm.extent(from_transposed[0]) &&
151 sm.extent(1) == dm.extent(from_transposed[1]) &&
152 sm.extent(2) == dm.extent(from_transposed[2]));
155 std::size_t mask_p = 1, mask_u = 2;
156 if (from_transposed[0] == 1) {
162 auto origin_dst_tile = dst.grid()(0, 0, 0).mdspan();
163 auto origin_src_tile = src.grid()(0, 0, 0).mdspan();
164 std::size_t sub_tile_size = origin_src_tile.extent(0) *
165 origin_dst_tile.extent(0) *
166 origin_dst_tile.extent(mask_p);
168 std::size_t sub_tiles_size = sub_tile_size * comm.size();
170 auto src_tile = src.grid()(comm.rank(), 0, 0);
171 auto dst_tile = dst.grid()(comm.rank(), 0, 0);
173 if (comm.size() == 1) {
174 dr::drlog.debug(dr::logger::transpose,
"direct transpose on single rank\n");
175 auto sm = src_tile.mdspan();
176 auto dm = dst_tile.mdspan();
178 dr::__detail::mdspan_copy(src_tile_t, dm).wait();
182 __detail::tmp_buffer<T> send_buffer(sub_tiles_size, dst_tile);
184 T *buffer = send_buffer.data();
186 std::vector<dr::__detail::event> pack_events;
187 index_type start({0, 0, 0}),
188 end({src_tile.mdspan().extent(0), src_tile.mdspan().extent(1),
189 src_tile.mdspan().extent(2)});
191 for (std::size_t i = 0; i < dst.grid().extent(0); i++) {
192 auto num_cols = dst.grid()(i, 0, 0).mdspan().extent(0);
193 end[mask_p] = start[mask_p] + num_cols;
195 dr::drlog.debug(dr::logger::transpose,
"Packing start: {}, end: {}\n",
198 dr::__detail::make_submdspan(src_tile.mdspan(), start, end);
201 pack_events.push_back(dr::__detail::mdspan_copy(sub_tile_t, buffer));
202 buffer += sub_tile_size;
203 start[mask_p] += num_cols;
206 rng::for_each(pack_events, [](
auto e) { e.wait(); });
208 __detail::tmp_buffer<T> receive_buffer(sub_tiles_size, src_tile);
209 buffer = receive_buffer.data();
210 comm.alltoall(send_buffer.data(), receive_buffer.data(), sub_tile_size);
212 std::vector<dr::__detail::event> unpack_events;
214 end = {dst_tile.mdspan().extent(0), dst_tile.mdspan().extent(1),
215 dst_tile.mdspan().extent(2)};
216 for (std::size_t i = 0; i < src.grid().extent(0); i++) {
217 auto num_cols = src.grid()(i, 0, 0).mdspan().extent(0);
219 end[mask_u] = start[mask_u] + num_cols;
220 dr::drlog.debug(dr::logger::transpose,
"Unpacking start: {}, end: {}\n",
223 dr::__detail::make_submdspan(dst_tile.mdspan(), start, end);
224 unpack_events.push_back(dr::__detail::mdspan_copy(buffer, sub_tile));
225 buffer += sub_tile_size;
226 start[mask_u] += num_cols;
228 rng::for_each(unpack_events, [](
auto e) { e.wait(); });
238template <dr::distributed_mdspan_range MR1, dr::distributed_mdspan_range MR2>
239void transpose(MR1 &&src, MR2 &&dst,
bool forward =
true) {
240 constexpr std::size_t rank1 = std::remove_cvref_t<MR1>::rank();
241 constexpr std::size_t rank2 = std::remove_cvref_t<MR2>::rank();
242 static_assert(rank1 == rank2);
245 for (std::size_t i = 1; i < rank1; i++) {
246 assert(src.grid().extent(i) == 1);
249 auto sm = src.mdspan();
250 auto dm = dst.mdspan();
252 if constexpr (rank1 == 2) {
253 __detail::transpose2D(src, dst, sm, dm);
254 }
else if constexpr (rank1 == 3) {
256 __detail::transpose3D_slab<MR1, MR2, 2, 0, 1>(src, dst, sm, dm);
258 __detail::transpose3D_slab<MR1, MR2, 1, 2, 0>(src, dst, sm, dm);
Definition: mdspan_utils.hpp:214
Definition: allocator.hpp:11
Definition: transpose.hpp:14
Definition: mdspan_utils.hpp:332