// Copyright Epic Games, Inc. All Rights Reserved. #include "NNEHlslShadersConvCS.h" #include "NNE.h" #include "RHIGlobals.h" namespace UE::NNEHlslShaders::Internal { namespace ConvUtils { TArray GetXBlockShape(TArrayView GroupShape, TArrayView WShape, TArrayView Dilations, TArrayView Strides) { check(WShape.Num() > 2); check(GroupShape.Num() == WShape.Num() - 2); check(Dilations.Num() == 0 || Dilations.Num() == GroupShape.Num()); check(Strides.Num() == 0 || Strides.Num() == GroupShape.Num()); TArray Result; Result.SetNumUninitialized(GroupShape.Num()); for (int32 i = 0; i < GroupShape.Num(); i++) { int32 DilatedKernelSize = (i < Dilations.Num() ? Dilations[i] : 1) * (WShape[2 + i] - 1) + 1; Result[i] = DilatedKernelSize + (GroupShape[i] - 1) * (i < Strides.Num() ? Strides[i] : 1); } return Result; } int32 GetNumThreadsPerGroup(EConvGroupSize GroupSize) { int32 NumThreadsPerGroup = 128; switch (GroupSize) { case EConvGroupSize::Size128: NumThreadsPerGroup = 128; break; case EConvGroupSize::Size256: NumThreadsPerGroup = 256; break; case EConvGroupSize::Size512: NumThreadsPerGroup = 512; break; default: check(false); break; } check(FMath::Log2((float)NumThreadsPerGroup) == FMath::Floor(FMath::Log2((float)NumThreadsPerGroup))); return NumThreadsPerGroup; } TArray GetGridShape(TArrayView YShape, TArrayView GroupShape) { check(YShape.Num() > 2); check(YShape.Num() == (GroupShape.Num() + 2)); TArray Result; Result.SetNumUninitialized(YShape.Num() - 2); for (int32 i = 2; i < YShape.Num(); i++) { Result[i-2] = FMath::DivideAndRoundUp(YShape[i], GroupShape[i-2]); } return Result; } uint32 GetTotalSharedMemoryUsed(uint32 ThreadGroupSize, uint32 Log2NumReadsPerThread) { return ThreadGroupSize * (1 + (1 << Log2NumReadsPerThread)) * sizeof(float); } } // namespace ConvUtils bool FConvCS::ShouldCompilePermutation(const FGlobalShaderPermutationParameters& InParameters) { if (!FHlslShaderBase::ShouldCompilePermutation(InParameters)) { return false; } FPermutationDomain PermutationVector(InParameters.PermutationId); int Log2NumReadsPerThread = PermutationVector.Get(); int32 ConvGroupSize = ConvUtils::GetNumThreadsPerGroup(PermutationVector.Get()); if(ConvUtils::GetTotalSharedMemoryUsed(ConvGroupSize, Log2NumReadsPerThread) > GetMaxComputeSharedMemory()) { return false; } return true; } void FConvCS::ModifyCompilationEnvironment(const FGlobalShaderPermutationParameters& InParameters, FShaderCompilerEnvironment& OutEnvironment) { FGlobalShader::ModifyCompilationEnvironment(InParameters, OutEnvironment); OutEnvironment.SetDefine(TEXT("MAX_NUM_DIMENSIONS"), FConvConstants::MAX_NUM_DIMENSIONS); } // Returns the biggest thread group size that allows to keep in shared memory the portion of the input tensor and the filter necessary to compute the portion of output tended by it (the thread group). // If no suitable thread group size was found, EConvGroupSize::MAX is returned. EConvGroupSize FConvCS::GetBiggestCompatibleGroupSize(TArrayView WShape, TArrayView Dilations, TArrayView Strides) { check(WShape.Num() > 2); check(Dilations.Num() == 0 || Dilations.Num() == WShape.Num() - 2); check(Strides.Num() == 0 || Strides.Num() == WShape.Num() - 2); for(EConvGroupSize TargetGroupSize = (EConvGroupSize)((uint8)EConvGroupSize::MAX - 1); (uint8)TargetGroupSize > 0; TargetGroupSize = (EConvGroupSize)((uint8)TargetGroupSize - 1)) { const int32 NumDimensions = WShape.Num() - 2; const TArray GroupShape = GetGroupShape(TargetGroupSize, NumDimensions); TArray TargetXBlockShape = ConvUtils::GetXBlockShape(GroupShape, WShape, Dilations, Strides); int32 XWindowSize = 1; for (int32 Idx = 0; Idx < NumDimensions; Idx++) { XWindowSize *= TargetXBlockShape[Idx]; } int32 NumReadsPerThread = FMath::DivideAndRoundUp(XWindowSize, ConvUtils::GetNumThreadsPerGroup(TargetGroupSize)); int32 Log2NumReadsPerThread = FMath::Max(FMath::RoundToPositiveInfinity(FMath::Log2((float)NumReadsPerThread)), FConvConstants::MIN_NUM_READS_PER_THREAD_POW2); uint32 TotSharedMemoryUsed = ConvUtils::GetTotalSharedMemoryUsed(ConvUtils::GetNumThreadsPerGroup(TargetGroupSize), Log2NumReadsPerThread); if(TotSharedMemoryUsed <= GetMaxComputeSharedMemory() && Log2NumReadsPerThread <= FConvConstants::MAX_NUM_READS_PER_THREAD_POW2) { return TargetGroupSize; } } return EConvGroupSize::MAX; } TArray FConvCS::GetPadding(TArrayView XShape, TArrayView WShape, EConvAutoPad AutoPad, TArrayView Dilations, TArrayView Strides, TArrayView Pads) { check(XShape.Num() > 2); check(WShape.Num() == XShape.Num()); check(Dilations.Num() == 0 || Dilations.Num() == WShape.Num() - 2); check(Strides.Num() == 0 || Strides.Num() == WShape.Num() - 2); check(AutoPad != EConvAutoPad::NOTSET || Pads.Num() == 2 * (WShape.Num() - 2)); int32 NumDimensions = XShape.Num() - 2; TArray Result; Result.Init(0, 2 * NumDimensions); if (AutoPad == EConvAutoPad::NOTSET) { return TArray{Pads}; } else if (AutoPad == EConvAutoPad::VALID) { return Result; } for (int32 DimensionIndex = 0; DimensionIndex < NumDimensions; DimensionIndex++) { int32 DilatedKernelSize = (DimensionIndex < Dilations.Num() ? Dilations[DimensionIndex] : 1) * (WShape[DimensionIndex + 2] - 1) + 1; int32 LastOutputIdx = ((int32) XShape[DimensionIndex + 2] + (DimensionIndex < Strides.Num() ? Strides[DimensionIndex] : 1) - 1) / (DimensionIndex < Strides.Num() ? Strides[DimensionIndex] : 1) - 1; int32 TotalPad = LastOutputIdx * (DimensionIndex < Strides.Num() ? Strides[DimensionIndex] : 1) + DilatedKernelSize - XShape[DimensionIndex + 2]; TotalPad = TotalPad >= 0 ? TotalPad : 0; if (AutoPad == EConvAutoPad::SAME_LOWER) { Result[DimensionIndex] = (TotalPad + 1) / 2; } else { Result[DimensionIndex] = TotalPad / 2; } Result[NumDimensions + DimensionIndex] = TotalPad - Result[DimensionIndex]; } return Result; } TArray FConvCS::GetOutputShape(TArrayView XShape, TArrayView WShape, EConvAutoPad AutoPad, TArrayView Dilations, TArrayView Strides, TArrayView Pads) { check(XShape.Num() > 2); check(WShape.Num() == XShape.Num()); check(Dilations.Num() == 0 || Dilations.Num() == WShape.Num() - 2); check(Strides.Num() == 0 || Strides.Num() == WShape.Num() - 2); check(Pads.Num() == 0 || Pads.Num() == 2 * (WShape.Num() - 2)); TArray Padding = GetPadding(XShape, WShape, AutoPad, Dilations, Strides, Pads); TArray Result; Result.SetNumUninitialized(XShape.Num()); Result[0] = XShape[0]; Result[1] = WShape[0]; int32 NumDimensions = XShape.Num() - 2; for (int32 DimensionIndex = 0; DimensionIndex < NumDimensions; DimensionIndex++) { int32 PaddedSize = XShape[2 + DimensionIndex] + Padding[DimensionIndex] + Padding[NumDimensions + DimensionIndex]; int32 DilatedKernelSize = (DimensionIndex < Dilations.Num() ? Dilations[DimensionIndex] : 1) * (WShape[2 + DimensionIndex] - 1) + 1; float PaddeSizeMinusDilatedKernelSize = (float)(PaddedSize - DilatedKernelSize); int32 OutputSize = (int32)(PaddeSizeMinusDilatedKernelSize / (DimensionIndex < Strides.Num() ? (float)Strides[DimensionIndex] : 1.0) + 1); Result[2 + DimensionIndex] = OutputSize; } return Result; } void FConvCS::FillInParameters(EConvGroupSize GroupSize, TArrayView XShape, TArrayView WShape, bool HasB, EConvAutoPad AutoPad, int Group, TArrayView Dilations, TArrayView Strides, TArrayView Pads, FConvCS::FParameters& Parameters) { check(XShape.Num() > 2); check(WShape.Num() == XShape.Num()); check(Dilations.Num() == 0 || Dilations.Num() == WShape.Num() - 2); check(Strides.Num() == 0 || Strides.Num() == WShape.Num() - 2); check(Pads.Num() == 0 || Pads.Num() == 2 * (WShape.Num() - 2)); check(GroupSize != EConvGroupSize::MAX); check(WShape[0] > 0) check(WShape[1] > 0) int32 NumDimensions = XShape.Num() - 2; TArray Padding = GetPadding(XShape, WShape, AutoPad, Dilations, Strides, Pads); TArray GroupShape = GetGroupShape(GroupSize, NumDimensions); TArray YShape = GetOutputShape(XShape, WShape, AutoPad, Dilations, Strides, Pads); TArray GridShape = ConvUtils::GetGridShape(YShape, GroupShape); TArray XBlockShape = ConvUtils::GetXBlockShape(GroupShape, WShape, Dilations, Strides); int32 GroupStride = 1; int32 GroupThreadStride = 1; int32 XBlockSize = 1; int32 YMemoryStride = 1; int32 XMemoryStride = 1; int32 WChannelSize = 1; for (int32 i = NumDimensions-1; i >= 0; i--) { int32 Stride = i < Strides.Num() ? Strides[i] : 1; int32 Dilation = i < Dilations.Num() ? Dilations[i] : 1; Parameters.Dilation_Stride_XBlockStartOffset_DilationXBlockStride[i] = FIntVector4(Dilation, Stride, -Padding[i], Dilation * XBlockSize); Parameters.GroupStride_GroupShape_GroupThreadStride_StrideXBlockStride[i] = FIntVector4(GroupStride, GroupShape[i], GroupThreadStride, Stride * XBlockSize); Parameters.YDimension_YMemoryStride_XDimension_XMemoryStride[i] = FIntVector4(YShape[2 + i], YMemoryStride, XShape[2 + i], XMemoryStride); Parameters.XBlockStartStride_XBlockStride_WDimension_WDimensionDilationXBlockStride[i] = FIntVector4(GroupShape[i] * Stride, XBlockSize, WShape[2 + i], WShape[2 + i] * Dilation * XBlockSize); Parameters.OneDiv_GroupStride_GroupThreadStride_XBlockStride[i] = FVector4f(1.0/((float)GroupStride), 1.0/((float)GroupThreadStride), 1.0/((float)XBlockSize), 0.0); GroupStride *= GridShape[i]; GroupThreadStride *= GroupShape[i]; XBlockSize *= XBlockShape[i]; YMemoryStride *= YShape[2 + i]; XMemoryStride *= XShape[2 + i]; WChannelSize *= WShape[2 + i]; } Parameters.NumWFeatures = WShape[0]; Parameters.NumWChannels = WShape[1]; Parameters.YBatchStride = YShape[1] * YMemoryStride; Parameters.YOutputKernelStride = YMemoryStride; Parameters.XBatchStride = XShape[1] * XMemoryStride; Parameters.XChannelStride = XMemoryStride; Parameters.XBlockSize = XBlockSize; Parameters.NumChannelsPerBatch = FMath::Min((int32)((float)GroupThreadStride / (float)WChannelSize), (int32)WShape[1]); check(Parameters.NumChannelsPerBatch > 0) Parameters.NumChannelBatches = FMath::DivideAndRoundUp((int32)WShape[1], Parameters.NumChannelsPerBatch); Parameters.WOutputKernelStride = WShape[1] * WChannelSize; Parameters.WChannelBatchSize = Parameters.NumChannelsPerBatch * WChannelSize; Parameters.WChannelSize = WChannelSize; Parameters.GroupsDivM = (float)Group / (float)WShape[0]; } int32 FConvCS::GetNumReadsPerThread(EConvGroupSize GroupSize, TArrayView WShape, TArrayView Dilations, TArrayView Strides) { check(WShape.Num() > 2); check(Dilations.Num() == 0 || Dilations.Num() == WShape.Num() - 2); check(Strides.Num() == 0 || Strides.Num() == WShape.Num() - 2); int32 NumDimensions = WShape.Num() - 2; TArray GroupShape = GetGroupShape(GroupSize, NumDimensions); int32 NumThreadsPerGroup = 1; for (int32 i = 0; i < NumDimensions; i++) { NumThreadsPerGroup *= GroupShape[i]; } TArray XBlockShape = ConvUtils::GetXBlockShape(GroupShape, WShape, Dilations, Strides); int32 NumXBlockElements = 1; for (int32 i = 0; i < NumDimensions; i++) { NumXBlockElements *= XBlockShape[i]; } int32 NumReads = FMath::DivideAndRoundUp(NumXBlockElements, NumThreadsPerGroup); int32 NumReadsPow2 = FMath::Max(FMath::RoundToPositiveInfinity(FMath::Log2((float)NumReads)), FConvConstants::MIN_NUM_READS_PER_THREAD_POW2); if(NumReadsPow2 <= FConvConstants::MAX_NUM_READS_PER_THREAD_POW2) { return NumReadsPow2; } return -1; } TArray FConvCS::GetGroupShape(EConvGroupSize GroupSize, int32 NumDimensions) { check(NumDimensions > 0); int32 NumThreadsPerGroup = ConvUtils::GetNumThreadsPerGroup(GroupSize); int32 Power = (int32)FMath::Log2((float)NumThreadsPerGroup); int32 MinPowerPerDim = (int32)((float)Power / (float)NumDimensions); int32 PowerReminder = Power - NumDimensions * MinPowerPerDim; TArray Result; Result.Init((int32)FMath::Pow((float)2.0, (float)MinPowerPerDim), NumDimensions); for (int32 i = 0; i < PowerReminder; i++) { Result[NumDimensions-1-i] *= 2; } return Result; } FIntVector FConvCS::GetGroupCount(TArrayView YShape, TArrayView GroupShape) { check(YShape.Num() > 2); check(YShape.Num() == (GroupShape.Num() + 2)); int32 ThreadGroupCountValueX = 1; for (int32 i = 2; i < YShape.Num(); i++) { ThreadGroupCountValueX *= FMath::DivideAndRoundUp(YShape[i], GroupShape[i-2]); } return FIntVector(ThreadGroupCountValueX, YShape[1], YShape[0]); } EConvGroupSize FConvCS::GetMinimalGroupSize(TArrayView WShape) { const int32 NumDimensions = WShape.Num() - 2; int32 WChannelSize = 1; for (int32 i = 0; i < NumDimensions; i++) { WChannelSize *= WShape[2 + i]; } for (int32 i = 0; i < (int32)EConvGroupSize::MAX; i++) { if (ConvUtils::GetNumThreadsPerGroup((EConvGroupSize)i) >= WChannelSize) { return (EConvGroupSize)i; } } return EConvGroupSize::MAX; } void FConvCS::LexFromString(EConvAutoPad& OutValue, const TCHAR* StringVal) { OutValue = EConvAutoPad::NOTSET; if (FCString::Stricmp(StringVal, TEXT("NOTSET")) == 0) OutValue = EConvAutoPad::NOTSET; else if (FCString::Stricmp(StringVal, TEXT("SAME_UPPER")) == 0) OutValue = EConvAutoPad::SAME_UPPER; else if (FCString::Stricmp(StringVal, TEXT("SAME_LOWER")) == 0) OutValue = EConvAutoPad::SAME_LOWER; else if (FCString::Stricmp(StringVal, TEXT("VALID")) == 0) OutValue = EConvAutoPad::VALID; } IMPLEMENT_GLOBAL_SHADER(FConvCS, "/NNEHlslShaders/NNEHlslShadersConv.usf", "Conv", SF_Compute); } // UE::NNEHlslShaders::Internal