import torch
import torch.nn as nn
import numpy as np
import torch.nn.funcational as F
class MultiHeadAttention(nn.module):
def __init__(self, embeded_size, num_heads, attention_head_size):
super().__init__()
self.num_heads = num_heads
self.attention_head_size = attention_head_size
self.embeded_size = embeded_size
self.W_query = nn.Linear(embeded_size, attention_head_size)
self.W_key = nn.Linear(embeded_size, attention_head_size)
self.W_value = nn.Linear(embeded_size, attention_head_size)
def forward(self, x):
batch_size, seq_len, _ = x.size()
querys = self.W_query(x) # (batch_size, sequence_len, attention_head_size)
keys = self.W_keys(x)
values = self.W_values(x)
assert self.attention_head_size % self.num_heads == 0
split_size = self.attention_head_size // self.num_heads
querys = torch.view(self.num_heads, batch_size, seq_len, split_size) # (h, batch_size, sequence_len, split_size)
keys = torch.view(self.num_heads, batch_size, seq_len, split_size)
values = torch.view(self.num_heads, batch_size, seq_len, split_size)
scores = torch.matmul(querys, keys.transpose(2, 3))
scores = scores / (split_size ** 0.5)
scores = F.softmax(scores, dim=-1)
out = torch.matmul(scores, values) # (h, batch_size, sequence_len, split_size)
out = out.transpose(0, 1) # (batch_size, h, sequence_len, split_size)
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
return out, scores