1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
|
import torch
import torch.nn as nn
import torch.nn.functional as F
from rich import print
import torch.utils.benchmark as benchmark
import math
class CausalSelfAttention(nn.Module):
def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, dropout:float=0.0):
super().__init__()
assert embed_dimension % num_heads == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias)
# output projection
self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias)
# regularization
self.attn_dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)
self.num_heads = num_heads
self.embed_dimension = embed_dimension
self.dropout = dropout
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer("bias", torch.tril(torch.ones(block_size, block_size))
.view(1, 1, block_size, block_size))
def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (embed_dimension)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k ,v = self.c_attn(x).split(self.embed_dimension, dim=2)
k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
if self.flash:
# efficient attention using Flash Attention CUDA kernels
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
else:
# manual implementation of attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.resid_dropout(self.c_proj(y))
return y
# Windows和Linux上使用GPU
# device = "cuda" if torch.cuda.is_available() else "cpu"
# Mac 上使用 GPU加速:
device = torch.device("cuda")
# device = "mps" if torch.backends.mps.is_built() else "cpu"
# 设置超参数:
batch_size = 32
# max_sequence_len = 128
max_sequence_len = 10240
num_heads = 8
heads_per_dim = 64
embed_dimension = num_heads * heads_per_dim
block_size = 1024
dtype = torch.float16
# 定义计时器:
def torch_timer(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
# 实例化我们上面的 CausalSelfAttention 类
model = CausalSelfAttention(num_heads=num_heads,
embed_dimension=embed_dimension,
bias=False,
dropout=0.1).to("cuda").to(dtype).eval() # mps / cuda
print(model)
# 模拟数据
x = torch.rand(batch_size,
max_sequence_len,
embed_dimension,
device=device,
dtype=dtype)
print(f"原始model 运行时间: {torch_timer(model, x):.3f} microseconds")
# 原始model 运行时间: 9169.492 microseconds
# 编译模型
compiled_model = torch.compile(model, backend="inductor")
compiled_model(x)
print(f"compiled model 运行时间: {torch_timer(compiled_model, x):.3f} microseconds")
# compiled model 运行时间: 6786.322 microseconds
|