再帰型ニューラルネットワーク(RNN) [LSTM, GRU, 双方向RNN]

1. 概要

再帰型ニューラルネットワーク(Recurrent Neural Network, RNN)と呼ばれる,可変長の系列形式データをモデル化するために用いられるニューラルネットワーク構造について,この記事ではまとめる.RNNの基本形について2節で解説したのち,3節でRNNの発展モデルとしてLSTM, GRUs, 双方向RNNについてとりあげる.

RNNは「時系列順に各フレームの特徴ベクトル同士が相互依存関係を持ちながら変化する」系列データ(の毎フレームの高次元特徴ベクトル)を対象として,系列の次フレームを予測するモデルを学習する仕組みである.RNNに入力する系列データの例としては,自然言語処理でのテキストデータや,環境音・音声などのオーディオデータや,あるいは動画処理でのアクション・イベントなどが上げられる. RNNはそのような系列全体の時間的変化を,時刻$t$の潜在変数ベクトル$\bm{h}_t$の変化として学習できる.

また,RNNは再帰(Recurrent)と名づけられている通り,フレーム間の隠れ状態$\bm{h}_t$と$\bm{h}_{t+1}$の変化が1つの関数(ニューラルネット)により表現されている再起更新型のモデルである.前ステップの潜在状態ベクトルと最新フレームの特徴ベクトルをもとに,次フレームにおける新たな潜在状態およびそれに対応した観測を予測することができるので,長めの系列の変化を学習するモデルでありながら,再帰関数1つで簡潔なモデル構造で済ませることができる.

1.1 従来の系列モデルとの違い

RNN登場以前は,自然言語処理や音声認識などの「系列ラベリング問題」や,動画行動認識などの「系列識別モデル」には,HMM(Hidden Markov Model:隠れマルコフモデル)CRF(Conditional Random Field:条件付き確率場)などの系列モデルがよく用いられていた.

HMMではフレーム間の潜在状態遷移にマルコフ連鎖を仮定しており,前フレームの潜在状態クラス$\bm{h}_{t-1}$の$K$クラスに対して,現フレームの状態が$K$クラスのどのクラスに遷移するかの「フレーム間における状態クラス遷移確率」がモデリングされていた.それに対し,RNNでは回帰関数を用いて,ステップ$t$までに入力されたシーケンス全体の表現を,潜在状態ベクトル$\bf{h}_{t}$として埋め込み(embedding)を学習できる.

従来の系列予測モデルと同じRNNの特徴・利点として,可変長シーケンスを学習できる点がある.RNNでは,回帰関数を用いて潜在状態のフレーム間変化を学習しているので,系列の長短が異なるサンプルでも全て同一の隠れ状態ベクトル空間$\bf{h}_{t}$へ射影(埋め込み)できる.よって,長短の系列長が違うサンプル同士を$\bf{h}_{t}$上で比較することも可能であり,学習データの系列長は統一されていなくても良い.

1.2 記事の構成

2節以降では,以下の順でRNNの基本モデルとその応用モデルを紹介する:

  • 2 RNNの基本モデル
  • 3 RNNの発展モデル
    • LSTM (Long Short-Term Memory)
    • GRUs (Gated Recurrent Units)
    • 双方向RNN
  • 4 まとめ

2 RNNの基本モデル

2節では,まずRNNの構成要素の再帰・繰り返し(Recurrence)記憶(Memory)を用いた,RNNの基本的な更新式について述べる(2.1節).その後,BPTTを用いたRNNの学習方法について簡単にだけ述べる(2.2節)

2.1 フレーム単位での繰り返し更新

RNNの入力である(時)系列データ(音声,テキストなど)を,$\bm{X} = (\bm{x}_1, \bm{x}_2, \ldots, \bm{x}_T)$と定義する.$T$は入力系列の長さであり,$\bm{x}_t $はフレーム$t$における観測データの特徴ベクトルである.

RNNは系列長$L$回分,同じ操作を行う回帰構造を採用している.図1は,RNNの状態遷移を示したダイアグラムである:

