Skip to content

Commit 3f7fc14

Browse files
gau-nernstandrewor14
authored andcommitted
Add INT8 mixed-precision training (#748)
* initial commit * expose some UX. update test * add test. update bench * update test. add doc * fix ngpu * fix FSDP * fix * fix fsdp test * fix * grammar * simplify fsdp test * update benchmark script * update * make claim more conservative * register fused adam * update benchmark script * add more ops * update default * use TorchAOBaseTensor * fix fsdp param_dtype * fix param_dtype * dtype check to prevent unnecessary errors * move checks * add note * fix * simplify script * add module-based UX * fix * use FP8 impl of __torch_dispatch__ * rename _dynamice interface * update test * fix compile on 2.4 * log torch version * make log interval customizable * make naming for explicit * update readme * some change * fix big bug * add docstring. update _get_linear_inserter * add TorchAOBaseTensor back * fix FSDP * update FSDP test. add autocast support * reduce iter * update int8_mm fallback * put leading dims logic to _dynamic_int8_mm
1 parent 10d038f commit 3f7fc14

File tree

9 files changed

+771
-175
lines changed

9 files changed

+771
-175
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import pandas as pd
2+
import torch
3+
from triton.testing import do_bench
4+
5+
from torchao.prototype.quantized_training.int8_mm import int8_mm_dequant
6+
7+
8+
def bench_f(f, *args):
9+
return do_bench(lambda: f(*args), fast_flush=False, return_mode="median")
10+
11+
12+
shapes = [(sz, sz, sz) for sz in [1024, 2048, 4096]]
13+
14+
# Llama-8B shapes
15+
shapes += [
16+
# linear in attention
17+
(32_768, 4096, 4096),
18+
(4096, 4096, 32_768),
19+
# linear in feed-forward
20+
(32_768, 14_336, 4096),
21+
(32_768, 4096, 14_336),
22+
(14_336, 4096, 32_768),
23+
]
24+
25+
data = []
26+
for M, N, K in shapes:
27+
print(f"{M=}, {N=}, {K=}")
28+
29+
A_bf16 = torch.randn(M, K).bfloat16().cuda()
30+
B_bf16 = torch.randn(N, K).bfloat16().cuda()
31+
A_i8 = torch.randint(-128, 127, size=(M, K), dtype=torch.int8).cuda()
32+
B_i8 = torch.randint(-128, 127, size=(N, K), dtype=torch.int8).cuda()
33+
A_scale = torch.randn(M).bfloat16().cuda()
34+
B_scale = torch.randn(N).bfloat16().cuda()
35+
36+
# benchmark F.linear() i.e. A @ B.T
37+
bf16_time = bench_f(torch.mm, A_bf16, B_bf16.T)
38+
i8_time = bench_f(torch._int_mm, A_i8, B_i8.T)
39+
i8_dequant_time = bench_f(int8_mm_dequant, A_i8, B_i8.T, A_scale, B_scale)
40+
41+
sample = [M, N, K, bf16_time / i8_time, bf16_time / i8_dequant_time]
42+
data.append(sample)
43+
44+
df = pd.DataFrame(data, columns=["M", "N", "K", "CuBLAS INT8 speedup", "Triton INT8 dequant speedup"])
45+
print(df.to_markdown())

benchmarks/quantized_training/pretrain_llama2.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
#
44
# BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile
55
# INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_weight_only
6+
# INT8 MP: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_mixed_precision
67

78
import os
89

910
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
1011

1112
import argparse
13+
import time
1214
from functools import partial
1315
from pathlib import Path
1416

@@ -18,22 +20,34 @@
1820
from torch.utils.checkpoint import checkpoint
1921
from tqdm import tqdm
2022

21-
from torchao._models.llama.model import ModelArgs, Transformer
23+
from torchao._models.llama.model import ModelArgs, Transformer, transformer_configs
2224
from torchao.prototype import low_bit_optim
23-
from torchao.prototype.quantized_training import int8_weight_only_quantized_training
25+
from torchao.prototype.quantized_training import (
26+
int8_mixed_precision_training,
27+
int8_weight_only_quantized_training,
28+
)
2429
from torchao.quantization.quant_api import quantize_
2530

2631

32+
# not official models
33+
transformer_configs.update(
34+
(
35+
("470M", dict(n_layer=24, n_head=16, dim=1024, intermediate_size=4096)),
36+
("1B", dict(n_layer=24, n_head=24, dim=1536, intermediate_size=6144)),
37+
)
38+
)
39+
40+
2741
# hack from fairseq
2842
# https://github.com/facebookresearch/fairseq/blob/920a548ca770fb1a951f7f4289b4d3a0c1bc226f/fairseq/modules/checkpoint_activations.py
2943
def enable_activation_checkpointing(m: torch.nn.Module):
3044
assert not hasattr(m, "_forward")
3145
m._forward = m.forward
32-
m.forward = partial(checkpoint, m.forward)
46+
m.forward = partial(checkpoint, m.forward, use_reentrant=False)
3347

3448

3549
def get_loss(model: Transformer, batch: torch.Tensor):
36-
logits = model(batch)[:, :-1].flatten(0, 1)
50+
logits = model(batch)[:, :-1].float().flatten(0, 1)
3751
labels = batch[:, 1:].flatten()
3852
return torch.nn.functional.cross_entropy(logits, labels)
3953

@@ -77,12 +91,7 @@ def get_tinystories():
7791

7892
if __name__ == "__main__":
7993
parser = argparse.ArgumentParser()
80-
# default config is 470M
81-
parser.add_argument("--d_model", type=int, default=1024)
82-
parser.add_argument("--depth", type=int, default=24)
83-
parser.add_argument("--ffn_size", type=int, default=4096)
84-
parser.add_argument("--head_dim", type=int, default=64)
85-
94+
parser.add_argument("--model", default="470M", choices=transformer_configs.keys())
8695
parser.add_argument("--quantize")
8796
parser.add_argument("--activation_checkpointing", action="store_true")
8897
parser.add_argument("--compile", action="store_true")
@@ -98,44 +107,48 @@ def get_tinystories():
98107
parser.add_argument("--project", default="int8_quantized_training")
99108
parser.add_argument("--run_name")
100109
parser.add_argument("--seed", type=int)
110+
parser.add_argument("--log_interval", type=int, default=10)
101111
args = parser.parse_args()
102112

103113
if args.seed is not None:
104114
torch.manual_seed(args.seed)
105115

106-
config = ModelArgs(
107-
block_size=args.seq_len,
108-
n_layer=args.depth,
109-
n_head=args.d_model // args.head_dim,
110-
dim=args.d_model,
111-
intermediate_size=args.ffn_size,
112-
)
116+
config = ModelArgs.from_name(args.model)
117+
config.block_size = args.seq_len
113118
model = Transformer(config).bfloat16().cuda()
114119
with torch.device("cuda"):
115120
model.setup_caches(args.batch_size, args.seq_len, training=True)
116121
if args.activation_checkpointing:
117122
for layer in model.layers:
118123
enable_activation_checkpointing(layer)
124+
125+
# don't apply int8_mixed_precision to LM head, since it can cause convergence issue.
126+
# TODO: might want to do the same for int8_weight_only to standardize.
119127
if args.quantize == "int8_weight_only":
120128
quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False)
129+
elif args.quantize == "int8_mixed_precision":
130+
quantize_(model.layers, int8_mixed_precision_training(), set_inductor_config=False)
121131
elif args.quantize is not None:
122132
raise ValueError(f"Unsupported quantize={args.quantize}")
133+
123134
print(f"No. of params: {sum(p.numel() for p in model.parameters()):,}")
124135
print(f"No. of buffers: {sum(p.numel() for p in model.buffers()):,}")
136+
torch.cuda.reset_peak_memory_stats() # don't count memory occupied by unquantized weights
125137

