Files
UnrealEngine/Engine/Plugins/Experimental/NNERuntimeRDG/Source/NNERuntimeRDGUtils/Private/NNERuntimeRDGUtilsModelOptimizerONNX.cpp
Brandyn / Techy fcc1b09210 init
2026-04-04 15:40:51 -05:00

237 lines
6.9 KiB
C++

// Copyright Epic Games, Inc. All Rights Reserved.
#include "NNERuntimeRDGUtilsModelOptimizerONNX.h"
#include "HAL/IConsoleManager.h"
#include "HAL/PlatformFileManager.h"
#include "HAL/PlatformTime.h"
#include "Misc/FileHelper.h"
#include "Misc/Paths.h"
#include "NNEHlslShadersLog.h"
#include "NNEOnnxruntimeEditor.h"
THIRD_PARTY_INCLUDES_START
#include <onnx/onnx_pb.h>
#include <onnx/shape_inference/implementation.h>
THIRD_PARTY_INCLUDES_END
namespace UE::NNERuntimeRDGUtils::Private
{
class FOnnxRuntimeModelOptimizerPass : public Internal::IModelOptimizerPass
{
public:
FOnnxRuntimeModelOptimizerPass()
{
}
virtual FString GetName() const
{
return TEXT("Onnx runtime model optimization");
}
virtual bool ApplyPass(TArray<uint8>& ModelData) const
{
onnx::ModelProto ModelProto;
const bool result = ModelProto.ParseFromArray(ModelData.GetData(), ModelData.Num());
if (!result)
{
UE_LOG(LogNNERuntimeRDGHlsl, Warning, TEXT("%s could not parse the input model as a ModelProto."), *(GetName()));
return false;
}
static const auto CVarHlslModelOptimization = IConsoleManager::Get().FindTConsoleVariableDataInt(TEXT("nne.hlsl.ModelOptimization"));
if (CVarHlslModelOptimization && CVarHlslModelOptimization->GetValueOnAnyThread() == 0)
{
return true;
}
// We disable ONNX Runtime optimizations if a model uses FP16 Tensors
// since these optimizations would add cast operators from and to FP16
// at the beginning and end of the network and convert all other operators to FP32.
if (HasFP16Tensor(ModelProto))
{
return true;
}
FString ProjIntermediateDir = FPaths::ConvertRelativePathToFull(FPaths::ProjectIntermediateDir());
FString ModelToOptimizePath = FPaths::CreateTempFilename(*ProjIntermediateDir, TEXT("ORTOptimizerPass_ToOptimize"), TEXT(".onnx"));
FString ModelOptimizedPath = FPaths::CreateTempFilename(*ProjIntermediateDir, TEXT("ORTOptimizerPass_Optimized"), TEXT(".onnx"));
//See https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html
//We only enable all the optimization when going to ORT format itself for the CPU provider
GraphOptimizationLevel OptimizationLevel = ORT_ENABLE_BASIC;
FFileHelper::SaveArrayToFile(ModelData, *ModelToOptimizePath);
{
Ort::ThreadingOptions ThreadingOptions;
ThreadingOptions.SetGlobalIntraOpNumThreads(1);
ThreadingOptions.SetGlobalInterOpNumThreads(1);
Ort::Env Env(ThreadingOptions);
Ort::SessionOptions SessionOptions;
SessionOptions.DisablePerSessionThreads();
SessionOptions.SetGraphOptimizationLevel(OptimizationLevel);
#if PLATFORM_WINDOWS
SessionOptions.SetOptimizedModelFilePath(*ModelOptimizedPath);
Ort::Session Session(Env, *ModelToOptimizePath, SessionOptions);
#else
SessionOptions.SetOptimizedModelFilePath(TCHAR_TO_ANSI(*ModelOptimizedPath));
Ort::Session Session(Env, TCHAR_TO_ANSI(*ModelToOptimizePath), SessionOptions);
#endif
}
FFileHelper::LoadFileToArray(ModelData, *ModelOptimizedPath);
IFileManager::Get().Delete(*ModelToOptimizePath);
IFileManager::Get().Delete(*ModelOptimizedPath);
return true;
}
private:
bool HasFP16Tensor(const onnx::ModelProto& Model) const
{
const onnx::GraphProto& Graph = Model.graph();
for (const onnx::TensorProto& Tensor : Graph.initializer())
{
if (Tensor.data_type() == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)
{
return true;
}
}
for (onnx::ValueInfoList Tensors : {Graph.input(), Graph.output()})
{
for (const onnx::ValueInfoProto& Tensor : Tensors)
{
if (!Tensor.has_type())
{
continue;
}
const onnx::TypeProto Type = Tensor.type();
if (!Type.has_tensor_type())
{
continue;
}
const onnx::TypeProto_Tensor TensorType = Type.tensor_type();
if (!TensorType.has_elem_type())
{
continue;
}
if (TensorType.elem_type() == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)
{
return true;
}
}
}
return false;
}
};
class FOnnxDomainCleanupModelOptimizerPass : public Internal::IModelOptimizerPass
{
public:
virtual FString GetName() const
{
return TEXT("Onnx domain cleanup");
}
virtual bool ApplyPass(TArray<uint8>& ModelData) const
{
onnx::ModelProto ModelProto;
const bool result = ModelProto.ParseFromArray(ModelData.GetData(), ModelData.Num());
if (!result)
{
UE_LOG(LogNNERuntimeRDGHlsl, Warning, TEXT("%s could not parse the input model as a ModelProto."), *(GetName()));
return false;
}
TArray<FString> UsedDomains;
const onnx::GraphProto& graph = ModelProto.graph();
for (const onnx::NodeProto& node : graph.node())
{
const char* DomainPtr = node.domain().c_str();
FString Domain = DomainPtr;
UsedDomains.AddUnique(Domain);
}
google::protobuf::RepeatedPtrField<onnx::OperatorSetIdProto> UsedOperatorSet;
for (const onnx::OperatorSetIdProto& OpSet : ModelProto.opset_import())
{
const char* DomainPtr = OpSet.domain().c_str();
FString Domain = DomainPtr;
bool IsUsed = UsedDomains.Contains(Domain);
// We additionally add the OpSet from models that don't have any UsedDomains/Nodes
// since we otherwise would get an invalid model
if (IsUsed || UsedDomains.IsEmpty())
{
UsedOperatorSet.Add()->CopyFrom(OpSet);
}
}
ModelProto.mutable_opset_import()->Clear();
ModelProto.mutable_opset_import()->Add(UsedOperatorSet.cbegin(), UsedOperatorSet.cend());
ModelData.SetNumUninitialized(ModelProto.ByteSizeLong());
ModelProto.SerializeToArray(ModelData.GetData(), ModelData.Num());
return true;
}
};
class FOnnxShapeInferenceModelOptimizerPass : public Internal::IModelOptimizerPass
{
public:
virtual FString GetName() const
{
return TEXT("Onnx shape inference");
}
virtual bool ApplyPass(TArray<uint8>& ModelData) const
{
onnx::ModelProto ModelProto;
const bool result = ModelProto.ParseFromArray(ModelData.GetData(), ModelData.Num());
if (!result)
{
UE_LOG(LogNNERuntimeRDGHlsl, Warning, TEXT("%s could not parse the input model as a ModelProto."), *(GetName()));
return false;
}
#ifdef ONNX_NO_EXCEPTIONS
UE_LOG(LogNNERuntimeRDGHlsl, Warning, TEXT("ONNX Shape inference can't be run as exception are disabled."));
return true;
#else
static onnx::OpSchemaRegistry* OnnxSchemaRegistry = onnx::OpSchemaRegistry::Instance();
try
{
onnx::shape_inference::InferShapes(ModelProto, OnnxSchemaRegistry);
}
catch (onnx::InferenceError& e)
{
UE_LOG(LogNNERuntimeRDGHlsl, Warning, TEXT("Shape inference failed with : %s."), ANSI_TO_TCHAR(e.what()));
}
#endif
ModelProto.SerializeToArray(ModelData.GetData(), ModelData.Num());
return true;
}
};
FModelOptimizerONNXToONNX::FModelOptimizerONNXToONNX()
{
AddOptimizationPass(MakeShared<FOnnxRuntimeModelOptimizerPass>());
AddOptimizationPass(MakeShared<FOnnxDomainCleanupModelOptimizerPass>());
AddOptimizationPass(MakeShared<FOnnxShapeInferenceModelOptimizerPass>());
AddValidator(MakeShared<FModelValidatorONNX>());
}
} // namespace UE::NNERuntimeRDGUtils::Private