CNNの損失関数(Loss Function) その(1): 交差エントロピーとMSE

1 概要

この記事では,畳み込みニューラルネットワーク(CNN)の教師あり学習に用いられる損失関数(Loss function)のうち,交差エントロピー誤差 (Cross Entropy Error)(識別問題向け)と 平均二乗誤差(MSE: Mean Squared Error)/ノルム損失(回帰問題向け)についてまとめる.また,各損失関数と組み合わせて用いる「 出力層」についても整理する.

この記事は,CNN向け損失関数をまとめたシリーズ記事のうち,識別問題と回帰問題における損失関数について取り上げる「 その(1) 基本編 」である.

読者が,各問題向けのニューラルネット損失関数を整理して復習し,基礎的な考え方が定着しやすくなるように,各記事を執筆していく.

1.1 記事の構成

以下の構成で,教師あり学習の基本問題である識別CNN・回帰CNN向けの「損失関数 (2節)」と,それらに組み合わせる「出力層(3節)」の種類について整理する.

  • 2節:識別CNN・回帰CNN向けの基本的な損失関数
  • 3節:(損失関数,出力層)の基本的な組み合わせ
  • 4節:まとめ

2 識別CNN・回帰CNN向けの基本的な損失関数

2.1 準備

まず問題設定を先に定義しておくことで,2.2節以降の準備を行う.

以下の,$i$番目のサンプルとターゲット$(\bm{x}^{(i)},\bm{y}^{(i)})$のペア$N$個から構成される画像認識データセットを用いて,識別CNNもしくは回帰CNNを学習することを考える:

  • $\bm{x}^{(i)}$:CNNに入力される,3チャンネルの画像サンプル.
  • $\bm{y}^{(i)} = (y_1^{(i)},y_2^{(i)},\ldots,y_K^{(i)})$:CNNの出力として学習させたい正解,ターゲットのone-hotベクトル(識別) or ターゲットのベクトル(回帰).
  • $\hat{\bm{y}}^{(i)} = (\hat{y}_1^{(i)},\hat{y}_2^{(i)},\ldots,\hat{y}_K^{(i)})$:学習中・学習後のCNNの出力層が,順伝搬の結果出力する$K$次元ベクトルの推定値.識別CNNの場合はSoftmax層を出力層に用いて,(マルチヌーイ分布として)確率ベクトル化を行うことが多い.

ここで,上記の各変数は,条件(A)~(D)において以下の通りになる:

(1) 識別問題 (A), (B):

  • $K$:識別したいクラス数.
  • $\bm{y}^{(i)}$:$K$次元のone-hotベクトル.正解のクラス$k$の要素$y_k^{(i)}$だけは値が1であり,他の要素は全て値が0である.
  • $\hat{\bm{y}}$:各$k$次元目の要素$\hat{y}_k^{(i)}$は,$k$番目のクラスの出力確率 ($0 \geq \hat{y}_k^{(i)} \geq 1$)に相当し,出力層は$K$個のニューロンを持つ.

(2) 回帰問題 (C), (D):

  • $K$:回帰対象$\bm{y}^{(i)} $の,ベクトルの次元数.
  • $\bm{y}^{(i)} $,$\hat{\bm{y}}^{(i)} $:(1)のような確率的な出力ではない (ただし,学習時の計算の都合でsoftmaxで回帰するような問題設定もある)

2.1.1 交差エントロピーを復習

交差エントロピー(CE: Cross Entropy)は,元の情報学での定義では,(離散確率分布)$p$と$q$の距離測度(measure)として,以下のように定義される:

\begin{equation} H(p,q) = – \sum_{x \in \mathcal{X}} p(x) \log{q(x)}\end{equation}

交差エントロピー$H(p,q)$は,$p$と$q$の間で取ったエントロピー(平均情報量)である.そして,それは『$p$と$q$の2分布が離れている度合い』の尺度を表す.$p$と$q$が同じ確率分布(の形状)であるほど交差エントロピー値はゼロに近づくので,予測値が確率値であるときの,機械学習モデルの誤差関数として用いることが多い(例:ロジスティック回帰や,識別ニューラルネットのパーセプトロンなど).

2.2 各問題で用いるCNNの損失関数

以降では,各損失関数$\mathcal{L}$は,SGD時のミニバッチ内における$N$個の「(ターゲット,サンプル)ペア間の誤差」の期待値を表現しているとする.実際は,CNNの学習では経験的リスク(empirical risk)を最小化するので,損失関数は$\mathcal{J}(\theta) $と表記してもよいが,ここでは便宜がよいのでひとまず損失関数 $\mathcal{L}(\cdot)$と表記したい.(経験的リスクについて詳しくは, Deep Learning Bookの 8.1 How Learning Differs from Pure Optimization が参考になる)

