Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ __global__ void reshape_and_cache_flash_kernel(
const int64_t block_stride, const int64_t page_stride,
const int64_t head_stride, const int64_t key_stride,
const int64_t value_stride, const int num_heads, const int head_size,
const int block_size, const float* k_scale, const float* v_scale) {
const int block_size, const float* k_scale, const float* v_scale,
const int kv_scale_stride) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
Expand All @@ -316,21 +317,23 @@ __global__ void reshape_and_cache_flash_kernel(
// this is true for the NHD layout where `head_stride == head_size`
const bool is_contiguous_heads = (head_stride == head_size);

float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale;
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale;
constexpr int VEC_SIZE = (sizeof(scalar_t) == 2) ? 8 : 4;
CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};
if (is_contiguous_heads) {
// NHD layout

if (is_contiguous_heads && kv_scale_stride == 0) {
// NHD layout and k/v_scales are [1] (i.e. single scale for all heads)
// kv cache: [num_blocks, block_size, num_heads, head_size]
float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale;
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale;

CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};

vectorize_with_alignment<VEC_SIZE>(key_src, key_dst, n_elems, threadIdx.x,
blockDim.x, k_op);

vectorize_with_alignment<VEC_SIZE>(value_src, value_dst, n_elems,
threadIdx.x, blockDim.x, v_op);

} else {
// HND layout OR k/v_scales are [num_heads] (i.e. per-attn-head)
// HND layout: heads are strided, but each head_size segment is contiguous
// kv cache: [num_blocks, num_heads, block_size, head_size]
const int lane = threadIdx.x & 31; // 0..31 within warp
Expand All @@ -346,6 +349,16 @@ __global__ void reshape_and_cache_flash_kernel(
cache_t* __restrict__ v_dst_h =
value_dst + static_cast<int64_t>(head) * head_stride;

float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto)
? 0.f
: k_scale[head * kv_scale_stride];
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto)
? 0.f
: v_scale[head * kv_scale_stride];

CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};

// within each head, let the 32 threads of the warp perform the vector
// copy
vectorize_with_alignment<VEC_SIZE>(k_src_h, k_dst_h, head_size, lane, 32,
Expand Down Expand Up @@ -699,7 +712,8 @@ void reshape_and_cache(
slot_mapping.data_ptr<int64_t>(), block_stride, page_stride, \
head_stride, key_stride, value_stride, num_heads, head_size, \
block_size, reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()));
reinterpret_cast<const float*>(v_scale.data_ptr()), \
kv_scale_stride);

void reshape_and_cache_flash(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
Expand All @@ -708,8 +722,9 @@ void reshape_and_cache_flash(
torch::Tensor&
value_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, // [1] or [num_heads]
torch::Tensor& v_scale) { // [1] or [num_heads]
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
// slot_mapping.size(0) because of padding for CUDA graphs.
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
Expand All @@ -732,6 +747,12 @@ void reshape_and_cache_flash(
int64_t head_stride = key_cache.stride(2);
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));

TORCH_CHECK(k_scale.sizes() == v_scale.sizes(),
"k_scale and v_scale must have the same shape");
TORCH_CHECK(k_scale.numel() == 1 || k_scale.numel() == num_heads,
"k_scale and v_scale must be of shape [1] or [num_heads]");
int kv_scale_stride = (k_scale.numel() > 1) ? 1 : 0;

dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
Expand Down
43 changes: 37 additions & 6 deletions csrc/quantization/w8a8/fp8/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,24 @@ __global__ void scaled_fp8_quant_kernel_strided(
});
}

template <typename scalar_t, typename fp8_type>
__global__ void scaled_fp8_quant_kernel_strided_per_channel(
fp8_type* __restrict__ out, const scalar_t* __restrict__ input,
const float* __restrict__ scale, int hidden_size, int64_t in_row_stride,
int64_t out_row_stride) {
const int64_t token_idx = blockIdx.x;
const int tid = threadIdx.x;

const scalar_t* token_in = input + token_idx * in_row_stride;
fp8_type* token_out = out + token_idx * out_row_stride;

for (int i = tid; i < hidden_size; i += blockDim.x) {
const float inv_scale = 1.0f / scale[i];
token_out[i] = scaled_fp8_conversion<true, fp8_type>(
static_cast<float>(token_in[i]), inv_scale);
}
}

template <typename scalar_t, typename fp8_type>
__global__ void segmented_max_reduction_strided(
float* __restrict__ scale, const scalar_t* __restrict__ input,
Expand Down Expand Up @@ -135,33 +153,46 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided(

void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor const& scale) // [1]
torch::Tensor const& scale) // [1] or [d]
{
TORCH_CHECK(input.stride(-1) == 1,
"last dimension of input must be contiguous");
TORCH_CHECK(out.stride(-1) == 1,
"last dimension of output must be contiguous");
TORCH_CHECK(scale.numel() == 1 || scale.numel() == input.size(-1),
"scale must be either shape [1] or [d], where d is hidden_size");

const int hidden_size = input.size(-1);
const int num_tokens = input.numel() / hidden_size;
const int block_size = 256;
dim3 grid(num_tokens);
dim3 block(block_size);
const bool is_per_channel = (scale.numel() == hidden_size);

const int64_t in_row_stride = input.stride(-2);
const int64_t out_row_stride = out.stride(-2);

const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
VLLM_DISPATCH_FP8_TYPES(
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
vllm::scaled_fp8_quant_kernel_strided<scalar_t, fp8_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), hidden_size, in_row_stride,
out_row_stride);
if (is_per_channel) {
vllm::scaled_fp8_quant_kernel_strided_per_channel<scalar_t,
fp8_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), hidden_size, in_row_stride,
out_row_stride);
} else {
vllm::scaled_fp8_quant_kernel_strided<scalar_t, fp8_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), hidden_size, in_row_stride,
out_row_stride);
}
});
});
}
Expand Down
Loading