RNN
図1:RNNのダイアグラム.左側は繰り返し方式でRNNの再帰構造を表現したもの,等号より右側は左側の表現を展開して,1ステップずつ個別に処理を表記したもの.青い丸は各ステップ$t$における隠れ状態ベクトル$h_t$を示す.

RNNでは,各ステップ$t$の観測に対し,それまでの全フレームの変遷の結果を表現する内部記憶に相当する潜在状態(hidden state)ベクトル$\bm{h}_t \in \mathbb{R}^d$を更新していくことで系列の変化を,再帰関数を用いてモデル化する.

各ステップ$t$では(以前の全フレームの全記憶を表す)潜在状態$\bm{h}_t$と,ステップ$t$の新規観測$\bm{x}_t$の2つを入力として用いて,再帰関数により新たな出力$\bm{o}_{t}$を予測する.つまり系列データにおいて,各フレームの予測$\bm{o}_{t}$が正しく予測されるようにRNNを学習する (学習手続きについては次の2.2節).図1左側の繰り返し表現を,図1右側のように展開(unfold)することで,RNNを「同一の再帰関数$T$層の繰り返しから構成される再帰的なニューラルネット」として捉えることができる.

RNNの各ステップ$t$では,線形重み層との活性化関数$f$(シグモイド,$\tanh$など)を用いて,以下の式のように出力$\bm{o}_t $と潜在状態$\bm{h}_t$を順伝搬により更新する:

$$\bm{h}_t = f( \bm{U} \bm{x}_t + \bm{W} \bm{h}_{t-1}) $$

また,各フレームにおける出力の予測は,以下のように隠れ状態を重みパラメータ$\bm{V}$で線形変換して行う:

$$\bm{o}_t= \bm{V}\bm{h}_t $$

言語モデルのように,各フレームの出力をクラス確率ベクトル$\bm{y}_t$にしたい場合は,更にsoftmax関数を用いて$\bm{o}_t$をクラス確率へ変換する:

$$ \bm{y}_t = softmax(\bm{o}_t)$$

以上がRNNの基本的な処理である.

NLPでの例

ここで,具体的な各変数の役割をイメージしやすくなるように,自然言語処理(NLP)でのRNNの文章予測のにおける3つの変数例を以下に提示してみる:

  • $\bm{x}_t$:入力の特徴表現ベクトル.ニューラル言語モデルとしてRNNを用いる場合は,元の単語を低次元ベクトルに埋め込んだword2vecGloVeなどの,単語の埋め込みベクトル(分散表現)を用いることが多い.
  • $\bm{h}_t$:潜在変数ベクトル.系列の最初から$t-1$フレームまでの,潜在状態の変化全てを蓄積した記憶に相当.
  • $\bm{y}_t$ :出力のクラス確率.自然言語処理系のRNNでは,この出力として,次のフレーム$t$の単語クラス確率を予測する(語彙サイズ$K$が).

RNN基本モデルの長所と課題

RNNは同じ順伝播を毎ステップで繰り返すが,この構造の利点として「パラメータ共有によるモデルの効率化」が挙げられる.RNNは「各フレーム$t$の更新では,毎回同一のパラメータで再帰するのみ」という省パラメータ設計であるので,学習するデータの系列長が長くとも小規模なモデルで済ますことができる(もし各ステップ$t$で,別々の再帰関数を用意すると,モデルが膨大になってしまい学習できなくなる).

一方で,このRNN基本モデルはシンプル過ぎるので,表現力が限られてしまっており,長期間の系列のモデリングや,より複雑で散発的だったり変動的だったりする時間依存を学習するのには向かない.そこで,3節に述べる発展モデル (LSTMやGRUs)や,アテンション機構やメモリー機構などの,「RNNを拡張する各種仕組み」を用いることで,弱点の解決を図ることになる.(逆に言うと,ナイーブなRNN基本モデルは,短くて分岐のないシーケンスにしか適用しづらいモデルである)

2.2 RNNの学習 (BPTT:時間方向逆伝播)

