7#include <dr/detail/ranges_shim.hpp>
8#include <dr/mp/views/mdspan_view.hpp>
14 using shape_type = dr::__detail::dr_extents<Rank>;
15 static constexpr auto rank() {
return Rank; }
19 : tile_shape_(tile_shape(shape)), dv_(dv_size(), dv_dist(dist, shape)),
20 md_view_(make_md_view(dv_, shape, tile_shape_)) {}
22 auto begin()
const {
return rng::begin(md_view_); }
23 auto end()
const {
return rng::end(md_view_); }
24 auto size()
const {
return rng::size(md_view_); }
25 auto operator[](
auto n) {
return md_view_[n]; }
27 auto segments() {
return dr::ranges::segments(md_view_); }
28 auto &halo()
const {
return dr::mp::halo(dv_); }
30 auto mdspan()
const {
return md_view_.mdspan(); }
31 auto extent(std::size_t r)
const {
return mdspan().extent(r); }
32 auto grid() {
return md_view_.grid(); }
33 auto view()
const {
return md_view_; }
36 return std::equal(begin(), end(), other.begin());
42 static auto tile_shape(
auto shape) {
43 std::size_t n = default_comm().size();
44 shape[0] = dr::__detail::partition_up(shape[0], n);
48 static auto md_size(
auto shape) {
50 for (
auto extent : shape) {
57 return default_comm().size() * md_size(tile_shape_);
60 static auto dv_dist(
distribution incoming_dist,
auto shape) {
64 std::size_t row_size = md_size(shape);
65 auto incoming_halo = incoming_dist.halo();
66 return distribution().halo(incoming_halo.prev * row_size,
67 incoming_halo.next * row_size);
72 static auto make_md_view(
const DV &dv, shape_type shape,
73 shape_type tile_shape) {
74 return views::mdspan(dv, shape, tile_shape);
77 shape_type tile_shape_;
80 decltype(make_md_view(std::declval<DV>(), std::declval<shape_type>(),
81 std::declval<shape_type>()));
85template <
typename T, std::
size_t Rank>
87 return mdarray.halo();
90template <
typename T, std::
size_t Rank>
91std::ostream &operator<<(std::ostream &os,
92 const distributed_mdarray<T, Rank> &mdarray) {
93 os << fmt::format(
"\n{}", mdarray.mdspan());
Definition: distributed_mdarray.hpp:12
distributed vector
Definition: distributed_vector.hpp:150
Definition: distribution.hpp:11