Distributed Ranges
Loading...
Searching...
No Matches
broadcasted_slim_matrix.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7namespace dr::mp {
8
9template <typename T, typename Allocator = dr::mp::__detail::allocator<T>>
11public:
12 broadcasted_slim_matrix() = default;
13
14 void broadcast_data(std::size_t height, std::size_t width, std::size_t root,
15 T **root_data, dr::communicator comm) {
16 if (_data != nullptr) {
17 destroy_data();
18 }
19 _data_size = height * width;
20 _height = height;
21 _width = width;
22 _data = alloc.allocate(_data_size);
23 if (comm.rank() == root) {
24 for (auto i = 0; i < width; i++) {
25 if (use_sycl()) {
26 __detail::sycl_copy(root_data[i], root_data[i] + height,
27 _data + height * i);
28 } else {
29 rng::copy(root_data[i], root_data[i] + height, _data + height * i);
30 }
31 }
32 }
33 comm.bcast(_data, sizeof(T) * _data_size, root);
34 }
35
36 template <rng::input_range R>
37 void broadcast_data(std::size_t height, std::size_t width, std::size_t root,
38 R root_data, dr::communicator comm) {
39 if (_data != nullptr) {
40 destroy_data();
41 }
42 _data_size = height * width;
43 _height = height;
44 _width = width;
45 _data = alloc.allocate(_data_size);
46 if (comm.rank() == root) {
47 if (use_sycl()) {
48 __detail::sycl_copy(std::to_address(root_data.begin()),
49 std::to_address(root_data.end()), _data);
50 } else {
51 rng::copy(root_data.begin(), root_data.end(), _data);
52 }
53 }
54 std::size_t position = 0;
55 std::size_t reminder = sizeof(T) * _data_size;
56 while (reminder > INT_MAX) {
57 comm.bcast(((uint8_t *)_data) + position, INT_MAX, root);
58 position += INT_MAX;
59 reminder -= INT_MAX;
60 }
61 comm.bcast(((uint8_t *)_data) + position, reminder, root);
62 }
63
64 void destroy_data() {
65 alloc.deallocate(_data, _data_size);
66 _data_size = 0;
67 _data = nullptr;
68 }
69
70 T *operator[](std::size_t index) { return _data + _height * index; }
71
72 T *broadcasted_data() { return _data; }
73 auto width() { return _width; }
74
75private:
76 T *_data = nullptr;
77 std::size_t _data_size = 0;
78 std::size_t _width = 0;
79 std::size_t _height = 0;
80
81 Allocator alloc;
82};
83} // namespace dr::mp
Definition: communicator.hpp:13
Definition: index.hpp:34
Definition: broadcasted_slim_matrix.hpp:10