RNNでは、図1-右図の展開図を元に,ステップの系列逆方向に沿って3つのパラメータ行列$\bf{U}, \bf{W}, \bf{V}$に対して誤差逆伝播を行う.

ロス関数としては,まず毎フレーム$t$での誤差として,そのフレームでの予測$\hat{\bm{o}_t}$とGround Truth$\bm{o}_t$の間の,クロスエントロピー$E_t$を計算する.その$E_t$を全フレーム分だけ足した各サンプルの合計ロスを,学習データ$N$サンプルあるうち平均した$L$を,RNNのロス関数としても用いる:

$$ L(\hat{\bm{y}_t}, \bm{y}_t)= – \bm{y}_t \log \hat{\bm{y}}_t $$

$$ L(\hat{\bm{y}}, \bm{y}) = – \frac{1}{N} \sum_t^{T} L(\hat{\bm{y}_t}, \bm{y}_t) $$

RNNの学習では,上記のロス関数を用いて,系列の最後の$t=T$から$t=1$の逆方向に,ステップ$t$ごとに誤差$L(\hat{\bm{y}_t}, \bm{y}_t)$の逆伝播を行う.このRNNの学習方法を,時間方向逆伝播(Back Propagation Through Time) と呼ぶ.

RNN を学習する際の課題:うまく長期記憶が行えない.

RNNは長い系列データが対象になるほど「勾配消失・勾配爆発が起きやすく,系列データの前半部分をうまく学習できない」という課題がある.

そこでRNNのネットワーク構造を改良することで,次の3節で述べる「長期記憶も保持可能な工夫(ゲート機構)」が追加されたLSTMや,「LTSMの簡易版のGRUs」が提案されるなどによって,より長い系列データの変化を学習する際にはLSTMやGRUsが使われるようになった.また,系列データの持つ意味を順方向だけでなく逆方向からも辿って表現したいので「双方向RNN」が登場した.次の3節ではこれらの発展モデルについて簡単にまとめる.

3 RNN の発展モデル

RNNの基本的なモデルでは学習しづらい「少し長めの系列データ」を学習するためのゲート機構付きRNN(Gated RNN)であるLSTM(3.1節)とその簡略化改善版のGRU(3.2節)について紹介する.また順-逆両方の系列モデリングを同時に行うRNNの双方向RNN(3.3節)も紹介する.

ちなみにVision and Language周辺において,更に長期の系列変化をモデル化する目的で「2階層LTSM」を用いることも多いが,RNNの階層化については本記事では触れない. 

3.1 LSTM (Long Short-Term Memory)

LSTM (Long-Short Term Memory) [Hochreiter and Schmidhuber]は,ベクトル合成用の重みを入力の値に則して制御するゲート(gate)と呼ばれる変数間調整機構をRNNに追加することで,潜在変数(記憶)の長期伝達も可能にしたRNNの改良型である.

RNN基本形では,長い系列長のデータセットにより長期記憶を行おうとしても,勾配消失問題(vanishing gradient problem)のせいでうまくいかず,短い系列長のデータの変化しかうまく学習できなかった.

それをLSTMでは,RNN基本形モデルの持つ短期的な記憶 $\bm{h}_t$ に加えて,記憶セル(memory cell) (もしくは「セル」)と呼ばれる「短期情報を長く保持する記憶」の $\bm{c}_t$ も追加したモデルである.これによりLSTMは「長短の記憶間の依存関係の変化(調整に用いる係数)」を3つゲート機構を利用し,学習できるようになっている.

詳細な処理手順を述べる前に,まず以下の図2にLSTMのネットーワーク構造(各ステップ$t$におけるブロック構造)を示す:

LSTM
図2:LSTMの,ステップ$t$におけるネットワーク構造.緑のブロックは層による操作を表し,青ブロックは要素ごとの演算を表す.「concat」は「ベクトル同士の結合」を示す.

LSTMのセルブロック内では,3種類のゲート(忘却ゲート,入力ゲート,出力ゲート)が,図2のような入出力関係のもとで配置され,セルの情報の「忘れる度合い/ 無視する度合い/保持する度合い」を各ゲートが調整することにより,セルの情報を保護することができ,情報の長期伝達を可能にしている.

