The pooling primitive performs forward or backward max or average pooling operation on 2D or 3D spatial data.
The pooling operation is defined by the following formulas. We show formulas only for 2D spatial data which are straightforward to generalize to cases of higher and lower dimensions. Variable names follow the standard Naming Conventions.
Max pooling:
\[ dst(n, c, oh, ow) = \max\limits_{kh, kw} \left( src(n, c, oh \cdot SH + kh - ph_0, ow \cdot SW +kw - pw_0) \right) \]
Average pooling:
\[ dst(n, c, oh, ow) = \frac{1}{DENOM} \sum\limits_{kh, kw} src(n, c, oh \cdot SH + kh - ph_0, ow \cdot SW +kw - pw_0) \]
where \(ph_0, pw_0\) are padding_l[0]
and padding_l[1]
respectively, and output spatial dimensions are calculated similarly to how they are done in convolution.
Average pooling supports two algorithms:
TODO: a picture would be nice here.
workspace
output for the dnnl_forward_training propagation kind, and doesn't require it for dnnl_forward_inference (see details below).The backward propagation computes \(diff\_src(n, c, h, w)\), based on \(diff\_dst(n, c, h, w)\) and (in case of max pooling) workspace
.
dst
memory descriptor when creating pooling forward propagation. The library would derive the appropriate format from the src
memory descriptor. However, the src
itself must be defined. Similarly, a user can use memory format tag dnnl_format_tag_any for thediff_src
memory descriptor when creating pooling backward propagation.The pooling primitive supports the following combinations of data types:
Propagation | Source / Destination | Acc |
---|---|---|
forward / backward | f32, bf16 | f32 |
forward | f16 | f16 |
forward | s8, u8, s32 | s32 |
Like other CNN primitives, the pooling primitive expects data to be \(N \times C \times H \times W\) tensor in case 2D spatial data and \(N \times C \times D \times H \times W\) tensor in case 3D spatial data.
The pooling primitive is optimized for the following memory formats:
Spatial | Logical tensor | Data type | Implementations optimized for memory formats |
---|---|---|---|
2D | NCHW | f32 | dnnl_nchw (dnnl_abcd), dnnl_nhwc (dnnl_acdb), optimized^ |
2D | NCHW | s32, s8, u8 | dnnl_nhwc (dnnl_acdb), optimized^ |
3D | NCDHW | f32 | dnnl_ncdhw (dnnl_abcde), dnnl_ndhwc (dnnl_acdeb), optimized^ |
3D | NCDHW | s32, s8, u8 | dnnl_ndhwc (dnnl_acdeb), optimized^ |
Here optimized^ means the format that comes out of any preceding compute-intensive primitive.
The pooling primitive doesn't support any post-ops or attributes.
N/A