マルチヘッドアテンション(Multi-head Attention) [Transformerの部品]

1. マルチヘッドアテンション(Multi-head Attention) とは [概要]

マルチヘッドアテンション(Multi-head Attention) とは Transformerで提案された,複数アテンションを並列に実行するブロック部品である [Vaswani et al., 2017] .

マルチヘッドアテンションは,以下の3処理によって,系列中のトークンT個の表現を一気に変換するモジュールである:

  1. 同一の入力トークン表現[Sennrich et al., 2016]のベクトル群(512次元)をまず次元削減した,トークン表現ベクトル群を入力として用意.
  2. 次元削減されたトークン表現ベクトル入力として,異なる h(=8)個のスケール化ドット積アテンションを分散実行し,8個のアテンション結果(コンテキストベクトル)を取得
  3. 結果の8個のベクトルを再度1つに結合し,元のトークン表現ベクトル次元数(512)まで次元復元して最終出力する.

Transformerは,機械翻訳や言語モデル,TTS・音声認識などの目的で,別ドメイン同士の系列を変換する際によく用いる系列対系列変換の定番モデルである.従来モデルに対するTransformerの主たる工夫は,マルチヘッドアテンションを用いることで,行列積が主体の「(効率的な)ドット積計算」により,系列中のベクトル全部を一括に変換する仕組みを提案した点にある(2節).従来のseq2seq with attention時代では,「(Bahdanau方式の)シングル・アテンション [Bahdanau et al., 2015][Luong et al., 2015]」 と,「RNNによる線形層を用いたフレーム遷移 or 時系列方向畳込み層による遷移 (ConvSeq2seq [Gehring et al., 2017])」のペアを用いることで,時系列方向に「ローカルな」自己回帰予測を行っていた .ローカルとは,それぞれ「1フレーム単位の線形層遷移 (with attention)」 or 「時系列畳み込みの窓内の数フレーム単位 (ConvSeq2seq))での時系列方向の関係性の学習であった.

それが,マルチヘッドアテンションでは,系列全体のトークン間関係をグローバルに学習できるので, 系列全体のトークン表現ベクトルを一括に変換できる. 窓サイズTのマルチヘッドアテンションを用いて,「T個のトークン幅」での表現を一気に変換するように変わったわけである

こうして,Transformerは「グローバルな長期依存コンテキスト」を加味できるマルチヘッドアテンション処理の採用により,表現シーケンスの変換性能が向上したうえに,計算効率も良くなった (※ 論文中の4節「Why Self-Attention」で展開されている主張).

1.1 記事の構成

記事の前半(2節)では,Transformerにおけるマルチヘッドアテンションの処理と定式化を行う.

記事の後半(3節)では,Transformer内の各マルチヘッドアテンションが,自己or 相互アテンションなのか(3.1節)と,各QKV入力がどのように異なるか(3.2節)について,順に整理したい.

この記事は,マルチヘッドアテンションにのみフォーカスしているので,Transformerの全体的な仕組みについては,親記事や元論文,および関連書籍等を参照のこと.

2. マルチヘッドアテンションの処理手順と定式化

マルチヘッドアテンションでは,まず入力トークン表現ベクトル群(系列全体)を元に,スケール化ドット積アテンション(QKV入力方式)を,$h=8$個を並行に実施する.

スケール化ドット積アテンションの入力は,系列中のトークン群N個(=入力系列長)から計算した{クエリ行列$\bm{Q}$,[キー行列$\bm{K}$,バリュー行列$\bm{V}$]}の3つであり,h個の各アテンションには,全て同じ$(\bm{Q},[\bm{K},\bm{V}])$を入力として与える.

ここで,h個のアテンションに,それぞれ異なる別の変換が学習される点が,マルチヘッドアテンション(もといTransformer)のキモである.線形変換中心の,計算負荷が低いアテンション処理の並列実行だけで,系列全体の各トークン間の関係を考慮した,全トークンの表現変換がいっぺんに行えるのが利点である.

要するにマルチヘッドアテンションの処理をまとめると「まず埋込み層で次元削減して,ボトルネックをつくり,そのあと8個並列にアテンションを使ったあと再結合して出力する」というものである.それでは,その処理や構成を細かくみていきたい.

2.1 マルチヘッド・アテンション(multi-head attention)の,処理手順の詳細

マルチヘッド・アテンション(multi-head attention)
図1. マルチヘッド・アテンション(multi-head attention)

