1. マルチヘッドアテンション (Multi-head Attention) とは [概要]
マルチヘッドアテンション (Multi-head Attention) とは,Transformerで提案された,複数のアテンションヘッドを並列実行して,系列中の各トークン表現の変換を行うブロック部品である [Vaswani et al., 2017].端的に言うと「並列型アテンション」である.
この記事では,Transformerの主部品としての「マルチヘッドアテンション」について,処理の詳細や利点についてみていく(2節).旧来の系列対系列変換でも用いられていた「ソース・ターゲット系列間の相互アテンション」の役割も引き続きTransformerで担当するが,それに加えて,自己アテンションとしての役割もTransformerで担当する点が,特に重要である(3節).
より具体的には,h個のヘッドが「低次元射影する全結合層 + スケール化ドット積アテンション(QKV型のアテンション)」を,少し下げた次元でアテンション処理を個別実行し,できあがったh個の表現を結合してまとめたのち,最後に線形変換する構成となっている.
1.1 記事の構成
- 記事の前半(2節):
- Transformerにおけるマルチヘッドアテンションの処理と定式化を行う.
- マルチヘッドアテンションが,マルチヘッド(=並列的)構成のEncoder-Decoder構成であることのメリットに焦点を当てる.
- 記事の後半(3節):
- Transformer内の各マルチヘッドアテンションが,(1) 自己アテンション or 相互アテンションのどちらなのか(3.1節)と,(2) 各QKV入力がどう異なるか(3.2節)について,順に整理する.
この記事は,マルチヘッドアテンションにのみフォーカスする.Transformerの全体的な仕組みや他構成部品については,親記事や元論文,および関連書籍等を参照のこと.
1.2 マルチヘッドアテンションの概要
マルチヘッドアテンションは(図1),入力の「シーケンス長 $n$ × トークン表現次元 $d_{\text{model}}$」のサイズで構成される行列である$\bm{Q}$(Query), $\bm{K}$(Key), $\bm{V}$(Value)を,h個の並列したヘッドで異なる処理を行った結果を,最後に1つに結合する処理である.
マルチヘッドアテンションでは,以下のように,並列に h=8個のアテンションヘッドを実行し,その結果を,後半でアテンション処理済みの表現h個を1つに結合して最終的に出力するブロックである (より詳細は2.1節で):
- 前半:h=8個のパスを並列にアテンション処理:
- 後半:アテンション処理を終えたh個の表現を合体:
- h=8個の$Z_i$を,1つに結合して統合(Concat).
- 各トークン表現を全結合層 $W^{O}$(512→512次元)で変換.
マルチヘッドアテンションの利点は,「ソース・ターゲット系列間のグローバルな関係コンテキストを加味できる」うえに,系列中の全トークン表現の一括変換になっているところである.簡単な演算(ドット積アテンションの並列化)ながら,高精度な系列変換をTransformerが実現できるのは,このマルチヘッドアテンションを,自己回帰で 6回(Encoder) or 12回(Decoder) 使用するTransformerのスタック設計のおかげである.
これにより,旧来のseq2seq with attention時代の,
- Encoder-Decoder間
- 各フレームでの,ソース-ターゲット間の相互(シングル)アテンション処理
でのアテンションの使用を,Transformer内では,以下の3箇所で「自己 or 相互マルチヘッドアテンション」に替えた:
- Transformer-Encoder
- Encoderの系列内の自己アテンション処理
- Transformer-Decoder
- Decoder前半の系列内の自己アテンション
- Transformer-Encoder v.s Tranformer-Decoder間
- Decoder後半の,系列間の相互アテンション
このように,Transformerは「(マルチ)アテンション中心の構成」恩恵で,高精度で計算効率性も高い系列変換モデルを実現できた.以上がTransformer目線での「自己or相互」のどちらをどこで使うかの概要である(3節では,図つきでより詳しく解説).
※ 関連書籍にも挙げたGetting Started with Google BERT では,Transformer (とBERT)の,具体的な数値例も図示しながらの説明が展開されていて,計算時の挙動を理解しやすい.より具体的な,例・図解や数値例も見られる方が理解しやすいという方にはオススメである.
1.3 系列全体の「一括アテンション処理」
Transformerは,機械翻訳や言語モデル,TTS・音声認識などの目的で,別ドメイン同士の系列を変換する際によく用いる,定番の系列対系列変換モデルである.
従来モデルに対するTransformerの主たる工夫が,マルチヘッドアテンションの考案と,その自己アテンションとしての使用(3節)である.従来のseq2seq with attention時代では,以下の2つの処理により,ローカルな,フレーム単位の自己回帰予測を行っていた:
- seq2seq with attentionのシングル・アテンション [Bahdanau et al., 2015], [Luong et al., 2015]
- RNNによるフレーム間遷移 (seq2seq with attention) or 畳み込み層によるフレーム間遷移 (ConvSeq2seq [Gehring et al., 2017])
対して,Transformerは,窓サイズTのマルチヘッドアテンションを用いることで,「T個のトークン幅」での系列内の各トークン表現を一気に変換するようになった.つまり,フレーム間のみで予測を随時繰り返すRNN方式ではなくなった.ただし,Transformer-Decoderは,フレーム間で1トークンずつ予測する自己回帰モデルのままである.
また,マルチヘッドアテンションで,「系列全体(グローバル)の長期依存コンテキスト」を加味できるようになったことで,系列対系列変換性能も向上した(3.2.1節で,新旧を表で比較).
2. マルチヘッドアテンションの詳細
マルチヘッドアテンションは,並列にアテンションを実行するためのTransformer向けブロックである.スケール化ドット積アテンション(QKV入力方式)を,$h=8$ 個並行に実施したのち,最後に8個の表現を1表現に再度結合して,全体を全結合層$\bm{W^O}$で線形変換したあと最終出力する (1.1節で概要)
各ヘッドで使用される「スケール化ドット積アテンション(2.2.2節)」は,QKV方式アテンションである.入力表現は3つあり,{クエリ行列$\bm{Q}$,[キー行列$\bm{K}$,バリュー行列$\bm{V}$]}である.h個の各アテンションには,同一の入力$(\bm{Q},[\bm{K},\bm{V}])$をコピーして,$h=8$ 個のヘッドに全て同一入力として与える.
このとき,各i番目ヘッド内の「QKV低次元射影変換 + QKVアテンション変換」には,それぞれ異なる変換が学習されている点がポイントである.マルチヘッド化により,同じ入力が,h(=8)個異なる変換が行われたのち合成される.これにより,それまで主流だったシングルヘッド・アテンションよりも,表現力が高くなり,高精度な変換をシンプルな並列計算だけで学習できるようになった.(※ 3.2.1節に,新旧比較表を用意).
つまり,1節冒頭でも述べたように,各マルチヘッドアテンションは「(1)並列計算」かつ「(2)行列演算(ドット積)中心 (※ 計算負荷の高いアテンション係数値を予測するサブネットは学習しない)」という性質で設計されているブロックなので,計算機的に非常に有利である.
2.1 詳細な処理手順
マルチヘッドアテンション(図1)では,入力の「トークン表現の系列」の変換処理を,以下の3手順により,h = 8個のアテンションヘッドで並列に行う:
- 前処理:トークン表現の次元削減
- 本処理:h=8個のアテンションを並列実行
- 低次元に射影された$\bm{Q}\bm{W}_i^Q,\bm{K}\bm{W}_i^K,\bm{V}\bm{W}_i^V$を入力に,8個のスケール化ドット積アテンション(2.2式)を実行.
- 8個の$\bm{Z}_1, \bm{Z}_2 ,\ldots, \bm{Z}_8$ が,各i番目のヘッドで得られる.
- [後処理]:並列処理結果の結合と,最後の線形変換
- $\bm{Z}_i$内の各トークン表現を,1つのベクトルにそれぞれ結合し,$8 \cdot d_v = 512$次元ベクトル表現になる.
- 結合したものを,全結合層$\bm{W^O}$で変換する.これにより,並列アテンション処理中の$h \cdot d_v$次元から,元の$d_{model}$次元に戻った最終的なトークン表現出力を得る.
Transformer は最初,seq2seq with attentionを差し替える「次世代の系列対系列変換ネットワーク」として,機械翻訳向けに提案された.その後seq2seq with attention同様に,様々な系列対系列変換タスクで,広くTransformerは用いられている
Transformerまでの段階的な経緯は,系列対系列変換とアテンション機構の記事を参照のこと.
2.2 マルチヘッドアテンションの定式化
2.1節の説明を元論文や式と対応付けられるようにするために,2.1節で述べた処理手順を,Transformerの元論文の「変換式」通り,処理順に示す.
2.2.1 各ヘッドでの処理:
まず,各$i$番目ヘッドでは,$i$番目のアテンション$(\text{Attention}(\cdot,\cdot,\cdot))$を用いて,出力$\bm{Z}_i$を並列に計算する:
\[
\bm{Z}_i = \text{Attention}(\bm{Q}\bm{W}_i^Q,\bm{K}\bm{W}_i^K,\bm{V}\bm{W}_i^V) \tag{2.1}
\]
式(2.1)は,「スケール化ドット積アテンション$\text{Attention}$」と,「$d_{model}$次元のトークン行ベクトルを列数だけ重ねた行列 $\bm{Q}, \bm{K}, \bm{V} $ (それぞれQuery, Key, Value)の表現を,低次元に射影する3つの全結合層 $\bm{W}_i^Q,\bm{W}_i^K,\bm{W}_i^V$ から構成されている (図1-a 中盤赤色枠の部分).
ここで,各線形層のパラメータ行列の次元数は,それぞれ以下のようになる(※ 2.1節でも述べたように $d_k = d_v = d_{model} / 8= 64$):
- $\bm{W}_i^Q \in \mathcal{R}^{d_{model} \times d_k}$
- $\bm{W}_i^K \in \mathcal{R}^{d_{model} \times d_k}$
- $\bm{W}_i^V \in \mathcal{R}^{d_{model} \times d_v}$
2.2.2 スケール化ドット積アテンション
式(2.1)中の,各$i$番目ヘッドで行うスケール化ドット積アテンションは,以下の(2.2)式の変換である:
\[
\text{Attention}(\bm{Q},\bm{K},\bm{V}) = \text{softmax}\left(\frac{\bm{Q} \bm{K}^T}{\sqrt{d_k}}\right) \bm{V} \tag{2.2}
\]
ここで,$\sqrt{d_k}$はKey入力$\bm{K}_k$の,各ベクトルの次元数である.
このスケール化ドット積アテンションは,QKVアテンションに,スケール正規化を追加したものである.系列全体のベクトルを積んだ行列の,Query表現行列 \bm{Q}とKey表現行列$\bm{K}$のあいだで,ドット積をもちいて,両者の行列内の各ベクトル同士の関連度(=類似度)計算を系列全体を一括に行う.
seq2seq with attention時代は,RNNにより1フレームずつトークン表現を予測・更新していた.それが,Transformerでは系列内のトークン全部を,QKVアテンションの行列変換で,一括変換するように設計変更したわけである.
スケール化ドット積アテンションでは,「複雑なスコア関数の計算・学習」は行わずに,単なるベクトル同士の内積で算出したスカラー類似度を,$\frac{1}{\sqrt{d_k}}$でスケール化した値を,各ベクトル同士のアテンション係数として使用する.(※ QKVアテンション全般に共通の性質)
また,低次元表現へ射影する全結合層$\bm{W}_i^Q,\bm{W}_i^K,\bm{W}_i^V$により射影しても,Q,K間の類似度計算と合成後の最終出力ベクトル値が調整されることになる(2.2.1).
ドット積アテンションを採用しているのは,行列計算テクニックを用いて高速化しやすいからである.また著者が,$\frac{1}{\sqrt{d_k}}$でスケール化するようにしているのは,次のsoftmax層の勾配が小さくなって,勾配消失してしまうのを防ぐ目的である.
2.2.3 各アテンションの出力を,結合および変換
式(2.1),(2.2)による h = 8 ヘッドの並列計算が全て完了したら,(2.1)の各ヘッド$i$の出力表現$\bm{Z}_i$を結合して,1つの結合表現($h \times d_v = 8 \times 64 = 512$次元)にまとめる.
そして,結合した表現を,全結合層$\bm{W}^{O} \in \mathcal{R}^{h \cdot d_v \times d_{model}}$で最後に変換することで,最終的なマルチヘッドアテンションで変換済みのトークン表現行列が得られる:
\[
\text{MultiHeadAttention}(\bm{Q},\bm{K},\bm{V}) = \text{Concat}(\bm{Z}_1, \bm{Z}_2, \ldots, \bm{Z}_8) \bm{W}^{O} \tag{2.3}
\]
以上が,マルチヘッドアテンションの計算手順である.その変換式(2.1)~(2.3)は,2018年以降のTransformer流行により,非常に多くの「Transformerを応用したパターン認識論文」で見かける式となった (図1も,2.2.1~2.2.3節の3処理への対応付けが行いやすいように描いたつもりである).
2.3 命名について
seq2seq with attention 世代で使用されていた「Bahdanau方式のアテンション」との違いを明確にするためには,もう少し詳しく「マルチヘッドQKVスケール化ドット積アテンション」と呼んでもよいと,管理人的には思う.
とはいえ,Transformerの著者らは,と「QKV」と「スケール化ドット積」を両方とも省略した「マルチヘッド・アテンション」と,簡潔な命名を行った.これはこれで,並列化アテンションと同義なので,汎用的な呼び名である.
3. Transformerにおける使用
Transformer内の各マルチヘッドアテンションは,入力のQKVを引っ張ってくる場所次第で,「自己アテンション」なのか,「相互アテンション」なのかが変わり,果たす役割も違ってくる.
3節では,自己アテンションと相互アテンションの違いについてまず整理したうえで(3.1節),Transformer内の3箇所のマルチヘッドアテンションについて,QKV入力の受け取り元がどのようになっているかと,どれが自己 or 相互アテンションなのかを整理できるようにしておきたい.
3.1 自己アテンション と 相互アテンションの違い
Transformer登場以前の時代から,アテンションを系列対系列変換で用いる際には,アテンション処理を「 (1)系列内で行う」か「(2)系列間で行う」かによって以下の2種類に分類してきた:
- [図2-a] 自己アテンション(self-attention) :系列内での,各トークン表現間のアテンション.
- [図2-b] 相互アテンション(cross-attention) :系列間での,各トークン表現間のアテンション(source-target アテンションと呼ぶこともある).
自己アテンション(系列内アテンション) [Cheng et al, 2016], [Parikh et al, 2016] は,系列内のトークン表現同士で,関連度を学習する(図3-a). これに対して,Bahdanau式の seq2seq with attentionでは,系列間でのトークン表現間の関連度を学習するので,相互アテンション(系列間アテンション)である(図3-b).Transformerが,自己アテンション(セルフアテンション)の良さを全面に押し出して,人気モデルとなったことから,この「自己アテンション v.s 相互アテンション」の呼び方が主流となった.
それ以前も,VQAや,画像キャプション生成,画像グラウンディング などのクロスモーダルアテンション技術においては,別の呼び方として「intra v.s inter アテンション」という同じことを指す呼び方もあった.しかし,(私の記憶や,知ってる範囲に限れば) Transformer流行後は「intra-interアテンション」という呼び方はあまり使われなくなり「自己アテンション・相互アテンション」の呼び名に統一された印象である.
3.2 「自己 or 相互」でTransformerの各アテンションを復習
図3は,Transformerの全体構成図であるが,構成部品の多くは省略したことで,各マルチヘッドアテンションにおける,以下2点の違いがわかりやく確認できる全体図にしてある:
- 入力の Q, K, Vのそれぞれが「X(入力系列)から来ている 」or 「Y(出力系列)から来ている」
- 自己アテンション or 相互アテンション
(図3は,[西田 2022] の図にインスパイアされた図である.それを,このサイトでの色分けに合わせたものにアレンジした).
この図3と同様に,各アテンションの違いにのみフォーカスし,TransformerのEncoderとDecoderの構成について以下に整理した (より詳しくは,親記事Transformerや元論文を参考):
- Transformer-Encoderブロック(× 6回):
- 入力$\bm{X}$:入力系列のT個のトークン表現ベクトルを並べた行列.
- マルチヘッド自己アテンション:
- マルチヘッド自己アテンションで,各トークン表現を更新.
- トークン位置ごとにFFNを順伝搬:
- $N$ブロック分繰り返したのち,$\bm{X}^{(N)}$を最終出力 (Decoderの相互アテンション用の出力).
- Transformer-Decoderブロック (× 6回) :
- 入力トークン行列$\bm{Y}$:前フレームまでの予測 $y_t$を,$N$フレームならべた行列.マルチヘッド自己アテンション:
- $(\bm{Q}=\bm{Y}, \bm{K}=\bm{Y}, \bm{V} =\bm{Y}$)の入力で,各ヘッドを並列実行し,中間出力$\bm{Z}$を出力.
- マルチヘッド相互アテンション:
- $(\bm{Q}=\bm{Z}, \bm{K}=\bm{X}^{(N)},\bm{V} =\bm{X}^{(N)}$)で各ヘッドを並列実行し,入力系列の符号$X^{(N)}$から,予測に対応するコンテキストを取り込む.
- トークン位置ごとにFFNを順伝搬して,各ベクトルを表現変換.
- 入力トークン行列$\bm{Y}$:前フレームまでの予測 $y_t$を,$N$フレームならべた行列.マルチヘッド自己アテンション:
- 旧来のseq2seq同様に,あらたに予測されたトークンを,次ぎフレームの入力$K$としてフィードバックし,全体を自己回帰 ([EOS]トークンが出るまで繰り返し).
3.2.1 旧来のseq2seq 向けと,Transformer 向けアテンションの違い
上記の処理を復習したので,従来手法とTransformerのあいだでの「アテンションの相違点」を,表にまとめておく:
系列の符号化・復号化 | 系列間での関係学習 | |
seq2seq with attention (従来手法) | LSTMのセル間遷移: ローカルな各フレーム間での変換. | (相互)シングルアテンション: 次の予測単語と,入力系列の全単語のコンテキストを使用. |
ConvSeq2seq (従来手法) | 系列方向の1D畳み込み層: ローカルな窓内での畳み込み遷移. | 上に同じ |
Transformer | マルチヘッド自己アテンション: 系列全体(=global)の,各トークン表現ベクトルを,一挙に変換. + トークンごとにFFNで変換 (図3 Encoder下部, Decoder下部) | マルチヘッド相互アテンション: Encoder-Decoder間もマルチヘッド化し,入力系列全体の各トークンを一挙に変換. (図3 右Decoder上部). |
ちなみに,この記事では話さない「マスク処理による未来トークンの隠ぺい処理」については,親記事Transformer 2.3.2節を参照のこと.
3.3 なぜマルチヘッドアテンションが重要?
この記事の最後として「マルチヘッドアテンションが,Transformerにおいてどう重要か」について,元論文中の4節の主張(3.3.1)と,クロスモーダル目線での重要性(3.3.2)をまとめたい.
3.3.1 論文中の主張点
Transformerの構成は,マルチヘッドアテンション部分が占める割合が非常に大きい.そのため,その他の部品の改良では,あまり大きく改善しようがない側面がある.従って,その後のTransformerを応用するために改善する研究では,マルチヘッドアテンションを各自の目的に沿って改良することが,盛んに行われている .
自己アテンションが,「系列内Encode」と「系列内Decode」の主役である.よって,CNNの主役部品である畳み込み層のようにTransformerのマルチヘッド自己アテンションは重要部品である.
3.3.2 クロスモーダル問題での重要性
また「クロスモーダルな問題 (Vision-LanguageやVision-Audio)」などでも,2~3モーダル間の各EncoderやDecoderをあいだで関係づける際に,クロスモーダルな(マルチヘッド)アテンションを使用することが多い.
ただし,Transformer が隆盛になる前の,seq2seq with attention 時代から,V-LやV-Aでは,(モーダル間の)クロスモーダルアテンションの研究は盛んである.つまり,Transformer時代になっても「クロスモーダルなアテンションの重要さは継続している」というだけではある.
4. まとめ
Transformer の,主要部品であるマルチヘッドアテンションについて詳しく紹介した.マルチヘッドアテンションの重要点は,系列内の自己アテンションを担当する点にあり,これにより,系列内の「全トークン間」における長い依存関係を加味したEncoding/Decodingを,全トークン表現ベクトルに対して一気に適用できる.
4.1 改善モデルTransformer-XLの話
ちなみに,元祖のTransformerにも,マルチヘッドアテンションを主部品にしているので,固定長コンテキストしかモデリングできない問題があった.
Transformerは,入力系列行列Qと出力系列行列V(=K)のあいだで,QKVアテンションを行う.その際に各ヘッドがドット積行列計算ベースなので,系列長の可変具合にはうまく対応できない.自己アテンションの採用により,長期コンテキストはアテンションで捉えやすくなったが,固定長コンテキストしか学習できないのは,まだ少し良くない面があった.
そこで,Transformer-XLでは,TransformerにRNN的な時系列遷移を半分追加することで,可変長コンテキストに対応するようにした(※ 親記事Transformerも,5節がTransformer-XLについて).
関連書籍
- Transformers for Natural Language Processing(2nd Edition): Build, train, and fine-tune deep neural network architectures for NLP with Python, PyTorch, TensorFlow, BERT, and GPT-3, Denis Rothman, Packet Publishing, March 2022.
- 2節 Getting Started with the Architecture of the Transformer Model
- 書籍全体のもう少し詳しい紹介はCV・DLのおすすめ書籍の4節「NLP本紹介」にて.
- 深層学習 改訂第2版 (機械学習プロフェッショナルシリーズ) 岡谷貴之,講談社,2022.
- 第2版になりseq2seq やアテンション,Transformerの章が追加されている
- 1.1.4 爆発的な発展 (p5)
- 7.3 トランスフォーマー (p153)
- IT Text 自然言語処理の基礎 岡﨑直観, 荒瀬由紀, 鈴木潤, 鶴岡慶雅, 宮尾祐介 .オーム社,2022.
- 6.3 Transformer の構成要素 → マルチヘッド注意機構 (p149)
- Getting Started with Google BERT: Build and train state-of-the-art natural language processing models using BERT, Packet publishing, 2020.
- 「Q,K,V」行列の中身が,例文の英語文と共に詳細に説明されている.行列の実際の値を図として,デバッグ表示しながらの説明があり,イメージしやすい.
References
- [西田 2022]「自然言語処理とVision-and-Language」人工知能学会全国大会(第36回) チュートリアル,2022, 西田京介.
- [Bahdanau et al., 2015] Bahdanau, Dzmitry, Kyunghyun Cho, and Yoshua Bengio. “Neural machine translation by jointly learning to align and translate.” In ICLR, 2015.
- [Gehring et al., 2017] Gehring, Jonas, et al. “Convolutional sequence to sequence learning.” In ICML. , 2017.
- [Luong et al., 2015] M.-T. Luong, H. Pham, and C. D. Manning, Effective approaches to attention-based neural machine translation. In EMNLP, 2015.
- [Sennrich et al., 2016] Rico Sennrich, Barry Haddow, and Alexandra Birch. 2016. Neural machine translation of rare words with subword units. In ACL, 2016.
- [Vaswani et al., 2017] Ashish Vaswani, et al. Attention is all you need. In NIPS, 2017.
参照外部リンク
- Multi-Head Attention Explained | paper with code
- The Illustrated Transformer | Jay Alammar
- PyTorch Lightning | Tutorial 5: Transformers and Multi-Head Attention
- Qiita(@keitain): 自然言語処理とDeep Learning – TransformerのMulti-Head Attentionの実装について