7#include <dr/mp/global.hpp>
8#include <dr/mp/sycl_support.hpp>
20 using T =
typename Group::element_type;
21 using Memory =
typename Group::memory_type;
24 using group_type = Group;
32 const std::vector<Group> &halo_groups,
33 const Memory &memory = Memory())
34 : comm_(comm), halo_groups_(halo_groups), owned_groups_(owned_groups),
36 DRLOG(
"Halo constructed with {}/{} owned/halo", rng::size(owned_groups),
37 rng::size(halo_groups));
40 std::vector<std::size_t> buffer_index;
41 for (
auto &g : owned_groups_) {
42 buffer_index.push_back(buffer_size_);
43 g.request_index = i++;
44 buffer_size_ += g.buffer_size();
47 for (
auto &g : halo_groups_) {
48 buffer_index.push_back(buffer_size_);
49 g.request_index = i++;
50 buffer_size_ += g.buffer_size();
53 buffer_ = memory_.allocate(buffer_size_);
54 assert(buffer_ !=
nullptr);
56 for (
auto &g : owned_groups_) {
57 g.buffer = &buffer_[buffer_index[i++]];
59 for (
auto &g : halo_groups_) {
60 g.buffer = &buffer_[buffer_index[i++]];
67 DRLOG(
"Halo exchange receiving");
68 receive(halo_groups_);
69 DRLOG(
"Halo exchange sending");
71 DRLOG(
"Halo exchange begin finished");
76 DRLOG(
"Halo exchange finalize started");
78 DRLOG(
"Halo exchange finalize finished");
88 receive(owned_groups_);
94 for (
int pending = rng::size(requests_); pending > 0; pending--) {
96 MPI_Waitany(rng::size(requests_), requests_.data(), &completed,
98 DRLOG(
"reduce_finalize(op) waitany completed: {}", completed);
99 auto &g = *map_[completed];
100 if (g.receive && g.buffered) {
108 for (
int pending = rng::size(requests_); pending > 0; pending--) {
110 MPI_Waitany(rng::size(requests_), requests_.data(), &completed,
112 DRLOG(
"reduce_finalize() waitany completed: {}", completed);
113 auto &g = *map_[completed];
114 if (g.receive && g.buffered) {
121 T operator()(T &a, T &b)
const {
return b; }
125 T operator()(T &a, T &b)
const {
return a + b; }
129 T operator()(T &a, T &b)
const {
return std::max(a, b); }
133 T operator()(T &a, T &b)
const {
return std::min(a, b); }
137 T operator()(T &a, T &b)
const {
return a * b; }
142 memory_.deallocate(buffer_, buffer_size_);
148 void send(std::vector<Group> &sends) {
149 for (
auto &g : sends) {
152 DRLOG(
"sending: {}", g.request_index);
153 comm_.isend(g.data_pointer(), g.data_size(), g.rank(), g.tag(),
154 &requests_[g.request_index]);
158 void receive(std::vector<Group> &receives) {
159 for (
auto &g : receives) {
161 DRLOG(
"receiving: {}", g.request_index);
162 comm_.irecv(g.data_pointer(), g.data_size(), g.rank(), g.tag(),
163 &requests_[g.request_index]);
168 std::vector<Group> halo_groups_, owned_groups_;
169 T *buffer_ =
nullptr;
170 std::size_t buffer_size_;
171 std::vector<MPI_Request> requests_;
172 std::vector<Group *> map_;
176template <
typename T,
typename Memory = default_memory<T>>
class index_group {
178 using element_type = T;
179 using memory_type = Memory;
181 std::size_t request_index;
187 const std::vector<std::size_t> &indices,
const Memory &memory)
188 : memory_(memory), data_(data), rank_(rank) {
190 for (std::size_t i = 0; i < rng::size(indices) - 1; i++) {
191 buffered = buffered || (indices[i + 1] - indices[i] != 1);
193 indices_size_ = rng::size(indices);
194 indices_ = memory_.template allocate<std::size_t>(indices_size_);
195 assert(indices_ !=
nullptr);
196 memory_.memcpy(indices_, indices.data(),
197 indices_size_ *
sizeof(std::size_t));
201 : buffer(o.buffer), request_index(o.request_index), receive(o.receive),
202 buffered(o.buffered), memory_(o.memory_), data_(o.data_),
203 rank_(o.rank_), indices_size_(o.indices_size_), tag_(o.tag_) {
204 indices_ = memory_.template allocate<std::size_t>(indices_size_);
205 assert(indices_ !=
nullptr);
206 memory_.memcpy(indices_, o.indices_, indices_size_ *
sizeof(std::size_t));
209 void unpack(
const auto &op) {
211 auto n = indices_size_;
212 auto *ipt = indices_;
214 memory_.offload([=]() {
215 for (std::size_t i = 0; i < n; i++) {
216 dpt[ipt[i]] = op(dpt[ipt[i]], b[i]);
223 auto n = indices_size_;
224 auto *ipt = indices_;
226 memory_.offload([=]() {
227 for (std::size_t i = 0; i < n; i++) {
233 std::size_t buffer_size() {
235 return indices_size_;
244 return &data_[indices_[0]];
248 std::size_t data_size() {
return indices_size_; }
250 std::size_t rank() {
return rank_; }
251 auto tag() {
return tag_; }
255 memory_.template deallocate<std::size_t>(indices_, indices_size_);
264 std::size_t indices_size_;
265 std::size_t *indices_;
266 halo_tag tag_ = halo_tag::index;
269template <
typename T,
typename Memory>
270using unstructured_halo_impl = halo_impl<index_group<T, Memory>>;
272template <
typename T,
typename Memory = default_memory<T>>
276 using index_map = std::pair<std::size_t, std::vector<std::size_t>>;
282 const std::vector<index_map> &owned,
283 const std::vector<index_map> &halo,
284 const Memory &memory = Memory())
286 comm, make_groups(comm, data, owned, memory),
287 make_groups(comm, data, halo, memory), memory) {}
290 static std::vector<group_type> make_groups(
communicator comm, T *data,
291 const std::vector<index_map> &map,
292 const Memory &memory) {
293 std::vector<group_type> groups;
294 for (
auto const &[rank, indices] : map) {
295 groups.emplace_back(data, rank, indices, memory);
301template <
typename T,
typename Memory = default_memory<T>>
class span_group {
303 using element_type = T;
304 using memory_type = Memory;
306 std::size_t request_index = 0;
307 bool receive =
false;
308 bool buffered =
false;
310 span_group(std::span<T> data, std::size_t rank, halo_tag tag)
311 : data_(data), rank_(rank), tag_(tag) {
312#ifdef SYCL_LANGUAGE_VERSION
313 if (use_sycl() && sycl_mem_kind() == sycl::usm::alloc::shared) {
321 if (mp::use_sycl()) {
322 __detail::sycl_copy(buffer, buffer + rng::size(data_), data_.data());
324 std::copy(buffer, buffer + rng::size(data_), data_.data());
331 if (mp::use_sycl()) {
332 __detail::sycl_copy(data_.data(), data_.data() + rng::size(data_),
335 std::copy(data_.begin(), data_.end(), buffer);
339 std::size_t buffer_size() {
return rng::size(data_); }
341 std::size_t data_size() {
return rng::size(data_); }
351 std::size_t rank() {
return rank_; }
353 auto tag() {
return tag_; }
359 halo_tag tag_ = halo_tag::invalid;
363 std::size_t prev = 0, next = 0;
364 bool periodic =
false;
367template <
typename T,
typename Memory>
370template <
typename T,
typename Memory = default_memory<T>>
379 halo_groups(comm, {data, size}, hb)) {
385 halo_groups(comm, span, hb)) {}
388 void check(
auto size,
auto hb) {
389 assert(size >= hb.prev + hb.next + std::max(hb.prev, hb.next));
392 static std::vector<group_type>
394 std::vector<group_type> owned;
395 DRLOG(
"owned groups {}/{} first/last", comm.first(), comm.last());
396 if (hb.next > 0 && (hb.periodic || !comm.first())) {
397 owned.emplace_back(span.subspan(hb.prev, hb.next), comm.prev(),
400 if (hb.prev > 0 && (hb.periodic || !comm.last())) {
402 span.subspan(rng::size(span) - (hb.prev + hb.next), hb.prev),
403 comm.next(), halo_tag::forward);
408 static std::vector<group_type>
410 std::vector<group_type> halo;
411 if (hb.prev > 0 && (hb.periodic || !comm.first())) {
412 halo.emplace_back(span.first(hb.prev), comm.prev(), halo_tag::forward);
414 if (hb.next > 0 && (hb.periodic || !comm.last())) {
415 halo.emplace_back(span.last(hb.next), comm.next(), halo_tag::reverse);
427 template <
typename FmtContext>
429 return fmt::format_to(ctx.out(),
"prev: {} next: {}", hb.prev, hb.next);
Definition: communicator.hpp:13
void reduce_finalize(const auto &op)
Complete a halo reduction.
Definition: halo.hpp:93
void reduce_begin()
Begin a halo reduction.
Definition: halo.hpp:87
void reduce_finalize()
Complete a halo reduction.
Definition: halo.hpp:107
void exchange_begin()
Begin a halo exchange.
Definition: halo.hpp:66
halo_impl(communicator comm, const std::vector< Group > &owned_groups, const std::vector< Group > &halo_groups, const Memory &memory=Memory())
halo constructor
Definition: halo.hpp:31
void exchange_finalize()
Complete a halo exchange.
Definition: halo.hpp:75
index_group(T *data, std::size_t rank, const std::vector< std::size_t > &indices, const Memory &memory)
Constructor.
Definition: halo.hpp:186
unstructured_halo(communicator comm, T *data, const std::vector< index_map > &owned, const std::vector< index_map > &halo, const Memory &memory=Memory())
Definition: halo.hpp:281