struct dnnl::lstm_forward::desc

Overview

Descriptor for an LSTM forward propagation primitive. More…

#include <dnnl.hpp>

struct desc
{
    // fields

    dnnl_rnn_desc_t data;

    // construction

    desc(
        prop_kind aprop_kind,
        rnn_direction direction,
        const memory::desc& src_layer_desc,
        const memory::desc& src_iter_desc,
        const memory::desc& src_iter_c_desc,
        const memory::desc& weights_layer_desc,
        const memory::desc& weights_iter_desc,
        const memory::desc& weights_peephole_desc,
        const memory::desc& weights_projection_desc,
        const memory::desc& bias_desc,
        const memory::desc& dst_layer_desc,
        const memory::desc& dst_iter_desc,
        const memory::desc& dst_iter_c_desc,
        rnn_flags flags = rnn_flags::undef
        );

    desc(
        prop_kind aprop_kind,
        rnn_direction direction,
        const memory::desc& src_layer_desc,
        const memory::desc& src_iter_desc,
        const memory::desc& src_iter_c_desc,
        const memory::desc& weights_layer_desc,
        const memory::desc& weights_iter_desc,
        const memory::desc& weights_peephole_desc,
        const memory::desc& bias_desc,
        const memory::desc& dst_layer_desc,
        const memory::desc& dst_iter_desc,
        const memory::desc& dst_iter_c_desc,
        rnn_flags flags = rnn_flags::undef
        );

    desc(
        prop_kind aprop_kind,
        rnn_direction direction,
        const memory::desc& src_layer_desc,
        const memory::desc& src_iter_desc,
        const memory::desc& src_iter_c_desc,
        const memory::desc& weights_layer_desc,
        const memory::desc& weights_iter_desc,
        const memory::desc& bias_desc,
        const memory::desc& dst_layer_desc,
        const memory::desc& dst_iter_desc,
        const memory::desc& dst_iter_c_desc,
        rnn_flags flags = rnn_flags::undef
        );
};

Detailed Documentation

Descriptor for an LSTM forward propagation primitive.

Construction

desc(
    prop_kind aprop_kind,
    rnn_direction direction,
    const memory::desc& src_layer_desc,
    const memory::desc& src_iter_desc,
    const memory::desc& src_iter_c_desc,
    const memory::desc& weights_layer_desc,
    const memory::desc& weights_iter_desc,
    const memory::desc& weights_peephole_desc,
    const memory::desc& weights_projection_desc,
    const memory::desc& bias_desc,
    const memory::desc& dst_layer_desc,
    const memory::desc& dst_iter_desc,
    const memory::desc& dst_iter_c_desc,
    rnn_flags flags = rnn_flags::undef
    )

Constructs a descriptor for an LSTM (with or without peephole and with or without projection) forward propagation primitive.

The following arguments may point to a zero memory descriptor:

  • src_iter_desc together with src_iter_c_desc,

  • weights_peephole_desc,

  • bias_desc,

  • dst_iter_desc together with dst_iter_c_desc.

This would then indicate that the LSTM forward propagation primitive should not use them and should default to zero values instead.

The weights_projection_desc may point to a zero memory descriptor. This would then indicate that the LSTM doesn’t have recurrent projection layer.

Note

All memory descriptors can be initialized with an dnnl::memory::format_tag::any value of format_tag.

Parameters:

aprop_kind

Propagation kind. Possible values are dnnl::prop_kind::forward_training, and dnnl::prop_kind::forward_inference.

direction

RNN direction. See dnnl::rnn_direction for more info.

src_layer_desc

Memory descriptor for the input vector.

src_iter_desc

Memory descriptor for the input recurrent hidden state vector.

src_iter_c_desc

Memory descriptor for the input recurrent cell state vector.

weights_layer_desc

Memory descriptor for the weights applied to the layer input.

weights_iter_desc

Memory descriptor for the weights applied to the recurrent input.

weights_peephole_desc

Memory descriptor for the weights applied to the cell states (according to the Peephole LSTM formula).

weights_projection_desc

Memory descriptor for the weights applied to the hidden states to get the recurrent projection (according to the Projection LSTM formula).

