7#define MPI_SUPPORTS_RGET_C \
8 (MPI_VERSION >= 4) || \
9 (defined(I_MPI_NUMVERSION) && (I_MPI_NUMVERSION > 20211200000))
15 communicator(MPI_Comm comm = MPI_COMM_WORLD) : mpi_comm_(comm) {
17 MPI_Comm_rank(comm, &rank);
18 MPI_Comm_size(comm, &size);
23 auto size()
const {
return size_; }
24 auto rank()
const {
return rank_; }
25 auto prev()
const {
return (rank() + size() - 1) % size(); }
26 auto next()
const {
return (rank() + 1) % size(); }
27 auto first()
const {
return rank() == 0; }
28 auto last()
const {
return rank() == size() - 1; }
30 MPI_Comm mpi_comm()
const {
return mpi_comm_; }
32 void barrier()
const {
34 DRLOG(
"calling COMM barrier (by calling fence) in ISHMEM");
37 DRLOG(
"calling COMM barrier in MPI");
38 MPI_Barrier(mpi_comm_);
39 DRLOG(
"COMM barrier finished");
42 void bcast(
void *src, std::size_t count, std::size_t root)
const {
43 MPI_Bcast(src, count, MPI_BYTE, root, mpi_comm_);
46 void scatter(
const void *src,
void *dst, std::size_t count,
47 std::size_t root)
const {
48 MPI_Scatter(src, count, MPI_BYTE, dst, count, MPI_BYTE, root, mpi_comm_);
52 void scatter(
const std::span<T> src, T &dst, std::size_t root)
const {
53 assert(rng::size(src) >= size_);
54 scatter(rng::data(src), &dst,
sizeof(T), root);
57 void scatterv(
const void *src,
int *counts,
int *offsets,
void *dst,
58 int dst_count, std::size_t root)
const {
59 assert(counts ==
nullptr || counts[rank()] == dst_count);
60 MPI_Scatterv(src, counts, offsets, MPI_BYTE, dst, dst_count, MPI_BYTE, root,
64 void gather(
const void *src,
void *dst, std::size_t count,
65 std::size_t root)
const {
66 MPI_Gather_c(src, count, MPI_BYTE, dst, count, MPI_BYTE, root, mpi_comm_);
70 void gather(
const T *src, T *dst, std::size_t count, std::size_t root)
const {
71 gather((
void *)src, (
void *)dst, count *
sizeof(T), root);
75 void gather(
const T &src, std::span<T> dst, std::size_t root)
const {
76 assert(rng::size(dst) >= size_);
77 gather(&src, rng::data(dst), 1, root);
81 void all_gather(
const T *src, T *dst, std::size_t count)
const {
83 MPI_Allgather_c(src, count *
sizeof(T), MPI_BYTE, dst, count *
sizeof(T),
88 void all_gather(
const T &src, std::vector<T> &dst)
const {
89 assert(rng::size(dst) >= size_);
90 all_gather(&src, rng::data(dst), 1);
93 template <rng::contiguous_range R>
94 void all_gather(
const R &src, R &dst)
const {
95 assert(rng::size(dst) >= size_ * rng::size(src));
96 all_gather(rng::data(src), rng::data(dst), rng::size(src));
100 void i_all_gather(
const T *src, T *dst, std::size_t count,
101 MPI_Request *req)
const {
103 MPI_Iallgather_c(src, count *
sizeof(T), MPI_BYTE, dst, count *
sizeof(T),
104 MPI_BYTE, mpi_comm_, req);
107 template <
typename T>
108 void i_all_gather(
const T &src, std::vector<T> &dst, MPI_Request *req)
const {
109 assert(rng::size(dst) >= size_);
110 i_all_gather(&src, rng::data(dst), 1, req);
113 void gatherv(
const void *src, MPI_Count *counts, MPI_Aint *offsets,
void *dst,
114 std::size_t root)
const {
115 MPI_Gatherv_c(src, counts[rank()], MPI_BYTE, dst, counts, offsets, MPI_BYTE,
120 template <
typename T>
121 void isend(
const T *data, std::size_t count, std::size_t dst_rank,
auto tag,
122 MPI_Request *request)
const {
123 MPI_Isend_c(data, count *
sizeof(T), MPI_BYTE, dst_rank,
int(tag),
128 template <
typename T>
129 void isend(
const T *data, std::size_t count, std::size_t dst_rank,
130 MPI_Request *request)
const {
131 isend(data, count, dst_rank, 0, request);
135 template <rng::contiguous_range R>
136 void isend(
const R &data, std::size_t dst_rank,
auto tag,
137 MPI_Request *request)
const {
138 isend(rng::data(data), rng::size(data), dst_rank, tag, request);
142 template <rng::contiguous_range R>
143 void isend(
const R &data, std::size_t dst_rank, MPI_Request *request)
const {
144 isend(data, dst_rank, 0, request);
148 template <
typename T>
149 void irecv(T *data, std::size_t size, std::size_t src_rank,
auto tag,
150 MPI_Request *request)
const {
151 MPI_Irecv_c(data, size *
sizeof(T), MPI_BYTE, src_rank,
int(tag), mpi_comm_,
156 template <
typename T>
157 void irecv(T *data, std::size_t size, std::size_t src_rank,
158 MPI_Request *request)
const {
159 irecv(data, size, src_rank, 0, request);
163 template <rng::contiguous_range R>
164 void irecv(R &data, std::size_t src_rank,
int tag,
165 MPI_Request *request)
const {
166 irecv(rng::data(data), rng::size(data), src_rank, tag, request);
170 template <rng::contiguous_range R>
171 void irecv(R &data, std::size_t src_rank, MPI_Request *request)
const {
172 irecv(data, src_rank, 0, request);
175 void wait(MPI_Request request)
const {
176 MPI_Wait(&request, MPI_STATUS_IGNORE);
178 void waitall(std::size_t count, MPI_Request *requests)
const {
179 MPI_Waitall(count, requests, MPI_STATUS_IGNORE);
182 template <rng::contiguous_range R>
183 void alltoall(
const R &sendr, R &recvr, std::size_t count) {
184 alltoall(rng::data(sendr), rng::data(recvr), count);
187 template <
typename T>
188 void alltoall(
const T *send, T *receive, std::size_t count) {
189 std::size_t bytes = count *
sizeof(T);
192 MPI_Alltoall_c(send, bytes, MPI_BYTE, receive, bytes, MPI_BYTE, mpi_comm_);
193 dr::drlog.debug(dr::logger::mpi,
"alltoall bytes: {} elapsed: {}\n", bytes,
197 template <rng::contiguous_range SendR, rng::contiguous_range RecvR>
198 void alltoallv(
const SendR &sendbuf,
const std::vector<std::size_t> &sendcnt,
199 const std::vector<std::size_t> &senddsp, RecvR &recvbuf,
200 const std::vector<std::size_t> &recvcnt,
201 const std::vector<std::size_t> &recvdsp) {
202 using valT =
typename RecvR::value_type;
204 static_assert(std::is_same_v<std::ranges::range_value_t<SendR>,
205 std::ranges::range_value_t<RecvR>>);
207 assert(rng::size(sendcnt) == size_);
208 assert(rng::size(senddsp) == size_);
209 assert(rng::size(recvcnt) == size_);
210 assert(rng::size(recvdsp) == size_);
212 std::vector<MPI_Count> _sendcnt(size_);
213 std::vector<MPI_Aint> _senddsp(size_);
214 std::vector<MPI_Count> _recvcnt(size_);
215 std::vector<MPI_Aint> _recvdsp(size_);
217 rng::transform(sendcnt, _sendcnt.begin(),
218 [](
auto e) { return e * sizeof(valT); });
219 rng::transform(senddsp, _senddsp.begin(),
220 [](
auto e) { return e * sizeof(valT); });
221 rng::transform(recvcnt, _recvcnt.begin(),
222 [](
auto e) { return e * sizeof(valT); });
223 rng::transform(recvdsp, _recvdsp.begin(),
224 [](
auto e) { return e * sizeof(valT); });
226 MPI_Alltoallv_c(rng::data(sendbuf), rng::data(_sendcnt),
227 rng::data(_senddsp), MPI_BYTE, rng::data(recvbuf),
228 rng::data(_recvcnt), rng::data(_recvdsp), MPI_BYTE,
233 return mpi_comm_ == other.mpi_comm_;
244 void create(
communicator comm,
void *data, std::size_t size) {
246 communicator_ = comm;
247 DRLOG(
"win create:: size: {} data:{}", size, data);
248 MPI_Win_create(data, size, 1, MPI_INFO_NULL, comm.mpi_comm(), &win_);
251 template <
typename T>
auto local_data() {
252 return static_cast<T *
>(local_data_);
255 void free() { MPI_Win_free(&win_); }
257 bool operator==(
const rma_window other)
const noexcept {
258 return this->win_ == other.win_;
261 void set_null() { win_ = MPI_WIN_NULL; }
262 bool null()
const noexcept {
return win_ == MPI_WIN_NULL; }
264 template <
typename T> T get(std::size_t rank, std::size_t disp)
const {
266 get(&dst,
sizeof(T), rank, disp *
sizeof(T));
270 void get(
void *dst, std::size_t size, std::size_t rank,
271 std::size_t disp)
const {
272 DRLOG(
"MPI comm get:: ({}:{}:{})", rank, disp, size);
274#if (MPI_VERSION >= 4) || \
275 (defined(I_MPI_NUMVERSION) && (I_MPI_NUMVERSION > 20211200000))
276 MPI_Rget_c(dst, size, MPI_BYTE, rank, disp, size, MPI_BYTE, win_, &request);
279 size <= (std::size_t)INT_MAX &&
280 "MPI API requires origin_count to be positive signed 32-bit integer");
281 MPI_Rget(dst, size, MPI_BYTE, rank, disp, size, MPI_BYTE, win_, &request);
283 MPI_Wait(&request, MPI_STATUS_IGNORE);
286 void put(
const auto &src, std::size_t rank, std::size_t disp)
const {
287 put(&src,
sizeof(src), rank, disp *
sizeof(src));
290 void put(
const void *src, std::size_t size, std::size_t rank,
291 std::size_t disp)
const {
292 DRLOG(
"MPI comm put:: ({}:{}:{})", rank, disp, size);
295#if (MPI_VERSION >= 4) || \
296 (defined(I_MPI_NUMVERSION) && (I_MPI_NUMVERSION > 20211200000))
297 MPI_Rput_c(src, size, MPI_BYTE, rank, disp, size, MPI_BYTE, win_, &request);
301 size <= (std::size_t)INT_MAX &&
302 "MPI API requires origin_count to be positive signed 32-bit integer");
303 MPI_Rput(src, size, MPI_BYTE, rank, disp, size, MPI_BYTE, win_, &request);
306 DRLOG(
"MPI comm wait:: ({}:{}:{})", rank, disp, size);
307 MPI_Wait(&request, MPI_STATUS_IGNORE);
308 DRLOG(
"MPI comm wait finished:: ({}:{}:{})", rank, disp, size);
312 if (win_ != MPI_WIN_NULL) {
313 DRLOG(
"MPI comm fence:: win:{}", win_);
314 MPI_Win_fence(0, win_);
315 DRLOG(
"MPI comm fence finished:: win:{}", win_);
317 DRLOG(
"MPI comm fence skipped because win is NULL");
321 void flush(std::size_t rank)
const {
322 DRLOG(
"MPI comm flush:: rank:{} win:{}", rank, win_);
323 MPI_Win_flush(rank, win_);
324 DRLOG(
"MPI comm flush finished:: rank:{} win:{}", rank, win_);
327 const auto &
communicator()
const {
return communicator_; }
328 auto mpi_win() {
return win_; }
332 MPI_Win win_ = MPI_WIN_NULL;
333 void *local_data_ =
nullptr;
Definition: communicator.hpp:13
Definition: communicator.hpp:242
Definition: logger.hpp:18