Global Namespace

Overview

// namespaces

namespace dnnl;
    namespace dnnl::graph;
        namespace dnnl::graph::ocl_interop;
        namespace dnnl::graph::sycl_interop;
    namespace dnnl::ocl_interop;
    namespace dnnl::sycl_interop;
    namespace dnnl::threadpool_interop;
    namespace dnnl::ukernel;
namespace oneapi;
namespace std;
namespace sycl;

// typedefs

typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef dnnl::memory::dim dim_t;
typedef dnnl::memory::dim dim_t;
typedef logical_tensor::data_type data_type;
typedef logical_tensor::layout_type layout_type;
typedef logical_tensor::dim dim;
typedef logical_tensor::dims dims;
typedef logical_tensor::data_type data_type;
typedef logical_tensor::layout_type layout_type;
typedef logical_tensor::property_type property_type;
typedef logical_tensor::dim dim;
typedef logical_tensor::dims dims;
typedef logical_tensor::data_type data_type;
typedef logical_tensor::layout_type layout_type;
typedef logical_tensor::dim dim;
typedef logical_tensor::dims dims;
typedef memory::format_tag tag;
typedef logical_tensor::layout_type layout_type;
typedef logical_tensor::dim dim;
typedef logical_tensor::dims dims;
typedef memory::format_tag tag;
typedef logical_tensor::layout_type layout_type;
typedef logical_tensor::dim dim;
typedef logical_tensor::dims dims;
typedef logical_tensor::data_type data_type;
typedef logical_tensor::layout_type layout_type;
typedef logical_tensor::dim dim;
typedef logical_tensor::dims dims;
typedef memory::format_tag tag;
typedef logical_tensor::layout_type layout_type;
typedef logical_tensor::dim dim;
typedef logical_tensor::dims dims;
typedef memory::format_tag tag;
typedef logical_tensor::layout_type layout_type;
typedef logical_tensor::dim dim;
typedef logical_tensor::dims dims;
typedef memory::format_tag tag;
typedef logical_tensor::layout_type layout_type;
typedef logical_tensor::dim dim;
typedef logical_tensor::dims dims;
typedef memory::format_tag tag;
typedef logical_tensor::layout_type layout_type;
typedef logical_tensor::dim dim;
typedef logical_tensor::dims dims;
typedef logical_tensor::data_type data_type;
typedef logical_tensor::layout_type layout_type;
typedef logical_tensor::dim dim;
typedef logical_tensor::dims dims;
typedef logical_tensor::data_type data_type;
typedef logical_tensor::layout_type layout_type;
typedef logical_tensor::dim dim;
typedef logical_tensor::dims dims;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef memory::format_tag tag;
typedef memory::data_type dt;
typedef struct dnnl_memory_desc* dnnl_memory_desc_t;
typedef const struct dnnl_memory_desc* const_dnnl_memory_desc_t;
typedef struct dnnl_memory* dnnl_memory_t;
typedef const struct dnnl_memory* const_dnnl_memory_t;
typedef struct dnnl_primitive_desc* dnnl_primitive_desc_t;
typedef const struct dnnl_primitive_desc* const_dnnl_primitive_desc_t;
typedef struct dnnl_primitive_attr* dnnl_primitive_attr_t;
typedef const struct dnnl_primitive_attr* const_dnnl_primitive_attr_t;
typedef struct dnnl_post_ops* dnnl_post_ops_t;
typedef const struct dnnl_post_ops* const_dnnl_post_ops_t;
typedef struct dnnl_primitive* dnnl_primitive_t;
typedef const struct dnnl_primitive* const_dnnl_primitive_t;
typedef int64_t dnnl_dim_t;
typedef dnnl_dim_t dnnl_dims_t[DNNL_MAX_NDIMS];
typedef struct dnnl_engine* dnnl_engine_t;
typedef struct dnnl_stream* dnnl_stream_t;
typedef const struct dnnl_stream* const_dnnl_stream_t;

typedef void* (*dnnl_graph_ocl_allocate_f)(
    size_t size,
    size_t alignment,
    cl_device_id device,
    cl_context context
    );

typedef void (*dnnl_graph_ocl_deallocate_f)(
    void *buf,
    cl_device_id device,
    cl_context context,
    cl_event event
    );

typedef void* (*dnnl_graph_sycl_allocate_f)(
    size_t size,
    size_t alignment,
    const void *dev,
    const void *context
    );

typedef void (*dnnl_graph_sycl_deallocate_f)(
    void *buf,
    const void *dev,
    const void *context,
    void *event
    );

typedef struct dnnl_graph_partition* dnnl_graph_partition_t;
typedef const struct dnnl_graph_partition* const_dnnl_graph_partition_t;
typedef struct dnnl_graph_graph* dnnl_graph_graph_t;
typedef const struct dnnl_graph_graph* const_dnnl_graph_graph_t;
typedef struct dnnl_graph_op* dnnl_graph_op_t;
typedef const struct dnnl_graph_op* const_dnnl_graph_op_t;

typedef void* (*dnnl_graph_host_allocate_f)(
    size_t size,
    size_t alignment
    );

typedef void (*dnnl_graph_host_deallocate_f)(void *);
typedef struct dnnl_graph_allocator* dnnl_graph_allocator_t;
typedef const struct dnnl_graph_allocator* const_dnnl_graph_allocator_t;
typedef struct dnnl_graph_compiled_partition* dnnl_graph_compiled_partition_t;
typedef const struct dnnl_graph_compiled_partition* const_dnnl_graph_compiled_partition_t;
typedef struct dnnl_graph_tensor* dnnl_graph_tensor_t;
typedef const struct dnnl_graph_tensor* const_dnnl_graph_tensor_t;
typedef struct dnnl_ukernel_attr_params* dnnl_ukernel_attr_params_t;
typedef const struct dnnl_ukernel_attr_params* const_dnnl_ukernel_attr_params_t;
typedef struct dnnl_brgemm* dnnl_brgemm_t;
typedef const struct dnnl_brgemm* const_dnnl_brgemm_t;
typedef struct dnnl_transform* dnnl_transform_t;
typedef const struct dnnl_transform* const_dnnl_transform_t;

// enums

enum api_kind;
enum dnnl_accumulation_mode_t;
enum dnnl_alg_kind_t;
enum dnnl_cpu_isa_hints_t;
enum dnnl_cpu_isa_t;
enum dnnl_data_type_t;
enum dnnl_engine_kind_t;
enum dnnl_format_kind_t;
enum dnnl_format_tag_t;
enum dnnl_fpmath_mode_t;
enum dnnl_graph_layout_type_t;
enum dnnl_graph_op_attr_t;
enum dnnl_graph_op_kind_t;
enum dnnl_graph_partition_policy_t;
enum dnnl_graph_tensor_property_t;
enum dnnl_normalization_flags_t;
enum dnnl_ocl_interop_memory_kind_t;
enum dnnl_pack_type_t;
enum dnnl_primitive_kind_t;
enum dnnl_profiling_data_kind_t;
enum dnnl_prop_kind_t;
enum dnnl_query_t;
enum dnnl_rnn_direction_t;
enum dnnl_rnn_flags_t;
enum dnnl_rounding_mode_t;
enum dnnl_scratchpad_mode_t;
enum dnnl_sparse_encoding_t;
enum dnnl_status_t;
enum dnnl_stream_flags_t;
enum dnnl_sycl_interop_memory_kind_t;

// structs

struct args_t;
struct cpu_deletor;
struct dnnl_brgemm;
struct dnnl_engine;
struct dnnl_exec_arg_t;
struct dnnl_graph_inplace_pair_t;
struct dnnl_graph_logical_tensor_t;
struct dnnl_memory;
struct dnnl_memory_desc;
struct dnnl_post_ops;
struct dnnl_primitive;
struct dnnl_primitive_attr;
struct dnnl_primitive_desc;
struct dnnl_stream;
struct dnnl_transform;
struct dnnl_ukernel_attr_params;
struct dnnl_version_t;
struct example_allows_unimplemented;
struct gemm_dims_t;
struct gqa_dims_t;
struct mlp_dims_t;
struct mqa_dims_t;
struct sdpa_dims_t;
struct sycl_deletor;

// classes

class simple_memory_pool_t;

// global variables

const dim_t batch = 32;
const dim_t src_seq_length_max = 10;
const dim_t tgt_seq_length_max = 10;
const dim_t feature_size = 256;
const dim_t enc_bidir_n_layers = 1;
const dim_t enc_unidir_n_layers = 3;
const dim_t dec_n_layers = 4;
const int lstm_n_gates = 4;
const dim_t batch = 32;
const dim_t src_seq_length_max = 10;
const dim_t tgt_seq_length_max = 10;
const dim_t feature_size = 256;
const dim_t enc_bidir_n_layers = 1;
const dim_t enc_unidir_n_layers = 3;
const dim_t dec_n_layers = 4;
const int lstm_n_gates = 4;
static const int min_runs = 4;
static const int min_runs = 4;
static const int min_runs = 4;
static const int min_runs = 4;
static const int min_runs = 4;
static const int min_runs = 4;
static const int min_runs = 4;
const memory::dims strides = {4, 4};
const memory::dims padding = {0, 0};
const int N0 = 1 + rand() % 31;
const int N1 = 1 + rand() % 31;
const int T0 = 31 + 1 + rand() % 31;
const int T1 = 1 + rand() % 31;
const int leftmost_batch = N0 + N1;
const int rightmost_batch = N0;
const int leftmost_seq_length = T1;
const int rightmost_seq_length = T0 - T1;
const int common_feature_size = 1024;
const int common_n_layers = 1;
const int lstm_n_gates = 4;
engine eng(engine::kind::cpu, 0);
int number_of_runs = 1;
float fixed_beta = 0.f;
engine eng(engine::kind::cpu, 0);
int number_of_runs = 1;
int number_of_runs = 1;

// global functions

void bnorm_u8_via_binary_postops(dnnl::engine::kind engine_kind);
int main(int argc, char** argv);
static size_t product(dnnl_dim_t* arr, size_t size);
static void init_net_data(float* data, uint32_t dim, const dnnl_dim_t* dims);
static void prepare_arg_node(args_t* node, int nargs);
static void free_arg_node(args_t* node);
static void set_arg(dnnl_exec_arg_t* arg, int arg_idx, dnnl_memory_t memory);

static void init_data_memory(
    uint32_t dim,
    const dnnl_dim_t* dims,
    dnnl_format_tag_t user_tag,
    dnnl_engine_t engine,
    float* data,
    dnnl_memory_t* memory
    );

dnnl_status_t prepare_reorder(
    dnnl_memory_t* user_memory,
    const_dnnl_memory_desc_t prim_memory_md,
    dnnl_engine_t prim_engine,
    int dir_is_user_to_prim,
    dnnl_memory_t* prim_memory,
    dnnl_primitive_t* reorder,
    uint32_t* net_index,
    dnnl_primitive_t* net,
    args_t* net_args
    );

void simple_net(dnnl_engine_kind_t engine_kind);
int main(int argc, char** argv);
void cnn_inference_f32(engine::kind engine_kind);
int main(int argc, char** argv);
int main(int argc, char** argv);
void simple_net(engine::kind engine_kind);
int main(int argc, char** argv);
void simple_net(engine::kind engine_kind);
int main(int argc, char** argv);
static size_t product(dnnl_dim_t* arr, size_t size);
static void init_net_data(float* data, uint32_t dim, const dnnl_dim_t* dims);
static void prepare_arg_node(args_t* node, int nargs);
static void free_arg_node(args_t* node);
static void set_arg(dnnl_exec_arg_t* arg, int arg_idx, dnnl_memory_t memory);

static void init_data_memory(
    uint32_t dim,
    const dnnl_dim_t* dims,
    dnnl_format_tag_t user_tag,
    dnnl_engine_t engine,
    float* data,
    dnnl_memory_t* memory
    );

dnnl_status_t prepare_reorder(
    dnnl_memory_t* user_memory,
    const_dnnl_memory_desc_t prim_memory_md,
    dnnl_engine_t prim_engine,
    int dir_is_user_to_prim,
    dnnl_memory_t* prim_memory,
    dnnl_primitive_t* reorder,
    uint32_t* net_index,
    dnnl_primitive_t* net,
    args_t* net_args
    );

void simple_net();
int main(int argc, char** argv);
bool check_result(dnnl::memory dst_mem);
void sparse_matmul();
int main(int argc, char** argv);
bool check_result(dnnl::memory dst_mem);
void sparse_matmul();
int main(int argc, char** argv);
void matmul_example(dnnl::engine::kind engine_kind);
int main(int argc, char** argv);
std::vector<float> weighted_src_layer(batch* feature_size, 1. 0f);

std::vector<float> alignment_model(
    src_seq_length_max*batch* feature_size,
    1. 0f
    );

std::vector<float> alignments(src_seq_length_max* batch, 1. 0f);
std::vector<float> exp_sums(batch, 1. 0f);

