// Copyright Epic Games, Inc. All Rights Reserved. #include "LearningAgentsImitationTrainer.h" #include "LearningAgentsCommunicator.h" #include "LearningAgentsInteractor.h" #include "LearningAgentsManager.h" #include "LearningAgentsPolicy.h" #include "LearningAgentsRecording.h" #include "LearningExperience.h" #include "LearningExternalTrainer.h" #include "LearningLog.h" #include "LearningNeuralNetwork.h" #include "LearningTrainer.h" #include "Dom/JsonObject.h" #include "HAL/PlatformMisc.h" #include "Misc/Paths.h" #include UE_INLINE_GENERATED_CPP_BY_NAME(LearningAgentsImitationTrainer) TSharedRef FLearningAgentsImitationTrainerTrainingSettings::AsJsonConfig() const { TSharedRef ConfigObject = MakeShared(); ConfigObject->SetNumberField(TEXT("IterationNum"), NumberOfIterations); ConfigObject->SetNumberField(TEXT("LearningRate"), LearningRate); ConfigObject->SetNumberField(TEXT("LearningRateDecay"), LearningRateDecay); ConfigObject->SetNumberField(TEXT("LearningRateDecayStepSize"), LearningRateDecayStepSize); ConfigObject->SetNumberField(TEXT("WeightDecay"), WeightDecay); ConfigObject->SetNumberField(TEXT("BatchSize"), BatchSize); ConfigObject->SetNumberField(TEXT("Window"), Window); ConfigObject->SetNumberField(TEXT("ActionRegularizationWeight"), ActionRegularizationWeight); ConfigObject->SetNumberField(TEXT("ActionEntropyWeight"), ActionEntropyWeight); ConfigObject->SetNumberField(TEXT("ObservationNoiseScale"), ObservationNoiseScale); ConfigObject->SetNumberField(TEXT("Seed"), RandomSeed); ConfigObject->SetBoolField(TEXT("SaveSnapshots"), bSaveSnapshots); ConfigObject->SetNumberField(TEXT("IterationsPerSnapshot"), IterationsPerSnapshot); // Eval ConfigObject->SetBoolField(TEXT("RunEvaluation"), bRunEvaluation); ConfigObject->SetNumberField(TEXT("TrainEvalDatasetSplit"), TrainEvalDatasetSplit); ConfigObject->SetNumberField(TEXT("EvaluationFrequency"), EvaluationFrequency); ConfigObject->SetNumberField(TEXT("BatchCountPerEvaluation"), BatchCountPerEvaluation); return ConfigObject; } ULearningAgentsImitationTrainer::ULearningAgentsImitationTrainer() : Super(FObjectInitializer::Get()) {} ULearningAgentsImitationTrainer::ULearningAgentsImitationTrainer(FVTableHelper& Helper) : Super(Helper) {} ULearningAgentsImitationTrainer::~ULearningAgentsImitationTrainer() = default; void ULearningAgentsImitationTrainer::BeginDestroy() { if (IsTraining()) { EndTraining(); } Super::BeginDestroy(); } ULearningAgentsImitationTrainer* ULearningAgentsImitationTrainer::MakeImitationTrainer( ULearningAgentsManager*& InManager, ULearningAgentsInteractor*& InInteractor, ULearningAgentsPolicy*& InPolicy, const FLearningAgentsCommunicator& Communicator, TSubclassOf Class, const FName Name) { if (!InManager) { UE_LOG(LogLearning, Error, TEXT("MakeImitationTrainer: InManager is nullptr.")); return nullptr; } if (!Class) { UE_LOG(LogLearning, Error, TEXT("MakeImitationTrainer: Class is nullptr.")); return nullptr; } const FName UniqueName = MakeUniqueObjectName(InManager, Class, Name, EUniqueObjectNameOptions::GloballyUnique); ULearningAgentsImitationTrainer* ImitationTrainer = NewObject(InManager, Class, UniqueName); if (!ImitationTrainer) { return nullptr; } ImitationTrainer->SetupImitationTrainer(InManager, InInteractor, InPolicy, Communicator); return ImitationTrainer->IsSetup() ? ImitationTrainer : nullptr; } void ULearningAgentsImitationTrainer::SetupImitationTrainer( ULearningAgentsManager* InManager, ULearningAgentsInteractor* InInteractor, ULearningAgentsPolicy* InPolicy, const FLearningAgentsCommunicator& Communicator) { if (IsSetup()) { UE_LOG(LogLearning, Error, TEXT("%s: Setup already performed!"), *GetName()); return; } if (!InManager) { UE_LOG(LogLearning, Error, TEXT("%s: InManager is nullptr."), *GetName()); return; } if (!InInteractor) { UE_LOG(LogLearning, Error, TEXT("%s: InInteractor is nullptr."), *GetName()); return; } if (!InInteractor->IsSetup()) { UE_LOG(LogLearning, Error, TEXT("%s: %s's Setup must be run before it can be used."), *GetName(), *InInteractor->GetName()); return; } if (!InPolicy) { UE_LOG(LogLearning, Error, TEXT("%s: InPolicy is nullptr."), *GetName()); return; } if (!InPolicy->IsSetup()) { UE_LOG(LogLearning, Error, TEXT("%s: %s's Setup must be run before it can be used."), *GetName(), *InPolicy->GetName()); return; } if (!Communicator.Trainer) { UE_LOG(LogLearning, Error, TEXT("%s: Communicator's Trainer is nullptr."), *GetName()); return; } Manager = InManager; Interactor = InInteractor; Policy = InPolicy; Trainer = Communicator.Trainer; bIsSetup = true; Manager->AddListener(this); } void ULearningAgentsImitationTrainer::BeginTraining( const ULearningAgentsRecording* Recording, const FLearningAgentsImitationTrainerSettings& ImitationTrainerSettings, const FLearningAgentsImitationTrainerTrainingSettings& ImitationTrainerTrainingSettings, const FLearningAgentsTrainerProcessSettings& ImitationTrainerPathSettings) { if (!IsSetup()) { UE_LOG(LogLearning, Error, TEXT("%s: Setup not complete."), *GetName()); return; } if (IsTraining()) { UE_LOG(LogLearning, Error, TEXT("%s: Cannot begin training as we are already training!"), *GetName()); return; } if (!Recording) { UE_LOG(LogLearning, Error, TEXT("%s: Recording is nullptr."), *GetName()); return; } if (Recording->Records.IsEmpty()) { UE_LOG(LogLearning, Error, TEXT("%s: Recording is empty!"), *GetName()); return; } // Get Number of Steps int32 TotalEpisodeNum = 0; int32 TotalStepNum = 0; for (const FLearningAgentsRecord& Record : Recording->Records) { if (ImitationTrainerSettings.RecordingFilterTag.IsValid() && !Record.Tags.HasTag(ImitationTrainerSettings.RecordingFilterTag)) { UE_LOG(LogLearning, Log, TEXT("Skipping record %s because its missing tag %s."), *Record.Name, *ImitationTrainerSettings.RecordingFilterTag.ToString()); continue; } if (!Recording->Schemas.Contains(Record.SchemaName)) { UE_LOG(LogLearning, Error, TEXT("Schema name %s not found in schema list. Is the schema added and are the names correct?"), *Record.SchemaName.ToString()); continue; } if (!UE::Learning::Trainer::IsObservationSchemaSubsetCompatible(Recording->Schemas[Record.SchemaName].ObservationSchemaJson, Interactor->GetObservationSchema()->ObservationSchema, Interactor->GetObservationSchemaElement().SchemaElement)) { UE_LOG(LogLearning, Warning, TEXT("%s: Record has incompatible observation schema."), *GetName()); continue; } if (!UE::Learning::Trainer::IsActionSchemaSubsetCompatible(Recording->Schemas[Record.SchemaName].ActionSchemaJson, Interactor->GetActionSchema()->ActionSchema, Interactor->GetActionSchemaElement().SchemaElement)) { UE_LOG(LogLearning, Warning, TEXT("%s: Record has incompatible action schema."), *GetName()); continue; } TotalEpisodeNum++; TotalStepNum += Record.StepNum; } if (TotalStepNum == 0) { UE_LOG(LogLearning, Warning, TEXT("%s: Recording contains no valid training data."), *GetName()); return; } // Sizes const int32 ObservationNum = Interactor->GetObservationVectorSize(); const int32 ActionNum = Interactor->GetActionVectorSize(); // Copy into Flat Arrays TLearningArray<1, int32> RecordedEpisodeStarts; TLearningArray<1, int32> RecordedEpisodeLengths; TLearningArray<2, float> RecordedObservations; TLearningArray<2, float> RecordedActions; RecordedEpisodeStarts.SetNumUninitialized({ TotalEpisodeNum }); RecordedEpisodeLengths.SetNumUninitialized({ TotalEpisodeNum }); RecordedObservations.SetNumUninitialized({ TotalStepNum, ObservationNum }); RecordedActions.SetNumUninitialized({ TotalStepNum, ActionNum }); int32 EpisodeIdx = 0; int32 StepIdx = 0; for (const FLearningAgentsRecord& Record : Recording->Records) { if (ImitationTrainerSettings.RecordingFilterTag.IsValid() && !Record.Tags.HasTag(ImitationTrainerSettings.RecordingFilterTag)) { UE_LOG(LogLearning, Log, TEXT("Skipping record %s because its missing tag %s."), *Record.Name, *ImitationTrainerSettings.RecordingFilterTag.ToString()); continue; } if (!Recording->Schemas.Contains(Record.SchemaName)) { UE_LOG(LogLearning, Error, TEXT("Schema name %s not found in schema list. Is the schema added and are the names correct?"), *Record.SchemaName.ToString()); continue; } if (Record.ObservationDimNum == ObservationNum && Record.ActionDimNum == ActionNum) { TLearningArrayView<2, const float> ObservationsView = TLearningArrayView<2, const float>(Record.ObservationData.GetData(), { Record.StepNum, Record.ObservationDimNum }); TLearningArrayView<2, const float> ActionsView = TLearningArrayView<2, const float>(Record.ActionData.GetData(), { Record.StepNum, Record.ActionDimNum }); RecordedEpisodeStarts[EpisodeIdx] = StepIdx; RecordedEpisodeLengths[EpisodeIdx] = Record.StepNum; UE::Learning::Array::Copy(RecordedObservations.Slice(StepIdx, Record.StepNum), ObservationsView); UE::Learning::Array::Copy(RecordedActions.Slice(StepIdx, Record.StepNum), ActionsView); } else if (UE::Learning::Trainer::IsObservationSchemaSubsetCompatible(Recording->Schemas[Record.SchemaName].ObservationSchemaJson, Interactor->GetObservationSchema()->ObservationSchema, Interactor->GetObservationSchemaElement().SchemaElement) && UE::Learning::Trainer::IsActionSchemaSubsetCompatible(Recording->Schemas[Record.SchemaName].ActionSchemaJson, Interactor->GetActionSchema()->ActionSchema, Interactor->GetActionSchemaElement().SchemaElement)) { TLearningArrayView<2, const float> ObservationsView = TLearningArrayView<2, const float>(Record.ObservationData.GetData(), { Record.StepNum, Record.ObservationDimNum }); TLearningArrayView<2, const float> ActionsView = TLearningArrayView<2, const float>(Record.ActionData.GetData(), { Record.StepNum, Record.ActionDimNum }); RecordedEpisodeStarts[EpisodeIdx] = StepIdx; RecordedEpisodeLengths[EpisodeIdx] = Record.StepNum; TArray ObservationIndices = UE::Learning::Trainer::ComputeObservationSchemaSubsetIndices( Recording->Schemas[Record.SchemaName].ObservationSchemaJson, Interactor->GetObservationSchema()->ObservationSchema, Interactor->GetObservationSchemaElement().SchemaElement); check(ObservationIndices.Num() == ObservationNum); for (int Step = 0; Step < Record.StepNum; Step++) { for (int i = 0; i < ObservationIndices.Num(); i++) { RecordedObservations[StepIdx + Step][i] = ObservationsView[Step][ObservationIndices[i]]; } } TArray ActionIndices = UE::Learning::Trainer::ComputeActionSchemaSubsetIndices( Recording->Schemas[Record.SchemaName].ActionSchemaJson, Interactor->GetActionSchema()->ActionSchema, Interactor->GetActionSchemaElement().SchemaElement); check(ActionIndices.Num() == ActionNum); for (int Step = 0; Step < Record.StepNum; Step++) { for (int i = 0; i < ActionIndices.Num(); i++) { RecordedActions[StepIdx + Step][i] = ActionsView[Step][ActionIndices[i]]; } } } else { continue; } EpisodeIdx++; StepIdx += Record.StepNum; } check(EpisodeIdx == TotalEpisodeNum); check(StepIdx == TotalStepNum); const int32 ObservationSchemaId = 0; const int32 ActionSchemaId = 0; // Create Replay Buffer from Records ReplayBuffer = MakeUnique(); ReplayBuffer->AddRecords( TotalEpisodeNum, TotalStepNum, ObservationSchemaId, ObservationNum, ActionSchemaId, ActionNum, RecordedEpisodeStarts, RecordedEpisodeLengths, RecordedObservations, RecordedActions); // We need to setup the trainer prior to sending the config PolicyNetworkId = Trainer->AddNetwork(*Policy->GetPolicyNetworkAsset()->NeuralNetworkData); EncoderNetworkId = Trainer->AddNetwork(*Policy->GetEncoderNetworkAsset()->NeuralNetworkData); DecoderNetworkId = Trainer->AddNetwork(*Policy->GetDecoderNetworkAsset()->NeuralNetworkData); ReplayBufferId = Trainer->AddReplayBuffer(*ReplayBuffer); TSharedRef DataConfigObject = CreateDataConfig(); TSharedRef TrainerConfigObject = CreateTrainerConfig(ImitationTrainerTrainingSettings); UE_LOG(LogLearning, Display, TEXT("%s: Sending configs..."), *GetName()); SendConfigs(DataConfigObject, TrainerConfigObject); UE_LOG(LogLearning, Display, TEXT("%s: Imitation Training Started"), *GetName()); UE::Learning::ETrainerResponse Response = UE::Learning::ETrainerResponse::Success; UE_LOG(LogLearning, Display, TEXT("%s: Sending / Receiving initial policy..."), *GetName()); Response = Trainer->SendNetwork(PolicyNetworkId, *Policy->GetPolicyNetworkAsset()->NeuralNetworkData); if (Response != UE::Learning::ETrainerResponse::Success) { UE_LOG(LogLearning, Error, TEXT("%s: Error sending policy to trainer: %s. Check log for additional errors."), *GetName(), UE::Learning::Trainer::GetResponseString(Response)); bHasTrainingFailed = true; Trainer->Terminate(); return; } Response = Trainer->SendNetwork(EncoderNetworkId, *Policy->GetEncoderNetworkAsset()->NeuralNetworkData); if (Response != UE::Learning::ETrainerResponse::Success) { UE_LOG(LogLearning, Error, TEXT("%s: Error sending encoder to trainer: %s. Check log for additional errors."), *GetName(), UE::Learning::Trainer::GetResponseString(Response)); bHasTrainingFailed = true; Trainer->Terminate(); return; } Response = Trainer->SendNetwork(DecoderNetworkId, *Policy->GetDecoderNetworkAsset()->NeuralNetworkData); if (Response != UE::Learning::ETrainerResponse::Success) { UE_LOG(LogLearning, Error, TEXT("%s: Error sending decoder to trainer: %s. Check log for additional errors."), *GetName(), UE::Learning::Trainer::GetResponseString(Response)); bHasTrainingFailed = true; Trainer->Terminate(); return; } UE_LOG(LogLearning, Display, TEXT("%s: Sending Experience..."), *GetName()); Response = Trainer->SendReplayBuffer(ReplayBufferId, *ReplayBuffer); if (Response != UE::Learning::ETrainerResponse::Success) { UE_LOG(LogLearning, Error, TEXT("%s: Error sending experience to trainer: %s. Check log for additional errors."), *GetName(), UE::Learning::Trainer::GetResponseString(Response)); bHasTrainingFailed = true; Trainer->Terminate(); return; } bIsTraining = true; } TSharedRef ULearningAgentsImitationTrainer::CreateDataConfig() const { TSharedRef ConfigObject = MakeShared(); const int32 ObservationSchemaId = 0; const int32 ActionSchemaId = 0; // Add Neural Network Config Entries TArray> NetworkObjects; // Policy { TSharedPtr NetworkObject = MakeShared(); NetworkObject->SetNumberField(TEXT("Id"), PolicyNetworkId); NetworkObject->SetStringField(TEXT("Name"), Policy->GetPolicyNetworkAsset()->GetFName().ToString()); NetworkObject->SetNumberField(TEXT("MaxByteNum"), Policy->GetPolicyNetworkAsset()->NeuralNetworkData->GetSnapshotByteNum()); TSharedRef JsonValue = MakeShared(NetworkObject); NetworkObjects.Add(JsonValue); } // Encoder { TSharedPtr NetworkObject = MakeShared(); NetworkObject->SetNumberField(TEXT("Id"), EncoderNetworkId); NetworkObject->SetStringField(TEXT("Name"), Policy->GetEncoderNetworkAsset()->GetFName().ToString()); NetworkObject->SetNumberField(TEXT("MaxByteNum"), Policy->GetEncoderNetworkAsset()->NeuralNetworkData->GetSnapshotByteNum()); NetworkObject->SetNumberField(TEXT("InputSchemaId"), ObservationSchemaId); TSharedRef JsonValue = MakeShared(NetworkObject); NetworkObjects.Add(JsonValue); } // Decoder { TSharedPtr NetworkObject = MakeShared(); NetworkObject->SetNumberField(TEXT("Id"), DecoderNetworkId); NetworkObject->SetStringField(TEXT("Name"), Policy->GetDecoderNetworkAsset()->GetFName().ToString()); NetworkObject->SetNumberField(TEXT("MaxByteNum"), Policy->GetDecoderNetworkAsset()->NeuralNetworkData->GetSnapshotByteNum()); NetworkObject->SetNumberField(TEXT("OutputSchemaId"), ActionSchemaId); TSharedRef JsonValue = MakeShared(NetworkObject); NetworkObjects.Add(JsonValue); } ConfigObject->SetArrayField(TEXT("Networks"), NetworkObjects); // Add Replay Buffers Config Entries TArray> ReplayBufferObjects; TSharedRef ReplayBufferJsonValue = MakeShared(ReplayBuffer->AsJsonConfig(ReplayBufferId)); ReplayBufferObjects.Add(ReplayBufferJsonValue); ConfigObject->SetArrayField(TEXT("ReplayBuffers"), ReplayBufferObjects); // Schemas TSharedPtr SchemasObject = MakeShared(); // Add the observation schemas TArray> ObservationSchemaObjects; { // For this trainer, add the one observation schema we have TSharedPtr ObservationSchemaObject = MakeShared(); ObservationSchemaObject->SetNumberField(TEXT("Id"), ObservationSchemaId); ObservationSchemaObject->SetStringField(TEXT("Name"), "Default"); ObservationSchemaObject->SetObjectField(TEXT("Schema"), UE::Learning::Trainer::ConvertObservationSchemaToJSON(Interactor->GetObservationSchema()->ObservationSchema, Interactor->GetObservationSchemaElement().SchemaElement)); TSharedRef JsonValue = MakeShared(ObservationSchemaObject); ObservationSchemaObjects.Add(JsonValue); } SchemasObject->SetArrayField(TEXT("Observations"), ObservationSchemaObjects); // Add the action schemas TArray> ActionSchemaObjects; { // For this trainer, add the one action schema we have TSharedPtr ActionSchemaObject = MakeShared(); ActionSchemaObject->SetNumberField(TEXT("Id"), ActionSchemaId); ActionSchemaObject->SetStringField(TEXT("Name"), "Default"); ActionSchemaObject->SetObjectField(TEXT("Schema"), UE::Learning::Trainer::ConvertActionSchemaToJSON(Interactor->GetActionSchema()->ActionSchema, Interactor->GetActionSchemaElement().SchemaElement)); TSharedRef JsonValue = MakeShared(ActionSchemaObject); ActionSchemaObjects.Add(JsonValue); } SchemasObject->SetArrayField(TEXT("Actions"), ActionSchemaObjects); ConfigObject->SetObjectField(TEXT("Schemas"), SchemasObject); return ConfigObject; } TSharedRef ULearningAgentsImitationTrainer::CreateTrainerConfig(const FLearningAgentsImitationTrainerTrainingSettings& TrainingSettings) const { TSharedRef ConfigObject = MakeShared(); // Add Training Config Entries ConfigObject->SetStringField(TEXT("TrainerMethod"), TEXT("BehaviorCloning")); ConfigObject->SetStringField(TEXT("TimeStamp"), *FDateTime::Now().ToFormattedString(TEXT("%Y-%m-%d_%H-%M-%S"))); // Add Imitation Specific Config Entries ConfigObject->SetObjectField(TEXT("BehaviorCloningSettings"), TrainingSettings.AsJsonConfig()); ConfigObject->SetObjectField(TEXT("PolicySettings"), Policy->AsJsonConfig()); ConfigObject->SetNumberField(TEXT("MemoryStateNum"), Policy->GetMemoryStateSize()); // Add Training Workflow Config Entries ConfigObject->SetStringField(TEXT("Device"), UE::Learning::Trainer::GetDeviceString(UE::Learning::Agents::GetTrainingDevice(TrainingSettings.Device))); ConfigObject->SetBoolField(TEXT("UseTensorBoard"), TrainingSettings.bUseTensorboard); ConfigObject->SetBoolField(TEXT("UseMLflow"), TrainingSettings.bUseMLflow); ConfigObject->SetStringField(TEXT("MLflowTrackingUri"), TrainingSettings.MLflowTrackingUri); return ConfigObject; } void ULearningAgentsImitationTrainer::SendConfigs(const TSharedRef& DataConfigObject, const TSharedRef& TrainerConfigObject) { UE::Learning::ETrainerResponse Response = UE::Learning::ETrainerResponse::Success; Response = Trainer->SendConfigs(DataConfigObject, TrainerConfigObject); if (Response != UE::Learning::ETrainerResponse::Success) { UE_LOG(LogLearning, Error, TEXT("%s: Error sending config to trainer: %s. Check log for additional errors."), *GetName(), UE::Learning::Trainer::GetResponseString(Response)); bHasTrainingFailed = true; Trainer->Terminate(); return; } } void ULearningAgentsImitationTrainer::DoneTraining() { if (!IsSetup()) { UE_LOG(LogLearning, Error, TEXT("%s: Setup not complete."), *GetName()); return; } if (IsTraining()) { // Wait for Trainer to finish Trainer->Wait(); // If not finished in time, terminate Trainer->Terminate(); bIsTraining = false; } } void ULearningAgentsImitationTrainer::EndTraining() { if (!IsSetup()) { UE_LOG(LogLearning, Error, TEXT("%s: Setup not complete."), *GetName()); return; } if (IsTraining()) { UE_LOG(LogLearning, Display, TEXT("%s: Stopping training..."), *GetName()); Trainer->SendStop(); DoneTraining(); } } void ULearningAgentsImitationTrainer::IterateTraining() { TRACE_CPUPROFILER_EVENT_SCOPE(ULearningAgentsImitationTrainer::IterateTraining); if (!IsSetup()) { UE_LOG(LogLearning, Error, TEXT("%s: Setup not complete."), *GetName()); return; } if (!IsTraining()) { UE_LOG(LogLearning, Error, TEXT("%s: Training not running."), *GetName()); return; } if (Trainer->HasNetworkOrCompleted()) { UE_LOG(LogLearning, Display, TEXT("Receiving trained networks...")); UE::Learning::ETrainerResponse Response = Trainer->ReceiveNetwork(PolicyNetworkId, *Policy->GetPolicyNetworkAsset()->NeuralNetworkData); if (Response == UE::Learning::ETrainerResponse::Completed) { UE_LOG(LogLearning, Display, TEXT("%s: Trainer completed training."), *GetName()); DoneTraining(); return; } else if (Response != UE::Learning::ETrainerResponse::Success) { UE_LOG(LogLearning, Error, TEXT("%s: Error receiving policy from trainer: %s. Check log for additional errors."), *GetName(), UE::Learning::Trainer::GetResponseString(Response)); bHasTrainingFailed = true; EndTraining(); return; } Policy->GetPolicyNetworkAsset()->ForceMarkDirty(); Response = Trainer->ReceiveNetwork(EncoderNetworkId, *Policy->GetEncoderNetworkAsset()->NeuralNetworkData); if (Response != UE::Learning::ETrainerResponse::Success) { UE_LOG(LogLearning, Error, TEXT("%s: Error receiving encoder from trainer: %s. Check log for additional errors."), *GetName(), UE::Learning::Trainer::GetResponseString(Response)); bHasTrainingFailed = true; EndTraining(); return; } Policy->GetEncoderNetworkAsset()->ForceMarkDirty(); Response = Trainer->ReceiveNetwork(DecoderNetworkId, *Policy->GetDecoderNetworkAsset()->NeuralNetworkData); if (Response != UE::Learning::ETrainerResponse::Success) { UE_LOG(LogLearning, Error, TEXT("%s: Error receiving decoder from trainer: %s. Check log for additional errors."), *GetName(), UE::Learning::Trainer::GetResponseString(Response)); bHasTrainingFailed = true; EndTraining(); return; } Policy->GetDecoderNetworkAsset()->ForceMarkDirty(); } } void ULearningAgentsImitationTrainer::RunTraining( const ULearningAgentsRecording* Recording, const FLearningAgentsImitationTrainerSettings& ImitationTrainerSettings, const FLearningAgentsImitationTrainerTrainingSettings& ImitationTrainerTrainingSettings, const FLearningAgentsTrainerProcessSettings& ImitationTrainerPathSettings) { if (!IsSetup()) { UE_LOG(LogLearning, Error, TEXT("%s: Setup not complete."), *GetName()); return; } if (bHasTrainingFailed) { UE_LOG(LogLearning, Error, TEXT("%s: Training has failed. Check log for errors."), *GetName()); #if !WITH_EDITOR FGenericPlatformMisc::RequestExitWithStatus(false, 99); #endif return; } // If we aren't training yet, then start training and do the first inference step. if (!IsTraining()) { BeginTraining( Recording, ImitationTrainerSettings, ImitationTrainerTrainingSettings, ImitationTrainerPathSettings); if (!IsTraining()) { // If IsTraining is false, then BeginTraining must have failed and we can't continue. return; } } // Otherwise, do the regular training process. IterateTraining(); } bool ULearningAgentsImitationTrainer::IsTraining() const { return bIsTraining; } bool ULearningAgentsImitationTrainer::HasTrainingFailed() const { return bHasTrainingFailed; }