マルチヘッドアテンション (Multi-head Attention) [Transformer]

1. マルチヘッドアテンション (Multi-head Attention) とは [概要]

マルチヘッドアテンション (Multi-head Attention) とは,Transformerで提案された,複数のアテンションヘッドを並列実行して,系列中の各トークン表現の変換を行うブロック部品である [Vaswani et al., 2017].端的に言うと,単に「並列実行アテンション」である.より具体的には,各ヘッドが「低次元射影する全結合層 + スケール化ドット積アテンション(QKV型のアテンション)」からなり,アテンション機構のを並列実行を行うEncoder-Decoderブロックとなっている.

この記事では,Transformerの主部品としての「マルチヘッドアテンション」について,処理の詳細や利点についてみていく(2節).旧来の系列対系列変換でも用いられていた「ソース・ターゲット系列間の相互アテンション」の役割も引き続きTransformerで担当するが,それに加えて,自己アテンションとしての役割もTransformerで担当する点が,特に重要である(3節).

マルチヘッドアテンションは,入力の「シーケンス長 $n$ × トークン表現次元 $d_{model}$」のサイズで構成される行列Q, K, Vを元に,以下の処理手順で並列にh=8個のアテンションヘッドを実行し,その結果を最後に結合して出力するブロックである (より詳しくは2節):

  • 各i番目のヘッドで,以下の処理を順に実行:
    1. [低次元射影の全結合層] 入力行列の Q, K, Vに対して,全結合層$W_i^Q,W_i^K,W_i^V$ により,64次元ベクトルへと射影 (次元削減).
    2. [スケール化ドット積アテンション] スケール化ドット積アテンションで表現変換を行い,$d_v$次元ベクトル群の$Z_i$を取得.
  • 8個の$Z_i$を1つに結合したのち,各トークン表現を全結合層WO(512→512次元)で線形変換したのものが,最終出力のトークン表現(系列).

つまり,旧来のseq2seq with attention時代の「各フレームでの,自己回帰なアテンション処理」を脱して,系列全体処理のN回スタック処理に変更した.これにより,高精度で計算効率性も高いTransformerを実現できている.

マルチヘッドアテンションの利点は,「ソース・ターゲット系列間のグローバルな関係コンテキストを加味できる」うえに,系列中の全トークン表現の一括変換になっているところである.簡単な演算(ドット積アテンションの並列化)ながら,高精度な系列変換をTransformerが実現できるのは,このマルチヘッドアテンションを,自己回帰で 6回(Encoder) or 12回(Decoder) 使用するTransformerのスタック設計のおかげである.

※ 関連書籍にも挙げたGetting Started with Google BERT では,Transformer (とBERT)の,具体的な数値例も図示しながらの説明が展開されていて,計算時の挙動を理解しやすい.より具体的な,例・図解や数値例も見られる方が理解しやすいという方にはオススメである.

1.1 系列全体の一括アテンション処理

Transformerは,機械翻訳や言語モデル,TTS・音声認識などの目的で,別ドメイン同士の系列を変換する際によく用いる,定番の系列対系列変換モデルである.

従来モデルに対するTransformerの主たる工夫が,マルチヘッドアテンションの考案と,その自己アテンションとしての使用(3節)である.従来のseq2seq with attention時代では,以下の2つの処理により,ローカルな,フレーム単位の自己回帰予測を行っていた:

対して,Transformerは,窓サイズTのマルチヘッドアテンションを用いることで,「T個のトークン幅」での系列内の各トークン表現を一気に変換するようになり,毎フレームでの予測を随時繰り返すRNN方式ではなくなった.また,マルチヘッドアテンションで,「系列全体(グローバル)の長期依存コンテキスト」を加味できるようになったことで,系列対系列変換性能も向上した(3.2.1節で,新旧を表で比較).

1.2 記事の構成

  • 記事の前半(2節):Transformerにおけるマルチヘッドアテンションの処理と定式化を行う.マルチヘッドアテンションが,マルチヘッド(=並列的)構成のEncoder-Decoder構成であることのメリットに焦点を当てる.
  • 記事の後半(3節):Transformer内の各マルチヘッドアテンションが,自己アテンション or 相互アテンションのどちらなのか(3.1節)と,各QKV入力がどのように異なるか(3.2節)について,図を提示しながら整理する.

この記事は,マルチヘッドアテンションにのみフォーカスする.よって,Transformerの全体的な仕組みや他の構成部品については,親記事や元論文,および関連書籍等を参照のこと.