126138
# only use optimizers from torchao.prototype.low_bit_optim to support quantized training
127139
if args.optim == "AdamW":
128140
args.optim = "_AdamW"
129141
optim = getattr(low_bit_optim, args.optim)(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
130142

131143
data = get_tinystories().cuda()
144+
args.torch_version = torch.__version__
132145
run = wandb.init(dir="/tmp", config=args, project=args.project, name=args.run_name)
133146

134147
step = 0
135-
log_interval = 50
136148
pbar = tqdm(total=args.n_steps, dynamic_ncols=True)
137149
model.train()
138150
_get_loss = torch.compile(get_loss) if args.compile else get_loss
151+
time0 = time.time()
139152

140153
while step < args.n_steps:
141154
# randomly select a continuous chunk, then reshape it
@@ -145,13 +158,17 @@ def get_tinystories():
145158
loss = _get_loss(model, batch)
146159
loss.backward()
147160

148-
if step % log_interval == 0:
161+
if step % args.log_interval == 0:
149162
log_dict = dict(
150163
loss=loss.item(),
151164
lr=optim.param_groups[0]["lr"],
152165
max_memory_allocated=torch.cuda.max_memory_allocated() / 1e9,
153-
max_memory_active=torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1e9,
166+
max_memory_reserved=torch.cuda.max_memory_reserved() / 1e9,
154167
)
168+
if step > 0:
169+
time1 = time.time()
170+
log_dict["tokens_per_second"] = (args.log_interval * args.batch_size * args.seq_len) / (time1 - time0)
171+
time0 = time1
155172
run.log(log_dict, step=step)
156173
pbar.set_postfix(loss=log_dict["loss"])
157174

0 commit comments

Comments
 (0)