Skip to content

Commit 24a78fe

Browse files
authored
Add run_tutorials github action and fix existing errors (#1546)
* Add run_tutorials github action and fix existing errors Summary: Added a GHA button for release oncall to check tutorial code are runnable can also be enabled by add a tag `ciflow/tutorials` Test Plan: CI github action Reviewers: Subscribers: Tasks: Tags: * add yml * add script * revert profile changes
1 parent 79979ec commit 24a78fe

File tree

9 files changed

+87
-138
lines changed

9 files changed

+87
-138
lines changed

.github/pytorch-probot.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
mergebot: True
22
ciflow_push_tags:
33
- ciflow/benchmark
4+
- ciflow/tutorials
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
name: Run tutorials
2+
3+
on:
4+
push:
5+
tags:
6+
- ciflow/tutorials/*
7+
jobs:
8+
run_tutorials:
9+
runs-on: linux.aws.a100
10+
strategy:
11+
matrix:
12+
torch-spec:
13+
- '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124'
14+
steps:
15+
- uses: actions/checkout@v4
16+
17+
- name: Setup miniconda
18+
uses: pytorch/test-infra/.github/actions/setup-miniconda@main
19+
with:
20+
python-version: "3.9"
21+
22+
- name: Run tutorials
23+
shell: bash
24+
run: |
25+
set -eux
26+
${CONDA_RUN} python -m pip install --upgrade pip
27+
${CONDA_RUN} pip install ${{ matrix.torch-spec }}
28+
${CONDA_RUN} pip install -r dev-requirements.txt
29+
${CONDA_RUN} pip install .
30+
cd tutorials
31+
${CONDA_RUN} sh run_all.sh

torchao/quantization/linear_activation_quantized_tensor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,10 @@ def _quantized_linear_op(
8080
input_quant_func = weight_tensor.input_quant_func
8181
original_weight_tensor = weight_tensor.original_weight_tensor
8282
quant_kwargs = weight_tensor.quant_kwargs
83-
aqt = input_quant_func(input_tensor, **quant_kwargs)
84-
return torch.nn.functional.linear(aqt, original_weight_tensor, bias)
83+
quantized_tensor = input_quant_func(input_tensor, **quant_kwargs)
84+
return torch.nn.functional.linear(
85+
quantized_tensor, original_weight_tensor, bias
86+
)
8587

8688
@classmethod
8789
def from_float(

tutorials/calibration_flow/awq_like.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,13 @@ def test_awq(target_dtype: torch.dtype, mapping_type: MappingType):
176176
act_obs = AffineQuantizedMinMaxObserver(
177177
mapping_type,
178178
target_dtype,
179-
granularity_type=PerTensor(),
179+
granularity=PerTensor(),
180180
eps=torch.finfo(torch.float32).eps,
181181
)
182182
weight_obs = AffineQuantizedMinMaxObserver(
183183
mapping_type,
184184
target_dtype,
185-
granularity_type=PerAxis(axis=0),
185+
granularity=PerAxis(axis=0),
186186
eps=torch.finfo(torch.float32).eps,
187187
)
188188

tutorials/calibration_flow/gptq_like.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,20 @@
3333
import torch
3434
from torch.utils._pytree import tree_flatten, tree_unflatten
3535

36-
from torchao.dtypes import to_affine_quantized_intx_static
36+
from torchao.dtypes import (
37+
to_affine_quantized_intx,
38+
to_affine_quantized_intx_static,
39+
)
3740
from torchao.quantization import (
41+
AffineQuantizedMinMaxObserver,
3842
LinearActivationQuantizedTensor,
43+
MappingType,
44+
PerTensor,
45+
fake_quantize_affine,
3946
quantize_,
4047
to_linear_activation_quantized,
4148
)
42-
from torchao.quantization.granularity import PerTensor
43-
from torchao.quantization.observer import (
44-
AffineQuantizedMinMaxObserver,
45-
)
4649
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
47-
from torchao.quantization.quant_primitives import (
48-
MappingType,
49-
fake_quantize_affine,
50-
)
5150
from torchao.quantization.utils import compute_error
5251

5352
torch.manual_seed(0)
@@ -211,7 +210,7 @@ def forward_pre_hook(
211210
act_obs = AffineQuantizedMinMaxObserver(
212211
MappingType.ASYMMETRIC,
213212
torch.uint8,
214-
granularity_type=PerTensor(),
213+
granularity=PerTensor(),
215214
eps=torch.finfo(torch.float32).eps,
216215
scale_dtype=torch.float32,
217216
zero_point_dtype=torch.int32,
@@ -254,8 +253,8 @@ def _register_forward_pre_hook(module: torch.nn.Module):
254253

255254

256255
# using a function to align with the API in quant_api
257-
def apply_activation_static_quant():
258-
def _apply_activation_static_quant(observed_linear):
256+
def apply_activation_static_weight_quant():
257+
def _apply_activation_static_weight_quant(observed_linear):
259258
target_dtype = torch.uint8
260259

261260
# we can quantize the weight here as well
@@ -268,16 +267,21 @@ def _apply_activation_static_quant(observed_linear):
268267
input_quant_func = lambda x: to_affine_quantized_intx_static(
269268
x, act_scale, act_zero_point, x.shape, target_dtype
270269
)
270+
# for demo purpose only, we quantize the weight here
271+
weight = observed_linear.weight
272+
weight = to_affine_quantized_intx(
273+
weight, MappingType.SYMMETRIC, (1, weight.shape[-1]), torch.int8
274+
)
271275
observed_linear.weight = torch.nn.Parameter(
272-
to_linear_activation_quantized(observed_linear.weight, input_quant_func),
276+
to_linear_activation_quantized(weight, input_quant_func),
273277
requires_grad=False,
274278
)
275279

276280
del observed_linear.input_scale
277281
del observed_linear.input_zp
278282
return observed_linear
279283

280-
return _apply_activation_static_quant
284+
return _apply_activation_static_weight_quant
281285

282286

283287
example_inputs = (torch.randn(32, 64),)
@@ -294,7 +298,7 @@ def _apply_activation_static_quant(observed_linear):
294298

295299
# just quantizing activation since we only observed quantization, this could be extended to support
296300
# quantizing weight as well
297-
quantize_(m, apply_activation_static_quant(), _is_linear)
301+
quantize_(m, apply_activation_static_weight_quant(), _is_linear)
298302
for l in m.modules():
299303
if isinstance(l, torch.nn.Linear):
300304
assert isinstance(l.weight, LinearActivationQuantizedTensor)

tutorials/calibration_flow/static_quant.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
to_affine_quantized_floatx_static,
1414
to_affine_quantized_intx_static,
1515
)
16+
from torchao.float8.inference import Float8MMConfig
1617
from torchao.quantization import quantize_, to_linear_activation_quantized
1718
from torchao.quantization.granularity import (
1819
PerAxis,
@@ -26,6 +27,7 @@
2627
MappingType,
2728
)
2829
from torchao.quantization.utils import compute_error
30+
from torchao.utils import is_sm_at_least_90
2931

3032

3133
class ObservedLinear(torch.nn.Linear):
@@ -90,12 +92,13 @@ def weight_quant_func(weight):
9092
weight, weight_scale, weight_zero_point, block_size, target_dtype
9193
)
9294
elif target_dtype == torch.float8_e4m3fn:
95+
mm_config = Float8MMConfig(use_fast_accum=True)
9396
return to_affine_quantized_floatx_static(
9497
weight,
9598
weight_scale,
9699
block_size,
97100
target_dtype,
98-
Float8Layout(mm_config=None),
101+
Float8Layout(mm_config=mm_config),
99102
)
100103
else:
101104
raise ValueError(f"Unsupported target dtype {target_dtype}")
@@ -248,15 +251,15 @@ def test_static_quant(target_dtype: torch.dtype, mapping_type: MappingType):
248251
act_obs = AffineQuantizedMinMaxObserver(
249252
mapping_type,
250253
target_dtype,
251-
granularity_type=PerTensor(),
254+
granularity=PerTensor(),
252255
eps=torch.finfo(torch.float32).eps,
253256
scale_dtype=torch.float32,
254257
zero_point_dtype=torch.float32,
255258
)
256259
weight_obs = AffineQuantizedMinMaxObserver(
257260
mapping_type,
258261
target_dtype,
259-
granularity_type=PerAxis(axis=0),
262+
granularity=PerAxis(axis=0),
260263
eps=torch.finfo(torch.float32).eps,
261264
scale_dtype=torch.float32,
262265
zero_point_dtype=torch.float32,
@@ -293,4 +296,6 @@ def test_static_quant(target_dtype: torch.dtype, mapping_type: MappingType):
293296

294297
if __name__ == "__main__":
295298
test_static_quant(torch.uint8, MappingType.ASYMMETRIC)
296-
test_static_quant(torch.float8_e4m3fn, MappingType.SYMMETRIC)
299+
if is_sm_at_least_90():
300+
# this is testing per row float8 quant
301+
test_static_quant(torch.float8_e4m3fn, MappingType.SYMMETRIC)

tutorials/developer_api_guide/my_trainable_tensor_subclass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"""
1212

1313
import torch
14-
from my_dtype_tensor_subclass import MyDTypeLayout, MyDTypeTensor
14+
from my_dtype_tensor_subclass import MyDTypeTensor, MyDTypeTensorImpl
1515
from torch.utils._python_dispatch import return_and_correct_aliasing
1616

1717
from torchao.dtypes.utils import Layout, PlainLayout
@@ -35,7 +35,7 @@ def _quantize(
3535
cls,
3636
input_float: torch.Tensor,
3737
_layout: Layout,
38-
) -> MyDTypeLayout:
38+
) -> MyDTypeTensorImpl:
3939
"""
4040
Convert from a floating point tensor (fp32/fp16/bf16) to the desired dtype.
4141
"""

tutorials/huggingface_24sparse_example.py

Lines changed: 0 additions & 113 deletions
This file was deleted.

tutorials/run_all.sh

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#!/bin/bash
2+
find . -type d | while read dir; do
3+
if [ -f "$dir/run.sh" ]; then
4+
echo "Running: $dir/run.sh"
5+
pushd "$dir"
6+
bash run.sh
7+
popd
8+
else
9+
find "$dir" -maxdepth 1 -name "*.py" | while read file; do
10+
if [[ "$file" == *"tensor_parallel"* ]]; then
11+
echo "Running: torchrun --standalone --nnodes=1 --nproc-per-node=1 $file"
12+
torchrun --standalone --nnodes=1 --nproc-per-node=4 "$file"
13+
else
14+
echo "Running: python $file"
15+
python "$file"
16+
fi
17+
done
18+
fi
19+
done

0 commit comments

Comments
 (0)