Distributed Ranges
Loading...
Searching...
No Matches
distributed_mdarray.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <dr/detail/ranges_shim.hpp>
8#include <dr/mp/views/mdspan_view.hpp>
9
10namespace dr::mp {
11
12template <typename T, std::size_t Rank> class distributed_mdarray {
13public:
14 using shape_type = dr::__detail::dr_extents<Rank>;
15 static constexpr auto rank() { return Rank; }
16
17 distributed_mdarray(dr::__detail::dr_extents<Rank> shape,
19 : tile_shape_(tile_shape(shape)), dv_(dv_size(), dv_dist(dist, shape)),
20 md_view_(make_md_view(dv_, shape, tile_shape_)) {}
21
22 auto begin() const { return rng::begin(md_view_); }
23 auto end() const { return rng::end(md_view_); }
24 auto size() const { return rng::size(md_view_); }
25 auto operator[](auto n) { return md_view_[n]; }
26
27 auto segments() { return dr::ranges::segments(md_view_); }
28 auto &halo() const { return dr::mp::halo(dv_); }
29
30 auto mdspan() const { return md_view_.mdspan(); }
31 auto extent(std::size_t r) const { return mdspan().extent(r); }
32 auto grid() { return md_view_.grid(); }
33 auto view() const { return md_view_; }
34
35 auto operator==(const distributed_mdarray &other) const {
36 return std::equal(begin(), end(), other.begin());
37 }
38
39private:
41
42 static auto tile_shape(auto shape) {
43 std::size_t n = default_comm().size(); // dr-style ignore
44 shape[0] = dr::__detail::partition_up(shape[0], n);
45 return shape;
46 }
47
48 static auto md_size(auto shape) {
49 std::size_t size = 1;
50 for (auto extent : shape) {
51 size *= extent;
52 }
53 return size;
54 }
55
56 auto dv_size() {
57 return default_comm().size() * md_size(tile_shape_); // dr-style ignore
58 }
59
60 static auto dv_dist(distribution incoming_dist, auto shape) {
61 // Decomp is 1 "row" in decomp dimension
62 // TODO: only supports dist on leading dimension
63 shape[0] = 1;
64 std::size_t row_size = md_size(shape);
65 auto incoming_halo = incoming_dist.halo();
66 return distribution().halo(incoming_halo.prev * row_size,
67 incoming_halo.next * row_size);
68 }
69
70 // This wrapper seems to avoid an issue with template argument
71 // deduction for mdspan_view
72 static auto make_md_view(const DV &dv, shape_type shape,
73 shape_type tile_shape) {
74 return views::mdspan(dv, shape, tile_shape);
75 }
76
77 shape_type tile_shape_;
78 DV dv_;
79 using mdspan_type =
80 decltype(make_md_view(std::declval<DV>(), std::declval<shape_type>(),
81 std::declval<shape_type>()));
82 mdspan_type md_view_;
83};
84
85template <typename T, std::size_t Rank>
86auto &halo(const distributed_mdarray<T, Rank> &mdarray) {
87 return mdarray.halo();
88}
89
90template <typename T, std::size_t Rank>
91std::ostream &operator<<(std::ostream &os,
92 const distributed_mdarray<T, Rank> &mdarray) {
93 os << fmt::format("\n{}", mdarray.mdspan());
94 return os;
95}
96
97} // namespace dr::mp
Definition: distributed_mdarray.hpp:12
distributed vector
Definition: distributed_vector.hpp:150
Definition: distribution.hpp:11