2. マルチヘッドアテンションの詳細

マルチヘッドアテンション(Multi-head Attention) [Transformerの主部品]
図1. マルチヘッドアテンション (Multi-head Attention) : Transformerの主部品.

マルチヘッドアテンション(図1)は,スケール化ドット積アテンション(QKV入力方式),$h=8$ 個並行に実施したのち,最後に8個の表現を結合した1表現に,再度結合して出力するというDNN用ブロックである.

各ヘッドで使用される「スケール化ドット積アテンション(2.2.2節)」は,QKV方式アテンションである.入力表現は3つあり,{クエリ行列$\bm{Q}$,[キー行列$\bm{K}$,バリュー行列$\bm{V}$]}である.h個の各アテンションには,同一の入力$(\bm{Q},[\bm{K},\bm{V}])$をコピーして,$h=8$ 個のヘッドに全て同一入力として与える.

このとき,各i番目ヘッド内の「QKV低次元射影変換 + QKVアテンション変換」には,それぞれ異なる変換が学習されている点がポイントである.マルチヘッド化により,同じ入力が,h(=8)個異なる変換が行われたのち合成される.これにより,それまで主流だったシングルヘッド・アテンションよりも,表現力が高くなり,高精度な変換をシンプルな並列計算だけで学習できるようになった.(※ 3.2.1節に,新旧比較表を用意).

つまり,1節冒頭でも述べたように,各マルチヘッドアテンションは「(1)並列計算」かつ「(2)行列演算(ドット積)中心 (※ 計算負荷の高いアテンション係数値を予測するサブネットは学習しない)」という性質で設計されているブロックなので,計算機的に非常に有利である.

2.1 マルチヘッド・アテンションの処理手順

マルチヘッドアテンション (図1) では,入力の「トークン表現の系列」の変換処理を,h = 8個の複数アテンションを用いて,以下の3手順により行う:

  1. [アテンション前の前処理:トークン表現の次元削減]
    • 入力のQ, K, V を構成する入力ベクトル $\bm{x}$を $d_k = d_v = d_{model} /h = 64$の低次元ベクトルへと,$i$番目ヘッド個別の全結合層$\bm{W}_i^Q,\bm{W}_i^K,\bm{W}_i^V$を用いて射影する.(※ 入力は前の層から来るので,埋め込み層ではない)
  2. [h=8個のアテンションを並列実行]
    • 低次元に射影された$\bm{Q}\bm{W}_i^Q,\bm{K}\bm{W}_i^K,\bm{V}\bm{W}_i^V$を入力に,8個のスケール化ドット積アテンション(2.2式)を実行.
    • 8個の$\bm{Z}_1, \bm{Z}_2 ,\ldots, \bm{Z}_8$ が,各i番目のヘッドで得られる.
  3. [出力:並列処理結果の結合]
    • $\bm{Z}_i$内の各トークン表現を,1つのベクトルにそれぞれ結合し,$8 \cdot d_v = 512$次元ベクトル表現になる.
    • 結合したものを,全結合層$\bm{W^O}$で変換する.これにより,並列アテンション処理中の$h \cdot d_v$次元から,元の$d_{model}$次元に戻った最終的なトークン表現出力を得る.

Transformer は最初,seq2seq with attentionを差し替える「次世代の系列対系列変換ネットワーク」として,機械翻訳向けに提案された.その後seq2seq with attention同様に,様々な系列対系列変換タスクで,広くTransformerは用いられている

Transformerまでの段階的な経緯は,系列対系列変換とアテンション機構の記事を参照のこと.

2.2 マルチヘッドアテンションの定式化

2.1節の説明を元論文や式と対応付けられるようにするために,2.1節で述べた処理手順を,Transformerの元論文の「変換式」通り,処理順に示す.

2.2.1 各ヘッドでの処理:

まず,各$i$番目ヘッドでは,$i$番目のアテンション$(\text{Attention}(\cdot,\cdot,\cdot))$を用いて,出力$\bm{Z}_i$が並列に計算される:

\[
\bm{Z}_i = \text{Attention}(\bm{Q}\bm{W}_i^Q,\bm{K}\bm{W}_i^K,\bm{V}\bm{W}_i^V) \tag{2.1}
\]

式(2.1)は,「スケール化ドット積アテンション$\text{Attention}$」と,「$d_{model}$次元のトークン行ベクトルを列数だけ重ねた行列 $\bm{Q}, \bm{K}, \bm{V} $ (それぞれQuery, Key, Value)の表現を,低次元に射影する3つの全結合層 $\bm{W}_i^Q,\bm{W}_i^K,\bm{W}_i^V$ から構成されている (図1-a 中盤赤色枠の部分).

