From 0e4a0b667fda354b755d4ff9d2a0491a8a4e96c5 Mon Sep 17 00:00:00 2001 From: zyeric <619828575@qq.com> Date: Fri, 29 Mar 2024 21:38:48 +0800 Subject: [PATCH] fix xentropy bf16 bug --- apex/contrib/csrc/xentropy/xentropy_kernel.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apex/contrib/csrc/xentropy/xentropy_kernel.cu b/apex/contrib/csrc/xentropy/xentropy_kernel.cu index 8ef62a334..13c49bdfc 100644 --- a/apex/contrib/csrc/xentropy/xentropy_kernel.cu +++ b/apex/contrib/csrc/xentropy/xentropy_kernel.cu @@ -574,7 +574,7 @@ std::vector host_softmax_xentropy( const Tensor & labels_, const float smoothing, const bool half_to_float){ - if (half_to_float) TORCH_CHECK(input_.scalar_type() == ScalarType::Half,"conversion is supported for Half type only"); + if (half_to_float) TORCH_CHECK(input_.scalar_type() == ScalarType::Half || input_.scalar_type() == ScalarType::BFloat16,"conversion is supported for Half type only"); TORCH_CHECK(labels_.scalar_type() == ScalarType::Long,"Label type should be CUDA Long"); auto input = input_.contiguous(); @@ -712,7 +712,7 @@ at::Tensor softmax_xentropy_backward_cuda( const float smoothing) { bool half_to_float = grad_loss.scalar_type() != logits.scalar_type(); if (half_to_float) { - TORCH_CHECK((grad_loss.scalar_type() == ScalarType::Float && logits.scalar_type() == ScalarType::Half), "expected input and grad types to match, or input to be at::Half and grad to be at::Float"); + TORCH_CHECK((grad_loss.scalar_type() == ScalarType::Float && (logits.scalar_type() == ScalarType::Half || logits.scalar_type() == ScalarType::BFloat16)), "expected input and grad types to match, or input to be at::Half and grad to be at::Float"); } return host_softmax_xentropy_backward(grad_loss, logits, max_log_sum_exp, labels, smoothing, half_to_float); }