マルチヘッドアテンション(図1)では,入力の「トークン表現の系列」の変換処理を,h = 8個の複数アテンションを用いて,以下の3手順で行う:

  1. [アテンション前の前処理:トークン表現の次元削減]
    • 入力のQKVを構成する入力ベクトル $\bm{x}$を,QKVを $d_k = d_v = d_{model} /h= 64$の低次元ベクトルへ射影(※ 並列に8個アテンションで処理するので,各ヘッドの計算負荷を減らすため次元削減).
  2. [h=8個のアテンションを並列実行]
    • 低次元に射影されたQKVから,h=8個のスケール化ドット積アテンションを,並列に実行.
    • 結果, $\bm{Z}_1, \bm{Z}_2 ,\ldots, \bm{Z}_8$ が,各アテンションヘッドから得られる.
  3. [出力:並列処理結果の結合]
    • 8ヘッドの結果$\bm{Z}_i$を,1つのベクトルに再結合し,$h \cdot d_v$次元ベクトルになる.
    • 結合したものを,線形層$\bm{W^O}$で変換することで,$h \cdot d_v$次元から元の$d_{model}$次元に戻った最終出力を得る.

Transformer は最初,seq2seq with attention を差し替える「次世代の系列対系列変換モデル」として,機械翻訳向けに提案された.その後seq2seq with attention同様に,様々な系列対系列変換問題で,Transformerは広く用いられている(Transformerまでの経緯は,アテンションと系列対系列変換の記事を参照).

2.2 定式化

2.1節の説明を元論文や式と対応付けられるようにするために,2.1節で述べた処理手順を,Transformerの元論文に沿って,各「変換式」で示す.

2.2.1 各ヘッドでの処理:

まず,各$i$番目ヘッドでは,$i$番目のアテンションモジュール$(\text{Attention}(\cdot,\cdot,\cdot))$を用いて,出力$\bm{Z}_i$が並列に計算される:

\[
\bm{Z}_i = \text{Attention}(\bm{Q}\bm{W}_i^Q,\bm{K}\bm{W}_i^K,\bm{V}\bm{W}_i^V) \tag{2.1}
\]

式(1.1)は,「スケール化ドット積アテンション$\text{Attention}$」と,「$d_{model}$次元のトークン行ベクトルを列数個重ね行列である $\bm{Q}, \bm{K}, \bm{V} $ (それぞれQuery, Key, Value)の表現をそれぞれ変換する3つの線形層 ($\bm{W}_i^Q,\bm{W}_i^K,\bm{W}_i^V$が重みパラメータ) から構成されている(図1 (a)中盤赤色枠の部分).

ここで,各線形層のパラメータ行列の次元数は,それぞれ以下のようになる(※ 2.1節でも述べたように $d_k = d_v = d_{model} /h= 64$):

  • $\bm{W}_i^Q \in \mathcal{R}^{d_{model} \times d_k}$
  • $\bm{W}_i^K \in \mathcal{R}^{d_{model} \times d_k}$
  • $\bm{W}_i^V \in \mathcal{R}^{d_{model} \times d_v} $

式(2.1)中の,各$i$番目のヘッドで行う「スケール化ドット積アテンション」は,以下の(2.2)式の変換処理である:

\[
\text{Attention}(\bm{Q},\bm{K},\bm{V}) = \text{softmax}\left(\frac{\bm{Q} \bm{K}^T}{\sqrt{d_k}}\right) \bm{V} \tag{2.2}
\]

2.2.2 各アテンションの出力を,結合および変換

式(2.1)(2.2)によるN=8ヘッドの並列計算が全て完了したら,(2.1)の各ヘッド$i$の出力表現$\bm{Z}_i$を結合して,1つの結合表現($h \cdot d_v $次元)にまとめる.

その結合した表現を,線形層$\bm{W}^{O} \in \mathcal{R}^{h \cdot d_v \times d_{model}}$で変換して,最終的に,元の$d_{model}$次元に次元数が戻ったトークン表現の系列行列を最後に出力する:

\[
\text{MultiHeadAttention}(\bm{Q},\bm{K},\bm{V}) = \text{Concat}(\bm{Z}_1, \bm{Z}_2, \ldots, \bm{Z}_8) \bm{W}^{O} \tag{2.3}
\]

以上が,マルチヘッドアテンションの計算手順である.

このマルチヘッドアテンションの変換式(2.1)(2.2)は,2018年以降のTransformer流行により,非常に多くのTransformerを応用したパターン認識系の論文で見かけることになる式である (図1も,2.1節との対応付けが行いやすいように描いた).

2.3 命名について

seq2seq with attention 世代で使用されていた「Bahdanau方式のアテンション」との違いを明確にするためには,もう少し詳しく「マルチヘッドQKVスケール化ドット積アテンション」と呼んでもよいと,管理人は個人的に思う.

とはいえ,Transformerの著者らは,と「QKV」と「スケール化ドット積」を両方とも省略した「マルチヘッド・アテンション」と,簡潔な命名を行った.

3. Transformerにおけるマルチヘッドアテンションの使用

