July 8, 2024 | 03:08
Rotary Positional Embeddingとは?
Rotary Positional Embeddingは、Transformerモデルにおける位置エンコーディングの一種です。位置エンコーディングは、シーケンスデータ(例えば、文章や時系列データ)を扱う際に、各トークンの位置情報をモデルに提供するための重要な機能です。
従来のPositional Encodingとその課題
従来のPositional Encoding(例:Sinusoidal Positional Encoding)は、各位置に対して固定の位置ベクトルを生成し、それを入力埋め込みに加算する方式でした。しかし、この方法にはトークン間の相対的な位置関係を考慮できない、大きなシーケンス長に対応できないと言った問題点がありました。
Rotary Positional Embeddingの解決策
Rotary Positional Embeddingは、複素数の回転操作を用いて位置情報をベクトルに組み込むというアイデアを利用しています。具体的には、各トークンのQueryとKeyのベクトルを複素数として扱い、その複素数を位置に応じた角度だけ回転させています。位置 (i) と位置 (i+k) のトークンの回転角度の差は位置に寄らず一定であり、位置情報を相対的に捉える能力を持ちます。
複素数と回転
複素数は、実部と虚部からなる数のことを指します。複素数 (z = a + ib) (ここで、(a) と (b) は実数、(i) は虚数単位)を極形式で表すと、(z = r * (\cos(\theta) + i \sin(\theta))) となります。ここで、(r) は複素数の絶対値、(\theta) は偏角(回転角)を表します。
複素数平面では、複素数を (\theta) ラジアンだけ反時計回りに回転させると、新たな複素数 (z’ = z \cdot (\cos(\theta) + i \sin(\theta))) が得られます。これは、複素数の乗法が回転とスケーリング(大きさの変更)を表すからです。
Rotary Positional Embeddingの実装
Rotary Positional Embeddingは、この複素数の回転操作を用いて位置情報をベクトルに組み込むというアイデアを利用しています。具体的には、以下のような操作を行います:
-
逆数周波数の計算:まず、各次元に対する逆数周波数を計算します。これは、Embeddingsの各次元が異なる周波数を持つことを意味します。
# パラメータの設定 base = 10000 dim = 1024 # iの値を0から1023まで2ステップずつ変化させます # ここで、iは埋め込みベクトルの次元を表しています i = np.arange(0, dim, 2) # 逆数周波数を計算します inv_freq = 1.0 / (base ** (i / dim)) # torch.Tensorに変換します inv_freq = torch.tensor(inv_freq)
-
位置エンコーディングの生成:次に、各位置と各周波数の組み合わせに対応する値を持つテンソルを生成します。これを2つ連結し、その結果に対してcosineとsineの関数を適用することで、位置エンコーディングを得ます。2つ連結するのは実部と虚部を表すためです。
# tを生成します(ここでは、シーケンス長を8192とします) t = torch.arange(8192) # freqsを計算します freqs = torch.einsum("i,j->ij", t, inv_freq) # embを計算します emb = torch.cat((freqs, freqs), dim=-1) # cosとsinを計算します cos = emb.cos() sin = emb.sin()
embはfreqsをconcatenateしたものです。
cosとsinはそれぞれembのcosineとsineを計算したものです。これにより、各位置とEmbeddingsの各次元に対応するcosineとsineの値を持つテンソルが得られます。
左上部分が周波数が高い部分で右下に行くにつれて周波数が低くなります。 Embedding vectorのインデックスが低い部分(左の周波数が高い部分)は位置ごとに異なるベクトルを割り当てて、インデックスが大きい部分(右の周波数が低い部分)は位置が変わっても近いベクトルを割り当てることにより、位置情報をモデルに伝えています。
-
位置エンコーディングの適用:最後に、QueryとKeyのベクトルに対して位置エンコーディングを適用します。具体的には、各ベクトルを位置に応じた角度だけ「回転」させます。
def rotate_half(x): x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] return torch.cat( (-x2, x1), dim=x1.ndim - 1 ) # dim=-1 triggers a bug in earlier torch versions (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
ここで、
rotate_half
関数はベクトルの前半部分と後半部分を入れ替える操作を行います。そして、q * cos
とrotate_half(q) * sin
の組み合わせにより、Queryベクトルq
が位置に応じた角度だけ「回転」されます。Keyベクトルk
についても同様の操作が行われます。
このように、Rotary Positional Embeddingは、位置情報をベクトルに直接足し合わせるのではなく、ベクトルの方向を変化させるという方法で位置情報を組み込んでいます。これにより、各位置でのベクトルの方向が変化し、位置情報が効果的に埋め込まれます。