1
//===-- SITypeRewriter.cpp - Remove unwanted types ------------------------===//
3
// The LLVM Compiler Infrastructure
5
// This file is distributed under the University of Illinois Open Source
6
// License. See LICENSE.TXT for details.
8
//===----------------------------------------------------------------------===//
11
/// This pass removes performs the following type substitution on all
12
/// non-compute shaders:
15
/// - v16i8 is used for constant memory resource descriptors. This type is
16
/// legal for some compute APIs, and we don't want to declare it as legal
17
/// in the backend, because we want the legalizer to expand all v16i8
20
/// - Having v1* types complicates the legalizer and we can easily replace
21
/// - them with the element type.
22
//===----------------------------------------------------------------------===//
25
#include "llvm/IR/IRBuilder.h"
26
#include "llvm/IR/InstVisitor.h"
32
class SITypeRewriter : public FunctionPass,
33
public InstVisitor<SITypeRewriter> {
41
SITypeRewriter() : FunctionPass(ID) { }
42
bool doInitialization(Module &M) override;
43
bool runOnFunction(Function &F) override;
44
const char *getPassName() const override {
45
return "SI Type Rewriter";
47
void visitLoadInst(LoadInst &I);
48
void visitCallInst(CallInst &I);
49
void visitBitCast(BitCastInst &I);
52
} // End anonymous namespace
54
char SITypeRewriter::ID = 0;
56
bool SITypeRewriter::doInitialization(Module &M) {
58
v16i8 = VectorType::get(Type::getInt8Ty(M.getContext()), 16);
59
v4i32 = VectorType::get(Type::getInt32Ty(M.getContext()), 4);
63
bool SITypeRewriter::runOnFunction(Function &F) {
64
Attribute A = F.getFnAttribute("ShaderType");
66
unsigned ShaderType = ShaderType::COMPUTE;
67
if (A.isStringAttribute()) {
68
StringRef Str = A.getValueAsString();
69
Str.getAsInteger(0, ShaderType);
71
if (ShaderType == ShaderType::COMPUTE)
80
void SITypeRewriter::visitLoadInst(LoadInst &I) {
81
Value *Ptr = I.getPointerOperand();
82
Type *PtrTy = Ptr->getType();
83
Type *ElemTy = PtrTy->getPointerElementType();
84
IRBuilder<> Builder(&I);
85
if (ElemTy == v16i8) {
86
Value *BitCast = Builder.CreateBitCast(Ptr,
87
PointerType::get(v4i32,PtrTy->getPointerAddressSpace()));
88
LoadInst *Load = Builder.CreateLoad(BitCast);
89
SmallVector<std::pair<unsigned, MDNode *>, 8> MD;
90
I.getAllMetadataOtherThanDebugLoc(MD);
91
for (unsigned i = 0, e = MD.size(); i != e; ++i) {
92
Load->setMetadata(MD[i].first, MD[i].second);
94
Value *BitCastLoad = Builder.CreateBitCast(Load, I.getType());
95
I.replaceAllUsesWith(BitCastLoad);
100
void SITypeRewriter::visitCallInst(CallInst &I) {
101
IRBuilder<> Builder(&I);
103
SmallVector <Value*, 8> Args;
104
SmallVector <Type*, 8> Types;
105
bool NeedToReplace = false;
106
Function *F = I.getCalledFunction();
107
std::string Name = F->getName();
108
for (unsigned i = 0, e = I.getNumArgOperands(); i != e; ++i) {
109
Value *Arg = I.getArgOperand(i);
110
if (Arg->getType() == v16i8) {
111
Args.push_back(Builder.CreateBitCast(Arg, v4i32));
112
Types.push_back(v4i32);
113
NeedToReplace = true;
114
Name = Name + ".v4i32";
115
} else if (Arg->getType()->isVectorTy() &&
116
Arg->getType()->getVectorNumElements() == 1 &&
117
Arg->getType()->getVectorElementType() ==
118
Type::getInt32Ty(I.getContext())){
119
Type *ElementTy = Arg->getType()->getVectorElementType();
120
std::string TypeName = "i32";
121
InsertElementInst *Def = cast<InsertElementInst>(Arg);
122
Args.push_back(Def->getOperand(1));
123
Types.push_back(ElementTy);
124
std::string VecTypeName = "v1" + TypeName;
125
Name = Name.replace(Name.find(VecTypeName), VecTypeName.length(), TypeName);
126
NeedToReplace = true;
129
Types.push_back(Arg->getType());
133
if (!NeedToReplace) {
136
Function *NewF = Mod->getFunction(Name);
138
NewF = Function::Create(FunctionType::get(F->getReturnType(), Types, false), GlobalValue::ExternalLinkage, Name, Mod);
139
NewF->setAttributes(F->getAttributes());
141
I.replaceAllUsesWith(Builder.CreateCall(NewF, Args));
145
void SITypeRewriter::visitBitCast(BitCastInst &I) {
146
IRBuilder<> Builder(&I);
147
if (I.getDestTy() != v4i32) {
151
if (BitCastInst *Op = dyn_cast<BitCastInst>(I.getOperand(0))) {
152
if (Op->getSrcTy() == v4i32) {
153
I.replaceAllUsesWith(Op->getOperand(0));
159
FunctionPass *llvm::createSITypeRewriter() {
160
return new SITypeRewriter();