void compute_weighted_annotations(
    float* weighted_annotations,
    dim_t src_seq_length_max,
    dim_t batch,
    dim_t feature_size,
    float* weights_annot,
    float* annotations
    );

void compute_attention(
    float* context_vectors,
    dim_t src_seq_length_max,
    dim_t batch,
    dim_t feature_size,
    float* weights_src_layer,
    float* dec_src_layer,
    float* annotations,
    float* weighted_annotations,
    float* weights_alignments
    );

void copy_context(
    float* src_iter,
    dim_t n_layers,
    dim_t batch,
    dim_t feature_size
    );

int main(int argc, char** argv);
std::vector<int32_t> weighted_src_layer(batch* feature_size, 1);

std::vector<float> alignment_model(
    src_seq_length_max*batch* feature_size,
    1. 0f
    );

std::vector<float> alignments(src_seq_length_max* batch, 1. 0f);
std::vector<float> exp_sums(batch, 1. 0f);

void compute_weighted_annotations(
    float* weighted_annotations,
    dim_t src_seq_length_max,
    dim_t batch,
    dim_t feature_size,
    float* weights_annot,
    float* annotations
    );

void compute_sum_of_rows(int8_t* a, dim_t rows, dim_t cols, int32_t* a_reduced);

void compute_attention(
    float* context_vectors,
    dim_t src_seq_length_max,
    dim_t batch,
    dim_t feature_size,
    int8_t* weights_src_layer,
    float weights_src_layer_scale,
    int32_t* compensation,
    uint8_t* dec_src_layer,
    float dec_src_layer_scale,
    float dec_src_layer_shift,
    uint8_t* annotations,
    float* weighted_annotations,
    float* weights_alignments
    );

void copy_context(
    float* src_iter,
    dim_t n_layers,
    dim_t batch,
    dim_t feature_size
    );

int main(int argc, char** argv);
size_t product(int n_dims, const dnnl_dim_t dims[]);
void fill(dnnl_memory_t mem, int n_dims, const dnnl_dim_t dims[]);
int find_negative(dnnl_memory_t mem, int n_dims, const dnnl_dim_t dims[]);
void cross_engine_reorder();
int main();
void fill(memory& mem, const memory::dims& adims);
int find_negative(memory& mem, const memory::dims& adims);
int main(int argc, char** argv);
static dnnl_engine_kind_t validate_engine_kind(dnnl_engine_kind_t akind);
static dnnl_engine_kind_t parse_engine_kind(int argc, char** argv);
static const char* engine_kind2str_upper(dnnl_engine_kind_t kind);
static void read_from_dnnl_memory(void* handle, dnnl_memory_t mem);
static void write_to_dnnl_memory(void* handle, dnnl_memory_t mem);
void finalize();
dnnl::engine::kind validate_engine_kind(dnnl::engine::kind akind);
const char* engine_kind2str_upper(dnnl::engine::kind kind);

int handle_example_errors(
    std::initializer_list<dnnl::engine::kind> engine_kinds,
    std::function<void()> example
    );

int handle_example_errors(
    std::function<void(dnnl::engine::kind, int, char**)> example,
    dnnl::engine::kind engine_kind,
    int argc,
    char** argv
    );

int handle_example_errors(
    std::function<void(dnnl::engine::kind)> example,
    dnnl::engine::kind engine_kind
    );

dnnl::engine::kind parse_engine_kind(int argc, char** argv, int extra_args = 0);
dnnl::memory::dim product(const dnnl::memory::dims& dims);
void read_from_dnnl_memory(void* handle, dnnl::memory& mem);
void write_to_dnnl_memory(void* handle, dnnl::memory& mem);
int main(int argc, char** argv);

cl_kernel create_init_opencl_kernel(
    cl_context ocl_ctx,
    const char* kernel_name,
    const char* ocl_code
    );

int main(int argc, char** argv);
int main(int argc, char** argv);
int main(int argc, char** argv);
int main(int argc, char** argv);
void fill_random(std::vector<float>& out);
const char* get_type_string(logical_tensor::data_type dt);
void print_test_case(logical_tensor::data_type dt, const mlp_dims_t& p);

void bench_gated_mlp(
    engine::kind ekind,
    logical_tensor::data_type dt,
    const mlp_dims_t& p,
    double time_limit = 0.
    );

void bad_args();

void bench(
    engine::kind ekind,
    dnnl_data_type_t dt,
    const mlp_dims_t& p,
    double time_limit = 0.
    );

void mlp_perf(engine::kind ekind, int argc, char** argv);
int main(int argc, char** argv);
void fill_random(std::vector<float>& out);
const char* get_type_string(logical_tensor::data_type dt);
size_t size_of(logical_tensor::data_type dt);
void print_test_case(logical_tensor::data_type dt, const mlp_dims_t& p);

void bench_gated_mlp(
    engine::kind ekind,
    logical_tensor::data_type dt,
    const mlp_dims_t& p,
    double time_limit = 0.
    );

void bad_args();

void bench(
    engine::kind ekind,
    dnnl_data_type_t dt,
    const mlp_dims_t& p,
    double time_limit = 0.
    );

void mlp_perf(engine::kind ekind, int argc, char** argv);
int main(int argc, char** argv);
int main(int argc, char** argv);
void fill_random(std::vector<float>& out);
void fill_mask(std::vector<float>& mask, size_t seq_len);
const char* get_type_string(logical_tensor::data_type dt);
void print_test_case(logical_tensor::data_type dt, const gqa_dims_t& p);

void bench_gqa(
    engine::kind ekind,
    logical_tensor::data_type dt,
    const gqa_dims_t& p,
    double time_limit = 0.
    );

void bad_args();

void bench(
    engine::kind ekind,
    dnnl_data_type_t dt,
    const gqa_dims_t& p,
    double time_limit = 0.
    );

void gqa_perf(engine::kind ekind, int argc, char** argv);
int main(int argc, char** argv);

void set_any_layout(
    const std::vector<dnnl::graph::partition>& partitions,
    std::unordered_set<size_t>& id_to_set_any_layout
    );

void* sycl_malloc_wrapper(
    size_t size,
    size_t alignment,
    const void* dev,
    const void* ctx
    );

void sycl_free_wrapper(
    void* ptr,
    const void* device,
    const void* context,
    void* event
    );

void allocate_graph_mem(
    std::vector<dnnl::graph::tensor>& tensors,
    const std::vector<dnnl::graph::logical_tensor>& lts,
    std::vector<std::shared_ptr<void>>& data_buffer,
    const dnnl::engine& eng
    );

void allocate_graph_mem(
    std::vector<dnnl::graph::tensor>& tensors,
    const std::vector<dnnl::graph::logical_tensor>& lts,
    std::vector<std::shared_ptr<void>>& data_buffer,
    std::unordered_map<size_t, dnnl::graph::tensor>& global_outputs_ts_map,
    const dnnl::engine& eng,
    bool is_input
    );

void allocate_sycl_graph_mem(
    std::vector<dnnl::graph::tensor>& tensors,
    const std::vector<dnnl::graph::logical_tensor>& lts,
    std::vector<std::shared_ptr<void>>& data_buffer,
    sycl::queue& q,
    const dnnl::engine& eng
    );

void allocate_sycl_graph_mem(
    std::vector<dnnl::graph::tensor>& tensors,
    const std::vector<dnnl::graph::logical_tensor>& lts,
    std::vector<std::shared_ptr<void>>& data_buffer,
    std::unordered_map<size_t, dnnl::graph::tensor>& global_outputs_ts_map,
    sycl::queue& q,
    const dnnl::engine& eng,
    bool is_input
    );

static void* ocl_malloc_shared(
    size_t size,
    size_t alignment,
    cl_device_id dev,
    cl_context ctx
    );

static void* ocl_malloc_device(
    size_t size,
    size_t alignment,
    cl_device_id dev,
    cl_context ctx
    );

static void ocl_free(
    void* ptr,
    cl_device_id dev,
    const cl_context ctx,
    cl_event event
    );

void allocate_ocl_graph_mem(
    std::vector<dnnl::graph::tensor>& tensors,
    const std::vector<dnnl::graph::logical_tensor>& lts,
    std::vector<std::shared_ptr<void>>& data_buffer,
    std::unordered_map<size_t, dnnl::graph::tensor>& global_outputs_ts_map,
    const dnnl::engine& eng,
    bool is_input
    );

void ocl_memcpy(dnnl::engine& eng, void* dst, const void* src, size_t size);

dnnl::memory::desc make_md(
    const dnnl::graph::logical_tensor& lt,
    dnnl::memory::data_type dt = dnnl::memory::data_type::undef
    );

void write_dt(void* handle, dnnl::graph::tensor& ts);
void write_to_dnnl_tensor(void* handle, dnnl::graph::tensor& ts);
simple_memory_pool_t& get_mem_pool();
dnnl::graph::allocator create_allocator(dnnl::engine::kind ekind);
void fill_random(std::vector<float>& out);
void fill_mask(std::vector<float>& mask, size_t seq_len);
const char* get_type_string(logical_tensor::data_type dt);
void print_test_case(logical_tensor::data_type dt, const mqa_dims_t& p);

void bench_mqa(
    engine::kind ekind,
    logical_tensor::data_type dt,
    const mqa_dims_t& p,
    double time_limit = 0.
    );

void bad_args();

void bench(
    engine::kind ekind,
    dnnl_data_type_t dt,
    const mqa_dims_t& p,
    double time_limit = 0.
    );

void mqa_perf(engine::kind ekind, int argc, char** argv);
int main(int argc, char** argv);
void fill_random(std::vector<float>& out);
void fill_mask(std::vector<float>& mask, size_t seq_len);
void print_test_case(memory::data_type dt, const sdpa_dims_t& p);

void bench_sdpa_primitives(
    engine::kind ekind,
    memory::data_type dt,
    const sdpa_dims_t& p,
    double time_limit = 0.
    );

const char* get_type_string(logical_tensor::data_type dt);
void print_test_case(logical_tensor::data_type dt, const sdpa_dims_t& p);

void bench_sdpa(
    engine::kind ekind,
    logical_tensor::data_type dt,
    const sdpa_dims_t& p,
    double time_limit = 0.
    );

void bad_args();

void bench(
    api_kind api,
    engine::kind ekind,
    dnnl_data_type_t dt,
    const sdpa_dims_t& p,
    double time_limit = 0.
    );

void sdpa_perf(engine::kind ekind, int argc, char** argv);
int main(int argc, char** argv);
void fill_random(std::vector<float>& out);
void fill_mask(std::vector<float>& mask, size_t seq_len);
const char* get_type_string(logical_tensor::data_type dt);
size_t size_of(logical_tensor::data_type dt);
void print_test_case(logical_tensor::data_type dt, const sdpa_dims_t& p);

void bench_sdpa(
    engine::kind ekind,
    logical_tensor::data_type dt,
    const sdpa_dims_t& p,
    double time_limit = 0.
    );

void bad_args();

void bench(
    engine::kind ekind,
    dnnl_data_type_t dt,
    const sdpa_dims_t& p,
    double time_limit = 0.
    );

void sdpa_perf(engine::kind ekind, int argc, char** argv);
int main(int argc, char** argv);
int main(int argc, char** argv);
int main(int argc, char** argv);
const char* get_type_string(dt type);
void print_test_case(dt type, gemm_dims_t dims);
void fill_random(std::vector<float>& out, bool is_integer);

double run_case(
    engine::kind engine_kind,
    dt type,
    gemm_dims_t dims,
    double time_limit = 0.
    );

void run(engine::kind engine_kind, dt type, gemm_dims_t dims, double time_limit);
void bad_args();
void matmul_perf(engine::kind engine_kind, int argc, char** argv);
int main(int argc, char** argv);
int main(int argc, char** argv);
void init_data(memory& m, float v);
void create_and_execute_relu(memory& data, engine& eng, stream& s);
primitive_attr create_attr_with_relu_post_op();
void performance_profiling(engine::kind engine_kind, int argc, char** argv);
int main(int argc, char** argv);
void augru_example(dnnl::engine::kind engine_kind);
int main(int argc, char** argv);
void batch_normalization_example(dnnl::engine::kind engine_kind);
int main(int argc, char** argv);
void binary_example(dnnl::engine::kind engine_kind);
int main(int argc, char** argv);
void concat_example(dnnl::engine::kind engine_kind);
int main(int argc, char** argv);
void convolution_example(dnnl::engine::kind engine_kind);
void depthwise_convolution_example(dnnl::engine::kind engine_kind);
int main(int argc, char** argv);
void deconvolution_example(dnnl::engine::kind engine_kind);
int main(int argc, char** argv);
void eltwise_example(dnnl::engine::kind engine_kind);
int main(int argc, char** argv);
void group_normalization_example(engine::kind engine_kind);
int main(int argc, char** argv);
void inner_product_example(dnnl::engine::kind engine_kind);
int main(int argc, char** argv);
void layer_normalization_example(dnnl::engine::kind engine_kind);
int main(int argc, char** argv);
void lbr_gru_example(dnnl::engine::kind engine_kind);
int main(int argc, char** argv);
void lrn_example(dnnl::engine::kind engine_kind);
int main(int argc, char** argv);
void lstm_example(dnnl::engine::kind engine_kind);
int main(int argc, char** argv);
void matmul_example(dnnl::engine::kind engine_kind);
int main(int argc, char** argv);
void pooling_example(dnnl::engine::kind engine_kind);
int main(int argc, char** argv);
void prelu_example(dnnl::engine::kind engine_kind);
int main(int argc, char** argv);
void reduction_example(dnnl::engine::kind engine_kind);
int main(int argc, char** argv);
void reorder_example(dnnl::engine::kind engine_kind);
int main(int argc, char** argv);
void resampling_example(dnnl::engine::kind engine_kind);
int main(int argc, char** argv);
void shuffle_example(dnnl::engine::kind engine_kind);
int main(int argc, char** argv);
void softmax_example(dnnl::engine::kind engine_kind);
int main(int argc, char** argv);
void sum_example(dnnl::engine::kind engine_kind);
int main(int argc, char** argv);
void vanilla_rnn_example(dnnl::engine::kind engine_kind);
int main(int argc, char** argv);
void simple_net(engine::kind engine_kind);
int main(int argc, char** argv);
int main(int argc, char** argv);
void sycl_usm_tutorial(engine::kind engine_kind);
int main(int argc, char** argv);