各ステップ$t$では,短期記憶$\bm{h}_{t-1}$と観測ベクトル$\bm{x}_t$を結合したベクトル$[\bm{x}_{t},\bm{h}_{t-1}]$を入力に用いて,3つのゲートが調整係数であるゲート値の予測を行い,そのゲート係数を用いてセルのgating(忘却,入力,出力)を行う.

忘却ゲート

忘却ゲート$\bm{f}_t$では,セル状態$\bm{c}_{t-1}$から,どのくらいの割合だけ情報を廃棄するかを決める(0~1の値):

\begin{equation} \bm{f}_t = \sigma (\bm{W}_{f}[\bm{x}_{t},\bm{h}_{t-1}]+ \bm{b}_{f})\end{equation}

入力ゲート

入力ゲート$\bm{i}_t$ は,どの新情報をセル(長期記憶)に記憶するかを決定する.そして入力ゲート値の分だけ,セルに貯める新情報として$\tilde{\bm{c}}_t$を作成する ($\tanh$により$\tilde{\bm{c}}_t$は-1から1のあいだの値をとる):

\begin{equation}\bm{i}_t = \sigma ( \bm{W}_{i}[\bm{x}_{t},\bm{h}_{t-1}]+ \bm{b}_{i} ) \end{equation}

\begin{equation} \tilde{\bm{c}}_t = \tanh(\bm{W}_{h}[\bm{x}_{t},\bm{h}_{t-1}])\end{equation}

長期記憶(セル)の更新


忘却ゲート値と入力ゲート値を用いて,セルの長期記憶$\bm{c}_t$を「忘却量 (長期記憶)+入力量(短期記憶の予測値)」の重み付け和によって更新する:


$$ \bm{c}_t = \bm{f}_t \circ \bm{c}_{t-1} + \bm{i}_t \circ \tilde{\bm{c}}_t$$
($\circ$はアダマール積で,要素ごとの積)

出力ゲート

出力ゲートでは,セルから新情報のうち何を出力すべきかの重み係数を決定する.出力ゲートの値を用いて,次フレームの$\bm{h}_t$の予測を行う:

\begin{equation} \bm{o}_t = \sigma ( \bm{W}_{o}[\bm{x}_{t},\bm{h}_{t-1}]+\bm{b}_{o})\end{equation}

\begin{equation} \bm{h}_t = \bm{o}_t \circ \tanh(\bm{c}_t)\end{equation}

以上がLSTMの順伝搬におけるステップ$t$での挙動である.ゲート機構の導入で長期記憶を更新する際に,どういうタイミングにどのくらいのバランスで短期記憶と長期記憶がインタクションするかが学習される.これにより,短期記憶しか保持できないRNNと比べて,LSTMでは,やや長い系列の変化も「2つの潜在変数(長期記憶と短期記憶)+3つのゲート関数」を用いて学習させることができる.

[2020/7/12 追記] LTSMのブロック図については,2015年に公開されたcolah’s blogの記事「Understanding LSTMs」による美しい図による詳細な説明が有名であり,その後はこの記事の図示の方式に習ったRNN/LSTMの図解がとても多い.ただ,この当時はこのような詳細なLSTMの説明の需要が高かったものの,今となってはLSTMは当時ほどハードルが高い技術ではなく,標準知識や標準APIとなっており,この記事のレベルほどに中身をつぶさに知る必要はなくなっている.なにより,現在はLSTMはあまり使われず,Transformerで系列データを表現する時代に移ってきている.

3.2 Gated Recurrent Units (GRU)

Gated Recurrent Units (GRU) は, LSTMを少し簡易したモデルである.LSTMのゲート機構はパラメータも多く,学習に時間がかかる原因ともなる.そこで,LSTMの各ゲートを結合して2つのゲート構成に単純化し,セルも用いずに$\bm{h}_t$だけで似たようなモデルを達成できるようにして,計算効率性を向上させたものがGRUである.各ステップ$t$でのモジュールをGated Recurrent Unitと呼び,それを$t$ステップ分つなげたものなので,GRUsもしくはGRU network とも呼ぶ.