ここで,各線形層のパラメータ行列の次元数は,それぞれ以下のようになる(※ 2.1節でも述べたように $d_k = d_v = d_{model} / 8= 64$):

  • $\bm{W}_i^Q \in \mathcal{R}^{d_{model} \times d_k}$
  • $\bm{W}_i^K \in \mathcal{R}^{d_{model} \times d_k}$
  • $\bm{W}_i^V \in \mathcal{R}^{d_{model} \times d_v}$

2.2.2 スケール化ドット積アテンション

式(2.1)中の,各$i$番目ヘッドで行うスケール化ドット積アテンションは,以下の(2.2)式の変換である:

\[
\text{Attention}(\bm{Q},\bm{K},\bm{V}) = \text{softmax}\left(\frac{\bm{Q} \bm{K}^T}{\sqrt{d_k}}\right) \bm{V} \tag{2.2}
\]

ここで,$\sqrt{d_k}$はKey入力$\bm{K}_k$の,各ベクトルの次元数である.

このスケール化ドット積アテンションは,QKVアテンションに,スケール正規化を追加したものである.系列全体のベクトルを積んだ行列の,Query表現行列 \bm{Q}とKey表現行列$\bm{K}$のあいだで,ドット積をもちいて,両者の行列内の各ベクトル同士の関連度(=類似度)計算を系列全体を一括に行う

seq2seq with attention時代は,RNNにより1フレームずつトークン表現を予測・更新していた.それが,Transformerでは系列内のトークン全部を,QKVアテンションの行列変換で,一括変換するように設計変更したわけである.

スケール化ドット積アテンションでは,「複雑なスコア関数の計算・学習」は行わずに,単なるベクトル同士の内積で算出したスカラー類似度を,$\frac{1}{\sqrt{d_k}}$でスケール化した値を,各ベクトル同士のアテンション係数として使用する.(※ QKVアテンション全般に共通の性質)

また,低次元表現へ射影する全結合層$\bm{W}_i^Q,\bm{W}_i^K,\bm{W}_i^V$により射影しても,Q,K間の類似度計算と合成後の最終出力ベクトル値が調整されることになる(2.2.1).

ドット積アテンションを採用しているのは,行列計算テクニックを用いて高速化しやすいからである.また著者が,$\frac{1}{\sqrt{d_k}}$でスケール化するようにしているのは,次のsoftmax層の勾配が小さくなって,勾配消失してしまうのを防ぐ目的である.

2.2.3 各アテンションの出力を,結合および変換

式(2.1),(2.2)による h = 8 ヘッドの並列計算が全て完了したら,(2.1)の各ヘッド$i$の出力表現$\bm{Z}_i$を結合して,1つの結合表現($h \times d_v = 8 \times 64 = 512$次元)にまとめる.

そして,結合した表現を,全結合層$\bm{W}^{O} \in \mathcal{R}^{h \cdot d_v \times d_{model}}$で最後に変換することで,最終的なマルチヘッドアテンションで変換済みのトークン表現行列が得られる:

\[
\text{MultiHeadAttention}(\bm{Q},\bm{K},\bm{V}) = \text{Concat}(\bm{Z}_1, \bm{Z}_2, \ldots, \bm{Z}_8) \bm{W}^{O} \tag{2.3}
\]

以上が,マルチヘッドアテンションの計算手順である.その変換式(2.1)~(2.3)は,2018年以降のTransformer流行により,非常に多くの「Transformerを応用したパターン認識論文」で見かける式となった (図1も,2.2.1~2.2.3節の3処理への対応付けが行いやすいように描いたつもりである).

2.3 命名について

seq2seq with attention 世代で使用されていた「Bahdanau方式のアテンション」との違いを明確にするためには,もう少し詳しく「マルチヘッドQKVスケール化ドット積アテンション」と呼んでもよいと,管理人的には思う.

とはいえ,Transformerの著者らは,と「QKV」と「スケール化ドット積」を両方とも省略した「マルチヘッド・アテンション」と,簡潔な命名を行った.これはこれで,並列化アテンションと同義なので,汎用的な呼び名である.

3. Transformerにおけるマルチヘッドアテンションの使用

