From a457f4d6dfb279569796b62699f53858d210a23d Mon Sep 17 00:00:00 2001 From: Dan Clark Date: Mon, 25 Nov 2024 23:04:23 +0000 Subject: [PATCH] Expand on AddAsyncPipelineBehavior --- .../IServiceCollectionExtensions.cs | 80 ++++++++++--------- 1 file changed, 42 insertions(+), 38 deletions(-) diff --git a/Toolkit.Foundation/IServiceCollectionExtensions.cs b/Toolkit.Foundation/IServiceCollectionExtensions.cs index bbbd070..50a54e2 100644 --- a/Toolkit.Foundation/IServiceCollectionExtensions.cs +++ b/Toolkit.Foundation/IServiceCollectionExtensions.cs @@ -5,15 +5,13 @@ namespace Toolkit.Foundation; public static class IServiceCollectionExtensions { public static IServiceCollection AddAsyncHandler(this IServiceCollection services, - string key) - where THandler : class, IAsyncHandler - where TMessage : class => AddAsyncHandler(services, ServiceLifetime.Transient, key); + string key) where THandler : class, IAsyncHandler + where TMessage : class => AddAsyncHandler(services, ServiceLifetime.Transient, key); public static IServiceCollection AddAsyncHandler(this IServiceCollection services, ServiceLifetime lifetime = ServiceLifetime.Transient, - string? key = null) - where THandler : class, IAsyncHandler - where TMessage : class + string? key = null) where THandler : class, IAsyncHandler + where TMessage : class { if (key is { Length: > 0 }) { @@ -30,15 +28,13 @@ public static class IServiceCollectionExtensions } public static IServiceCollection AddAsyncHandler(this IServiceCollection services, - string key) - where THandler : class, IAsyncHandler - where TMessage : class => AddAsyncHandler(services, ServiceLifetime.Transient, key); + string key) where THandler : class, IAsyncHandler + where TMessage : class => AddAsyncHandler(services, ServiceLifetime.Transient, key); public static IServiceCollection AddAsyncHandler(this IServiceCollection services, ServiceLifetime lifetime = ServiceLifetime.Transient, - string? key = null) - where THandler : class, IAsyncHandler - where TMessage : class + string? key = null) where THandler : class, IAsyncHandler + where TMessage : class { if (key is { Length: > 0 }) { @@ -55,8 +51,7 @@ public static class IServiceCollectionExtensions } public static IServiceCollection AddAsyncInitialization(this IServiceCollection services) - where TInitialization : class, - IAsyncInitialization + where TInitialization : class, IAsyncInitialization { services.AddTransient(); return services; @@ -81,22 +76,19 @@ public static class IServiceCollectionExtensions } public static IServiceCollection AddComponent(this IServiceCollection services) - where TComponent : class, - IComponent + where TComponent : class, IComponent { services.AddTransient(); return services; } public static IServiceCollection AddHandler(this IServiceCollection services, - string key) - where THandler : class, IHandler - where TMessage : class => AddHandler(services, ServiceLifetime.Transient, key); + string key) where THandler : class, IHandler + where TMessage : class => AddHandler(services, ServiceLifetime.Transient, key); public static IServiceCollection AddHandler(this IServiceCollection services, ServiceLifetime lifetime = ServiceLifetime.Transient, - string? key = null) - where THandler : class, IHandler + string? key = null) where THandler : class, IHandler where TMessage : class { if (key is { Length: > 0 }) @@ -114,14 +106,12 @@ public static class IServiceCollectionExtensions } public static IServiceCollection AddHandler(this IServiceCollection services, - string key) - where THandler : class, IHandler - where TMessage : class => AddHandler(services, ServiceLifetime.Transient, key); + string key) where THandler : class, IHandler + where TMessage : class => AddHandler(services, ServiceLifetime.Transient, key); public static IServiceCollection AddHandler(this IServiceCollection services, ServiceLifetime lifetime = ServiceLifetime.Transient, - string? key = null) - where THandler : class, IHandler + string? key = null) where THandler : class, IHandler where TMessage : class { if (key is { Length: > 0 }) @@ -140,10 +130,8 @@ public static class IServiceCollectionExtensions public static IServiceCollection AddInitialization(this IServiceCollection services, ServiceLifetime lifetime = ServiceLifetime.Transient) - where TInitialization : class, - IInitialization - where TInitializationImplementation : class, - TInitialization + where TInitialization : class, IInitialization + where TInitializationImplementation : class, TInitialization { services.Add(new ServiceDescriptor(typeof(TInitialization), typeof(TInitializationImplementation), lifetime)); services.AddTransient(provider => provider.GetRequiredService()); @@ -151,8 +139,7 @@ public static class IServiceCollectionExtensions } public static IServiceCollection AddInitialization(this IServiceCollection services) - where TInitialization : class, - IInitialization + where TInitialization : class, IInitialization { services.AddTransient(); return services; @@ -160,8 +147,7 @@ public static class IServiceCollectionExtensions public static IServiceCollection AddInitialization(this IServiceCollection services, params object[] parameters) - where TInitialization : class, - IInitialization + where TInitialization : class, IInitialization { services.AddTransient(provider => provider.GetRequiredService() .Create(parameters)); @@ -173,13 +159,31 @@ public static class IServiceCollectionExtensions Type behaviorType, string? key = null) { - if (key is { Length: > 0 }) + bool ImplementsInterface(Type type, Type interfaceType) => + type.GetInterfaces().Any(i => i.IsGenericType && i.GetGenericTypeDefinition() == interfaceType); + + if (ImplementsInterface(behaviorType, typeof(IAsyncPipelineBehavior<,>))) { - services.AddKeyedTransient(typeof(IAsyncPipelineBehavior<>), key, behaviorType); + if (key is { Length: > 0 }) + { + services.AddKeyedTransient(typeof(IAsyncPipelineBehavior<,>), key, behaviorType); + } + else + { + services.AddTransient(typeof(IAsyncPipelineBehavior<,>), behaviorType); + } } - else + + if (ImplementsInterface(behaviorType, typeof(IAsyncPipelineBehavior<>))) { - services.AddTransient(typeof(IAsyncPipelineBehavior<>), behaviorType); + if (key is { Length: > 0 }) + { + services.AddKeyedTransient(typeof(IAsyncPipelineBehavior<>), key, behaviorType); + } + else + { + services.AddTransient(typeof(IAsyncPipelineBehavior<>), behaviorType); + } } return services;