1+ import math
2+ from typing import Optional
13import torch
24from torch import nn
5+ from torch .nn import functional as F
36import numpy as np
47
58
@@ -24,7 +27,13 @@ def __init__(self, embed_dim: int, num_heads: int):
2427 self .num_heads : int = num_heads
2528 self .head_size : int = embed_dim // num_heads
2629
27- def forward (self , q : torch .Tensor , k : torch .Tensor , v : torch .Tensor ):
30+ def forward (
31+ self ,
32+ q : torch .Tensor ,
33+ k : torch .Tensor ,
34+ v : torch .Tensor ,
35+ mask : Optional [torch .Tensor ] = None ,
36+ ):
2837 """
2938 Calculates softmax(Q @ KT / sqrt(dk)) @ V .
3039
@@ -46,30 +55,34 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
4655 """
4756
4857 q_len , kv_len = q .size (0 ), k .size (0 )
49- # q: (num_heads, q_len, head_size)
58+ # q -> (num_heads, q_len, head_size)
5059 q = q .view (q_len , self .num_heads , self .head_size ).transpose (0 , 1 )
51- # k: (num_heads, kv_len, head_size)
60+ # k -> (num_heads, kv_len, head_size)
5261 k = k .view (kv_len , self .num_heads , self .head_size ).transpose (0 , 1 )
53- # v: (num_heads, kv_len, head_size)
62+ # v -> (num_heads, kv_len, head_size)
5463 v = v .view (kv_len , self .num_heads , self .head_size ).transpose (0 , 1 )
55-
56- attn_weights = torch .matmul (q , k .transpose (- 1 , - 2 )) / torch .sqrt (
57- torch . tensor ( self .head_size , dtype = torch . float32 )
64+ # scores -> (num_heads, q_len, kv_len)
65+ scores = torch .matmul (q , k .transpose (- 1 , - 2 )) / math .sqrt (
66+ self .head_size
5867 )
59-
60- # logits: (num_heads, q_len, kv_len)
61- logits = torch . softmax ( attn_weights , dim = - 1 )
62-
63- # out: (num_head , q_len, head_size)
64- out = torch .matmul (logits , v )
65- # out: (q_len, embed_dim )
68+ # scores -> (num_heads, q_len, kv_len)
69+ scores = scores + mask if mask is not None else scores
70+ # scores -> (num_heads, q_len, kv_len )
71+ scores = F . softmax ( scores , dim = - 1 )
72+ # out -> (num_heads , q_len, head_size)
73+ out = torch .matmul (scores , v )
74+ # out -> (q_len, num_heads, head_size )
6675 out = out .transpose (0 , 1 ).reshape (q_len , self .embed_dim )
6776
6877 return out
6978
7079
7180class MultiHeadAttention (nn .Module ):
72- def __init__ (self , embed_dim : int , num_heads : int ):
81+ def __init__ (
82+ self ,
83+ embed_dim : int ,
84+ num_heads : int ,
85+ ):
7386 super ().__init__ ()
7487
7588 self .embed_dim : int = embed_dim
@@ -82,7 +95,11 @@ def __init__(self, embed_dim: int, num_heads: int):
8295
8396 self .attn_kernel = MultiHeadAttentionKernel (embed_dim , num_heads )
8497
85- def forward (self , seq : torch .Tensor ):
98+ def forward (
99+ self ,
100+ seq : torch .Tensor ,
101+ mask : Optional [torch .Tensor ] = None ,
102+ ):
86103 """
87104 Parameters
88105 ----------
@@ -96,15 +113,15 @@ def forward(self, seq: torch.Tensor):
96113 Attention output, cached K and cached V.
97114 """
98115
99- # q: (seq_len, embed_dim)
116+ # q -> (seq_len, embed_dim)
100117 q = self .Wq (seq )
101- # k: (seq_len, embed_dim)
118+ # k -> (seq_len, embed_dim)
102119 k = self .Wk (seq )
103- # v: (seq_len, embed_dim)
120+ # v -> (seq_len, embed_dim)
104121 v = self .Wv (seq )
105122
106- # out: (seq_len, embed_dim)
107- out = self .Wo (self .attn_kernel (q , k , v ))
123+ # out -> (seq_len, embed_dim)
124+ out = self .Wo (self .attn_kernel (q , k , v , mask ))
108125
109126 return out , k , v
110127
@@ -125,7 +142,10 @@ def __init__(self, embed_dim: int, num_heads: int):
125142 self .attn_kernel = MultiHeadAttentionKernel (embed_dim , num_heads )
126143
127144 def forward (
128- self , seq : torch .Tensor , k_cache : torch .Tensor , v_cache : torch .Tensor
145+ self ,
146+ seq : torch .Tensor ,
147+ k_cache : torch .Tensor ,
148+ v_cache : torch .Tensor ,
129149 ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
130150 """
131151 Parameters
@@ -147,19 +167,19 @@ def forward(
147167 When decoing, the input seq only has ONE newly generated token.
148168 """
149169
150- # q: (1, embed_dim)
170+ # q -> (1, embed_dim)
151171 q = self .Wq (seq )
152- # k: (1, embed_dim)
172+ # k -> (1, embed_dim)
153173 k = self .Wk (seq )
154- # v: (1, embed_dim)
174+ # v -> (1, embed_dim)
155175 v = self .Wv (seq )
156176
157- # k_cache: (kv_len + 1, embed_dim)
177+ # k_cache -> (kv_len + 1, embed_dim)
158178 k_cache = torch .cat ([k_cache , k .detach ()], dim = 0 )
159- # v_cache: (kv_len + 1, embed_dim)
179+ # v_cache -> (kv_len + 1, embed_dim)
160180 v_cache = torch .cat ([v_cache , v .detach ()], dim = 0 )
161181
162- # out: (seq_len, embed_dim)
182+ # out -> (seq_len, embed_dim)
163183 out = self .Wo (self .attn_kernel (q , k_cache , v_cache ))
164184
165185 return out , k_cache , v_cache
@@ -189,7 +209,11 @@ def __init__(
189209 self .k_cache = nn .Buffer (torch .zeros (size = (0 , embed_dim )))
190210 self .v_cache = nn .Buffer (torch .zeros (size = (0 , embed_dim )))
191211
192- def forward (self , prompt : torch .Tensor , is_prefilling : bool = True ):
212+ def forward (
213+ self ,
214+ prompt : torch .Tensor ,
215+ is_prefilling : bool = True ,
216+ ):
193217 """
194218 Parameters
195219 ----------
@@ -203,19 +227,29 @@ def forward(self, prompt: torch.Tensor, is_prefilling: bool = True):
203227 step, which means `seq_len` should equal to `1`.
204228 """
205229
206- # embedded_prompt: (seq_len, embed_dim)
230+ # embedded_prompt -> (seq_len, embed_dim)
207231 embedded_prompt = self .embed (prompt )
208232
209233 if is_prefilling :
210- # out: (seq_len, embed_dim)
211- # k: (seq_len, embed_dim)
212- # v: (seq_len, embed_dim)
213- out , k , v = self .mha (embedded_prompt )
234+ seq_len = prompt .size (0 )
235+ mask = None
236+ if seq_len > 1 :
237+ mask = torch .full (
238+ (seq_len , seq_len ), - float ("Inf" ), device = prompt .device
239+ )
240+ mask = torch .triu (mask , diagonal = 1 )
241+ # out -> (seq_len, embed_dim)
242+ # k -> (seq_len, embed_dim)
243+ # v -> (seq_len, embed_dim)
244+ out , k , v = self .mha (embedded_prompt , mask )
214245 else :
215- # out: (seq_len, embed_dim)
216- # k: (kv_len, embed_dim)
217- # v: (kv_len, embed_dim)
218- out , k , v = self .cached_mha (embedded_prompt , self .k_cache , self .v_cache )
246+ assert prompt .size (0 ) == 1
247+ # out -> (seq_len, embed_dim)
248+ # k -> (kv_len, embed_dim)
249+ # v -> (kv_len, embed_dim)
250+ out , k , v = self .cached_mha (
251+ embedded_prompt , self .k_cache , self .v_cache
252+ )
219253
220254 # Update k cache and v cache
221255 # [NOTE]
@@ -226,7 +260,7 @@ def forward(self, prompt: torch.Tensor, is_prefilling: bool = True):
226260
227261 # Use the last token to calculate the probability of each word in the
228262 # vocabulary bank:
229- # probs: (vocab_size,)
263+ # probs -> (vocab_size,)
230264 probs = torch .softmax (self .proj_to_vocab (out [- 1 ]), dim = - 1 )
231265
232266 return probs
0 commit comments