chainer.functions.lstm(c_prev, x)[source]

活性関数としてのLong Short-Term Memory units。

 

この関数はForget Gate を伴うLSTM units を実装しています。以前のセル・ステート(メモリセル内の状態)を c_prev とし、入力配列を xとします。

 

はじめに、入力配列 x が第2軸に添って、同じShapeを持つ4つの配列a,i,f,oにわけられます。つまり、  x の第2軸は c_prev の第2軸の4倍でなければなりません。

 

分割された入力配列は下記に相当します。

  • aa : sources of cell input
  • ii : sources of input gate
  • ff : sources of forget gate
  • oo : sources of output gate

次に、更新されたセル状態 c と出力シグナルを h とします。:

ch=tanh(a)σ(i)+cprevσ(f),=tanh(c)σ(o),c=tanh⁡(a)σ(i)+cprevσ(f),h=tanh⁡(c)σ(o),

 

 

σ は cより小さい場合、この関数は c[0:len(x)] を更新するのみで、 cや c[len(x):]の残りは変更しません。 ですから、入力シーケンスこの関数へ適用する前に降順にソートしてください。

 

Parameters:
  • c_prev (Variable or numpy.ndarray or cupy.ndarray) – 以前のセルの状態を保持するVariables。セル状態は0配列か. LSTMの以前の呼び出しの出力でなければならない。
  • x (Variable or numpy.ndarray or cupy.ndarray) – cell input、input gate、 forget gate 、 output gateのソースを保持するVariables。セル状態の4倍のサイズの次元数の第2軸を持たなければならない。
Returns:

2つの Variable オブジェクト、 c と hc は更新されたセル状態。 h は出力シグナルを示す。

Return type:

tuple

 

Forget Gateを伴う LSTM を提案した元論文もお読みください。: Long Short-Term Memory in Recurrent Neural Networks.

See also

LSTM

 

Example

 

y がカレントの入力シグナル、 c は以前のセル状態、 h が以前の lstm からの出力シグナルであると仮定します。各 y、 c 、 h は n_units チャンネルを持ちます。最も一般的な x は:

 


>>>
n_units = 100
>>> y = chainer.Variable(np.zeros((1, n_units), 'f'))
>>> h = chainer.Variable(np.zeros((1, n_units), 'f'))
>>> c = chainer.Variable(np.zeros((1, n_units), 'f'))
>>> model = chainer.Chain()
>>> with model.init_scope():
... model.w = L.Linear(n_units, 4 * n_units)
... model.v = L.Linear(n_units, 4 * n_units)
>>> x = model.w(y) + model.v(h)
>>> c, h = F.lstm(c, x)

 

これは、入力配列 x、もしくはカレント入力シグナル y と以前の出力シグナル hからの入力ソースa,i,f,o計算することに相応します。異なる入力ソースへは異なるパラメータが用いられます。 

Note

下記の命名規則を使用しています。

  • incoming signal /入力シグナル
     LSTM の公式の正式な入力(たとえばNLPにおけるword vector もしくは 低次のRNN レイヤーの出力).  chainer.links.LSTM の入力はこの incoming signal.
    input array /入力配列
    incoming signal や以前の出力シグナルから線形変換された配列。この input arrayは4つのソースcell input、 input gate、forget gate 、 output gateを含む。chainer.functions.LSTM の入力はこの input array。