CNNの損失関数(Loss Function) その(1): 交差エントロピーとMSE

記事を共有する:

1. 概要

この記事では,CNN(畳み込みニューラルネットワーク)の教師あり学習時に用いられる損失関数(Loss function)のうち,交差エントロピー誤差 (Cross Entropy Error)識別問題向け)と 平均二乗誤差(MSE: Mean Squared Error)/ノルム損失回帰問題向け)についてまとめる(2節).また,各損失関数と組み合わせて用いる「 出力層」についても整理する(3節).

この記事は,「CNN向け損失関数」をまとめたシリーズ記事のうち,一番ベーシックな「識別問題「と「回帰問題」むけの損失関数について取り上げる「その(1) 基本編」である.

読者が,各問題向けのニューラルネット・CNN損失関数を整理して復習し,基礎的な考え方が定着しやすくなるように,横断的に各モデルや各タスクの損失関数をグループ化して比較・列挙する記事を執筆していく.

1.1 記事の構成

この記事の2節以降は,以下の構成である:

  • 2節:識別CNN・回帰CNN向けの基本的な損失関数
  • 3節:(損失関数,出力層)ペアの,基本的な組み合わせ
  • 4節:まとめ

教師あり学習の基本問題である識別CNN・回帰CNN向けの「損失関数 (2節)」と,それらにペアとして組み合わせる「出力層(3節)」の種類について整理する.

2. 識別CNN・回帰CNN向けの基本的な損失関数

2.1 準備

まず問題設定を先に定義しておくことで,2.2節以降の準備を行う.

以下の,$i$番目のサンプルとターゲット$(\bm{x}^{(i)},\bm{y}^{(i)})$のペア$N$個から構成される画像認識データセットを用いて,識別CNNもしくは回帰CNNを学習することを考える:

  • $\bm{x}^{(i)}$:CNNに入力される,3チャンネルの画像サンプル.
  • $\bm{y}^{(i)} = (y_1^{(i)},y_2^{(i)},\ldots,y_K^{(i)})$:CNNの出力として学習させたい正解,ターゲットのone-hotベクトル(識別) or ターゲットのベクトル(回帰).
  • $\hat{\bm{y}}^{(i)} = (\hat{y}_1^{(i)},\hat{y}_2^{(i)},\ldots,\hat{y}_K^{(i)})$:学習中・学習後のCNNの出力層が,順伝搬の結果出力する$K$次元ベクトルの推定値.識別CNNの場合はSoftmax層を出力層に用いて,(マルチヌーイ分布として)確率ベクトル化を行うことが多い.

ここで,上記の各変数は,条件(A)~(D)において以下の通りになる:

(1) 識別問題 (A), (B):

  • $K$:識別したいクラス数.
  • $\bm{y}^{(i)}$:$K$次元のone-hotベクトル.正解のクラス$k$の要素$y_k^{(i)}$だけは値が1であり,他の要素は全て値が0である.
  • $\hat{\bm{y}}$:各$k$次元目の要素$\hat{y}_k^{(i)}$は,$k$番目のクラスの出力確率 ($0 \geq \hat{y}_k^{(i)} \geq 1$)に相当し,出力層は$K$個のニューロンを持つ.

(2) 回帰問題 (C), (D):

  • $K$:回帰対象$\bm{y}^{(i)} $の,ベクトルの次元数.
  • $\bm{y}^{(i)} $,$\hat{\bm{y}}^{(i)} $:(1)のような確率的な出力ではない (ただし,学習時の計算の都合でsoftmaxで回帰するような問題設定もある)

2.1.1 交差エントロピーを復習

交差エントロピー(CE: Cross Entropy)は,元の情報学での定義では,(離散確率分布)$p$と$q$の距離測度(measure)として,以下のように定義される:

\begin{equation} H(p,q) = – \sum_{x \in \mathcal{X}} p(x) \log{q(x)}\end{equation}

