Files
UnrealEngine/Engine/Source/ThirdParty/Intel/ISPC/ispc-1.21.0/src/stmt.cpp
Brandyn / Techy fcc1b09210 init
2026-04-04 15:40:51 -05:00

4013 lines
156 KiB
C++

/*
Copyright (c) 2010-2023, Intel Corporation
SPDX-License-Identifier: BSD-3-Clause
*/
/** @file stmt.cpp
@brief File with definitions classes related to statements in the language
*/
#include "stmt.h"
#include "builtins-info.h"
#include "ctx.h"
#include "expr.h"
#include "func.h"
#include "llvmutil.h"
#include "module.h"
#include "sym.h"
#include "type.h"
#include "util.h"
#include <algorithm>
#include <iterator>
#include <map>
#include <sstream>
#include <stdio.h>
#include <llvm/IR/CallingConv.h>
#include <llvm/IR/DerivedTypes.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/Metadata.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/Type.h>
#include <llvm/IR/Value.h>
#include <llvm/Support/raw_ostream.h>
using namespace ispc;
///////////////////////////////////////////////////////////////////////////
// Stmt
Stmt *Stmt::Optimize() { return this; }
void Stmt::SetLoopAttribute(std::pair<Globals::pragmaUnrollType, int> lAttr) {
Error(pos, "Illegal pragma - expected a loop to follow '#pragma unroll/nounroll'.");
}
///////////////////////////////////////////////////////////////////////////
// ExprStmt
ExprStmt::ExprStmt(Expr *e, SourcePos p) : Stmt(p, ExprStmtID) { expr = e; }
void ExprStmt::EmitCode(FunctionEmitContext *ctx) const {
if (!ctx->GetCurrentBasicBlock())
return;
ctx->SetDebugPos(pos);
if (expr)
expr->GetValue(ctx);
}
Stmt *ExprStmt::TypeCheck() { return this; }
void ExprStmt::Print(Indent &indent) const {
indent.PrintLn("ExprStmt", pos);
indent.pushSingle();
if (expr != nullptr) {
expr->Print(indent);
} else {
indent.Print("<NULL EXPR>\n");
indent.Done();
}
indent.Done();
}
int ExprStmt::EstimateCost() const { return 0; }
ExprStmt *ExprStmt::Instantiate(TemplateInstantiation &templInst) const {
return new ExprStmt(expr ? expr->Instantiate(templInst) : nullptr, pos);
}
///////////////////////////////////////////////////////////////////////////
// DeclStmt
DeclStmt::DeclStmt(const std::vector<VariableDeclaration> &v, SourcePos p) : Stmt(p, DeclStmtID), vars(v) {}
static bool lHasUnsizedArrays(const Type *type) {
const ArrayType *at = CastType<ArrayType>(type);
if (at == nullptr)
return false;
if (at->GetElementCount() == 0)
return true;
else
return lHasUnsizedArrays(at->GetElementType());
}
void DeclStmt::EmitCode(FunctionEmitContext *ctx) const {
if (!ctx->GetCurrentBasicBlock())
return;
for (unsigned int i = 0; i < vars.size(); ++i) {
Symbol *sym = vars[i].sym;
AssertPos(pos, sym != nullptr);
if (sym->type == nullptr)
continue;
Expr *initExpr = vars[i].init;
// Now that we're in the thick of emitting code, it's easy for us
// to find out the level of nesting of varying control flow we're
// in at this declaration. So we can finally set that
// Symbol::varyingCFDepth variable.
// @todo It's disgusting to be doing this here.
sym->varyingCFDepth = ctx->VaryingCFDepth();
ctx->SetDebugPos(sym->pos);
// If it's an array that was declared without a size but has an
// initializer list, then use the number of elements in the
// initializer list to finally set the array's size.
sym->type = ArrayType::SizeUnsizedArrays(sym->type, initExpr);
if (sym->type == nullptr)
continue;
if (lHasUnsizedArrays(sym->type)) {
Error(pos, "Illegal to declare an unsized array variable without "
"providing an initializer expression to set its size.");
continue;
}
// References must have initializer expressions as well.
if (IsReferenceType(sym->type) == true) {
if (initExpr == nullptr) {
Error(sym->pos,
"Must provide initializer for reference-type "
"variable \"%s\".",
sym->name.c_str());
continue;
}
if (IsReferenceType(initExpr->GetType()) == false) {
const Type *initLVType = initExpr->GetLValueType();
if (initLVType == nullptr) {
Error(initExpr->pos,
"Initializer for reference-type variable "
"\"%s\" must have an lvalue type.",
sym->name.c_str());
continue;
}
if (initLVType->IsUniformType() == false) {
Error(initExpr->pos,
"Initializer for reference-type variable "
"\"%s\" must have a uniform lvalue type.",
sym->name.c_str());
continue;
}
}
}
llvm::Type *llvmType = sym->type->LLVMType(g->ctx);
if (llvmType == nullptr) {
AssertPos(pos, m->errorCount > 0);
return;
}
if (sym->storageClass == SC_STATIC) {
// For static variables, we need a compile-time constant value
// for its initializer; if there's no initializer, we use a
// zero value.
llvm::Constant *cinit = nullptr;
if (initExpr != nullptr) {
if (PossiblyResolveFunctionOverloads(initExpr, sym->type) == false)
continue;
// FIXME: we only need this for function pointers; it was
// already done for atomic types and enums in
// DeclStmt::TypeCheck()...
if (llvm::dyn_cast<ExprList>(initExpr) == nullptr) {
initExpr = TypeConvertExpr(initExpr, sym->type, "initializer");
// FIXME: and this is only needed to re-establish
// constant-ness so that GetConstant below works for
// constant artithmetic expressions...
initExpr = ::Optimize(initExpr);
}
std::pair<llvm::Constant *, bool> cinitPair = initExpr->GetConstant(sym->type);
cinit = cinitPair.first;
if (cinit == nullptr)
Error(initExpr->pos,
"Initializer for static variable "
"\"%s\" must be a constant.",
sym->name.c_str());
}
if (cinit == nullptr)
cinit = llvm::Constant::getNullValue(llvmType);
// Allocate space for the static variable in global scope, so
// that it persists across function calls
sym->storageInfo = new AddressInfo(
new llvm::GlobalVariable(
*m->module, llvmType, sym->type->IsConstType(), llvm::GlobalValue::InternalLinkage, cinit,
llvm::Twine("static.") + llvm::Twine(sym->pos.first_line) + llvm::Twine(".") + sym->name.c_str()),
llvmType);
// Tell the FunctionEmitContext about the variable
ctx->EmitVariableDebugInfo(sym);
} else {
// For non-static variables, allocate storage on the stack
sym->storageInfo = ctx->AllocaInst(sym->type, sym->name.c_str());
// Tell the FunctionEmitContext about the variable; must do
// this before the initializer stuff.
ctx->EmitVariableDebugInfo(sym);
if (initExpr == 0 && sym->type->IsConstType())
Error(sym->pos,
"Missing initializer for const variable "
"\"%s\".",
sym->name.c_str());
// And then get it initialized...
sym->parentFunction = ctx->GetFunction();
InitSymbol(sym->storageInfo, sym->type, initExpr, ctx, sym->pos);
}
}
}
Stmt *DeclStmt::Optimize() {
for (unsigned int i = 0; i < vars.size(); ++i) {
Expr *init = vars[i].init;
if (init != nullptr && llvm::dyn_cast<ExprList>(init) == nullptr) {
// If the variable is const-qualified, after we've optimized
// the initializer expression, see if we have a ConstExpr. If
// so, save it in Symbol::constValue where it can be used in
// optimizing later expressions that have this symbol in them.
// Note that there are cases where the expression may be
// constant but where we don't have a ConstExpr; an example is
// const arrays--the ConstExpr implementation just can't
// represent an array of values.
//
// All this is fine in terms of the code that's generated in
// the end (LLVM's constant folding stuff is good), but it
// means that the ispc compiler's ability to reason about what
// is definitely a compile-time constant for things like
// computing array sizes from non-trivial expressions is
// consequently limited.
Symbol *sym = vars[i].sym;
if (sym->type && sym->type->IsConstType() && Type::Equal(init->GetType(), sym->type))
sym->constValue = llvm::dyn_cast<ConstExpr>(init);
}
}
return this;
}
// Do type conversion if needed and check for not initializing array with
// another array (array assignment is not allowed).
// Do that recursively to handle brace initialization, which may contain
// another brace initialization.
static bool checkInit(const Type *type, Expr **init) {
bool encounteredError = false;
if (type && type->IsDependentType()) {
return false;
}
if (*init == nullptr) {
// Return error if an initializer expression is malformed, e.g., undeclared symbol is used.
return true;
}
// get the right type for stuff like const float foo = 2; so that
// the int->float type conversion is in there and we don't return
// an int as the constValue later...
if (CastType<AtomicType>(type) != nullptr || CastType<EnumType>(type) != nullptr) {
// If it's an expr list with an atomic type, we'll later issue
// an error. Need to leave vars[i].init as is in that case so
// it is in fact caught later, though.
if (llvm::dyn_cast<ExprList>(*init) == nullptr) {
const Type *t = (*init) ? (*init)->GetType() : nullptr;
if (t && t->IsDependentType()) {
return false;
}
*init = TypeConvertExpr(*init, type, "initializer");
if (*init == nullptr)
encounteredError = true;
}
} else if (CastType<ArrayType>(type) != nullptr && llvm::dyn_cast<ExprList>(*init) == nullptr) {
encounteredError = true;
Error((*init)->pos, "Array initializer must be an initializer list");
} else if (CastType<StructType>(type) != nullptr && llvm::dyn_cast<ExprList>(*init) != nullptr) {
const StructType *st = CastType<StructType>(type);
ExprList *el = llvm::dyn_cast<ExprList>(*init);
int elt_count = st->GetElementCount() < el->exprs.size() ? st->GetElementCount() : el->exprs.size();
for (int i = 0; i < elt_count; i++) {
encounteredError |= checkInit(st->GetElementType(i), &(el->exprs[i]));
}
}
return encounteredError;
}
Stmt *DeclStmt::TypeCheck() {
bool encounteredError = false;
for (unsigned int i = 0; i < vars.size(); ++i) {
if (vars[i].sym == nullptr) {
encounteredError = true;
continue;
}
if (vars[i].init == nullptr)
continue;
// Check an init.
encounteredError |= checkInit(vars[i].sym->type, &(vars[i].init));
}
return encounteredError ? nullptr : this;
}
void DeclStmt::Print(Indent &indent) const {
indent.PrintLn("DeclStmt", pos);
indent.pushList(vars.size());
for (unsigned int i = 0; i < vars.size(); ++i) {
indent.Print();
printf("Variable %s (%s)\n", vars[i].sym->name.c_str(), vars[i].sym->type->GetString().c_str());
if (vars[i].init != nullptr) {
indent.pushSingle();
indent.setNextLabel("init");
vars[i].init->Print(indent);
}
indent.Done();
}
indent.Done();
}
int DeclStmt::EstimateCost() const { return 0; }
DeclStmt *DeclStmt::Instantiate(TemplateInstantiation &templInst) const {
std::vector<VariableDeclaration> instVars;
for (auto &var : vars) {
Expr *instInit = var.init ? var.init->Instantiate(templInst) : nullptr;
Symbol *instSym = templInst.InstantiateSymbol(var.sym);
VariableDeclaration instDecl(instSym, instInit);
instVars.push_back(instDecl);
}
return new DeclStmt(instVars, pos);
}
///////////////////////////////////////////////////////////////////////////
// IfStmt
IfStmt::IfStmt(Expr *t, Stmt *ts, Stmt *fs, bool checkCoherence, SourcePos p)
: Stmt(p, IfStmtID), test(t), trueStmts(ts), falseStmts(fs),
doAllCheck(checkCoherence && !g->opt.disableCoherentControlFlow) {}
static void lEmitIfStatements(FunctionEmitContext *ctx, Stmt *stmts, const char *trueOrFalse) {
if (!stmts)
return;
if (llvm::dyn_cast<StmtList>(stmts) == nullptr)
ctx->StartScope();
ctx->AddInstrumentationPoint(trueOrFalse);
stmts->EmitCode(ctx);
if (llvm::dyn_cast<const StmtList>(stmts) == nullptr)
ctx->EndScope();
}
/** Returns true if the "true" block for the if statement consists of a
single 'break' statement, and the "false" block is empty. */
/*
static bool
lCanApplyBreakOptimization(Stmt *trueStmts, Stmt *falseStmts) {
if (falseStmts != nullptr) {
if (StmtList *sl = llvm::dyn_cast<StmtList>(falseStmts)) {
return (sl->stmts.size() == 0);
}
else
return false;
}
if (llvm::dyn_cast<BreakStmt>(trueStmts))
return true;
else if (StmtList *sl = llvm::dyn_cast<StmtList>(trueStmts))
return (sl->stmts.size() == 1 &&
llvm::dyn_cast<BreakStmt>(sl->stmts[0]) != nullptr);
else
return false;
}
*/
void IfStmt::EmitCode(FunctionEmitContext *ctx) const {
// First check all of the things that might happen due to errors
// earlier in compilation and bail out if needed so that we don't
// dereference nullptr pointers in the below...
if (!ctx->GetCurrentBasicBlock())
return;
if (!test)
return;
const Type *testType = test->GetType();
if (!testType)
return;
ctx->SetDebugPos(pos);
bool isUniform = testType->IsUniformType();
llvm::Value *testValue = test->GetValue(ctx);
if (testValue == nullptr)
return;
bool emulateUniform = false;
if (ctx->emitXeHardwareMask() && !isUniform) {
/* With Xe target we generate uniform control flow but
emit varying using CM simdcf.any intrinsic. We mark the scope as
emulateUniform = true to let nested scopes know that they should
generate vector conditions before branching.
This is needed because CM does not support scalar control flow inside
simd control flow.
*/
isUniform = true;
emulateUniform = true;
}
if (isUniform) {
ctx->StartUniformIf(emulateUniform);
if (doAllCheck && !emulateUniform)
Warning(test->pos, "Uniform condition supplied to \"cif\" statement.");
// 'If' statements with uniform conditions are relatively
// straightforward. We evaluate the condition and then jump to
// either the 'then' or 'else' clause depending on its value.
llvm::BasicBlock *bthen = ctx->CreateBasicBlock("if_then", ctx->GetCurrentBasicBlock());
llvm::BasicBlock *belse = ctx->CreateBasicBlock("if_else", bthen);
llvm::BasicBlock *bexit = ctx->CreateBasicBlock("if_exit", belse);
// Jump to the appropriate basic block based on the value of
// the 'if' test
ctx->BranchInst(bthen, belse, testValue);
// Emit code for the 'true' case
ctx->SetCurrentBasicBlock(bthen);
lEmitIfStatements(ctx, trueStmts, "true");
if (ctx->GetCurrentBasicBlock())
ctx->BranchInst(bexit);
// Emit code for the 'false' case
ctx->SetCurrentBasicBlock(belse);
lEmitIfStatements(ctx, falseStmts, "false");
if (ctx->GetCurrentBasicBlock())
ctx->BranchInst(bexit);
// Set the active basic block to the newly-created exit block
// so that subsequent emitted code starts there.
ctx->SetCurrentBasicBlock(bexit);
ctx->EndIf();
}
/*
// Disabled for performance reasons. Change to an optional compile-time opt switch.
else if (lCanApplyBreakOptimization(trueStmts, falseStmts)) {
// If we have a simple break statement inside the 'if' and are
// under varying control flow, just update the execution mask
// directly and don't emit code for the statements. This leads to
// better code for this case--this is surprising and should be
// root-caused further, but for now this gives us performance
// benefit in this case.
ctx->SetInternalMaskAndNot(ctx->GetInternalMask(), testValue);
}
*/
else
emitVaryingIf(ctx, testValue);
}
Stmt *IfStmt::TypeCheck() {
if (test != nullptr) {
const Type *testType = test->GetType();
if (testType != nullptr) {
if (testType->IsDependentType()) {
return this;
}
bool isUniform = (testType->IsUniformType() && !g->opt.disableUniformControlFlow);
test = TypeConvertExpr(test, isUniform ? AtomicType::UniformBool : AtomicType::VaryingBool,
"\"if\" statement test");
if (test == nullptr)
return nullptr;
}
}
return this;
}
int IfStmt::EstimateCost() const {
const Type *type;
if (test == nullptr || (type = test->GetType()) == nullptr)
return 0;
return type->IsUniformType() ? COST_UNIFORM_IF : COST_VARYING_IF;
}
IfStmt *IfStmt::Instantiate(TemplateInstantiation &templInst) const {
Expr *instTestExpr = test ? test->Instantiate(templInst) : nullptr;
Stmt *instTrueStmts = trueStmts ? trueStmts->Instantiate(templInst) : nullptr;
Stmt *instFalseStmts = falseStmts ? falseStmts->Instantiate(templInst) : nullptr;
return new IfStmt(instTestExpr, instTrueStmts, instFalseStmts, doAllCheck, pos);
}
void IfStmt::Print(Indent &indent) const {
indent.PrintLn(doAllCheck ? "IfStmt DO ALL CHECK" : "IfStmt", pos);
int totalChildren = 1 + (trueStmts ? 1 : 0) + (falseStmts ? 1 : 0);
indent.pushList(totalChildren);
indent.setNextLabel("test");
test->Print(indent);
if (trueStmts) {
indent.setNextLabel("true");
trueStmts->Print(indent);
}
if (falseStmts) {
indent.setNextLabel("false");
falseStmts->Print(indent);
}
indent.Done();
}
/** Emit code to run both the true and false statements for the if test,
with the mask set appropriately before running each one.
*/
void IfStmt::emitMaskedTrueAndFalse(FunctionEmitContext *ctx, llvm::Value *oldMask, llvm::Value *test) const {
if (trueStmts) {
ctx->SetInternalMaskAnd(oldMask, test);
lEmitIfStatements(ctx, trueStmts, "if: expr mixed, true statements");
// under varying control flow,, returns can't stop instruction
// emission, so this better be non-nullptr...
AssertPos(ctx->GetDebugPos(), ctx->GetCurrentBasicBlock());
}
if (falseStmts) {
ctx->SetInternalMaskAndNot(oldMask, test);
lEmitIfStatements(ctx, falseStmts, "if: expr mixed, false statements");
AssertPos(ctx->GetDebugPos(), ctx->GetCurrentBasicBlock());
}
}
/** Emit code for an if test that checks the mask and the test values and
tries to be smart about jumping over code that doesn't need to be run.
*/
void IfStmt::emitVaryingIf(FunctionEmitContext *ctx, llvm::Value *ltest) const {
llvm::Value *oldMask = ctx->GetInternalMask();
if (doAllCheck) {
// We can't tell if the mask going into the if is all on at the
// compile time. Emit code to check for this and then either run
// the code for the 'all on' or the 'mixed' case depending on the
// mask's value at runtime.
llvm::BasicBlock *bAllOn = ctx->CreateBasicBlock("cif_mask_all");
llvm::BasicBlock *bMixedOn = ctx->CreateBasicBlock("cif_mask_mixed");
llvm::BasicBlock *bDone = ctx->CreateBasicBlock("cif_done");
// Jump to either bAllOn or bMixedOn, depending on the mask's value
llvm::Value *maskAllQ = ctx->All(ctx->GetFullMask());
ctx->BranchInst(bAllOn, bMixedOn, maskAllQ);
// Emit code for the 'mask all on' case
ctx->SetCurrentBasicBlock(bAllOn);
emitMaskAllOn(ctx, ltest, bDone);
// And emit code for the mixed mask case
ctx->SetCurrentBasicBlock(bMixedOn);
emitMaskMixed(ctx, oldMask, ltest, bDone);
// When done, set the current basic block to the block that the two
// paths above jump to when they're done.
ctx->SetCurrentBasicBlock(bDone);
} else if (trueStmts != nullptr || falseStmts != nullptr) {
// If there is nothing that is potentially unsafe to run with all
// lanes off in the true and false statements and if the total
// complexity of those two is relatively simple, then we'll go
// ahead and emit straightline code that runs both sides, updating
// the mask accordingly. This is useful for efficiently compiling
// things like:
//
// if (foo) x = 0;
// else ++x;
//
// Where the overhead of checking if any of the program instances wants
// to run one side or the other is more than the actual computation.
// SafeToRunWithMaskAllOff() checks to make sure that we don't do this
// for potentially dangerous code like:
//
// if (index < count) array[index] = 0;
//
// where our use of blend for conditional assignments doesn't check
// for the 'all lanes' off case.
int trueFalseCost = (::EstimateCost(trueStmts) + ::EstimateCost(falseStmts));
bool costIsAcceptable = (trueFalseCost < PREDICATE_SAFE_IF_STATEMENT_COST);
bool safeToRunWithAllLanesOff = (SafeToRunWithMaskAllOff(trueStmts) && SafeToRunWithMaskAllOff(falseStmts));
Debug(pos, "If statement: true cost %d (safe %d), false cost %d (safe %d).", ::EstimateCost(trueStmts),
(int)SafeToRunWithMaskAllOff(trueStmts), ::EstimateCost(falseStmts),
(int)SafeToRunWithMaskAllOff(falseStmts));
if (safeToRunWithAllLanesOff && (costIsAcceptable || g->opt.disableCoherentControlFlow)) {
ctx->StartVaryingIf(oldMask);
emitMaskedTrueAndFalse(ctx, oldMask, ltest);
AssertPos(pos, ctx->GetCurrentBasicBlock());
ctx->EndIf();
} else {
llvm::BasicBlock *bDone = ctx->CreateBasicBlock("if_done");
emitMaskMixed(ctx, oldMask, ltest, bDone);
ctx->SetCurrentBasicBlock(bDone);
}
}
}
/** Emits code for 'if' tests under the case where we know that the program
mask is all on going into the 'if'.
*/
void IfStmt::emitMaskAllOn(FunctionEmitContext *ctx, llvm::Value *ltest, llvm::BasicBlock *bDone) const {
// We start by explicitly storing "all on" into the mask mask. Note
// that this doesn't change its actual value, but doing so lets the
// compiler see what's going on so that subsequent optimizations for
// code emitted here can operate with the knowledge that the mask is
// definitely all on (until it modifies the mask itself).
AssertPos(pos, !g->opt.disableCoherentControlFlow);
if (!g->opt.disableMaskAllOnOptimizations)
ctx->SetInternalMask(LLVMMaskAllOn);
llvm::Value *oldFunctionMask = ctx->GetFunctionMask();
if (!g->opt.disableMaskAllOnOptimizations)
ctx->SetFunctionMask(LLVMMaskAllOn);
// First, check the value of the test. If it's all on, then we jump to
// a basic block that will only have code for the true case.
llvm::BasicBlock *bTestAll = ctx->CreateBasicBlock("cif_test_all");
llvm::BasicBlock *bTestNoneCheck = ctx->CreateBasicBlock("cif_test_none_check");
llvm::Value *testAllQ = ctx->All(ltest);
ctx->BranchInst(bTestAll, bTestNoneCheck, testAllQ);
// Emit code for the 'test is all true' case
ctx->SetCurrentBasicBlock(bTestAll);
ctx->StartVaryingIf(LLVMMaskAllOn);
lEmitIfStatements(ctx, trueStmts, "if: all on mask, expr all true");
ctx->EndIf();
if (ctx->GetCurrentBasicBlock() != nullptr)
// bblock may legitimately be nullptr since if there's a return stmt
// or break or continue we can actually jump and end emission since
// we know all of the lanes are following this path...
ctx->BranchInst(bDone);
// The test isn't all true. Now emit code to determine if it's all
// false, or has mixed values.
ctx->SetCurrentBasicBlock(bTestNoneCheck);
llvm::BasicBlock *bTestNone = ctx->CreateBasicBlock("cif_test_none");
llvm::BasicBlock *bTestMixed = ctx->CreateBasicBlock("cif_test_mixed");
llvm::Value *testMixedQ = ctx->Any(ltest);
ctx->BranchInst(bTestMixed, bTestNone, testMixedQ);
// Emit code for the 'test is all false' case
ctx->SetCurrentBasicBlock(bTestNone);
ctx->StartVaryingIf(LLVMMaskAllOn);
lEmitIfStatements(ctx, falseStmts, "if: all on mask, expr all false");
ctx->EndIf();
if (ctx->GetCurrentBasicBlock())
// bblock may be nullptr since if there's a return stmt or break or
// continue we can actually jump or whatever and end emission...
ctx->BranchInst(bDone);
// Finally emit code for the 'mixed true/false' case. We unavoidably
// need to run both the true and the false statements.
ctx->SetCurrentBasicBlock(bTestMixed);
ctx->StartVaryingIf(LLVMMaskAllOn);
emitMaskedTrueAndFalse(ctx, LLVMMaskAllOn, ltest);
// In this case, return/break/continue isn't allowed to jump and end
// emission.
AssertPos(pos, ctx->GetCurrentBasicBlock());
ctx->EndIf();
ctx->BranchInst(bDone);
ctx->SetCurrentBasicBlock(bDone);
ctx->SetFunctionMask(oldFunctionMask);
}
/** Emit code for an 'if' test where the lane mask is known to be mixed
on/off going into it.
*/
void IfStmt::emitMaskMixed(FunctionEmitContext *ctx, llvm::Value *oldMask, llvm::Value *ltest,
llvm::BasicBlock *bDone) const {
ctx->StartVaryingIf(oldMask);
llvm::BasicBlock *bNext = ctx->CreateBasicBlock("safe_if_after_true");
llvm::BasicBlock *bRunTrue = ctx->CreateBasicBlock("safe_if_run_true");
ctx->SetInternalMaskAnd(oldMask, ltest);
// Do any of the program instances want to run the 'true'
// block? If not, jump ahead to bNext.
llvm::Value *maskAnyTrueQ = ctx->Any(ctx->GetFullMask());
ctx->BranchInst(bRunTrue, bNext, maskAnyTrueQ);
// Emit statements for true
ctx->SetCurrentBasicBlock(bRunTrue);
if (trueStmts != nullptr)
lEmitIfStatements(ctx, trueStmts, "if: expr mixed, true statements");
AssertPos(pos, ctx->GetCurrentBasicBlock());
ctx->BranchInst(bNext);
ctx->SetCurrentBasicBlock(bNext);
// False...
llvm::BasicBlock *bRunFalse = ctx->CreateBasicBlock("safe_if_run_false");
ctx->SetInternalMaskAndNot(oldMask, ltest);
// Similarly, check to see if any of the instances want to
// run the 'false' block...
llvm::Value *maskAnyFalseQ = ctx->Any(ctx->GetFullMask());
ctx->BranchInst(bRunFalse, bDone, maskAnyFalseQ);
// Emit code for false
ctx->SetCurrentBasicBlock(bRunFalse);
if (falseStmts)
lEmitIfStatements(ctx, falseStmts, "if: expr mixed, false statements");
AssertPos(pos, ctx->GetCurrentBasicBlock());
ctx->BranchInst(bDone);
ctx->SetCurrentBasicBlock(bDone);
ctx->EndIf();
}
///////////////////////////////////////////////////////////////////////////
// DoStmt
struct VaryingBCCheckInfo {
VaryingBCCheckInfo() {
varyingControlFlowDepth = 0;
foundVaryingBreakOrContinue = false;
}
int varyingControlFlowDepth;
bool foundVaryingBreakOrContinue;
};
/** Returns true if the given node is an 'if' statement where the test
condition has varying type. */
static bool lIsVaryingFor(ASTNode *node) {
IfStmt *ifStmt;
if ((ifStmt = llvm::dyn_cast<IfStmt>(node)) != nullptr && ifStmt->test != nullptr) {
const Type *type = ifStmt->test->GetType();
return (type != nullptr && type->IsVaryingType());
} else
return false;
}
/** Preorder callback function for checking for varying breaks or
continues. */
static bool lVaryingBCPreFunc(ASTNode *node, void *d) {
VaryingBCCheckInfo *info = (VaryingBCCheckInfo *)d;
// We found a break or continue statement; if we're under varying
// control flow, then bingo.
if ((llvm::dyn_cast<BreakStmt>(node) != nullptr || llvm::dyn_cast<ContinueStmt>(node) != nullptr) &&
info->varyingControlFlowDepth > 0) {
info->foundVaryingBreakOrContinue = true;
return false;
}
// Update the count of the nesting depth of varying control flow if
// this is an if statement with a varying condition.
if (lIsVaryingFor(node))
++info->varyingControlFlowDepth;
if (llvm::dyn_cast<ForStmt>(node) != nullptr || llvm::dyn_cast<DoStmt>(node) != nullptr ||
llvm::dyn_cast<ForeachStmt>(node) != nullptr)
// Don't recurse into these guys, since we don't care about varying
// breaks or continues within them...
return false;
else
return true;
}
/** Postorder callback function for checking for varying breaks or
continues; decrement the varying control flow depth after the node's
children have been processed, if this is a varying if statement. */
static ASTNode *lVaryingBCPostFunc(ASTNode *node, void *d) {
VaryingBCCheckInfo *info = (VaryingBCCheckInfo *)d;
if (lIsVaryingFor(node))
--info->varyingControlFlowDepth;
return node;
}
/** Given a statment, walk through it to see if there is a 'break' or
'continue' statement inside if its children, under varying control
flow. We need to detect this case for loops since what might otherwise
look like a 'uniform' loop needs to have code emitted to do all of the
lane management stuff if this is the case.
*/
static bool lHasVaryingBreakOrContinue(Stmt *stmt) {
VaryingBCCheckInfo info;
WalkAST(stmt, lVaryingBCPreFunc, lVaryingBCPostFunc, &info);
return info.foundVaryingBreakOrContinue;
}
DoStmt::DoStmt(Expr *t, Stmt *s, bool cc, SourcePos p)
: Stmt(p, DoStmtID), testExpr(t), bodyStmts(s), doCoherentCheck(cc && !g->opt.disableCoherentControlFlow) {}
void DoStmt::EmitCode(FunctionEmitContext *ctx) const {
// Check for things that could be nullptr due to earlier errors during
// compilation.
if (!ctx->GetCurrentBasicBlock())
return;
if (!testExpr || !testExpr->GetType())
return;
bool uniformTest = testExpr->GetType()->IsUniformType();
if (uniformTest && doCoherentCheck)
Warning(testExpr->pos, "Uniform condition supplied to \"cdo\" "
"statement.");
llvm::BasicBlock *bloop = ctx->CreateBasicBlock("do_loop", ctx->GetCurrentBasicBlock());
llvm::BasicBlock *btest = ctx->CreateBasicBlock("do_test", bloop);
llvm::BasicBlock *bexit = ctx->CreateBasicBlock("do_exit", btest);
bool emulateUniform = false;
llvm::Instruction *branchInst = nullptr;
if (ctx->emitXeHardwareMask() && !uniformTest) {
/* With Xe target we generate uniform control flow but
emit varying using CM simdcf.any intrinsic. We mark the scope as
emulateUniform = true to let nested scopes know that they should
generate vector conditions before branching.
This is needed because CM does not support scalar control flow inside
simd control flow.
*/
uniformTest = true;
emulateUniform = true;
}
ctx->StartLoop(bexit, btest, uniformTest, emulateUniform);
// Start by jumping into the loop body
ctx->BranchInst(bloop);
// And now emit code for the loop body
ctx->SetCurrentBasicBlock(bloop);
ctx->SetBlockEntryMask(ctx->GetFullMask());
ctx->SetDebugPos(pos);
// FIXME: in the StmtList::EmitCode() method takes starts/stops a new
// scope around the statements in the list. So if the body is just a
// single statement (and thus not a statement list), we need a new
// scope, but we don't want two scopes in the StmtList case.
if (!bodyStmts || !llvm::dyn_cast<StmtList>(bodyStmts))
ctx->StartScope();
ctx->AddInstrumentationPoint("do loop body");
if (doCoherentCheck && !uniformTest) {
// Check to see if the mask is all on
llvm::BasicBlock *bAllOn = ctx->CreateBasicBlock("do_all_on");
llvm::BasicBlock *bMixed = ctx->CreateBasicBlock("do_mixed");
ctx->BranchIfMaskAll(bAllOn, bMixed);
// If so, emit code for the 'mask all on' case. In particular,
// explicitly set the mask to 'all on' (see rationale in
// IfStmt::emitCoherentTests()), and then emit the code for the
// loop body.
ctx->SetCurrentBasicBlock(bAllOn);
if (!g->opt.disableMaskAllOnOptimizations)
ctx->SetInternalMask(LLVMMaskAllOn);
llvm::Value *oldFunctionMask = ctx->GetFunctionMask();
if (!g->opt.disableMaskAllOnOptimizations)
ctx->SetFunctionMask(LLVMMaskAllOn);
if (bodyStmts)
bodyStmts->EmitCode(ctx);
AssertPos(pos, ctx->GetCurrentBasicBlock());
ctx->SetFunctionMask(oldFunctionMask);
ctx->BranchInst(btest);
// The mask is mixed. Just emit the code for the loop body.
ctx->SetCurrentBasicBlock(bMixed);
if (bodyStmts)
bodyStmts->EmitCode(ctx);
AssertPos(pos, ctx->GetCurrentBasicBlock());
ctx->BranchInst(btest);
} else {
// Otherwise just emit the code for the loop body. The current
// mask is good.
if (bodyStmts)
bodyStmts->EmitCode(ctx);
if (ctx->GetCurrentBasicBlock()) {
ctx->BranchInst(btest);
}
}
// End the scope we started above, if needed.
if (!bodyStmts || !llvm::dyn_cast<StmtList>(bodyStmts))
ctx->EndScope();
// Now emit code for the loop test.
ctx->SetCurrentBasicBlock(btest);
// First, emit code to restore the mask value for any lanes that
// executed a 'continue' during the current loop before we go and emit
// the code for the test. This is only necessary for varying loops;
// 'uniform' loops just jump when they hit a continue statement and
// don't mess with the mask.
if (!uniformTest) {
ctx->RestoreContinuedLanes();
ctx->ClearBreakLanes();
}
llvm::Value *testValue = testExpr->GetValue(ctx);
if (!testValue)
return;
if (uniformTest) {
// For the uniform case, just jump to the top of the loop or the
// exit basic block depending on the value of the test.
branchInst = ctx->BranchInst(bloop, bexit, testValue);
ctx->setLoopUnrollMetadata(branchInst, loopAttribute, pos);
} else {
// For the varying case, update the mask based on the value of the
// test. If any program instances still want to be running, jump
// to the top of the loop. Otherwise, jump out.
llvm::Value *mask = ctx->GetInternalMask();
ctx->SetInternalMaskAnd(mask, testValue);
ctx->BranchIfMaskAny(bloop, bexit);
}
// ...and we're done. Set things up for subsequent code to be emitted
// in the right basic block.
ctx->SetCurrentBasicBlock(bexit);
ctx->EndLoop();
}
Stmt *DoStmt::TypeCheck() {
const Type *testType;
if (testExpr != nullptr && (testType = testExpr->GetType()) != nullptr) {
if (testType->IsDependentType()) {
return this;
}
// Should the test condition for the loop be uniform or varying?
// It can be uniform only if three conditions are met:
//
// - First and foremost, the type of the test condition must be
// uniform.
//
// - Second, the user must not have set the dis-optimization option
// that disables uniform flow control.
//
// - Thirdly, and most subtlely, there must not be any break or
// continue statements inside the loop that are within the scope
// of a 'varying' if statement. If there are, then we type cast
// the test to be 'varying', so that the code generated for the
// loop includes masking stuff, so that we can track which lanes
// actually want to be running, accounting for breaks/continues.
//
bool uniformTest =
(testType->IsUniformType() && !g->opt.disableUniformControlFlow && !lHasVaryingBreakOrContinue(bodyStmts));
testExpr = TypeConvertExpr(testExpr, uniformTest ? AtomicType::UniformBool : AtomicType::VaryingBool,
"\"do\" statement");
}
return this;
}
static bool lLoopStmtUniformTest(Expr *expr, Stmt *stmts) {
if (expr) {
const Type *type = expr->GetType();
Assert(type);
return type->IsUniformType();
} else {
return (!g->opt.disableUniformControlFlow && !lHasVaryingBreakOrContinue(stmts));
}
}
void DoStmt::SetLoopAttribute(std::pair<Globals::pragmaUnrollType, int> lAttr) {
if (loopAttribute.first != Globals::pragmaUnrollType::none)
Error(pos, "Multiple '#pragma unroll/nounroll' directives used.");
if (lLoopStmtUniformTest(testExpr, bodyStmts)) {
loopAttribute = lAttr;
} else {
Warning(pos, "'#pragma unroll/nounroll' ignored - not supported for varying do loop.");
}
}
int DoStmt::EstimateCost() const {
return lLoopStmtUniformTest(testExpr, bodyStmts) ? COST_UNIFORM_LOOP : COST_VARYING_LOOP;
}
DoStmt *DoStmt::Instantiate(TemplateInstantiation &templInst) const {
Expr *instTestExpr = testExpr ? testExpr->Instantiate(templInst) : nullptr;
Stmt *instBodyStmts = bodyStmts ? bodyStmts->Instantiate(templInst) : nullptr;
return new DoStmt(instTestExpr, instBodyStmts, doCoherentCheck, pos);
}
void DoStmt::Print(Indent &indent) const {
indent.PrintLn("DoStmt", pos);
int totalChildren = (testExpr ? 1 : 0) + (bodyStmts ? 1 : 0);
indent.pushList(totalChildren);
if (testExpr) {
indent.setNextLabel("test");
testExpr->Print(indent);
}
if (bodyStmts) {
indent.setNextLabel("body");
bodyStmts->Print(indent);
}
indent.Done();
}
///////////////////////////////////////////////////////////////////////////
// ForStmt
ForStmt::ForStmt(Stmt *i, Expr *t, Stmt *s, Stmt *st, bool cc, SourcePos p)
: Stmt(p, ForStmtID), init(i), test(t), step(s), stmts(st),
doCoherentCheck(cc && !g->opt.disableCoherentControlFlow) {}
void ForStmt::EmitCode(FunctionEmitContext *ctx) const {
if (!ctx->GetCurrentBasicBlock())
return;
llvm::BasicBlock *btest = ctx->CreateBasicBlock("for_test", ctx->GetCurrentBasicBlock());
llvm::BasicBlock *bloop = ctx->CreateBasicBlock("for_loop", btest);
llvm::BasicBlock *bstep = ctx->CreateBasicBlock("for_step", bloop);
llvm::BasicBlock *bexit = ctx->CreateBasicBlock("for_exit", bstep);
bool uniformTest = test ? test->GetType()->IsUniformType()
: (!g->opt.disableUniformControlFlow && !lHasVaryingBreakOrContinue(stmts));
bool emulateUniform = false;
if (ctx->emitXeHardwareMask() && !uniformTest) {
/* With Xe target we generate uniform control flow but
emit varying using CM simdcf.any intrinsic. We mark the scope as
emulateUniform = true to let nested scopes know that they should
generate vector conditions before branching.
This is needed because CM does not support scalar control flow inside
simd control flow.
*/
uniformTest = true;
emulateUniform = true;
}
ctx->StartLoop(bexit, bstep, uniformTest, emulateUniform);
ctx->SetDebugPos(pos);
// If we have an initiailizer statement, start by emitting the code for
// it and then jump into the loop test code. (Also start a new scope
// since the initiailizer may be a declaration statement).
if (init) {
AssertPos(pos, llvm::dyn_cast<StmtList>(init) == nullptr);
ctx->StartScope();
init->EmitCode(ctx);
}
ctx->BranchInst(btest);
// Emit code to get the value of the loop test. If no test expression
// was provided, just go with a true value.
ctx->SetCurrentBasicBlock(btest);
llvm::Value *ltest = nullptr;
if (test) {
ltest = test->GetValue(ctx);
if (!ltest) {
// We need to end scope only if we had initializer statement.
if (init) {
ctx->EndScope();
}
ctx->EndLoop();
return;
}
} else
ltest = uniformTest ? LLVMTrue : LLVMBoolVector(true);
// Now use the test's value. For a uniform loop, we can either jump to
// the loop body or the loop exit, based on whether it's true or false.
// For a non-uniform loop, we update the mask and jump into the loop if
// any of the mask values are true.
if (uniformTest) {
if (doCoherentCheck && !emulateUniform)
if (test)
Warning(test->pos, "Uniform condition supplied to cfor/cwhile "
"statement.");
if (!ctx->emitXeHardwareMask())
AssertPos(pos, ltest->getType() == LLVMTypes::BoolType);
ctx->BranchInst(bloop, bexit, ltest);
} else {
llvm::Value *mask = ctx->GetInternalMask();
ctx->SetInternalMaskAnd(mask, ltest);
ctx->BranchIfMaskAny(bloop, bexit);
}
// On to emitting the code for the loop body.
ctx->SetCurrentBasicBlock(bloop);
ctx->SetBlockEntryMask(ctx->GetFullMask());
ctx->AddInstrumentationPoint("for loop body");
if (!llvm::dyn_cast_or_null<StmtList>(stmts))
ctx->StartScope();
if (doCoherentCheck && !uniformTest) {
// For 'varying' loops with the coherence check, we start by
// checking to see if the mask is all on, after it has been updated
// based on the value of the test.
llvm::BasicBlock *bAllOn = ctx->CreateBasicBlock("for_all_on");
llvm::BasicBlock *bMixed = ctx->CreateBasicBlock("for_mixed");
ctx->BranchIfMaskAll(bAllOn, bMixed);
// Emit code for the mask being all on. Explicitly set the mask to
// be on so that the optimizer can see that it's on (i.e. now that
// the runtime test has passed, make this fact clear for code
// generation at compile time here.)
ctx->SetCurrentBasicBlock(bAllOn);
if (!g->opt.disableMaskAllOnOptimizations)
ctx->SetInternalMask(LLVMMaskAllOn);
llvm::Value *oldFunctionMask = ctx->GetFunctionMask();
if (!g->opt.disableMaskAllOnOptimizations)
ctx->SetFunctionMask(LLVMMaskAllOn);
if (stmts)
stmts->EmitCode(ctx);
AssertPos(pos, ctx->GetCurrentBasicBlock());
ctx->SetFunctionMask(oldFunctionMask);
ctx->BranchInst(bstep);
// Emit code for the mask being mixed. We should never run the
// loop with the mask all off, based on the BranchIfMaskAny call
// above.
ctx->SetCurrentBasicBlock(bMixed);
if (stmts)
stmts->EmitCode(ctx);
ctx->BranchInst(bstep);
} else {
// For both uniform loops and varying loops without the coherence
// check, we know that at least one program instance wants to be
// running the loop, so just emit code for the loop body and jump
// to the loop step code.
if (stmts)
stmts->EmitCode(ctx);
if (ctx->GetCurrentBasicBlock())
ctx->BranchInst(bstep);
}
if (!llvm::dyn_cast_or_null<StmtList>(stmts))
ctx->EndScope();
// Emit code for the loop step. First, restore the lane mask of any
// program instances that executed a 'continue' during the previous
// iteration. Then emit code for the loop step and then jump to the
// test code.
ctx->SetCurrentBasicBlock(bstep);
ctx->RestoreContinuedLanes();
ctx->ClearBreakLanes();
if (step)
step->EmitCode(ctx);
llvm::Instruction *branchInst = ctx->BranchInst(btest);
ctx->setLoopUnrollMetadata(branchInst, loopAttribute, pos);
// Set the current emission basic block to the loop exit basic block
ctx->SetCurrentBasicBlock(bexit);
if (init)
ctx->EndScope();
ctx->EndLoop();
}
Stmt *ForStmt::TypeCheck() {
const Type *testType;
if (test && (testType = test->GetType()) != nullptr) {
if (testType->IsDependentType()) {
return this;
}
// See comments in DoStmt::TypeCheck() regarding
// 'uniformTest' and the type conversion here.
bool uniformTest =
(testType->IsUniformType() && !g->opt.disableUniformControlFlow && !lHasVaryingBreakOrContinue(stmts));
test = TypeConvertExpr(test, uniformTest ? AtomicType::UniformBool : AtomicType::VaryingBool,
"\"for\"/\"while\" statement");
if (test == nullptr)
return nullptr;
}
return this;
}
void ForStmt::SetLoopAttribute(std::pair<Globals::pragmaUnrollType, int> lAttr) {
if (loopAttribute.first != Globals::pragmaUnrollType::none)
Error(pos, "Multiple '#pragma unroll/nounroll' directives used.");
if (!lLoopStmtUniformTest(test, stmts)) {
PerformanceWarning(
pos,
"'#pragma unroll/nounroll' for varying for loop is slow. Try '#pragma unroll/nounroll' for foreach loop.");
}
loopAttribute = lAttr;
}
int ForStmt::EstimateCost() const { return lLoopStmtUniformTest(test, stmts) ? COST_UNIFORM_LOOP : COST_VARYING_LOOP; }
ForStmt *ForStmt::Instantiate(TemplateInstantiation &templInst) const {
Expr *instTestExpr = test ? test->Instantiate(templInst) : nullptr;
Stmt *instInitStmts = init ? init->Instantiate(templInst) : nullptr;
Stmt *instStepStmts = step ? step->Instantiate(templInst) : nullptr;
Stmt *instBodyStmts = stmts ? stmts->Instantiate(templInst) : nullptr;
return new ForStmt(instInitStmts, instTestExpr, instStepStmts, instBodyStmts, doCoherentCheck, pos);
}
void ForStmt::Print(Indent &indent) const {
indent.PrintLn("ForStmt", pos);
int totalChildren = (init ? 1 : 0) + (test ? 1 : 0) + (step ? 1 : 0) + (stmts ? 1 : 0);
indent.pushList(totalChildren);
if (init) {
indent.setNextLabel("init");
init->Print(indent);
}
if (test) {
indent.setNextLabel("test");
test->Print(indent);
}
if (step) {
indent.setNextLabel("step");
step->Print(indent);
}
if (stmts) {
indent.setNextLabel("stmts");
stmts->Print(indent);
}
indent.Done();
}
///////////////////////////////////////////////////////////////////////////
// BreakStmt
BreakStmt::BreakStmt(SourcePos p) : Stmt(p, BreakStmtID) {}
void BreakStmt::EmitCode(FunctionEmitContext *ctx) const {
if (!ctx->GetCurrentBasicBlock())
return;
ctx->SetDebugPos(pos);
ctx->Break(true);
}
Stmt *BreakStmt::TypeCheck() { return this; }
int BreakStmt::EstimateCost() const { return COST_BREAK_CONTINUE; }
BreakStmt *BreakStmt::Instantiate(TemplateInstantiation &templInst) const { return new BreakStmt(pos); }
void BreakStmt::Print(Indent &indent) const {
indent.PrintLn("BreakStmt", pos);
indent.Done();
}
///////////////////////////////////////////////////////////////////////////
// ContinueStmt
ContinueStmt::ContinueStmt(SourcePos p) : Stmt(p, ContinueStmtID) {}
void ContinueStmt::EmitCode(FunctionEmitContext *ctx) const {
if (!ctx->GetCurrentBasicBlock())
return;
ctx->SetDebugPos(pos);
ctx->Continue(true);
}
Stmt *ContinueStmt::TypeCheck() { return this; }
int ContinueStmt::EstimateCost() const { return COST_BREAK_CONTINUE; }
ContinueStmt *ContinueStmt::Instantiate(TemplateInstantiation &templInst) const { return new ContinueStmt(pos); }
void ContinueStmt::Print(Indent &indent) const {
indent.PrintLn("ContinueStmt", pos);
indent.Done();
}
///////////////////////////////////////////////////////////////////////////
// ForeachStmt
ForeachStmt::ForeachStmt(const std::vector<Symbol *> &lvs, const std::vector<Expr *> &se, const std::vector<Expr *> &ee,
Stmt *s, bool t, SourcePos pos)
: Stmt(pos, ForeachStmtID), dimVariables(lvs), startExprs(se), endExprs(ee), isTiled(t), stmts(s) {}
/* Calculate delta that should be added to varying counter
between iterations for given dimension.
*/
static llvm::Constant *lCalculateDeltaForVaryingCounter(int dim, int nDims, const std::vector<int> &spans) {
// Figure out the offsets; this is a little bit tricky. As an example,
// consider a 2D tiled foreach loop, where we're running 8-wide and
// where the inner dimension has a stride of 4 and the outer dimension
// has a stride of 2. For the inner dimension, we want the offsets
// (0,1,2,3,0,1,2,3), and for the outer dimension we want
// (0,0,0,0,1,1,1,1).
int32_t delta[ISPC_MAX_NVEC];
for (int i = 0; i < g->target->getVectorWidth(); ++i) {
int d = i;
// First, account for the effect of any dimensions at deeper
// nesting levels than the current one.
int prevDimSpanCount = 1;
for (int j = dim; j < nDims - 1; ++j)
prevDimSpanCount *= spans[j + 1];
d /= prevDimSpanCount;
// And now with what's left, figure out our own offset
delta[i] = d % spans[dim];
}
return LLVMInt32Vector(delta);
}
/* Given a uniform counter value in the memory location pointed to by
uniformCounterPtr, compute the corresponding set of varying counter
values for use within the loop body.
*/
static llvm::Value *lUpdateVaryingCounter(int dim, int nDims, FunctionEmitContext *ctx,
AddressInfo *uniformCounterPtrInfo, AddressInfo *varyingCounterPtrInfo,
const std::vector<int> &spans) {
// Smear the uniform counter value out to be varying
llvm::Value *counter = ctx->LoadInst(uniformCounterPtrInfo);
llvm::Value *smearCounter = ctx->BroadcastValue(counter, LLVMTypes::Int32VectorType, "smear_counter");
llvm::Constant *delta = lCalculateDeltaForVaryingCounter(dim, nDims, spans);
// Add the deltas to compute the varying counter values; store the
// result to memory and then return it directly as well.
llvm::Value *varyingCounter =
ctx->BinaryOperator(llvm::Instruction::Add, smearCounter, delta, WrapSemantics::NSW, "iter_val");
ctx->StoreInst(varyingCounter, varyingCounterPtrInfo);
return varyingCounter;
}
/** Returns the integer log2 of the given integer. */
static int lLog2(int i) {
int ret = 0;
while (i != 0) {
++ret;
i >>= 1;
}
return ret - 1;
}
/* Figure out how many elements to process in each dimension for each time
through a foreach loop. The untiled case is easy; all of the outer
dimensions up until the innermost one have a span of 1, and the
innermost one takes the entire vector width. For the tiled case, we
give wider spans to the innermost dimensions while also trying to
generate relatively square domains.
This code works recursively from outer dimensions to inner dimensions.
*/
static void lGetSpans(int dimsLeft, int nDims, int itemsLeft, bool isTiled, int *a) {
if (dimsLeft == 0) {
// Nothing left to do but give all of the remaining work to the
// innermost domain.
*a = itemsLeft;
return;
}
if (isTiled == false || (dimsLeft >= lLog2(itemsLeft)))
// If we're not tiled, or if there are enough dimensions left that
// giving this one any more than a span of one would mean that a
// later dimension would have to have a span of one, give this one
// a span of one to save the available items for later.
*a = 1;
else if (itemsLeft >= 16 && (dimsLeft == 1))
// Special case to have 4x4 domains for the 2D case when running
// 16-wide.
*a = 4;
else
// Otherwise give this dimension a span of two.
*a = 2;
lGetSpans(dimsLeft - 1, nDims, itemsLeft / *a, isTiled, a + 1);
}
/* Emit code for a foreach statement. We effectively emit code to run the
set of n-dimensional nested loops corresponding to the dimensionality of
the foreach statement along with the extra logic to deal with mismatches
between the vector width we're compiling to and the number of elements
to process.
*/
void ForeachStmt::EmitCode(FunctionEmitContext *ctx) const {
#ifdef ISPC_XE_ENABLED
if (ctx->emitXeHardwareMask()) {
EmitCodeForXe(ctx);
return;
}
#endif
if (ctx->GetCurrentBasicBlock() == nullptr || stmts == nullptr)
return;
llvm::BasicBlock *bbFullBody = ctx->CreateBasicBlock("foreach_full_body");
llvm::BasicBlock *bbMaskedBody = ctx->CreateBasicBlock("foreach_masked_body");
llvm::BasicBlock *bbExit = ctx->CreateBasicBlock("foreach_exit");
llvm::Value *oldMask = ctx->GetInternalMask();
llvm::Value *oldFunctionMask = ctx->GetFunctionMask();
ctx->SetDebugPos(pos);
ctx->StartScope();
ctx->SetInternalMask(LLVMMaskAllOn);
ctx->SetFunctionMask(LLVMMaskAllOn);
// This should be caught during typechecking
AssertPos(pos, startExprs.size() == dimVariables.size() && endExprs.size() == dimVariables.size());
int nDims = (int)dimVariables.size();
///////////////////////////////////////////////////////////////////////
// Setup: compute the number of items we have to work on in each
// dimension and a number of derived values.
std::vector<llvm::BasicBlock *> bbReset, bbStep, bbTest;
std::vector<llvm::Value *> startVals, endVals;
std::vector<llvm::Value *> nExtras, alignedEnd;
std::vector<AddressInfo *> uniformCounterPtrs, extrasMaskPtrs;
std::vector<int> span(nDims, 0);
lGetSpans(nDims - 1, nDims, g->target->getVectorWidth(), isTiled, &span[0]);
for (int i = 0; i < nDims; ++i) {
// Basic blocks that we'll fill in later with the looping logic for
// this dimension.
bbReset.push_back(ctx->CreateBasicBlock("foreach_reset"));
if (i < nDims - 1)
// stepping for the innermost dimension is handled specially
bbStep.push_back(ctx->CreateBasicBlock("foreach_step"));
bbTest.push_back(ctx->CreateBasicBlock("foreach_test"));
// Start and end value for this loop dimension
llvm::Value *sv = startExprs[i]->GetValue(ctx);
llvm::Value *ev = endExprs[i]->GetValue(ctx);
if (sv == nullptr || ev == nullptr)
return;
startVals.push_back(sv);
endVals.push_back(ev);
// nItems = endVal - startVal
llvm::Value *nItems = ctx->BinaryOperator(llvm::Instruction::Sub, ev, sv, WrapSemantics::NSW, "nitems");
// nExtras = nItems % (span for this dimension)
// This gives us the number of extra elements we need to deal with
// at the end of the loop for this dimension that don't fit cleanly
// into a vector width.
nExtras.push_back(
ctx->BinaryOperator(llvm::Instruction::SRem, nItems, LLVMInt32(span[i]), WrapSemantics::None, "nextras"));
// alignedEnd = endVal - nExtras
alignedEnd.push_back(
ctx->BinaryOperator(llvm::Instruction::Sub, ev, nExtras[i], WrapSemantics::NSW, "aligned_end"));
///////////////////////////////////////////////////////////////////////
// Each dimension has a loop counter that is a uniform value that
// goes from startVal to endVal, in steps of the span for this
// dimension. Its value is only used internally here for looping
// logic and isn't directly available in the user's program code.
uniformCounterPtrs.push_back(ctx->AllocaInst(LLVMTypes::Int32Type, "counter"));
ctx->StoreInst(startVals[i], uniformCounterPtrs[i]);
// There is also a varying variable that holds the set of index
// values for each dimension in the current loop iteration; this is
// the value that is program-visible.
dimVariables[i]->storageInfo = ctx->AllocaInst(LLVMTypes::Int32VectorType, dimVariables[i]->name.c_str());
dimVariables[i]->parentFunction = ctx->GetFunction();
ctx->EmitVariableDebugInfo(dimVariables[i]);
// Each dimension also maintains a mask that represents which of
// the varying elements in the current iteration should be
// processed. (i.e. this is used to disable the lanes that have
// out-of-bounds offsets.)
extrasMaskPtrs.push_back(ctx->AllocaInst(LLVMTypes::MaskType, "extras mask"));
ctx->StoreInst(LLVMMaskAllOn, extrasMaskPtrs[i]);
}
ctx->StartForeach(FunctionEmitContext::FOREACH_REGULAR);
// On to the outermost loop's test
llvm::Instruction *bbBIOuter = ctx->BranchInst(bbTest[0]);
ctx->setLoopUnrollMetadata(bbBIOuter, loopAttribute, pos);
///////////////////////////////////////////////////////////////////////////
// foreach_reset: this code runs when we need to reset the counter for
// a given dimension in preparation for running through its loop again,
// after the enclosing level advances its counter.
for (int i = 0; i < nDims; ++i) {
ctx->SetCurrentBasicBlock(bbReset[i]);
if (i == 0)
ctx->BranchInst(bbExit);
else {
ctx->StoreInst(LLVMMaskAllOn, extrasMaskPtrs[i]);
ctx->StoreInst(startVals[i], uniformCounterPtrs[i]);
ctx->BranchInst(bbStep[i - 1]);
}
}
///////////////////////////////////////////////////////////////////////////
// foreach_step: increment the uniform counter by the vector width.
// Note that we don't increment the varying counter here as well but
// just generate its value when we need it in the loop body. Don't do
// this for the innermost dimension, which has a more complex stepping
// structure..
for (int i = 0; i < nDims - 1; ++i) {
ctx->SetCurrentBasicBlock(bbStep[i]);
llvm::Value *counter = ctx->LoadInst(uniformCounterPtrs[i]);
llvm::Value *newCounter =
ctx->BinaryOperator(llvm::Instruction::Add, counter, LLVMInt32(span[i]), WrapSemantics::NSW, "new_counter");
ctx->StoreInst(newCounter, uniformCounterPtrs[i]);
ctx->BranchInst(bbTest[i]);
}
///////////////////////////////////////////////////////////////////////////
// foreach_test (for all dimensions other than the innermost...)
std::vector<llvm::Value *> inExtras;
for (int i = 0; i < nDims - 1; ++i) {
ctx->SetCurrentBasicBlock(bbTest[i]);
llvm::Value *haveExtras =
ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_SGT, endVals[i], alignedEnd[i], "have_extras");
llvm::Value *counter = ctx->LoadInst(uniformCounterPtrs[i], nullptr, "counter");
llvm::Value *atAlignedEnd =
ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_EQ, counter, alignedEnd[i], "at_aligned_end");
llvm::Value *inEx =
ctx->BinaryOperator(llvm::Instruction::And, haveExtras, atAlignedEnd, WrapSemantics::None, "in_extras");
if (i == 0)
inExtras.push_back(inEx);
else
inExtras.push_back(ctx->BinaryOperator(llvm::Instruction::Or, inEx, inExtras[i - 1], WrapSemantics::None,
"in_extras_all"));
llvm::Value *varyingCounter =
lUpdateVaryingCounter(i, nDims, ctx, uniformCounterPtrs[i], dimVariables[i]->storageInfo, span);
llvm::Value *smearEnd = ctx->BroadcastValue(endVals[i], LLVMTypes::Int32VectorType, "smear_end");
// Do a vector compare of its value to the end value to generate a
// mask for this last bit of work.
llvm::Value *emask = ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_SLT, varyingCounter, smearEnd);
emask = ctx->I1VecToBoolVec(emask);
if (i == 0)
ctx->StoreInst(emask, extrasMaskPtrs[i]);
else {
llvm::Value *oldMask = ctx->LoadInst(extrasMaskPtrs[i - 1]);
llvm::Value *newMask =
ctx->BinaryOperator(llvm::Instruction::And, oldMask, emask, WrapSemantics::None, "extras_mask");
ctx->StoreInst(newMask, extrasMaskPtrs[i]);
}
llvm::Value *notAtEnd = ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_SLT, counter, endVals[i]);
ctx->BranchInst(bbTest[i + 1], bbReset[i], notAtEnd);
}
///////////////////////////////////////////////////////////////////////////
// foreach_test (for innermost dimension)
//
// All of the outer dimensions are handled generically--basically as a
// for() loop from the start value to the end value, where at each loop
// test, we compute the mask of active elements for the current
// dimension and then update an overall mask that is the AND
// combination of all of the outer ones.
//
// The innermost loop is handled specially, for performance purposes.
// When starting the innermost dimension, we start by checking once
// whether any of the outer dimensions has set the mask to be
// partially-active or not. We follow different code paths for these
// two cases, taking advantage of the knowledge that the mask is all
// on, when this is the case.
//
// In each of these code paths, we start with a loop from the starting
// value to the aligned end value for the innermost dimension; we can
// guarantee that the innermost loop will have an "all on" mask (as far
// as its dimension is concerned) for the duration of this loop. Doing
// so allows us to emit code that assumes the mask is all on (for the
// case where none of the outer dimensions has set the mask to be
// partially on), or allows us to emit code that just uses the mask
// from the outer dimensions directly (for the case where they have).
//
// After this loop, we just need to deal with one vector's worth of
// "ragged extra bits", where the mask used includes the effect of the
// mask for the innermost dimension.
//
// We start out this process by emitting the check that determines
// whether any of the enclosing dimensions is partially active
// (i.e. processing extra elements that don't exactly fit into a
// vector).
llvm::BasicBlock *bbOuterInExtras = ctx->CreateBasicBlock("outer_in_extras");
llvm::BasicBlock *bbOuterNotInExtras = ctx->CreateBasicBlock("outer_not_in_extras");
ctx->SetCurrentBasicBlock(bbTest[nDims - 1]);
if (inExtras.size()) {
ctx->BranchInst(bbOuterInExtras, bbOuterNotInExtras, inExtras.back());
}
else
// for a 1D iteration domain, we certainly don't have any enclosing
// dimensions that are processing extra elements.
ctx->BranchInst(bbOuterNotInExtras);
///////////////////////////////////////////////////////////////////////////
// One or more outer dimensions in extras, so we need to mask for the loop
// body regardless. We break this into two cases, roughly:
// for (counter = start; counter < alignedEnd; counter += step) {
// // mask is all on for inner, so set mask to outer mask
// // run loop body with mask
// }
// // counter == alignedEnd
// if (counter < end) {
// // set mask to outermask & (counter+programCounter < end)
// // run loop body with mask
// }
llvm::BasicBlock *bbAllInnerPartialOuter = ctx->CreateBasicBlock("all_inner_partial_outer");
llvm::BasicBlock *bbPartial = ctx->CreateBasicBlock("both_partial");
ctx->SetCurrentBasicBlock(bbOuterInExtras);
{
// Update the varying counter value here, since all subsequent
// blocks along this path need it.
lUpdateVaryingCounter(nDims - 1, nDims, ctx, uniformCounterPtrs[nDims - 1],
dimVariables[nDims - 1]->storageInfo, span);
// here we just check to see if counter < alignedEnd
llvm::Value *counter = ctx->LoadInst(uniformCounterPtrs[nDims - 1], nullptr, "counter");
llvm::Value *beforeAlignedEnd = ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_SLT, counter,
alignedEnd[nDims - 1], "before_aligned_end");
ctx->BranchInst(bbAllInnerPartialOuter, bbPartial, beforeAlignedEnd);
}
// Below we have a basic block that runs the loop body code for the
// case where the mask is partially but not fully on. This same block
// runs in multiple cases: both for handling any ragged extra data for
// the innermost dimension but also when outer dimensions have set the
// mask to be partially on.
//
// The value stored in stepIndexAfterMaskedBodyPtrInfo is used after each
// execution of the body code to determine whether the innermost index
// value should be incremented by the step (we're running the "for"
// loop of full vectors at the innermost dimension, with outer
// dimensions having set the mask to be partially on), or whether we're
// running once for the ragged extra bits at the end of the innermost
// dimension, in which case we're done with the innermost dimension and
// should step the loop counter for the next enclosing dimension
// instead.
// Revisit : Should this be an i1.
AddressInfo *stepIndexAfterMaskedBodyPtrInfo = ctx->AllocaInst(LLVMTypes::BoolType, "step_index");
///////////////////////////////////////////////////////////////////////////
// We're in the inner loop part where the only masking is due to outer
// dimensions but the innermost dimension fits fully into a vector's
// width. Set the mask and jump to the masked loop body.
ctx->SetCurrentBasicBlock(bbAllInnerPartialOuter);
{
llvm::Value *mask;
if (nDims == 1)
// 1D loop; we shouldn't ever get here anyway
mask = LLVMMaskAllOff;
else
mask = ctx->LoadInst(extrasMaskPtrs[nDims - 2]);
ctx->SetInternalMask(mask);
ctx->StoreInst(LLVMTrue, stepIndexAfterMaskedBodyPtrInfo);
ctx->BranchInst(bbMaskedBody);
}
///////////////////////////////////////////////////////////////////////////
// We need to include the effect of the innermost dimension in the mask
// for the final bits here
ctx->SetCurrentBasicBlock(bbPartial);
{
llvm::Value *varyingCounter =
ctx->LoadInst(dimVariables[nDims - 1]->storageInfo, dimVariables[nDims - 1]->type);
llvm::Value *smearEnd = ctx->BroadcastValue(endVals[nDims - 1], LLVMTypes::Int32VectorType, "smear_end");
llvm::Value *emask = ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_SLT, varyingCounter, smearEnd);
emask = ctx->I1VecToBoolVec(emask);
if (nDims == 1) {
ctx->SetInternalMask(emask);
} else {
llvm::Value *oldMask = ctx->LoadInst(extrasMaskPtrs[nDims - 2]);
llvm::Value *newMask =
ctx->BinaryOperator(llvm::Instruction::And, oldMask, emask, WrapSemantics::None, "extras_mask");
ctx->SetInternalMask(newMask);
}
ctx->StoreInst(LLVMFalse, stepIndexAfterMaskedBodyPtrInfo);
// check to see if counter != end, otherwise, the next step is not necessary
llvm::Value *counter = ctx->LoadInst(uniformCounterPtrs[nDims - 1], nullptr, "counter");
llvm::Value *atEnd =
ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_NE, counter, endVals[nDims - 1], "at_end");
ctx->BranchInst(bbMaskedBody, bbReset[nDims - 1], atEnd);
}
///////////////////////////////////////////////////////////////////////////
// None of the outer dimensions is processing extras; along the lines
// of above, we can express this as:
// for (counter = start; counter < alignedEnd; counter += step) {
// // mask is all on
// // run loop body with mask all on
// }
// // counter == alignedEnd
// if (counter < end) {
// // set mask to (counter+programCounter < end)
// // run loop body with mask
// }
llvm::BasicBlock *bbPartialInnerAllOuter = ctx->CreateBasicBlock("partial_inner_all_outer");
ctx->SetCurrentBasicBlock(bbOuterNotInExtras);
{
llvm::Value *counter = ctx->LoadInst(uniformCounterPtrs[nDims - 1], nullptr, "counter");
llvm::Value *beforeAlignedEnd = ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_SLT, counter,
alignedEnd[nDims - 1], "before_aligned_end");
llvm::Instruction *bbBIOuterNotInExtras = ctx->BranchInst(bbFullBody, bbPartialInnerAllOuter, beforeAlignedEnd);
ctx->setLoopUnrollMetadata(bbBIOuterNotInExtras, loopAttribute, pos);
}
///////////////////////////////////////////////////////////////////////////
// full_body: do a full vector's worth of work. We know that all
// lanes will be running here, so we explicitly set the mask to be 'all
// on'. This ends up being relatively straightforward: just update the
// value of the varying loop counter and have the statements in the
// loop body emit their code.
llvm::BasicBlock *bbFullBodyContinue = ctx->CreateBasicBlock("foreach_full_continue");
ctx->SetCurrentBasicBlock(bbFullBody);
{
ctx->SetInternalMask(LLVMMaskAllOn);
ctx->SetBlockEntryMask(LLVMMaskAllOn);
lUpdateVaryingCounter(nDims - 1, nDims, ctx, uniformCounterPtrs[nDims - 1],
dimVariables[nDims - 1]->storageInfo, span);
ctx->SetContinueTarget(bbFullBodyContinue);
ctx->AddInstrumentationPoint("foreach loop body (all on)");
stmts->EmitCode(ctx);
AssertPos(pos, ctx->GetCurrentBasicBlock() != nullptr);
ctx->BranchInst(bbFullBodyContinue);
}
ctx->SetCurrentBasicBlock(bbFullBodyContinue);
{
ctx->RestoreContinuedLanes();
llvm::Value *counter = ctx->LoadInst(uniformCounterPtrs[nDims - 1]);
llvm::Value *newCounter = ctx->BinaryOperator(llvm::Instruction::Add, counter, LLVMInt32(span[nDims - 1]),
WrapSemantics::NSW, "new_counter");
ctx->StoreInst(newCounter, uniformCounterPtrs[nDims - 1]);
ctx->BranchInst(bbOuterNotInExtras);
}
///////////////////////////////////////////////////////////////////////////
// We're done running blocks with the mask all on; see if the counter is
// less than the end value, in which case we need to run the body one
// more time to get the extra bits.
llvm::BasicBlock *bbSetInnerMask = ctx->CreateBasicBlock("partial_inner_only");
ctx->SetCurrentBasicBlock(bbPartialInnerAllOuter);
{
llvm::Value *counter = ctx->LoadInst(uniformCounterPtrs[nDims - 1], nullptr, "counter");
llvm::Value *beforeFullEnd = ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_SLT, counter,
endVals[nDims - 1], "before_full_end");
ctx->BranchInst(bbSetInnerMask, bbReset[nDims - 1], beforeFullEnd);
}
///////////////////////////////////////////////////////////////////////////
// The outer dimensions are all on, so the mask is just given by the
// mask for the innermost dimension
ctx->SetCurrentBasicBlock(bbSetInnerMask);
{
llvm::Value *varyingCounter = lUpdateVaryingCounter(nDims - 1, nDims, ctx, uniformCounterPtrs[nDims - 1],
dimVariables[nDims - 1]->storageInfo, span);
llvm::Value *smearEnd = ctx->BroadcastValue(endVals[nDims - 1], LLVMTypes::Int32VectorType, "smear_end");
llvm::Value *emask = ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_SLT, varyingCounter, smearEnd);
emask = ctx->I1VecToBoolVec(emask);
ctx->SetInternalMask(emask);
ctx->SetBlockEntryMask(emask);
ctx->StoreInst(LLVMFalse, stepIndexAfterMaskedBodyPtrInfo);
ctx->BranchInst(bbMaskedBody);
}
///////////////////////////////////////////////////////////////////////////
// masked_body: set the mask and have the statements emit their
// code again. Note that it's generally worthwhile having two copies
// of the statements' code, since the code above is emitted with the
// mask known to be all-on, which in turn leads to more efficient code
// for that case.
llvm::BasicBlock *bbStepInnerIndex = ctx->CreateBasicBlock("step_inner_index");
llvm::BasicBlock *bbMaskedBodyContinue = ctx->CreateBasicBlock("foreach_masked_continue");
ctx->SetCurrentBasicBlock(bbMaskedBody);
{
ctx->AddInstrumentationPoint("foreach loop body (masked)");
ctx->SetContinueTarget(bbMaskedBodyContinue);
ctx->DisableGatherScatterWarnings();
ctx->SetBlockEntryMask(ctx->GetFullMask());
stmts->EmitCode(ctx);
ctx->EnableGatherScatterWarnings();
ctx->BranchInst(bbMaskedBodyContinue);
}
ctx->SetCurrentBasicBlock(bbMaskedBodyContinue);
{
ctx->RestoreContinuedLanes();
llvm::Value *stepIndex = ctx->LoadInst(stepIndexAfterMaskedBodyPtrInfo);
ctx->BranchInst(bbStepInnerIndex, bbReset[nDims - 1], stepIndex);
}
///////////////////////////////////////////////////////////////////////////
// step the innermost index, for the case where we're doing the
// innermost for loop over full vectors.
ctx->SetCurrentBasicBlock(bbStepInnerIndex);
{
llvm::Value *counter = ctx->LoadInst(uniformCounterPtrs[nDims - 1]);
llvm::Value *newCounter = ctx->BinaryOperator(llvm::Instruction::Add, counter, LLVMInt32(span[nDims - 1]),
WrapSemantics::NSW, "new_counter");
ctx->StoreInst(newCounter, uniformCounterPtrs[nDims - 1]);
ctx->BranchInst(bbOuterInExtras);
}
///////////////////////////////////////////////////////////////////////////
// foreach_exit: All done. Restore the old mask and clean up
ctx->SetCurrentBasicBlock(bbExit);
ctx->SetInternalMask(oldMask);
ctx->SetFunctionMask(oldFunctionMask);
ctx->EndForeach();
ctx->EndScope();
}
#ifdef ISPC_XE_ENABLED
/* Emit code for a foreach statement on Xe. We effectively emit code to run
the set of n-dimensional nested loops corresponding to the dimensionality of
the foreach statement along with the extra logic to deal with mismatches
between the vector width we're compiling to and the number of elements
to process. Handler logic is different from the other targets due to
Xe Execution Mask usage. We do not need to generate different bodies
for full and partial masks due to it.
*/
void ForeachStmt::EmitCodeForXe(FunctionEmitContext *ctx) const {
AssertPos(pos, g->target->isXeTarget());
if (ctx->GetCurrentBasicBlock() == nullptr || stmts == nullptr)
return;
// We store current EM and reset it to AllOn state.
llvm::Value *oldMask = ctx->GetInternalMask();
llvm::Value *oldFunctionMask = ctx->GetFunctionMask();
llvm::Value *execMask = nullptr;
if (g->opt.enableForeachInsideVarying) {
Warning(pos, "\"foreach\" statement is not optimized for Xe targets yet.");
ctx->SetInternalMask(LLVMMaskAllOn);
ctx->SetFunctionMask(LLVMMaskAllOn);
execMask = ctx->XeStartUnmaskedRegion();
} else {
Warning(pos, "\"foreach\" statement is not supported under varying CF for Xe targets yet. Make sure that"
" it is not called under varying CF or use \"--opt=enable-xe-foreach-varying\" to enable its "
"experimental support.");
}
llvm::BasicBlock *bbBody = ctx->CreateBasicBlock("foreach_body", ctx->GetCurrentBasicBlock());
llvm::BasicBlock *bbExit = ctx->CreateBasicBlock("foreach_exit", bbBody);
ctx->SetDebugPos(pos);
ctx->StartScope();
// This should be caught during typechecking
AssertPos(pos, startExprs.size() == dimVariables.size() && endExprs.size() == dimVariables.size());
int nDims = (int)dimVariables.size();
///////////////////////////////////////////////////////////////////////
// Setup: compute the number of items we have to work on in each
// dimension and a number of derived values.
std::vector<llvm::BasicBlock *> bbReset, bbTest, bbStep;
std::vector<llvm::Value *> startVals, endVals;
std::vector<llvm::Constant *> steps;
std::vector<int> span(nDims, 0);
lGetSpans(nDims - 1, nDims, g->target->getVectorWidth(), isTiled, &span[0]);
for (int i = 0; i < nDims; ++i) {
// Basic blocks that we'll fill in later with the looping logic for
// this dimension.
bbTest.push_back(ctx->CreateBasicBlock("foreach_test", i == 0 ? ctx->GetCurrentBasicBlock() : bbTest[i - 1]));
bbStep.push_back(ctx->CreateBasicBlock("foreach_step", bbBody));
bbReset.push_back(ctx->CreateBasicBlock("foreach_reset", bbStep[i]));
llvm::Value *sv = startExprs[i]->GetValue(ctx);
llvm::Value *ev = endExprs[i]->GetValue(ctx);
if (sv == nullptr || ev == nullptr)
return;
// Store varying start
sv = ctx->BroadcastValue(sv, LLVMTypes::Int32VectorType, "start_broadcast");
llvm::Constant *delta = lCalculateDeltaForVaryingCounter(i, nDims, span);
sv = ctx->BinaryOperator(llvm::Instruction::Add, sv, delta, WrapSemantics::NSW, "varying_start");
startVals.push_back(sv);
// Store broadcasted end values
ev = ctx->BroadcastValue(ev, LLVMTypes::Int32VectorType, "end_broadcast");
endVals.push_back(ev);
// Store vectorized step
llvm::Constant *step = LLVMInt32Vector(span[i]);
steps.push_back(step);
// Init vectorized counters
dimVariables[i]->storageInfo = ctx->AllocaInst(LLVMTypes::Int32VectorType, dimVariables[i]->name.c_str());
dimVariables[i]->parentFunction = ctx->GetFunction();
ctx->StoreInst(sv, dimVariables[i]->storageInfo);
ctx->EmitVariableDebugInfo(dimVariables[i]);
}
// Officially start foreach. Emulating uniform for proper continue handlers.
ctx->StartForeach(FunctionEmitContext::FOREACH_REGULAR, true);
// Jump to outermost test block
ctx->BranchInst(bbTest[0]);
///////////////////////////////////////////////////////////////////////////
// foreach_reset: this code runs when we need to reset the counter for
// a given dimension in preparation for running through its loop again,
// after the enclosing level advances its counter.
for (int i = 0; i < nDims; ++i) {
ctx->SetCurrentBasicBlock(bbReset[i]);
if (i == 0)
// Outermost loop finished - exit
ctx->BranchInst(bbExit);
else {
// Reset counter for this dimension, iterate over previous one
ctx->StoreInst(startVals[i], dimVariables[i]->storageInfo);
ctx->BranchInst(bbStep[i - 1]);
}
}
///////////////////////////////////////////////////////////////////////////
// foreach_step: iterate counters with step that was calculated before
// entering foreach.
for (int i = 0; i < nDims; ++i) {
ctx->SetCurrentBasicBlock(bbStep[i]);
llvm::Value *counter = ctx->LoadInst(dimVariables[i]->storageInfo);
llvm::Value *newCounter =
ctx->BinaryOperator(llvm::Instruction::Add, counter, steps[i], WrapSemantics::NSW, "new_counter");
ctx->StoreInst(newCounter, dimVariables[i]->storageInfo);
ctx->BranchInst(bbTest[i]);
}
///////////////////////////////////////////////////////////////////////////
// foreach_test: compare varying counter with end value and branch to
// target or reset. Xe EM magic happens here: we turn off all lanes
// that fail check until reset is reached. And reset is reached only when
// all lanes fail this check due to test -> target -> step -> test loop.
//
// It looks tricky for multidimensional case. Suppose we have 3 dimensional
// loop, some lanes were turn off in the second dimension. When we reach
// reset in the innermost one (3rd) we won't be able to reset lanes that were
// turned off in the second dimension. But we don't actually need to reset
// them: they were reseted right before test of the second dimension turned
// them off. So after 2nd dimension's reset there will be reseted 2nd and 3rd
// counters. If some of lanes were turned off in the first dimension all
// this stuff doesn't matter: we will exit loop after this iteration anyway.
for (int i = 0; i < nDims; ++i) {
ctx->SetCurrentBasicBlock(bbTest[i]);
llvm::Value *val = ctx->LoadInst(dimVariables[i]->storageInfo, nullptr, "val");
llvm::Value *checkVal = ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_SLT, val, endVals[i]);
// Target is body for innermost dimension, next dimension test for others
llvm::BasicBlock *targetBB = (i < nDims - 1) ? bbTest[i + 1] : bbBody;
// Turn off lanes untill reset is reached
ctx->BranchInst(targetBB, bbReset[i], checkVal);
}
///////////////////////////////////////////////////////////////////////////
// foreach_body: emit code for loop body. Execution is driven by
// Xe Execution Mask.
ctx->SetCurrentBasicBlock(bbBody);
ctx->SetContinueTarget(bbStep[nDims - 1]);
ctx->AddInstrumentationPoint("foreach loop body");
stmts->EmitCode(ctx);
AssertPos(pos, ctx->GetCurrentBasicBlock() != nullptr);
ctx->BranchInst(bbStep[nDims - 1]);
///////////////////////////////////////////////////////////////////////////
// foreach_exit: All done. Restore the old mask and clean up
ctx->SetCurrentBasicBlock(bbExit);
// Restore execution mask from value that was saved at the beginning
if (execMask != nullptr) {
ctx->XeEndUnmaskedRegion(execMask);
ctx->SetInternalMask(oldMask);
ctx->SetFunctionMask(oldFunctionMask);
}
ctx->EndForeach();
ctx->EndScope();
}
#endif
Stmt *ForeachStmt::TypeCheck() {
for (auto expr : startExprs) {
const Type *t = expr ? expr->GetType() : nullptr;
if (t && t->IsDependentType()) {
return this;
}
}
for (auto expr : endExprs) {
const Type *t = expr ? expr->GetType() : nullptr;
if (t && t->IsDependentType()) {
return this;
}
}
bool anyErrors = false;
for (unsigned int i = 0; i < startExprs.size(); ++i) {
if (startExprs[i] != nullptr)
startExprs[i] = TypeConvertExpr(startExprs[i], AtomicType::UniformInt32, "foreach starting value");
anyErrors |= (startExprs[i] == nullptr);
}
for (unsigned int i = 0; i < endExprs.size(); ++i) {
if (endExprs[i] != nullptr)
endExprs[i] = TypeConvertExpr(endExprs[i], AtomicType::UniformInt32, "foreach ending value");
anyErrors |= (endExprs[i] == nullptr);
}
if (startExprs.size() < dimVariables.size()) {
Error(pos,
"Not enough initial values provided for \"foreach\" loop; "
"got %d, expected %d\n",
(int)startExprs.size(), (int)dimVariables.size());
anyErrors = true;
} else if (startExprs.size() > dimVariables.size()) {
Error(pos,
"Too many initial values provided for \"foreach\" loop; "
"got %d, expected %d\n",
(int)startExprs.size(), (int)dimVariables.size());
anyErrors = true;
}
if (endExprs.size() < dimVariables.size()) {
Error(pos,
"Not enough initial values provided for \"foreach\" loop; "
"got %d, expected %d\n",
(int)endExprs.size(), (int)dimVariables.size());
anyErrors = true;
} else if (endExprs.size() > dimVariables.size()) {
Error(pos,
"Too many initial values provided for \"foreach\" loop; "
"got %d, expected %d\n",
(int)endExprs.size(), (int)dimVariables.size());
anyErrors = true;
}
return anyErrors ? nullptr : this;
}
void ForeachStmt::SetLoopAttribute(std::pair<Globals::pragmaUnrollType, int> lAttr) {
if (loopAttribute.first != Globals::pragmaUnrollType::none) {
Error(pos, "Multiple '#pragma unroll/nounroll' directives used.");
}
loopAttribute = lAttr;
}
int ForeachStmt::EstimateCost() const { return dimVariables.size() * (COST_UNIFORM_LOOP + COST_SIMPLE_ARITH_LOGIC_OP); }
ForeachStmt *ForeachStmt::Instantiate(TemplateInstantiation &templInst) const {
std::vector<Symbol *> instDimVariables;
std::vector<Expr *> instStartExprs;
std::vector<Expr *> instEndExprs;
for (auto dimVar : dimVariables) {
Symbol *instDimVar = templInst.InstantiateSymbol(dimVar);
instDimVariables.push_back(instDimVar);
}
for (auto startExpr : startExprs) {
Expr *instStartExpr = startExpr ? startExpr->Instantiate(templInst) : nullptr;
instStartExprs.push_back(instStartExpr);
}
for (auto endExpr : endExprs) {
Expr *instEndExpr = endExpr ? endExpr->Instantiate(templInst) : nullptr;
instEndExprs.push_back(instEndExpr);
}
Stmt *instStmts = stmts ? stmts->Instantiate(templInst) : nullptr;
ForeachStmt *inst = new ForeachStmt(instDimVariables, instStartExprs, instEndExprs, instStmts, isTiled, pos);
inst->loopAttribute = loopAttribute;
return inst;
}
void ForeachStmt::Print(Indent &indent) const {
indent.PrintLn("ForeachStmt", pos);
int totalChildren = dimVariables.size() + (stmts ? 1 : 0);
indent.pushList(totalChildren);
for (unsigned int i = 0; i < dimVariables.size(); ++i) {
char buffer[15];
snprintf(buffer, 15, "index var %d\n", i);
indent.Print(buffer);
{
indent.pushList(3);
indent.setNextLabel("var");
indent.Print();
if (dimVariables[i] != nullptr) {
printf("%s\n", dimVariables[i]->name.c_str());
} else {
printf("<NULL>\n");
}
indent.Done();
indent.setNextLabel("start value");
if (i < startExprs.size() && startExprs[i] != nullptr) {
startExprs[i]->Print(indent);
} else {
indent.Print("<NULL>");
indent.Done();
}
indent.setNextLabel("end value");
if (i < endExprs.size() && endExprs[i] != nullptr) {
endExprs[i]->Print(indent);
} else {
indent.Print("<NULL>");
indent.Done();
}
}
indent.Done();
}
if (stmts != nullptr) {
indent.setNextLabel("body");
stmts->Print(indent);
}
indent.Done();
}
///////////////////////////////////////////////////////////////////////////
// ForeachActiveStmt
ForeachActiveStmt::ForeachActiveStmt(Symbol *s, Stmt *st, SourcePos pos) : Stmt(pos, ForeachActiveStmtID) {
sym = s;
stmts = st;
}
void ForeachActiveStmt::EmitCode(FunctionEmitContext *ctx) const {
if (!ctx->GetCurrentBasicBlock())
return;
// Allocate storage for the symbol that we'll use for the uniform
// variable that holds the current program instance in each loop
// iteration.
if (sym->type == nullptr) {
Assert(m->errorCount > 0);
return;
}
Assert(Type::Equal(sym->type, AtomicType::UniformInt64->GetAsConstType()));
sym->storageInfo = ctx->AllocaInst(LLVMTypes::Int64Type, sym->name.c_str());
ctx->SetDebugPos(pos);
ctx->EmitVariableDebugInfo(sym);
// The various basic blocks that we'll need in the below
llvm::BasicBlock *bbFindNext = ctx->CreateBasicBlock("foreach_active_find_next", ctx->GetCurrentBasicBlock());
llvm::BasicBlock *bbBody = ctx->CreateBasicBlock("foreach_active_body", bbFindNext);
llvm::BasicBlock *bbCheckForMore = ctx->CreateBasicBlock("foreach_active_check_for_more", bbBody);
llvm::BasicBlock *bbDone = ctx->CreateBasicBlock("foreach_active_done", bbCheckForMore);
// Save the old mask so that we can restore it at the end
llvm::Value *oldInternalMask = ctx->GetInternalMask();
// Now, *maskBitsPtr will maintain a bitmask for the lanes that remain
// to be processed by a pass through the loop body. It starts out with
// the current execution mask (which should never be all off going in
// to this)...
llvm::Value *oldFullMask = nullptr;
bool uniformEmulated = false;
#ifdef ISPC_XE_ENABLED
if (ctx->emitXeHardwareMask()) {
// Emulate uniform to make proper continue handler
uniformEmulated = true;
// Current mask will be calculated according to EM mask
oldFullMask = ctx->XeSimdCFPredicate(LLVMMaskAllOn);
} else
#endif
oldFullMask = ctx->GetFullMask();
AddressInfo *maskBitsPtrInfo = ctx->AllocaInst(LLVMTypes::Int64Type, "mask_bits");
llvm::Value *movmsk = ctx->LaneMask(oldFullMask);
ctx->StoreInst(movmsk, maskBitsPtrInfo);
// Officially start the loop.
ctx->StartScope();
ctx->StartForeach(FunctionEmitContext::FOREACH_ACTIVE, uniformEmulated);
ctx->SetContinueTarget(bbCheckForMore);
// Onward to find the first set of program instance to run the loop for
ctx->BranchInst(bbFindNext);
ctx->SetCurrentBasicBlock(bbFindNext);
{
// Load the bitmask of the lanes left to be processed
llvm::Value *remainingBits = ctx->LoadInst(maskBitsPtrInfo, nullptr, "remaining_bits");
// Find the index of the first set bit in the mask
llvm::Function *ctlzFunc = m->module->getFunction("__count_trailing_zeros_i64");
Assert(ctlzFunc != nullptr);
llvm::Value *firstSet = ctx->CallInst(ctlzFunc, nullptr, remainingBits, "first_set");
// Store that value into the storage allocated for the iteration
// variable.
ctx->StoreInst(firstSet, sym->storageInfo, sym->type);
// Now set the execution mask to be only on for the current program
// instance. (TODO: is there a more efficient way to do this? e.g.
// for AVX1, we might want to do this as float rather than int
// math...)
// Get the "program index" vector value
llvm::Value *programIndex = ctx->ProgramIndexVector();
// And smear the current lane out to a vector
llvm::Value *firstSet32 = ctx->TruncInst(firstSet, LLVMTypes::Int32Type, "first_set32");
llvm::Value *firstSet32Smear = ctx->SmearUniform(firstSet32);
// Now set the execution mask based on doing a vector compare of
// these two
llvm::Value *iterMask =
ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_EQ, firstSet32Smear, programIndex);
iterMask = ctx->I1VecToBoolVec(iterMask);
// Don't need to change this mask in XE: execution
// is performed according to Xe EM
if (!ctx->emitXeHardwareMask())
ctx->SetInternalMask(iterMask);
// Also update the bitvector of lanes left to turn off the bit for
// the lane we're about to run.
llvm::Value *setMask =
ctx->BinaryOperator(llvm::Instruction::Shl, LLVMInt64(1), firstSet, WrapSemantics::None, "set_mask");
llvm::Value *notSetMask = ctx->NotOperator(setMask);
llvm::Value *newRemaining = ctx->BinaryOperator(llvm::Instruction::And, remainingBits, notSetMask,
WrapSemantics::None, "new_remaining");
ctx->StoreInst(newRemaining, maskBitsPtrInfo);
// and onward to run the loop body...
// Set Xe EM through simdcf.goto
// The EM will be restored when CheckForMore is reached
if (ctx->emitXeHardwareMask()) {
ctx->BranchInst(bbBody, bbCheckForMore, iterMask);
} else {
ctx->BranchInst(bbBody);
}
}
ctx->SetCurrentBasicBlock(bbBody);
{
ctx->SetBlockEntryMask(ctx->GetFullMask());
// Run the code in the body of the loop. This is easy now.
if (stmts)
stmts->EmitCode(ctx);
Assert(ctx->GetCurrentBasicBlock() != nullptr);
ctx->BranchInst(bbCheckForMore);
}
ctx->SetCurrentBasicBlock(bbCheckForMore);
{
ctx->RestoreContinuedLanes();
// At the end of the loop body (either due to running the
// statements normally, or a continue statement in the middle of
// the loop that jumps to the end, see if there are any lanes left
// to be processed.
llvm::Value *remainingBits = ctx->LoadInst(maskBitsPtrInfo, nullptr, "remaining_bits");
llvm::Value *nonZero = ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_NE, remainingBits,
LLVMInt64(0), "remaining_ne_zero");
ctx->BranchInst(bbFindNext, bbDone, nonZero);
}
ctx->SetCurrentBasicBlock(bbDone);
ctx->SetInternalMask(oldInternalMask);
ctx->EndForeach();
ctx->EndScope();
}
void ForeachActiveStmt::Print(Indent &indent) const {
indent.PrintLn("ForeachActiveStmt", pos);
indent.pushList(2);
indent.setNextLabel("iter symbol");
indent.Print();
if (sym != nullptr) {
printf("%s", sym->name.c_str());
if (sym->type != nullptr)
printf(" %s", sym->type->GetString().c_str());
} else {
printf("NULL");
}
printf("\n");
indent.Done();
indent.setNextLabel("body");
if (stmts != nullptr) {
stmts->Print(indent);
} else {
indent.Print("<NULL>");
indent.Done();
}
indent.Done();
}
Stmt *ForeachActiveStmt::TypeCheck() {
if (sym == nullptr)
return nullptr;
return this;
}
void ForeachActiveStmt::SetLoopAttribute(std::pair<Globals::pragmaUnrollType, int> lAttr) {
Warning(pos, "'#pragma unroll/nounroll' ignored - not supported for foreach_active loop.");
}
int ForeachActiveStmt::EstimateCost() const { return COST_VARYING_LOOP; }
ForeachActiveStmt *ForeachActiveStmt::Instantiate(TemplateInstantiation &templInst) const {
Symbol *instSym = templInst.InstantiateSymbol(sym);
Stmt *instStmts = stmts ? stmts->Instantiate(templInst) : nullptr;
ForeachActiveStmt *inst = new ForeachActiveStmt(instSym, instStmts, pos);
inst->loopAttribute = loopAttribute;
return inst;
}
///////////////////////////////////////////////////////////////////////////
// ForeachUniqueStmt
ForeachUniqueStmt::ForeachUniqueStmt(const char *iterName, Expr *e, Stmt *s, SourcePos pos)
: Stmt(pos, ForeachUniqueStmtID), expr(e), stmts(s) {
sym = m->symbolTable->LookupVariable(iterName);
}
ForeachUniqueStmt::ForeachUniqueStmt(Symbol *symbol, Expr *e, Stmt *s, SourcePos pos)
: Stmt(pos, ForeachUniqueStmtID), sym(symbol), expr(e), stmts(s) {}
void ForeachUniqueStmt::EmitCode(FunctionEmitContext *ctx) const {
if (!ctx->GetCurrentBasicBlock())
return;
// First, allocate local storage for the symbol that we'll use for the
// uniform variable that holds the current unique value through each
// loop.
if (sym->type == nullptr) {
Assert(m->errorCount > 0);
return;
}
llvm::Type *symType = sym->type->LLVMType(g->ctx);
if (symType == nullptr) {
Assert(m->errorCount > 0);
return;
}
sym->storageInfo = ctx->AllocaInst(sym->type, sym->name.c_str());
ctx->SetDebugPos(pos);
ctx->EmitVariableDebugInfo(sym);
// The various basic blocks that we'll need in the below
llvm::BasicBlock *bbFindNext = ctx->CreateBasicBlock("foreach_find_next", ctx->GetCurrentBasicBlock());
llvm::BasicBlock *bbBody = ctx->CreateBasicBlock("foreach_body", bbFindNext);
llvm::BasicBlock *bbCheckForMore = ctx->CreateBasicBlock("foreach_check_for_more", bbBody);
llvm::BasicBlock *bbDone = ctx->CreateBasicBlock("foreach_done", bbCheckForMore);
// Prepare the FunctionEmitContext
ctx->StartScope();
// Save the old internal mask so that we can restore it at the end
llvm::Value *oldMask = ctx->GetInternalMask();
// Now, *maskBitsPtr will maintain a bitmask for the lanes that remain
// to be processed by a pass through the foreach_unique loop body. It
// starts out with the full execution mask (which should never be all
// off going in to this)...
llvm::Value *oldFullMask = nullptr;
bool emulatedUniform = false;
#ifdef ISPC_XE_ENABLED
if (ctx->emitXeHardwareMask()) {
// Emulating uniform behavior for proper continue handling
emulatedUniform = true;
// Current mask will be calculated according to EM mask
oldFullMask = ctx->XeSimdCFPredicate(LLVMMaskAllOn);
} else
#endif
oldFullMask = ctx->GetFullMask();
AddressInfo *maskBitsPtrInfo = ctx->AllocaInst(LLVMTypes::Int64Type, "mask_bits");
llvm::Value *movmsk = ctx->LaneMask(oldFullMask);
ctx->StoreInst(movmsk, maskBitsPtrInfo);
// Officially start the loop.
ctx->StartForeach(FunctionEmitContext::FOREACH_UNIQUE, emulatedUniform);
ctx->SetContinueTarget(bbCheckForMore);
// Evaluate the varying expression we're iterating over just once.
llvm::Value *exprValue = expr->GetValue(ctx);
// And we'll store its value into locally-allocated storage, for ease
// of indexing over it with non-compile-time-constant indices.
const Type *exprType;
if (exprValue == nullptr || (exprType = expr->GetType()) == nullptr ||
llvm::dyn_cast<llvm::VectorType>(exprValue->getType()) == nullptr) {
Assert(m->errorCount > 0);
return;
}
ctx->SetDebugPos(pos);
AddressInfo *exprMem = ctx->AllocaInst(exprType, "expr_mem");
ctx->StoreInst(exprValue, exprMem, exprType);
// Onward to find the first set of lanes to run the loop for
ctx->BranchInst(bbFindNext);
ctx->SetCurrentBasicBlock(bbFindNext);
{
// Load the bitmask of the lanes left to be processed
llvm::Value *remainingBits = ctx->LoadInst(maskBitsPtrInfo, nullptr, "remaining_bits");
// Find the index of the first set bit in the mask
llvm::Function *ctlzFunc = m->module->getFunction("__count_trailing_zeros_i64");
Assert(ctlzFunc != nullptr);
llvm::Value *firstSet = ctx->CallInst(ctlzFunc, nullptr, remainingBits, "first_set");
// And load the corresponding element value from the temporary
// memory storing the value of the varying expr.
llvm::Value *uniqueValue;
// Load plus EEI is more preferable way to get unique value than GEP + load.
// It allows better register utilization for Xe targets and reduces allocas
// number for both CPU and Xe.
llvm::Value *uniqueValueVec = ctx->LoadInst(exprMem, exprType, "unique_value_vec");
Assert(llvm::dyn_cast<llvm::VectorType>(uniqueValueVec->getType()) != nullptr);
uniqueValue =
llvm::ExtractElementInst::Create(uniqueValueVec, firstSet, "unique_value", ctx->GetCurrentBasicBlock());
// If it's a varying pointer type, need to convert from the int
// type we store in the vector to the actual pointer type
if (llvm::dyn_cast<llvm::PointerType>(symType) != nullptr)
uniqueValue = ctx->IntToPtrInst(uniqueValue, symType);
Assert(uniqueValue != nullptr);
// Store that value in sym's storage so that the iteration variable
// has the right value inside the loop body
ctx->StoreInst(uniqueValue, sym->storageInfo, sym->type);
// Set the execution mask so that it's on for any lane that a) was
// running at the start of the foreach loop, and b) where that
// lane's value of the varying expression is the same as the value
// we've selected to process this time through--i.e.:
// oldMask & (smear(element) == exprValue)
llvm::Value *uniqueSmear = ctx->SmearUniform(uniqueValue, "unique_smear");
llvm::Value *matchingLanes = nullptr;
if (uniqueValue->getType()->isFloatingPointTy())
matchingLanes = ctx->CmpInst(llvm::Instruction::FCmp, llvm::CmpInst::FCMP_OEQ, uniqueSmear, exprValue,
"matching_lanes");
else
matchingLanes =
ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_EQ, uniqueSmear, exprValue, "matching_lanes");
matchingLanes = ctx->I1VecToBoolVec(matchingLanes);
llvm::Value *loopMask = ctx->BinaryOperator(llvm::Instruction::And, oldMask, matchingLanes, WrapSemantics::None,
"foreach_unique_loop_mask");
// Don't need to change this mask in XE: execution
// is performed according to Xe EM
if (!ctx->emitXeHardwareMask())
ctx->SetInternalMask(loopMask);
// Also update the bitvector of lanes left to process in subsequent
// loop iterations:
// remainingBits &= ~movmsk(current mask)
llvm::Value *loopMaskMM = ctx->LaneMask(loopMask);
llvm::Value *notLoopMaskMM = ctx->NotOperator(loopMaskMM);
llvm::Value *newRemaining = ctx->BinaryOperator(llvm::Instruction::And, remainingBits, notLoopMaskMM,
WrapSemantics::None, "new_remaining");
ctx->StoreInst(newRemaining, maskBitsPtrInfo);
// and onward...
// Set Xe EM through simdcf.goto
// The EM will be restored when CheckForMore is reached
if (ctx->emitXeHardwareMask()) {
ctx->BranchInst(bbBody, bbCheckForMore, loopMask);
} else {
ctx->BranchInst(bbBody);
}
}
ctx->SetCurrentBasicBlock(bbBody);
{
ctx->SetBlockEntryMask(ctx->GetFullMask());
// Run the code in the body of the loop. This is easy now.
if (stmts)
stmts->EmitCode(ctx);
Assert(ctx->GetCurrentBasicBlock() != nullptr);
ctx->BranchInst(bbCheckForMore);
}
ctx->SetCurrentBasicBlock(bbCheckForMore);
{
// At the end of the loop body (either due to running the
// statements normally, or a continue statement in the middle of
// the loop that jumps to the end, see if there are any lanes left
// to be processed.
ctx->RestoreContinuedLanes();
llvm::Value *remainingBits = ctx->LoadInst(maskBitsPtrInfo, nullptr, "remaining_bits");
llvm::Value *nonZero = ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_NE, remainingBits,
LLVMInt64(0), "remaining_ne_zero");
ctx->BranchInst(bbFindNext, bbDone, nonZero);
}
ctx->SetCurrentBasicBlock(bbDone);
ctx->SetInternalMask(oldMask);
ctx->EndForeach();
ctx->EndScope();
}
void ForeachUniqueStmt::Print(Indent &indent) const {
indent.PrintLn("ForeachUniqueStmt", pos);
indent.pushList(3);
indent.setNextLabel("iter symbol");
indent.Print();
if (sym != nullptr) {
printf("%s", sym->name.c_str());
if (sym->type != nullptr)
printf(" %s", sym->type->GetString().c_str());
} else {
printf("NULL");
}
printf("\n");
indent.Done();
indent.setNextLabel("iter expr");
if (expr != nullptr) {
expr->Print(indent);
} else {
indent.Print("NULL\n");
indent.Done();
}
indent.setNextLabel("body");
if (stmts != nullptr) {
stmts->Print(indent);
} else {
indent.Print("NULL\n");
indent.Done();
}
indent.Done();
}
Stmt *ForeachUniqueStmt::TypeCheck() {
const Type *type;
if (sym == nullptr || expr == nullptr || (type = expr->GetType()) == nullptr)
return nullptr;
if (type->IsDependentType()) {
return this;
}
if (type->IsVaryingType() == false) {
Error(expr->pos,
"Iteration domain type in \"foreach_tiled\" loop "
"must be \"varying\" type, not \"%s\".",
type->GetString().c_str());
return nullptr;
}
if (Type::IsBasicType(type) == false) {
Error(expr->pos,
"Iteration domain type in \"foreach_tiled\" loop "
"must be an atomic, pointer, or enum type, not \"%s\".",
type->GetString().c_str());
return nullptr;
}
return this;
}
void ForeachUniqueStmt::SetLoopAttribute(std::pair<Globals::pragmaUnrollType, int> lAttr) {
Warning(pos, "'#pragma unroll/nounroll' ignored - not supported for foreach_unique loop.");
}
int ForeachUniqueStmt::EstimateCost() const { return COST_VARYING_LOOP; }
ForeachUniqueStmt *ForeachUniqueStmt::Instantiate(TemplateInstantiation &templInst) const {
Expr *instExpr = expr ? expr->Instantiate(templInst) : nullptr;
Stmt *instStmts = stmts ? stmts->Instantiate(templInst) : nullptr;
Symbol *instSym = templInst.InstantiateSymbol(sym);
ForeachUniqueStmt *inst = new ForeachUniqueStmt(instSym, instExpr, instStmts, pos);
inst->loopAttribute = loopAttribute;
return inst;
}
///////////////////////////////////////////////////////////////////////////
// CaseStmt
/** Given the statements following a 'case' or 'default' label, this
function determines whether the mask should be checked to see if it is
"all off" immediately after the label, before executing the code for
the statements.
*/
static bool lCheckMask(Stmt *stmts) {
if (stmts == nullptr)
return false;
int cost = EstimateCost(stmts);
bool safeToRunWithAllLanesOff = SafeToRunWithMaskAllOff(stmts);
// The mask should be checked if the code following the
// 'case'/'default' is relatively complex, or if it would be unsafe to
// run that code with the execution mask all off.
return (cost > PREDICATE_SAFE_IF_STATEMENT_COST || safeToRunWithAllLanesOff == false);
}
CaseStmt::CaseStmt(int v, Stmt *s, SourcePos pos) : Stmt(pos, CaseStmtID), value(v) { stmts = s; }
void CaseStmt::EmitCode(FunctionEmitContext *ctx) const {
ctx->EmitCaseLabel(value, lCheckMask(stmts), pos);
if (stmts)
stmts->EmitCode(ctx);
}
void CaseStmt::Print(Indent &indent) const {
indent.Print("CaseStmt", pos);
printf("Value: %d\n", value);
indent.pushSingle();
stmts->Print(indent);
indent.Done();
}
Stmt *CaseStmt::TypeCheck() { return this; }
int CaseStmt::EstimateCost() const { return 0; }
CaseStmt *CaseStmt::Instantiate(TemplateInstantiation &templInst) const {
Stmt *instStmts = stmts ? stmts->Instantiate(templInst) : nullptr;
return new CaseStmt(value, instStmts, pos);
}
///////////////////////////////////////////////////////////////////////////
// DefaultStmt
DefaultStmt::DefaultStmt(Stmt *s, SourcePos pos) : Stmt(pos, DefaultStmtID) { stmts = s; }
void DefaultStmt::EmitCode(FunctionEmitContext *ctx) const {
ctx->EmitDefaultLabel(lCheckMask(stmts), pos);
if (stmts)
stmts->EmitCode(ctx);
}
void DefaultStmt::Print(Indent &indent) const {
indent.PrintLn("DefaultStmt", pos);
indent.pushSingle();
stmts->Print(indent);
indent.Done();
}
Stmt *DefaultStmt::TypeCheck() { return this; }
int DefaultStmt::EstimateCost() const { return 0; }
DefaultStmt *DefaultStmt::Instantiate(TemplateInstantiation &templInst) const {
Stmt *instStmts = stmts ? stmts->Instantiate(templInst) : nullptr;
return new DefaultStmt(instStmts, pos);
}
///////////////////////////////////////////////////////////////////////////
// SwitchStmt
SwitchStmt::SwitchStmt(Expr *e, Stmt *s, SourcePos pos) : Stmt(pos, SwitchStmtID) {
expr = e;
stmts = s;
}
/* An instance of this structure is carried along as we traverse the AST
nodes for the statements after a "switch" statement. We use this
structure to record all of the 'case' and 'default' statements after the
"switch". */
struct SwitchVisitInfo {
SwitchVisitInfo(FunctionEmitContext *c) {
ctx = c;
defaultBlock = nullptr;
lastBlock = nullptr;
insertAfter = ctx->GetCurrentBasicBlock();
}
FunctionEmitContext *ctx;
/* Basic block for the code following the "default" label (if any). */
llvm::BasicBlock *defaultBlock;
/* Map from integer values after "case" labels to the basic blocks that
follow the corresponding "case" label. */
std::vector<std::pair<int, llvm::BasicBlock *>> caseBlocks;
/* For each basic block for a "case" label or a "default" label,
nextBlock[block] stores the basic block pointer for the next
subsequent "case" or "default" label in the program. */
std::map<llvm::BasicBlock *, llvm::BasicBlock *> nextBlock;
/* The last basic block created for a "case" or "default" label; when
we create the basic block for the next one, we'll use this to update
the nextBlock map<> above. */
llvm::BasicBlock *lastBlock;
llvm::BasicBlock *insertAfter;
};
static bool lSwitchASTPreVisit(ASTNode *node, void *d) {
if (llvm::dyn_cast<SwitchStmt>(node) != nullptr)
// don't continue recursively into a nested switch--we only want
// our own case and default statements!
return false;
CaseStmt *cs = llvm::dyn_cast<CaseStmt>(node);
DefaultStmt *ds = llvm::dyn_cast<DefaultStmt>(node);
SwitchVisitInfo *svi = (SwitchVisitInfo *)d;
llvm::BasicBlock *bb = nullptr;
if (cs != nullptr) {
// Complain if we've seen a case statement with the same value
// already
for (int i = 0; i < (int)svi->caseBlocks.size(); ++i) {
if (svi->caseBlocks[i].first == cs->value) {
Error(cs->pos, "Duplicate case value \"%d\".", cs->value);
return true;
}
}
// Otherwise create a new basic block for the code following this
// 'case' statement and record the mappign between the case label
// value and the basic block
char buf[32];
snprintf(buf, sizeof(buf), "case_%d", cs->value);
bb = svi->ctx->CreateBasicBlock(buf, svi->insertAfter);
svi->caseBlocks.push_back(std::make_pair(cs->value, bb));
} else if (ds != nullptr) {
// And complain if we've seen another 'default' label..
if (svi->defaultBlock != nullptr) {
Error(ds->pos, "Multiple \"default\" lables in switch statement.");
return true;
} else {
// Otherwise create a basic block for the code following the
// "default".
bb = svi->ctx->CreateBasicBlock("default", svi->insertAfter);
svi->defaultBlock = bb;
}
}
// If we saw a "case" or "default" label, then update the map to record
// that the block we just created follows the block created for the
// previous label in the "switch".
if (bb != nullptr) {
svi->nextBlock[svi->lastBlock] = bb;
svi->lastBlock = bb;
svi->insertAfter = bb;
}
return true;
}
void SwitchStmt::EmitCode(FunctionEmitContext *ctx) const {
if (ctx->GetCurrentBasicBlock() == nullptr)
return;
const Type *type;
if (expr == nullptr || ((type = expr->GetType()) == nullptr)) {
AssertPos(pos, m->errorCount > 0);
return;
}
// Basic block we'll end up after the switch statement
llvm::BasicBlock *bbDone = ctx->CreateBasicBlock("switch_done", ctx->GetCurrentBasicBlock());
// Walk the AST of the statements after the 'switch' to collect a bunch
// of information about the structure of the 'case' and 'default'
// statements.
SwitchVisitInfo svi(ctx);
WalkAST(stmts, lSwitchASTPreVisit, nullptr, &svi);
// Record that the basic block following the last one created for a
// case/default is the block after the end of the switch statement.
svi.nextBlock[svi.lastBlock] = bbDone;
llvm::Value *exprValue = expr->GetValue(ctx);
if (exprValue == nullptr) {
AssertPos(pos, m->errorCount > 0);
return;
}
bool isUniformCF = (type->IsUniformType() && lHasVaryingBreakOrContinue(stmts) == false);
bool emulateUniform = false;
#ifdef ISPC_XE_ENABLED
if (ctx->emitXeHardwareMask()) {
if (isUniformCF && ctx->inXeSimdCF()) {
// Broadcast value to work with EM. We are doing
// it here because it is too late to make CMP
// broadcast through BranchInst: we need vectorized
// case checks to be able to reenable fall through
// cases under emulated uniform CF.
llvm::Type *vecType = (exprValue->getType() == LLVMTypes::Int32Type) ? LLVMTypes::Int32VectorType
: LLVMTypes::Int64VectorType;
exprValue = ctx->BroadcastValue(exprValue, vecType, "switch_expr_broadcast");
emulateUniform = true;
}
if (!isUniformCF) {
isUniformCF = true;
emulateUniform = true;
}
}
#endif
ctx->StartSwitch(isUniformCF, bbDone, emulateUniform);
ctx->SetBlockEntryMask(ctx->GetFullMask());
ctx->SwitchInst(exprValue, svi.defaultBlock ? svi.defaultBlock : bbDone, svi.caseBlocks, svi.nextBlock);
if (stmts != nullptr)
stmts->EmitCode(ctx);
if (ctx->GetCurrentBasicBlock() != nullptr)
ctx->BranchInst(bbDone);
ctx->SetCurrentBasicBlock(bbDone);
ctx->EndSwitch();
}
void SwitchStmt::Print(Indent &indent) const {
indent.PrintLn("SwitchStmt", pos);
indent.pushList(2);
indent.setNextLabel("expr");
expr->Print(indent);
indent.setNextLabel("stmts");
stmts->Print(indent);
indent.Done();
}
Stmt *SwitchStmt::TypeCheck() {
const Type *exprType;
if (expr == nullptr || (exprType = expr->GetType()) == nullptr) {
Assert(m->errorCount > 0);
return nullptr;
}
if (exprType->IsDependentType()) {
return this;
}
const Type *toType = nullptr;
exprType = exprType->GetAsConstType();
bool is64bit = (Type::EqualIgnoringConst(exprType->GetAsUniformType(), AtomicType::UniformUInt64) ||
Type::EqualIgnoringConst(exprType->GetAsUniformType(), AtomicType::UniformInt64));
if (exprType->IsUniformType()) {
if (is64bit)
toType = AtomicType::UniformInt64;
else
toType = AtomicType::UniformInt32;
} else {
if (is64bit)
toType = AtomicType::VaryingInt64;
else
toType = AtomicType::VaryingInt32;
}
expr = TypeConvertExpr(expr, toType, "switch expression");
if (expr == nullptr)
return nullptr;
return this;
}
int SwitchStmt::EstimateCost() const {
const Type *type = expr->GetType();
if (type && type->IsVaryingType())
return COST_VARYING_SWITCH;
else
return COST_UNIFORM_SWITCH;
}
SwitchStmt *SwitchStmt::Instantiate(TemplateInstantiation &templInst) const {
Expr *instExpr = expr ? expr->Instantiate(templInst) : nullptr;
Stmt *instStmts = stmts ? stmts->Instantiate(templInst) : nullptr;
return new SwitchStmt(instExpr, instStmts, pos);
}
///////////////////////////////////////////////////////////////////////////
// UnmaskedStmt
UnmaskedStmt::UnmaskedStmt(Stmt *s, SourcePos pos) : Stmt(pos, UnmaskedStmtID) { stmts = s; }
void UnmaskedStmt::EmitCode(FunctionEmitContext *ctx) const {
if (!ctx->GetCurrentBasicBlock() || !stmts)
return;
llvm::Value *oldInternalMask = ctx->GetInternalMask();
llvm::Value *oldFunctionMask = ctx->GetFunctionMask();
ctx->SetInternalMask(LLVMMaskAllOn);
ctx->SetFunctionMask(LLVMMaskAllOn);
if (!ctx->emitXeHardwareMask()) {
stmts->EmitCode(ctx);
} else {
#ifdef ISPC_XE_ENABLED
// For Xe we insert special intrinsics at the beginning and end of unmasked region.
// Correct execution mask will be set in CMSIMDCFLowering
llvm::Value *oldInternalMask = ctx->XeStartUnmaskedRegion();
stmts->EmitCode(ctx);
ctx->XeEndUnmaskedRegion(oldInternalMask);
#endif
}
// Do not restore old mask if our basic block is over. This happends if we emit code
// for something like 'unmasked{return;}', for example.
if (ctx->GetCurrentBasicBlock() == nullptr)
return;
ctx->SetInternalMask(oldInternalMask);
ctx->SetFunctionMask(oldFunctionMask);
}
void UnmaskedStmt::Print(Indent &indent) const {
indent.PrintLn("UnmaskedStmt", pos);
indent.pushSingle();
if (stmts != nullptr) {
stmts->Print(indent);
} else {
indent.Print("NULL\n");
indent.Done();
}
indent.Done();
}
Stmt *UnmaskedStmt::TypeCheck() { return this; }
int UnmaskedStmt::EstimateCost() const { return COST_ASSIGN; }
UnmaskedStmt *UnmaskedStmt::Instantiate(TemplateInstantiation &templInst) const {
Stmt *instStmts = stmts ? stmts->Instantiate(templInst) : nullptr;
return new UnmaskedStmt(instStmts, pos);
}
///////////////////////////////////////////////////////////////////////////
// ReturnStmt
ReturnStmt::ReturnStmt(Expr *e, SourcePos p) : Stmt(p, ReturnStmtID), expr(e) {}
void ReturnStmt::EmitCode(FunctionEmitContext *ctx) const {
if (!ctx->GetCurrentBasicBlock())
return;
if (ctx->InForeachLoop()) {
Error(pos, "\"return\" statement is illegal inside a \"foreach\" loop.");
return;
}
// Make sure we're not trying to return a reference to something where
// that doesn't make sense
const Function *func = ctx->GetFunction();
const Type *returnType = func->GetReturnType();
if (IsReferenceType(returnType) == true && IsReferenceType(expr->GetType()) == false) {
const Type *lvType = expr->GetLValueType();
if (lvType == nullptr) {
Error(expr->pos,
"Illegal to return non-lvalue from function "
"returning reference type \"%s\".",
returnType->GetString().c_str());
return;
} else if (lvType->IsUniformType() == false) {
Error(expr->pos,
"Illegal to return varying lvalue type from "
"function returning a reference type \"%s\".",
returnType->GetString().c_str());
return;
}
}
ctx->SetDebugPos(pos);
ctx->CurrentLanesReturned(expr, true);
}
Stmt *ReturnStmt::TypeCheck() { return this; }
int ReturnStmt::EstimateCost() const { return COST_RETURN; }
ReturnStmt *ReturnStmt::Instantiate(TemplateInstantiation &templInst) const {
Expr *instExpr = expr ? expr->Instantiate(templInst) : nullptr;
return new ReturnStmt(instExpr, pos);
}
void ReturnStmt::Print(Indent &indent) const {
indent.Print("ReturnStmt", pos);
if (expr) {
printf("\n");
indent.pushSingle();
expr->Print(indent);
} else {
printf("(void)\n");
}
indent.Done();
}
///////////////////////////////////////////////////////////////////////////
// GotoStmt
GotoStmt::GotoStmt(const char *l, SourcePos gotoPos, SourcePos ip) : Stmt(gotoPos, GotoStmtID) {
label = l;
identifierPos = ip;
}
void GotoStmt::EmitCode(FunctionEmitContext *ctx) const {
if (!ctx->GetCurrentBasicBlock())
return;
#ifdef ISPC_XE_ENABLED
if ((ctx->emitXeHardwareMask() && ctx->inXeSimdCF()) || ctx->VaryingCFDepth() > 0) {
#else
if (ctx->VaryingCFDepth() > 0) {
#endif
Error(pos, "\"goto\" statements are only legal under \"uniform\" "
"control flow.");
return;
}
if (ctx->InForeachLoop()) {
Error(pos, "\"goto\" statements are currently illegal inside "
"\"foreach\" loops.");
return;
}
llvm::BasicBlock *bb = ctx->GetLabeledBasicBlock(label);
if (bb == nullptr) {
/* Label wasn't found. Look for suggestions that are close */
std::vector<std::string> labels = ctx->GetLabels();
std::vector<std::string> matches = MatchStrings(label, labels);
std::string match_output;
if (!matches.empty()) {
/* Print up to 5 matches. Don't want to spew too much */
match_output += "\nDid you mean:";
for (unsigned int i = 0; i < matches.size() && i < 5; i++)
match_output += "\n " + matches[i] + "?";
}
/* Label wasn't found. Emit an error */
Error(identifierPos, "No label named \"%s\" found in current function.%s", label.c_str(), match_output.c_str());
return;
}
ctx->BranchInst(bb);
ctx->SetCurrentBasicBlock(nullptr);
}
void GotoStmt::Print(Indent &indent) const {
indent.Print("GotoStmt", pos);
printf("Label: %s\n", label.c_str());
indent.Done();
}
Stmt *GotoStmt::Optimize() { return this; }
Stmt *GotoStmt::TypeCheck() { return this; }
int GotoStmt::EstimateCost() const { return COST_GOTO; }
GotoStmt *GotoStmt::Instantiate(TemplateInstantiation &templInst) const {
return new GotoStmt(label.c_str(), pos, identifierPos);
;
}
///////////////////////////////////////////////////////////////////////////
// LabeledStmt
LabeledStmt::LabeledStmt(const char *n, Stmt *s, SourcePos p) : Stmt(p, LabeledStmtID) {
name = n;
stmt = s;
}
void LabeledStmt::EmitCode(FunctionEmitContext *ctx) const {
llvm::BasicBlock *bblock = ctx->GetLabeledBasicBlock(name);
AssertPos(pos, bblock != nullptr);
// End the current basic block with a jump to our basic block and then
// set things up for emission to continue there. Note that the current
// basic block may validly be nullptr going into this statement due to an
// earlier goto that nullptr'ed it out; that doesn't stop us from
// re-establishing a current basic block starting at the label..
if (ctx->GetCurrentBasicBlock() != nullptr)
ctx->BranchInst(bblock);
ctx->SetCurrentBasicBlock(bblock);
if (stmt != nullptr)
stmt->EmitCode(ctx);
}
void LabeledStmt::Print(Indent &indent) const {
indent.Print("LabeledStmt", pos);
printf("Label: %s\n", name.c_str());
indent.pushSingle();
if (stmt != nullptr) {
stmt->Print(indent);
} else {
indent.Print("<NULL STMT>\n");
indent.Done();
}
indent.Done();
}
Stmt *LabeledStmt::Optimize() { return this; }
Stmt *LabeledStmt::TypeCheck() {
if (!isalpha(name[0]) || name[0] == '_') {
Error(pos, "Label must start with either alphabetic character or '_'.");
return nullptr;
}
for (unsigned int i = 1; i < name.size(); ++i) {
if (!isalnum(name[i]) && name[i] != '_') {
Error(pos, "Character \"%c\" is illegal in labels.", name[i]);
return nullptr;
}
}
return this;
}
int LabeledStmt::EstimateCost() const { return 0; }
LabeledStmt *LabeledStmt::Instantiate(TemplateInstantiation &templInst) const {
Stmt *instStmt = stmt ? stmt->Instantiate(templInst) : nullptr;
return new LabeledStmt(name.c_str(), instStmt, pos);
}
///////////////////////////////////////////////////////////////////////////
// StmtList
void StmtList::EmitCode(FunctionEmitContext *ctx) const {
ctx->StartScope();
ctx->SetDebugPos(pos);
for (unsigned int i = 0; i < stmts.size(); ++i)
if (stmts[i])
stmts[i]->EmitCode(ctx);
ctx->EndScope();
}
Stmt *StmtList::TypeCheck() { return this; }
int StmtList::EstimateCost() const { return 0; }
StmtList *StmtList::Instantiate(TemplateInstantiation &templInst) const {
StmtList *inst = new StmtList(pos);
for (auto stmt : stmts) {
inst->Add(stmt ? stmt->Instantiate(templInst) : nullptr);
}
return inst;
}
void StmtList::Print(Indent &indent) const {
indent.PrintLn("StmtList", pos);
indent.pushList(stmts.size());
for (unsigned int i = 0; i < stmts.size(); ++i) {
if (stmts[i]) {
stmts[i]->Print(indent);
} else {
indent.Print("<NULL STMT>\n");
indent.Done();
}
}
indent.Done();
}
///////////////////////////////////////////////////////////////////////////
// PrintStmt
PrintStmt::PrintStmt(const std::string &f, Expr *v, SourcePos p) : Stmt(p, PrintStmtID), format(f), values(v) {}
/* Because the pointers to values that are passed to __do_print() are all
void *s (and because ispc print() formatting strings statements don't
encode types), we pass along a string to __do_print() where the i'th
character encodes the type of the i'th value to be printed. The encoding
is defined by getEncoding4Uniform<T> and getEncding4Varying<T> functions.
*/
static char lEncodeType(const Type *t) {
if (Type::Equal(t, AtomicType::UniformBool))
return PrintInfo::getEncoding4Uniform<bool>();
if (Type::Equal(t, AtomicType::VaryingBool))
return PrintInfo::getEncoding4Varying<bool>();
if (Type::Equal(t, AtomicType::UniformInt32))
return PrintInfo::getEncoding4Uniform<int>();
if (Type::Equal(t, AtomicType::VaryingInt32))
return PrintInfo::getEncoding4Varying<int>();
if (Type::Equal(t, AtomicType::UniformUInt32))
return PrintInfo::getEncoding4Uniform<unsigned>();
if (Type::Equal(t, AtomicType::VaryingUInt32))
return PrintInfo::getEncoding4Varying<unsigned>();
if (Type::Equal(t, AtomicType::UniformFloat))
return PrintInfo::getEncoding4Uniform<float>();
if (Type::Equal(t, AtomicType::VaryingFloat))
return PrintInfo::getEncoding4Varying<float>();
if (Type::Equal(t, AtomicType::UniformInt64))
return PrintInfo::getEncoding4Uniform<long long>();
if (Type::Equal(t, AtomicType::VaryingInt64))
return PrintInfo::getEncoding4Varying<long long>();
if (Type::Equal(t, AtomicType::UniformUInt64))
return PrintInfo::getEncoding4Uniform<unsigned long long>();
if (Type::Equal(t, AtomicType::VaryingUInt64))
return PrintInfo::getEncoding4Varying<unsigned long long>();
if (Type::Equal(t, AtomicType::UniformDouble))
return PrintInfo::getEncoding4Uniform<double>();
if (Type::Equal(t, AtomicType::VaryingDouble))
return PrintInfo::getEncoding4Varying<double>();
if (CastType<PointerType>(t) != nullptr) {
if (t->IsUniformType())
return PrintInfo::getEncoding4Uniform<void *>();
else
return PrintInfo::getEncoding4Varying<void *>();
} else
return '\0';
}
struct ExprWithType {
Expr *expr;
char type;
};
static ExprWithType lProcessPrintArgType(Expr *expr) {
const Type *type = expr->GetType();
if (type == nullptr)
return {nullptr, '\0'};
if (CastType<ReferenceType>(type) != nullptr) {
expr = new RefDerefExpr(expr, expr->pos);
type = expr->GetType();
if (type == nullptr)
return {nullptr, '\0'};
}
// Just int8 and int16 types to int32s...
// Also ensure 'varying bool' is excluded since it's baseType can be one
// of these types.
const Type *baseType = type->GetAsNonConstType()->GetAsUniformType();
if ((Type::Equal(baseType, AtomicType::UniformInt8) || Type::Equal(baseType, AtomicType::UniformUInt8) ||
Type::Equal(baseType, AtomicType::UniformInt16) || Type::Equal(baseType, AtomicType::UniformUInt16)) &&
!type->IsBoolType()) {
expr = new TypeCastExpr(type->IsUniformType() ? AtomicType::UniformInt32 : AtomicType::VaryingInt32, expr,
expr->pos);
type = expr->GetType();
}
if (Type::Equal(baseType, AtomicType::UniformFloat16)) {
expr = new TypeCastExpr(type->IsUniformType() ? AtomicType::UniformFloat : AtomicType::VaryingFloat, expr,
expr->pos);
type = expr->GetType();
}
char t = lEncodeType(type->GetAsNonConstType());
if (t == '\0') {
Error(expr->pos,
"Only atomic types are allowed in print statements; "
"type \"%s\" is illegal.",
type->GetString().c_str());
return {nullptr, '\0'};
}
if (type->IsBoolType()) {
// Blast bools to ints, but do it here to preserve encoding for
// printing 'true' or 'false'
expr = new TypeCastExpr(type->IsUniformType() ? AtomicType::UniformInt32 : AtomicType::VaryingInt32, expr,
expr->pos);
}
return {expr, t};
}
// Returns pointer to __do_print function
static llvm::Function *getPrintImplFunc() {
Assert(g->target->isXeTarget() == false);
llvm::Function *printImplFunc = m->module->getFunction("__do_print");
return printImplFunc;
}
// Check if number of requested arguments in format string corresponds to actual number of arguments
static bool checkFormatString(const std::string &format, const int nArgs, const SourcePos &pos) {
// We do not allow escape percent sign in ISPC as %%, so treat it as two args
const int argsInFormat = std::count(format.begin(), format.end(), '%');
if (nArgs < argsInFormat) {
Error(pos, "Not enough arguments are provided in print call");
return false;
} else if (nArgs > argsInFormat) {
Error(pos, "Too much arguments are provided in print call");
return false;
}
return true;
}
#ifdef ISPC_XE_ENABLED
// Builds args for OCL printf function based on ISPC print args.
class PrintArgsBuilder {
// properly dereferenced and size extended value expressions
std::vector<ExprWithType> argExprs;
FunctionEmitContext *ctx;
struct AdditionalData {
llvm::Value *mask{nullptr};
enum { LeftParenthesisIdx = 0, RightParenthesisIdx, EmptyIdx, FalseIdx, TrueIdx, NumStrings };
std::array<llvm::Value *, NumStrings> strings{};
AdditionalData() {}
AdditionalData(FunctionEmitContext *ctx) {
if (ctx->emitXeHardwareMask())
mask = ctx->XeSimdCFPredicate(LLVMMaskAllOn);
else
mask = ctx->GetFullMask();
strings[AdditionalData::LeftParenthesisIdx] =
ctx->XeGetOrCreateConstantString("((", "ispc.print.left.parenthesis");
strings[AdditionalData::RightParenthesisIdx] =
ctx->XeGetOrCreateConstantString("))", "ispc.print.right.parenthesis");
strings[AdditionalData::EmptyIdx] = ctx->XeGetOrCreateConstantString("", "ispc.print.empty");
strings[AdditionalData::FalseIdx] = ctx->XeGetOrCreateConstantString("false", "ispc.print.false");
strings[AdditionalData::TrueIdx] = ctx->XeGetOrCreateConstantString("true", "ispc.print.true");
}
};
AdditionalData data;
public:
PrintArgsBuilder() : ctx{nullptr} {}
PrintArgsBuilder(FunctionEmitContext *ctxIn) : argExprs{}, ctx{ctxIn}, data{ctxIn} {}
template <typename Iter>
PrintArgsBuilder(Iter first, Iter last, FunctionEmitContext *ctxIn) : PrintArgsBuilder{ctxIn} {
std::transform(first, last, std::back_inserter(argExprs),
[](Expr *expr) { return lProcessPrintArgType(expr); });
std::for_each(argExprs.cbegin(), argExprs.cend(),
[](const ExprWithType &elem) { Assert(elem.expr && "must have all values processed"); });
}
// Returns new args builder with subset of original args.
// Subset is defined with pair of indexes, element with \p end index is not included.
PrintArgsBuilder extract(int beg, int end) const {
Assert(beg >= 0 && beg <= argExprs.size() && end >= 0 && end <= argExprs.size() &&
"wrong argument: index is out of bound");
Assert(beg <= end && "wrong arguments: beg must preceed end");
PrintArgsBuilder extraction(ctx, data);
std::copy(std::next(argExprs.begin(), beg), std::next(argExprs.begin(), end),
std::back_inserter(extraction.argExprs));
return extraction;
}
// Combine all arg types into a continuous string.
std::string generateArgTypes() const {
std::string argTypes;
std::transform(argExprs.cbegin(), argExprs.cend(), std::back_inserter(argTypes),
[](const ExprWithType &argInfo) { return argInfo.type; });
return argTypes;
}
// Emit code for OCL printf arguments.
// Each generated arg is returned in the vector.
std::vector<llvm::Value *> emitArgCode() const {
if (argExprs.empty())
return {};
std::vector<llvm::Value *> Args;
// It would require at least the same amount of args. More if there're vector args.
Args.reserve(argExprs.size());
for (const ExprWithType &argInfo : argExprs)
writeRawArg(*argInfo.expr->GetValue(ctx), static_cast<PrintInfo::Encoding>(argInfo.type),
std::back_inserter(Args));
return Args;
}
private:
PrintArgsBuilder(FunctionEmitContext *ctxIn, AdditionalData data) : argExprs{}, ctx{ctxIn}, data{std::move(data)} {}
// Emit code for ISPC print uniform arg.
// Most arg types are unchanged. Boolean arg should be transformed into string argument.
// Pointers to generated llvm::Value are stored into output iterator \p OutIt.
template <typename OutIter>
OutIter writeRawUniformArg(llvm::Value &rawArg, PrintInfo::Encoding type, OutIter OutIt) const {
if (type != PrintInfo::getEncoding4Uniform<bool>()) {
*OutIt++ = &rawArg;
return OutIt;
}
auto *argAsPred = ctx->CmpInst(llvm::Instruction::OtherOps::ICmp, llvm::CmpInst::Predicate::ICMP_NE, &rawArg,
LLVMInt32(0), "print.arg.bool.cast");
auto *argAsStr = ctx->SelectInst(argAsPred, data.strings[AdditionalData::TrueIdx],
data.strings[AdditionalData::FalseIdx], "print.arg.bool.str");
*OutIt++ = argAsStr;
return OutIt;
}
// Emit code for ISPC print varying arg.
// Each element of a vector is emitted as a separate OCL printf argument. Plus it is surounded by two string
// arguments to print additional parantheses when the corresponding lane is off. Pointers to generated llvm::Value
// are stored into output iterator \p OutIt.
template <typename OutIter>
OutIter writeRawVaryingArg(llvm::Value &rawArg, PrintInfo::Encoding type, OutIter OutIt) const {
auto width = g->target->getVectorWidth();
for (int idx = 0; idx != width; ++idx) {
auto *isLaneOn = ctx->ExtractInst(data.mask, idx, "print.arg.lane");
auto *leftParenthesis =
ctx->SelectInst(isLaneOn, data.strings[AdditionalData::EmptyIdx],
data.strings[AdditionalData::LeftParenthesisIdx], "print.arg.left.par");
auto *rightParenthesis =
ctx->SelectInst(isLaneOn, data.strings[AdditionalData::EmptyIdx],
data.strings[AdditionalData::RightParenthesisIdx], "print.arg.right.par");
auto *argElement = ctx->ExtractInst(&rawArg, idx, "print.arg.elem");
*OutIt++ = leftParenthesis;
OutIt = writeRawUniformArg(*argElement, PrintInfo::getCorrespondingEncoding4Uniform(type), OutIt);
*OutIt++ = rightParenthesis;
}
return OutIt;
}
// Emit printf OCL args (one or more) based on the single ISPC print arg \p rawArg and its \p type.
// Pointers to generated llvm::Value are stored into output iterator \p OutIt.
template <typename OutIter>
OutIter writeRawArg(llvm::Value &rawArg, PrintInfo::Encoding type, OutIter OutIt) const {
if (PrintInfo::isUniformEncoding(type))
return writeRawUniformArg(rawArg, type, OutIt);
return writeRawVaryingArg(rawArg, type, OutIt);
}
};
// When one print is split in several smaller ones,
// this structure will hold info about a split.
struct PrintSliceInfo {
// format holds ISPC-style format string (with just '%')
std::string format_;
PrintArgsBuilder args_;
};
class PrintLZFormatStrBuilder {
std::array<std::string, PrintInfo::Encoding::Size> specifiers;
const int width;
public:
PrintLZFormatStrBuilder(int widthIn) : width{widthIn} {}
// Based on original ISPC format string, and encoded arg types
// generates printf format string.
std::string get(const std::string &ISPCFormat, const std::string &argTypes, const SourcePos &pos) {
std::string format;
if (!checkFormatString(ISPCFormat, argTypes.size(), pos))
return "";
format.reserve(ISPCFormat.size());
auto curISPCFormatIt = ISPCFormat.begin();
for (auto type : argTypes) {
auto percentIt = std::find(curISPCFormatIt, ISPCFormat.end(), '%');
if (percentIt == ISPCFormat.end())
Error(pos, "Too much arguments are provided in print call");
format.append(curISPCFormatIt, percentIt);
format.append(getOrCreateSpecifier(static_cast<PrintInfo::Encoding>(type)));
curISPCFormatIt = std::next(percentIt);
}
if (std::any_of(curISPCFormatIt, ISPCFormat.end(), [](char ch) { return ch == '%'; }))
Error(pos, "Not enough arguments are provided in print call");
format.append(curISPCFormatIt, ISPCFormat.end());
return format;
}
const std::string &getOrCreateSpecifier(PrintInfo::Encoding type) {
assertEncoding(type);
auto &specifier = getSpecifier(type);
if (specifier.empty())
return createSpecifier(type);
return specifier;
}
private:
static void assertEncoding(PrintInfo::Encoding type) {
Assert(type >= PrintInfo::Encoding::Bool && type <= PrintInfo::Encoding::VecPtr &&
"wrong argument: unsupported type");
}
static void assertEncoding4Uniform(PrintInfo::Encoding type) {
Assert(type >= PrintInfo::Encoding::Bool && type < PrintInfo::Encoding::VecBool &&
"wrong argument: unsupported type");
}
static void assertEncoding4Varying(PrintInfo::Encoding type) {
Assert(type >= PrintInfo::Encoding::VecBool && type <= PrintInfo::Encoding::VecPtr &&
"wrong argument: unsupported type");
}
// helper functor to generate specifier for uniform type
struct FillSpecifier4Uniform {
std::string &str;
FillSpecifier4Uniform(std::string &strIn) : str(strIn) {}
template <typename T> void call() { str = PrintInfo::type2Specifier<T>(); }
};
// tip: don't access specifiers field directly, use this function.
std::string &accessSpecifier(PrintInfo::Encoding type) {
assertEncoding(type);
return specifiers[type - PrintInfo::Bool];
}
const std::string &getSpecifier(PrintInfo::Encoding type) const {
assertEncoding(type);
return const_cast<PrintLZFormatStrBuilder *>(this)->accessSpecifier(type);
}
const std::string &createSpecifier4Uniform(PrintInfo::Encoding type) {
assertEncoding4Uniform(type);
switchEncoding4Uniform(type, FillSpecifier4Uniform{accessSpecifier(type)});
return getSpecifier(type);
}
const std::string &getOrCreateSpecifier4Uniform(PrintInfo::Encoding type) {
assertEncoding4Uniform(type);
auto &specifier = getSpecifier(type);
if (specifier.empty())
return createSpecifier4Uniform(type);
return specifier;
}
const std::string &createSpecifier4Varying(PrintInfo::Encoding type) {
assertEncoding4Varying(type);
const std::string &uniform = getOrCreateSpecifier4Uniform(getCorrespondingEncoding4Uniform(type));
std::stringstream ss;
ss << "[";
for (int i = 0; i < width - 1; ++i)
ss << "%s" << uniform << "%s,";
ss << "%s" << uniform << "%s]";
accessSpecifier(type) = ss.str();
return getSpecifier(type);
}
const std::string &createSpecifier(PrintInfo::Encoding type) {
assertEncoding(type);
if (type < PrintInfo::Encoding::VecBool)
return createSpecifier4Uniform(type);
return createSpecifier4Varying(type);
}
};
// Finds a prefix of a format string with a valid weight.
//
// Arguments:
// [\p formatFirst, \p formatLast) format string defined with a range
// [\p typeWeightFirst, ...) range of weights of every type to be printed,
// only begin of the range is required, length of the range must correspond to
// the number of '%' in format string. Weight here is meant to represent
// the length of the string, with which '%' will be replaced. In other
// words weight of every char except '%' is 1, and the weight of every
// '%' char is taken from the range.
// \p LZPrintFormatLimit - limit on the weight of the resulting string.
//
// Iterator to the provided format string such that string [\p formatFirst, returned iter)
// meets the limit is returned.
// Iterator to the element past the last weight element used is returned. When no weight
// info is used unchanged \p typeWeightFirst is returned.
template <typename FormatIt, typename TypeWeightIt>
std::tuple<FormatIt, TypeWeightIt> splitValidFormat(FormatIt formatFirst, FormatIt formatLast,
TypeWeightIt typeWeightFirst, int LZPrintFormatLimit,
const std::vector<int> argWeights) {
int sum = 0;
// space for '\0'
--LZPrintFormatLimit;
for (; formatFirst != formatLast; ++formatFirst) {
char curCh = *formatFirst;
// Check that typeWeightFirst can be safely derefrenced here
if ((curCh == '%') && (typeWeightFirst != argWeights.end())) {
sum += *typeWeightFirst;
if (sum <= LZPrintFormatLimit)
++typeWeightFirst;
} else
++sum;
if (sum > LZPrintFormatLimit)
return {formatFirst, typeWeightFirst};
}
return {formatFirst, typeWeightFirst};
}
// Splits original print into several prints with valid length format strings.
static std::vector<PrintSliceInfo> getPrintSlices(const std::string &format, PrintLZFormatStrBuilder &formatBuilder,
const PrintArgsBuilder &args, const int LZPrintFormatLimit,
const SourcePos &pos) {
auto argTypes = args.generateArgTypes();
std::vector<PrintSliceInfo> printSlices;
if (!checkFormatString(format, argTypes.size(), pos)) {
return printSlices;
}
std::vector<int> argWeights(argTypes.size());
std::transform(argTypes.begin(), argTypes.end(), argWeights.begin(), [&formatBuilder](char type) {
return formatBuilder.getOrCreateSpecifier(static_cast<PrintInfo::Encoding>(type)).size();
});
auto firstArgWeight = argWeights.begin();
auto curArgWeight = firstArgWeight;
for (auto curFormat = format.begin(), lastFormat = format.end(); curFormat != lastFormat;) {
auto prevFormat = curFormat;
auto prevArgWeight = curArgWeight;
std::tie(curFormat, curArgWeight) =
splitValidFormat(curFormat, lastFormat, curArgWeight, LZPrintFormatLimit, argWeights);
Assert(curFormat > prevFormat && "haven't managed to split format string");
printSlices.push_back({std::string(prevFormat, curFormat),
args.extract(prevArgWeight - firstArgWeight, curArgWeight - firstArgWeight)});
}
return printSlices;
}
// prepares arguments for __spirv_ocl_printf function
static std::vector<llvm::Value *> getOCLPrintfArgs(const std::string &format, PrintLZFormatStrBuilder &formatBuilder,
const PrintArgsBuilder &args, FunctionEmitContext *ctx,
const SourcePos &pos) {
std::vector<llvm::Value *> allArgs;
auto argTypes = args.generateArgTypes();
allArgs.push_back(ctx->XeCreateConstantString(formatBuilder.get(format, argTypes, pos), "lz_format_str"));
auto valueArgs = args.emitArgCode();
std::move(valueArgs.begin(), valueArgs.end(), std::back_inserter(allArgs));
return allArgs;
}
static PrintArgsBuilder getPrintArgsBuilder(Expr *values, FunctionEmitContext *ctx) {
if (values == nullptr)
return PrintArgsBuilder(ctx);
else {
ExprList *elist = llvm::dyn_cast<ExprList>(values);
if (elist)
return PrintArgsBuilder{elist->exprs.begin(), elist->exprs.end(), ctx};
else
return PrintArgsBuilder{&values, &values + 1, ctx};
}
}
// This name should be also properly mangled. It happens later.
static llvm::FunctionCallee getSPIRVOCLPrintfDecl() {
auto *PrintfTy = llvm::FunctionType::get(LLVMTypes::Int32Type,
llvm::PointerType::get(LLVMTypes::Int8Type, /* const addrspace */ 2),
/* isVarArg */ true);
return m->module->getOrInsertFunction("__spirv_ocl_printf", PrintfTy);
}
static void emitCode4LZPrintSlice(const PrintSliceInfo &printSlice, PrintLZFormatStrBuilder &formatBuilder,
FunctionEmitContext *ctx, const SourcePos &pos) {
auto printImplArgs = getOCLPrintfArgs(printSlice.format_, formatBuilder, printSlice.args_, ctx, pos);
auto printImplFunc = getSPIRVOCLPrintfDecl();
Assert(printImplFunc && "__spirv_ocl_printf declaration wasn't created");
ctx->CallInst(printImplFunc.getCallee(), nullptr, printImplArgs, "");
}
void PrintStmt::emitCode4LZ(FunctionEmitContext *ctx) const {
auto allArgs = getPrintArgsBuilder(values, ctx);
PrintLZFormatStrBuilder formatBuilder(g->target->getVectorWidth());
auto printSlices = getPrintSlices(format, formatBuilder, allArgs, PrintInfo::LZMaxFormatStrSize, pos);
for (const auto &printSlice : printSlices)
emitCode4LZPrintSlice(printSlice, formatBuilder, ctx, pos);
}
#endif // ISPC_XE_ENABLED
/** Given an Expr for a value to be printed, emit the code to evaluate the
expression and store the result to alloca's memory. Update the
argTypes string with the type encoding for this expression.
*/
static AddressInfo *lEmitPrintArgCode(Expr *expr, FunctionEmitContext *ctx) {
const Type *type = expr->GetType();
Assert(type);
llvm::Type *llvmExprType = type->LLVMType(g->ctx);
AddressInfo *ptrInfo = ctx->AllocaInst(llvmExprType, "print_arg");
llvm::Value *val = expr->GetValue(ctx);
if (!val)
return nullptr;
ctx->StoreInst(val, ptrInfo);
llvm::Value *ptr = ctx->BitCastInst(ptrInfo->getPointer(), LLVMTypes::VoidPointerType);
return new AddressInfo(ptr, LLVMTypes::VoidPointerType);
}
static bool lProcessPrintArg(Expr *expr, FunctionEmitContext *ctx, AddressInfo *argPtrArray, int offset,
std::string &argTypes) {
if (!expr)
return false;
auto exprType = lProcessPrintArgType(expr);
expr = exprType.expr;
char type = exprType.type;
if (!expr)
return false;
argTypes.push_back(type);
AddressInfo *ptrInfo = lEmitPrintArgCode(expr, ctx);
if (!ptrInfo)
return false;
llvm::Value *arrayPtr = ctx->AddElementOffset(argPtrArray, offset);
ctx->StoreInst(ptrInfo->getPointer(), new AddressInfo(arrayPtr, ptrInfo->getElementType()));
return true;
}
// prepares arguments for __do_print function
std::vector<llvm::Value *> PrintStmt::getDoPrintArgs(FunctionEmitContext *ctx) const {
std::vector<llvm::Value *> doPrintArgs(STD_NUM_IDX);
std::string argTypes;
if (values == nullptr) {
// Check requested format
checkFormatString(format, 0, pos);
llvm::Type *ptrPtrType = llvm::PointerType::get(LLVMTypes::VoidPointerType, 0);
doPrintArgs[ARGS_IDX] = llvm::Constant::getNullValue(ptrPtrType);
} else {
// Get the values passed to the print() statement evaluated and
// stored in memory so that we set up the array of pointers to them
// for the 5th __do_print() argument
ExprList *elist = llvm::dyn_cast<ExprList>(values);
int nArgs = elist ? elist->exprs.size() : 1;
// Check requested format
checkFormatString(format, nArgs, pos);
// Allocate space for the array of pointers to values to be printed
llvm::Type *argPtrArrayType = llvm::ArrayType::get(LLVMTypes::VoidPointerType, nArgs);
AddressInfo *argPtrArrayInfo = ctx->AllocaInst(argPtrArrayType, "print_arg_ptrs");
// Store the array pointer as a void **, which is what __do_print()
// expects
doPrintArgs[ARGS_IDX] =
ctx->BitCastInst(argPtrArrayInfo->getPointer(), llvm::PointerType::get(LLVMTypes::VoidPointerType, 0));
// Now, for each of the arguments, emit code to evaluate its value
// and store the value into alloca's storage. Then store the
// pointer to the alloca's storage into argPtrArrayInfo.
if (elist) {
for (unsigned int i = 0; i < elist->exprs.size(); ++i) {
Expr *expr = elist->exprs[i];
if (!lProcessPrintArg(expr, ctx, argPtrArrayInfo, i, argTypes)) {
return {};
}
}
} else {
if (lProcessPrintArg(values, ctx, argPtrArrayInfo, 0, argTypes)) {
return {};
}
}
}
llvm::Value *mask = ctx->GetFullMask();
// Set up the rest of the parameters to it
doPrintArgs[FORMAT_IDX] = ctx->GetStringPtr(format);
doPrintArgs[TYPES_IDX] = ctx->GetStringPtr(argTypes);
doPrintArgs[WIDTH_IDX] = LLVMInt32(g->target->getVectorWidth());
doPrintArgs[MASK_IDX] = ctx->LaneMask(mask);
return doPrintArgs;
}
/* PrintStmt works closely with the __do_print() function implemented in
the builtins-c-cpu.cpp file. In particular, the EmitCode() method here needs to
take the arguments passed to it from ispc and generate a valid call to
__do_print() with the information that __do_print() then needs to do the
actual printing work at runtime.
*/
void PrintStmt::EmitCode(FunctionEmitContext *ctx) const {
if (!ctx->GetCurrentBasicBlock())
return;
ctx->SetDebugPos(pos);
#ifdef ISPC_XE_ENABLED
if (g->target->isXeTarget()) {
emitCode4LZ(ctx);
return;
}
#endif /* ISPC_XE_ENABLED */
auto printImplArgs = getDoPrintArgs(ctx);
if (printImplArgs.empty()) {
AssertPos(pos, m->errorCount > 0);
return;
}
auto printImplFunc = getPrintImplFunc();
AssertPos(pos, printImplFunc);
ctx->CallInst(printImplFunc, nullptr, printImplArgs, "");
}
void PrintStmt::Print(Indent &indent) const {
indent.Print("PrintStmt", pos);
printf("Format string: \"%s\"\n", format.c_str());
indent.pushSingle();
indent.setNextLabel("args");
if (values) {
values->Print(indent);
} else {
indent.Print("<NULL / NO ARGS>\n");
indent.Done();
}
indent.Done();
}
Stmt *PrintStmt::TypeCheck() { return this; }
int PrintStmt::EstimateCost() const { return COST_FUNCALL; }
PrintStmt *PrintStmt::Instantiate(TemplateInstantiation &templInst) const {
Expr *instValues = values ? values->Instantiate(templInst) : nullptr;
return new PrintStmt(format, instValues, pos);
}
///////////////////////////////////////////////////////////////////////////
// AssertStmt
AssertStmt::AssertStmt(const std::string &msg, Expr *e, SourcePos p) : Stmt(p, AssertStmtID), message(msg), expr(e) {}
void AssertStmt::EmitAssertCode(FunctionEmitContext *ctx, const Type *type) const {
bool isUniform = type->IsUniformType();
// The actual functionality to do the check and then handle failure is
// done via a builtin written in bitcode in builtins/util.m4.
llvm::Function *assertFunc =
isUniform ? m->module->getFunction("__do_assert_uniform") : m->module->getFunction("__do_assert_varying");
AssertPos(pos, assertFunc != nullptr);
char *errorString;
if (asprintf(&errorString, "%s:%d:%d: Assertion failed: %s \n", pos.name, pos.first_line, pos.first_column,
message.c_str()) == -1) {
Error(pos, "Fatal error when generating assert string: asprintf() "
"unable to allocate memory!");
return;
}
std::vector<llvm::Value *> args;
#ifdef ISPC_XE_ENABLED
if (g->target->isXeTarget()) {
PrintLZFormatStrBuilder formatBuilder(g->target->getVectorWidth());
args.push_back(ctx->XeCreateConstantString(errorString, "lz_format_str"));
} else
#endif
args.push_back(ctx->GetStringPtr(errorString));
llvm::Value *exprValue = expr->GetValue(ctx);
if (exprValue == nullptr) {
free(errorString);
AssertPos(pos, m->errorCount > 0);
return;
}
args.push_back(exprValue);
#ifdef ISPC_XE_ENABLED
if (ctx->emitXeHardwareMask())
// This will create mask according to current EM on SIMD CF Lowering.
// The result will be like mask = select (EM, AllOn, AllFalse)
args.push_back(ctx->XeSimdCFPredicate(LLVMMaskAllOn));
else
#endif
args.push_back(ctx->GetFullMask());
ctx->CallInst(assertFunc, nullptr, args, "");
free(errorString);
}
void AssertStmt::EmitAssumeCode(FunctionEmitContext *ctx, const Type *type) const {
bool isUniform = type->IsUniformType();
// Currently, we insert an assume only for uniform conditions.
if (!isUniform) {
return;
}
// The actual functionality to insert an 'llvm.assume' intrinsic is
// done via a builtin written in bitcode in builtins/util.m4.
llvm::Function *assumeFunc = m->module->getFunction("__do_assume_uniform");
AssertPos(pos, assumeFunc != nullptr);
llvm::Value *exprValue = expr->GetValue(ctx);
if (exprValue == nullptr) {
AssertPos(pos, m->errorCount > 0);
return;
}
ctx->CallInst(assumeFunc, nullptr, exprValue, "");
}
void AssertStmt::EmitCode(FunctionEmitContext *ctx) const {
if (!ctx->GetCurrentBasicBlock())
return;
const Type *type;
if (expr == nullptr || (type = expr->GetType()) == nullptr) {
AssertPos(pos, m->errorCount > 0);
return;
}
if (g->opt.disableAsserts) {
EmitAssumeCode(ctx, type);
} else {
EmitAssertCode(ctx, type);
}
}
void AssertStmt::Print(Indent &indent) const {
indent.Print("AssertStmt", pos);
printf("Message: %s\n", message.c_str());
indent.pushSingle();
if (expr) {
expr->Print(indent);
} else {
indent.Print("<NULL>\n");
}
indent.Done();
}
Stmt *AssertStmt::TypeCheck() {
const Type *type;
if (expr && (type = expr->GetType()) != nullptr) {
if (type->IsDependentType()) {
return this;
}
bool isUniform = type->IsUniformType();
expr = TypeConvertExpr(expr, isUniform ? AtomicType::UniformBool : AtomicType::VaryingBool,
"\"assert\" statement");
if (expr == nullptr)
return nullptr;
}
return this;
}
int AssertStmt::EstimateCost() const { return COST_ASSERT; }
AssertStmt *AssertStmt::Instantiate(TemplateInstantiation &templInst) const {
Expr *instExpr = expr ? expr->Instantiate(templInst) : nullptr;
return new AssertStmt(message, instExpr, pos);
}
///////////////////////////////////////////////////////////////////////////
// DeleteStmt
DeleteStmt::DeleteStmt(Expr *e, SourcePos p) : Stmt(p, DeleteStmtID) { expr = e; }
void DeleteStmt::EmitCode(FunctionEmitContext *ctx) const {
if (g->target->isXeTarget()) {
Error(pos, "\"delete\" statement is not supported for Xe targets yet.");
return;
}
if (!ctx->GetCurrentBasicBlock())
return;
const Type *exprType;
if (expr == nullptr || ((exprType = expr->GetType()) == nullptr)) {
AssertPos(pos, m->errorCount > 0);
return;
}
llvm::Value *exprValue = expr->GetValue(ctx);
if (exprValue == nullptr) {
AssertPos(pos, m->errorCount > 0);
return;
}
// Typechecking should catch this
AssertPos(pos, CastType<PointerType>(exprType) != nullptr);
if (exprType->IsUniformType()) {
// For deletion of a uniform pointer, we just need to cast the
// pointer type to a void pointer type, to match what
// __delete_uniform() from the builtins expects.
exprValue = ctx->BitCastInst(exprValue, LLVMTypes::VoidPointerType, "ptr_to_void");
llvm::Function *func;
if (g->target->is32Bit()) {
func = m->module->getFunction("__delete_uniform_32rt");
} else {
func = m->module->getFunction("__delete_uniform_64rt");
}
AssertPos(pos, func != nullptr);
ctx->CallInst(func, nullptr, exprValue, "");
} else {
// Varying pointers are arrays of ints, and __delete_varying()
// takes a vector of i64s (even for 32-bit targets). Therefore, we
// only need to extend to 64-bit values on 32-bit targets before
// calling it.
llvm::Function *func;
if (g->target->is32Bit()) {
func = m->module->getFunction("__delete_varying_32rt");
} else {
func = m->module->getFunction("__delete_varying_64rt");
}
AssertPos(pos, func != nullptr);
if (g->target->is32Bit())
exprValue = ctx->ZExtInst(exprValue, LLVMTypes::Int64VectorType, "ptr_to_64");
ctx->CallInst(func, nullptr, exprValue, "");
}
}
void DeleteStmt::Print(Indent &indent) const {
indent.PrintLn("DeleteStmt", pos);
indent.pushSingle();
if (expr) {
expr->Print(indent);
} else {
indent.Print("<NULL>\n");
indent.Done();
}
indent.Done();
}
Stmt *DeleteStmt::TypeCheck() {
const Type *exprType;
if (expr == nullptr || ((exprType = expr->GetType()) == nullptr))
return nullptr;
if (exprType->IsDependentType()) {
return this;
}
if (CastType<PointerType>(exprType) == nullptr) {
Error(pos, "Illegal to delete non-pointer type \"%s\".", exprType->GetString().c_str());
return nullptr;
}
return this;
}
int DeleteStmt::EstimateCost() const { return COST_DELETE; }
DeleteStmt *DeleteStmt::Instantiate(TemplateInstantiation &templInst) const {
Expr *instExpr = expr->Instantiate(templInst);
return new DeleteStmt(instExpr, pos);
}