Transformer内の各マルチヘッドアテンションは,入力のQKVを引っ張ってくる場所次第で,「自己アテンション」なのか,「相互アテンション」なのかが変わり,果たす役割も違ってくる.

3節では,自己アテンションと相互アテンションの違いについてまず整理したうえで(3.1節),Transformer内の3箇所のマルチヘッドアテンションについて,QKV入力の受け取り元がどのようになっているかと,どれが自己 or 相互アテンションなのかを整理できるようにしておきたい.

3.1 自己アテンション と 相互アテンションの違い

自己アテンションと相互アテンションの比較
図2. 自己アテンションと相互アテンションの比較

Transformer登場以前の時代から,アテンションを系列対系列変換で用いる際には,アテンション処理を「 (1)系列内で行う」か「(2)系列間で行う」かによって以下の2種類に分類してきた:

  • [図2-a] 自己アテンション(self-attention) 系列内での,各トークン表現間のアテンション.
  • [図2-b] 相互アテンション(cross-attention) 系列間での,各トークン表現間のアテンション(source-target アテンションと呼ぶこともある).

自己アテンション(系列内アテンション) [Cheng et al, 2016], [Parikh et al, 2016] は,系列内のトークン表現同士で,関連度を学習する(図3-a). これに対して,Bahdanau式の seq2seq with attentionでは,系列間でのトークン表現間の関連度を学習するので,相互アテンション(系列間アテンション)である(図3-b).Transformerが,自己アテンション(セルフアテンション)の良さを全面に押し出して,人気モデルとなったことから,この「自己アテンション v.s 相互アテンション」の呼び方が主流となった.

それ以前も,VQAや,画像キャプション生成,画像グラウンディング などのクロスモーダルアテンション技術においては,別の呼び方として「intra v.s inter アテンション」という同じことを指す呼び方もあった.しかし,(私の記憶や,知ってる範囲に限れば) Transformer流行後は「intra-interアテンション」という呼び方はあまり使われなくなり「自己アテンション・相互アテンション」の呼び名に統一された印象である.

3.2 自己 or 相互アテンションの視点で,Transformerの構成を復習

Transformerの概要図:各マルチヘッドアテンションの種類の整理
図3. Transformerの概要図:各マルチヘッドアテンションの種類の整理

図3は,Transformerの全体構成図であるが,構成部品の多くは省略したことで,各マルチヘッドアテンションにおける,以下2点の違いがわかりやく確認できる全体図にしてある:

  1. 入力の Q, K, Vのそれぞれが, 「X(入力系列)から来ている 」or 「Y(出力系列)から来ている」
  2. 自己アテンション or 相互アテンション

(図3は,[西田 2022] の図にインスパイアされた図で,本サイトの色分けに合わせてアレンジした).

この図3と同様に,各アテンションの違いにのみフォーカスし,TransformerのEncoderとDecoderの構成について以下に整理した (より詳しくは,親記事Transformerや元論文を参考):

  • Transformer-Encoderブロック(× 6回):
    • 入力$\bm{X}$:入力系列のT個のトークン表現ベクトルを並べた行列.
    • マルチヘッド自己アテンション:
      • マルチヘッド自己アテンションで,各トークン表現を更新.
    • トークン位置ごとにFFNを順伝搬:
    • Nブロック分繰り返したのち,$X^{(N)}$を最終出力(Decoderの相互アテンション用の出力).
  • Transformer-Decoderブロック (× 6回) :
    • 入力トークン行列$Y$:前フレームまでの予測 $y_t$を,$N$フレームならべた行列.
    • マルチヘッド自己アテンション:
      • $(Q=Y, K=Y,V =Y$)の入力で,各ヘッドを並列実行し,中間出力$\bm{Z}$を出力.
    • マルチヘッド相互アテンション:
      • $(Q=Z, K=X^{(N)},V =X^{(N)}$)で各ヘッドを並列実行し,入力系列の符号$X^{(N)}$から,予測に対応するコンテキストを取り込む.
    • トークン位置ごとにFFNを順伝搬
  • 旧来のseq2seq同様に,あらたに予測されたトークンを,次ぎフレームの入力$K$としてフィードバックし,全体を自己回帰 ([EOS]トークンが出るまで繰り返し).

以上で,Transformerの構成と処理手順における,マルチヘッドアテンションの違いの紹介は,終わりである.

3.2.1 旧来のseq2seq 向けと,Transformer 向けアテンションの違い

さて,上記の処理を復習したので,従来手法とTransformerのあいだで「アテンション」がどのように異なるかを,以下の表にまとめておく:

