@@ -222,17 +222,17 @@ CM_NODEBUG CM_INLINE vector<uint64_t, N>
222222__impl_fp2ui__double__ (vector<double , N> a) {
223223 // vector of floats -> vector of ints
224224 vector<uint32_t , 2 *N> LoHi = a.template format <uint32_t >();
225- const vector<uint32_t , N> Exp_mask ( 0xff << 20 );
226- const vector<uint32_t , N> Mantissa_mask (( 1u << 20 ) - 1 );
225+ const vector<uint32_t , N> MantissaMask (( 1u << 20 ) - 1 );
226+ const vector<uint32_t , N> ExpMask ( 0x7ff );
227227 const vector<uint32_t , N> Zero (0 );
228228 const vector<uint32_t , N> Ones (0xffffffff );
229229 const vector<uint32_t , N> One (1 );
230230 vector<uint32_t , N> Lo = LoHi.template select <N, 2 >(0 );
231231 vector<uint32_t , N> Hi = LoHi.template select <N, 2 >(1 );
232- vector<uint32_t , N> Exp = (Hi >> 20 ) & vector< uint32_t , N>( 0x7ff ) ;
232+ vector<uint32_t , N> Exp = (Hi >> 20 ) & ExpMask ;
233233 // mantissa without hidden bit
234234 vector<uint32_t , N> LoMant = Lo;
235- vector<uint32_t , N> HiMant = Hi & Mantissa_mask ;
235+ vector<uint32_t , N> HiMant = Hi & MantissaMask ;
236236 // for normalized numbers (1 + mant/2^52) * 2 ^ (mant-1023)
237237 vector<int32_t , N> MantShift = Exp - 1023 - 52 ;
238238 vector<int32_t , N> OneShift = Exp - 1023 ;
@@ -272,6 +272,7 @@ __impl_fp2ui__double__(vector<double, N> a) {
272272 // check for Exponent overflow (when sign bit set)
273273 auto FlagExpO = (Exp > vector<uint32_t , N>(1089 ));
274274 auto FlagExpUO = FlagNoSignSet & FlagExpO;
275+ auto IsNaN = (Exp == ExpMask) & ((LoMant != Zero) | (HiMant != Zero));
275276 if constexpr (isSigned) {
276277 // calculate (NOT[Lo, Hi] + 1) (integer sign negation)
277278 vector<uint32_t , N> NegLo = ~LoRes;
@@ -307,29 +308,38 @@ __impl_fp2ui__double__(vector<double, N> a) {
307308 LoRes.merge (Ones, FlagExpUO);
308309 HiRes.merge (vector<uint32_t , N>((1u << 31 ) - 1 ), FlagExpUO);
309310
311+ // if (IsNaN)
312+ LoRes.merge (Zero, IsNaN);
313+ HiRes.merge (Zero, IsNaN);
314+
310315 } else {
311316 // if (FlagSignSet)
312317 LoRes.merge (Zero, FlagSignSet);
313318 HiRes.merge (Zero, FlagSignSet);
319+
314320 // if (FlagExpUO)
315321 LoRes.merge (Ones, FlagExpUO);
316322 HiRes.merge (Ones, FlagExpUO);
323+
324+ // if (IsNaN)
325+ LoRes.merge (Zero, IsNaN);
326+ HiRes.merge (Zero, IsNaN);
317327 }
318328 return __impl_combineLoHi<N>(LoRes, HiRes);
319329}
320330template <unsigned N, bool isSigned>
321331CM_NODEBUG CM_INLINE vector<uint64_t , N> __impl_fp2ui__ (vector<float , N> a) {
322332 // vector of floats -> vector of ints
323333 vector<uint32_t , N> Uifl = a.template format <uint32_t >();
324- const vector<uint32_t , N> Exp_mask (0xff << 23 );
325- const vector<uint32_t , N> Mantissa_mask ((1u << 23 ) - 1 );
334+ const vector<uint32_t , N> ExpMask (0xff );
335+ const vector<uint32_t , N> MantissaMask ((1u << 23 ) - 1 );
326336 const vector<uint32_t , N> Zero (0 );
327337 const vector<uint32_t , N> Ones (0xffffffff );
328338 const vector<uint32_t , N> One (1 );
329339
330- vector<uint32_t , N> Exp = (Uifl >> 23 ) & vector< uint32_t , N>( 0xff ) ;
340+ vector<uint32_t , N> Exp = (Uifl >> 23 ) & ExpMask ;
331341 // mantissa without hidden bit
332- vector<uint32_t , N> Pmantissa = Uifl & Mantissa_mask ;
342+ vector<uint32_t , N> Pmantissa = Uifl & MantissaMask ;
333343 // take hidden bit into account
334344 vector<uint32_t , N> Mantissa = Pmantissa | vector<uint32_t , N>(1 << 23 );
335345 vector<uint32_t , N> Data_h = Mantissa << 8 ;
@@ -354,8 +364,8 @@ CM_NODEBUG CM_INLINE vector<uint64_t, N> __impl_fp2ui__(vector<float, N> a) {
354364
355365 // Discard results if shift is greater than 63
356366 vector<uint32_t , N> Mask = Ones;
357- auto Flag_discard = (Shift > vector<uint32_t , N>(63 ));
358- Mask.merge (Zero, Flag_discard );
367+ auto FlagDiscard = (Shift > vector<uint32_t , N>(63 ));
368+ Mask.merge (Zero, FlagDiscard );
359369 Lo = Lo & Mask;
360370 Hi = Hi & Mask;
361371 vector<uint32_t , N> SignedBitMask (1u << 31 );
@@ -365,6 +375,7 @@ CM_NODEBUG CM_INLINE vector<uint64_t, N> __impl_fp2ui__(vector<float, N> a) {
365375 // check for Exponent overflow (when sign bit set)
366376 auto FlagExpO = (Exp > vector<uint32_t , N>(0xbe ));
367377 auto FlagExpUO = FlagNoSignSet & FlagExpO;
378+ auto IsNaN = (Exp == ExpMask) & (Pmantissa != Zero);
368379 if constexpr (isSigned) {
369380 // calculate (NOT[Lo, Hi] + 1) (integer sign negation)
370381 vector<uint32_t , N> NegLo = ~Lo;
@@ -401,13 +412,21 @@ CM_NODEBUG CM_INLINE vector<uint64_t, N> __impl_fp2ui__(vector<float, N> a) {
401412 Lo.merge (Ones, FlagExpUO);
402413 Hi.merge (vector<uint32_t , N>((1u << 31 ) - 1 ), FlagExpUO);
403414
415+ // if (IsNaN)
416+ Lo.merge (Zero, IsNaN);
417+ Hi.merge (Zero, IsNaN);
404418 } else {
405419 // if (FlagSignSet)
406420 Lo.merge (Zero, FlagSignSet);
407421 Hi.merge (Zero, FlagSignSet);
422+
408423 // if (FlagExpUO)
409424 Lo.merge (Ones, FlagExpUO);
410425 Hi.merge (Ones, FlagExpUO);
426+
427+ // if (IsNaN)
428+ Lo.merge (Zero, IsNaN);
429+ Hi.merge (Zero, IsNaN);
411430 }
412431 return __impl_combineLoHi<N>(Lo, Hi);
413432}
0 commit comments