The matrix multiplication (MatMul) primitive computes the product of two 2D tensors with optional bias addition (the variable names follow the standard Naming Conventions):
\[ \dst(m, n) = \sum_{k=0}^{K} \left( \src(m, k) \cdot \weights(k, n) \right) + \bias(m, n) \]
The MatMul primitive also supports batching multiple independent matrix multiplication operations, in which case the tensors can be up to 12D:
\[ \dst(bs_0, bs_1, bs_2, \ldots, m, n) = \sum_{k=0}^{K} \left( \src(bs_0, bs_1, bs_2, \ldots, m, k) \cdot \weights(bs_0, bs_1, bs_2, \ldots, k, n) \right) + \bias(bs_0, bs_1, bs_2, \ldots, m, n) \]
MatMul also supports implicit broadcast semantics i.e., \(\src\) can be broadcasted into \(\weights\) if the corresponding dimension in \(\src\) is 1 (and vice versa). However, all tensors (including \(\bias\), if it exists) must have the same number of dimensions.
The shape of \(\dst\) only depends on \(\src\) and \(\weights\) tensors. The \(\bias\) cannot change the dimensions of \(\dst\) by broadcasting. In other words, for every dimension, the following constraint must hold true: dimension(bias) == dimension(dst) || dimension(bias) == 1
.
When executed, the inputs and outputs should be mapped to an execution argument index as specified by the following table.
Primitive input/output | Execution argument index |
---|---|
\(\src\) | DNNL_ARG_SRC |
\(\weights\) | DNNL_ARG_WEIGHTS |
\(\bias\) | DNNL_ARG_BIAS |
\(\dst\) | DNNL_ARG_DST |
\(binary post-op\) | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) | DNNL_ARG_SRC_1 |
{3, 4, 4}
and {DNNL_RUNTIME_DIM_VAL, 4, 4}
respectively is invalid.The broadcasting shape consistency check is not done for the dimensions with DNNL_RUNTIME_DIM_VAL. It is user responsibility to make sure the dimensions for the tensors are valid.
The MatMul primitive supports the following combinations of data types for source, destination, weights, and bias tensors:
Source | Weights | Destination | Bias |
---|---|---|---|
f32 | f32 | f32 | f32 |
f16 | f16 | f16 | f16 |
bf16 | bf16 | bf16 | bf16, f32 |
u8, s8 | s8, u8 | u8, s8, s32, f32 | u8, s8, s32, f32 |
The MatMul primitive expects the following tensors:
Dims | Source | Weights | Destination | Bias |
---|---|---|---|---|
2D | \(M \times K\) | \(K \times N\) | \(M \times N\) | None or \((M \text{ or } 1) \times (N \text{ or } 1)\) |
ND | \((\prod_{i=0}^{ND - 2} src{\_}dims[i]) \times M \times K\) | \((\prod_{i=0}^{ND - 2} weights{\_}dims[i]) \times K \times N\) | \((\prod_{i=0}^{ND - 2} dst{\_}dims[i]) \times M \times N\) | None or \(\prod_{i=0}^{ND} (dst{\_}dims[i] { or } 1)\) |
The MatMul primitive is generally optimized for the case in which memory objects use plain memory formats. Additionally, the \(\src\) and \(\weights\) must have at least one of the axes m
or k
and n
or k
contiguous (i.e., stride=1) respectively. However, it is recommended to use the placeholder memory format dnnl::memory::format_tag::any if an input tensor is reused across multiple executions. In this case, the primitive will set the most appropriate memory format for the corresponding input tensor.
The memory format of the destination tensor should always be plain with n
axis contiguous. For example, dnnl::memory::format_tag::ab for the 2D case and dnnl::memory::format_tag::abc or dnnl::memory::format_tag::bac for the 3D one.
Attributes and post-ops enable modifying the behavior of the MatMul primitive. The following attributes and post-ops are supported:
Type | Operation | Description | Restrictions |
---|---|---|---|
Attribute | Output scales | Scales the result by given scale factor(s) | |
Attribute | Zero points | Sets zero point(s) for the corresponding tensors | Int8 computations only |
Post-op | Eltwise | Applies an Eltwise operation to the result | |
Post-op | Sum | Adds the operation result to the destination tensor instead of overwriting it | |
Post-op | Binary | Applies a Binary operation to the result | General binary post-op restrictions |
To facilitate dynamic quantization, the primitive supports run-time output scales. That means a user could configure attributes with output scales set to the DNNL_RUNTIME_F32_VAL wildcard value instead of the actual scales, if the scales are not known at the primitive descriptor creation stage. In this case, the user must provide the scales as an additional input memory object with argument DNNL_ARG_ATTR_OUTPUT_SCALES
during the execution stage.
Similarly to run-time output scales, the primitive supports run-time zero points. The wildcard value for zero points is DNNL_RUNTIME_S32_VAL. During the execution stage, the corresponding memory object needs to be passed in the argument with index set to (DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_${MEMORY_INDEX}
).
DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC
).u8
data type for weights.Engine | Name | Com |
---|---|---|
CPU/GPU | Matmul Primitive Example | This C++ API example demonstrates how to create and execute a MatMul primitive. Key optimizations included in this example:
|
CPU | MatMul Tutorial: Comparison with SGEMM | C++ API example demonstrating MatMul as a replacement for SGEMM functions. Concepts:
|
CPU/GPU | MatMul Tutorial: INT8 Inference | C++ API example demonstrating how one can use MatMul fused with ReLU in INT8 inference. Concepts:
|
CPU | MatMul Tutorial: Quantization | C++ API example demonstrating how one can perform reduced precision matrix-matrix multiplication using MatMul and the accuracy of the result compared to the floating point computations. Concepts:
|