このとき,識別と回帰の2つの学習問題では,それぞれ以下の (A)~(D) の損失関数がファーストチョイスとして用いられる.

2.2 (A) 2クラス識別:二値クロスエントロピー(BCE)

画像入力の2クラス分類問題の学習では,ターゲットはスカラーの確率値$ y (0 \leq y \leq 1) $となる.


2クラス識別CNNの損失関数には,交差エントロピーの2クラス版である二値交差エントロピー(BCE: Binary Cross Entropy)を用いる:

\[
\mathcal{L}_{\text{BCE}}(y,\hat{y})
= – \sum_{i =1}^N [y^{(i)} \log \hat{y}^{(i)} + (1-y^{(i)}) \log (1-\hat{y}^{(i)})]\tag{2.1}
\]

これは,1次元確率であるターゲット$y$に対して,出力の1次元確率$\hat{y}$を,交差エントロピーでフィットする損失関数である.

2.2 (B) 多クラス識別:交差エントロピー(CE)

$C(\geq 2)$クラスのラベルをアノテーションした画像群から,データセットから多クラス識別器CNNを学習する際,損失関数には,交差エントロピー誤差をサンプル数だけ和を取った交差エントロピー損失が標準的に用いられる:

\[
\mathcal{L}_{\text{CE}} = – \sum_{i=1}^{N} \sum_{k=1}^{K} y_k^{(i)} \log \hat{y}_k^{(i)} \tag{2.2}
\]

式(2.2)の交差エントロピー損失は,ターゲットのクラス確率ベクトル$\bm{y}$(識別CNNの場合softmax確率)に,出力のクラス確率ベクトル$\hat{\bm{y}}_k$をフィットさせる損失関数である.

交差エントロピー損失は,負の対数尤度(Negative Log Likelihood) 損失とも呼ばれる(各次元$k$における負の対数尤度の和が,交差エントロピーである)

ここで,式(2.2)の構成を理解しやすくするために,バッチ内にN個ある損失のうち,$i$番目のサンプルペアの交差エントロピー誤差だけを見てみたい:

\begin{equation}\mathcal{L}_{CE}^{(i)}(\hat{\bm{y}},\bm{y}) = – \sum_{k=1}^{K} y_k \log \hat{y}_k \tag{2.1.1}\end{equation}

ここでは,$i$の表記を省略して$y_k^{(i)} = y_k$, $\hat{y}_k^{(i)} = \hat{y}_k$と表記した.

正解ターゲットの$\bm{y}$はone-hotベクトルであるので,結局は各サンプル$i$においては,$\bm{y}$の$k$次元目の$y_k = 1$となる要素についてだけ,積$y_k \log \hat{y}_k$が加算されると,式(2.1.1)から整理できる(他は$y_k = 0$なので加算されない).

2.2 (C) 単回帰:平均二乗誤差の損失関数

今度は回帰である.まずはスカラー$\hat{y}$を推定する,単回帰CNNのロス関数から整理したい.

単回帰CNNを学習する損失関数には,以下の平均二乗誤差(Mean Squared Error, MSE)が用いられる:

\begin{equation}\mathcal{L}_{MSE}(\hat{y},y) = \frac{1}{K} \sum_{k=1}^{K}(y- \hat{y})^2 \tag{2.3} \end{equation}

2.2 (D) 重回帰:平均二乗誤差 (L2ノルムの二乗誤差)

単回帰(C)と同様に,重回帰向けの損失関数のファーストチョイスも,ベクトル間の平均二乗誤差である (= L2ノルム二乗の平均誤差):

\begin{eqnarray}\mathcal{L}_{MSE} (\hat{\bm{y}},\bm{y}) = \frac{1}{K} \sum_{k=1}^{K} \| \bm{y}_k -\hat{\bm{y}_k} \|^2_2 \tag{2.4}
\end{eqnarray}

例えば,画像中のキーポイント(例:人物姿勢推定における人物の関節座標)をベクトル回帰させたり,物体検出において,バウンディングボックス間の修正誤差量 $\bm{t} = (x,y,w,h)$ を回帰する時などに用いられる.これらの2例のように「対象物体領域内の画像特徴に対して,何らかの$K$次元ベクトル$\bm{y}$をCNNに重回帰させる」ことを考えた場合,式(2.4)を基本的な損失関数としてまず検討する.

3 (損失関数,出力層)の基本的な組み合わせ

CNNの学習時は,出力側から,損失のパラメータに関数偏微分を前の層に逆伝播させる.ここで,CNNの出力層の各ニューロン($K$個)と,その1つ手前の全結合層のニューロン($J$個)の間における各重みパラメータを$\bm{w}$とし,各ニューロン間の個別パラメータを$w_{jk}$とする.

