Distributed Ranges
Loading...
Searching...
No Matches
halo.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <dr/mp/global.hpp>
8#include <dr/mp/sycl_support.hpp>
9
10namespace dr::mp {
11
12enum class halo_tag {
13 invalid,
14 forward,
15 reverse,
16 index,
17};
18
19template <typename Group> class halo_impl {
20 using T = typename Group::element_type;
21 using Memory = typename Group::memory_type;
22
23public:
24 using group_type = Group;
25
26 // Destructor frees buffer_, so cannot copy
27 halo_impl(const halo_impl &) = delete;
28 halo_impl operator=(const halo_impl &) = delete;
29
31 halo_impl(communicator comm, const std::vector<Group> &owned_groups,
32 const std::vector<Group> &halo_groups,
33 const Memory &memory = Memory())
34 : comm_(comm), halo_groups_(halo_groups), owned_groups_(owned_groups),
35 memory_(memory) {
36 DRLOG("Halo constructed with {}/{} owned/halo", rng::size(owned_groups),
37 rng::size(halo_groups));
38 buffer_size_ = 0;
39 std::size_t i = 0;
40 std::vector<std::size_t> buffer_index;
41 for (auto &g : owned_groups_) {
42 buffer_index.push_back(buffer_size_);
43 g.request_index = i++;
44 buffer_size_ += g.buffer_size();
45 map_.push_back(&g);
46 }
47 for (auto &g : halo_groups_) {
48 buffer_index.push_back(buffer_size_);
49 g.request_index = i++;
50 buffer_size_ += g.buffer_size();
51 map_.push_back(&g);
52 }
53 buffer_ = memory_.allocate(buffer_size_);
54 assert(buffer_ != nullptr);
55 i = 0;
56 for (auto &g : owned_groups_) {
57 g.buffer = &buffer_[buffer_index[i++]];
58 }
59 for (auto &g : halo_groups_) {
60 g.buffer = &buffer_[buffer_index[i++]];
61 }
62 requests_.resize(i);
63 }
64
67 DRLOG("Halo exchange receiving");
68 receive(halo_groups_);
69 DRLOG("Halo exchange sending");
70 send(owned_groups_);
71 DRLOG("Halo exchange begin finished");
72 }
73
76 DRLOG("Halo exchange finalize started");
78 DRLOG("Halo exchange finalize finished");
79 }
80
81 void exchange() {
84 }
85
87 void reduce_begin() {
88 receive(owned_groups_);
89 send(halo_groups_);
90 }
91
93 void reduce_finalize(const auto &op) {
94 for (int pending = rng::size(requests_); pending > 0; pending--) {
95 int completed;
96 MPI_Waitany(rng::size(requests_), requests_.data(), &completed,
97 MPI_STATUS_IGNORE);
98 DRLOG("reduce_finalize(op) waitany completed: {}", completed);
99 auto &g = *map_[completed];
100 if (g.receive && g.buffered) {
101 g.unpack(op);
102 }
103 }
104 }
105
108 for (int pending = rng::size(requests_); pending > 0; pending--) {
109 int completed;
110 MPI_Waitany(rng::size(requests_), requests_.data(), &completed,
111 MPI_STATUS_IGNORE);
112 DRLOG("reduce_finalize() waitany completed: {}", completed);
113 auto &g = *map_[completed];
114 if (g.receive && g.buffered) {
115 g.unpack();
116 }
117 }
118 }
119
120 struct second_op {
121 T operator()(T &a, T &b) const { return b; }
122 } second;
123
124 struct plus_op {
125 T operator()(T &a, T &b) const { return a + b; }
126 } plus;
127
128 struct max_op {
129 T operator()(T &a, T &b) const { return std::max(a, b); }
130 } max;
131
132 struct min_op {
133 T operator()(T &a, T &b) const { return std::min(a, b); }
134 } min;
135
137 T operator()(T &a, T &b) const { return a * b; }
138 } multiplies;
139
140 ~halo_impl() {
141 if (buffer_) {
142 memory_.deallocate(buffer_, buffer_size_);
143 buffer_ = nullptr;
144 }
145 }
146
147private:
148 void send(std::vector<Group> &sends) {
149 for (auto &g : sends) {
150 g.pack();
151 g.receive = false;
152 DRLOG("sending: {}", g.request_index);
153 comm_.isend(g.data_pointer(), g.data_size(), g.rank(), g.tag(),
154 &requests_[g.request_index]);
155 }
156 }
157
158 void receive(std::vector<Group> &receives) {
159 for (auto &g : receives) {
160 g.receive = true;
161 DRLOG("receiving: {}", g.request_index);
162 comm_.irecv(g.data_pointer(), g.data_size(), g.rank(), g.tag(),
163 &requests_[g.request_index]);
164 }
165 }
166
167 communicator comm_;
168 std::vector<Group> halo_groups_, owned_groups_;
169 T *buffer_ = nullptr;
170 std::size_t buffer_size_;
171 std::vector<MPI_Request> requests_;
172 std::vector<Group *> map_;
173 Memory memory_;
174};
175
176template <typename T, typename Memory = default_memory<T>> class index_group {
177public:
178 using element_type = T;
179 using memory_type = Memory;
180 T *buffer = nullptr;
181 std::size_t request_index;
182 bool receive;
183 bool buffered;
184
186 index_group(T *data, std::size_t rank,
187 const std::vector<std::size_t> &indices, const Memory &memory)
188 : memory_(memory), data_(data), rank_(rank) {
189 buffered = false;
190 for (std::size_t i = 0; i < rng::size(indices) - 1; i++) {
191 buffered = buffered || (indices[i + 1] - indices[i] != 1);
192 }
193 indices_size_ = rng::size(indices);
194 indices_ = memory_.template allocate<std::size_t>(indices_size_);
195 assert(indices_ != nullptr);
196 memory_.memcpy(indices_, indices.data(),
197 indices_size_ * sizeof(std::size_t));
198 }
199
200 index_group(const index_group &o)
201 : buffer(o.buffer), request_index(o.request_index), receive(o.receive),
202 buffered(o.buffered), memory_(o.memory_), data_(o.data_),
203 rank_(o.rank_), indices_size_(o.indices_size_), tag_(o.tag_) {
204 indices_ = memory_.template allocate<std::size_t>(indices_size_);
205 assert(indices_ != nullptr);
206 memory_.memcpy(indices_, o.indices_, indices_size_ * sizeof(std::size_t));
207 }
208
209 void unpack(const auto &op) {
210 T *dpt = data_;
211 auto n = indices_size_;
212 auto *ipt = indices_;
213 auto *b = buffer;
214 memory_.offload([=]() {
215 for (std::size_t i = 0; i < n; i++) {
216 dpt[ipt[i]] = op(dpt[ipt[i]], b[i]);
217 }
218 });
219 }
220
221 void pack() {
222 T *dpt = data_;
223 auto n = indices_size_;
224 auto *ipt = indices_;
225 auto *b = buffer;
226 memory_.offload([=]() {
227 for (std::size_t i = 0; i < n; i++) {
228 b[i] = dpt[ipt[i]];
229 }
230 });
231 }
232
233 std::size_t buffer_size() {
234 if (buffered) {
235 return indices_size_;
236 }
237 return 0;
238 }
239
240 T *data_pointer() {
241 if (buffered) {
242 return buffer;
243 } else {
244 return &data_[indices_[0]];
245 }
246 }
247
248 std::size_t data_size() { return indices_size_; }
249
250 std::size_t rank() { return rank_; }
251 auto tag() { return tag_; }
252
253 ~index_group() {
254 if (indices_) {
255 memory_.template deallocate<std::size_t>(indices_, indices_size_);
256 indices_ = nullptr;
257 }
258 }
259
260private:
261 Memory memory_;
262 T *data_ = nullptr;
263 std::size_t rank_;
264 std::size_t indices_size_;
265 std::size_t *indices_;
266 halo_tag tag_ = halo_tag::index;
267};
268
269template <typename T, typename Memory>
270using unstructured_halo_impl = halo_impl<index_group<T, Memory>>;
271
272template <typename T, typename Memory = default_memory<T>>
273class unstructured_halo : public unstructured_halo_impl<T, Memory> {
274public:
276 using index_map = std::pair<std::size_t, std::vector<std::size_t>>;
277
282 const std::vector<index_map> &owned,
283 const std::vector<index_map> &halo,
284 const Memory &memory = Memory())
285 : unstructured_halo_impl<T, Memory>(
286 comm, make_groups(comm, data, owned, memory),
287 make_groups(comm, data, halo, memory), memory) {}
288
289private:
290 static std::vector<group_type> make_groups(communicator comm, T *data,
291 const std::vector<index_map> &map,
292 const Memory &memory) {
293 std::vector<group_type> groups;
294 for (auto const &[rank, indices] : map) {
295 groups.emplace_back(data, rank, indices, memory);
296 }
297 return groups;
298 }
299};
300
301template <typename T, typename Memory = default_memory<T>> class span_group {
302public:
303 using element_type = T;
304 using memory_type = Memory;
305 T *buffer = nullptr;
306 std::size_t request_index = 0;
307 bool receive = false;
308 bool buffered = false;
309
310 span_group(std::span<T> data, std::size_t rank, halo_tag tag)
311 : data_(data), rank_(rank), tag_(tag) {
312#ifdef SYCL_LANGUAGE_VERSION
313 if (use_sycl() && sycl_mem_kind() == sycl::usm::alloc::shared) {
314 buffered = true;
315 }
316#endif
317 }
318
319 void unpack() {
320 if (buffered) {
321 if (mp::use_sycl()) {
322 __detail::sycl_copy(buffer, buffer + rng::size(data_), data_.data());
323 } else {
324 std::copy(buffer, buffer + rng::size(data_), data_.data());
325 }
326 }
327 }
328
329 void pack() {
330 if (buffered) {
331 if (mp::use_sycl()) {
332 __detail::sycl_copy(data_.data(), data_.data() + rng::size(data_),
333 buffer);
334 } else {
335 std::copy(data_.begin(), data_.end(), buffer);
336 }
337 }
338 }
339 std::size_t buffer_size() { return rng::size(data_); }
340
341 std::size_t data_size() { return rng::size(data_); }
342
343 T *data_pointer() {
344 if (buffered) {
345 return buffer;
346 } else {
347 return data_.data();
348 }
349 }
350
351 std::size_t rank() { return rank_; }
352
353 auto tag() { return tag_; }
354
355private:
356 Memory memory_;
357 std::span<T> data_;
358 std::size_t rank_;
359 halo_tag tag_ = halo_tag::invalid;
360};
361
363 std::size_t prev = 0, next = 0;
364 bool periodic = false;
365};
366
367template <typename T, typename Memory>
369
370template <typename T, typename Memory = default_memory<T>>
371class span_halo : public span_halo_impl<T, Memory> {
372public:
374
376
377 span_halo(communicator comm, T *data, std::size_t size, halo_bounds hb)
378 : span_halo_impl<T, Memory>(comm, owned_groups(comm, {data, size}, hb),
379 halo_groups(comm, {data, size}, hb)) {
380 check(size, hb);
381 }
382
383 span_halo(communicator comm, std::span<T> span, halo_bounds hb)
384 : span_halo_impl<T, Memory>(comm, owned_groups(comm, span, hb),
385 halo_groups(comm, span, hb)) {}
386
387private:
388 void check(auto size, auto hb) {
389 assert(size >= hb.prev + hb.next + std::max(hb.prev, hb.next));
390 }
391
392 static std::vector<group_type>
393 owned_groups(communicator comm, std::span<T> span, halo_bounds hb) {
394 std::vector<group_type> owned;
395 DRLOG("owned groups {}/{} first/last", comm.first(), comm.last());
396 if (hb.next > 0 && (hb.periodic || !comm.first())) {
397 owned.emplace_back(span.subspan(hb.prev, hb.next), comm.prev(),
398 halo_tag::reverse);
399 }
400 if (hb.prev > 0 && (hb.periodic || !comm.last())) {
401 owned.emplace_back(
402 span.subspan(rng::size(span) - (hb.prev + hb.next), hb.prev),
403 comm.next(), halo_tag::forward);
404 }
405 return owned;
406 }
407
408 static std::vector<group_type>
409 halo_groups(communicator comm, std::span<T> span, halo_bounds hb) {
410 std::vector<group_type> halo;
411 if (hb.prev > 0 && (hb.periodic || !comm.first())) {
412 halo.emplace_back(span.first(hb.prev), comm.prev(), halo_tag::forward);
413 }
414 if (hb.next > 0 && (hb.periodic || !comm.last())) {
415 halo.emplace_back(span.last(hb.next), comm.next(), halo_tag::reverse);
416 }
417 return halo;
418 }
419};
420
421} // namespace dr::mp
422
423#ifdef DR_FORMAT
424
425template <>
426struct fmt::formatter<dr::mp::halo_bounds> : formatter<string_view> {
427 template <typename FmtContext>
428 auto format(dr::mp::halo_bounds hb, FmtContext &ctx) {
429 return fmt::format_to(ctx.out(), "prev: {} next: {}", hb.prev, hb.next);
430 }
431};
432
433#endif
Definition: communicator.hpp:13
Definition: halo.hpp:19
void reduce_finalize(const auto &op)
Complete a halo reduction.
Definition: halo.hpp:93
void reduce_begin()
Begin a halo reduction.
Definition: halo.hpp:87
void reduce_finalize()
Complete a halo reduction.
Definition: halo.hpp:107
void exchange_begin()
Begin a halo exchange.
Definition: halo.hpp:66
halo_impl(communicator comm, const std::vector< Group > &owned_groups, const std::vector< Group > &halo_groups, const Memory &memory=Memory())
halo constructor
Definition: halo.hpp:31
void exchange_finalize()
Complete a halo exchange.
Definition: halo.hpp:75
Definition: halo.hpp:176
index_group(T *data, std::size_t rank, const std::vector< std::size_t > &indices, const Memory &memory)
Constructor.
Definition: halo.hpp:186
Definition: halo.hpp:301
Definition: halo.hpp:371
Definition: halo.hpp:273
unstructured_halo(communicator comm, T *data, const std::vector< index_map > &owned, const std::vector< index_map > &halo, const Memory &memory=Memory())
Definition: halo.hpp:281
Definition: halo.hpp:362
Definition: halo.hpp:128
Definition: halo.hpp:132
Definition: halo.hpp:136
Definition: halo.hpp:124
Definition: halo.hpp:120