// Copyright Epic Games, Inc. All Rights Reserved. using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; using EpicGames.Horde.Storage.Bundles; using EpicGames.Horde.Storage.Bundles.V2; namespace EpicGames.Horde.Storage { /// /// Request to read a blob from storage /// public record class BlobRequest(IBlobRef Handle, TUserData UserData); /// /// Response from reading a blob from storage /// public sealed record class BlobResponse(BlobData BlobData, TUserData UserData) : IDisposable { /// public void Dispose() => BlobData.Dispose(); } /// /// Batch of responses from the reader /// public sealed record class BlobResponseBatch(BlobResponse[] Responses) : IDisposable { /// public void Dispose() { foreach (BlobResponse response in Responses) { response.Dispose(); } } } /// /// Options for /// public class BlobPipelineOptions { /// /// Maximum number of responses to buffer before pausing. /// public int ResponseBufferSize { get; set; } = 200; /// /// Number of requests to enumerate before flushing the current batch /// public int FlushBatchLength { get; set; } = 200; /// /// Number of batches to fetch in parallel /// public int NumFetchTasks { get; set; } = 4; } /// /// Helper class to sort requested reads for optimal coherency within bundles /// /// Type of user data to include with requests public sealed class BlobPipeline : IAsyncDisposable { readonly BlobPipelineOptions _options; readonly CancellationTokenSource _cancellationTokenSource = new CancellationTokenSource(); readonly List _tasks = new List(); readonly Channel> _requestChannel; readonly Channel[]> _requestBatchChannel; readonly Channel> _responseBatchChannel; int _writerCount = 1; bool _writerFinished; /// /// Constructor /// public BlobPipeline() : this(new BlobPipelineOptions()) { } /// /// Constructor /// public BlobPipeline(BlobPipelineOptions options) { _options = options; _requestChannel = Channel.CreateUnbounded>(); _requestBatchChannel = Channel.CreateUnbounded[]>(); _responseBatchChannel = Channel.CreateBounded>(new BoundedChannelOptions(options.ResponseBufferSize) { FullMode = BoundedChannelFullMode.Wait }); _tasks.Add(Task.Run(() => BatchLoopAsync(_cancellationTokenSource.Token), _cancellationTokenSource.Token)); _tasks.Add(Task.Run(() => FetchLoopAsync(_cancellationTokenSource.Token), _cancellationTokenSource.Token)); } /// public async ValueTask DisposeAsync() { if (_tasks.Count > 0) { await _cancellationTokenSource.CancelAsync(); try { await Task.WhenAll(_tasks); } catch (OperationCanceledException) { } _tasks.Clear(); } _cancellationTokenSource.Dispose(); BlobResponseBatch? batch; while (_responseBatchChannel.Reader.TryRead(out batch)) { batch.Dispose(); } } /// /// Adds a new read request /// /// The request to add public void Add(BlobRequest request) { if (!_requestChannel.Writer.TryWrite(request)) { throw new InvalidOperationException("Cannot write request to channel"); } } /// /// Adds a new request source /// /// Method to construct the sequence of items public void AddSource(Func>> factory) { AdjustWriterCount(1); _tasks.Add(Task.Run(() => CopyToRequestChannelAsync(factory, _cancellationTokenSource.Token), _cancellationTokenSource.Token)); } async Task CopyToRequestChannelAsync(Func>> factory, CancellationToken cancellationToken) { await using IAsyncEnumerator> source = factory(cancellationToken); while (await source.MoveNextAsync()) { await _requestChannel.Writer.WriteAsync(source.Current, cancellationToken); } AdjustWriterCount(-1); } /// /// Indicate that we've finished adding new items to the reader /// public void FinishAdding() { if (!_writerFinished) { AdjustWriterCount(-1); _writerFinished = true; } } void AdjustWriterCount(int delta) { for (; ; ) { int writerCount = Interlocked.CompareExchange(ref _writerCount, 0, 0); if (writerCount == 0) { throw new InvalidOperationException("Reading has already been marked complete"); } if (Interlocked.CompareExchange(ref _writerCount, writerCount + delta, writerCount) == writerCount) { if (writerCount + delta == 0) { _requestChannel.Writer.TryComplete(); } break; } } } record class BundleRequest(BundleHandle BundleHandle, int PacketOffset, int ExportIdx, BlobRequest OriginalRequest); record class BundleBatchRequest(BundleHandle BundleHandle, List Requests); async Task BatchLoopAsync(CancellationToken cancellationToken) { int queueLength = 0; Queue bundleBatchQueue = new Queue(); Dictionary bundleHandleToBatch = new Dictionary(); for (; ; ) { // Add requests to the queue for (; ; ) { BlobRequest? request; if (_requestChannel.Reader.TryRead(out request)) { if (TryAddBundleRequest(request, bundleBatchQueue, bundleHandleToBatch)) { queueLength++; } else { await _requestBatchChannel.Writer.WriteAsync(new[] { request }, cancellationToken); } } else { if (queueLength > _options.FlushBatchLength) { break; } if (!await _requestChannel.Reader.WaitToReadAsync(cancellationToken)) { break; } } } // Exit once we've processed everything and can't get any more items to read. if (queueLength == 0) { _requestBatchChannel.Writer.TryComplete(); break; } // Flush the first queue BundleBatchRequest exportBatch = bundleBatchQueue.Dequeue(); queueLength -= exportBatch.Requests.Count; bundleHandleToBatch.Remove(exportBatch.BundleHandle); BlobRequest[] chunkBatch = exportBatch.Requests.OrderBy(x => x.PacketOffset).ThenBy(x => x.ExportIdx).Select(x => x.OriginalRequest).ToArray(); await _requestBatchChannel.Writer.WriteAsync(chunkBatch, cancellationToken); } } static bool TryAddBundleRequest(BlobRequest request, Queue bundleBatchQueue, Dictionary bundleHandleToBatch) { ExportHandle? exportHandle = request.Handle.Innermost as ExportHandle; if (exportHandle == null) { return false; } FlushedPacketHandle? packetHandle = exportHandle.Packet as FlushedPacketHandle; if (packetHandle == null) { return false; } BundleBatchRequest? bundleBatchRequest; if (!bundleHandleToBatch.TryGetValue(packetHandle.Bundle, out bundleBatchRequest)) { bundleBatchRequest = new BundleBatchRequest(packetHandle.Bundle, new List()); bundleBatchQueue.Enqueue(bundleBatchRequest); bundleHandleToBatch.Add(packetHandle.Bundle, bundleBatchRequest); } BundleRequest bundleRequest = new BundleRequest(packetHandle.Bundle, packetHandle.PacketOffset, exportHandle.ExportIdx, request); bundleBatchRequest.Requests.Add(bundleRequest); return true; } async Task FetchLoopAsync(CancellationToken cancellationToken) { try { List tasks = new List(); for (int idx = 0; idx < _options.NumFetchTasks; idx++) { tasks.Add(Task.Run(() => FetchWorkerAsync(cancellationToken), cancellationToken)); } await Task.WhenAll(tasks); _responseBatchChannel.Writer.TryComplete(); } catch(Exception ex) { _responseBatchChannel.Writer.TryComplete(ex); } } async Task FetchWorkerAsync(CancellationToken cancellationToken) { while (await _requestBatchChannel.Reader.WaitToReadAsync(cancellationToken)) { BlobRequest[]? requestBatch; if (_requestBatchChannel.Reader.TryRead(out requestBatch)) { BlobResponse[] responseBatch = new BlobResponse[requestBatch.Length]; for (int idx = 0; idx < requestBatch.Length; idx++) { BlobRequest request = requestBatch[idx]; BlobData blobData = await requestBatch[idx].Handle.ReadBlobDataAsync(cancellationToken); BlobResponse response = new BlobResponse(blobData, request.UserData); responseBatch[idx] = response; } #pragma warning disable CA2000 await _responseBatchChannel.Writer.WriteAsync(new BlobResponseBatch(responseBatch), cancellationToken); #pragma warning restore CA2000 } } } /// /// Reads all responses from the reader /// /// Cancellation token for the operation public async IAsyncEnumerable> ReadAllAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) { while (await WaitToReadAsync(cancellationToken)) { #pragma warning disable CA2000 BlobResponseBatch? batch; if (TryReadBatch(out batch)) { int idx = 0; try { for (; idx < batch.Responses.Length; idx++) { yield return batch.Responses[idx]; } } finally { for (; idx < batch.Responses.Length; idx++) { batch.Responses[idx].Dispose(); } } } #pragma warning restore CA2000 } } /// /// Attempts to read a batch from the queue /// /// Batch of responses public bool TryReadBatch([NotNullWhen(true)] out BlobResponseBatch? batch) => _responseBatchChannel.Reader.TryRead(out batch); /// /// Waits until there is data available to read /// /// Cancellation token for the operation public ValueTask WaitToReadAsync(CancellationToken cancellationToken) => _responseBatchChannel.Reader.WaitToReadAsync(cancellationToken); } }