Files
UnrealEngine/Engine/Plugins/Experimental/NNERuntimeRDG/Source/NNEHlslShaders/Private/NNEHlslShadersElementWiseVariadicCS.cpp
Brandyn / Techy fcc1b09210 init
2026-04-04 15:40:51 -05:00

49 lines
1.6 KiB
C++

// Copyright Epic Games, Inc. All Rights Reserved.
#include "NNEHlslShadersElementWiseVariadicCS.h"
#include "NNEHlslShadersLog.h"
namespace UE::NNEHlslShaders::Internal
{
void TElementWiseVariadicCS::ModifyCompilationEnvironment(const FGlobalShaderPermutationParameters& InParameters, FShaderCompilerEnvironment& OutEnvironment)
{
FGlobalShader::ModifyCompilationEnvironment(InParameters, OutEnvironment);
OutEnvironment.SetDefine(TEXT("THREADGROUP_SIZE_X"), FElementWiseVariadicConstants::NUM_GROUP_THREADS);
FPermutationDomain PermutationVector(InParameters.PermutationId);
const FString OpFunc = GetOpFunc(PermutationVector.Get<FOperatorType>());
OutEnvironment.SetDefine(TEXT("ELEMENTWISE_OP(X,Y)"), *OpFunc);
}
const FString TElementWiseVariadicCS::GetOpFunc(EElementWiseVariadicOperatorType OpType)
{
FString OpTable[(int32) EElementWiseVariadicOperatorType::MAX];
for (int32 Idx = 0; Idx < (int32) EElementWiseVariadicOperatorType::MAX; ++Idx)
{
OpTable[Idx] = FString("");
}
#define OP(OpName, OpFunc) OpTable[(int32) EElementWiseVariadicOperatorType::OpName] = OpFunc
OP(Max, TEXT("max(X,Y)"));
OP(Min, TEXT("min(X,Y)"));
OP(Mean, TEXT("((X)+(Y))"));
OP(Sum, TEXT("((X)+(Y))"));
#undef OP
FString OpFunc = OpTable[(int32)OpType];
if (OpFunc == "")
{
UE_LOG(LogNNERuntimeRDGHlsl, Warning, TEXT("Undefined ElementWise Variadic operator name for operator:%d"), int(OpType));
}
return OpFunc;
}
IMPLEMENT_GLOBAL_SHADER(TElementWiseVariadicCS, "/NNEHlslShaders/NNEHlslShadersElementWiseVariadic.usf", "ElementWiseVariadic", SF_Compute);
} // UE::NNEHlslShaders::Internal