@@ -16,18 +16,22 @@ SPDX-License-Identifier: MIT
1616// /
1717// ===----------------------------------------------------------------------===//
1818
19-
2019#include " GenX.h"
20+ #include " GenXSubtarget.h"
21+ #include " GenXTargetMachine.h"
2122#include " GenXUtil.h"
23+
2224#include " llvm/ADT/EquivalenceClasses.h"
2325#include " llvm/ADT/Statistic.h"
26+ #include " llvm/CodeGen/TargetPassConfig.h"
2427#include " llvm/IR/IRBuilder.h"
2528#include " llvm/IR/InstIterator.h"
29+ #include " llvm/InitializePasses.h"
2630#include " llvm/Pass.h"
2731
2832#include " llvmWrapper/IR/DerivedTypes.h"
2933
30- #define DEBUG_TYPE " GENX_PROMOTE_PREDICATE "
34+ #define DEBUG_TYPE " genx-promote-predicate "
3135
3236using namespace llvm ;
3337using namespace genx ;
@@ -48,6 +52,7 @@ class GenXPromotePredicate : public FunctionPass {
4852 bool runOnFunction (Function &F) override ;
4953 StringRef getPassName () const override { return " GenXPromotePredicate" ; }
5054 void getAnalysisUsage (AnalysisUsage &AU) const override {
55+ AU.addRequired <TargetPassConfig>();
5156 AU.setPreservesCFG ();
5257 }
5358};
@@ -61,6 +66,7 @@ void initializeGenXPromotePredicatePass(PassRegistry &);
6166}
6267INITIALIZE_PASS_BEGIN (GenXPromotePredicate, " GenXPromotePredicate" ,
6368 " GenXPromotePredicate" , false , false )
69+ INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
6470INITIALIZE_PASS_END(GenXPromotePredicate, " GenXPromotePredicate" ,
6571 " GenXPromotePredicate" , false , false )
6672
@@ -138,8 +144,9 @@ static Value *promoteInstToScalar(Instruction *Inst) {
138144
139145// Promote one predicate instruction to grf - promote all its operands and
140146// instruction itself, and then sink the result back to predicate.
141- static Value *promoteInst (Instruction *Inst) {
142- if (auto *VTy = dyn_cast<IGCLLVM::FixedVectorType>(Inst->getType ())) {
147+ static Value *promoteInst (Instruction *Inst, bool AllowScalarPromotion) {
148+ if (auto *VTy = dyn_cast<IGCLLVM::FixedVectorType>(Inst->getType ());
149+ VTy && AllowScalarPromotion) {
143150 IGC_ASSERT (VTy->isIntOrIntVectorTy (1 ));
144151 auto Width = VTy->getNumElements ();
145152
@@ -220,7 +227,8 @@ static void foldBitcast(BitCastInst *Cast) {
220227class PredicateWeb {
221228public:
222229 template <class InputIt >
223- PredicateWeb (InputIt first, InputIt last) : Web(first, last) {}
230+ PredicateWeb (InputIt First, InputIt Last, bool AllowScalar)
231+ : Web(First, Last), AllowScalarPromotion(AllowScalar) {}
224232 void print (llvm::raw_ostream &O) const {
225233 for (auto Inst : Web)
226234 O << *Inst << ' \n ' ;
@@ -236,7 +244,7 @@ class PredicateWeb {
236244 // Do promotion.
237245 SmallVector<Instruction *, 8 > Worklist;
238246 for (auto *Inst : Web) {
239- auto *PromotedInst = promoteInst (Inst);
247+ auto *PromotedInst = promoteInst (Inst, AllowScalarPromotion );
240248
241249 if (isa<TruncInst>(PromotedInst) || isa<BitCastInst>(PromotedInst))
242250 Worklist.push_back (cast<Instruction>(PromotedInst));
@@ -254,6 +262,7 @@ class PredicateWeb {
254262
255263private:
256264 SmallPtrSet<Instruction *, 16 > Web;
265+ bool AllowScalarPromotion;
257266};
258267
259268constexpr const char IdxMDName[] = " pred.index" ;
@@ -273,6 +282,11 @@ struct Comparator {
273282};
274283
275284bool GenXPromotePredicate::runOnFunction (Function &F) {
285+ auto &ST = getAnalysis<TargetPassConfig>()
286+ .getTM <GenXTargetMachine>()
287+ .getGenXSubtarget ();
288+ bool AllowScalarPromotion = !ST.hasFusedEU ();
289+
276290 // Put every predicate instruction into its own equivalence class.
277291 long Idx = 0 ;
278292 llvm::EquivalenceClasses<Instruction *, Comparator> PredicateWebs;
@@ -303,7 +317,8 @@ bool GenXPromotePredicate::runOnFunction(Function &F) {
303317 for (auto I = PredicateWebs.begin (), E = PredicateWebs.end (); I != E; ++I) {
304318 if (!I->isLeader ())
305319 continue ;
306- PredicateWeb Web (PredicateWebs.member_begin (I), PredicateWebs.member_end ());
320+ PredicateWeb Web (PredicateWebs.member_begin (I), PredicateWebs.member_end (),
321+ AllowScalarPromotion);
307322 LLVM_DEBUG (dbgs () << " Predicate web:\n " ; Web.dump ());
308323 ++NumCollectedPredicateWebs;
309324 if (!Web.isBeneficialToPromote ())
0 commit comments