// Copyright Epic Games, Inc. All Rights Reserved.
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using System.Threading.Tasks;
using EpicGames.Core;
using EpicGames.Horde.Storage.Nodes;
namespace EpicGames.Horde.Storage
{
///
/// Index of known nodes that can be used for deduplication.
///
public sealed class DedupeBlobWriter : BlobWriter
{
record BlobKey(IoHash Hash, BlobType Type);
class DedupeCache
{
readonly int _maxKeys;
readonly Queue _blobKeys = new Queue();
readonly Dictionary _blobKeyToHandle = new Dictionary();
public DedupeCache(int maxKeys)
{
_maxKeys = maxKeys;
_blobKeys = new Queue(maxKeys);
_blobKeyToHandle = new Dictionary(maxKeys);
}
internal void Add(BlobKey key, IHashedBlobRef handle)
{
BlobKey? prevKey;
if (_blobKeys.Count == _maxKeys && _blobKeys.TryDequeue(out prevKey))
{
_blobKeyToHandle.Remove(prevKey);
}
_blobKeyToHandle.TryAdd(key, handle);
}
internal bool TryGetValue(BlobKey key, [NotNullWhen(true)] out IHashedBlobRef? handle) => _blobKeyToHandle.TryGetValue(key, out handle);
}
class WrappedHandle : IHashedBlobRef
{
public object _lockObject = new object();
public IHashedBlobRef? _inner;
///
public IBlobRef Innermost
=> _inner!.Innermost;
///
public IoHash Hash
=> _inner!.Hash;
///
public bool TryGetLocator(out BlobLocator locator)
=> _inner!.TryGetLocator(out locator);
///
public ValueTask FlushAsync(CancellationToken cancellationToken)
=> _inner!.FlushAsync(cancellationToken);
///
public ValueTask ReadBlobDataAsync(CancellationToken cancellationToken = default)
=> _inner!.ReadBlobDataAsync(cancellationToken);
///
public override bool Equals(object? obj) => _inner is not null && obj is WrappedHandle other && _inner == other._inner;
///
public override int GetHashCode() => HashCode.Combine((_inner is null) ? 0 : _inner.GetHashCode(), 1);
}
///
/// Default value for maximum number of keys
///
public const int DefaultMaxKeys = 512 * 1024;
readonly BlobWriter _inner;
readonly DedupeCache _cache;
int _numHits;
long _totalHitsSize;
int _numMisses;
long _totalMissesSize;
int _numCacheAdds;
///
/// Constructor
///
///
///
public DedupeBlobWriter(IBlobWriter inner, int maxKeys = DefaultMaxKeys)
: base(inner.Options)
{
_inner = (BlobWriter)inner;
_cache = new DedupeCache(maxKeys);
}
private DedupeBlobWriter(IBlobWriter inner, DedupeCache cache)
: base(inner.Options)
{
_inner = (BlobWriter)inner;
_cache = cache;
}
///
public override ValueTask DisposeAsync() => _inner.DisposeAsync();
///
public override Task FlushAsync(CancellationToken cancellationToken = default) => _inner.FlushAsync(cancellationToken);
///
public override IBlobWriter Fork() => new DedupeBlobWriter(_inner.Fork(), _cache);
///
public override Memory GetOutputBuffer(int usedSize, int desiredSize) => _inner.GetOutputBuffer(usedSize, desiredSize);
///
/// Add a blob to the cache
///
/// Type of the blob
/// Reference to the blob data
public void AddToCache(BlobType type, IHashedBlobRef blobRef)
{
_cache.Add(new BlobKey(blobRef.Hash, type), blobRef);
Interlocked.Increment(ref _numCacheAdds);
}
///
public override async ValueTask WriteBlobAsync(BlobType type, int size, IReadOnlyList imports, IReadOnlyList aliases, CancellationToken cancellationToken = default)
{
ReadOnlyMemory data = _inner.GetOutputBuffer(size, size).Slice(0, size);
IoHash hash = IoHash.Compute(data.Span);
BlobKey key = new BlobKey(hash, type);
WrappedHandle? wrappedHandle;
lock (_cache)
{
IHashedBlobRef? handle;
if (_cache.TryGetValue(key, out handle))
{
_numHits++;
_totalHitsSize += size;
return handle;
}
else
{
_numMisses++;
_totalMissesSize += size;
wrappedHandle = new WrappedHandle();
_cache.Add(key, wrappedHandle);
}
}
wrappedHandle._inner = await _inner.WriteBlobAsync(type, size, imports.ConvertAll(x => x.Innermost), aliases, cancellationToken);
return wrappedHandle;
}
///
/// Gets stats for the copy operation
///
public StorageStats GetStats()
{
StorageStats stats = new StorageStats();
stats.Add("Cache hits", _numHits);
stats.Add("Cache hits size", _totalHitsSize);
stats.Add("Cache misses", _numMisses);
stats.Add("Cache misses size", _totalMissesSize);
stats.Add("Cache adds", _numCacheAdds);
return stats;
}
}
///
/// Extension methods for
///
public static class DedupeBlobWriterExtensions
{
///
/// Wraps a with a
///
public static DedupeBlobWriter WithDedupe(this IBlobWriter writer, int maxKeys = DedupeBlobWriter.DefaultMaxKeys) => new DedupeBlobWriter(writer, maxKeys);
///
/// Creates a dedupe writer
///
/// The store instance to read from
/// Maximum number of keys to include in the cache
public static DedupeBlobWriter CreateDedupeBlobWriter(this IStorageNamespace store, int maxKeys = DedupeBlobWriter.DefaultMaxKeys)
{
IBlobWriter writer = store.CreateBlobWriter();
return new DedupeBlobWriter(writer, maxKeys);
}
///
/// Creates a writer using a refname as a base path
///
/// The store instance to read from
/// Ref name to use as a base path
/// Maximum number of keys to include in the cache
public static DedupeBlobWriter CreateDedupeBlobWriter(this IStorageNamespace store, RefName refName, int maxKeys = DedupeBlobWriter.DefaultMaxKeys)
{
IBlobWriter writer = store.CreateBlobWriter(refName.ToString());
return new DedupeBlobWriter(writer, maxKeys);
}
///
/// Adds a directory tree to the cache
///
/// Dedupe writer to operate on
/// Reference to the directory to add
/// Cancellation token for the operation
public static async Task AddToCacheAsync(this DedupeBlobWriter dedupeWriter, IBlobRef directoryNodeRef, CancellationToken cancellationToken = default)
{
DirectoryNode directoryNode;
using (BlobData blobData = await directoryNodeRef.ReadBlobDataAsync(cancellationToken))
{
IHashedBlobRef hashedBlobRef = HashedBlobRef.Create(IoHash.Compute(blobData.Data.Span), directoryNodeRef);
dedupeWriter.AddToCache(blobData.Type, hashedBlobRef);
directoryNode = BlobSerializer.Deserialize(blobData);
}
foreach (DirectoryEntry directoryEntry in directoryNode.Directories)
{
await AddToCacheAsync(dedupeWriter, directoryEntry.Handle, cancellationToken);
}
foreach (FileEntry fileEntry in directoryNode.Files)
{
await AddToCacheAsync(dedupeWriter, fileEntry.Target, cancellationToken);
}
}
///
/// Adds a chunked data stream to the cache
///
/// Dedupe writer to operate on
/// Reference to a data node in the stream
/// Cancellation token for the operation
public static async Task AddToCacheAsync(this DedupeBlobWriter dedupeWriter, ChunkedDataNodeRef dataNodeRef, CancellationToken cancellationToken = default)
{
if (dataNodeRef.Type == ChunkedDataNodeType.Leaf)
{
BlobType blobType = LeafChunkedDataNodeConverter.BlobType;
dedupeWriter.AddToCache(blobType, dataNodeRef);
}
else if (dataNodeRef.Type == ChunkedDataNodeType.Interior)
{
using BlobData blobData = await dataNodeRef.ReadBlobDataAsync(cancellationToken);
dedupeWriter.AddToCache(blobData.Type, dataNodeRef);
InteriorChunkedDataNode interiorNode = BlobSerializer.Deserialize(blobData);
foreach (ChunkedDataNodeRef childNodeRef in interiorNode.Children)
{
await AddToCacheAsync(dedupeWriter, childNodeRef, cancellationToken);
}
}
}
}
}