Commit ceceea5
authored
promote blocksparse from prototype, make it faster (#1734)
This PR promotes block sparsity from prototype in torchao.
Chiefly, it ports over the triton addmm blocksparse kernels from core, and makes several performance improvements to them.
All of the numbers reported below are for an H100, with blocksize=64 and sparsity_level=0.9. The default dense baseline is 134 tok/s
1) Adds padding support to the triton kernel for dense matrices with dimension < 16, like those we run into during decoding. (214 -> 218 tok/s)
2) Changes the default [num_stages](triton-lang/triton#512) parameter from 1 to 4. This has a large effect on performance, and it seemed like the default kernel autotuning either does not modify or deems this parameter to be unimportant for some reason. (218 -> 263 tok/s).
3) Adds an env_var, BSR_AUTOTUNE, that users can use if they want to do kernel autotuning on top of the default parameters. (263 -> 266 tok/s) This seems to matter more for bs=n compute bound workloads, where I see a reduction from 0.3855 to 0.3745s on bs=8192 prefill (roughly 3%)
So in total we are seeing a **1.985x** speedup 🚀
I've also updated the documentation to not reference prototype - planning on updating the diagram in a subsequent PR.
### Testing
I added a new test case for the padding inputs and moved the test file out of prototype.
```
python test/sparsity/test_sparse_api.py
```1 parent ed16fe7 commit ceceea5
File tree
9 files changed
+843
-43
lines changed- test/sparsity
- torchao
- _models/llama
- kernel
- sparsity
9 files changed
+843
-43
lines changedLines changed: 4 additions & 5 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
132 | 132 | | |
133 | 133 | | |
134 | 134 | | |
135 | | - | |
136 | | - | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
137 | 138 | | |
138 | 139 | | |
139 | 140 | | |
| |||
152 | 153 | | |
153 | 154 | | |
154 | 155 | | |
155 | | - | |
156 | | - | |
157 | | - | |
| 156 | + | |
158 | 157 | | |
159 | 158 | | |
160 | 159 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
793 | 793 | | |
794 | 794 | | |
795 | 795 | | |
796 | | - | |
| 796 | + | |
797 | 797 | | |
798 | 798 | | |
| 799 | + | |
| 800 | + | |
| 801 | + | |
| 802 | + | |
| 803 | + | |
| 804 | + | |
| 805 | + | |
| 806 | + | |
| 807 | + | |
| 808 | + | |
| 809 | + | |
| 810 | + | |
| 811 | + | |
| 812 | + | |
| 813 | + | |
| 814 | + | |
| 815 | + | |
| 816 | + | |
| 817 | + | |
| 818 | + | |
| 819 | + | |
| 820 | + | |
| 821 | + | |
| 822 | + | |
| 823 | + | |
| 824 | + | |
| 825 | + | |
| 826 | + | |
799 | 827 | | |
800 | 828 | | |
801 | 829 | | |
| |||
810 | 838 | | |
811 | 839 | | |
812 | 840 | | |
813 | | - | |
| 841 | + | |
| 842 | + | |
| 843 | + | |
| 844 | + | |
814 | 845 | | |
815 | 846 | | |
816 | 847 | | |
| |||
849 | 880 | | |
850 | 881 | | |
851 | 882 | | |
852 | | - | |
| 883 | + | |
853 | 884 | | |
854 | 885 | | |
855 | 886 | | |
| |||
919 | 950 | | |
920 | 951 | | |
921 | 952 | | |
922 | | - | |
| 953 | + | |
923 | 954 | | |
924 | 955 | | |
925 | 956 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
1 | 2 | | |
2 | 3 | | |
3 | 4 | | |
| 5 | + | |
4 | 6 | | |
5 | 7 | | |
6 | 8 | | |
0 commit comments