Distributed SPMD model¶
Refer to Developer Guide: SPMD.
Programming interface¶
All types and functions in this section are declared in the
oneapi::dal::spmd::preview
namespace and are available via inclusion of the
header file from specified backend.
SPMD distributed model consists of the following components:
Additional
train
,infer
, andcompute
methods that acceptcommunicator
object as the first parameter. Those methods are expected to be called on all ranks to start distributed simulations.The communicator class that contains methods to perform collective operations among all ranks.
Free functions to create a communicator using a specified communicator backend. Available backends are
ccl
andmpi
.
Usage Example¶
The following listings provide a brief introduction on how to create a particular communicator.
MPI backend
#ifndef ONEDAL_DATA_PARALLEL
#define ONEDAL_DATA_PARALLEL
#endif
#include "oneapi/dal/algo/kmeans.hpp"
#include "oneapi/dal/spmd/mpi/communicator.hpp"
kmeans::model<> run_training(const table& data,
const table& initial_centroids) {
const auto kmeans_desc = kmeans::descriptor<float>{}
.set_cluster_count(10)
.set_max_iteration_count(50)
.set_accuracy_threshold(1e-4);
auto comm = dal::preview::spmd::make_communicator<dal::preview::spmd::backend::mpi>(queue);
auto rank_id = comm.get_rank();
const auto result_train = dal::preview::train(comm, kmeans_desc, local_input);
if(rank_id == 0) {
print_table("centroids", result.get_model().get_centroids());
print_value("objective", result.get_objective_function_value());
}
return result.get_model();
}
CCL backend
#ifndef ONEDAL_DATA_PARALLEL
#define ONEDAL_DATA_PARALLEL
#endif
#include "oneapi/dal/algo/kmeans.hpp"
#include "oneapi/dal/spmd/ccl/communicator.hpp"
kmeans::model<> run_training(const table& data,
const table& initial_centroids) {
const auto kmeans_desc = kmeans::descriptor<float>{}
.set_cluster_count(10)
.set_max_iteration_count(50)
.set_accuracy_threshold(1e-4);
auto comm = dal::preview::spmd::make_communicator<dal::preview::spmd::backend::ccl>(queue);
auto rank_id = comm.get_rank();
const auto result_train = dal::preview::train(comm, kmeans_desc, local_input);
if(rank_id == 0) {
print_table("centroids", result.get_model().get_centroids());
print_value("objective", result.get_objective_function_value());
}
return result.get_model();
}