交差エントロピー$H(p,q)$は,$p$と$q$の間で取ったエントロピー(平均情報量)である.そして,それは『$p$と$q$の2分布が離れている度合い』の尺度を表す.$p$と$q$が同じ確率分布(の形状)であるほど交差エントロピー値はゼロに近づくので,予測値が確率値であるときの,機械学習モデルの誤差関数として用いることが多い(例:ロジスティック回帰や,識別ニューラルネットのパーセプトロンなど).

2.2 各問題で用いるCNNの損失関数

以降では,各損失関数$\mathcal{L}$は,SGD時のミニバッチ内における$N$個の「(ターゲット,サンプル)ペア間の誤差」の期待値を表現しているとする.実際は,CNNの学習では経験的リスク(empirical risk)を最小化するので,損失関数は$\mathcal{J}(\theta) $と表記してもよいが,ここでは便宜がよいのでひとまず損失関数 $\mathcal{L}(\cdot)$と表記したい.(経験的リスクについて詳しくは, Deep Learning Bookの 8.1 How Learning Differs from Pure Optimization が参考になる)

このとき,識別と回帰の2つの学習問題では,それぞれ以下の (A)~(D) の損失関数がファーストチョイスとして用いられる.

2.2 (A) 2クラス識別:二値クロスエントロピー(BCE)

画像入力の2クラス分類問題の学習では,ターゲットはスカラーの確率値$ y (0 \leq y \leq 1) $となる.


2クラス識別CNNの損失関数には,交差エントロピーの2クラス版である二値交差エントロピー(BCE: Binary Cross Entropy)を用いる:

\[
\mathcal{L}_{\text{BCE}}(y,\hat{y})
= – \sum_{i =1}^N [y^{(i)} \log \hat{y}^{(i)} + (1-y^{(i)}) \log (1-\hat{y}^{(i)})]\tag{2.1}
\]

これは,1次元確率であるターゲット$y$に対して,出力の1次元確率$\hat{y}$を,交差エントロピーでフィットする損失関数である.

2.2 (B) 多クラス識別:交差エントロピー(CE)

$C(\geq 2)$クラスのラベルをアノテーションした画像群から,データセットから多クラス識別器CNNを学習する際,損失関数には,交差エントロピー誤差をサンプル数だけ和を取った交差エントロピー損失が標準的に用いられる:

\[
\mathcal{L}_{\text{CE}} = – \sum_{i=1}^{N} \sum_{k=1}^{K} y_k^{(i)} \log \hat{y}_k^{(i)} \tag{2.2}
\]

式(2.2)の交差エントロピー損失は,ターゲットのクラス確率ベクトル$\bm{y}$(識別CNNの場合softmax確率)に,出力のクラス確率ベクトル$\hat{\bm{y}}_k$をフィットさせる損失関数である.

交差エントロピー損失は,負の対数尤度(Negative Log Likelihood) 損失とも呼ばれる(各次元$k$における負の対数尤度の和が,交差エントロピーである)

ここで,式(2.2)の構成を理解しやすくするために,バッチ内にN個ある損失のうち,$i$番目のサンプルペアの交差エントロピー誤差だけを見てみたい:

\begin{equation}\mathcal{L}_{CE}^{(i)}(\hat{\bm{y}},\bm{y}) = – \sum_{k=1}^{K} y_k \log \hat{y}_k \tag{2.1.1}\end{equation}

ここでは,$i$の表記を省略して$y_k^{(i)} = y_k$, $\hat{y}_k^{(i)} = \hat{y}_k$と表記した.

正解ターゲットの$\bm{y}$はone-hotベクトルであるので,結局は各サンプル$i$においては,$\bm{y}$の$k$次元目の$y_k = 1$となる要素についてだけ,積$y_k \log \hat{y}_k$が加算されると,式(2.1.1)から整理できる(他は$y_k = 0$なので加算されない).

2.2 (C) 単回帰:平均二乗誤差の損失関数