動画認識や言語処理など,中~高次元の特徴ベクトルを使う場合にはGRUが効率が良い.しかし,1次元の生波形をそのまま特徴化しないで解析する場合など,入力特徴ベクトルが低~中次元ベクトルであり,なおかつ系列長もあまり長くない対象であればLSTMで十分である.またGRUは,LSTMほどの長期記憶性能はないので,GRUで性能が出ない場合にLSTMのほうが予測性能がよくなる場合もあるゆえ,2者の使い分け意識は大事である(※ 2021年7月追記.ただし最近だとTransformersで解ける問題では,そもそもLSTM/GRUの出番が無いことも多い).

以下の図3に,GRUのステップ$t$におけるネットワーク構造を示す.

GRU
図3 GRUのフレーム$t$におけるブロック構造

GRUでは,以下のようにリセット(reset) ゲート更新(update)ゲートの2つのゲートを用いて,(LTSMより簡素化した形で)毎ステップ$t$における潜在状態$\bm{h}_t$の回帰を行う:

\begin{align}\bm{z}_t &= \sigma (\bm{W}_{z } [\bm{x}_{t}, \bm{h}_{t-1}]) \\
\bm{r}_t &= \sigma (\bm{W}_{r}[\bm{x}_{t}, \bm{h}_{t-1}]) \\
\tilde{\bm{h}}_t &= \tanh (\bm{W}_{h}[\bm{x}_{t}, \bm{h}_{t-1}] \circ \bm{r}_t) \\
\bm{h}_t &= (1 -\bm{z}_t) \circ \tilde{\bm{h}_t} + \bm{z}_t \circ \bm{h}_{t-1} \end{align}

最後の行がGRUのメインの仕組みに相当しており,ここでは予測された次ステップの潜在ベクトルの候補状態$\tilde{\bm{h}}$と前ステップの潜在ベクトル$\bm{h}_{t-1}$を,更新ゲート値$\bm{z}_t$の比率$(1-\bm{z}_t):(\bm{z}_t)$で合成することで,次ステップの$\bm{h}_t $を予測する.すなわち,異なる時間スケール間での依存関係を,適応的に取得できるようになっている.

また,リセットゲートはLSTMの忘却ゲートと同じく「前フレームの潜在状態をどれだけ忘れるか(メモリ上から除去するか)」の役割を担当している.

3.3 双方向 RNN

双方向RNN (Bidirectional RNN) は,系列の最初のステップ$$から繰り返して順方向に予想することに加えて,系列の最後のステップ$t=T$からの逆方向の予測も行う,RNNを双方向形に拡張したモデルである.

以下の図4に双方向RNNのダイアグラムを示す.

双方向RNN
図4 双方向RNNのネットワーク構造.

双方向RNNも,毎フレーム$t$において観測$\bf{x}_t$を入力して潜在変数をもとに出力$\bf{o}_t$を予測するところは通常の(単方向)RNNと同じである.一方で,単方向RNNと異なる点は,潜在変数が順方向$\vec{\bf{h}}_t$と逆方向の$\cev{\bf{h}}_t$の2種類が使用される点である.これにより,双方向のコンテキストを集約した出力の予測が可能となる.もう少し正確に述べると,順方向側はステップ1からステップ$t-1$までのコンテキストを知っており,逆方向側はステップ2からステップ$t+1$までのコンテキストを双方向とも学習できて予測できるようになる.

4. まとめ

系列モデル化のためのニューラルネットワークである,再帰ニューラルネットワーク(RNN)の基本的な原理を紹介した.また,RNN基本モデルの「短い系列しかうまく学習できない」弱点を補うための派生系であるLSTMと,そのLSTMの簡素版であるGRUsについて述べた.また,RNNに逆方向の隠れ状態も追加した双方向RNNについても触れた.

References

  • [Hochreiter and Schmidhuber] S. Hochreiter and J. Schmidhuber. “Long short-term memory.” Neural computation 9.8 (1997): 1735-1780.

外部参照リンク

関連記事