struct dnnl::post_ops¶
Overview¶
Post-ops. More…
#include <dnnl.hpp> struct post_ops: public dnnl::handle { // methods int len() const; primitive::kind kind(int index) const; void append_sum( float scale = 1.f, memory::data_type data_type = memory::data_type::undef ); void append_sum( float scale, int32_t zero_point, memory::data_type data_type = memory::data_type::undef ); void get_params_sum(int index, float& scale) const; void get_params_sum(int index, float& scale, memory::data_type& data_type) const; void get_params_sum( int index, float& scale, int32_t& zero_point, memory::data_type& data_type ) const; void append_eltwise(float scale, algorithm aalgorithm, float alpha, float beta); void get_params_eltwise( int index, float& scale, algorithm& aalgorithm, float& alpha, float& beta ) const; void append_dw_k3s1p1( memory::data_type weights_data_type, memory::data_type bias_data_type, memory::data_type dst_data_type, int mask, const std::vector<float>& scales ); void get_params_dw_k3s1p1( int index, memory::data_type& weights_data_type, memory::data_type& bias_data_type, memory::data_type& dst_data_type, int& mask, std::vector<float>& scales ) const; void append_dw_k3s2p1( memory::data_type weights_data_type, memory::data_type bias_data_type, memory::data_type dst_data_type, int mask, const std::vector<float>& scales ); void get_params_dw_k3s2p1( int index, memory::data_type& weights_data_type, memory::data_type& bias_data_type, memory::data_type& dst_data_type, int& mask, std::vector<float>& scales ) const; void append_binary(algorithm aalgorithm, const memory::desc& src1_desc); void get_params_binary( int index, algorithm& aalgorithm, memory::desc& src1_desc ) const; void set_output_scales(int mask, const std::vector<float>& scales); void get_scales(int arg, int& mask, std::vector<float>& scales) const; void set_scales(int arg, int mask, const std::vector<float>& scales); void get_zero_points(int arg, int& mask, std::vector<int32_t>& zero_points) const; void set_zero_points(int arg, int mask, const std::vector<int32_t>& zero_points); const post_ops get_post_ops() const; void set_post_ops(const post_ops ops); void set_rnn_data_qparams(float scale, float shift); void get_rnn_data_qparams(float& scale, float& shift); void set_rnn_weights_qparams(int mask, const std::vector<float>& scales); void get_rnn_weights_qparams(int& mask, std::vector<float>& scales); void set_rnn_weights_projection_qparams( int mask, const std::vector<float>& scales ); void get_rnn_weights_projection_qparams(int& mask, std::vector<float>& scales); };
Inherited Members¶
public: // methods handle<T, traits>& operator = (const handle<T, traits>&); handle<T, traits>& operator = (handle<T, traits>&&); void reset(T t, bool weak = false); T get(bool allow_empty = false) const; operator T () const; operator bool () const; bool operator == (const handle<T, traits>& other) const; bool operator != (const handle& other) const;
Detailed Documentation¶
Post-ops.
Post-ops are computations executed after the main primitive computations and are attached to the primitive via primitive attributes.
See also:
Primitive Attributes: Post-ops
Methods¶
int len() const
Returns the number of post-ops entries.
primitive::kind kind(int index) const
Returns the primitive kind of post-op at entry with a certain index.
Parameters:
index |
Index of the post-op to return the kind for. |
Returns:
Primitive kind of the post-op at the specified index.
void append_sum( float scale = 1.f, memory::data_type data_type = memory::data_type::undef )
Appends an accumulation (sum) post-op.
Prior to accumulating the result, the previous value would be multiplied by a scaling factor scale
.
The kind of this post-op is dnnl::primitive::kind::sum.
This feature may improve performance for cases like residual learning blocks, where the result of convolution is accumulated to the previously computed activations. The parameter scale
may be used for the integer-based computations when the result and previous activations have different logical scaling factors.
In the simplest case when the accumulation is the only post-op, the computations will be dst[:] := scale * dst[:] + op(...)
instead of dst[:] := op(...)
.
If data_type
is specified, the original dst tensor will be reinterpreted as a tensor with the provided data type. Because it is a reinterpretation, data_type and dst data type should have the same size. As a result, computations will be dst[:] <- scale * as_data_type(dst[:]) + op(...)
instead of dst[:] <- op(...)
.
Note
This post-op executes in-place and does not change the destination layout.
Parameters:
scale |
Scaling factor. |
data_type |
Data type. |
void append_sum( float scale, int32_t zero_point, memory::data_type data_type = memory::data_type::undef )
Appends an accumulation (sum) post-op.
Prior to accumulating the result, the previous value will be will be reduced by zero point zero_point
and multiplied by a scaling factor scale
.
The kind of this post-op is dnnl::primitive::kind::sum.
This feature may improve performance for cases like dequantize the asymmetrically quantized sum’s src1 tensor to f32 domain before performing the sum operation by subtracting zero_point
before the scaling.
In the simplest case when the accumulation is the only post-op, the computations will be dst[:] := scale * (dst[:] - zero_point) + op(...)
instead of dst[:] := op(...)
.
If data_type
is specified, the original dst tensor will be reinterpreted as a tensor with the provided data type. Because it is a reinterpretation, data_type and dst data type should have the same size. As a result, computations will be dst[:] <- scale * (as_data_type(dst[:]) - zero_point) + op(...)
instead of dst[:] <- op(...)
.
Note
This post-op executes in-place and does not change the destination layout.
Parameters:
scale |
Scaling factor. |
zero_point |
Zero point. |
data_type |
Data type. |
void get_params_sum(int index, float& scale) const
Returns the parameters of an accumulation (sum) post-op.
Parameters:
index |
Index of the sum post-op. |
scale |
Scaling factor of the sum post-op. |
void get_params_sum(int index, float& scale, memory::data_type& data_type) const
Returns the parameters of an accumulation (sum) post-op.
Parameters:
index |
Index of the sum post-op. |
scale |
Scaling factor of the sum post-op. |
data_type |
Data type of the sum post-op. |
void get_params_sum( int index, float& scale, int32_t& zero_point, memory::data_type& data_type ) const
Returns the parameters of an accumulation (sum) post-op.
Parameters:
index |
Index of the sum post-op. |
scale |
Scaling factor of the sum post-op. |
zero_point |
Single scalar int32_t value of zeropoint. |
data_type |
Data type of the sum post-op. |
void append_eltwise(float scale, algorithm aalgorithm, float alpha, float beta)
Appends an elementwise post-op.
The kind of this post-op is dnnl::primitive::kind::eltwise.
In the simplest case when the elementwise is the only post-op, the computations would be dst[:] := scale * eltwise_op (op(...))
instead of dst[:] <- op(...)
, where eltwise_op is configured with the given parameters.
Parameters:
scale |
Scaling factor. |
aalgorithm |
Elementwise algorithm. |
alpha |
Alpha parameter for the elementwise algorithm. |
beta |
Beta parameter for the elementwise algorithm. |
void get_params_eltwise( int index, float& scale, algorithm& aalgorithm, float& alpha, float& beta ) const
Returns parameters of an elementwise post-op.
Parameters:
index |
Index of the post-op. |
scale |
Output scaling factor. |
aalgorithm |
Output elementwise algorithm kind. |
alpha |
Output alpha parameter for the elementwise algorithm. |
beta |
Output beta parameter for the elementwise algorithm. |
void append_dw_k3s1p1( memory::data_type weights_data_type, memory::data_type bias_data_type, memory::data_type dst_data_type, int mask, const std::vector<float>& scales )
Appends a depthwise post-op convolution with stride 1.
This post-op can only be fused with a 2D 1x1 convolution (convolution with weights spatial dimension equal to 1 i.e., kh=kw=1).
The kind of this post-op is dnnl_convolution.
The number of outputs for primitive remain same as before. The output size remain same as the original primitive due to stride=1.
The Post-op can be defined as:
dst[:] <- scales * (conv_dw(conv_1x1))
See dev_guide_attributes_post_ops_depthwise and dev_guide_attributes_post_ops_depthwise_fusion for more info.
Parameters:
weights_data_type |
Weights data type of depthwise post-op |
bias_data_type |
Bias data type of depthwise post-op |
dst_data_type |
Output data type of depthwise post-op |
mask |
Output scaling factors correspondence mask that defines the correspondence between the output tensor dimensions and the |
scales |
Output pointer to a constant array of float scaling factors. |
void get_params_dw_k3s1p1( int index, memory::data_type& weights_data_type, memory::data_type& bias_data_type, memory::data_type& dst_data_type, int& mask, std::vector<float>& scales ) const
Returns the parameters of an depthwise post-op with stride 1.
Parameters:
index |
Index of the elementwise post-op. |
weights_data_type |
Weights data type of depthwise post-op |
bias_data_type |
Bias data type of depthwise post-op |
dst_data_type |
Output data type of depthwise post-op |
mask |
Output scaling factors correspondence mask that defines the correspondence between the output tensor dimensions and the |
scales |
Output pointer to a constant array of float scaling factors. |
void append_dw_k3s2p1( memory::data_type weights_data_type, memory::data_type bias_data_type, memory::data_type dst_data_type, int mask, const std::vector<float>& scales )
Appends a depthwise post-op convolution with stride 2.
This post-op can only be fused with a 2D 1x1 convolution (convolution with weights spatial dimension equal to 1 i.e., kh=kw=1).
The kind of this post-op is dnnl_convolution.
The number of outputs for primitive remain same as before. The output spatial size can be derived as below:
output_height = ceil(output_height_1x1_convolution, stride) output_width = ceil(output_width_1x1_convolution, stride)
The Post-op can be defined as:
dst[:] <- scales * (conv_dw(conv_1x1))
See dev_guide_attributes_post_ops_depthwise and dev_guide_attributes_post_ops_depthwise_fusion for more info.
Parameters:
weights_data_type |
Weights data type of depthwise post-op |
bias_data_type |
Bias data type of depthwise post-op |
dst_data_type |
Output data type of depthwise post-op |
mask |
Output scaling factors correspondence mask that defines the correspondence between the output tensor dimensions and the |
scales |
Output pointer to a constant array of float scaling factors. |
Returns:
dnnl_success on success and a status describing the error otherwise
void get_params_dw_k3s2p1( int index, memory::data_type& weights_data_type, memory::data_type& bias_data_type, memory::data_type& dst_data_type, int& mask, std::vector<float>& scales ) const
Returns the parameters of an depthwise post-op with stride 2.
Parameters:
index |
Index of the elementwise post-op. |
weights_data_type |
Weights data type of depthwise post-op |
bias_data_type |
Bias data type of depthwise post-op |
dst_data_type |
Output data type of depthwise post-op |
mask |
Output scaling factors correspondence mask that defines the correspondence between the output tensor dimensions and the |
scales |
Output pointer to a constant array of float scaling factors. |
void append_binary(algorithm aalgorithm, const memory::desc& src1_desc)
Appends a binary post-op.
The kind of this post operation is dnnl_binary.
In the simplest case when the binary is the only post operation, the computations would be:
dst[:] <- binary_op (dst[:], another_input[:])
where binary_op is configured with the given parameters. binary_op supports broadcast semantics for a second operand.
Parameters:
aalgorithm |
Binary algorithm for the post-op. |
src1_desc |
Memory descriptor of a second operand. |
void get_params_binary( int index, algorithm& aalgorithm, memory::desc& src1_desc ) const
Returns the parameters of a binary post-op.
Parameters:
index |
Index of the binary post-op. |
aalgorithm |
Output binary algorithm kind. |
src1_desc |
Output memory descriptor of a second operand. |
void set_output_scales(int mask, const std::vector<float>& scales)
Appends a prelu forward post-op.
The kind of this post-op is dnnl::primitive::kind::prelu.
The post-op can be defined as:
dst[:] <- prelu(dst[:], weights[:])
prelu:
dst[:] <- dst[:] if dst[:] > 0
dst[:] <- dst[:] * weights[:] if dst[:] <= 0
Example usage:
int mb = 32, oc = 32, oh = 14, ow = 14; // convolution output params // unique weights per output channel vector<float> weights = { ... }; int oc_dim = 1; // mb_dim = 0, channel_dim = 1, height_dim = 2, ... // construct a convolution descriptor dnnl::convolution::desc conv_d; dnnl::primitive_attr attr; attr.append_prelu(1 << oc_dim); dnnl::primitive_desc conv_pd(conv_d, attr, engine); memory prelu_weights({{1}, dt::f32, {1}}, eng, weights.data()); std::unordered_map<int, memory> conv_args; conv_args.insert( {DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_WEIGHTS, prelu_weights}) @note The order of dimensions does not depend on how elements are laid out in memory. For example: - for a 2D CNN activations tensor the order is always (n, c) - for a 4D CNN activations tensor the order is always (n, c, h, w) - for a 5D CNN weights tensor the order is always (g, oc, ic, kh, kw) Prelu weights tensor is passed in runtime execution phase. Prelu weights tensor data type is implicitly assumed as f32 using plain layout (a, ab, acb, acdb, acdeb) @param mask Defines the correspondence between the output tensor dimensions and the prelu weights tensor. The set i-th bit indicates that a dedicated weights value is used for each index along that dimension. Set the mask to 0 to use a common weights value for the whole output tensor. void append_prelu(int mask) { error::wrap_c_api(dnnl_post_ops_append_prelu(get(), mask), "could not append a prelu post-op"); } Returns the parameters of a prelu post-op. @param index Index of the prelu post-op. @param maks Weights mask of prelu post-op. void get_params_prelu(int index, int &mask) const { error::wrap_c_api(dnnl_post_ops_get_params_prelu(get(), index, &mask), "could not get parameters of a binary post-op"); } }; Primitive attributes. @sa @ref dev_guide_attributes struct primitive_attr : public handle<dnnl_primitive_attr_t> { using handle<dnnl_primitive_attr_t>::handle; Constructs default (empty) primitive attributes. primitive_attr() { dnnl_primitive_attr_t result; error::wrap_c_api(dnnl_primitive_attr_create(&result), "could not create primitive attribute"); reset(result); } Creates primitive attributes from a C API ::dnnl_primitive_attr_t handle. The resulting handle is not weak and the C handle will be destroyed during the destruction of the C++ object. @param attr The C API primitive attributes. primitive_attr(dnnl_primitive_attr_t attr) : handle<dnnl_primitive_attr_t>(attr) {} Returns the fpmath mode fpmath_mode get_fpmath_mode() const { dnnl_fpmath_mode_t result; error::wrap_c_api(dnnl_primitive_attr_get_fpmath_mode(get(), &result), "could not get fpmath mode primitive attribute"); return fpmath_mode(result); } Sets fpmath mode. @param mode Specified fpmath mode. void set_fpmath_mode(fpmath_mode mode) { error::wrap_c_api(dnnl_primitive_attr_set_fpmath_mode( get(), dnnl::convert_to_c(mode)), "could not set fpmath mode primitive attribute"); } Returns the scratchpad mode. scratchpad_mode get_scratchpad_mode() const { dnnl_scratchpad_mode_t result; error::wrap_c_api( dnnl_primitive_attr_get_scratchpad_mode(get(), &result), "could not get scratchpad mode primitive attribute"); return scratchpad_mode(result); } Sets scratchpad mode. @param mode Specified scratchpad mode. void set_scratchpad_mode(scratchpad_mode mode) { error::wrap_c_api(dnnl_primitive_attr_set_scratchpad_mode( get(), dnnl::convert_to_c(mode)), "could not set scratchpad mode primitive attribute"); } Returns output scaling factors correspondence mask and values. @param mask Scaling factors correspondence mask that defines the correspondence between the output tensor dimensions and the @p scales vector. The set i-th bit indicates that a dedicated output scaling factor is used for each index along that dimension. The mask value of 0 implies a common output scaling factor for the whole output tensor. @param scales Vector of output scaling factors. void get_output_scales(int &mask, std::vector<float> &scales) const { dnnl_dim_t count; int c_mask; const float *c_scales; error::wrap_c_api(dnnl_primitive_attr_get_output_scales( get(), &count, &c_mask, &c_scales), "could not get output scales primitive attribute"); scales.resize(count); mask = c_mask; for (dnnl_dim_t c = 0; c < count; ++c) scales[c] = c_scales[c]; } Sets output scaling factors correspondence mask and values. Example usage: @code int mb = 32, oc = 32, oh = 14, ow = 14; // convolution output params // unique output scales per output channel vector<float> scales = { ... }; int oc_dim = 1; // mb_dim = 0, channel_dim = 1, height_dim = 2, ... // construct a convolution descriptor dnnl::convolution::desc conv_d; dnnl::primitive_attr attr; attr.set_output_scales(attr, oc, 1 << oc_dim, scales); dnnl::primitive_desc conv_pd(conv_d, attr, engine);
Note
The order of dimensions does not depend on how elements are laid out in memory. For example:
for a 2D CNN activations tensor the order is always (n, c)
for a 4D CNN activations tensor the order is always (n, c, h, w)
for a 5D CNN weights tensor the order is always (g, oc, ic, kh, kw)
Parameters:
mask |
Defines the correspondence between the output tensor dimensions and the |
scales |
Constant vector of output scaling factors. If the scaling factors are known at the time of this call, the following equality must hold: \(scales.size() = \prod\limits_{d \in mask} output.dims[d].\) Violations can only be detected when the attributes are used to create a primitive descriptor. If the scaling factors are not known at the time of the call, this vector must contain a single DNNL_RUNTIME_F32_VAL value and the output scaling factors must be passed at execution time as an argument with index DNNL_ARG_ATTR_OUTPUT_SCALES. |
void get_scales(int arg, int& mask, std::vector<float>& scales) const
Returns scaling factors correspondence mask and values for a given memory argument.
Parameters:
arg |
Parameter argument index as passed to the primitive::execute() call. |
mask |
Scaling factors correspondence mask that defines the correspondence between the output tensor dimensions and the |
scales |
Output vector of scaling factors. |
void set_scales(int arg, int mask, const std::vector<float>& scales)
Sets scaling factors for primitive operations for a given memory argument.
Parameters:
arg |
Parameter argument index as passed to the primitive::execute() call. |
mask |
Scaling factors correspondence mask that defines the correspondence between the tensor dimensions and the |
scales |
Constant vector of scaling factors. The following equality must hold: \(scales.size() = \prod\limits_{d \in mask} argument.dims[d].\) |
See also:
dnnl_primitive_attr_set_scales
dnnl::primitive_attr::set_output_scales
void get_zero_points(int arg, int& mask, std::vector<int32_t>& zero_points) const
Returns zero points correspondence mask and values.
Parameters:
arg |
Parameter argument index as passed to the primitive::execute() call. |
mask |
Zero points correspondence mask that defines the correspondence between the output tensor dimensions and the |
zero_points |
Output vector of zero points. |
void set_zero_points(int arg, int mask, const std::vector<int32_t>& zero_points)
Sets zero points for primitive operations for a given memory argument.
Parameters:
arg |
Parameter argument index as passed to the primitive::execute() call. |
mask |
Zero point correspondence mask that defines the correspondence between the tensor dimensions and the |
zero_points |
Constant vector of zero points. If the zero points are known at the time of this call, the following equality must hold: \(zero\_points.size() = \prod\limits_{d \in mask} argument.dims[d].\) If the zero points are not known at the time of the call, this vector must contain a single DNNL_RUNTIME_S32_VAL value and the zero points must be passed at execution time as an argument with index DNNL_ARG_ATTR_ZERO_POINTS. |
See also:
dnnl_primitive_attr_set_zero_points
dnnl::primitive_attr::set_output_scales
const post_ops get_post_ops() const
Returns post-ops previously set via set_post_ops().
Returns:
Post-ops.
void set_post_ops(const post_ops ops)
Sets post-ops.
Note
There is no way to check whether the post-ops would be supported by the target primitive. Any error will be reported by the respective primitive descriptor constructor.
Parameters:
ops |
Post-ops object to copy post-ops from. |
void set_rnn_data_qparams(float scale, float shift)
Sets quantization scale and shift parameters for RNN data tensors.
For performance reasons, the low-precision configuration of the RNN primitives expect input activations to have the unsigned 8-bit integer data type. The scale and shift parameters are used to quantize floating-point data to unsigned integer and must be passed to the RNN primitive using attributes.
The quantization formula is scale * data + shift
.
Example usage:
// RNN parameters int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32; // Activations quantization parameters float scale = 63.f, shift = 64.f; primitive_attr attr; // Set scale and shift for int8 quantization of activation attr.set_rnn_data_qparams(scale, shift); // Create and configure rnn op_desc vanilla_rnn_forward::desc rnn_d(/* arguments */); vanilla_rnn_forward::primitive_desc rnn_d(rnn_d, attr, engine);
Note
Quantization scale and shift are common for src_layer, src_iter, dst_iter, and dst_layer.
Parameters:
scale |
The value to scale the data by. |
shift |
The value to shift the data by. |
void get_rnn_data_qparams(float& scale, float& shift)
Returns the quantization scale and shift parameters for RNN data tensors.
Note
Quantization scale and shift are common for src_layer, src_iter, dst_iter, and dst_layer.
Parameters:
scale |
The value to scale the data by. |
shift |
The value to shift the data by. |
void set_rnn_weights_qparams(int mask, const std::vector<float>& scales)
Sets quantization scaling factors for RNN weights tensors.
The low-precision configuration of the RNN primitives expect input weights to use the signed 8-bit integer data type. The scaling factors are used to quantize floating-point data to signed integer and must be passed to RNN primitives using attributes.
Note
The dimension order is always native and does not depend on the actual layout used. For example, five-dimensional weights always have (l, d, i, g, o) logical dimension ordering.
Note
Quantization scales are common for weights_layer and weights_iteration
Parameters:
mask |
Scaling factors correspondence mask that defines the correspondence between the output tensor dimensions and the |
scales |
Constant vector of output scaling factors. The following equality must hold: \(scales.size() = \prod\limits_{d \in mask} weights.dims[d].\) Violations can only be detected when the attributes are used to create a primitive descriptor. |
void get_rnn_weights_qparams(int& mask, std::vector<float>& scales)
Returns the quantization scaling factors for RNN projection weights tensors.
Note
The dimension order is always native and does not depend on the actual layout used. For example, five-dimensional weights always have (l, d, i, g, o) logical dimension ordering.
Parameters:
mask |
Scaling factors correspondence mask that defines the correspondence between the output tensor dimensions and the |
scales |
Constant vector of output scaling factors. The following equality must hold: \(scales.size() = \prod\limits_{d \in mask} weights.dims[d].\) Violations can only be detected when the attributes are used to create a primitive descriptor. |
void set_rnn_weights_projection_qparams( int mask, const std::vector<float>& scales )
Sets quantization scaling factors for RNN projection weights tensors.
passed to RNN primitives using attributes.
Note
The dimension order is always native and does not depend on the actual layout used. For example, five-dimensional weights always have (l, d, i, g, o) logical dimension ordering.
Note
Quantization scales are common for weights_layer and weights_iteration
Parameters:
mask |
Scaling factors correspondence mask that defines the correspondence between the output tensor dimensions and the |
scales |
Constant vector of output scaling factors. The following equality must hold: \(scales.size() = \prod\limits_{d \in mask} weights.dims[d].\) Violations can only be detected when the attributes are used to create a primitive descriptor. |
void get_rnn_weights_projection_qparams(int& mask, std::vector<float>& scales)
Returns the quantization scaling factors for RNN projection weights tensors.
Note
The dimension order is always native and does not depend on the actual layout used. For example, five-dimensional weights always have (l, d, i, g, o) logical dimension ordering.
Parameters:
mask |
Scaling factors correspondence mask that defines the correspondence between the output tensor dimensions and the |
scales |
Constant vector of output scaling factors. The following equality must hold: \(scales.size() = \prod\limits_{d \in mask} weights.dims[d].\) Violations can only be detected when the attributes are used to create a primitive descriptor. |