7#include <dr/detail/index.hpp>
8#include <dr/sp/containers/detail.hpp>
9#include <dr/sp/init.hpp>
19inline constexpr std::size_t div = std::numeric_limits<std::size_t>::max();
25 virtual std::size_t tile_rank(
dr::index<> matrix_shape,
30 virtual std::unique_ptr<matrix_partition> clone()
const = 0;
37 dr::index<> grid_shape = detail::factor(dr::sp::nprocs()))
38 : tile_shape_(tile_shape), grid_shape_(grid_shape) {}
42 dr::index<> tile_shape()
const {
return tile_shape_; }
45 dr::index<> pgrid_idx = {tile_id[0] % grid_shape_[0],
46 tile_id[1] % grid_shape_[1]};
48 auto pgrid = processor_grid_();
50 return pgrid[pgrid_idx[0] * grid_shape_[1] + pgrid_idx[1]];
54 auto ts = this->tile_shape(matrix_shape);
56 return dr::index<>((matrix_shape[0] + ts[0] - 1) / ts[0],
57 (matrix_shape[1] + ts[1] - 1) / ts[1]);
61 std::array<std::size_t, 2> tshape = {tile_shape_[0], tile_shape_[1]};
63 constexpr std::size_t ndims = 2;
64 for (std::size_t i = 0; i < ndims; i++) {
65 if (tshape[i] == dr::sp::tile::div) {
66 tshape[i] = (matrix_shape[i] + grid_shape_[i] - 1) / grid_shape_[i];
73 std::unique_ptr<matrix_partition> clone()
const noexcept {
74 return std::unique_ptr<matrix_partition>(
new block_cyclic(*
this));
78 std::vector<std::size_t> processor_grid_()
const {
79 std::vector<std::size_t> grid(grid_shape_[0] * grid_shape_[1]);
81 for (std::size_t i = 0; i < grid.size(); i++) {
91inline auto row_cyclic() {
92 return block_cyclic({dr::sp::tile::div, dr::sp::tile::div},
93 {dr::sp::nprocs(), 1});
96inline std::vector<block_cyclic> partition_matmul(std::size_t m, std::size_t n,
100 block_cyclic c_block({dr::sp::tile::div, dr::sp::tile::div},
101 {c_pgrid[0], c_pgrid[1]});
105 if (m * k >= k * n) {
106 k_block = (sp::nprocs() + c_pgrid[0] - 1) / c_pgrid[0];
108 k_block = (sp::nprocs() + c_pgrid[1] - 1) / c_pgrid[1];
111 block_cyclic a_block({dr::sp::tile::div, dr::sp::tile::div},
112 {c_pgrid[0], k_block});
113 block_cyclic b_block({dr::sp::tile::div, dr::sp::tile::div},
114 {k_block, c_pgrid[1]});
116 return {a_block, b_block, c_block};
Definition: matrix_partition.hpp:34
Definition: matrix_partition.hpp:23