Descriptor for an LSTM forward propagation primitive. More...
#include <dnnl.hpp>
Public Member Functions | |
desc (prop_kind prop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &weights_peephole_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, rnn_flags flags=rnn_flags::undef) | |
Constructs a descriptor for an LSTM (with or without peephole) forward propagation primitive. More... | |
desc (prop_kind prop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, rnn_flags flags=rnn_flags::undef) | |
Constructs a descriptor for an LSTM forward propagation primitive. More... | |
Descriptor for an LSTM forward propagation primitive.
|
inline |
Constructs a descriptor for an LSTM (with or without peephole) forward propagation primitive.
The src_iter_desc
, src_iter_c_desc
, weights_peephole_desc
, bias_desc
, dst_iter_desc
, and dst_iter_c_desc
may point to a zero memory descriptor. This would then indicate that the LSTM forward propagation primitive should not use them and should default to zero values instead.
src_iter_desc
can be initialized with an dnnl::memory::format_tag::any value of format_tag
.Inputs:
Outputs:
prop_kind
equals dnnl::prop_kind::forward_training; must be queried for using dnnl::primitive_desc_base::query_md() after a corresponding primitive descriptor is createdprop_kind | Propagation kind. Possible values are dnnl::prop_kind::forward_training, and dnnl::prop_kind::forward_inference. |
direction | RNN direction. See dnnl::rnn_direction for more info. |
src_layer_desc | Memory descriptor for the input vector. |
src_iter_desc | Memory descriptor for the input recurrent hidden state vector. |
src_iter_c_desc | Memory descriptor for the input recurrent cell state vector. |
weights_layer_desc | Memory descriptor for the weights applied to the layer input. |
weights_iter_desc | Memory descriptor for the weights applied to the recurrent input. |
weights_peephole_desc | Memory descriptor for the weights applied to the cell states (according to the Peephole LSTM formula). |
bias_desc | Bias memory descriptor. |
dst_layer_desc | Memory descriptor for the output vector. |
dst_iter_desc | Memory descriptor for the output recurrent hidden state vector. |
dst_iter_c_desc | Memory descriptor for the output recurrent cell state vector. |
flags | Unused. |
|
inline |
Constructs a descriptor for an LSTM forward propagation primitive.
The src_iter_desc
, src_iter_c_desc
, bias_desc
, dst_iter_desc
, and dst_iter_c_desc
may point to a zero memory descriptor. This would then indicate that the LSTM forward propagation primitive should not use them and should default to zero values instead.
src_iter_desc
can be initialized with an dnnl::memory::format_tag::any value of format_tag
.Inputs:
Outputs:
prop_kind
equals dnnl_forward_training; must be queried for using dnnl_primitive_desc_query_md() after a corresponding primitive descriptor is createdprop_kind | Propagation kind. Possible values are dnnl::prop_kind::forward_training, and dnnl::prop_kind::forward_inference. |
direction | RNN direction. See dnnl::rnn_direction for more info. |
src_layer_desc | Memory descriptor for the input vector. |
src_iter_desc | Memory descriptor for the input recurrent hidden state vector. |
src_iter_c_desc | Memory descriptor for the input recurrent cell state vector. |
weights_layer_desc | Memory descriptor for the weights applied to the layer input. |
weights_iter_desc | Memory descriptor for the weights applied to the recurrent input. |
bias_desc | Bias memory descriptor. |
dst_layer_desc | Memory descriptor for the output vector. |
dst_iter_desc | Memory descriptor for the output recurrent hidden state vector. |
dst_iter_c_desc | Memory descriptor for the output recurrent cell state vector. |
flags | Unused. |