// Copyright Epic Games, Inc. All Rights Reserved. #include "NNERuntimeRDGUtilsModelOptimizerONNX.h" #include "HAL/IConsoleManager.h" #include "HAL/PlatformFileManager.h" #include "HAL/PlatformTime.h" #include "Misc/FileHelper.h" #include "Misc/Paths.h" #include "NNEHlslShadersLog.h" #include "NNEOnnxruntimeEditor.h" THIRD_PARTY_INCLUDES_START #include #include THIRD_PARTY_INCLUDES_END namespace UE::NNERuntimeRDGUtils::Private { class FOnnxRuntimeModelOptimizerPass : public Internal::IModelOptimizerPass { public: FOnnxRuntimeModelOptimizerPass() { } virtual FString GetName() const { return TEXT("Onnx runtime model optimization"); } virtual bool ApplyPass(TArray& ModelData) const { onnx::ModelProto ModelProto; const bool result = ModelProto.ParseFromArray(ModelData.GetData(), ModelData.Num()); if (!result) { UE_LOG(LogNNERuntimeRDGHlsl, Warning, TEXT("%s could not parse the input model as a ModelProto."), *(GetName())); return false; } static const auto CVarHlslModelOptimization = IConsoleManager::Get().FindTConsoleVariableDataInt(TEXT("nne.hlsl.ModelOptimization")); if (CVarHlslModelOptimization && CVarHlslModelOptimization->GetValueOnAnyThread() == 0) { return true; } // We disable ONNX Runtime optimizations if a model uses FP16 Tensors // since these optimizations would add cast operators from and to FP16 // at the beginning and end of the network and convert all other operators to FP32. if (HasFP16Tensor(ModelProto)) { return true; } FString ProjIntermediateDir = FPaths::ConvertRelativePathToFull(FPaths::ProjectIntermediateDir()); FString ModelToOptimizePath = FPaths::CreateTempFilename(*ProjIntermediateDir, TEXT("ORTOptimizerPass_ToOptimize"), TEXT(".onnx")); FString ModelOptimizedPath = FPaths::CreateTempFilename(*ProjIntermediateDir, TEXT("ORTOptimizerPass_Optimized"), TEXT(".onnx")); //See https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html //We only enable all the optimization when going to ORT format itself for the CPU provider GraphOptimizationLevel OptimizationLevel = ORT_ENABLE_BASIC; FFileHelper::SaveArrayToFile(ModelData, *ModelToOptimizePath); { Ort::ThreadingOptions ThreadingOptions; ThreadingOptions.SetGlobalIntraOpNumThreads(1); ThreadingOptions.SetGlobalInterOpNumThreads(1); Ort::Env Env(ThreadingOptions); Ort::SessionOptions SessionOptions; SessionOptions.DisablePerSessionThreads(); SessionOptions.SetGraphOptimizationLevel(OptimizationLevel); #if PLATFORM_WINDOWS SessionOptions.SetOptimizedModelFilePath(*ModelOptimizedPath); Ort::Session Session(Env, *ModelToOptimizePath, SessionOptions); #else SessionOptions.SetOptimizedModelFilePath(TCHAR_TO_ANSI(*ModelOptimizedPath)); Ort::Session Session(Env, TCHAR_TO_ANSI(*ModelToOptimizePath), SessionOptions); #endif } FFileHelper::LoadFileToArray(ModelData, *ModelOptimizedPath); IFileManager::Get().Delete(*ModelToOptimizePath); IFileManager::Get().Delete(*ModelOptimizedPath); return true; } private: bool HasFP16Tensor(const onnx::ModelProto& Model) const { const onnx::GraphProto& Graph = Model.graph(); for (const onnx::TensorProto& Tensor : Graph.initializer()) { if (Tensor.data_type() == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) { return true; } } for (onnx::ValueInfoList Tensors : {Graph.input(), Graph.output()}) { for (const onnx::ValueInfoProto& Tensor : Tensors) { if (!Tensor.has_type()) { continue; } const onnx::TypeProto Type = Tensor.type(); if (!Type.has_tensor_type()) { continue; } const onnx::TypeProto_Tensor TensorType = Type.tensor_type(); if (!TensorType.has_elem_type()) { continue; } if (TensorType.elem_type() == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) { return true; } } } return false; } }; class FOnnxDomainCleanupModelOptimizerPass : public Internal::IModelOptimizerPass { public: virtual FString GetName() const { return TEXT("Onnx domain cleanup"); } virtual bool ApplyPass(TArray& ModelData) const { onnx::ModelProto ModelProto; const bool result = ModelProto.ParseFromArray(ModelData.GetData(), ModelData.Num()); if (!result) { UE_LOG(LogNNERuntimeRDGHlsl, Warning, TEXT("%s could not parse the input model as a ModelProto."), *(GetName())); return false; } TArray UsedDomains; const onnx::GraphProto& graph = ModelProto.graph(); for (const onnx::NodeProto& node : graph.node()) { const char* DomainPtr = node.domain().c_str(); FString Domain = DomainPtr; UsedDomains.AddUnique(Domain); } google::protobuf::RepeatedPtrField UsedOperatorSet; for (const onnx::OperatorSetIdProto& OpSet : ModelProto.opset_import()) { const char* DomainPtr = OpSet.domain().c_str(); FString Domain = DomainPtr; bool IsUsed = UsedDomains.Contains(Domain); // We additionally add the OpSet from models that don't have any UsedDomains/Nodes // since we otherwise would get an invalid model if (IsUsed || UsedDomains.IsEmpty()) { UsedOperatorSet.Add()->CopyFrom(OpSet); } } ModelProto.mutable_opset_import()->Clear(); ModelProto.mutable_opset_import()->Add(UsedOperatorSet.cbegin(), UsedOperatorSet.cend()); ModelData.SetNumUninitialized(ModelProto.ByteSizeLong()); ModelProto.SerializeToArray(ModelData.GetData(), ModelData.Num()); return true; } }; class FOnnxShapeInferenceModelOptimizerPass : public Internal::IModelOptimizerPass { public: virtual FString GetName() const { return TEXT("Onnx shape inference"); } virtual bool ApplyPass(TArray& ModelData) const { onnx::ModelProto ModelProto; const bool result = ModelProto.ParseFromArray(ModelData.GetData(), ModelData.Num()); if (!result) { UE_LOG(LogNNERuntimeRDGHlsl, Warning, TEXT("%s could not parse the input model as a ModelProto."), *(GetName())); return false; } #ifdef ONNX_NO_EXCEPTIONS UE_LOG(LogNNERuntimeRDGHlsl, Warning, TEXT("ONNX Shape inference can't be run as exception are disabled.")); return true; #else static onnx::OpSchemaRegistry* OnnxSchemaRegistry = onnx::OpSchemaRegistry::Instance(); try { onnx::shape_inference::InferShapes(ModelProto, OnnxSchemaRegistry); } catch (onnx::InferenceError& e) { UE_LOG(LogNNERuntimeRDGHlsl, Warning, TEXT("Shape inference failed with : %s."), ANSI_TO_TCHAR(e.what())); } #endif ModelProto.SerializeToArray(ModelData.GetData(), ModelData.Num()); return true; } }; FModelOptimizerONNXToONNX::FModelOptimizerONNXToONNX() { AddOptimizationPass(MakeShared()); AddOptimizationPass(MakeShared()); AddOptimizationPass(MakeShared()); AddValidator(MakeShared()); } } // namespace UE::NNERuntimeRDGUtils::Private