Skip to content

Commit 8f943cf

Browse files
committed
Fix thread safety when OpenAPI document is downloaded in parallel
1 parent daec9e7 commit 8f943cf

File tree

4 files changed

+64
-11
lines changed

4 files changed

+64
-11
lines changed

src/JsonApiDotNetCore.OpenApi.Swashbuckle/JsonApiActionDescriptorCollectionProvider.cs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using System.Collections.Concurrent;
12
using System.Net;
23
using System.Reflection;
34
using JsonApiDotNetCore.Configuration;
@@ -36,8 +37,10 @@ internal sealed partial class JsonApiActionDescriptorCollectionProvider : IActio
3637
private readonly JsonApiEndpointMetadataProvider _jsonApiEndpointMetadataProvider;
3738
private readonly IJsonApiOptions _options;
3839
private readonly ILogger<JsonApiActionDescriptorCollectionProvider> _logger;
40+
private readonly ConcurrentDictionary<int, Lazy<ActionDescriptorCollection>> _versionedActionDescriptorCache = new();
3941

40-
public ActionDescriptorCollection ActionDescriptors => GetActionDescriptors();
42+
public ActionDescriptorCollection ActionDescriptors =>
43+
_versionedActionDescriptorCache.GetOrAdd(_defaultProvider.ActionDescriptors.Version, LazyGetActionDescriptors).Value;
4144

4245
public JsonApiActionDescriptorCollectionProvider(IActionDescriptorCollectionProvider defaultProvider, IControllerResourceMapping controllerResourceMapping,
4346
JsonApiEndpointMetadataProvider jsonApiEndpointMetadataProvider, IJsonApiOptions options, ILogger<JsonApiActionDescriptorCollectionProvider> logger)
@@ -55,7 +58,13 @@ public JsonApiActionDescriptorCollectionProvider(IActionDescriptorCollectionProv
5558
_logger = logger;
5659
}
5760

58-
private ActionDescriptorCollection GetActionDescriptors()
61+
private Lazy<ActionDescriptorCollection> LazyGetActionDescriptors(int version)
62+
{
63+
// https://andrewlock.net/making-getoradd-on-concurrentdictionary-thread-safe-using-lazy/
64+
return new Lazy<ActionDescriptorCollection>(() => GetActionDescriptors(version), LazyThreadSafetyMode.ExecutionAndPublication);
65+
}
66+
67+
private ActionDescriptorCollection GetActionDescriptors(int version)
5968
{
6069
List<ActionDescriptor> descriptors = [];
6170

@@ -106,8 +115,7 @@ private ActionDescriptorCollection GetActionDescriptors()
106115
descriptors.Add(descriptor);
107116
}
108117

109-
int descriptorVersion = _defaultProvider.ActionDescriptors.Version;
110-
return new ActionDescriptorCollection(descriptors.AsReadOnly(), descriptorVersion);
118+
return new ActionDescriptorCollection(descriptors.AsReadOnly(), version);
111119
}
112120

113121
internal static bool IsVisibleEndpoint(ActionDescriptor descriptor)

src/JsonApiDotNetCore.OpenApi.Swashbuckle/SchemaGenerationTracer.cs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using System.Runtime.CompilerServices;
12
using JsonApiDotNetCore.Resources.Annotations;
23
using JsonApiDotNetCore.Serialization.Objects;
34
using Microsoft.Extensions.Logging;
@@ -87,7 +88,7 @@ private static string GetSchemaTypeName(Type type)
8788

8889
private sealed partial class SchemaGenerationTraceScope : ISchemaGenerationTraceScope
8990
{
90-
private static readonly AsyncLocal<int> RecursionDepthAsyncLocal = new();
91+
private static readonly AsyncLocal<StrongBox<int>> RecursionDepthAsyncLocal = new();
9192

9293
private readonly ILogger _logger;
9394
private readonly string _schemaTypeName;
@@ -101,8 +102,10 @@ public SchemaGenerationTraceScope(ILogger logger, string schemaTypeName)
101102
_logger = logger;
102103
_schemaTypeName = schemaTypeName;
103104

104-
RecursionDepthAsyncLocal.Value++;
105-
LogStarted(RecursionDepthAsyncLocal.Value, _schemaTypeName);
105+
RecursionDepthAsyncLocal.Value ??= new StrongBox<int>(0);
106+
int depth = Interlocked.Increment(ref RecursionDepthAsyncLocal.Value.Value);
107+
108+
LogStarted(depth, _schemaTypeName);
106109
}
107110

108111
public void TraceSucceeded(string schemaId)
@@ -112,16 +115,18 @@ public void TraceSucceeded(string schemaId)
112115

113116
public void Dispose()
114117
{
118+
int depth = RecursionDepthAsyncLocal.Value!.Value;
119+
115120
if (_schemaId != null)
116121
{
117-
LogSucceeded(RecursionDepthAsyncLocal.Value, _schemaTypeName, _schemaId);
122+
LogSucceeded(depth, _schemaTypeName, _schemaId);
118123
}
119124
else
120125
{
121-
LogFailed(RecursionDepthAsyncLocal.Value, _schemaTypeName);
126+
LogFailed(depth, _schemaTypeName);
122127
}
123128

124-
RecursionDepthAsyncLocal.Value--;
129+
Interlocked.Decrement(ref RecursionDepthAsyncLocal.Value.Value);
125130
}
126131

127132
[LoggerMessage(Level = LogLevel.Trace, SkipEnabledCheck = true, Message = "({Depth:D2}) Started for {SchemaTypeName}.")]

test/OpenApiTests/OpenApiTestContext.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ internal async Task<JsonElement> GetSwaggerDocumentAsync()
2828
return await _lazySwaggerDocument.Value;
2929
}
3030

31-
private async Task<JsonElement> CreateSwaggerDocumentAsync()
31+
internal async Task<JsonElement> CreateSwaggerDocumentAsync()
3232
{
3333
string content = await GetAsync("/swagger/v1/swagger.json");
3434

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
using FluentAssertions;
2+
using Microsoft.Extensions.DependencyInjection;
3+
using Microsoft.Extensions.Logging;
4+
using Xunit;
5+
using Xunit.Abstractions;
6+
7+
namespace OpenApiTests.ResourceInheritance;
8+
9+
public sealed class ConcurrencyTests : ResourceInheritanceTests
10+
{
11+
private readonly OpenApiTestContext<OpenApiStartup<ResourceInheritanceDbContext>, ResourceInheritanceDbContext> _testContext;
12+
13+
public ConcurrencyTests(OpenApiTestContext<OpenApiStartup<ResourceInheritanceDbContext>, ResourceInheritanceDbContext> testContext,
14+
ITestOutputHelper testOutputHelper)
15+
: base(testContext, testOutputHelper, true, false)
16+
{
17+
_testContext = testContext;
18+
19+
testContext.ConfigureServices(services => services.AddLogging(loggingBuilder => loggingBuilder.ClearProviders()));
20+
}
21+
22+
[Fact]
23+
public async Task Can_download_OpenAPI_documents_in_parallel()
24+
{
25+
// Arrange
26+
const int count = 15;
27+
var downloadTasks = new Task[count];
28+
29+
for (int index = 0; index < count; index++)
30+
{
31+
downloadTasks[index] = _testContext.CreateSwaggerDocumentAsync();
32+
}
33+
34+
// Act
35+
Func<Task> action = async () => await Task.WhenAll(downloadTasks);
36+
37+
// Assert
38+
await action.Should().NotThrowAsync();
39+
}
40+
}

0 commit comments

Comments
 (0)