教師データがラベルとして与えられた時の損失関数の実装がすぐ理解できなかったのでメモ

「ゼロから作るDeep Learning ―Pythonで学ぶディープラーニングの理論と実装」を勉強しています。



4.2 損失関数 というところで
4.2.4 [バッチ対応版]交差エントロピー誤差の実装
という項があります。
ここではバッチ処理に対応した交差エントロピー誤差を実装するというのが主題ですが、このとき教師データとの差分を使って誤差を出します。
ここで、それまでは one-hot 表現、つまり例えばクラスが5つあり、正解が3のときの教師データは [0, 0, 1, 0, 0] という配列であったのが、ラベルとしての表現、つまり教師データが数値として 3 が与えられる場合も出てきました。
ここが、説明はあるのですが、すぐ理解できませんでした。

出力が y という行列で、教師データが t であった場合の説明です。
ここではバッチ処理を行っていますが、今はそこは気にせずにone-hotとラベルだけに着目するために無視します。

このとき、例えば出力 y はクラスが5つの場合次のような列になります。
[0.1, 0.3, 0.8, 0.2, 0,1]

正解クラスが、1つ目は3で2つ目は2であるとき、教師データはそれぞれ次のようになります。
one-hot
[0, 0, 1, 0, 0]
ラベル
[3]

まずone-hotの場合を見ていきます。
交差エントロピー誤差の式は次のようになります。

-np.sum(t*np.log(y))/batch_size

batch_sizeは今 1です。

-(t*np.log(y))

すると 3 つ目の値以外は 0 を掛けるために 0になってしまい、

[log(0.1), log(0.3), log(0.8), log(0.2), log(0,1)]

のうちで

log (0.8)

だけが出てきて、計算結果は

-log(0.8) になります。



では次にラベル表現の場合ですが、式は次のようになります。

-np.sum(np.log(y[np.arange(batch_size), t])) / batch_size

ここで np.arange(batch_size) は 0 から始まり、 batch_size -1 までの配列ということですので、batch_size が 1 で t が 3 である今は

-np.sum(np.log(y[0, 3]))

(y[0, 3] という表現は行が1 しかない行列で正しい表現なのかちょっと不明ですが、例えば 2行 5列 の場合の1行目だと思ってください)
こうなります。
ここから意味がよくわからなかったのですが、

y[0, 3] とは y の配列のうちで 3列目を取り出すということなので今は 0.8 という数値です。
つまり
-np.sum(np.log(0.8))

これで無事にone-hot と同じ

-log(0.8)

が出てきました。

こう見ると当たり前なのですが、
最初は、取り出したい数値がラベルの数字?でもラベルの数字は1を超えてるから式に入れたらおかしなことにならないか?などと意味不明なことを考えていました。
ラベル表現の場合でも、one-hot と ラベル表現で 出力である y の配列は同じであって、正解ラベルの数値である0.8を取り出したいのだということが頭に入っていなかったので混乱してしまっていた、ということでした。

スポンサーリンク