chainer.functions.crf1d(cost, xs, ys, reduce='mean')[source]

linear-chain CRFの負の対数尤度を計算する。

 

この関数は遷移コスト行列、コストのシーケンス、そしてラベルのシーケンスをとります。

  \(c_{st}\) をラベル\(s\) からラベル\(t\)への遷移コストとし、\(x_{it}\) が位置が\(i\)のときのラベル \(t\) 、 a \(y_i\) が位置\(i\)での予測ラベルとします。

linear-chain CRFの負の対数尤度は下記のように定義されます。

\[L = -\left( \sum_{i=1}^l x_{iy_i} + \ \sum_{i=1}^{l-1} c_{y_i y_{i+1}} - {\log(Z)} \right) ,\]

ただし、 \(l\) は入力シーケンスの長さ、\(Z\)は分配関数と呼ばれる規格化定数

Note

異なる長さをもつシーケンスの負の対数尤度を計算したいとき、シーケンスを長さによって降順にソートして転置します。

 

たとえば、3つの入力シーケンスがあるとします。


>>> a1 = a2 = a3 = a4 = np.random.uniform(-1, 1, 3).astype('f')
>>> b1 = b2 = b3 = np.random.uniform(-1, 1, 3).astype('f')
>>> c1 = c2 = np.random.uniform(-1, 1, 3).astype('f')

>>> a = [a1, a2, a3, a4]
>>> b = [b1, b2, b3]
>>> c = [c1, c2]

 

ただし、 a1 と他のvariables全てが(K,) shapeであるような配列の場合。

 

シーケンスを転置してください。


>>> x1 = np.stack([a1, b1, c1])
>>> x2 = np.stack([a2, b2, c2])
>>> x3 = np.stack([a3, b3])
>>> x4 = np.stack([a4])

 

そして、配列のリストを作成してください。

>>> xs = [x1, x2, x3, x4]

 

同様に、ラベルシーケンスを作成する必要があります。

そして、それから、この関数を呼んでください。

>>> cost = chainer.Variable(

...               np.random.uniform(-1, 1, (3, 3)).astype('f'))
>>> ys = [np.zeros(x.shape[0:1], dtype='i') for x in xs]
>>> loss = F.crf1d(cost, xs, ys)

 

3つのシーケンスの負の対数尤度の平均値を計算します。

この出力は reduce の値に依存する値となります。 'no'が設定されている場合、エレメント毎のロス値を保持します。 'mean'が設定されている場合、ロス値の平均を保持します。

 

Parameters:
  • cost (Variable) –  \(K \times K\) 行列。 2つのラベル間の遷移コストを保持している。ただし、\(K\)はラベル数。
  • xs (list of Variable) –各ラベルの入力ベクトル。 len(xs) は、シーケンス長を表す。各 Variable\(B \times K\) 行列を保持する。この \(B\) はミニバッチサイズ。\(K\) はラベル数。 全ての変数中の\(B\)が同じである必要はない。つまり異なる長さのシーケンスも受け入れる。
  • ys (list of Variable) – 予測出力ラベル。xsと同じ長さである必要がある。 各 Variable\(B\)整数ベクトルを保持している。 xs において x が異なる\(B\)を持つ時、対応する y は同じ \(B\)を持つ。つまり ysys[i].shape == xs[i].shape[0:1] を、全てのiについて満たさなければならない。
  • reduce (str) – 削減オプション。 'mean''no'のいずれかでなければならない。それ以外では ValueError が発生する。
Returns:

入力シーケンスの負の平均対数尤度を保持しているVariable。

Return type:

Variable

Note

元論文もお読みください。: Conditional Random Fields: Probabilistic Models for Segmenting and Labeling Sequence Data.