bias_desc

Bias memory descriptor.

dst_layer_desc

Memory descriptor for the output vector.

dst_iter_desc

Memory descriptor for the output recurrent hidden state vector.

dst_iter_c_desc

Memory descriptor for the output recurrent cell state vector.

flags

Unused.

desc(
    prop_kind aprop_kind,
    rnn_direction direction,
    const memory::desc& src_layer_desc,
    const memory::desc& src_iter_desc,
    const memory::desc& src_iter_c_desc,
    const memory::desc& weights_layer_desc,
    const memory::desc& weights_iter_desc,
    const memory::desc& weights_peephole_desc,
    const memory::desc& bias_desc,
    const memory::desc& dst_layer_desc,
    const memory::desc& dst_iter_desc,
    const memory::desc& dst_iter_c_desc,
    rnn_flags flags = rnn_flags::undef
    )

Constructs a descriptor for an LSTM (with or without peephole) forward propagation primitive.

The following arguments may point to a zero memory descriptor:

  • src_iter_desc together with src_iter_c_desc,

  • weights_peephole_desc,

  • bias_desc,

  • dst_iter_desc together with dst_iter_c_desc.

This would then indicate that the LSTM forward propagation primitive should not use them and should default to zero values instead.

Note

All memory descriptors can be initialized with an dnnl::memory::format_tag::any value of format_tag.

Parameters:

aprop_kind

Propagation kind. Possible values are dnnl::prop_kind::forward_training, and dnnl::prop_kind::forward_inference.

direction

RNN direction. See dnnl::rnn_direction for more info.

src_layer_desc

Memory descriptor for the input vector.

src_iter_desc

Memory descriptor for the input recurrent hidden state vector.

src_iter_c_desc

Memory descriptor for the input recurrent cell state vector.

weights_layer_desc

Memory descriptor for the weights applied to the layer input.

weights_iter_desc

Memory descriptor for the weights applied to the recurrent input.

weights_peephole_desc

Memory descriptor for the weights applied to the cell states (according to the Peephole LSTM formula).

bias_desc

Bias memory descriptor.

dst_layer_desc

Memory descriptor for the output vector.

dst_iter_desc

Memory descriptor for the output recurrent hidden state vector.

dst_iter_c_desc

Memory descriptor for the output recurrent cell state vector.

flags

Unused.

desc(
    prop_kind aprop_kind,
    rnn_direction direction,
    const memory::desc& src_layer_desc,
    const memory::desc& src_iter_desc,
    const memory::desc& src_iter_c_desc,
    const memory::desc& weights_layer_desc,
    const memory::desc& weights_iter_desc,
    const memory::desc& bias_desc,
    const memory::desc& dst_layer_desc,
    const memory::desc& dst_iter_desc,
    const memory::desc& dst_iter_c_desc,
    rnn_flags flags = rnn_flags::undef
    )

Constructs a descriptor for an LSTM forward propagation primitive.

The following arguments may point to a zero memory descriptor:

  • src_iter_desc together with src_iter_c_desc,

  • bias_desc,

  • dst_iter_desc together with dst_iter_c_desc.

This would then indicate that the LSTM forward propagation primitive should not use them and should default to zero values instead.

Note

All memory descriptors can be initialized with an dnnl::memory::format_tag::any value of format_tag.

Parameters:

aprop_kind

Propagation kind. Possible values are dnnl::prop_kind::forward_training, and dnnl::prop_kind::forward_inference.

direction

RNN direction. See dnnl::rnn_direction for more info.

src_layer_desc

Memory descriptor for the input vector.

src_iter_desc

Memory descriptor for the input recurrent hidden state vector.

src_iter_c_desc

Memory descriptor for the input recurrent cell state vector.

weights_layer_desc

Memory descriptor for the weights applied to the layer input.

weights_iter_desc

Memory descriptor for the weights applied to the recurrent input.

bias_desc

Bias memory descriptor.

dst_layer_desc

Memory descriptor for the output vector.

dst_iter_desc

Memory descriptor for the output recurrent hidden state vector.

dst_iter_c_desc

Memory descriptor for the output recurrent cell state vector.

flags

Unused.