Transformerでは,アテンション(=スケールドット積アテンション)の入力のQKVのもってくる場所次第で,「マルチヘッド自己アテンション」なのか,「マルチヘッド相互アテンション」なのかが変わり,各マルチヘッドアテンションの果たす役割も微妙に違ってくる.

3節では,自己アテンションなのか相互アテンションの違いについて整理したのち(3.1節),Transformer内で使われている3箇所のマルチヘッドアテンションについて,QKVの受け取り元がどのようになっているかを確認しておきたい.

3.1 自己アテンション と 相互アテンションの違い

自己アテンションと相互アテンションの比較
図2 自己アテンションと相互アテンションの比較

Transformer登場以前の時代から,アテンションを系列対系列変換で用いる際には, (1) 系列内で行うか,(2)系列間で行うかによって,アテンションを以下の2種類に分類していた:

  • [図2-a] 自己アテンション(self-attention) 系列内における,トークン表現ベクトル間のアテンション.
  • [図2-b] 相互アテンション(cross-attention) 系列間での,トークン表現ベクトル間のアテンション(「source-target アテンション」と呼ぶこともある).

自己アテンション=系列内アテンション [Cheng et al, 2016] [Parikh et al, 2016] は,系列内のトークン同士で,トークン表現ベクトル同士の関連度を学習する (図2-a). これに対して,Bahdanau式のseq2seq with attentionでは,系列間でのベクトル間の関連度を学習するので相互アテンション=系列間アテンションであると言える.

ちなみにTransformerが,自己アテンション(セルフアテンション)の良さを全面に押し出して,人気mモデルとなったことから,この「自己アテンション v.s 相互アテンション」という呼び方が主流となった.それ以前も,VQAやImage Captioning/Grounding などのクロスモーダルアテンション技術においては,別の呼び方として「intra v.s inter アテンション」という同じことを指す呼び方もあったのだ.しかし,Transformer流行後は(私の記憶や,知ってる範囲に限れば),intra-interアテンションの呼び方はほぼ使われなくなり,「自己アテンション・相互アテンション」の呼び名に統一された印象である.

3.2 アテンションの種類にもとづいた,全体構成の整理

Transformerの概要図:各マルチヘッドアテンションの種類の整理
図3 Transformerの概要図:各マルチヘッドアテンションの種類の整理

図3は,Transformerの全体構成図なのであるが,その構成部品の多くは省力し,マルチヘッドアテンションの違いにのみフォーカスした図である.図3では,各マルチヘッドアテンションが「(1) QKVが X(入力系列)から来ているのか,Y(出力系列)から来ているのか」と「(2) 自己アテンションか相互アテンション化」の2点のみを確認できる([西田 2022] の図にインスパイアされ,その図を本サイトでの色分けに合わせてアレンジすることで,更にTransformerの全体像がつかみやすいようにした).

この図3と同様に,各アテンションの違いにのみフォーカスし,TransformerのEncoderとDecoderの構成について以下に整理した(より詳しい部品構成は,親記事Transformerや元論文などを参考のこと):

  • Transformer-Encoderブロック(×6回):
    • マルチヘッド自己アテンション:
      • 前ブロックのトークン表現系列の行列$X$を入力して,$i$番目のマルチヘッド自己アテンションで更新.
    • トークン位置ごとにFFNをforward
      • 同一のFFNで,各トークン表現を非線形変換.
    • Nブロック分繰り返したのち,$X^{(N)}$を最終出力(Decoderの相互アテンション用の出力).
  • Transformer-Decoderブロック (×6回) :
    • マルチヘッド自己アテンション:
      • 前フレームの予測 $y_t$ を$N$フレーム分ならべた入力トークン行列$Y$.
      • $(Q=Y, K=Y,V =Y$)で各ヘッドを並列実行し,中間出力$\bm{Z}$を出力.
    • マルチヘッド相互アテンション:
      • $(Q=Z, K=X^{(N)},V =X^{(N)}$)で各ヘッドを並列実行し,入力系列の符号$X^{(N)}$から,予測に対応するコンテキストを取り込む.
    • トークン位置ごとにFFNをforward
      • 同一のFFNで,各トークン表現を非線形変換.
    • se2seq同様に,あらたに予測されたトークンを,次ぎフレームの入力$K$にフィードバックして,全体を自己回帰する([EOS]トークンが出るまで繰り返し).

以上で,マルチヘッドアテンション処理の紹介は終わりである.

最後に,ここまで述べてきた,従来手法とTransformerのあいだで「アテンション」がどのように違うかを,表にまとめておく:

系列の符号化系列間での関係学習
seq2seq with attention
(従来手法)
LSTMのセル間遷移
ローカルな各フレーム間で変換
(相互)シングルアテンション:
次の予測単語と,入力系列の全単語のコンテキストを使用.
ConvSeq2seq
(従来手法)
時系列畳み込み:
ローカルな窓内での畳み込み遷移.
上に同じ
Transformerマルチヘッド自己アテンション
系列全体(=global)の,各トークン表現ベクトルを,一挙に変換.
+
FFNでトークンごとに変換
(図3 Encoder下部, Decoder下部)
マルチヘッド相互アテンション
Encoder-Decoder間もマルチヘッド化し,入力系列全体の各トークんを一挙に変換.
(図3 右Decoder上部).
テーブル1 旧来手法とTransformerの比較

ちなみにこの記事では,Transformer-Decoderで次トークンを予測する時の,マルチヘッドアテンションの「マスク化」については話していない.マスク化による,未来トークンの隠ぺい処理については,親記事を参照のこと.

3.3 なぜマルチヘッドアテンションが重要?

この記事の最後として,「マルチヘッドアテンションが,Transformerにおいてどう重要か」について,元論文中の4節の主張(3.3.1)と,クロスモーダル目線での重要性(3.3.2)をまとめたい.

3.3.1 論文中の主張点

Transformerの構成は,マルチヘッドアテンション部分が占める割合が非常に大きい.そのため,その他の部品の改良では,あまり大きく改善しようがない側面がある.従って,その後のTransformerを応用するために改善する研究では,マルチヘッドアテンションを各自の目的に沿って改良することが,盛んに行われている .

自己アテンションが系列内EncodeとDecodeの主役であるので,CNNでの畳み込み層のように,当然Tranformerのマルチヘッド自己アテンションは重要部品である.

3.3.2 クロスモーダル問題での重要性

また,「クロスモーダルな問題(Vision-LanguageやVision-Audio)」などでも,2~3モーダル間の各EncoderやDecoderをあいだで関係づける際に,クロスモーダルな(マルチヘッド)アテンションを使用することが多い.

ただし,Transformer隆盛になる前の,seq2seq with attention時代から,V-LやV-Aでは,クロスモーダルアテンションの研究は盛んである(※ 私もseq2seq 時代から研究していた技術である).つまり,Transformer時代になっても,「クロスモーダルなアテンションの重要さは継続している」というだけではある.

4. まとめ

Transformerの,主要部品であるマルチヘッドアテンションについて紹介した.マルチヘッドアテンションの重要点は,系列内の自己アテンションを担当する点にある.これにより,系列内の「全トークン間」における長い依存関係を加味したCodingを,全トークン表現ベクトルに対して一気に適用できる.

Transformerの前世代のモデルであるseq2seq with attentionでは,そのうち,1回のアテンションのみを,各フレームで逐次実行していた.これを,マルチヘッドアテンションを主部品に採用したブロックの繰り返しによる自己回帰により,Transformerは長距離依存性(もとい,系列内の全トークン間の関係性)を捉えての「トークン表現一括変換」を用いた系列変換が行えるようになった.(全体設計について詳しくは,親記事のTransformerを参照)

4.1 改善モデルTransformer-XLの話

ちなみに,元祖のTransformerにも,マルチヘッドアテンションを主部品にしているので,固定長コンテキストしかモデリングできない問題があった.

Transformerは,入力系列行列Qと出力系列行列V(=K)のあいだで,QKVアテンションを行う.その際に各ヘッドがドット積行列計算ベースなので,系列長の可変具合にはうまく対応できない.自己アテンションの採用により,長期コンテキストはアテンションで捉えやすくなったが,固定長コンテキストしか学習できないのは,まだ少し良くない面があった.

そこで,Transformer-XLでは,TransformerにRNN的な時系列遷移を半分追加することで,可変長コンテキストに対応するようにした(※ 親記事Transformerの終盤で,Transformer-XLの節を用意してある).

関連記事

関連書籍

References

  • [西田 2022]「自然言語処理とVision-and-Language」人工知能学会全国大会(第36回) チュートリアル,2022, 西田京介.
  • [Bahdanau et al., 2015] Bahdanau, Dzmitry, Kyunghyun Cho, and Yoshua Bengio. “Neural machine translation by jointly learning to align and translate.” In ICLR, 2015.
  • [Gehring et al., 2017] Gehring, Jonas, et al. “Convolutional sequence to sequence learning.” In ICML. , 2017.
  • [Luong et al., 2015] M.-T. Luong, H. Pham, and C. D. Manning, Effective approaches to attention-based neural machine translation. In EMNLP, 2015.
  • [Sennrich et al., 2016] Rico Sennrich, Barry Haddow, and Alexandra Birch. 2016. Neural machine translation of rare words with subword units. In ACL, 2016.
  • [Vaswani et al., 2017] Ashish Vaswani, et al. Attention is all you need. In NIPS, 2017.

参照外部リンク

関連記事