今度は回帰である.まずはスカラー$\hat{y}$を推定する,単回帰CNNのロス関数から整理したい.

単回帰CNNを学習する損失関数には,以下の平均二乗誤差(Mean Squared Error, MSE)が用いられる:

\begin{equation}\mathcal{L}_{MSE}(\hat{y},y) = \frac{1}{K} \sum_{k=1}^{K}(y- \hat{y})^2 \tag{2.3} \end{equation}

2.2 (D) 重回帰:平均二乗誤差 (L2ノルムの二乗誤差)

単回帰(C)と同様に,重回帰向けの損失関数のファーストチョイスも,ベクトル間の平均二乗誤差である (= L2ノルム二乗の平均誤差):

\begin{eqnarray}\mathcal{L}_{MSE} (\hat{\bm{y}},\bm{y}) = \frac{1}{K} \sum_{k=1}^{K} \| \bm{y}_k -\hat{\bm{y}_k} \|^2_2 \tag{2.4}
\end{eqnarray}

例えば,画像中のキーポイント(例:人物姿勢推定における人物の関節座標)をベクトル回帰させたり,物体検出において,バウンディングボックス間の修正誤差量 $\bm{t} = (x,y,w,h)$ を回帰する時などに用いられる.これらの2例のように「対象物体領域内の画像特徴に対して,何らかの$K$次元ベクトル$\bm{y}$をCNNに重回帰させる」ことを考えた場合,式(2.4)を基本的な損失関数としてまず検討する.

3. [損失関数,出力層] の基本的なペア

CNNの学習時は,計算グラフにそって,バッチ単位誤差の偏微分値を,出力側から手前の各層へと逆伝播させる.

ここで,CNNの出力層の各ニューロン($K$個)と,その1つ手前の全結合層のニューロン($J$個)の間における,重みパラメータ全体を$\bm{w}$ とし,そのうち各ニューロン間の個別重みを$w_{jk}$ とする.このときSGDにおける誤差逆伝播の,各ステージ $t$ においては,「(バッチ単位の)損失関数の重みパラメータ$w_{jk}$に対する偏微分値」に,係数$\alpha$ をかけた分だけ,重みを毎回更新するのであった:

\[
w_{jk}^{(t+1)} = w_{jk}^{(t)} – \alpha \frac{\partial \mathcal{L}}{\partial w_{jk}^{(t)} }\tag{3.1}
\]

つまり,『この逆伝播値の計算に必要な偏微分値を,スムーズに低コストで計算できるような [損失関数,出力層]のペアほど,よくCNNで用いられる』というのが基本的な考え方となる.

一方で,あなたが計算資源が豊富な環境を持っているほど,計算効率化の視点は必要なくなるので,計算効率など度外視で自由に組み合わせてしまって良いとも言える.

次の3.1節では,まず2.2節の各(A)~(D)の損失関数に対してペアにする,出力層のファーストチョイスを整理する.また,3.2節では管理人の個人的な視点として,回帰問題(C,D)について「正規化」の観点からも,組み合わせの良し悪しについて述べてみたい.

3.1 [損失関数,出力層] ペアのファーストチョイス

3.1 (A) 2クラス分類:BCE損失 + シグモイド出力

2クラス分類CNNの場合,基本は「二値交差エントロピー(BCE)損失 + シグモイド関数」を組み合わせる.この組み合わせは,確率値(0~1)になる(ロジスティック)シグモイド関数を出力に用いるロジスティック回帰にも関連している.

出力層のシグモイド関の前にある層のニューロン値$J$個のうち,$j$番目の値を$x_j$で表すとする.このとき,式(3.1)を用いて,$w_{j}$に逆伝播させる微分値は以下のようになる:

\begin{eqnarray} \frac{\partial \mathcal{L}}{\partial w_{j}} &=
\frac{y-\hat{y}}{
\textcolor{blue}{
\underbrace{\hat{y}(1- \hat{y})}_{BCEの微分}
}
}
\textcolor{green}{
\underbrace{
\hat{y}(1- \hat{y})
}_{シグモイドの微分}
}
\times x_j \\
&= (y-\hat{y}) x_j \tag{3.1}
\end{eqnarray}

つまり,この「BCE+シグモイド」だと,$\hat{y}(1- \hat{y})$を分子と分母で打ち消すことができ,学習の計算コストが低くて済む利点が出る.

3.1 (B) 多クラス分類:CE + softmax出力

多クラス分類CNNを学習する際は「交差エントロピー(CE)損失 + softmax出力」が基本的な組み合わせである.こちらも,ロジスティック回帰のマルチクラス識別版である「多項ロジスティック回帰」に由来している. 例えば,VGGNetの論文「3.1 Training」 にも,「AlexNetで行わた方法と同じく,多項ロジスティック回帰の目的関数をもとに,SGDで最適化する」と書かれている.

(A)と同様に,この組み合わせでも,計算コストが低い微分値を$w_{ij}$へ逆伝播できる:

\begin{eqnarray} \frac{\partial \mathcal{L}}{\partial w_{jk}} &=
\frac{y_k – \hat{y}_k}{
\textcolor{blue}{
\underbrace{\hat{y}_k (1- \hat{y}_k)}_{CEの微分}
}
}
\textcolor{green}{
\underbrace{
\hat{y}_k (1- \hat{y}_k)
}_{softmaxの微分}
} x_j \\
&= (y_k- \hat{y}_k)x_j \tag{3.2}
\end{eqnarray}

ちなみに主要なDeep Learningライブラリでも,計算が有利になるのでこの組み合わせが実用されている.例えばPyTorchでは,このペアの際に,交差エントロピーとsoftmax 関数を,個別にそれぞれ計算するのではなく,合成して計算コストを減らした式(3.2)で逆伝播できるtorch.nn.CrossEntropyLossクラスが実装されている.

また,実用上では,softmax後に対数関数も追加した「softmax -> log」構成の出力層にすることも多い.その理由であるが,softmaxの対数関数のおかげで,特定のクラスの生起確率が小さくなりすぎてアンダーフローが生じてしまい,その計算誤差によって,学習やテストが止まってしまうのを防ぐためである [Kamath et al., 2019].PyTorchだと,logとsoftmaxの合成関数計算が効率化されているtorch.nn.functional.log_softmax を用いると「softmax -> log 」構成で計算できるが,前述のCrossEntropyLossクラスの中では,そのlog_softmaxが使用されている.

logを追加する代わりに,正解データで0となっているクラスを0.0001などの小さい値に設定する対処をしても良いが,PyTorchの場合だと,logSoftmaxが既にCrossEntropyLossに組み込まれているので,そういった前処理を挿入しての対処は,基本的に不要である.

3.1 (C) 単回帰:平均二乗誤差損失 + 出力ニューロン (活性化関数なし)

単回帰CNNでは,通常「MSE + 出力ニューロン$y$」の組み合わせを取る.重み$w_j$へ逆伝播する微分値は,以下の計算式になる:

\begin{equation}\frac{\partial \mathcal{L}}{\partial w_{j}} = (y-\hat{y}) x_j \tag{3.3}\end{equation}

3.1 (D) 重回帰:平均二乗誤差 + 線形出力層 (活性化関数なし)

重回帰CNNでも,(C)の単回帰と同様に,畳み込み層や線形層を出力層に配置した「MSE + 線形出力層」の組み合わせが標準的である.組み合わせを取る.重み$w_{ij}$へ逆伝播する微分値は,以下の計算式になる:

\begin{equation}\frac{\partial \mathcal{L}}{\partial w_{ij}} = (y_i-\hat{y_i}) x_j \tag{3.4}\end{equation}

