Files
Brandyn / Techy fcc1b09210 init
2026-04-04 15:40:51 -05:00

491 lines
16 KiB
C++

// Copyright Epic Games, Inc. All Rights Reserved.
#pragma once
// Vantage-point tree implementation https://en.wikipedia.org/wiki/Vantage-point_tree
// it relies on templated methods to specify the
// - data point type (typename T, usually a multidimensional point like TArray<float>)
// - data source struct (typename TDataSource) that needs to implement
// struct FDataSource
// {
// const T& operator[](int32 Index) const;
// int32 Num() const;
// static float GetDistance(const T& A, const T& B);
// };
// and can be specialized for runtime calls at FindNeighbors to laverage vectorized code with aliged and padded data
// - the result set struct (typename TVPTreeResultSet), see struct FVPTreeResultSet below as example
#include "Containers/ArrayView.h"
#include "Math/RandomStream.h"
#include "PoseSearch/PoseSearchDefines.h"
namespace UE::PoseSearch
{
#define VALIDATE_FINDNEIGHBORS 0
typedef float VPTreeScalar;
struct FIndexDistance
{
FIndexDistance(int32 InIndex = INDEX_NONE, VPTreeScalar InDistance = VPTreeScalar(0))
: Index(InIndex)
, Distance(InDistance)
{
}
bool operator<(const FIndexDistance& Other) const
{
return Distance < Other.Distance;
}
bool operator==(const FIndexDistance& Other) const
{
return Index == Other.Index && Distance == Other.Distance;
}
friend FArchive& operator<<(FArchive& Ar, FIndexDistance& IndexDistance)
{
Ar << IndexDistance.Index;
Ar << IndexDistance.Distance;
return Ar;
}
int32 Index;
VPTreeScalar Distance;
};
struct FVPTreeNode : public FIndexDistance
{
int32 LeftIndex = INDEX_NONE;
int32 RightIndex = INDEX_NONE;
bool operator==(const FVPTreeNode& Other) const
{
return FIndexDistance::operator==(Other) && LeftIndex == Other.LeftIndex && RightIndex == Other.RightIndex;
}
friend FArchive& operator<<(FArchive& Ar, FVPTreeNode& Node)
{
Ar << static_cast<FIndexDistance&>(Node);
Ar << Node.LeftIndex;
Ar << Node.RightIndex;
return Ar;
}
};
struct FVPTreeResultSet
{
FVPTreeResultSet(int32 InNumNeighbors)
: NumNeighbors(InNumNeighbors)
{
#if !NO_LOGGING
if (InNumNeighbors > HeapDefaultMaxSize)
{
UE_LOG(LogPoseSearch, Warning, TEXT("FVPTreeResultSet - preallocated Heap data size 'HeapDefaultMaxSize' of %d is less than requested %d num neighbors. Performances will be negatively impacted"), HeapDefaultMaxSize, InNumNeighbors);
}
#endif // !NO_LOGGING
}
VPTreeScalar GetWorstDistance() const
{
return Heap.HeapTop().Distance;
}
bool IsFull() const
{
return Heap.Num() == NumNeighbors;
}
void AddPoint(float Distance, int32 Index
#if VALIDATE_FINDNEIGHBORS
, int32 BestIndex = INDEX_NONE
#endif // VALIDATE_FINDNEIGHBORS
)
{
if (!IsFull())
{
Heap.HeapPush(FIndexDistance(Index, Distance), HeapCompare());
#if VALIDATE_FINDNEIGHBORS
Heap.VerifyHeap(HeapCompare());
#endif // VALIDATE_FINDNEIGHBORS
}
else if (Distance < GetWorstDistance())
{
// popping the worst FIndexDistance
FIndexDistance Removed;
Heap.HeapPop(Removed, HeapCompare());
#if VALIDATE_FINDNEIGHBORS
if (BestIndex != INDEX_NONE && Removed.Index == BestIndex)
{
UE_LOG(LogPoseSearch, Error, TEXT("FVPTreeResultSet::AddPoint - we removed the best node from the ResultSet..."));
}
#endif // VALIDATE_FINDNEIGHBORS
// and replacing it with FIndexDistance(Index, Distance)
Heap.HeapPush(FIndexDistance(Index, Distance), HeapCompare());
#if VALIDATE_FINDNEIGHBORS
Heap.VerifyHeap(HeapCompare());
#endif // VALIDATE_FINDNEIGHBORS
}
}
TArray<FIndexDistance> GetSortedResults()
{
TArray<FIndexDistance> Results;
Results.SetNumUninitialized(Heap.Num());
int32 Index = 0;
while (!Heap.IsEmpty())
{
Heap.HeapPop(Results[Index++], HeapCompare());
}
return Results;
}
TConstArrayView<FIndexDistance> GetUnsortedResults() const
{
return Heap;
}
bool operator==(const FVPTreeResultSet& Other) const
{
return NumNeighbors == Other.NumNeighbors && Heap == Other.Heap;
}
private:
struct HeapCompare
{
bool operator()(const FIndexDistance& A, const FIndexDistance& B) const
{
// using > to create a heap where Heap.HeapTop().Distance is the greater Distance (to behave like the std::priority_queue)
return A.Distance > B.Distance;
}
};
const int32 NumNeighbors;
enum { HeapDefaultMaxSize = 256 };
TArray<FIndexDistance, TInlineAllocator<HeapDefaultMaxSize>> Heap;
};
struct FVPTree
{
template<typename TDataSource>
void Construct(const TDataSource& DataSource, FRandomStream& RandStream)
{
Reset();
// initializing the DataSourceIndexMapping, array of indeces pointing to the original TDataSource
const int32 Count = DataSource.Num();
TArray<FIndexDistance> DataSourceIndexMapping;
DataSourceIndexMapping.SetNumUninitialized(Count);
for (int32 Index = 0; Index < DataSourceIndexMapping.Num(); ++Index)
{
DataSourceIndexMapping[Index].Index = Index;
}
ConstructRecursive(DataSource, DataSourceIndexMapping, RandStream);
}
template<typename TDataSource>
bool TestConstruct(const TDataSource& DataSource, int32 NodeIndex = 0) const
{
TArray<int32> SubTreeNodeIndexesTempBuffer;
SubTreeNodeIndexesTempBuffer.Reserve(Nodes.Num());
return TestConstructRecursive(DataSource, NodeIndex, SubTreeNodeIndexesTempBuffer);
}
template<typename T, typename TDataSource, typename TVPTreeResultSet>
void FindNeighbors(const T& Query, TVPTreeResultSet& ResultSet, const TDataSource& DataSource) const
{
#if VALIDATE_FINDNEIGHBORS
int32 NumEvaluatedQueryDistances = 0;
// brute forcing to find the best Node
VPTreeScalar BestDistance = VPTreeScalar(UE_BIG_NUMBER);
int32 BestNodeIndex = INDEX_NONE;
TArray<bool> DoWeHaveAllNodes;
DoWeHaveAllNodes.SetNumZeroed(Nodes.Num());
for (int32 NodeIndex = 0; NodeIndex < Nodes.Num(); ++NodeIndex)
{
check(Nodes[NodeIndex].Index != INDEX_NONE);
DoWeHaveAllNodes[Nodes[NodeIndex].Index] = true;
const VPTreeScalar Distance = TDataSource::GetDistance(Query, DataSource[Nodes[NodeIndex].Index]);
if (Distance < BestDistance)
{
// @todo: handle eventual duplicates
BestDistance = Distance;
BestNodeIndex = NodeIndex;
}
}
check(!DoWeHaveAllNodes.Contains(false));
TArray<int32> SubTreeNodeIndexesTempBuffer;
SubTreeNodeIndexesTempBuffer.Reserve(Nodes.Num());
bool bBestNodeIndexFound = false;
#endif // VALIDATE_FINDNEIGHBORS
enum { NodesToSearchDefaultMaxSize = 512 };
TArray<int32, TInlineAllocator<NodesToSearchDefaultMaxSize>> NodesToSearch;
NodesToSearch.Emplace(0);
VPTreeScalar Tau = VPTreeScalar(UE_BIG_NUMBER);
while (!NodesToSearch.IsEmpty())
{
const int32 NodeIndex = NodesToSearch.Pop();
const FVPTreeNode& Node = Nodes[NodeIndex];
// query distance from vantage point
const VPTreeScalar QueryDistance = TDataSource::GetDistance(DataSource[Node.Index], Query);
#if VALIDATE_FINDNEIGHBORS
++NumEvaluatedQueryDistances;
bBestNodeIndexFound |= NodeIndex == BestNodeIndex;
bool bNodeIndexContainsBestNodeIndex = false;
if (!bBestNodeIndexFound)
{
SubTreeNodeIndexesTempBuffer.Reset();
GetSubTreeNodeIndexes(NodeIndex, SubTreeNodeIndexesTempBuffer);
bNodeIndexContainsBestNodeIndex = SubTreeNodeIndexesTempBuffer.Contains(BestNodeIndex);
if (!bNodeIndexContainsBestNodeIndex)
{
bool bOtherNodeIndexContainsBestNodeIndex = false;
for (int32 OtherNodeIndex : NodesToSearch)
{
SubTreeNodeIndexesTempBuffer.Reset();
GetSubTreeNodeIndexes(OtherNodeIndex, SubTreeNodeIndexesTempBuffer);
bOtherNodeIndexContainsBestNodeIndex = SubTreeNodeIndexesTempBuffer.Contains(BestNodeIndex);
if (bOtherNodeIndexContainsBestNodeIndex)
{
break;
}
}
if (!bOtherNodeIndexContainsBestNodeIndex)
{
UE_LOG(LogPoseSearch, Error, TEXT("FVPTree::FindNeighbors - we couldn't find the best node..."));
}
}
}
#endif // VALIDATE_FINDNEIGHBORS
// Tau represents the worst distance of the points in the TVPTreeResultSet
if (QueryDistance <= Tau)
{
#if VALIDATE_FINDNEIGHBORS
if (bBestNodeIndexFound)
{
ResultSet.AddPoint(QueryDistance, Node.Index, BestNodeIndex);
}
else
#endif // VALIDATE_FINDNEIGHBORS
{
ResultSet.AddPoint(QueryDistance, Node.Index);
}
if (ResultSet.IsFull())
{
Tau = ResultSet.GetWorstDistance();
}
}
#if VALIDATE_FINDNEIGHBORS
else if (NodeIndex == BestNodeIndex)
{
UE_LOG(LogPoseSearch, Error, TEXT("FVPTree::FindNeighbors - we had the best node in our hands and let it slipped away..."));
}
bool bIsBestNodeStillReachable = false;
#endif // VALIDATE_FINDNEIGHBORS
if (Node.LeftIndex != INDEX_NONE && QueryDistance < Node.Distance + Tau)
{
NodesToSearch.Emplace(Node.LeftIndex);
#if !NO_LOGGING
if (NodesToSearch.Num() > NodesToSearchDefaultMaxSize)
{
UE_LOG(LogPoseSearch, Warning, TEXT("FVPTree::FindNeighbors - requested more than preallocated NodesToSearch data size 'NodesToSearchDefaultMaxSize' (%d / %d)"), NodesToSearch.Num(), NodesToSearchDefaultMaxSize);
}
#endif // !NO_LOGGING
#if VALIDATE_FINDNEIGHBORS
if (bNodeIndexContainsBestNodeIndex)
{
SubTreeNodeIndexesTempBuffer.Reset();
GetSubTreeNodeIndexes(Node.LeftIndex, SubTreeNodeIndexesTempBuffer);
const bool bLeftIndexContainsBestNodeIndex = SubTreeNodeIndexesTempBuffer.Contains(BestNodeIndex);
bIsBestNodeStillReachable |= bLeftIndexContainsBestNodeIndex;
}
#endif // VALIDATE_FINDNEIGHBORS
}
if (Node.RightIndex != INDEX_NONE && QueryDistance >= Node.Distance - Tau)
{
NodesToSearch.Emplace(Node.RightIndex);
#if !NO_LOGGING
if (NodesToSearch.Num() > NodesToSearchDefaultMaxSize)
{
UE_LOG(LogPoseSearch, Warning, TEXT("FVPTree::FindNeighbors - requested more than preallocated NodesToSearch data size 'NodesToSearchDefaultMaxSize' (%d / %d)"), NodesToSearch.Num(), NodesToSearchDefaultMaxSize);
}
#endif // !NO_LOGGING
#if VALIDATE_FINDNEIGHBORS
if (bNodeIndexContainsBestNodeIndex)
{
SubTreeNodeIndexesTempBuffer.Reset();
GetSubTreeNodeIndexes(Node.RightIndex, SubTreeNodeIndexesTempBuffer);
const bool bRightIndexContainsBestNodeIndex = SubTreeNodeIndexesTempBuffer.Contains(BestNodeIndex);
bIsBestNodeStillReachable |= bRightIndexContainsBestNodeIndex;
}
#endif // VALIDATE_FINDNEIGHBORS
}
#if VALIDATE_FINDNEIGHBORS
if (bNodeIndexContainsBestNodeIndex != bIsBestNodeStillReachable)
{
UE_LOG(LogPoseSearch, Error, TEXT("FVPTree::FindNeighbors - we lost strack of the best node..."));
}
#endif // VALIDATE_FINDNEIGHBORS
}
}
void Reset();
POSESEARCH_API SIZE_T GetAllocatedSize() const;
bool operator==(const FVPTree& Other) const;
friend FArchive& operator<<(FArchive& Ar, FVPTree& VPTree);
private:
void GetSubTreeNodeIndexes(int32 RootNodeIndex, TArray<int32>& NodeIndexes) const
{
if (RootNodeIndex != INDEX_NONE)
{
NodeIndexes.Add(RootNodeIndex);
GetSubTreeNodeIndexes(Nodes[RootNodeIndex].LeftIndex, NodeIndexes);
GetSubTreeNodeIndexes(Nodes[RootNodeIndex].RightIndex, NodeIndexes);
}
}
// choose and swap vantage point into DataSourceIndexMapping[0]
template<typename TDataSource>
void ChooseVantagePointMappingIndex(const TDataSource& DataSource, TArrayView<FIndexDistance> DataSourceIndexMapping, FRandomStream& RandStream) const
{
// choosing a random, but deterministic, vantage point within DataSourceIndexMapping
const int32 VantagePointMappingIndex = FMath::FloorToInt32(RandStream.FRand() * DataSourceIndexMapping.Num());
// swapping DataSourceIndexMapping[VantagePointMappingIndex] with DataSourceIndexMapping[0];
const int32 VantagePointDataIndex = DataSourceIndexMapping[VantagePointMappingIndex].Index;
DataSourceIndexMapping[VantagePointMappingIndex] = DataSourceIndexMapping[0].Index;
DataSourceIndexMapping[0].Index = VantagePointDataIndex;
DataSourceIndexMapping[0].Distance = VPTreeScalar(0);
}
template<typename TDataSource>
void CacheDistancesFromVantagePoint(const TDataSource& DataSource, TArrayView<FIndexDistance> DataSourceIndexMapping) const
{
DataSourceIndexMapping[0].Distance = VPTreeScalar(0);
const int32 VantagePointDataIndex = DataSourceIndexMapping[0].Index;
ParallelFor(DataSourceIndexMapping.Num() - 1, [&DataSource, &DataSourceIndexMapping, VantagePointDataIndex](int32 ParallelForIndex)
{
FIndexDistance& IndexDistance = DataSourceIndexMapping[ParallelForIndex + 1];
IndexDistance.Distance = TDataSource::GetDistance(DataSource[VantagePointDataIndex], DataSource[IndexDistance.Index]);
});
}
template<typename TDataSource>
int32 ConstructRecursive(const TDataSource& DataSource, TArrayView<FIndexDistance> DataSourceIndexMapping, FRandomStream& RandStream)
{
int32 NodeIndex = INDEX_NONE;
const int32 Num = DataSourceIndexMapping.Num();
if (Num > 0)
{
NodeIndex = Nodes.Num();
Nodes.SetNum(NodeIndex + 1);
if (Num > 1)
{
ChooseVantagePointMappingIndex(DataSource, DataSourceIndexMapping, RandStream);
CacheDistancesFromVantagePoint(DataSource, DataSourceIndexMapping);
Nodes[NodeIndex].Index = DataSourceIndexMapping[0].Index;
// @todo: use something like std::nth_element since we don't need to sort the MedianMappingIndex -> end section of the DataSourceIndexMapping array!
DataSourceIndexMapping.Slice(1, Num - 1).StableSort([](const FIndexDistance& A, const FIndexDistance& B)
{
return A.Distance < B.Distance;
});
// @todo: maybe use the average delta distance between MedianMappingIndex - 1 and MedianMappingIndex, to round up eventual search numerical errors
const int32 MedianMappingIndex = Num / 2;
Nodes[NodeIndex].Distance = DataSourceIndexMapping[MedianMappingIndex - 1].Distance;
Nodes[NodeIndex].LeftIndex = ConstructRecursive(DataSource, DataSourceIndexMapping.Slice(1, MedianMappingIndex - 1), RandStream);
Nodes[NodeIndex].RightIndex = ConstructRecursive(DataSource, DataSourceIndexMapping.Slice(MedianMappingIndex, Num - MedianMappingIndex), RandStream);
}
else
{
Nodes[NodeIndex].Index = DataSourceIndexMapping[0].Index;
}
}
return NodeIndex;
}
template<typename TDataSource>
bool TestConstructRecursive(const TDataSource& DataSource, int32 NodeIndex, TArray<int32>& SubTreeNodeIndexesTempBuffer) const
{
if (NodeIndex == INDEX_NONE)
{
return true;
}
const FVPTreeNode Node = Nodes[NodeIndex];
// making sure all the left sub tree nodes are within Nodes[NodeIndex].Distance to the vantage point Nodes[NodeIndex].Index
SubTreeNodeIndexesTempBuffer.Reset();
GetSubTreeNodeIndexes(Node.LeftIndex, SubTreeNodeIndexesTempBuffer);
for (int32 SubTreeNodeIndex : SubTreeNodeIndexesTempBuffer)
{
const VPTreeScalar Distance = TDataSource::GetDistance(DataSource[Nodes[NodeIndex].Index], DataSource[Nodes[SubTreeNodeIndex].Index]);
if (Distance > Nodes[NodeIndex].Distance)
{
return false;
}
}
// making sure all the right sub tree nodes are outside the hypersphere of radius Nodes[NodeIndex].Distance from the vantage point Nodes[NodeIndex].Index
SubTreeNodeIndexesTempBuffer.Reset();
GetSubTreeNodeIndexes(Node.RightIndex, SubTreeNodeIndexesTempBuffer);
for (int32 SubTreeNodeIndex : SubTreeNodeIndexesTempBuffer)
{
const VPTreeScalar Distance = TDataSource::GetDistance(DataSource[Nodes[NodeIndex].Index], DataSource[Nodes[SubTreeNodeIndex].Index]);
if (Distance <= Nodes[NodeIndex].Distance)
{
return false;
}
}
// test recursively left and right sub trees
if (!TestConstructRecursive(DataSource, Node.LeftIndex, SubTreeNodeIndexesTempBuffer))
{
return false;
}
return TestConstructRecursive(DataSource, Node.RightIndex, SubTreeNodeIndexesTempBuffer);
}
TArray<FVPTreeNode> Nodes;
};
} // namespace UE::PoseSearch