// Copyright Epic Games, Inc. All Rights Reserved.
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using EpicGames.Horde.Storage.Bundles;
using EpicGames.Horde.Storage.Bundles.V2;
namespace EpicGames.Horde.Storage
{
///
/// Request to read a blob from storage
///
public record class BlobRequest(IBlobRef Handle, TUserData UserData);
///
/// Response from reading a blob from storage
///
public sealed record class BlobResponse(BlobData BlobData, TUserData UserData) : IDisposable
{
///
public void Dispose()
=> BlobData.Dispose();
}
///
/// Batch of responses from the reader
///
public sealed record class BlobResponseBatch(BlobResponse[] Responses) : IDisposable
{
///
public void Dispose()
{
foreach (BlobResponse response in Responses)
{
response.Dispose();
}
}
}
///
/// Options for
///
public class BlobPipelineOptions
{
///
/// Maximum number of responses to buffer before pausing.
///
public int ResponseBufferSize { get; set; } = 200;
///
/// Number of requests to enumerate before flushing the current batch
///
public int FlushBatchLength { get; set; } = 200;
///
/// Number of batches to fetch in parallel
///
public int NumFetchTasks { get; set; } = 4;
}
///
/// Helper class to sort requested reads for optimal coherency within bundles
///
/// Type of user data to include with requests
public sealed class BlobPipeline : IAsyncDisposable
{
readonly BlobPipelineOptions _options;
readonly CancellationTokenSource _cancellationTokenSource = new CancellationTokenSource();
readonly List _tasks = new List();
readonly Channel> _requestChannel;
readonly Channel[]> _requestBatchChannel;
readonly Channel> _responseBatchChannel;
int _writerCount = 1;
bool _writerFinished;
///
/// Constructor
///
public BlobPipeline()
: this(new BlobPipelineOptions())
{ }
///
/// Constructor
///
public BlobPipeline(BlobPipelineOptions options)
{
_options = options;
_requestChannel = Channel.CreateUnbounded>();
_requestBatchChannel = Channel.CreateUnbounded[]>();
_responseBatchChannel = Channel.CreateBounded>(new BoundedChannelOptions(options.ResponseBufferSize) { FullMode = BoundedChannelFullMode.Wait });
_tasks.Add(Task.Run(() => BatchLoopAsync(_cancellationTokenSource.Token), _cancellationTokenSource.Token));
_tasks.Add(Task.Run(() => FetchLoopAsync(_cancellationTokenSource.Token), _cancellationTokenSource.Token));
}
///
public async ValueTask DisposeAsync()
{
if (_tasks.Count > 0)
{
await _cancellationTokenSource.CancelAsync();
try
{
await Task.WhenAll(_tasks);
}
catch (OperationCanceledException)
{
}
_tasks.Clear();
}
_cancellationTokenSource.Dispose();
BlobResponseBatch? batch;
while (_responseBatchChannel.Reader.TryRead(out batch))
{
batch.Dispose();
}
}
///
/// Adds a new read request
///
/// The request to add
public void Add(BlobRequest request)
{
if (!_requestChannel.Writer.TryWrite(request))
{
throw new InvalidOperationException("Cannot write request to channel");
}
}
///
/// Adds a new request source
///
/// Method to construct the sequence of items
public void AddSource(Func>> factory)
{
AdjustWriterCount(1);
_tasks.Add(Task.Run(() => CopyToRequestChannelAsync(factory, _cancellationTokenSource.Token), _cancellationTokenSource.Token));
}
async Task CopyToRequestChannelAsync(Func>> factory, CancellationToken cancellationToken)
{
await using IAsyncEnumerator> source = factory(cancellationToken);
while (await source.MoveNextAsync())
{
await _requestChannel.Writer.WriteAsync(source.Current, cancellationToken);
}
AdjustWriterCount(-1);
}
///
/// Indicate that we've finished adding new items to the reader
///
public void FinishAdding()
{
if (!_writerFinished)
{
AdjustWriterCount(-1);
_writerFinished = true;
}
}
void AdjustWriterCount(int delta)
{
for (; ; )
{
int writerCount = Interlocked.CompareExchange(ref _writerCount, 0, 0);
if (writerCount == 0)
{
throw new InvalidOperationException("Reading has already been marked complete");
}
if (Interlocked.CompareExchange(ref _writerCount, writerCount + delta, writerCount) == writerCount)
{
if (writerCount + delta == 0)
{
_requestChannel.Writer.TryComplete();
}
break;
}
}
}
record class BundleRequest(BundleHandle BundleHandle, int PacketOffset, int ExportIdx, BlobRequest OriginalRequest);
record class BundleBatchRequest(BundleHandle BundleHandle, List Requests);
async Task BatchLoopAsync(CancellationToken cancellationToken)
{
int queueLength = 0;
Queue bundleBatchQueue = new Queue();
Dictionary bundleHandleToBatch = new Dictionary();
for (; ; )
{
// Add requests to the queue
for (; ; )
{
BlobRequest? request;
if (_requestChannel.Reader.TryRead(out request))
{
if (TryAddBundleRequest(request, bundleBatchQueue, bundleHandleToBatch))
{
queueLength++;
}
else
{
await _requestBatchChannel.Writer.WriteAsync(new[] { request }, cancellationToken);
}
}
else
{
if (queueLength > _options.FlushBatchLength)
{
break;
}
if (!await _requestChannel.Reader.WaitToReadAsync(cancellationToken))
{
break;
}
}
}
// Exit once we've processed everything and can't get any more items to read.
if (queueLength == 0)
{
_requestBatchChannel.Writer.TryComplete();
break;
}
// Flush the first queue
BundleBatchRequest exportBatch = bundleBatchQueue.Dequeue();
queueLength -= exportBatch.Requests.Count;
bundleHandleToBatch.Remove(exportBatch.BundleHandle);
BlobRequest[] chunkBatch = exportBatch.Requests.OrderBy(x => x.PacketOffset).ThenBy(x => x.ExportIdx).Select(x => x.OriginalRequest).ToArray();
await _requestBatchChannel.Writer.WriteAsync(chunkBatch, cancellationToken);
}
}
static bool TryAddBundleRequest(BlobRequest request, Queue bundleBatchQueue, Dictionary bundleHandleToBatch)
{
ExportHandle? exportHandle = request.Handle.Innermost as ExportHandle;
if (exportHandle == null)
{
return false;
}
FlushedPacketHandle? packetHandle = exportHandle.Packet as FlushedPacketHandle;
if (packetHandle == null)
{
return false;
}
BundleBatchRequest? bundleBatchRequest;
if (!bundleHandleToBatch.TryGetValue(packetHandle.Bundle, out bundleBatchRequest))
{
bundleBatchRequest = new BundleBatchRequest(packetHandle.Bundle, new List());
bundleBatchQueue.Enqueue(bundleBatchRequest);
bundleHandleToBatch.Add(packetHandle.Bundle, bundleBatchRequest);
}
BundleRequest bundleRequest = new BundleRequest(packetHandle.Bundle, packetHandle.PacketOffset, exportHandle.ExportIdx, request);
bundleBatchRequest.Requests.Add(bundleRequest);
return true;
}
async Task FetchLoopAsync(CancellationToken cancellationToken)
{
try
{
List tasks = new List();
for (int idx = 0; idx < _options.NumFetchTasks; idx++)
{
tasks.Add(Task.Run(() => FetchWorkerAsync(cancellationToken), cancellationToken));
}
await Task.WhenAll(tasks);
_responseBatchChannel.Writer.TryComplete();
}
catch(Exception ex)
{
_responseBatchChannel.Writer.TryComplete(ex);
}
}
async Task FetchWorkerAsync(CancellationToken cancellationToken)
{
while (await _requestBatchChannel.Reader.WaitToReadAsync(cancellationToken))
{
BlobRequest[]? requestBatch;
if (_requestBatchChannel.Reader.TryRead(out requestBatch))
{
BlobResponse[] responseBatch = new BlobResponse[requestBatch.Length];
for (int idx = 0; idx < requestBatch.Length; idx++)
{
BlobRequest request = requestBatch[idx];
BlobData blobData = await requestBatch[idx].Handle.ReadBlobDataAsync(cancellationToken);
BlobResponse response = new BlobResponse(blobData, request.UserData);
responseBatch[idx] = response;
}
#pragma warning disable CA2000
await _responseBatchChannel.Writer.WriteAsync(new BlobResponseBatch(responseBatch), cancellationToken);
#pragma warning restore CA2000
}
}
}
///
/// Reads all responses from the reader
///
/// Cancellation token for the operation
public async IAsyncEnumerable> ReadAllAsync([EnumeratorCancellation] CancellationToken cancellationToken = default)
{
while (await WaitToReadAsync(cancellationToken))
{
#pragma warning disable CA2000
BlobResponseBatch? batch;
if (TryReadBatch(out batch))
{
int idx = 0;
try
{
for (; idx < batch.Responses.Length; idx++)
{
yield return batch.Responses[idx];
}
}
finally
{
for (; idx < batch.Responses.Length; idx++)
{
batch.Responses[idx].Dispose();
}
}
}
#pragma warning restore CA2000
}
}
///
/// Attempts to read a batch from the queue
///
/// Batch of responses
public bool TryReadBatch([NotNullWhen(true)] out BlobResponseBatch? batch)
=> _responseBatchChannel.Reader.TryRead(out batch);
///
/// Waits until there is data available to read
///
/// Cancellation token for the operation
public ValueTask WaitToReadAsync(CancellationToken cancellationToken)
=> _responseBatchChannel.Reader.WaitToReadAsync(cancellationToken);
}
}