Distributed Ranges
Loading...
Searching...
No Matches
communicator.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#define MPI_SUPPORTS_RGET_C \
8 (MPI_VERSION >= 4) || \
9 (defined(I_MPI_NUMVERSION) && (I_MPI_NUMVERSION > 20211200000))
10
11namespace dr {
12
14public:
15 communicator(MPI_Comm comm = MPI_COMM_WORLD) : mpi_comm_(comm) {
16 int rank, size;
17 MPI_Comm_rank(comm, &rank);
18 MPI_Comm_size(comm, &size);
19 rank_ = rank;
20 size_ = size;
21 }
22
23 auto size() const { return size_; }
24 auto rank() const { return rank_; }
25 auto prev() const { return (rank() + size() - 1) % size(); }
26 auto next() const { return (rank() + 1) % size(); }
27 auto first() const { return rank() == 0; }
28 auto last() const { return rank() == size() - 1; }
29
30 MPI_Comm mpi_comm() const { return mpi_comm_; }
31
32 void barrier() const {
33#ifdef DRISHMEM
34 DRLOG("calling COMM barrier (by calling fence) in ISHMEM");
35 ishmem_fence();
36#endif
37 DRLOG("calling COMM barrier in MPI");
38 MPI_Barrier(mpi_comm_);
39 DRLOG("COMM barrier finished");
40 }
41
42 void bcast(void *src, std::size_t count, std::size_t root) const {
43 MPI_Bcast(src, count, MPI_BYTE, root, mpi_comm_);
44 }
45
46 void scatter(const void *src, void *dst, std::size_t count,
47 std::size_t root) const {
48 MPI_Scatter(src, count, MPI_BYTE, dst, count, MPI_BYTE, root, mpi_comm_);
49 }
50
51 template <typename T>
52 void scatter(const std::span<T> src, T &dst, std::size_t root) const {
53 assert(rng::size(src) >= size_);
54 scatter(rng::data(src), &dst, sizeof(T), root);
55 }
56
57 void scatterv(const void *src, int *counts, int *offsets, void *dst,
58 int dst_count, std::size_t root) const {
59 assert(counts == nullptr || counts[rank()] == dst_count);
60 MPI_Scatterv(src, counts, offsets, MPI_BYTE, dst, dst_count, MPI_BYTE, root,
61 mpi_comm_);
62 }
63
64 void gather(const void *src, void *dst, std::size_t count,
65 std::size_t root) const {
66 MPI_Gather_c(src, count, MPI_BYTE, dst, count, MPI_BYTE, root, mpi_comm_);
67 }
68
69 template <typename T>
70 void gather(const T *src, T *dst, std::size_t count, std::size_t root) const {
71 gather((void *)src, (void *)dst, count * sizeof(T), root);
72 }
73
74 template <typename T>
75 void gather(const T &src, std::span<T> dst, std::size_t root) const {
76 assert(rng::size(dst) >= size_);
77 gather(&src, rng::data(dst), 1, root);
78 }
79
80 template <typename T>
81 void all_gather(const T *src, T *dst, std::size_t count) const {
82 // Gather size elements from each rank
83 MPI_Allgather_c(src, count * sizeof(T), MPI_BYTE, dst, count * sizeof(T),
84 MPI_BYTE, mpi_comm_);
85 }
86
87 template <typename T>
88 void all_gather(const T &src, std::vector<T> &dst) const {
89 assert(rng::size(dst) >= size_);
90 all_gather(&src, rng::data(dst), 1);
91 }
92
93 template <rng::contiguous_range R>
94 void all_gather(const R &src, R &dst) const {
95 assert(rng::size(dst) >= size_ * rng::size(src));
96 all_gather(rng::data(src), rng::data(dst), rng::size(src));
97 }
98
99 template <typename T>
100 void i_all_gather(const T *src, T *dst, std::size_t count,
101 MPI_Request *req) const {
102 // Gather size elements from each rank
103 MPI_Iallgather_c(src, count * sizeof(T), MPI_BYTE, dst, count * sizeof(T),
104 MPI_BYTE, mpi_comm_, req);
105 }
106
107 template <typename T>
108 void i_all_gather(const T &src, std::vector<T> &dst, MPI_Request *req) const {
109 assert(rng::size(dst) >= size_);
110 i_all_gather(&src, rng::data(dst), 1, req);
111 }
112
113 void gatherv(const void *src, MPI_Count *counts, MPI_Aint *offsets, void *dst,
114 std::size_t root) const {
115 MPI_Gatherv_c(src, counts[rank()], MPI_BYTE, dst, counts, offsets, MPI_BYTE,
116 root, mpi_comm_);
117 }
118
119 // pointer with explicit tag
120 template <typename T>
121 void isend(const T *data, std::size_t count, std::size_t dst_rank, auto tag,
122 MPI_Request *request) const {
123 MPI_Isend_c(data, count * sizeof(T), MPI_BYTE, dst_rank, int(tag),
124 mpi_comm_, request);
125 }
126
127 // pointer, no tag
128 template <typename T>
129 void isend(const T *data, std::size_t count, std::size_t dst_rank,
130 MPI_Request *request) const {
131 isend(data, count, dst_rank, 0, request);
132 }
133
134 // range and tag
135 template <rng::contiguous_range R>
136 void isend(const R &data, std::size_t dst_rank, auto tag,
137 MPI_Request *request) const {
138 isend(rng::data(data), rng::size(data), dst_rank, tag, request);
139 }
140
141 // range, no tag
142 template <rng::contiguous_range R>
143 void isend(const R &data, std::size_t dst_rank, MPI_Request *request) const {
144 isend(data, dst_rank, 0, request);
145 }
146
147 // pointer and tag
148 template <typename T>
149 void irecv(T *data, std::size_t size, std::size_t src_rank, auto tag,
150 MPI_Request *request) const {
151 MPI_Irecv_c(data, size * sizeof(T), MPI_BYTE, src_rank, int(tag), mpi_comm_,
152 request);
153 }
154
155 // pointer, no tag
156 template <typename T>
157 void irecv(T *data, std::size_t size, std::size_t src_rank,
158 MPI_Request *request) const {
159 irecv(data, size, src_rank, 0, request);
160 }
161
162 // range and tag
163 template <rng::contiguous_range R>
164 void irecv(R &data, std::size_t src_rank, int tag,
165 MPI_Request *request) const {
166 irecv(rng::data(data), rng::size(data), src_rank, tag, request);
167 }
168
169 // range, no tag
170 template <rng::contiguous_range R>
171 void irecv(R &data, std::size_t src_rank, MPI_Request *request) const {
172 irecv(data, src_rank, 0, request);
173 }
174
175 void wait(MPI_Request request) const {
176 MPI_Wait(&request, MPI_STATUS_IGNORE);
177 }
178 void waitall(std::size_t count, MPI_Request *requests) const {
179 MPI_Waitall(count, requests, MPI_STATUS_IGNORE);
180 }
181
182 template <rng::contiguous_range R>
183 void alltoall(const R &sendr, R &recvr, std::size_t count) {
184 alltoall(rng::data(sendr), rng::data(recvr), count);
185 }
186
187 template <typename T>
188 void alltoall(const T *send, T *receive, std::size_t count) {
189 std::size_t bytes = count * sizeof(T);
190
191 timer time;
192 MPI_Alltoall_c(send, bytes, MPI_BYTE, receive, bytes, MPI_BYTE, mpi_comm_);
193 dr::drlog.debug(dr::logger::mpi, "alltoall bytes: {} elapsed: {}\n", bytes,
194 time.elapsed());
195 }
196
197 template <rng::contiguous_range SendR, rng::contiguous_range RecvR>
198 void alltoallv(const SendR &sendbuf, const std::vector<std::size_t> &sendcnt,
199 const std::vector<std::size_t> &senddsp, RecvR &recvbuf,
200 const std::vector<std::size_t> &recvcnt,
201 const std::vector<std::size_t> &recvdsp) {
202 using valT = typename RecvR::value_type;
203
204 static_assert(std::is_same_v<std::ranges::range_value_t<SendR>,
205 std::ranges::range_value_t<RecvR>>);
206
207 assert(rng::size(sendcnt) == size_);
208 assert(rng::size(senddsp) == size_);
209 assert(rng::size(recvcnt) == size_);
210 assert(rng::size(recvdsp) == size_);
211
212 std::vector<MPI_Count> _sendcnt(size_);
213 std::vector<MPI_Aint> _senddsp(size_);
214 std::vector<MPI_Count> _recvcnt(size_);
215 std::vector<MPI_Aint> _recvdsp(size_);
216
217 rng::transform(sendcnt, _sendcnt.begin(),
218 [](auto e) { return e * sizeof(valT); });
219 rng::transform(senddsp, _senddsp.begin(),
220 [](auto e) { return e * sizeof(valT); });
221 rng::transform(recvcnt, _recvcnt.begin(),
222 [](auto e) { return e * sizeof(valT); });
223 rng::transform(recvdsp, _recvdsp.begin(),
224 [](auto e) { return e * sizeof(valT); });
225
226 MPI_Alltoallv_c(rng::data(sendbuf), rng::data(_sendcnt),
227 rng::data(_senddsp), MPI_BYTE, rng::data(recvbuf),
228 rng::data(_recvcnt), rng::data(_recvdsp), MPI_BYTE,
229 mpi_comm_);
230 }
231
232 bool operator==(const communicator &other) const {
233 return mpi_comm_ == other.mpi_comm_;
234 }
235
236private:
237 MPI_Comm mpi_comm_;
238 std::size_t rank_;
239 std::size_t size_;
240};
241
243public:
244 void create(communicator comm, void *data, std::size_t size) {
245 local_data_ = data;
246 communicator_ = comm;
247 DRLOG("win create:: size: {} data:{}", size, data);
248 MPI_Win_create(data, size, 1, MPI_INFO_NULL, comm.mpi_comm(), &win_);
249 }
250
251 template <typename T> auto local_data() {
252 return static_cast<T *>(local_data_);
253 }
254
255 void free() { MPI_Win_free(&win_); }
256
257 bool operator==(const rma_window other) const noexcept {
258 return this->win_ == other.win_;
259 }
260
261 void set_null() { win_ = MPI_WIN_NULL; }
262 bool null() const noexcept { return win_ == MPI_WIN_NULL; }
263
264 template <typename T> T get(std::size_t rank, std::size_t disp) const {
265 T dst;
266 get(&dst, sizeof(T), rank, disp * sizeof(T));
267 return dst;
268 }
269
270 void get(void *dst, std::size_t size, std::size_t rank,
271 std::size_t disp) const {
272 DRLOG("MPI comm get:: ({}:{}:{})", rank, disp, size);
273 MPI_Request request;
274#if (MPI_VERSION >= 4) || \
275 (defined(I_MPI_NUMVERSION) && (I_MPI_NUMVERSION > 20211200000))
276 MPI_Rget_c(dst, size, MPI_BYTE, rank, disp, size, MPI_BYTE, win_, &request);
277#else
278 assert(
279 size <= (std::size_t)INT_MAX &&
280 "MPI API requires origin_count to be positive signed 32-bit integer");
281 MPI_Rget(dst, size, MPI_BYTE, rank, disp, size, MPI_BYTE, win_, &request);
282#endif
283 MPI_Wait(&request, MPI_STATUS_IGNORE);
284 }
285
286 void put(const auto &src, std::size_t rank, std::size_t disp) const {
287 put(&src, sizeof(src), rank, disp * sizeof(src));
288 }
289
290 void put(const void *src, std::size_t size, std::size_t rank,
291 std::size_t disp) const {
292 DRLOG("MPI comm put:: ({}:{}:{})", rank, disp, size);
293 MPI_Request request;
294
295#if (MPI_VERSION >= 4) || \
296 (defined(I_MPI_NUMVERSION) && (I_MPI_NUMVERSION > 20211200000))
297 MPI_Rput_c(src, size, MPI_BYTE, rank, disp, size, MPI_BYTE, win_, &request);
298#else
299 // MPI_Rput origin_count is 32-bit signed int - check range
300 assert(
301 size <= (std::size_t)INT_MAX &&
302 "MPI API requires origin_count to be positive signed 32-bit integer");
303 MPI_Rput(src, size, MPI_BYTE, rank, disp, size, MPI_BYTE, win_, &request);
304#endif
305
306 DRLOG("MPI comm wait:: ({}:{}:{})", rank, disp, size);
307 MPI_Wait(&request, MPI_STATUS_IGNORE);
308 DRLOG("MPI comm wait finished:: ({}:{}:{})", rank, disp, size);
309 }
310
311 void fence() const {
312 if (win_ != MPI_WIN_NULL) {
313 DRLOG("MPI comm fence:: win:{}", win_);
314 MPI_Win_fence(0, win_);
315 DRLOG("MPI comm fence finished:: win:{}", win_);
316 } else {
317 DRLOG("MPI comm fence skipped because win is NULL");
318 }
319 }
320
321 void flush(std::size_t rank) const {
322 DRLOG("MPI comm flush:: rank:{} win:{}", rank, win_);
323 MPI_Win_flush(rank, win_);
324 DRLOG("MPI comm flush finished:: rank:{} win:{}", rank, win_);
325 }
326
327 const auto &communicator() const { return communicator_; }
328 auto mpi_win() { return win_; }
329
330private:
331 dr::communicator communicator_;
332 MPI_Win win_ = MPI_WIN_NULL;
333 void *local_data_ = nullptr;
334};
335
336} // namespace dr
Definition: communicator.hpp:13
Definition: communicator.hpp:242
Definition: logger.hpp:18