このとき,誤差逆伝播のステージ$t$では「損失関数の重みパラメータ$w_{jk}$に対する偏微分値」に,係数$\alpha$をかけた分だけ重みを更新する:

\begin{equation} w_{jk}^{(t+1)} = w_{jk}^{(t)} – \alpha \frac{\partial \mathcal{L}}{\partial w_{jk}^{(t)} }\tag{3.1} \end{equation}

つまり,この逆伝播値の計算に必要な偏微分値を,スムーズに低コストで計算できるような (損失関数,出力層)の組み合わせがよく用いられるのが基本的な考え方となる.

一方で,あなたが計算資源が豊富な環境を持っているのであれば,計算効率化の視点は必要ないので,自由に組み合わせてしまって良いとも言える.

次の3.1節では,まず2.2節の各(A)~(D)の損失関数に対して組み合わせる,出力層のファーストチョイスを整理する.また,3.2節では管理人の個人的な視点として,回帰問題(C,D)について「正規化」の観点からも,組み合わせの良し悪しについて述べてみたい.

3.1 (損失関数,出力層) の組み合わせのファーストチョイス

3.1 (A) 2クラス分類:BCE損失 + シグモイド出力

2クラス分類CNNの場合,基本は「二値交差エントロピー(BCE)損失 + シグモイド関数」を組み合わせる.この組み合わせは,確率値(0~1)になる(ロジスティック)シグモイド関数を出力に用いるロジスティック回帰にも関連している.

出力層シグモイドの一つ前の層のニューロン値$J$個のうち$j$番目の値を$x_j$で表すとする.このとき,式(3.1)を用いて,$w_{j}$に逆伝播させる微分値は以下のようになる:

\begin{eqnarray} \frac{\partial \mathcal{L}}{\partial w_{j}} &=
\frac{y-\hat{y}}{
\textcolor{blue}{
\underbrace{\hat{y}(1- \hat{y})}_{BCEの微分}
}
}
\textcolor{green}{
\underbrace{
\hat{y}(1- \hat{y})
}_{シグモイドの微分}
}
\times x_j \\
&= (y-\hat{y}) x_j \tag{3.1}
\end{eqnarray}

つまり,この「BCE+シグモイド」だと,$\hat{y}(1- \hat{y})$を分子と分母で打ち消すことができ,学習の計算コストが低くて済む利点が出る.

3.1 (B) 多クラス分類:CE + softmax出力

多クラス分類CNNを学習する際は「交差エントロピー(CE)損失 + softmax出力」が基本的な組み合わせである.こちらも,ロジスティック回帰のマルチクラス識別版である「多項ロジスティック回帰」に由来している.

(A)と同様に,この組み合わせでも,計算コストが低い微分値を$w_{ij}$へ逆伝播できる:

\begin{eqnarray} \frac{\partial \mathcal{L}}{\partial w_{jk}} &=
\frac{y_k – \hat{y}_k}{
\textcolor{blue}{
\underbrace{\hat{y}_k (1- \hat{y}_k)}_{CEの微分}
}
}
\textcolor{green}{
\underbrace{
\hat{y}_k (1- \hat{y}_k)
}_{softmaxの微分}
} x_j \\
&= (y_k- \hat{y}_k)x_j \tag{3.2}
\end{eqnarray}

ちなみに主要なDeep Learningライブラリでも,計算が有利になるのでこの組み合わせが実用されている.例えばPyTorchでは,この組み合わせの際に,交差エントロピーとsofsoftmax 関数 [活性化関数]tmaxを個別にそれぞれ計算するのではなく,合成して計算コストを減らした式(3.2)で逆伝播できるtorch.nn.CrossEntropyLossクラスが実装されている.

また,実用上ではsoftmax後に,対数関数も追加して用いる「softmax -> log」構成の出力層にすることも多い.これは,対数関数のおかげで,特定のクラスの生起確率が小さくなりすぎてアンダーフローが生じてしまい,その計算誤差で,学習やテストが止まるのを防ぐためである [Kamath et al., 2019].PyTorchだと,logとsoftmaxの合成関数計算が効率化されているtorch.nn.functional.log_softmax を用いると「softmax -> log 」で計算できるが,前述のCrossEntropyLossクラスの中では,そのlog_softmaxが使用されている.

logを追加する代わりに,正解データで0となっているクラスを0.0001などの小さい値に設定する対処をしても良いが,PyTorchの場合logSoftmaxがCrossEntropyLossに組み込まれているので,そういった前処理を行なっての対応は基本的に不要である.

3.1 (C) 単回帰:平均二乗誤差損失 + 出力ニューロン (活性化関数なし)

