// Copyright Epic Games, Inc. All Rights Reserved. #if (defined(__AUTORTFM) && __AUTORTFM) #include "Transaction.h" #include "TransactionInlines.h" #include "CallNestInlines.h" #include "ContextStatus.h" #include "HitSet.h" #include "Utils.h" #include #include #include namespace AutoRTFM { void FTransaction::Resurrect(FContext* const InContext) { AUTORTFM_ASSERT(Context == InContext) Parent = nullptr; StatDepth = 1; RecordedWriteHash = 0; NumWriteLogsHashed = 0; CurrentMemoryValidationLevel = EMemoryValidationLevel::Disabled; CurrentOpenReturnAddress = nullptr; CurrentState = EState::Uninitialized; bIsStackScoped = false; bIsInAllocateFn = false; } void FTransaction::Suppress() { CurrentState = EState::Done; Reset(); } FTransaction** FTransaction::GetIntrusiveAddress() { return &Parent; } FTransaction::FTransaction(FContext* InContext) : Context(InContext) , CommitTasks(Context->GetTaskPool()) , PreAbortTasks(Context->GetTaskPool()) , AbortTasks(Context->GetTaskPool()) { Resurrect(Context); } FTransaction::~FTransaction() { Suppress(); } void FTransaction::Initialize(FTransaction* InParent, bool bInIsStackScoped, FStackRange InStackRange) { AUTORTFM_ASSERT(CurrentState == EState::Uninitialized); Parent = InParent; bIsStackScoped = bInIsStackScoped; StackRange = InStackRange; if (Parent) { // For stats, record the nested depth of the transaction. StatDepth = Parent->StatDepth + 1; } else { // Root-level transactions require a completion task list. CompletionTasks.emplace(Context->GetTaskPool()); } Stats.Collect(StatDepth); Stats.Collect(StatDepth); } void FTransaction::Reset() { AUTORTFM_ASSERT(IsDone()); CommitTasks.Reset(); PreAbortTasks.Reset(); AbortTasks.Reset(); if (CompletionTasks != std::nullopt) { CompletionTasks->Reset(); CompletionTasks = std::nullopt; } HitSet.Reset(); NewMemoryTracker.Reset(); WriteLog.Reset(); CurrentMemoryValidationLevel = EMemoryValidationLevel::Disabled; // Reset to the initial state. CurrentState = EState::Uninitialized; DeferredPopOnCommitHandlers.Reset(); DeferredPopOnAbortHandlers.Reset(); DeferredPopAllOnCommitHandlers.Reset(); DeferredPopAllOnAbortHandlers.Reset(); AUTORTFM_ASSERT(IsFresh()); } bool FTransaction::IsFresh() const { return HitSet.IsEmpty() && NewMemoryTracker.IsEmpty() && WriteLog.IsEmpty() && CommitTasks.IsEmpty() && PreAbortTasks.IsEmpty() && AbortTasks.IsEmpty() && (CompletionTasks == std::nullopt || CompletionTasks->IsEmpty()) && !IsDone() && DeferredPopOnCommitHandlers.IsEmpty() && DeferredPopOnAbortHandlers.IsEmpty() && DeferredPopAllOnCommitHandlers.IsEmpty() && DeferredPopAllOnAbortHandlers.IsEmpty(); } void FTransaction::AbortWithoutThrowing() { AUTORTFM_VERBOSE("Aborting '%hs'!", GetContextStatusName(Context->GetStatus())); AUTORTFM_ASSERT(Context->IsAborting()); Stats.Collect(); CollectStats(); // Ensure that we enter the done state before applying the commit, as this // will ensure the open-memory validation is performed before the write log // is cleared. SetDone(); const EMemoryValidationLevel MemoryValidationLevel = ForTheRuntime::GetMemoryValidationLevel(); // Call the destructors of all the OnCommit functors before undoing the transactional memory and // calling the OnAbort callbacks. This is important as the callback functions may have captured // variables that are depending on the allocated memory. if (AUTORTFM_LIKELY(MemoryValidationLevel == EMemoryValidationLevel::Disabled)) { CommitTasks.Reset(); } else { ResetCommitTasksWithValidation(); } // Call each of the pre memory rollback on-abort handlers. if (AUTORTFM_LIKELY(MemoryValidationLevel == EMemoryValidationLevel::Disabled)) { PreAbortTasks.RemoveEachBackward([&](TTask& Task) { Task(); }); } else { CallAbortTasksWithValidation(PreAbortTasks); } // Revert all memory in the write-log. Undo(); // Call each of the post memory rollback on-abort handlers. if (AUTORTFM_LIKELY(MemoryValidationLevel == EMemoryValidationLevel::Disabled)) { AbortTasks.RemoveEachBackward([&](TTask& Task) { Task(); }); } else { CallAbortTasksWithValidation(AbortTasks); } if (IsNested()) { AUTORTFM_ASSERT(Parent); } else { // Completion tasks don't use memory validation; by definition, they happen after the transaction is over. CompletionTasks->RemoveEachForward([&](TTask& Task) { Task(); }); AUTORTFM_ASSERT(Context->IsAborting()); } } void FTransaction::ResetCommitTasksWithValidation() { FOpenHashThrottler& OpenHashThrottler = Context->GetOpenHashThrottler(); const bool bThrottlingEnabled = AutoRTFM::ForTheRuntime::GetMemoryValidationThrottlingEnabled(); std::optional OldHash; CommitTasks.RemoveEachBackward([&](TTask& Task) { void* const FunctionAddress = Task.FunctionAddress(); const bool bShouldHash = !bThrottlingEnabled || OpenHashThrottler.ShouldHashFor(FunctionAddress); if (bShouldHash && !OldHash) { OldHash = CalculateNestedWriteHash(FunctionAddress); } Task.Reset(); if (bShouldHash) { const FWriteHash NewHash = CalculateNestedWriteHash(FunctionAddress); if (OldHash != NewHash) { OnValidatorWriteHashMismatch(ForTheRuntime::GetMemoryValidationLevel(), "Memory modified in a transaction was also modified in the destructor of the on-commit handler: %s\n" "This may lead to memory corruption if the transaction is aborted.", GetFunctionDescription(FunctionAddress).c_str()); OldHash = NewHash; } } else { // Task was reset without validation, so invalidate the snapshot hash. OldHash = std::nullopt; } }); } void FTransaction::CallAbortTasksWithValidation(FTaskArray& Tasks) { FOpenHashThrottler& OpenHashThrottler = Context->GetOpenHashThrottler(); const bool bThrottlingEnabled = AutoRTFM::ForTheRuntime::GetMemoryValidationThrottlingEnabled(); std::optional OldHash; auto MaybeValidate = [&](void* FunctionAddress, const char* Message, auto&& Fn) { const bool bShouldHash = !bThrottlingEnabled || OpenHashThrottler.ShouldHashFor(FunctionAddress); if (bShouldHash && !OldHash) { OldHash = CalculateNestedWriteHash(FunctionAddress); } Fn(); if (bShouldHash) { const FWriteHash NewHash = CalculateNestedWriteHash(FunctionAddress); if (OldHash != NewHash) { OnValidatorWriteHashMismatch(ForTheRuntime::GetMemoryValidationLevel(), Message, GetFunctionDescription(FunctionAddress).c_str()); OldHash = NewHash; } } else { // Fn() was run without validation, so invalidate the snapshot hash. OldHash = std::nullopt; } }; Tasks.RemoveEachBackward([&](TTask& Task) { void* const FunctionAddress = Task.FunctionAddress(); MaybeValidate(FunctionAddress, "Memory modified in a transaction was also modified in the on-abort handler: %s\n" "This may lead to memory corruption.", [&]{ Task(); }); MaybeValidate(FunctionAddress, "Memory modified in a transaction was also modified in the destructor of the on-abort handler: %s\n" "This may lead to memory corruption if the transaction is aborted.", [&]{ Task.Reset(); }); }); } void FTransaction::AbortAndThrow() { AbortWithoutThrowing(); Context->Throw(); } bool FTransaction::AttemptToCommit() { AUTORTFM_ASSERT(Context->GetStatus() == EContextStatus::Committing); AUTORTFM_ASSERT(Context->GetCurrentTransaction() == this); Stats.Collect(); CollectStats(); // Ensure that we enter the done state before applying the commit, as this // will ensure the open-memory validation is performed before the write log // is cleared. SetDone(); bool bResult; if (IsNested()) { CommitNested(); bResult = true; } else { bResult = AttemptToCommitOuterNest(); } return bResult; } void FTransaction::Undo() { AUTORTFM_VERBOSE("Undoing a transaction..."); AUTORTFM_ASSERT(IsDone()); for(auto Iter = WriteLog.rbegin(); Iter != WriteLog.rend(); ++Iter) { FWriteLogEntry Entry = *Iter; // No write records should be within the transaction's stack range. AUTORTFM_ENSURE(!IsOnStack(Entry.LogicalAddress)); memcpy(Entry.LogicalAddress, Entry.Data, Entry.Size); } AUTORTFM_VERBOSE("Undone a transaction!"); } void FTransaction::CommitNested() { AUTORTFM_ASSERT(Parent); AUTORTFM_ASSERT(CompletionTasks == std::nullopt); // We need to pass our write log to our parent transaction, but with care! // We need to discard any writes if the memory location is on the parent // transaction's stack range. for (FWriteLogEntry Write : WriteLog) { if (Parent->IsOnStack(Write.LogicalAddress)) { continue; } if (Write.Size <= FHitSet::MaxSize) { FHitSetEntry HitSetEntry{}; HitSetEntry.Address = reinterpret_cast(Write.LogicalAddress); HitSetEntry.Size = static_cast(Write.Size); HitSetEntry.bNoMemoryValidation = Write.bNoMemoryValidation; if (Parent->HitSet.FindOrTryInsert(HitSetEntry)) { continue; // Don't duplicate the write-log entry. } } Parent->WriteLog.Push(Write); } // For all the deferred calls to `PopOnCommitHandler` that we couldn't // process (because our transaction nest didn't `PushOnCommitHandler`) // we need to move these to the parent now to handle them. for (const void* Key : DeferredPopOnCommitHandlers) { Parent->PopDeferUntilCommitHandler(Key); } DeferredPopOnCommitHandlers.Reset(); // For all the deferred calls to `PopOnAbortHandler` that we couldn't // process (because our transaction nest didn't `PushOnAbortHandler`) // we need to move these to the parent now to handle them. for (const void* Key : DeferredPopOnAbortHandlers) { Parent->PopDeferUntilAbortHandler(Key); } DeferredPopOnAbortHandlers.Reset(); // For all the calls to `PopAllOnCommitHandlers` we need to run these // again on parent now to handle them there too. for (const void* Key : DeferredPopAllOnCommitHandlers) { Parent->PopAllDeferUntilCommitHandlers(Key); } DeferredPopAllOnCommitHandlers.Reset(); // For all the calls to `PopAllOnAbortHandlers` we need to run these // again on parent now to handle them there too. for (const void* Key : DeferredPopAllOnAbortHandlers) { Parent->PopAllDeferUntilAbortHandlers(Key); } DeferredPopAllOnAbortHandlers.Reset(); Parent->CommitTasks.AddAll(std::move(CommitTasks)); Parent->PreAbortTasks.AddAll(std::move(PreAbortTasks)); Parent->AbortTasks.AddAll(std::move(AbortTasks)); Parent->NewMemoryTracker.Merge(NewMemoryTracker); } bool FTransaction::AttemptToCommitOuterNest() { AUTORTFM_ASSERT(!Parent); AUTORTFM_VERBOSE("About to run commit tasks!"); Context->DumpState(); AUTORTFM_VERBOSE("Running commit tasks..."); PreAbortTasks.Reset(); AbortTasks.Reset(); CommitTasks.RemoveEachForward([] (TTask& Task) { Task(); }); CompletionTasks->RemoveEachForward([] (TTask& Task) { Task(); }); return true; } void FTransaction::SetOpenActiveValidatorEnabled(EMemoryValidationLevel NewMemoryValidationLevel, const void* ReturnAddress) { AUTORTFM_ASSERT(NewMemoryValidationLevel != EMemoryValidationLevel::Disabled); CurrentMemoryValidationLevel = NewMemoryValidationLevel; CurrentOpenReturnAddress = ReturnAddress; if (AutoRTFM::ForTheRuntime::GetMemoryValidationThrottlingEnabled()) { FOpenHashThrottler& Throttler = Context->GetOpenHashThrottler(); if (!Throttler.ShouldHashFor(ReturnAddress)) { CurrentMemoryValidationLevel = EMemoryValidationLevel::Disabled; } Throttler.Update(); } SetState(); } void FTransaction::SetOpenActive(EMemoryValidationLevel NewMemoryValidationLevel, const void* ReturnAddress) { if (AUTORTFM_UNLIKELY(NewMemoryValidationLevel != EMemoryValidationLevel::Disabled)) { AUTORTFM_MUST_TAIL return SetOpenActiveValidatorEnabled(NewMemoryValidationLevel, ReturnAddress); } CurrentMemoryValidationLevel = NewMemoryValidationLevel; CurrentOpenReturnAddress = ReturnAddress; // TODO: Validate if open -> open with different validation levels. RecordedWriteHash = 0; NumWriteLogsHashed = 0; CurrentOpenReturnAddress = nullptr; SetState(); } void FTransaction::SetClosedActive() { SetState(); } void FTransaction::SetOpenInactive() { SetState(); } void FTransaction::SetClosedInactive() { SetState(); } void FTransaction::SetActive() { switch (CurrentState) { case EState::OpenActive: case EState::ClosedActive: break; case EState::OpenInactive: SetState(); break; case EState::ClosedInactive: SetState(); break; default: AUTORTFM_FATAL("Invalid state"); } } void FTransaction::SetInactive() { switch (CurrentState) { case EState::OpenInactive: case EState::ClosedInactive: break; case EState::OpenActive: SetState(); break; case EState::ClosedActive: SetState(); break; default: AUTORTFM_FATAL("Invalid state"); } } void FTransaction::SetDone() { SetState(); } template void FTransaction::SetState() { AUTORTFM_ASSERT(NewState != CurrentState); switch (CurrentState) { case EState::Uninitialized: AUTORTFM_ASSERT(NewState == EState::OpenActive || NewState == EState::ClosedActive); break; // OpenActive -> OpenInactive, ClosedActive or Done case EState::OpenActive: AUTORTFM_ASSERT(NewState == EState::OpenInactive || NewState == EState::ClosedActive || NewState == EState::Done); if (CurrentMemoryValidationLevel != EMemoryValidationLevel::Disabled) { ValidateWriteHash(); RecordedWriteHash = 0; NumWriteLogsHashed = 0; } else { AUTORTFM_ASSERT(RecordedWriteHash == 0 && NumWriteLogsHashed == 0); } break; // ClosedActive -> ClosedInactive, OpenActive or Done case EState::ClosedActive: AUTORTFM_ASSERT(NewState == EState::ClosedInactive || NewState == EState::OpenActive || NewState == EState::Done); break; // ClosedActive -> OpenActive case EState::OpenInactive: AUTORTFM_ASSERT(NewState == EState::OpenActive); break; // ClosedInactive -> ClosedActive case EState::ClosedInactive: AUTORTFM_ASSERT(NewState == EState::ClosedActive); break; case EState::Done: AUTORTFM_FATAL("Once Done, the transaction cannot change state without a call to Reset()"); break; default: AUTORTFM_FATAL("Invalid state"); break; } // OpenInactive, ClosedActive or Done -> OpenActive if (NewState == EState::OpenActive) { if (CurrentMemoryValidationLevel != EMemoryValidationLevel::Disabled) { AUTORTFM_ASSERT(RecordedWriteHash == 0 && NumWriteLogsHashed == 0); RecordWriteHash(); } } CurrentState = NewState; } void FTransaction::DebugBreakIfMemoryValidationFails() { if (CurrentMemoryValidationLevel != EMemoryValidationLevel::Disabled) { FWriteHash OldHash = RecordedWriteHash; FWriteHash NewHash = CalculateNestedWriteHash(CurrentOpenReturnAddress); if (OldHash != NewHash) { AUTORTFM_WARN("DebugBreakIfInvalidMemoryHash() detected a change in hash"); __builtin_debugtrap(); } } } void FTransaction::RecordWriteHash() { NumWriteLogsHashed = WriteLog.Num(); RecordedWriteHash = CalculateNestedWriteHash(CurrentOpenReturnAddress); } void FTransaction::ValidateWriteHash() const { const FWriteHash OldHash = RecordedWriteHash; const FWriteHash NewHash = CalculateNestedWriteHash(CurrentOpenReturnAddress); if (OldHash != NewHash) { OnValidatorWriteHashMismatch(CurrentMemoryValidationLevel, "Memory modified in a transaction was also modified in an call to AutoRTFM::Open().\n" "This may lead to memory corruption if the transaction is aborted."); } } void FTransaction::OnValidatorWriteHashMismatch(EMemoryValidationLevel MemoryValidationLevel, const char* Message, ...) const { va_list Args; va_start(Args, Message); switch (MemoryValidationLevel) { case EMemoryValidationLevel::Default: case EMemoryValidationLevel::Disabled: break; case EMemoryValidationLevel::Warn: AUTORTFM_WARN_V(Message, Args); break; case EMemoryValidationLevel::Error: if (!ForTheRuntime::GetEnsureOnInternalAbort()) { AUTORTFM_FATAL_V(Message, Args); } else { AUTORTFM_ENSURE_MSG_V(false, Message, Args); } break; } va_end(Args); } FTransaction::FWriteHash FTransaction::CalculateNestedWriteHash(const void *FunctionAddress) const { const FWriteHash Hash = CalculateNestedWriteHashWithLimit(NumWriteLogsHashed, FunctionAddress); Context->GetOpenHashThrottler().Update(); return Hash; } FTransaction::FWriteHash FTransaction::CalculateNestedWriteHashWithLimit(size_t NumWriteEntries, const void *FunctionAddress) const { FWriteHash Hash = 0; if (nullptr != Parent) { Hash = 31 * Parent->CalculateNestedWriteHashWithLimit(Parent->WriteLog.Num(), FunctionAddress); } { FOpenHashThrottler::FHashScope Profile(Context->GetOpenHashThrottler(), FunctionAddress, WriteLog); Hash ^= WriteLog.Hash(NumWriteEntries); } return Hash; } } // namespace AutoRTFM #endif // defined(__AUTORTFM) && __AUTORTFM