void quantize(
    const std::vector<float>& X_f32,
    float scale_X,
    int32_t zp_X,
    memory& X_int_m
    );

void f32_matmul_compute(
    int64_t M,
    int64_t N,
    int64_t K,
    const std::vector<float>& A_f32,
    const std::vector<float>& B_f32,
    std::vector<float>& C_f32
    );

void dynamic_q10n_matmul(
    int64_t M,
    int64_t N,
    int64_t K,
    const std::vector<float>& A_f32,
    const std::vector<float>& B_f32,
    std::vector<uint8_t>& C_u8,
    float& scale_C,
    int32_t& zp_C
    );

void compare_f32_and_quantized_matmuls();
int main(int argc, char** argv);
matmul dynamic_matmul_create();

void dynamic_matmul_execute(
    matmul& matmul_p,
    char transA,
    char transB,
    int64_t M,
    int64_t N,
    int64_t K,
    float alpha,
    const float* A,
    int64_t lda,
    const float* B,
    int64_t ldb,
    float beta,
    float* C,
    int64_t ldc
    );

void sgemm_and_matmul_with_params(
    char transA,
    char transB,
    int64_t M,
    int64_t N,
    int64_t K,
    float alpha,
    float beta
    );

void sgemm_and_matmul();
int main(int argc, char** argv);
matmul::primitive_desc matmul_pd_create(int64_t K, int64_t N, const engine& eng);

void prepare_input(
    memory& A_u8_mem,
    memory& sc_A_mem,
    memory& sc_B_mem,
    memory& sc_C_mem,
    memory& zp_A_mem,
    memory& zp_C_mem
    );

void sanity_check(memory& C_u8_mem, memory& zp_C_mem);

void infer(
    const matmul& matmul_p,
    int64_t M,
    int64_t N,
    int64_t K,
    const memory& B_s8_mem,
    const engine& eng
    );

void inference_int8_matmul(engine::kind engine_kind);
int main(int argc, char** argv);

matmul::primitive_desc matmul_pd_create(
    int64_t M,
    int64_t N,
    int64_t K,
    int64_t G,
    const engine& eng
    );

void prepare_input(memory& A_f32_mem, memory& sc_B_mem, memory& zp_B_mem);

void infer(
    const matmul& matmul_p,
    int64_t M,
    int64_t N,
    int64_t K,
    int64_t G,
    const memory& B_s8_mem,
    const engine& eng
    );

void weights_decompression_matmul(engine::kind engine_kind);
int main(int argc, char** argv);
void brgemm_example();
int main(int argc, char** argv);
dnnl_status_t DNNL_API dnnl_primitive_desc_next_impl(dnnl_primitive_desc_t primitive_desc);

dnnl_status_t DNNL_API dnnl_primitive_desc_clone(
    dnnl_primitive_desc_t* primitive_desc,
    const_dnnl_primitive_desc_t existing_primitive_desc
    );

dnnl_status_t DNNL_API dnnl_primitive_desc_get_attr(
    const_dnnl_primitive_desc_t primitive_desc,
    const_dnnl_primitive_attr_t* attr
    );

dnnl_status_t DNNL_API dnnl_primitive_desc_destroy(dnnl_primitive_desc_t primitive_desc);

dnnl_status_t DNNL_API dnnl_primitive_desc_query(
    const_dnnl_primitive_desc_t primitive_desc,
    dnnl_query_t what,
    int index,
    void* result
    );

const_dnnl_memory_desc_t DNNL_API dnnl_primitive_desc_query_md(
    const_dnnl_primitive_desc_t primitive_desc,
    dnnl_query_t what,
    int index
    );

int DNNL_API dnnl_primitive_desc_query_s32(
    const_dnnl_primitive_desc_t primitive_desc,
    dnnl_query_t what,
    int index
    );

dnnl_status_t DNNL_API dnnl_primitive_create(
    dnnl_primitive_t* primitive,
    const_dnnl_primitive_desc_t primitive_desc
    );

dnnl_status_t DNNL_API dnnl_primitive_create_from_cache_blob(
    dnnl_primitive_t* primitive,
    const_dnnl_primitive_desc_t primitive_desc,
    size_t size,
    const uint8_t* cache_blob
    );

dnnl_status_t DNNL_API dnnl_primitive_execute(
    const_dnnl_primitive_t primitive,
    dnnl_stream_t stream,
    int nargs,
    const dnnl_exec_arg_t* args
    );

dnnl_status_t DNNL_API dnnl_primitive_get_primitive_desc(
    const_dnnl_primitive_t primitive,
    const_dnnl_primitive_desc_t* primitive_desc
    );

dnnl_status_t DNNL_API dnnl_primitive_get_cache_blob(
    const_dnnl_primitive_t primitive,
    size_t* size,
    uint8_t* cache_blob
    );

dnnl_status_t DNNL_API dnnl_primitive_destroy(dnnl_primitive_t primitive);
dnnl_status_t DNNL_API dnnl_primitive_attr_create(dnnl_primitive_attr_t* attr);

