2
// Copyright (c) 2002-2011 The ANGLE Project Authors. All rights reserved.
3
// Use of this source code is governed by a BSD-style license that can be
4
// found in the LICENSE file.
7
#include "compiler/ForLoopUnroll.h"
11
class IntegerForLoopUnrollMarker : public TIntermTraverser {
14
virtual bool visitLoop(Visit, TIntermLoop* node)
16
// This is called after ValidateLimitations pass, so all the ASSERT
18
// See ValidateLimitations::validateForLoopInit().
20
ASSERT(node->getType() == ELoopFor);
21
ASSERT(node->getInit());
22
TIntermAggregate* decl = node->getInit()->getAsAggregate();
23
ASSERT(decl && decl->getOp() == EOpDeclaration);
24
TIntermSequence& declSeq = decl->getSequence();
25
ASSERT(declSeq.size() == 1);
26
TIntermBinary* declInit = declSeq[0]->getAsBinaryNode();
27
ASSERT(declInit && declInit->getOp() == EOpInitialize);
28
ASSERT(declInit->getLeft());
29
TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode();
31
TBasicType type = symbol->getBasicType();
32
ASSERT(type == EbtInt || type == EbtFloat);
34
node->setUnrollFlag(true);
40
} // anonymous namepsace
42
void ForLoopUnroll::FillLoopIndexInfo(TIntermLoop* node, TLoopIndexInfo& info)
44
ASSERT(node->getType() == ELoopFor);
45
ASSERT(node->getUnrollFlag());
47
TIntermNode* init = node->getInit();
49
TIntermAggregate* decl = init->getAsAggregate();
50
ASSERT((decl != NULL) && (decl->getOp() == EOpDeclaration));
51
TIntermSequence& declSeq = decl->getSequence();
52
ASSERT(declSeq.size() == 1);
53
TIntermBinary* declInit = declSeq[0]->getAsBinaryNode();
54
ASSERT((declInit != NULL) && (declInit->getOp() == EOpInitialize));
55
TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode();
56
ASSERT(symbol != NULL);
57
ASSERT(symbol->getBasicType() == EbtInt);
59
info.id = symbol->getId();
61
ASSERT(declInit->getRight() != NULL);
62
TIntermConstantUnion* initNode = declInit->getRight()->getAsConstantUnion();
63
ASSERT(initNode != NULL);
65
info.initValue = evaluateIntConstant(initNode);
66
info.currentValue = info.initValue;
68
TIntermNode* cond = node->getCondition();
70
TIntermBinary* binOp = cond->getAsBinaryNode();
71
ASSERT(binOp != NULL);
72
ASSERT(binOp->getRight() != NULL);
73
ASSERT(binOp->getRight()->getAsConstantUnion() != NULL);
75
info.incrementValue = getLoopIncrement(node);
76
info.stopValue = evaluateIntConstant(
77
binOp->getRight()->getAsConstantUnion());
78
info.op = binOp->getOp();
81
void ForLoopUnroll::Step()
83
ASSERT(mLoopIndexStack.size() > 0);
84
TLoopIndexInfo& info = mLoopIndexStack[mLoopIndexStack.size() - 1];
85
info.currentValue += info.incrementValue;
88
bool ForLoopUnroll::SatisfiesLoopCondition()
90
ASSERT(mLoopIndexStack.size() > 0);
91
TLoopIndexInfo& info = mLoopIndexStack[mLoopIndexStack.size() - 1];
92
// Relational operator is one of: > >= < <= == or !=.
95
return (info.currentValue == info.stopValue);
97
return (info.currentValue != info.stopValue);
99
return (info.currentValue < info.stopValue);
101
return (info.currentValue > info.stopValue);
102
case EOpLessThanEqual:
103
return (info.currentValue <= info.stopValue);
104
case EOpGreaterThanEqual:
105
return (info.currentValue >= info.stopValue);
112
bool ForLoopUnroll::NeedsToReplaceSymbolWithValue(TIntermSymbol* symbol)
114
for (TVector<TLoopIndexInfo>::iterator i = mLoopIndexStack.begin();
115
i != mLoopIndexStack.end();
117
if (i->id == symbol->getId())
123
int ForLoopUnroll::GetLoopIndexValue(TIntermSymbol* symbol)
125
for (TVector<TLoopIndexInfo>::iterator i = mLoopIndexStack.begin();
126
i != mLoopIndexStack.end();
128
if (i->id == symbol->getId())
129
return i->currentValue;
135
void ForLoopUnroll::Push(TLoopIndexInfo& info)
137
mLoopIndexStack.push_back(info);
140
void ForLoopUnroll::Pop()
142
mLoopIndexStack.pop_back();
146
void ForLoopUnroll::MarkForLoopsWithIntegerIndicesForUnrolling(
151
IntegerForLoopUnrollMarker marker;
152
root->traverse(&marker);
155
int ForLoopUnroll::getLoopIncrement(TIntermLoop* node)
157
TIntermNode* expr = node->getExpression();
158
ASSERT(expr != NULL);
159
// for expression has one of the following forms:
162
// loop_index += constant_expression
163
// loop_index -= constant_expression
166
// The last two forms are not specified in the spec, but I am assuming
168
TIntermUnary* unOp = expr->getAsUnaryNode();
169
TIntermBinary* binOp = unOp ? NULL : expr->getAsBinaryNode();
171
TOperator op = EOpNull;
172
TIntermConstantUnion* incrementNode = NULL;
175
} else if (binOp != NULL) {
177
ASSERT(binOp->getRight() != NULL);
178
incrementNode = binOp->getRight()->getAsConstantUnion();
179
ASSERT(incrementNode != NULL);
183
// The operator is one of: ++ -- += -=.
185
case EOpPostIncrement:
186
case EOpPreIncrement:
187
ASSERT((unOp != NULL) && (binOp == NULL));
190
case EOpPostDecrement:
191
case EOpPreDecrement:
192
ASSERT((unOp != NULL) && (binOp == NULL));
196
ASSERT((unOp == NULL) && (binOp != NULL));
197
increment = evaluateIntConstant(incrementNode);
200
ASSERT((unOp == NULL) && (binOp != NULL));
201
increment = - evaluateIntConstant(incrementNode);
210
int ForLoopUnroll::evaluateIntConstant(TIntermConstantUnion* node)
212
ASSERT((node != NULL) && (node->getUnionArrayPointer() != NULL));
213
return node->getUnionArrayPointer()->getIConst();