fcd源码分析5

我们继续来分析fcd源码,之前讲到argrec和sroa这两个优化Pass,接下来我们讲的是intnarrowing和signext这两个优化Pass

intnarrowing

intnarrowing这个Pass从名字上就可以看出其作用,就是将一些指令中的使用过大的寄存器,将其减少,下面给一个简单的例子

1
or rax, 0x12

可以转变为

1
or al,0x12

这里并不完全正常,因为真正的目标其实是IR

我们来分析下源码的runOnFunction

for循环遍历每一个指令,找到二元操作指令,并且这个二元操作指令是针对Integer而不是浮点类型的

1
2
3
4
5
6
for (BasicBlock& bb : fn)
{
for (Instruction& inst : bb)
{
if (auto binOp = dyn_cast<BinaryOperator>(&inst))
if (binOp->getType()->isIntegerTy())

然后针对不同的指令获取activeBits,
假如是and指令并且第二个是常数,就获取这个常数的activeBits
假如是其他指令,就利用DemandedBitsWrapperPass这个pass去获取activeBits

1
2
3
4
5
6
7
8
9
10
11
                  if (binOp->getOpcode() == BinaryOperator::And)
{
if (auto constantMask = dyn_cast<ConstantInt>(binOp->getOperand(1)))
{
activeBits = constantMask->getValue().getActiveBits();
}
}
else
{
activeBits = db.getDemandedBits(binOp).getActiveBits();
}

然后假如activeBits比typeBits少并且,activeBits大于0,就调用narrowDown去把指令narrow down

1
2
3
4
if (activeBits < typeBits && activeBits > 0)
{
narrowDown(binOp, activeBits);
}

narrowDown函数这里就不仔细分析了,总之它会把narrow down 之后的指令存到
unordered_map<Value, SmallDenseMap<unsigned, Instruction, 8>> resized;

然后for循环遍历这个map,插入新的指令去替换旧的指令

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
for (auto& pair : resized)
{
if (auto key = dyn_cast<Instruction>(pair.first))
{
auto& otherSizes = pair.second;
if (otherSizes.size() == 1)
{
auto toEnlarge = otherSizes.begin()->second;
auto type = key->getType();
CastInst* enlarged = CastInst::Create(Instruction::ZExt, toEnlarge, type);
enlarged->insertAfter(toEnlarge);

// replace almost all uses with
for (Use& use : key->uses())
{
auto user = use.getUser();
if (user != toEnlarge && resized.count(user) == 0)
{
use.set(enlarged);
}
}
}
}
}

signext

这个pass是为了优化这样一种模式,可以替换为一个sext指令

1
2
3
4
5
6
//  %1 = /* i32 */
// %2 = ashr i32 %1, 31
// %3 = zext i32 %2 to i64
// %4 = shl nuw i64 %3, 32
// %5 = zext i32 %1 to i64
// %6 = or i64 %4, %5

首先看下runOnFunction,首先遍历所有的Or指令

1
2
3
4
5
6
7
8
9
10
11
for (BasicBlock& bb : fn)
{
for (Instruction& inst : bb)
{
if (inst.getOpcode() == Instruction::Or)
{
auto& orInst = cast<BinaryOperator>(inst);
changed |= handleOrInst(orInst);
}
}
}

在handleOrInst函数里面有一堆if,判断是否符合上面说的那种模式

1
2
3
4
5
if (auto zExtSign = dyn_cast<ZExtInst>(shiftLeft->getOperand(0)))
if (auto shiftRight = dyn_cast<BinaryOperator>(zExtSign->getOperand(0)))
if (shiftRight->getOpcode() == Instruction::AShr)
if (auto shiftLeftAmountAP = dyn_cast<ConstantInt>(shiftLeft->getOperand(1)))
if (auto shiftRightAmountAP = dyn_cast<ConstantInt>(shiftRight->getOperand(1)))

这里是通过一系列的计算获取initialWidth和predictedInitialWidth,假如initialWidth大于predictedInitialWidth
顺便加个Trunc命令,把size减少

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
auto initialValue = shiftRight->getOperand(0);

// This should be (bit length of original int) - 1.
auto shiftRightAmount = shiftRightAmountAP->getLimitedValue();

// This should be (extended length) - (original length).
auto shiftLeftAmount = shiftLeftAmountAP->getLimitedValue();

auto predictedInitialWidth = shiftRightAmount + 1;
auto predictedFinalWidth = predictedInitialWidth + shiftLeftAmount;

auto initialWidth = initialValue->getType()->getIntegerBitWidth();
auto finalWidth = orInst.getType()->getIntegerBitWidth();

if (predictedInitialWidth > initialWidth || predictedFinalWidth > finalWidth)
{
// Sign extension doesn't make sense.
assert(false);
return false;
}

// Insert trunc/ext as necessary to simplify pattern next to orInst.
if (predictedInitialWidth < initialWidth)
{
auto truncatedType = Type::getIntNTy(orInst.getContext(), static_cast<unsigned>(predictedInitialWidth));
initialValue = CastInst::Create(Instruction::Trunc, initialValue, truncatedType, "", &orInst);
}

最后插入Sext,假如前面有插入Trunc,再插一个Zext,最后替换掉Or指令

1
2
3
4
5
6
7
8
9
auto extendedType = Type::getIntNTy(orInst.getContext(), static_cast<unsigned>(predictedFinalWidth));
auto extended = CastInst::Create(Instruction::SExt, initialValue, extendedType, "", &orInst);
if (predictedFinalWidth < finalWidth)
{
extended = CastInst::Create(Instruction::ZExt, extended, orInst.getType(), "", &orInst);
}

orInst.replaceAllUsesWith(extended);
return true;