Distributed Ranges
Loading...
Searching...
No Matches
matrix_partition.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <dr/detail/index.hpp>
8#include <dr/sp/containers/detail.hpp>
9#include <dr/sp/init.hpp>
10
11namespace dr::sp {
12
13namespace tile {
14
15// Special constant to indicate tile dimensions of
16// {ceil(m / p_m), ceil(n / p_n)} should be chosen
17// in order to evenly divide a dimension amongst the
18// ranks in the processor grid.
19inline constexpr std::size_t div = std::numeric_limits<std::size_t>::max();
20
21} // namespace tile
22
24public:
25 virtual std::size_t tile_rank(dr::index<> matrix_shape,
26 dr::index<> tile_id) const = 0;
27 virtual dr::index<> grid_shape(dr::index<> matrix_shape) const = 0;
28 virtual dr::index<> tile_shape(dr::index<> matrix_shape) const = 0;
29
30 virtual std::unique_ptr<matrix_partition> clone() const = 0;
31 virtual ~matrix_partition(){};
32};
33
34class block_cyclic final : public matrix_partition {
35public:
36 block_cyclic(dr::index<> tile_shape = {dr::sp::tile::div, dr::sp::tile::div},
37 dr::index<> grid_shape = detail::factor(dr::sp::nprocs()))
38 : tile_shape_(tile_shape), grid_shape_(grid_shape) {}
39
40 block_cyclic(const block_cyclic &) noexcept = default;
41
42 dr::index<> tile_shape() const { return tile_shape_; }
43
44 std::size_t tile_rank(dr::index<> matrix_shape, dr::index<> tile_id) const {
45 dr::index<> pgrid_idx = {tile_id[0] % grid_shape_[0],
46 tile_id[1] % grid_shape_[1]};
47
48 auto pgrid = processor_grid_();
49
50 return pgrid[pgrid_idx[0] * grid_shape_[1] + pgrid_idx[1]];
51 }
52
53 dr::index<> grid_shape(dr::index<> matrix_shape) const {
54 auto ts = this->tile_shape(matrix_shape);
55
56 return dr::index<>((matrix_shape[0] + ts[0] - 1) / ts[0],
57 (matrix_shape[1] + ts[1] - 1) / ts[1]);
58 }
59
60 dr::index<> tile_shape(dr::index<> matrix_shape) const {
61 std::array<std::size_t, 2> tshape = {tile_shape_[0], tile_shape_[1]};
62
63 constexpr std::size_t ndims = 2;
64 for (std::size_t i = 0; i < ndims; i++) {
65 if (tshape[i] == dr::sp::tile::div) {
66 tshape[i] = (matrix_shape[i] + grid_shape_[i] - 1) / grid_shape_[i];
67 }
68 }
69
70 return tshape;
71 }
72
73 std::unique_ptr<matrix_partition> clone() const noexcept {
74 return std::unique_ptr<matrix_partition>(new block_cyclic(*this));
75 }
76
77private:
78 std::vector<std::size_t> processor_grid_() const {
79 std::vector<std::size_t> grid(grid_shape_[0] * grid_shape_[1]);
80
81 for (std::size_t i = 0; i < grid.size(); i++) {
82 grid[i] = i;
83 }
84 return grid;
85 }
86
87 dr::index<> tile_shape_;
88 dr::index<> grid_shape_;
89};
90
91inline auto row_cyclic() {
92 return block_cyclic({dr::sp::tile::div, dr::sp::tile::div},
93 {dr::sp::nprocs(), 1});
94}
95
96inline std::vector<block_cyclic> partition_matmul(std::size_t m, std::size_t n,
97 std::size_t k) {
98 dr::index<> c_pgrid = detail::factor(sp::nprocs());
99
100 block_cyclic c_block({dr::sp::tile::div, dr::sp::tile::div},
101 {c_pgrid[0], c_pgrid[1]});
102
103 std::size_t k_block;
104
105 if (m * k >= k * n) {
106 k_block = (sp::nprocs() + c_pgrid[0] - 1) / c_pgrid[0];
107 } else {
108 k_block = (sp::nprocs() + c_pgrid[1] - 1) / c_pgrid[1];
109 }
110
111 block_cyclic a_block({dr::sp::tile::div, dr::sp::tile::div},
112 {c_pgrid[0], k_block});
113 block_cyclic b_block({dr::sp::tile::div, dr::sp::tile::div},
114 {k_block, c_pgrid[1]});
115
116 return {a_block, b_block, c_block};
117}
118
119} // namespace dr::sp
Definition: index.hpp:34
Definition: matrix_partition.hpp:34
Definition: matrix_partition.hpp:23