顔のランドマーク座標や,物体領域のバウンディングボックスサイズなど,幾何的な出力をCNNに学習させることも多いことから,素直に「線形層」を最後の層にして,そのまま出力させる.

ただし,単純なL2ノルムを用いるMSEでは,ノイズや外れ値に弱い.よって,物体検出など実際の応用では MSEより「ロバストな回帰向けの損失関数」が提案され,使用されていく (このシリーズ記事の「その(2)」以降で紹介予定).

ちなみに,もし重回帰CNNで「MSE + softmax出力」という組み合わせで,出力に活性化関数を採用した場合は,出力層手前の偏微分値が以下のようになる:

\begin{equation}
\frac{\partial \mathcal{L}}{\partial w_{jk}} =
\underbrace{
(y_k-\hat{y}_k)}_{MSEの微分}
\textcolor{red}{
\underbrace{
\hat{y}_k (1- \hat{y}_k)
}_{softmaxの微分}
} x_j \tag{3.5}
\end{equation}

これだと式(3.2)の場合とは異なり,$\hat{y}_k (1- \hat{y}_k)$の項が消去できずに残ってしまう.よって逆伝播の計算コストが高くなり,あまりよろしくない (使っても良いが学習は少し遅くなる).

3.2 回帰問題における「正規化」の重要性

クラス識別での確率による出力と異なり,回帰問題では,ターゲットベクトル$\bm{y}$の各次元の値の範囲が幅広かったり,次元間で値の統計に偏りがあるほど,パラメータのベクトル空間も広く偏りが出て回帰CNNをうまく収束させづらくなる.

例えば,入力と出力の関係がしっかり対応づきやすいランドマーク推定(キーポイント推定)問題などでは (例:顔のランドマーク推定や,人物姿勢推定での関節キーポイント推定など),出力の多様性が大きいと,それだけ広いパラメータ空間で最適化するはめになり,回帰CNNを学習する難易度は上がる.ましてや,入力画像の見えと相関性があまり無いような出力スコア値を回帰させたいデータの場合は(例:スコア付けやランキング問題など),入出力間の相関は少ない場合も多く,そうした場合は回帰CNNをうまく学習させるのはなおさら難しい(※ 当然,そういう場合でも,隠れ層を用いることでむりやりにでも入出力間の関係性を発見して予測モデルを学習するのが,ニューラルネットワークなわけではあるが).

この意味で,入力や各中間層の活性化後ベクトルにおいて,正規化(あるいは標準化)を行うことも,SGD学習を安定・効率化させて,回帰CNNの精度を高める上で重要となる.回帰でも識別でも,CNNの場合は,入力の正規化 [Lecun et al., 1998] や,中間層のバッチ正規化 [Ioffe et al., 2015] により,この問題に対応することが標準的である.出力層での正規化は,あまり効果が出ないことが多い(出力に対しては,スケーリングを行う程度).

4. まとめ

この記事では,基本的なCNN(識別と回帰)の損失関数について,以下の点を整理した:

  • 識別CNNには「交差エントロピー関数」を,回帰CNNには「平均二乗誤差」を,それぞれ基本的な損失関数として用いる(2節).
  • ロジスティック回帰と同様の,識別モデルでのsoftmax層の使用が,識別ニューラルネットでも用いられる.「交差エントロピー + softmax」ペアになるが,微分値がシンプルで計算が有利なことも,この組み合わせが標準的に用いられる理由である(3節 (A),(B))
  • 【著者の個人的経験・意見】(識別CNNとは対照的であるが)回帰ニューラルネットだと,出力層には活性化関数は用いないで,線形層の出力を,そのまま出力することも多い.入力と出力の間にはっきりとした相関が存在しづらい,多くの画像入力回帰問題において,回帰CNNを学習しやすくするには,識別CNNにも増して入力層や中間層をしっかり正規化してデータ分布を整えていくことが,安定した収束解を達成するためには重要である.

関連書籍

References

参照外部リンク