@@ -2128,6 +2128,25 @@ SPIRVToLLVM::transType(SPIRVType *T) {
21282128 getOrCreateOpaquePtrType (M, " intel.buffer_rw_t" ,
21292129 SPIRAddressSpace::SPIRAS_Global));
21302130 }
2131+ case OpTypeMatrixINTEL:
2132+ {
2133+ SPIRVTypeMatrixINTEL *MT = static_cast <SPIRVTypeMatrixINTEL *>(T);
2134+ const char *typeName = nullptr ;
2135+ switch (MT->getLayout ()) {
2136+ case SPIRVTypeMatrixINTEL::LayoutPackedA:
2137+ typeName = " intel.joint_matrix_packedA_t" ;
2138+ break ;
2139+ case SPIRVTypeMatrixINTEL::LayoutPackedB:
2140+ typeName = " intel.joint_matrix_packedB_t" ;
2141+ break ;
2142+ case SPIRVTypeMatrixINTEL::LayoutRowMajor:
2143+ case SPIRVTypeMatrixINTEL::LayoutColumnMajor:
2144+ typeName = " intel.joint_matrix_acc_t" ;
2145+ break ;
2146+ }
2147+ IGC_ASSERT_EXIT_MESSAGE (typeName, " Unsupported layout of INTEL Joint Matrix." );
2148+ return mapType (T, getOrCreateOpaquePtrType (M, typeName, SPIRAddressSpace::SPIRAS_Global));
2149+ }
21312150 default : {
21322151 auto OC = T->getOpCode ();
21332152 if (isOpaqueGenericTypeOpCode (OC) ||
@@ -3651,6 +3670,199 @@ SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
36513670 auto * BC = static_cast <SPIRVUnary*>(BV);
36523671 return mapValue (BV, transValue (BC->getOperand (0 ), F, BB));
36533672 }
3673+ case OpMatrixLoadINTEL: {
3674+ SPIRVMatrixLoadINTEL *ML = static_cast <SPIRVMatrixLoadINTEL *>(BV);
3675+ std::vector<SPIRVValue *> BArgs = ML->getOperands ();
3676+ enum SPVIdx { Pointer, Stride, Layout, Scope, MemOp };
3677+
3678+ SPIRVTypeMatrixINTEL *MatTy = static_cast <SPIRVTypeMatrixINTEL *>(ML->getType ());
3679+ const unsigned loadLayout = (unsigned )BM->get <SPIRVConstant>(BArgs[Layout]->getId ())->getZExtIntValue ();
3680+
3681+ IGC_ASSERT_MESSAGE (BB, " Invalid BB" );
3682+
3683+ /* Get arugment values for the intrinsic call */
3684+ Value *PtrVal = transValue (BArgs[Pointer], F, BB);
3685+ Value *StrideVal = transValue (BArgs[Stride], F, BB);
3686+
3687+ unsigned AS = static_cast <PointerType *>(PtrVal->getType ())->getAddressSpace ();
3688+ /* Prepare types for the call: */
3689+ Type *RetTy = transType (MatTy);
3690+ Type *PtrTy = PointerType::get (Type::getInt8Ty (*Context), AS);
3691+ Type *StrideTy = Type::getInt32Ty (*Context);
3692+ Type *ElemTypeTy = Type::getInt32Ty (*Context);
3693+ Type *LayoutTy = Type::getInt32Ty (*Context);
3694+ Type *SizeTy = Type::getInt32Ty (*Context);
3695+
3696+ std::vector<Type *> ArgTys = {
3697+ PtrTy, StrideTy, LayoutTy, ElemTypeTy, SizeTy, SizeTy
3698+ };
3699+ FunctionType *builtinTy = FunctionType::get (RetTy, ArgTys, false );
3700+
3701+ /* Cast if necessary and prepare rest of the arguments: */
3702+ CastInst *Ptr = CastInst::CreatePointerCast (PtrVal, PtrTy, " " , BB);
3703+ if (StrideVal->getType () != StrideTy) {
3704+ IGC_ASSERT_MESSAGE (StrideVal->getType ()->isIntegerTy (),
3705+ " Unspupported matrix stide type in load instruction." );
3706+ StrideVal = CastInst::CreateIntegerCast (StrideVal, StrideTy, false , " stride" , Ptr);
3707+ }
3708+
3709+ Value *LoadLayoutVal = ConstantInt::get (LayoutTy, loadLayout);
3710+ Value *ElementTypeVal = ConstantInt::get (ElemTypeTy, MatTy->getElementTypeFlags ());
3711+ Value *RowsVal = ConstantInt::get (SizeTy, MatTy->getRows ());
3712+ Value *ColumnsVal = ConstantInt::get (SizeTy, MatTy->getColumns ());
3713+
3714+ /* Get function to call */
3715+ const char *suffix = nullptr ;
3716+ switch (MatTy->getLayout ()) {
3717+ case SPIRVTypeMatrixINTEL::LayoutPackedA:
3718+ suffix = " _PackedA" ;
3719+ break ;
3720+ case SPIRVTypeMatrixINTEL::LayoutPackedB:
3721+ suffix = " _PackedB" ;
3722+ break ;
3723+ case SPIRVTypeMatrixINTEL::LayoutRowMajor:
3724+ case SPIRVTypeMatrixINTEL::LayoutColumnMajor:
3725+ suffix = " _Accumulator" ;
3726+ break ;
3727+ }
3728+ IGC_ASSERT_MESSAGE (suffix, " Unsupported layout type for INTEL Joint Matrix." );
3729+ auto BI = static_cast <SPIRVInstruction *>(BV);
3730+ std::string builtinName (getSPIRVBuiltinName (BV->getOpCode (), BI, ArgTys, suffix));
3731+ Function *Func = cast<Function>(M->getOrInsertFunction (builtinName, builtinTy));
3732+
3733+ std::vector<Value *> Args = {
3734+ Ptr, StrideVal, LoadLayoutVal, ElementTypeVal, RowsVal, ColumnsVal
3735+ };
3736+ CallInst *CI = CallInst::Create (Func, Args, " matrix" , BB);
3737+ return mapValue (BV, CI);
3738+ }
3739+ case OpMatrixStoreINTEL: {
3740+ SPIRVMatrixStoreINTEL *MS = static_cast <SPIRVMatrixStoreINTEL *>(BV);
3741+ std::vector<SPIRVValue *> BArgs = MS->getOperands ();
3742+ enum SPVIdx { Pointer, Object, Stride, Layout, Scope, MemOp };
3743+
3744+ SPIRVTypeMatrixINTEL *MatTy = static_cast <SPIRVTypeMatrixINTEL *>(BArgs[Object]->getType ());
3745+ const unsigned storeLayout = (unsigned )BM->get <SPIRVConstant>(BArgs[Layout]->getId ())->getZExtIntValue ();
3746+
3747+ IGC_ASSERT_MESSAGE (BB, " Invalid BB" );
3748+
3749+ /* Get arugment values for the intrinsic call */
3750+ Value *MatrixVal = transValue (BArgs[Object], F, BB);
3751+ Value *PtrVal = transValue (BArgs[Pointer], F, BB);
3752+ Value *StrideVal = transValue (BArgs[Stride], F, BB);
3753+
3754+ unsigned AS = static_cast <PointerType *>(PtrVal->getType ())->getAddressSpace ();
3755+ /* Prepare types for the call: */
3756+ Type *MatrixTy = transType (MatTy);
3757+ Type *PtrTy = PointerType::get (Type::getInt8Ty (*Context), AS);
3758+ Type *StrideTy = Type::getInt32Ty (*Context);
3759+ Type *ElemTypeTy = Type::getInt32Ty (*Context);
3760+ Type *LayoutTy = Type::getInt32Ty (*Context);
3761+ Type *SizeTy = Type::getInt32Ty (*Context);
3762+
3763+ std::vector<Type *> ArgTys = {
3764+ PtrTy, MatrixTy, StrideTy, LayoutTy, ElemTypeTy, SizeTy, SizeTy
3765+ };
3766+ FunctionType *builtinTy = FunctionType::get (Type::getVoidTy (*Context), ArgTys, false );
3767+
3768+ /* Cast if necessary and prepare rest of the arguments: */
3769+ CastInst *Ptr = CastInst::CreatePointerCast (PtrVal, PtrTy, " " , BB);
3770+ if (StrideVal->getType () != StrideTy) {
3771+ IGC_ASSERT_MESSAGE (StrideVal->getType ()->isIntegerTy (),
3772+ " Unspupported matrix stide type in store instruction." );
3773+ StrideVal = CastInst::CreateIntegerCast (StrideVal, StrideTy, false , " stride" , Ptr);
3774+ }
3775+
3776+ Value *StoreLayoutVal = ConstantInt::get (LayoutTy, storeLayout);
3777+ Value *ElementTypeVal = ConstantInt::get (ElemTypeTy, MatTy->getElementTypeFlags ());
3778+ Value *RowsVal = ConstantInt::get (SizeTy, MatTy->getRows ());
3779+ Value *ColumnsVal = ConstantInt::get (SizeTy, MatTy->getColumns ());
3780+
3781+ /* Get function to call */
3782+ const char *suffix = nullptr ;
3783+ switch (MatTy->getLayout ()) {
3784+ case SPIRVTypeMatrixINTEL::LayoutPackedA:
3785+ suffix = " _PackedA" ;
3786+ break ;
3787+ case SPIRVTypeMatrixINTEL::LayoutPackedB:
3788+ suffix = " _PackedB" ;
3789+ break ;
3790+ case SPIRVTypeMatrixINTEL::LayoutRowMajor:
3791+ case SPIRVTypeMatrixINTEL::LayoutColumnMajor:
3792+ suffix = " _Accumulator" ;
3793+ break ;
3794+ }
3795+ IGC_ASSERT_MESSAGE (suffix, " Unsupported layout type for INTEL Joint Matrix." );
3796+ auto BI = static_cast <SPIRVInstruction *>(BV);
3797+ std::string builtinName (getSPIRVBuiltinName (BV->getOpCode (), BI, ArgTys, suffix));
3798+ Function *Func = cast<Function>(M->getOrInsertFunction (builtinName, builtinTy));
3799+
3800+ std::vector<Value *> Args = {
3801+ Ptr, MatrixVal, StrideVal, StoreLayoutVal, ElementTypeVal, RowsVal, ColumnsVal
3802+ };
3803+ CallInst *CI = CallInst::Create (Func, Args, " " , BB);
3804+ return mapValue (BV, CI);
3805+ }
3806+ case OpMatrixMadINTEL: {
3807+ SPIRVMatrixMadINTEL *MM = static_cast <SPIRVMatrixMadINTEL *>(BV);
3808+ std::vector<SPIRVValue *> BArgs = MM->getOperands ();
3809+ enum SPVIdx { A, B, C, Scope };
3810+
3811+ auto *MatATy = static_cast <SPIRVTypeMatrixINTEL *>(BArgs[A]->getType ());
3812+ auto *MatBTy = static_cast <SPIRVTypeMatrixINTEL *>(BArgs[B]->getType ());
3813+ auto *MatCTy = static_cast <SPIRVTypeMatrixINTEL *>(BArgs[C]->getType ());
3814+
3815+ auto *ResMatTy = static_cast <SPIRVTypeMatrixINTEL *>(MM->getType ());
3816+
3817+ const unsigned sizeM = MatATy->getRows ();
3818+ const unsigned sizeK = MatATy->getColumns ();
3819+ const unsigned sizeN = MatBTy->getColumns ();
3820+
3821+ IGC_ASSERT (sizeM == MatCTy->getRows ());
3822+ IGC_ASSERT (sizeN == MatCTy->getColumns ());
3823+ IGC_ASSERT (sizeK == MatBTy->getRows ());
3824+
3825+ IGC_ASSERT (ResMatTy->getRows () == MatCTy->getRows ());
3826+ IGC_ASSERT (ResMatTy->getColumns () == MatCTy->getColumns ());
3827+
3828+ Type *RetTy = transType (ResMatTy);
3829+ Type *ATy = transType (MatATy);
3830+ Type *BTy = transType (MatBTy);
3831+ Type *CTy = transType (MatCTy);
3832+ Type *ElemTypeTy = Type::getInt32Ty (*Context);
3833+ Type *SizeTy = Type::getInt32Ty (*Context);
3834+
3835+ std::vector<Type *> ArgTys = {
3836+ ATy, ElemTypeTy, SizeTy, SizeTy,
3837+ BTy, ElemTypeTy, SizeTy, SizeTy,
3838+ CTy, ElemTypeTy, SizeTy, SizeTy
3839+ };
3840+ FunctionType *builtinTy = FunctionType::get (RetTy, ArgTys, false );
3841+
3842+ auto BI = static_cast <SPIRVInstruction *>(BV);
3843+ std::string builtinName (getSPIRVBuiltinName (BV->getOpCode (), BI, ArgTys, " " ));
3844+ Function *Func = cast<Function>(M->getOrInsertFunction (builtinName, builtinTy));
3845+
3846+ std::vector<Value *> Args = {
3847+ /* Matrix A */
3848+ transValue (BArgs[A], F, BB),
3849+ ConstantInt::get (ElemTypeTy, MatATy->getElementTypeFlags ()),
3850+ ConstantInt::get (SizeTy, MatATy->getRows ()),
3851+ ConstantInt::get (SizeTy, MatATy->getColumns ()),
3852+ /* Matrix B */
3853+ transValue (BArgs[B], F, BB),
3854+ ConstantInt::get (ElemTypeTy, MatBTy->getElementTypeFlags ()),
3855+ ConstantInt::get (SizeTy, MatBTy->getRows ()),
3856+ ConstantInt::get (SizeTy, MatBTy->getColumns ()),
3857+ /* Matrix C */
3858+ transValue (BArgs[C], F, BB),
3859+ ConstantInt::get (ElemTypeTy, MatCTy->getElementTypeFlags ()),
3860+ ConstantInt::get (SizeTy, MatCTy->getRows ()),
3861+ ConstantInt::get (SizeTy, MatCTy->getColumns ()),
3862+ };
3863+ CallInst *CI = CallInst::Create (Func, Args, " matrix" , BB);
3864+ return mapValue (BV, CI);
3865+ }
36543866 default : {
36553867 auto OC = BV->getOpCode ();
36563868 if (isSPIRVCmpInstTransToLLVMInst (static_cast <SPIRVInstruction*>(BV))) {
0 commit comments