系列の符号化・復号化系列間での関係学習
seq2seq with attention
(従来手法)
LSTMのセル間遷移
ローカルな各フレーム間で変換
(相互)シングルアテンション:
次の予測単語と,入力系列の全単語のコンテキストを使用.
ConvSeq2seq
(従来手法)
系列方向の1D畳み込み層
ローカルな窓内での畳み込み遷移.
上に同じ
Transformerマルチヘッド自己アテンション
系列全体(=global)の,各トークン表現ベクトルを,一挙に変換.
+ トークンごとにFFNで変換
(図3 Encoder下部, Decoder下部)
マルチヘッド相互アテンション
Encoder-Decoder間もマルチヘッド化し,入力系列全体の各トークンを一挙に変換.
(図3 右Decoder上部).
テーブル1 旧来手法とTransformerの比較

ちなみに,この記事では話さない「マスク処理による未来トークンの隠ぺい処理」については,親記事Transformer 2.3.2節を参照のこと.

3.3 なぜマルチヘッドアテンションが重要?

この記事の最後として「マルチヘッドアテンションが,Transformerにおいてどう重要か」について,元論文中の4節の主張(3.3.1)と,クロスモーダル目線での重要性(3.3.2)をまとめたい.

3.3.1 論文中の主張点

Transformerの構成は,マルチヘッドアテンション部分が占める割合が非常に大きい.そのため,その他の部品の改良では,あまり大きく改善しようがない側面がある.従って,その後のTransformerを応用するために改善する研究では,マルチヘッドアテンションを各自の目的に沿って改良することが,盛んに行われている .

自己アテンションが系列内EncodeとDecodeの主役であるので,CNNでの畳み込み層のように,当然Tranformerのマルチヘッド自己アテンションは重要部品である.

3.3.2 クロスモーダル問題での重要性

また「クロスモーダルな問題 (Vision-LanguageやVision-Audio)」などでも,2~3モーダル間の各EncoderやDecoderをあいだで関係づける際に,クロスモーダルな(マルチヘッド)アテンションを使用することが多い.

ただし,Transformer が隆盛になる前の,seq2seq with attention 時代から,V-LやV-Aでは,(モーダル間の)クロスモーダルアテンションの研究は盛んである.つまり,Transformer時代になっても「クロスモーダルなアテンションの重要さは継続している」というだけではある.

4. まとめ

Transformer の,主要部品であるマルチヘッドアテンションについて詳しく紹介した.マルチヘッドアテンションの重要点は,系列内の自己アテンションを担当する点にあり,これにより,系列内の「全トークン間」における長い依存関係を加味したEncoding/Decodingを,全トークン表現ベクトルに対して一気に適用できる.

4.1 改善モデルTransformer-XLの話

ちなみに,元祖のTransformerにも,マルチヘッドアテンションを主部品にしているので,固定長コンテキストしかモデリングできない問題があった.

Transformerは,入力系列行列Qと出力系列行列V(=K)のあいだで,QKVアテンションを行う.その際に各ヘッドがドット積行列計算ベースなので,系列長の可変具合にはうまく対応できない.自己アテンションの採用により,長期コンテキストはアテンションで捉えやすくなったが,固定長コンテキストしか学習できないのは,まだ少し良くない面があった.

そこで,Transformer-XLでは,TransformerにRNN的な時系列遷移を半分追加することで,可変長コンテキストに対応するようにした(※ 親記事Transformerも,5節がTransformer-XLについて).

関連書籍

References

  • [西田 2022]「自然言語処理とVision-and-Language」人工知能学会全国大会(第36回) チュートリアル,2022, 西田京介.
  • [Bahdanau et al., 2015] Bahdanau, Dzmitry, Kyunghyun Cho, and Yoshua Bengio. “Neural machine translation by jointly learning to align and translate.” In ICLR, 2015.
  • [Gehring et al., 2017] Gehring, Jonas, et al. “Convolutional sequence to sequence learning.” In ICML. , 2017.
  • [Luong et al., 2015] M.-T. Luong, H. Pham, and C. D. Manning, Effective approaches to attention-based neural machine translation. In EMNLP, 2015.
  • [Sennrich et al., 2016] Rico Sennrich, Barry Haddow, and Alexandra Birch. 2016. Neural machine translation of rare words with subword units. In ACL, 2016.
  • [Vaswani et al., 2017] Ashish Vaswani, et al. Attention is all you need. In NIPS, 2017.

参照外部リンク

関連記事