class MultiHead(nn.Module):
def __init__(self, n_head, model_dim, drop_rate):
self.head_dim = model_dim // n_head
self.model_dim = model_dim
self.wq = nn.Linear(model_dim, n_head * self.head_dim)
self.wk = nn.Linear(model_dim, n_head * self.head_dim)
self.wv = nn.Linear(model_dim, n_head * self.head_dim)
self.o_dense = nn.Linear(model_dim, model_dim)
self.o_drop = nn.Dropout(drop_rate)
self.layer_norm = nn.LayerNorm(model_dim)
def forward(self, q, k, v, mask, training):
key = self.wk(k) # [n, step, num_heads * head_dim]
value = self.wv(v) # [n, step, num_heads * head_dim]
query = self.wq(q) # [n, step, num_heads * head_dim]
query = self.split_heads(query) # [n, n_head, q_step, h_dim]
key = self.split_heads(key)
value = self.split_heads(value) # [n, h, step, h_dim]
context = self.scaled_dot_product_attention(query, key, value, mask) # [n, q_step, h*dv]
o = self.o_dense(context) # [n, step, dim]
o = self.layer_norm(residual + o)
def split_heads(self, x):
x = torch.reshape(x, (x.shape[0], x.shape[1], self.n_head, self.head_dim))
return x.permute(0, 2, 1, 3)
def scaled_dot_product_attention(self, q, k, v, mask=None):
dk = torch.tensor(k.shape[-1]).type(torch.float)
score = torch.matmul(q, k.permute(0, 1, 3, 2)) / (torch.sqrt(dk) + 1e-8) # [n, n_head, step, step]
# change the value at masked position to negative infinity,
# so the attention score at these positions after softmax will close to 0.
score = score.masked_fill_(mask, -np.inf)
self.attention = softmax(score, dim=-1)
context = torch.matmul(self.attention, v) # [n, num_head, step, head_dim]
context = context.permute(0, 2, 1, 3) # [n, step, num_head, head_dim]
context = context.reshape((context.shape[0], context.shape[1], -1))
return context # [n, step, model_dim]