dnnl_status_t DNNL_API dnnl_primitive_attr_clone(
    dnnl_primitive_attr_t* attr,
    const_dnnl_primitive_attr_t existing_attr
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_destroy(dnnl_primitive_attr_t attr);

dnnl_status_t DNNL_API dnnl_primitive_attr_get_dropout(
    const_dnnl_primitive_attr_t attr,
    const_dnnl_memory_desc_t* dropout_desc
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_set_dropout(
    dnnl_primitive_attr_t attr,
    const_dnnl_memory_desc_t dropout_desc
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_get_fpmath_mode(
    const_dnnl_primitive_attr_t attr,
    dnnl_fpmath_mode_t* mode
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_set_fpmath_mode(
    dnnl_primitive_attr_t attr,
    dnnl_fpmath_mode_t mode
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_get_fpmath_mode_v2(
    const_dnnl_primitive_attr_t attr,
    dnnl_fpmath_mode_t* mode,
    int* apply_to_int
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_set_fpmath_mode_v2(
    dnnl_primitive_attr_t attr,
    dnnl_fpmath_mode_t mode,
    int apply_to_int
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_get_deterministic(
    const_dnnl_primitive_attr_t attr,
    int* value
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_set_deterministic(
    dnnl_primitive_attr_t attr,
    int value
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_get_accumulation_mode(
    const_dnnl_primitive_attr_t attr,
    dnnl_accumulation_mode_t* mode
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_set_accumulation_mode(
    dnnl_primitive_attr_t attr,
    dnnl_accumulation_mode_t mode
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_get_scratchpad_mode(
    const_dnnl_primitive_attr_t attr,
    dnnl_scratchpad_mode_t* mode
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_set_scratchpad_mode(
    dnnl_primitive_attr_t attr,
    dnnl_scratchpad_mode_t mode
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_mask(
    dnnl_primitive_attr_t attr,
    int arg,
    int mask
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales(
    dnnl_primitive_attr_t attr,
    int arg,
    int mask,
    int ndims,
    const dnnl_dims_t group_dims,
    dnnl_data_type_t data_type
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points_mask(
    dnnl_primitive_attr_t attr,
    int arg,
    int mask
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points(
    dnnl_primitive_attr_t attr,
    int arg,
    int mask,
    int ndims,
    const dnnl_dims_t group_dims,
    dnnl_data_type_t data_type
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_set_rounding(
    dnnl_primitive_attr_t attr,
    int arg,
    dnnl_rounding_mode_t mode
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_get_rounding(
    dnnl_primitive_attr_t attr,
    int arg,
    dnnl_rounding_mode_t* mode
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_get_post_ops(
    const_dnnl_primitive_attr_t attr,
    const_dnnl_post_ops_t* post_ops
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_set_post_ops(
    dnnl_primitive_attr_t attr,
    const_dnnl_post_ops_t post_ops
    );

dnnl_status_t DNNL_API dnnl_post_ops_create(dnnl_post_ops_t* post_ops);

dnnl_status_t DNNL_API dnnl_post_ops_clone(
    dnnl_post_ops_t* post_ops,
    const_dnnl_post_ops_t existing_post_ops
    );

dnnl_status_t DNNL_API dnnl_post_ops_destroy(dnnl_post_ops_t post_ops);
int DNNL_API dnnl_post_ops_len(const_dnnl_post_ops_t post_ops);

dnnl_primitive_kind_t DNNL_API dnnl_post_ops_get_kind(
    const_dnnl_post_ops_t post_ops,
    int index
    );

dnnl_status_t DNNL_API dnnl_post_ops_append_sum(
    dnnl_post_ops_t post_ops,
    float scale,
    int32_t zero_point,
    dnnl_data_type_t data_type
    );

dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum(
    const_dnnl_post_ops_t post_ops,
    int index,
    float* scale,
    int32_t* zero_point,
    dnnl_data_type_t* data_type
    );

dnnl_status_t DNNL_API dnnl_post_ops_append_eltwise(
    dnnl_post_ops_t post_ops,
    dnnl_alg_kind_t alg_kind,
    float alpha,
    float beta
    );

dnnl_status_t DNNL_API dnnl_post_ops_get_params_eltwise(
    const_dnnl_post_ops_t post_ops,
    int index,
    dnnl_alg_kind_t* alg_kind,
    float* alpha,
    float* beta
    );

dnnl_status_t DNNL_API dnnl_post_ops_append_dw(
    dnnl_post_ops_t post_ops,
    dnnl_data_type_t weights_data_type,
    dnnl_data_type_t bias_data_type,
    dnnl_data_type_t dst_data_type,
    dnnl_dim_t kernel_size,
    dnnl_dim_t stride_size,
    dnnl_dim_t padding_l_size
    );

dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw(
    const_dnnl_post_ops_t post_ops,
    int index,
    dnnl_data_type_t* weights_data_type,
    dnnl_data_type_t* bias_data_type,
    dnnl_data_type_t* dst_data_type,
    dnnl_dim_t* kernel_size,
    dnnl_dim_t* stride_size,
    dnnl_dim_t* padding_l_size
    );

dnnl_status_t DNNL_API dnnl_post_ops_append_binary(
    dnnl_post_ops_t post_ops,
    dnnl_alg_kind_t alg_kind,
    const_dnnl_memory_desc_t src1_desc
    );

dnnl_status_t DNNL_API dnnl_post_ops_get_params_binary(
    const_dnnl_post_ops_t post_ops,
    int index,
    dnnl_alg_kind_t* alg_kind,
    const_dnnl_memory_desc_t* src1_desc
    );

dnnl_status_t DNNL_API dnnl_post_ops_append_prelu(
    dnnl_post_ops_t post_ops,
    int mask
    );

dnnl_status_t DNNL_API dnnl_post_ops_get_params_prelu(
    const_dnnl_post_ops_t post_ops,
    int index,
    int* mask
    );

dnnl_status_t DNNL_API dnnl_memory_desc_destroy(dnnl_memory_desc_t memory_desc);

dnnl_status_t DNNL_API dnnl_memory_desc_clone(
    dnnl_memory_desc_t* memory_desc,
    const_dnnl_memory_desc_t existing_memory_desc
    );

dnnl_status_t DNNL_API dnnl_memory_desc_get_blob(
    uint8_t* blob,
    size_t* size,
    const_dnnl_memory_desc_t memory_desc
    );

dnnl_status_t DNNL_API dnnl_memory_desc_create_with_blob(
    dnnl_memory_desc_t* memory_desc,
    const uint8_t* blob
    );

dnnl_status_t DNNL_API dnnl_memory_desc_create_with_strides(
    dnnl_memory_desc_t* memory_desc,
    int ndims,
    const dnnl_dims_t dims,
    dnnl_data_type_t data_type,
    const dnnl_dims_t strides
    );

dnnl_status_t DNNL_API dnnl_memory_desc_create_with_tag(
    dnnl_memory_desc_t* memory_desc,
    int ndims,
    const dnnl_dims_t dims,
    dnnl_data_type_t data_type,
    dnnl_format_tag_t tag
    );

dnnl_status_t DNNL_API dnnl_memory_desc_create_with_csr_encoding(
    dnnl_memory_desc_t* memory_desc,
    int ndims,
    const dnnl_dims_t dims,
    dnnl_data_type_t data_type,
    dnnl_dim_t nnz,
    dnnl_data_type_t indices_dt,
    dnnl_data_type_t pointers_dt
    );

dnnl_status_t DNNL_API dnnl_memory_desc_create_with_coo_encoding(
    dnnl_memory_desc_t* memory_desc,
    int ndims,
    const dnnl_dims_t dims,
    dnnl_data_type_t data_type,
    dnnl_dim_t nnz,
    dnnl_data_type_t indices_dt
    );

dnnl_status_t DNNL_API dnnl_memory_desc_create_with_packed_encoding(
    dnnl_memory_desc_t* memory_desc,
    int ndims,
    const dnnl_dims_t dims,
    dnnl_data_type_t data_type,
    dnnl_dim_t nnz
    );

dnnl_status_t DNNL_API dnnl_memory_desc_create_submemory(
    dnnl_memory_desc_t* memory_desc,
    const_dnnl_memory_desc_t parent_memory_desc,
    const dnnl_dims_t dims,
    const dnnl_dims_t offsets
    );

dnnl_status_t DNNL_API dnnl_memory_desc_reshape(
    dnnl_memory_desc_t* out_memory_desc,
    const_dnnl_memory_desc_t in_memory_desc,
    int ndims,
    const dnnl_dims_t dims
    );

dnnl_status_t DNNL_API dnnl_memory_desc_permute_axes(
    dnnl_memory_desc_t* out_memory_desc,
    const_dnnl_memory_desc_t in_memory_desc,
    const int* permutation
    );

dnnl_status_t DNNL_API dnnl_memory_desc_query(
    const_dnnl_memory_desc_t memory_desc,
    dnnl_query_t what,
    void* result
    );

dnnl_status_t DNNL_API dnnl_memory_desc_query_v2(
    const_dnnl_memory_desc_t memory_desc,
    dnnl_query_t what,
    int index,
    void* result
    );

int DNNL_API dnnl_memory_desc_equal(
    const_dnnl_memory_desc_t lhs,
    const_dnnl_memory_desc_t rhs
    );

size_t DNNL_API dnnl_memory_desc_get_size(const_dnnl_memory_desc_t memory_desc);

size_t DNNL_API dnnl_memory_desc_get_size_v2(
    const_dnnl_memory_desc_t memory_desc,
    int index
    );

size_t DNNL_API dnnl_data_type_size(dnnl_data_type_t data_type);

dnnl_status_t DNNL_API dnnl_memory_create(
    dnnl_memory_t* memory,
    const_dnnl_memory_desc_t memory_desc,
    dnnl_engine_t engine,
    void* handle
    );

dnnl_status_t DNNL_API dnnl_memory_create_v2(
    dnnl_memory_t* memory,
    const_dnnl_memory_desc_t memory_desc,
    dnnl_engine_t engine,
    int nhandles,
    void** handles
    );

dnnl_status_t DNNL_API dnnl_memory_get_memory_desc(
    const_dnnl_memory_t memory,
    const_dnnl_memory_desc_t* memory_desc
    );

dnnl_status_t DNNL_API dnnl_memory_get_engine(
    const_dnnl_memory_t memory,
    dnnl_engine_t* engine
    );

dnnl_status_t DNNL_API dnnl_memory_map_data(
    const_dnnl_memory_t memory,
    void** mapped_ptr
    );

dnnl_status_t DNNL_API dnnl_memory_map_data_v2(
    const_dnnl_memory_t memory,
    void** mapped_ptr,
    int index
    );

dnnl_status_t DNNL_API dnnl_memory_unmap_data(
    const_dnnl_memory_t memory,
    void* mapped_ptr
    );

dnnl_status_t DNNL_API dnnl_memory_unmap_data_v2(
    const_dnnl_memory_t memory,
    void* mapped_ptr,
    int index
    );

dnnl_status_t DNNL_API dnnl_memory_get_data_handle(
    const_dnnl_memory_t memory,
    void** handle
    );

dnnl_status_t DNNL_API dnnl_memory_set_data_handle(
    dnnl_memory_t memory,
    void* handle
    );

dnnl_status_t DNNL_API dnnl_memory_get_data_handle_v2(
    const_dnnl_memory_t memory,
    void** handle,
    int index
    );

dnnl_status_t DNNL_API dnnl_memory_set_data_handle_v2(
    dnnl_memory_t memory,
    void* handle,
    int index
    );

dnnl_status_t DNNL_API dnnl_memory_destroy(dnnl_memory_t memory);

dnnl_status_t DNNL_API dnnl_reorder_primitive_desc_create(
    dnnl_primitive_desc_t* reorder_primitive_desc,
    const_dnnl_memory_desc_t src_desc,
    dnnl_engine_t src_engine,
    const_dnnl_memory_desc_t dst_desc,
    dnnl_engine_t dst_engine,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_concat_primitive_desc_create(
    dnnl_primitive_desc_t* concat_primitive_desc,
    dnnl_engine_t engine,
    const_dnnl_memory_desc_t dst_desc,
    int n,
    int concat_dimension,
    const_dnnl_memory_desc_t const* src_descs,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_sum_primitive_desc_create(
    dnnl_primitive_desc_t* sum_primitive_desc,
    dnnl_engine_t engine,
    const_dnnl_memory_desc_t dst_desc,
    int n,
    const float* scales,
    const_dnnl_memory_desc_t const* src_descs,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_binary_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_alg_kind_t alg_kind,
    const_dnnl_memory_desc_t src0_desc,
    const_dnnl_memory_desc_t src1_desc,
    const_dnnl_memory_desc_t dst_desc,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_convolution_forward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    dnnl_alg_kind_t alg_kind,
    const_dnnl_memory_desc_t src_desc,
    const_dnnl_memory_desc_t weights_desc,
    const_dnnl_memory_desc_t bias_desc,
    const_dnnl_memory_desc_t dst_desc,
    const dnnl_dims_t strides,
    const dnnl_dims_t dilates,
    const dnnl_dims_t padding_l,
    const dnnl_dims_t padding_r,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_convolution_backward_data_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_alg_kind_t alg_kind,
    const_dnnl_memory_desc_t diff_src_desc,
    const_dnnl_memory_desc_t weights_desc,
    const_dnnl_memory_desc_t diff_dst_desc,
    const dnnl_dims_t strides,
    const dnnl_dims_t dilates,
    const dnnl_dims_t padding_l,
    const dnnl_dims_t padding_r,
    const_dnnl_primitive_desc_t hint_fwd_pd,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_convolution_backward_weights_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_alg_kind_t alg_kind,
    const_dnnl_memory_desc_t src_desc,
    const_dnnl_memory_desc_t diff_weights_desc,
    const_dnnl_memory_desc_t diff_bias_desc,
    const_dnnl_memory_desc_t diff_dst_desc,
    const dnnl_dims_t strides,
    const dnnl_dims_t dilates,
    const dnnl_dims_t padding_l,
    const dnnl_dims_t padding_r,
    const_dnnl_primitive_desc_t hint_fwd_pd,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_deconvolution_forward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    dnnl_alg_kind_t alg_kind,
    const_dnnl_memory_desc_t src_desc,
    const_dnnl_memory_desc_t weights_desc,
    const_dnnl_memory_desc_t bias_desc,
    const_dnnl_memory_desc_t dst_desc,
    const dnnl_dims_t strides,
    const dnnl_dims_t dilates,
    const dnnl_dims_t padding_l,
    const dnnl_dims_t padding_r,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_deconvolution_backward_data_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_alg_kind_t alg_kind,
    const_dnnl_memory_desc_t diff_src_desc,
    const_dnnl_memory_desc_t weights_desc,
    const_dnnl_memory_desc_t diff_dst_desc,
    const dnnl_dims_t strides,
    const dnnl_dims_t dilates,
    const dnnl_dims_t padding_l,
    const dnnl_dims_t padding_r,
    const_dnnl_primitive_desc_t hint_fwd_pd,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_deconvolution_backward_weights_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_alg_kind_t alg_kind,
    const_dnnl_memory_desc_t src_desc,
    const_dnnl_memory_desc_t diff_weights_desc,
    const_dnnl_memory_desc_t diff_bias_desc,
    const_dnnl_memory_desc_t diff_dst_desc,
    const dnnl_dims_t strides,
    const dnnl_dims_t dilates,
    const dnnl_dims_t padding_l,
    const dnnl_dims_t padding_r,
    const_dnnl_primitive_desc_t hint_fwd_pd,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_shuffle_forward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    const_dnnl_memory_desc_t src_desc,
    const_dnnl_memory_desc_t dst_desc,
    int axis,
    dnnl_dim_t group_size,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_shuffle_backward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    const_dnnl_memory_desc_t diff_src_desc,
    const_dnnl_memory_desc_t diff_dst_desc,
    int axis,
    dnnl_dim_t group_size,
    const_dnnl_primitive_desc_t hint_fwd_pd,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_eltwise_forward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    dnnl_alg_kind_t alg_kind,
    const_dnnl_memory_desc_t src_desc,
    const_dnnl_memory_desc_t dst_desc,
    float alpha,
    float beta,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_eltwise_backward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_alg_kind_t alg_kind,
    const_dnnl_memory_desc_t diff_src_desc,
    const_dnnl_memory_desc_t diff_dst_desc,
    const_dnnl_memory_desc_t data_desc,
    float alpha,
    float beta,
    const_dnnl_primitive_desc_t hint_fwd_pd,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_softmax_forward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    dnnl_alg_kind_t alg_kind,
    const_dnnl_memory_desc_t src_desc,
    const_dnnl_memory_desc_t dst_desc,
    int softmax_axis,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_softmax_backward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_alg_kind_t alg_kind,
    const_dnnl_memory_desc_t diff_src_desc,
    const_dnnl_memory_desc_t diff_dst_desc,
    const_dnnl_memory_desc_t dst_desc,
    int softmax_axis,
    const_dnnl_primitive_desc_t hint_fwd_pd,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_pooling_forward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    dnnl_alg_kind_t alg_kind,
    const_dnnl_memory_desc_t src_desc,
    const_dnnl_memory_desc_t dst_desc,
    const dnnl_dims_t strides,
    const dnnl_dims_t kernel,
    const dnnl_dims_t dilation,
    const dnnl_dims_t padding_l,
    const dnnl_dims_t padding_r,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_pooling_backward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_alg_kind_t alg_kind,
    const_dnnl_memory_desc_t diff_src_desc,
    const_dnnl_memory_desc_t diff_dst_desc,
    const dnnl_dims_t strides,
    const dnnl_dims_t kernel,
    const dnnl_dims_t dilation,
    const dnnl_dims_t padding_l,
    const dnnl_dims_t padding_r,
    const_dnnl_primitive_desc_t hint_fwd_pd,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_prelu_forward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    const_dnnl_memory_desc_t src_desc,
    const_dnnl_memory_desc_t weights_desc,
    const_dnnl_memory_desc_t dst_desc,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_prelu_backward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    const_dnnl_memory_desc_t src_desc,
    const_dnnl_memory_desc_t weights_desc,
    const_dnnl_memory_desc_t diff_src_desc,
    const_dnnl_memory_desc_t diff_weights_desc,
    const_dnnl_memory_desc_t diff_dst_desc,
    const_dnnl_primitive_desc_t hint_fwd_pd,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_lrn_forward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    dnnl_alg_kind_t alg_kind,
    const_dnnl_memory_desc_t src_desc,
    const_dnnl_memory_desc_t dst_desc,
    dnnl_dim_t local_size,
    float alpha,
    float beta,
    float k,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_lrn_backward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_alg_kind_t alg_kind,
    const_dnnl_memory_desc_t diff_src_desc,
    const_dnnl_memory_desc_t diff_dst_desc,
    const_dnnl_memory_desc_t src_desc,
    dnnl_dim_t local_size,
    float alpha,
    float beta,
    float k,
    const_dnnl_primitive_desc_t hint_fwd_pd,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_batch_normalization_forward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    const_dnnl_memory_desc_t src_desc,
    const_dnnl_memory_desc_t dst_desc,
    float epsilon,
    unsigned flags,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_batch_normalization_backward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    const_dnnl_memory_desc_t diff_src_desc,
    const_dnnl_memory_desc_t diff_dst_desc,
    const_dnnl_memory_desc_t src_desc,
    float epsilon,
    unsigned flags,
    const_dnnl_primitive_desc_t hint_fwd_pd,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_group_normalization_forward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    const_dnnl_memory_desc_t src_desc,
    const_dnnl_memory_desc_t dst_desc,
    dnnl_dim_t groups,
    float epsilon,
    unsigned flags,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_group_normalization_backward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    const_dnnl_memory_desc_t diff_src_desc,
    const_dnnl_memory_desc_t diff_dst_desc,
    const_dnnl_memory_desc_t src_desc,
    dnnl_dim_t groups,
    float epsilon,
    unsigned flags,
    const_dnnl_primitive_desc_t hint_fwd_pd,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_layer_normalization_forward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    const_dnnl_memory_desc_t src_desc,
    const_dnnl_memory_desc_t dst_desc,
    const_dnnl_memory_desc_t stat_desc,
    float epsilon,
    unsigned flags,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_layer_normalization_backward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    const_dnnl_memory_desc_t diff_src_desc,
    const_dnnl_memory_desc_t diff_dst_desc,
    const_dnnl_memory_desc_t src_desc,
    const_dnnl_memory_desc_t stat_desc,
    float epsilon,
    unsigned flags,
    const_dnnl_primitive_desc_t hint_fwd_pd,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_layer_normalization_forward_primitive_desc_create_v2(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    const_dnnl_memory_desc_t src_desc,
    const_dnnl_memory_desc_t dst_desc,
    const_dnnl_memory_desc_t stat_desc,
    dnnl_data_type_t scale_shift_data_type,
    float epsilon,
    unsigned flags,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_layer_normalization_backward_primitive_desc_create_v2(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    const_dnnl_memory_desc_t diff_src_desc,
    const_dnnl_memory_desc_t diff_dst_desc,
    const_dnnl_memory_desc_t src_desc,
    const_dnnl_memory_desc_t stat_desc,
    dnnl_data_type_t diff_scale_shift_data_type,
    dnnl_data_type_t scale_shift_data_type,
    float epsilon,
    unsigned flags,
    const_dnnl_primitive_desc_t hint_fwd_pd,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_inner_product_forward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    const_dnnl_memory_desc_t src_desc,
    const_dnnl_memory_desc_t weights_desc,
    const_dnnl_memory_desc_t bias_desc,
    const_dnnl_memory_desc_t dst_desc,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_inner_product_backward_data_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    const_dnnl_memory_desc_t diff_src_desc,
    const_dnnl_memory_desc_t weights_desc,
    const_dnnl_memory_desc_t diff_dst_desc,
    const_dnnl_primitive_desc_t hint_fwd_pd,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_inner_product_backward_weights_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    const_dnnl_memory_desc_t src_desc,
    const_dnnl_memory_desc_t diff_weights_desc,
    const_dnnl_memory_desc_t diff_bias_desc,
    const_dnnl_memory_desc_t diff_dst_desc,
    const_dnnl_primitive_desc_t hint_fwd_pd,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_data_qparams(
    dnnl_primitive_attr_t attr,
    const float scale,
    const float shift
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_data_qparams(
    const_dnnl_primitive_attr_t attr,
    float* scale,
    float* shift
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_qparams(
    dnnl_primitive_attr_t attr,
    dnnl_dim_t count,
    int mask,
    const float* scales
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_qparams(
    const_dnnl_primitive_attr_t attr,
    dnnl_dim_t* count,
    int* mask,
    const float** scales
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_projection_qparams(
    dnnl_primitive_attr_t attr,
    dnnl_dim_t count,
    int mask,
    const float* scales
    );

dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_projection_qparams(
    const_dnnl_primitive_attr_t attr,
    dnnl_dim_t* count,
    int* mask,
    const float** scales
    );

dnnl_status_t DNNL_API dnnl_vanilla_rnn_forward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    const dnnl_alg_kind_t activation,
    const dnnl_rnn_direction_t direction,
    const_dnnl_memory_desc_t src_layer_desc,
    const_dnnl_memory_desc_t src_iter_desc,
    const_dnnl_memory_desc_t weights_layer_desc,
    const_dnnl_memory_desc_t weights_iter_desc,
    const_dnnl_memory_desc_t bias_desc,
    const_dnnl_memory_desc_t dst_layer_desc,
    const_dnnl_memory_desc_t dst_iter_desc,
    unsigned flags,
    float alpha,
    float beta,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_vanilla_rnn_backward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    const dnnl_alg_kind_t activation,
    const dnnl_rnn_direction_t direction,
    const_dnnl_memory_desc_t src_layer_desc,
    const_dnnl_memory_desc_t src_iter_desc,
    const_dnnl_memory_desc_t weights_layer_desc,
    const_dnnl_memory_desc_t weights_iter_desc,
    const_dnnl_memory_desc_t bias_desc,
    const_dnnl_memory_desc_t dst_layer_desc,
    const_dnnl_memory_desc_t dst_iter_desc,
    const_dnnl_memory_desc_t diff_src_layer_desc,
    const_dnnl_memory_desc_t diff_src_iter_desc,
    const_dnnl_memory_desc_t diff_weights_layer_desc,
    const_dnnl_memory_desc_t diff_weights_iter_desc,
    const_dnnl_memory_desc_t diff_bias_desc,
    const_dnnl_memory_desc_t diff_dst_layer_desc,
    const_dnnl_memory_desc_t diff_dst_iter_desc,
    unsigned flags,
    float alpha,
    float beta,
    const_dnnl_primitive_desc_t hint_fwd_pd,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_lstm_forward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    dnnl_rnn_direction_t direction,
    const_dnnl_memory_desc_t src_layer_desc,
    const_dnnl_memory_desc_t src_iter_desc,
    const_dnnl_memory_desc_t src_iter_c_desc,
    const_dnnl_memory_desc_t weights_layer_desc,
    const_dnnl_memory_desc_t weights_iter_desc,
    const_dnnl_memory_desc_t weights_peephole_desc,
    const_dnnl_memory_desc_t weights_projection_desc,
    const_dnnl_memory_desc_t bias_desc,
    const_dnnl_memory_desc_t dst_layer_desc,
    const_dnnl_memory_desc_t dst_iter_desc,
    const_dnnl_memory_desc_t dst_iter_c_desc,
    unsigned flags,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_lstm_backward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    dnnl_rnn_direction_t direction,
    const_dnnl_memory_desc_t src_layer_desc,
    const_dnnl_memory_desc_t src_iter_desc,
    const_dnnl_memory_desc_t src_iter_c_desc,
    const_dnnl_memory_desc_t weights_layer_desc,
    const_dnnl_memory_desc_t weights_iter_desc,
    const_dnnl_memory_desc_t weights_peephole_desc,
    const_dnnl_memory_desc_t weights_projection_desc,
    const_dnnl_memory_desc_t bias_desc,
    const_dnnl_memory_desc_t dst_layer_desc,
    const_dnnl_memory_desc_t dst_iter_desc,
    const_dnnl_memory_desc_t dst_iter_c_desc,
    const_dnnl_memory_desc_t diff_src_layer_desc,
    const_dnnl_memory_desc_t diff_src_iter_desc,
    const_dnnl_memory_desc_t diff_src_iter_c_desc,
    const_dnnl_memory_desc_t diff_weights_layer_desc,
    const_dnnl_memory_desc_t diff_weights_iter_desc,
    const_dnnl_memory_desc_t diff_weights_peephole_desc,
    const_dnnl_memory_desc_t diff_weights_projection_desc,
    const_dnnl_memory_desc_t diff_bias_desc,
    const_dnnl_memory_desc_t diff_dst_layer_desc,
    const_dnnl_memory_desc_t diff_dst_iter_desc,
    const_dnnl_memory_desc_t diff_dst_iter_c_desc,
    unsigned flags,
    const_dnnl_primitive_desc_t hint_fwd_pd,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_gru_forward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    dnnl_rnn_direction_t direction,
    const_dnnl_memory_desc_t src_layer_desc,
    const_dnnl_memory_desc_t src_iter_desc,
    const_dnnl_memory_desc_t weights_layer_desc,
    const_dnnl_memory_desc_t weights_iter_desc,
    const_dnnl_memory_desc_t bias_desc,
    const_dnnl_memory_desc_t dst_layer_desc,
    const_dnnl_memory_desc_t dst_iter_desc,
    unsigned flags,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_gru_backward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    dnnl_rnn_direction_t direction,
    const_dnnl_memory_desc_t src_layer_desc,
    const_dnnl_memory_desc_t src_iter_desc,
    const_dnnl_memory_desc_t weights_layer_desc,
    const_dnnl_memory_desc_t weights_iter_desc,
    const_dnnl_memory_desc_t bias_desc,
    const_dnnl_memory_desc_t dst_layer_desc,
    const_dnnl_memory_desc_t dst_iter_desc,
    const_dnnl_memory_desc_t diff_src_layer_desc,
    const_dnnl_memory_desc_t diff_src_iter_desc,
    const_dnnl_memory_desc_t diff_weights_layer_desc,
    const_dnnl_memory_desc_t diff_weights_iter_desc,
    const_dnnl_memory_desc_t diff_bias_desc,
    const_dnnl_memory_desc_t diff_dst_layer_desc,
    const_dnnl_memory_desc_t diff_dst_iter_desc,
    unsigned flags,
    const_dnnl_primitive_desc_t hint_fwd_pd,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_lbr_gru_forward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    dnnl_rnn_direction_t direction,
    const_dnnl_memory_desc_t src_layer_desc,
    const_dnnl_memory_desc_t src_iter_desc,
    const_dnnl_memory_desc_t weights_layer_desc,
    const_dnnl_memory_desc_t weights_iter_desc,
    const_dnnl_memory_desc_t bias_desc,
    const_dnnl_memory_desc_t dst_layer_desc,
    const_dnnl_memory_desc_t dst_iter_desc,
    unsigned flags,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_lbr_gru_backward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    dnnl_rnn_direction_t direction,
    const_dnnl_memory_desc_t src_layer_desc,
    const_dnnl_memory_desc_t src_iter_desc,
    const_dnnl_memory_desc_t weights_layer_desc,
    const_dnnl_memory_desc_t weights_iter_desc,
    const_dnnl_memory_desc_t bias_desc,
    const_dnnl_memory_desc_t dst_layer_desc,
    const_dnnl_memory_desc_t dst_iter_desc,
    const_dnnl_memory_desc_t diff_src_layer_desc,
    const_dnnl_memory_desc_t diff_src_iter_desc,
    const_dnnl_memory_desc_t diff_weights_layer_desc,
    const_dnnl_memory_desc_t diff_weights_iter_desc,
    const_dnnl_memory_desc_t diff_bias_desc,
    const_dnnl_memory_desc_t diff_dst_layer_desc,
    const_dnnl_memory_desc_t diff_dst_iter_desc,
    unsigned flags,
    const_dnnl_primitive_desc_t hint_fwd_pd,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_augru_forward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    dnnl_rnn_direction_t direction,
    const_dnnl_memory_desc_t src_layer_desc,
    const_dnnl_memory_desc_t src_iter_desc,
    const_dnnl_memory_desc_t attention_desc,
    const_dnnl_memory_desc_t weights_layer_desc,
    const_dnnl_memory_desc_t weights_iter_desc,
    const_dnnl_memory_desc_t bias_desc,
    const_dnnl_memory_desc_t dst_layer_desc,
    const_dnnl_memory_desc_t dst_iter_desc,
    unsigned flags,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_augru_backward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    dnnl_rnn_direction_t direction,
    const_dnnl_memory_desc_t src_layer_desc,
    const_dnnl_memory_desc_t src_iter_desc,
    const_dnnl_memory_desc_t attention_desc,
    const_dnnl_memory_desc_t weights_layer_desc,
    const_dnnl_memory_desc_t weights_iter_desc,
    const_dnnl_memory_desc_t bias_desc,
    const_dnnl_memory_desc_t dst_layer_desc,
    const_dnnl_memory_desc_t dst_iter_desc,
    const_dnnl_memory_desc_t diff_src_layer_desc,
    const_dnnl_memory_desc_t diff_src_iter_desc,
    const_dnnl_memory_desc_t diff_attention_desc,
    const_dnnl_memory_desc_t diff_weights_layer_desc,
    const_dnnl_memory_desc_t diff_weights_iter_desc,
    const_dnnl_memory_desc_t diff_bias_desc,
    const_dnnl_memory_desc_t diff_dst_layer_desc,
    const_dnnl_memory_desc_t diff_dst_iter_desc,
    unsigned flags,
    const_dnnl_primitive_desc_t hint_fwd_pd,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_lbr_augru_forward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    dnnl_rnn_direction_t direction,
    const_dnnl_memory_desc_t src_layer_desc,
    const_dnnl_memory_desc_t src_iter_desc,
    const_dnnl_memory_desc_t attention_desc,
    const_dnnl_memory_desc_t weights_layer_desc,
    const_dnnl_memory_desc_t weights_iter_desc,
    const_dnnl_memory_desc_t bias_desc,
    const_dnnl_memory_desc_t dst_layer_desc,
    const_dnnl_memory_desc_t dst_iter_desc,
    unsigned flags,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_lbr_augru_backward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    dnnl_rnn_direction_t direction,
    const_dnnl_memory_desc_t src_layer_desc,
    const_dnnl_memory_desc_t src_iter_desc,
    const_dnnl_memory_desc_t attention_desc,
    const_dnnl_memory_desc_t weights_layer_desc,
    const_dnnl_memory_desc_t weights_iter_desc,
    const_dnnl_memory_desc_t bias_desc,
    const_dnnl_memory_desc_t dst_layer_desc,
    const_dnnl_memory_desc_t dst_iter_desc,
    const_dnnl_memory_desc_t diff_src_layer_desc,
    const_dnnl_memory_desc_t diff_src_iter_desc,
    const_dnnl_memory_desc_t diff_attention_desc,
    const_dnnl_memory_desc_t diff_weights_layer_desc,
    const_dnnl_memory_desc_t diff_weights_iter_desc,
    const_dnnl_memory_desc_t diff_bias_desc,
    const_dnnl_memory_desc_t diff_dst_layer_desc,
    const_dnnl_memory_desc_t diff_dst_iter_desc,
    unsigned flags,
    const_dnnl_primitive_desc_t hint_fwd_pd,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_matmul_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    const_dnnl_memory_desc_t src_desc,
    const_dnnl_memory_desc_t weights_desc,
    const_dnnl_memory_desc_t bias_desc,
    const_dnnl_memory_desc_t dst_desc,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_resampling_forward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_prop_kind_t prop_kind,
    dnnl_alg_kind_t alg_kind,
    const float* factors,
    const_dnnl_memory_desc_t src_desc,
    const_dnnl_memory_desc_t dst_desc,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_resampling_backward_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_alg_kind_t alg_kind,
    const float* factors,
    const_dnnl_memory_desc_t diff_src_desc,
    const_dnnl_memory_desc_t diff_dst_desc,
    const_dnnl_primitive_desc_t hint_fwd_pd,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_reduction_primitive_desc_create(
    dnnl_primitive_desc_t* primitive_desc,
    dnnl_engine_t engine,
    dnnl_alg_kind_t alg_kind,
    const_dnnl_memory_desc_t src_desc,
    const_dnnl_memory_desc_t dst_desc,
    float p,
    float eps,
    const_dnnl_primitive_attr_t attr
    );

dnnl_status_t DNNL_API dnnl_get_primitive_cache_capacity(int* capacity);
dnnl_status_t DNNL_API dnnl_set_primitive_cache_capacity(int capacity);
dnnl_status_t DNNL_API dnnl_set_jit_dump(int enable);
dnnl_status_t DNNL_API dnnl_set_jit_profiling_flags(unsigned flags);
dnnl_status_t DNNL_API dnnl_set_jit_profiling_jitdumpdir(const char* dir);
dnnl_status_t DNNL_API dnnl_set_max_cpu_isa(dnnl_cpu_isa_t isa);
dnnl_cpu_isa_t DNNL_API dnnl_get_effective_cpu_isa(void);
dnnl_status_t DNNL_API dnnl_set_cpu_isa_hints(dnnl_cpu_isa_hints_t isa_hints);
dnnl_cpu_isa_hints_t DNNL_API dnnl_get_cpu_isa_hints(void);
dnnl_status_t DNNL_API dnnl_reset_profiling(dnnl_stream_t stream);

dnnl_status_t DNNL_API dnnl_query_profiling_data(
    dnnl_stream_t stream,
    dnnl_profiling_data_kind_t data_kind,
    int* num_entries,
    uint64_t* data
    );

dnnl_status_t DNNL_API dnnl_sgemm(
    char transa,
    char transb,
    dnnl_dim_t M,
    dnnl_dim_t N,
    dnnl_dim_t K,
    float alpha,
    const float* A,
    dnnl_dim_t lda,
    const float* B,
    dnnl_dim_t ldb,
    float beta,
    float* C,
    dnnl_dim_t ldc
    );

dnnl_status_t DNNL_API dnnl_gemm_u8s8s32(
    char transa,
    char transb,
    char offsetc,
    dnnl_dim_t M,
    dnnl_dim_t N,
    dnnl_dim_t K,
    float alpha,
    const uint8_t* A,
    dnnl_dim_t lda,
    uint8_t ao,
    const int8_t* B,
    dnnl_dim_t ldb,
    int8_t bo,
    float beta,
    int32_t* C,
    dnnl_dim_t ldc,
    const int32_t* co
    );

dnnl_status_t DNNL_API dnnl_gemm_s8s8s32(
    char transa,
    char transb,
    char offsetc,
    dnnl_dim_t M,
    dnnl_dim_t N,
    dnnl_dim_t K,
    float alpha,
    const int8_t* A,
    dnnl_dim_t lda,
    int8_t ao,
    const int8_t* B,
    dnnl_dim_t ldb,
    int8_t bo,
    float beta,
    int32_t* C,
    dnnl_dim_t ldc,
    const int32_t* co
    );

const char DNNL_API* dnnl_status2str(dnnl_status_t v);
const char DNNL_API* dnnl_dt2str(dnnl_data_type_t v);
const char DNNL_API* dnnl_fpmath_mode2str(dnnl_fpmath_mode_t v);
const char DNNL_API* dnnl_accumulation_mode2str(dnnl_accumulation_mode_t v);
const char DNNL_API* dnnl_engine_kind2str(dnnl_engine_kind_t v);
const char DNNL_API* dnnl_sparse_encoding2str(dnnl_sparse_encoding_t v);
const char DNNL_API* dnnl_fmt_tag2str(dnnl_format_tag_t v);
const char DNNL_API* dnnl_prop_kind2str(dnnl_prop_kind_t v);
const char DNNL_API* dnnl_prim_kind2str(dnnl_primitive_kind_t v);
const char DNNL_API* dnnl_alg_kind2str(dnnl_alg_kind_t v);
const char DNNL_API* dnnl_rnn_flags2str(dnnl_rnn_flags_t v);
const char DNNL_API* dnnl_rnn_direction2str(dnnl_rnn_direction_t v);
const char DNNL_API* dnnl_scratchpad_mode2str(dnnl_scratchpad_mode_t v);
const char DNNL_API* dnnl_rounding_mode2str(dnnl_rounding_mode_t v);
const char DNNL_API* dnnl_cpu_isa2str(dnnl_cpu_isa_t v);
const char DNNL_API* dnnl_cpu_isa_hints2str(dnnl_cpu_isa_hints_t v);
const char DNNL_API* dnnl_runtime2str(unsigned v);
const char DNNL_API* dnnl_fmt_kind2str(dnnl_format_kind_t v);

dnnl_status_t DNNL_API dnnl_ocl_interop_memory_create(
    dnnl_memory_t* memory,
    const_dnnl_memory_desc_t memory_desc,
    dnnl_engine_t engine,
    dnnl_ocl_interop_memory_kind_t memory_kind,
    void* handle
    );

dnnl_status_t DNNL_API dnnl_ocl_interop_memory_create_v2(
    dnnl_memory_t* memory,
    const_dnnl_memory_desc_t memory_desc,
    dnnl_engine_t engine,
    dnnl_ocl_interop_memory_kind_t memory_kind,
    int nhandles,
    void** handles
    );

dnnl_status_t DNNL_API dnnl_ocl_interop_memory_get_memory_kind(
    const_dnnl_memory_t memory,
    dnnl_ocl_interop_memory_kind_t* memory_kind
    );

dnnl_status_t DNNL_API dnnl_ocl_interop_memory_get_mem_object(
    const_dnnl_memory_t memory,
    cl_mem* mem_object
    );

dnnl_status_t DNNL_API dnnl_ocl_interop_memory_set_mem_object(
    dnnl_memory_t memory,
    cl_mem mem_object
    );

dnnl_status_t DNNL_API dnnl_ocl_interop_engine_get_cache_blob_id(
    cl_device_id device,
    size_t* size,
    uint8_t* cache_blob_id
    );

dnnl_status_t DNNL_API dnnl_ocl_interop_engine_get_cache_blob(
    dnnl_engine_t engine,
    size_t* size,
    uint8_t* cache_blob
    );

dnnl_status_t DNNL_API dnnl_ocl_interop_engine_create_from_cache_blob(
    dnnl_engine_t* engine,
    cl_device_id device,
    cl_context context,
    size_t size,
    const uint8_t* cache_blob
    );

dnnl_status_t DNNL_API dnnl_ocl_interop_engine_create(
    dnnl_engine_t* engine,
    cl_device_id device,
    cl_context context
    );

dnnl_status_t DNNL_API dnnl_ocl_interop_engine_get_context(
    dnnl_engine_t engine,
    cl_context* context
    );

dnnl_status_t DNNL_API dnnl_ocl_interop_get_device(
    dnnl_engine_t engine,
    cl_device_id* device
    );

dnnl_status_t DNNL_API dnnl_ocl_interop_stream_create(
    dnnl_stream_t* stream,
    dnnl_engine_t engine,
    cl_command_queue queue
    );

dnnl_status_t DNNL_API dnnl_ocl_interop_stream_get_command_queue(
    dnnl_stream_t stream,
    cl_command_queue* queue
    );

dnnl_status_t DNNL_API dnnl_ocl_interop_primitive_execute(
    const_dnnl_primitive_t primitive,
    dnnl_stream_t stream,
    int nargs,
    const dnnl_exec_arg_t* args,
    const cl_event* deps,
    int ndeps,
    cl_event* return_event
    );

dnnl_status_t DNNL_API dnnl_sycl_interop_engine_create(
    dnnl_engine_t* engine,
    const void* device,
    const void* context
    );

dnnl_status_t DNNL_API dnnl_sycl_interop_engine_get_context(
    dnnl_engine_t engine,
    void** context
    );

dnnl_status_t DNNL_API dnnl_sycl_interop_engine_get_device(
    dnnl_engine_t engine,
    void** device
    );

dnnl_status_t DNNL_API dnnl_sycl_interop_memory_create(
    dnnl_memory_t* memory,
    const_dnnl_memory_desc_t memory_desc,
    dnnl_engine_t engine,
    dnnl_sycl_interop_memory_kind_t memory_kind,
    void* handle
    );

dnnl_status_t DNNL_API dnnl_sycl_interop_memory_create_v2(
    dnnl_memory_t* memory,
    const_dnnl_memory_desc_t memory_desc,
    dnnl_engine_t engine,
    dnnl_sycl_interop_memory_kind_t memory_kind,
    int nhandles,
    void** handles
    );

dnnl_status_t DNNL_API dnnl_sycl_interop_memory_get_memory_kind(
    const_dnnl_memory_t memory,
    dnnl_sycl_interop_memory_kind_t* memory_kind
    );

dnnl_status_t DNNL_API dnnl_sycl_interop_memory_set_buffer(
    dnnl_memory_t memory,
    void* buffer
    );

dnnl_status_t DNNL_API dnnl_sycl_interop_stream_create(
    dnnl_stream_t* stream,
    dnnl_engine_t engine,
    void* queue
    );

dnnl_status_t DNNL_API dnnl_sycl_interop_stream_get_queue(
    dnnl_stream_t stream,
    void** queue
    );

dnnl_status_t DNNL_API dnnl_sycl_interop_primitive_execute(
    const_dnnl_primitive_t primitive,
    dnnl_stream_t stream,
    int nargs,
    const dnnl_exec_arg_t* args,
    const void* deps,
    void* return_event
    );

dnnl_status_t DNNL_API dnnl_threadpool_interop_stream_create(
    dnnl_stream_t* stream,
    dnnl_engine_t engine,
    void* threadpool
    );

dnnl_status_t DNNL_API dnnl_threadpool_interop_stream_get_threadpool(
    dnnl_stream_t astream,
    void** threadpool
    );

dnnl_status_t DNNL_API dnnl_threadpool_interop_set_max_concurrency(int max_concurrency);
dnnl_status_t DNNL_API dnnl_threadpool_interop_get_max_concurrency(int* max_concurrency);

dnnl_status_t DNNL_API dnnl_threadpool_interop_sgemm(
    char transa,
    char transb,
    dnnl_dim_t M,
    dnnl_dim_t N,
    dnnl_dim_t K,
    float alpha,
    const float* A,
    dnnl_dim_t lda,
    const float* B,
    dnnl_dim_t ldb,
    float beta,
    float* C,
    dnnl_dim_t ldc,
    void* threadpool
    );

dnnl_status_t DNNL_API dnnl_threadpool_interop_gemm_u8s8s32(
    char transa,
    char transb,
    char offsetc,
    dnnl_dim_t M,
    dnnl_dim_t N,
    dnnl_dim_t K,
    float alpha,
    const uint8_t* A,
    dnnl_dim_t lda,
    uint8_t ao,
    const int8_t* B,
    dnnl_dim_t ldb,
    int8_t bo,
    float beta,
    int32_t* C,
    dnnl_dim_t ldc,
    const int32_t* co,
    void* threadpool
    );

dnnl_status_t DNNL_API dnnl_threadpool_interop_gemm_s8s8s32(
    char transa,
    char transb,
    char offsetc,
    dnnl_dim_t M,
    dnnl_dim_t N,
    dnnl_dim_t K,
    float alpha,
    const int8_t* A,
    dnnl_dim_t lda,
    int8_t ao,
    const int8_t* B,
    dnnl_dim_t ldb,
    int8_t bo,
    float beta,
    int32_t* C,
    dnnl_dim_t ldc,
    const int32_t* co,
    void* threadpool
    );

size_t DNNL_API dnnl_engine_get_count(dnnl_engine_kind_t kind);

dnnl_status_t DNNL_API dnnl_engine_create(
    dnnl_engine_t* engine,
    dnnl_engine_kind_t kind,
    size_t index
    );

dnnl_status_t DNNL_API dnnl_engine_get_kind(
    dnnl_engine_t engine,
    dnnl_engine_kind_t* kind
    );

dnnl_status_t DNNL_API dnnl_engine_destroy(dnnl_engine_t engine);

dnnl_status_t DNNL_API dnnl_stream_create(
    dnnl_stream_t* stream,
    dnnl_engine_t engine,
    unsigned flags
    );

dnnl_status_t DNNL_API dnnl_stream_get_engine(
    const_dnnl_stream_t stream,
    dnnl_engine_t* engine
    );

dnnl_status_t DNNL_API dnnl_stream_wait(dnnl_stream_t stream);
dnnl_status_t DNNL_API dnnl_stream_destroy(dnnl_stream_t stream);
dnnl_status_t DNNL_API dnnl_get_default_fpmath_mode(dnnl_fpmath_mode_t* mode);
dnnl_status_t DNNL_API dnnl_set_default_fpmath_mode(dnnl_fpmath_mode_t mode);
dnnl_status_t DNNL_API dnnl_set_verbose(int level);
const dnnl_version_t DNNL_API* dnnl_version(void);

dnnl_status_t DNNL_API dnnl_graph_allocator_create(
    dnnl_graph_allocator_t* allocator,
    dnnl_graph_host_allocate_f host_malloc,
    dnnl_graph_host_deallocate_f host_free
    );

dnnl_status_t DNNL_API dnnl_graph_allocator_destroy(dnnl_graph_allocator_t allocator);

dnnl_status_t DNNL_API dnnl_graph_make_engine_with_allocator(
    dnnl_engine_t* engine,
    dnnl_engine_kind_t kind,
    size_t index,
    const_dnnl_graph_allocator_t alloc
    );

dnnl_status_t DNNL_API dnnl_graph_logical_tensor_init(
    dnnl_graph_logical_tensor_t* logical_tensor,
    size_t tid,
    dnnl_data_type_t dtype,
    int32_t ndims,
    dnnl_graph_layout_type_t ltype,
    dnnl_graph_tensor_property_t ptype
    );

dnnl_status_t DNNL_API dnnl_graph_logical_tensor_init_with_dims(
    dnnl_graph_logical_tensor_t* logical_tensor,
    size_t tid,
    dnnl_data_type_t dtype,
    int32_t ndims,
    const dnnl_dims_t dims,
    dnnl_graph_layout_type_t ltype,
    dnnl_graph_tensor_property_t ptype
    );

dnnl_status_t DNNL_API dnnl_graph_logical_tensor_init_with_strides(
    dnnl_graph_logical_tensor_t* logical_tensor,
    size_t tid,
    dnnl_data_type_t dtype,
    int32_t ndims,
    const dnnl_dims_t dims,
    const dnnl_dims_t strides,
    dnnl_graph_tensor_property_t ptype
    );

dnnl_status_t DNNL_API dnnl_graph_logical_tensor_get_mem_size(
    const dnnl_graph_logical_tensor_t* logical_tensor,
    size_t* size
    );

dnnl_status_t DNNL_API dnnl_graph_logical_tensor_is_equal(
    const dnnl_graph_logical_tensor_t* lt1,
    const dnnl_graph_logical_tensor_t* lt2,
    uint8_t* is_equal
    );

dnnl_status_t DNNL_API dnnl_graph_tensor_create(
    dnnl_graph_tensor_t* tensor,
    const dnnl_graph_logical_tensor_t* logical_tensor,
    dnnl_engine_t engine,
    void* handle
    );

dnnl_status_t DNNL_API dnnl_graph_tensor_destroy(dnnl_graph_tensor_t tensor);

dnnl_status_t DNNL_API dnnl_graph_tensor_get_data_handle(
    const_dnnl_graph_tensor_t tensor,
    void** handle
    );

dnnl_status_t DNNL_API dnnl_graph_tensor_set_data_handle(
    dnnl_graph_tensor_t tensor,
    void* handle
    );

dnnl_status_t DNNL_API dnnl_graph_tensor_get_engine(
    const_dnnl_graph_tensor_t tensor,
    dnnl_engine_t* engine
    );

dnnl_status_t DNNL_API dnnl_graph_tensor_get_logical_tensor(
    const_dnnl_graph_tensor_t tensor,
    dnnl_graph_logical_tensor_t* logical_tensor
    );

dnnl_status_t DNNL_API dnnl_graph_op_create(
    dnnl_graph_op_t* op,
    size_t id,
    dnnl_graph_op_kind_t kind,
    const char* verbose_name
    );

dnnl_status_t DNNL_API dnnl_graph_op_destroy(dnnl_graph_op_t op);

dnnl_status_t DNNL_API dnnl_graph_op_add_input(
    dnnl_graph_op_t op,
    const dnnl_graph_logical_tensor_t* input
    );

dnnl_status_t DNNL_API dnnl_graph_op_add_output(
    dnnl_graph_op_t op,
    const dnnl_graph_logical_tensor_t* output
    );

dnnl_status_t DNNL_API dnnl_graph_op_set_attr_f32(
    dnnl_graph_op_t op,
    dnnl_graph_op_attr_t name,
    const float* value,
    size_t value_len
    );

dnnl_status_t DNNL_API dnnl_graph_op_set_attr_bool(
    dnnl_graph_op_t op,
    dnnl_graph_op_attr_t name,
    const uint8_t* value,
    size_t value_len
    );

dnnl_status_t DNNL_API dnnl_graph_op_set_attr_s64(
    dnnl_graph_op_t op,
    dnnl_graph_op_attr_t name,
    const int64_t* value,
    size_t value_len
    );

dnnl_status_t DNNL_API dnnl_graph_op_set_attr_str(
    dnnl_graph_op_t op,
    dnnl_graph_op_attr_t name,
    const char* value,
    size_t value_len
    );

dnnl_status_t DNNL_API dnnl_graph_op_get_id(
    const_dnnl_graph_op_t op,
    size_t* id
    );

dnnl_status_t DNNL_API dnnl_graph_op_get_kind(
    const_dnnl_graph_op_t op,
    dnnl_graph_op_kind_t* kind
    );

dnnl_status_t DNNL_API dnnl_graph_partition_create_with_op(
    dnnl_graph_partition_t* partition,
    const_dnnl_graph_op_t op,
    dnnl_engine_kind_t ekind
    );

dnnl_status_t DNNL_API dnnl_graph_partition_destroy(dnnl_graph_partition_t partition);

dnnl_status_t DNNL_API dnnl_graph_partition_get_op_num(
    const_dnnl_graph_partition_t partition,
    size_t* num
    );

dnnl_status_t DNNL_API dnnl_graph_partition_get_ops(
    dnnl_graph_partition_t partition,
    size_t num,
    size_t* ids
    );

dnnl_status_t DNNL_API dnnl_graph_partition_get_id(
    const_dnnl_graph_partition_t partition,
    size_t* id
    );

dnnl_status_t DNNL_API dnnl_graph_partition_compile(
    dnnl_graph_partition_t partition,
    dnnl_graph_compiled_partition_t compiled_partition,
    size_t in_num,
    const dnnl_graph_logical_tensor_t** inputs,
    size_t out_num,
    const dnnl_graph_logical_tensor_t** outputs,
    dnnl_engine_t engine
    );

dnnl_status_t DNNL_API dnnl_graph_partition_get_input_ports_num(
    const_dnnl_graph_partition_t partition,
    size_t* num
    );

dnnl_status_t DNNL_API dnnl_graph_partition_get_input_ports(
    const_dnnl_graph_partition_t partition,
    size_t num,
    dnnl_graph_logical_tensor_t* inputs
    );

dnnl_status_t DNNL_API dnnl_graph_partition_get_output_ports_num(
    const_dnnl_graph_partition_t partition,
    size_t* num
    );

dnnl_status_t DNNL_API dnnl_graph_partition_get_output_ports(
    const_dnnl_graph_partition_t partition,
    size_t num,
    dnnl_graph_logical_tensor_t* outputs
    );

dnnl_status_t DNNL_API dnnl_graph_partition_is_supported(
    const_dnnl_graph_partition_t partition,
    uint8_t* is_supported
    );

dnnl_status_t DNNL_API dnnl_graph_partition_get_engine_kind(
    const_dnnl_graph_partition_t partition,
    dnnl_engine_kind_t* kind
    );

dnnl_status_t DNNL_API dnnl_graph_compiled_partition_create(
    dnnl_graph_compiled_partition_t* compiled_partition,
    dnnl_graph_partition_t partition
    );

dnnl_status_t DNNL_API dnnl_graph_compiled_partition_execute(
    const_dnnl_graph_compiled_partition_t compiled_partition,
    dnnl_stream_t stream,
    size_t num_inputs,
    const_dnnl_graph_tensor_t* inputs,
    size_t num_outputs,
    const_dnnl_graph_tensor_t* outputs
    );

dnnl_status_t DNNL_API dnnl_graph_compiled_partition_destroy(dnnl_graph_compiled_partition_t compiled_partition);

dnnl_status_t DNNL_API dnnl_graph_compiled_partition_query_logical_tensor(
    const_dnnl_graph_compiled_partition_t compiled_partition,
    size_t tid,
    dnnl_graph_logical_tensor_t* lt
    );

dnnl_status_t DNNL_API dnnl_graph_compiled_partition_get_inplace_ports(
    const_dnnl_graph_compiled_partition_t compiled_partition,
    size_t* num_inplace_pairs,
    const dnnl_graph_inplace_pair_t** inplace_pairs
    );

dnnl_status_t DNNL_API dnnl_graph_graph_create(
    dnnl_graph_graph_t* graph,
    dnnl_engine_kind_t engine_kind
    );

dnnl_status_t DNNL_API dnnl_graph_graph_create_with_fpmath_mode(
    dnnl_graph_graph_t* graph,
    dnnl_engine_kind_t engine_kind,
    dnnl_fpmath_mode_t mode
    );

dnnl_status_t DNNL_API dnnl_graph_graph_destroy(dnnl_graph_graph_t graph);

dnnl_status_t DNNL_API dnnl_graph_graph_set_fpmath_mode(
    dnnl_graph_graph_t graph,
    dnnl_fpmath_mode_t mode,
    int apply_to_int
    );

dnnl_status_t DNNL_API dnnl_graph_graph_get_fpmath_mode(
    dnnl_graph_graph_t graph,
    dnnl_fpmath_mode_t* mode,
    int* apply_to_int
    );

dnnl_status_t DNNL_API dnnl_graph_add_op(
    dnnl_graph_graph_t graph,
    dnnl_graph_op_t op
    );

dnnl_status_t DNNL_API dnnl_graph_graph_finalize(dnnl_graph_graph_t graph);

dnnl_status_t DNNL_API dnnl_graph_graph_is_finalized(
    dnnl_graph_graph_t graph,
    uint8_t* finalized
    );

dnnl_status_t DNNL_API dnnl_graph_graph_filter(
    dnnl_graph_graph_t graph,
    dnnl_graph_partition_policy_t policy
    );

dnnl_status_t DNNL_API dnnl_graph_graph_get_partition_num(
    const_dnnl_graph_graph_t graph,
    size_t* num
    );

dnnl_status_t DNNL_API dnnl_graph_graph_get_partitions(
    dnnl_graph_graph_t graph,
    size_t num,
    dnnl_graph_partition_t* partitions
    );

dnnl_status_t DNNL_API dnnl_graph_get_compiled_partition_cache_capacity(int* capacity);
dnnl_status_t DNNL_API dnnl_graph_set_compiled_partition_cache_capacity(int capacity);
dnnl_status_t DNNL_API dnnl_graph_set_constant_tensor_cache(int flag);
dnnl_status_t DNNL_API dnnl_graph_get_constant_tensor_cache(int* flag);

dnnl_status_t DNNL_API dnnl_graph_set_constant_tensor_cache_capacity(
    dnnl_engine_kind_t eng_kind,
    size_t size
    );

dnnl_status_t DNNL_API dnnl_graph_get_constant_tensor_cache_capacity(
    dnnl_engine_kind_t eng_kind,
    size_t* size
    );

dnnl_status_t DNNL_API dnnl_graph_ocl_interop_allocator_create(
    dnnl_graph_allocator_t* allocator,
    dnnl_graph_ocl_allocate_f ocl_malloc,
    dnnl_graph_ocl_deallocate_f ocl_free
    );

dnnl_status_t DNNL_API dnnl_graph_ocl_interop_make_engine_with_allocator(
    dnnl_engine_t* engine,
    cl_device_id device,
    cl_context context,
    const_dnnl_graph_allocator_t alloc
    );

dnnl_status_t DNNL_API dnnl_graph_ocl_interop_make_engine_from_cache_blob_with_allocator(
    dnnl_engine_t* engine,
    cl_device_id device,
    cl_context context,
    const_dnnl_graph_allocator_t alloc,
    size_t size,
    const uint8_t* cache_blob
    );

dnnl_status_t DNNL_API dnnl_graph_ocl_interop_compiled_partition_execute(
    const_dnnl_graph_compiled_partition_t compiled_partition,
    dnnl_stream_t stream,
    size_t num_inputs,
    const_dnnl_graph_tensor_t* inputs,
    size_t num_outputs,
    const_dnnl_graph_tensor_t* outputs,
    const cl_event* deps,
    int ndeps,
    cl_event* return_event
    );

dnnl_status_t DNNL_API dnnl_graph_sycl_interop_allocator_create(
    dnnl_graph_allocator_t* allocator,
    dnnl_graph_sycl_allocate_f sycl_malloc,
    dnnl_graph_sycl_deallocate_f sycl_free
    );

dnnl_status_t DNNL_API dnnl_graph_sycl_interop_make_engine_with_allocator(
    dnnl_engine_t* engine,
    const void* device,
    const void* context,
    const_dnnl_graph_allocator_t alloc
    );

dnnl_status_t DNNL_API dnnl_graph_sycl_interop_compiled_partition_execute(
    const_dnnl_graph_compiled_partition_t compiled_partition,
    dnnl_stream_t stream,
    size_t num_inputs,
    const_dnnl_graph_tensor_t* inputs,
    size_t num_outputs,
    const_dnnl_graph_tensor_t* outputs,
    const void* deps,
    void* sycl_event
    );

dnnl_status_t DNNL_API dnnl_ukernel_attr_params_create(dnnl_ukernel_attr_params_t* attr_params);

dnnl_status_t DNNL_API dnnl_ukernel_attr_params_set_post_ops_args(
    dnnl_ukernel_attr_params_t attr_params,
    const void** post_ops_args
    );

dnnl_status_t DNNL_API dnnl_ukernel_attr_params_set_A_scales(
    dnnl_ukernel_attr_params_t attr_params,
    const void* a_scales
    );

dnnl_status_t DNNL_API dnnl_ukernel_attr_params_set_B_scales(
    dnnl_ukernel_attr_params_t attr_params,
    const void* b_scales
    );

dnnl_status_t DNNL_API dnnl_ukernel_attr_params_set_D_scales(
    dnnl_ukernel_attr_params_t attr_params,
    const void* d_scales
    );

dnnl_status_t DNNL_API dnnl_ukernel_attr_params_destroy(dnnl_ukernel_attr_params_t attr_params);

dnnl_status_t DNNL_API dnnl_brgemm_create(
    dnnl_brgemm_t* brgemm,
    dnnl_dim_t M,
    dnnl_dim_t N,
    dnnl_dim_t K,
    dnnl_dim_t batch_size,
    dnnl_dim_t lda,
    dnnl_dim_t ldb,
    dnnl_dim_t ldc,
    dnnl_data_type_t a_dt,
    dnnl_data_type_t b_dt,
    dnnl_data_type_t c_dt
    );

dnnl_status_t DNNL_API dnnl_brgemm_set_add_C(dnnl_brgemm_t brgemm, int add_C);

dnnl_status_t DNNL_API dnnl_brgemm_set_post_ops(
    dnnl_brgemm_t brgemm,
    dnnl_dim_t ldd,
    dnnl_data_type_t d_dt,
    const_dnnl_post_ops_t post_ops
    );

dnnl_status_t DNNL_API dnnl_brgemm_set_A_scales(
    dnnl_brgemm_t brgemm,
    int a_scale_mask
    );

dnnl_status_t DNNL_API dnnl_brgemm_set_B_scales(
    dnnl_brgemm_t brgemm,
    int b_scale_mask
    );

dnnl_status_t DNNL_API dnnl_brgemm_set_D_scales(
    dnnl_brgemm_t brgemm,
    int d_scale_mask
    );

dnnl_status_t DNNL_API dnnl_brgemm_finalize(dnnl_brgemm_t brgemm);

dnnl_status_t DNNL_API dnnl_brgemm_get_B_pack_type(
    const_dnnl_brgemm_t brgemm,
    dnnl_pack_type_t* pack_type
    );

dnnl_status_t DNNL_API dnnl_brgemm_get_scratchpad_size(
    const_dnnl_brgemm_t brgemm,
    size_t* size
    );

dnnl_status_t DNNL_API dnnl_brgemm_is_execute_postops_valid(
    const_dnnl_brgemm_t brgemm,
    int* valid
    );

dnnl_status_t DNNL_API dnnl_brgemm_set_hw_context(const_dnnl_brgemm_t brgemm);
dnnl_status_t DNNL_API dnnl_brgemm_release_hw_context();
dnnl_status_t DNNL_API dnnl_brgemm_generate(dnnl_brgemm_t brgemm);

dnnl_status_t DNNL_API dnnl_brgemm_execute(
    const_dnnl_brgemm_t brgemm,
    const void* A_ptr,
    const void* B_ptr,
    const dnnl_dim_t* A_B_offsets,
    void* C_ptr,
    void* scratchpad_ptr
    );

dnnl_status_t DNNL_API dnnl_brgemm_execute_postops(
    const_dnnl_brgemm_t brgemm,
    const void* A,
    const void* B,
    const dnnl_dim_t* A_B_offsets,
    const void* C_ptr,
    void* D_ptr,
    void* scratchpad_ptr,
    const_dnnl_ukernel_attr_params_t attr_params
    );

dnnl_status_t DNNL_API dnnl_brgemm_destroy(dnnl_brgemm_t brgemm);

dnnl_status_t DNNL_API dnnl_transform_create(
    dnnl_transform_t* transform,
    dnnl_dim_t K,
    dnnl_dim_t N,
    dnnl_pack_type_t in_pack_type,
    dnnl_dim_t in_ld,
    dnnl_dim_t out_ld,
    dnnl_data_type_t in_dt,
    dnnl_data_type_t out_dt
    );

dnnl_status_t DNNL_API dnnl_transform_generate(dnnl_transform_t transform);

dnnl_status_t DNNL_API dnnl_transform_execute(
    const_dnnl_transform_t transform,
    const void* in_ptr,
    void* out_ptr
    );

dnnl_status_t DNNL_API dnnl_transform_destroy(dnnl_transform_t transform);

// macros

#define BATCH
#define BATCH
#define CHECK(f)
#define COMPLAIN_DNNL_ERROR_AND_EXIT(what, status)
#define COMPLAIN_EXAMPLE_ERROR_AND_EXIT(complain_fmt, ...)
#define CONV_IH
#define CONV_IH
#define CONV_IW
#define CONV_IW
#define CONV_OH
#define CONV_OH
#define CONV_OW
#define CONV_OW
#define CONV_PAD
#define CONV_PAD
#define CONV_STRIDE
#define CONV_STRIDE
#define DNNL_ARG_ATTR_DROPOUT_MASK
#define DNNL_ARG_ATTR_DROPOUT_PROBABILITY
#define DNNL_ARG_ATTR_DROPOUT_SEED
#define DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx)
#define DNNL_ARG_ATTR_MULTIPLE_POST_OP_BASE
#define DNNL_ARG_ATTR_OUTPUT_SCALES
#define DNNL_ARG_ATTR_POST_OP_DW
#define DNNL_ARG_ATTR_ROUNDING_SEED
#define DNNL_ARG_ATTR_SCALES
#define DNNL_ARG_ATTR_ZERO_POINTS
#define DNNL_ARG_AUGRU_ATTENTION
#define DNNL_ARG_BIAS
#define DNNL_ARG_DIFF_AUGRU_ATTENTION
#define DNNL_ARG_DIFF_BIAS
#define DNNL_ARG_DIFF_DST
#define DNNL_ARG_DIFF_DST_0
#define DNNL_ARG_DIFF_DST_1
#define DNNL_ARG_DIFF_DST_2
#define DNNL_ARG_DIFF_DST_ITER
#define DNNL_ARG_DIFF_DST_ITER_C
#define DNNL_ARG_DIFF_DST_LAYER
#define DNNL_ARG_DIFF_SCALE
#define DNNL_ARG_DIFF_SHIFT
#define DNNL_ARG_DIFF_SRC
#define DNNL_ARG_DIFF_SRC_0
#define DNNL_ARG_DIFF_SRC_1
#define DNNL_ARG_DIFF_SRC_2
#define DNNL_ARG_DIFF_SRC_3
#define DNNL_ARG_DIFF_SRC_ITER
#define DNNL_ARG_DIFF_SRC_ITER_C
#define DNNL_ARG_DIFF_SRC_LAYER
#define DNNL_ARG_DIFF_WEIGHTS
#define DNNL_ARG_DIFF_WEIGHTS_0
#define DNNL_ARG_DIFF_WEIGHTS_1
#define DNNL_ARG_DIFF_WEIGHTS_2
#define DNNL_ARG_DIFF_WEIGHTS_3
#define DNNL_ARG_DIFF_WEIGHTS_ITER
#define DNNL_ARG_DIFF_WEIGHTS_LAYER
#define DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE
#define DNNL_ARG_DIFF_WEIGHTS_PROJECTION
#define DNNL_ARG_DST
#define DNNL_ARG_DST_0
#define DNNL_ARG_DST_1
#define DNNL_ARG_DST_2
#define DNNL_ARG_DST_ITER
#define DNNL_ARG_DST_ITER_C
#define DNNL_ARG_DST_LAYER
#define DNNL_ARG_FROM
#define DNNL_ARG_MEAN
#define DNNL_ARG_MULTIPLE_DST
#define DNNL_ARG_MULTIPLE_SRC
#define DNNL_ARG_SCALE
#define DNNL_ARG_SCRATCHPAD
#define DNNL_ARG_SHIFT
#define DNNL_ARG_SRC
#define DNNL_ARG_SRC_0
#define DNNL_ARG_SRC_1
#define DNNL_ARG_SRC_2
#define DNNL_ARG_SRC_3
#define DNNL_ARG_SRC_ITER
#define DNNL_ARG_SRC_ITER_C
#define DNNL_ARG_SRC_LAYER
#define DNNL_ARG_TO
#define DNNL_ARG_UNDEF
#define DNNL_ARG_VARIANCE
#define DNNL_ARG_WEIGHTS
#define DNNL_ARG_WEIGHTS_0
#define DNNL_ARG_WEIGHTS_1
#define DNNL_ARG_WEIGHTS_2
#define DNNL_ARG_WEIGHTS_3
#define DNNL_ARG_WEIGHTS_ITER
#define DNNL_ARG_WEIGHTS_LAYER
#define DNNL_ARG_WEIGHTS_PEEPHOLE
#define DNNL_ARG_WEIGHTS_PROJECTION
#define DNNL_ARG_WORKSPACE
#define DNNL_ENABLE_EXCEPTIONS
#define DNNL_GRAPH_UNKNOWN_DIM
#define DNNL_GRAPH_UNKNOWN_NDIMS
#define DNNL_JIT_PROFILE_LINUX_JITDUMP
#define DNNL_JIT_PROFILE_LINUX_JITDUMP_USE_TSC
#define DNNL_JIT_PROFILE_LINUX_PERF
#define DNNL_JIT_PROFILE_LINUX_PERFMAP
#define DNNL_JIT_PROFILE_NONE
#define DNNL_JIT_PROFILE_VTUNE
#define DNNL_MAX_NDIMS
#define DNNL_MEMORY_ALLOCATE
#define DNNL_MEMORY_NONE
#define DNNL_RUNTIME_DIM_VAL
#define DNNL_RUNTIME_F32_VAL
#define DNNL_RUNTIME_S32_VAL
#define DNNL_RUNTIME_SIZE_VAL
#define DNNL_THROW_ERROR(status, msg)
#define IC
#define IC
#define OC
#define OC
#define OCL_CHECK(x)
#define OCL_CHECK(x)
#define POOL_OH
#define POOL_OH
#define POOL_OW
#define POOL_OW
#define POOL_PAD
#define POOL_PAD
#define POOL_STRIDE
#define POOL_STRIDE
#define PRAGMA_MACRO(x)
#define PRAGMA_MACRo(x)
#define PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(n)
#define TYPE_CASE(T)
#define TYPE_CASE(T)
#define TYPE_CASE(T)
#define TYPE_CASE(T)
#define TYPE_CASE(T)
#define TYPE_CASE(T)
#define TYPE_CASE(T)
#define UNUSED(x)
#define _POSIX_C_SOURCE
#define _POSIX_C_SOURCE

Detailed Documentation

Global Functions

void set_any_layout(
    const std::vector<dnnl::graph::partition>& partitions,
    std::unordered_set<size_t>& id_to_set_any_layout
    )

Set any layout according to the connection relationship of partitions.

Parameters:

partitions

a list of partitions

id_to_set_any_layout

a set of ids of logical tensors with any layout type