Distributed Ranges
Loading...
Searching...
No Matches
count.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7namespace dr::mp::__detail {
8
9inline auto add_counts(rng::forward_range auto &&r) {
10 rng::range_difference_t<decltype(r)> zero{};
11
12 return std::accumulate(rng::begin(r), rng::end(r), zero);
13}
14
15inline auto count_if_local(rng::forward_range auto &&r, auto &&pred) {
16 if (mp::use_sycl()) {
17 dr::drlog.debug(" with DPL\n");
18#ifdef SYCL_LANGUAGE_VERSION
19 return std::count_if(mp::dpl_policy(),
20 dr::__detail::direct_iterator(rng::begin(r)),
21 dr::__detail::direct_iterator(rng::end(r)), pred);
22#else
23 assert(false);
24#endif
25 } else {
26 dr::drlog.debug(" with CPU\n");
27 return std::count_if(std::execution::par_unseq,
28 dr::__detail::direct_iterator(rng::begin(r)),
29 dr::__detail::direct_iterator(rng::end(r)), pred);
30 }
31}
32
33template <dr::distributed_range DR>
34auto count_if(std::size_t root, bool root_provided, DR &&dr, auto &&pred) {
35 using count_type = rng::range_difference_t<decltype(dr)>;
36 auto comm = mp::default_comm();
37
38 if (rng::empty(dr)) {
39 return count_type{};
40 }
41
42 dr::drlog.debug("Parallel count\n");
43
44 // Count within the local segments
45 auto count = [=](auto &&r) {
46 assert(rng::size(r) > 0);
47 return count_if_local(r, pred);
48 };
49 auto locals = rng::views::transform(local_segments(dr), count);
50 auto local = add_counts(locals);
51
52 std::vector<count_type> all(comm.size());
53 if (root_provided) {
54 // Everyone gathers to root, only root adds up the counts
55 comm.gather(local, std::span{all}, root);
56 if (root == comm.rank()) {
57 return add_counts(all);
58 } else {
59 return count_type{};
60 }
61 } else {
62 // Everyone gathers and everyone adds up the counts
63 comm.all_gather(local, all);
64 return add_counts(all);
65 }
66}
67
68} // namespace dr::mp::__detail
69
70namespace dr::mp {
71
72class count_fn_ {
73public:
74 template <typename T, dr::distributed_range DR>
75 auto operator()(std::size_t root, DR &&dr, const T &value) const {
76 auto pred = [=](auto &&v) { return v == value; };
77 return __detail::count_if(root, true, dr, pred);
78 }
79
80 template <typename T, dr::distributed_range DR>
81 auto operator()(DR &&dr, const T &value) const {
82 auto pred = [=](auto &&v) { return v == value; };
83 return __detail::count_if(0, false, dr, pred);
84 }
85
86 template <typename T, dr::distributed_iterator DI>
87 auto operator()(std::size_t root, DI first, DI last, const T &value) const {
88 auto pred = [=](auto &&v) { return v == value; };
89 return __detail::count_if(root, true, rng::subrange(first, last), pred);
90 }
91
92 template <typename T, dr::distributed_iterator DI>
93 auto operator()(DI first, DI last, const T &value) const {
94 auto pred = [=](auto &&v) { return v == value; };
95 return __detail::count_if(0, false, rng::subrange(first, last), pred);
96 }
97};
98
99inline constexpr count_fn_ count;
100
102public:
103 template <dr::distributed_range DR>
104 auto operator()(std::size_t root, DR &&dr, auto &&pred) const {
105 return __detail::count_if(root, true, dr, pred);
106 }
107
108 template <dr::distributed_range DR>
109 auto operator()(DR &&dr, auto &&pred) const {
110 return __detail::count_if(0, false, dr, pred);
111 }
112
113 template <dr::distributed_iterator DI>
114 auto operator()(std::size_t root, DI first, DI last, auto &&pred) const {
115 return __detail::count_if(root, true, rng::subrange(first, last), pred);
116 }
117
118 template <dr::distributed_iterator DI>
119 auto operator()(DI first, DI last, auto &&pred) const {
120 return __detail::count_if(0, false, rng::subrange(first, last), pred);
121 }
122};
123
124inline constexpr count_if_fn_ count_if;
125
126}; // namespace dr::mp
Definition: onedpl_direct_iterator.hpp:15
Definition: count.hpp:72
Definition: count.hpp:101