単回帰CNNでは,通常「MSE + 出力ニューロン$y$」の組み合わせを取る.重み$w_j$へ逆伝播する微分値は,以下の計算式になる:

\begin{equation}\frac{\partial \mathcal{L}}{\partial w_{j}} = (y-\hat{y}) x_j \tag{3.3}\end{equation}

3.1 (D) 重回帰:平均二乗誤差 + 線形出力層 (活性化関数なし)

重回帰でも.(C)の単回帰と同様に,畳み込み層や線形層を出力層に配置した「MSE + 線形出力層」の組み合わせが標準的である.組み合わせを取る.重み$w_{ij}$へ逆伝播する微分値は,以下の計算式になる:

\begin{equation}\frac{\partial \mathcal{L}}{\partial w_{ij}} = (y_i-\hat{y_i}) x_j \tag{3.4}\end{equation}

顔のランドマーク座標や,物体領域のバウンディングボックスサイズなど,幾何的な出力を学習させることが多いことから,素直に線形層でそのまま出力させる.

ただし,単純なL2ノルムを用いるMSEではノイズや外れ値に弱いのもあり,実際の応用では,MSEの代わりに様々なロバストな損失関数が回帰する際に使用される(このシリーズ記事の「その(2)」以降で紹介していく).

ちなみに,もし重回帰CNNで「MSE + soft出力」と出力に活性化関数を採用した場合,出力層手前の偏微分値は以下のようになる:

\begin{equation}
\frac{\partial \mathcal{L}}{\partial w_{jk}} =
\underbrace{
(y_k-\hat{y}_k)}_{MSEの微分}
\textcolor{red}{
\underbrace{
\hat{y}_k (1- \hat{y}_k)
}_{softmaxの微分}
} x_j \tag{3.5}
\end{equation}

これだと式(3.2)の場合とは異なり,$\hat{y}_k (1- \hat{y}_k)$の項が消去できずに残ってしまう.よって逆伝播の計算コストが高くなり,あまりよろしくない (使っても良いが学習は少し遅くなる).

3.2 回帰問題における「正規化」の重要性

クラス識別での確率による出力と異なり,回帰問題では,ターゲットベクトル$\bm{y}$の各次元の値の範囲が幅広かったり,次元間で値の統計に偏りがあるほど,パラメータのベクトル空間も広く偏りが出て回帰CNNをうまく収束させづらくなる.

例えば,入力と出力の関係がしっかり対応づきやすいランドマーク推定問題などでは (例:顔のランドマーク推定や人物姿勢推定など),出力の多様性が大きいと,それだけ広いパラメータ空間で最適化するはめになることから,回帰CNNを学習する難易度は上がる.ましてや,入力画像の見えと相関性があまり無いようなスコア値を回帰させたいデータの場合は(例:スコア付けやランキング問題など),入出力間の相関は少ない場合が多くで,そうした場合は回帰CNNをうまく学習させるのはなおさら難しい(※ 当然,そういう場合でも,隠れ層を用いることでむりやりにでも入出力間の関係性を発見して予測モデルを学習するのが,ニューラルネットワークなわけではあるが).

この意味で,入力や中間層のベクトルにおいて,正規化(あるいは標準化)を行うことも,学習を安定・効率化させて,回帰CNNの精度を高める上で重要となる.回帰でも識別でも,CNNの場合は,入力の正規化 [Lecun et al., 1998] や,中間層のバッチ正規化 [Ioffe et al., 2015] により,この問題に対応することが標準的である.出力層での正規化はあまり効果が出ないことが多く,出力に対してはスケーリングを行うくらい.

4. まとめ

この記事では,基本的なCNN(識別と回帰)の損失関数について,以下の点を整理した:

  • 識別CNNには「交差エントロピー関数」を,回帰CNNには「平均二乗誤差」を,それぞれ基本的な損失関数として用いる(2節).
  • ロジスティック回帰と同様の,識別モデルでのsoftmax層の使用が,識別ニューラルネットでも用いられる.「交差エントロピー + softmax」の組み合わせになるが,微分値がシンプルで計算が有利なことも,この組み合わせが標準的に用いられる理由である(3節 (A),(B))
  • (識別CNNとは対照的であるが)回帰ニューラルネットだと,出力層には活性化関数は用いないで線形層をそのまま出力することが多い.入力と出力に相関が出づらい回帰CNNを学習しやすくするには,識別CNNにも増して入力層や中間層をしっかり正規化することが,安定した収束を達成するために重要となる.

関連書籍

References

参照外部リンク

関連記事

↓ ためになった方は,記事をSNSでシェアをしてくださると,管理人の記事執筆モチベーションが上がります