33#
44# BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile
55# INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_weight_only
6+ # INT8 MP: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_mixed_precision
67
78import os
89
910os .environ ["PYTORCH_CUDA_ALLOC_CONF" ] = "expandable_segments:True"
1011
1112import argparse
13+ import time
1214from functools import partial
1315from pathlib import Path
1416
1820from torch .utils .checkpoint import checkpoint
1921from tqdm import tqdm
2022
21- from torchao ._models .llama .model import ModelArgs , Transformer
23+ from torchao ._models .llama .model import ModelArgs , Transformer , transformer_configs
2224from torchao .prototype import low_bit_optim
23- from torchao .prototype .quantized_training import int8_weight_only_quantized_training
25+ from torchao .prototype .quantized_training import (
26+ int8_mixed_precision_training ,
27+ int8_weight_only_quantized_training ,
28+ )
2429from torchao .quantization .quant_api import quantize_
2530
2631
32+ # not official models
33+ transformer_configs .update (
34+ (
35+ ("470M" , dict (n_layer = 24 , n_head = 16 , dim = 1024 , intermediate_size = 4096 )),
36+ ("1B" , dict (n_layer = 24 , n_head = 24 , dim = 1536 , intermediate_size = 6144 )),
37+ )
38+ )
39+
40+
2741# hack from fairseq
2842# https://github.com/facebookresearch/fairseq/blob/920a548ca770fb1a951f7f4289b4d3a0c1bc226f/fairseq/modules/checkpoint_activations.py
2943def enable_activation_checkpointing (m : torch .nn .Module ):
3044 assert not hasattr (m , "_forward" )
3145 m ._forward = m .forward
32- m .forward = partial (checkpoint , m .forward )
46+ m .forward = partial (checkpoint , m .forward , use_reentrant = False )
3347
3448
3549def get_loss (model : Transformer , batch : torch .Tensor ):
36- logits = model (batch )[:, :- 1 ].flatten (0 , 1 )
50+ logits = model (batch )[:, :- 1 ].float (). flatten (0 , 1 )
3751 labels = batch [:, 1 :].flatten ()
3852 return torch .nn .functional .cross_entropy (logits , labels )
3953
@@ -77,12 +91,7 @@ def get_tinystories():
7791
7892if __name__ == "__main__" :
7993 parser = argparse .ArgumentParser ()
80- # default config is 470M
81- parser .add_argument ("--d_model" , type = int , default = 1024 )
82- parser .add_argument ("--depth" , type = int , default = 24 )
83- parser .add_argument ("--ffn_size" , type = int , default = 4096 )
84- parser .add_argument ("--head_dim" , type = int , default = 64 )
85-
94+ parser .add_argument ("--model" , default = "470M" , choices = transformer_configs .keys ())
8695 parser .add_argument ("--quantize" )
8796 parser .add_argument ("--activation_checkpointing" , action = "store_true" )
8897 parser .add_argument ("--compile" , action = "store_true" )
@@ -98,44 +107,48 @@ def get_tinystories():
98107 parser .add_argument ("--project" , default = "int8_quantized_training" )
99108 parser .add_argument ("--run_name" )
100109 parser .add_argument ("--seed" , type = int )
110+ parser .add_argument ("--log_interval" , type = int , default = 10 )
101111 args = parser .parse_args ()
102112
103113 if args .seed is not None :
104114 torch .manual_seed (args .seed )
105115
106- config = ModelArgs (
107- block_size = args .seq_len ,
108- n_layer = args .depth ,
109- n_head = args .d_model // args .head_dim ,
110- dim = args .d_model ,
111- intermediate_size = args .ffn_size ,
112- )
116+ config = ModelArgs .from_name (args .model )
117+ config .block_size = args .seq_len
113118 model = Transformer (config ).bfloat16 ().cuda ()
114119 with torch .device ("cuda" ):
115120 model .setup_caches (args .batch_size , args .seq_len , training = True )
116121 if args .activation_checkpointing :
117122 for layer in model .layers :
118123 enable_activation_checkpointing (layer )
124+
125+ # don't apply int8_mixed_precision to LM head, since it can cause convergence issue.
126+ # TODO: might want to do the same for int8_weight_only to standardize.
119127 if args .quantize == "int8_weight_only" :
120128 quantize_ (model , int8_weight_only_quantized_training (), set_inductor_config = False )
129+ elif args .quantize == "int8_mixed_precision" :
130+ quantize_ (model .layers , int8_mixed_precision_training (), set_inductor_config = False )
121131 elif args .quantize is not None :
122132 raise ValueError (f"Unsupported quantize={ args .quantize } " )
133+
123134 print (f"No. of params: { sum (p .numel () for p in model .parameters ()):,} " )
124135 print (f"No. of buffers: { sum (p .numel () for p in model .buffers ()):,} " )
136+ torch .cuda .reset_peak_memory_stats () # don't count memory occupied by unquantized weights
125137
126138 # only use optimizers from torchao.prototype.low_bit_optim to support quantized training
127139 if args .optim == "AdamW" :
128140 args .optim = "_AdamW"
129141 optim = getattr (low_bit_optim , args .optim )(model .parameters (), lr = args .lr , weight_decay = args .weight_decay )
130142
131143 data = get_tinystories ().cuda ()
144+ args .torch_version = torch .__version__
132145 run = wandb .init (dir = "/tmp" , config = args , project = args .project , name = args .run_name )
133146
134147 step = 0
135- log_interval = 50
136148 pbar = tqdm (total = args .n_steps , dynamic_ncols = True )
137149 model .train ()
138150 _get_loss = torch .compile (get_loss ) if args .compile else get_loss
151+ time0 = time .time ()
139152
140153 while step < args .n_steps :
141154 # randomly select a continuous chunk, then reshape it
@@ -145,13 +158,17 @@ def get_tinystories():
145158 loss = _get_loss (model , batch )
146159 loss .backward ()
147160
148- if step % log_interval == 0 :
161+ if step % args . log_interval == 0 :
149162 log_dict = dict (
150163 loss = loss .item (),
151164 lr = optim .param_groups [0 ]["lr" ],
152165 max_memory_allocated = torch .cuda .max_memory_allocated () / 1e9 ,
153- max_memory_active = torch .cuda .memory_stats (). get ( "active_bytes.all.peak" , 0 ) / 1e9 ,
166+ max_memory_reserved = torch .cuda .max_memory_reserved ( ) / 1e9 ,
154167 )
168+ if step > 0 :
169+ time1 = time .time ()
170+ log_dict ["tokens_per_second" ] = (args .log_interval * args .batch_size * args .seq_len ) / (time1 - time0 )
171+ time0 = time1
155172 run .log (log_dict , step = step )
156173 pbar .set_postfix (loss = log_dict ["loss" ])
157174
0 commit comments