using Microsoft.Extensions.DependencyInjection; using System.Reflection; using System.Runtime.CompilerServices; namespace Toolkit.Foundation; public class Mediator(IHandlerProvider handlerProvider, IServiceProvider provider) : IMediator { public async Task Handle(TMessage message, object? key = null, CancellationToken cancellationToken = default) where TMessage : notnull { Type messageType = message.GetType(); Type handlerType = typeof(HandlerWrapper<,>).MakeGenericType(messageType, typeof(TResponse)); key = $"{(key is not null ? $"{key}:" : "")}{handlerType}"; List handlers = GetHandlers(message, handlerType, key); foreach (object? handler in handlers) { MethodInfo? handleMethod = handler?.GetType().GetMethod("Handle", [message.GetType(), typeof(CancellationToken)]); if (handleMethod is not null) { return await (Task)handleMethod.Invoke(handler, new object[] { message, cancellationToken })!; } } return default; } public async Task Handle(Type responseType, object message, object? key = null, CancellationToken cancellationToken = default) { Type messageType = message.GetType(); Type handlerType = typeof(HandlerWrapper<,>).MakeGenericType(message.GetType(), responseType); key = $"{(key is not null ? $"{key}:" : "")}{handlerType}"; List handlers = GetHandlers(message, handlerType, key); foreach (object? handler in handlers) { MethodInfo? handleMethod = handler?.GetType().GetMethod("Handle", [messageType, typeof(CancellationToken)]); if (handleMethod is not null) { dynamic task = handleMethod.Invoke(handler, new object[] { message, cancellationToken })!; await task; return task.Result; } } return default; } public async Task> HandleMany(Type responseType, object message, object? key = null, CancellationToken cancellationToken = default) { List responses = []; await foreach (object? response in HandleManyAsync(responseType, message, key, cancellationToken)) { responses.Add(response); } return responses; } public async Task> HandleMany(TMessage message, object? key = null, CancellationToken cancellationToken = default) where TMessage : notnull { List responses = []; await foreach (TResponse? response in HandleManyAsync(message, key, cancellationToken)) { responses.Add(response); } return responses; } public async IAsyncEnumerable HandleManyAsync(Type responseType, object message, object? key = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { Type messageType = message.GetType(); Type handlerType = typeof(HandlerWrapper<,>).MakeGenericType(message.GetType(), responseType); key = $"{(key is not null ? $"{key}:" : "")}{handlerType}"; List handlers = GetHandlers(message, handlerType, key); foreach (object? handler in handlers) { MethodInfo? handleMethod = handler?.GetType().GetMethod("Handle", [messageType, typeof(CancellationToken)]); if (handleMethod is not null) { yield return await (Task)handleMethod.Invoke(handler, new object[] { message, cancellationToken })!; } } } public async IAsyncEnumerable HandleManyAsync(TMessage message, object? key = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) where TMessage : notnull { Type messageType = message.GetType(); Type handlerType = typeof(HandlerWrapper<,>).MakeGenericType(messageType, typeof(TResponse)); key = $"{(key is not null ? $"{key}:" : "")}{handlerType}"; List handlers = GetHandlers(message, handlerType, key); foreach (object? handler in handlers) { MethodInfo? handleMethod = handler?.GetType().GetMethod("Handle", [message.GetType(), typeof(CancellationToken)]); if (handleMethod is not null) { yield return await (Task)handleMethod.Invoke(handler, new object[] { message, cancellationToken })!; } } } private List GetHandlers(object message, Type handlerWrapperType, object? key) { Type messageType = message.GetType(); Dictionary> handlers = []; void AddHandlers(IEnumerable newHandlers) { foreach (object? handler in newHandlers) { if (handler == null) continue; Type serviceType = handler.GetType(); if (!handlers.TryGetValue(serviceType, out List? handlerList)) { handlerList = []; handlers.Add(serviceType, handlerList); } handlerList.Add(handler); } } IEnumerable keyedServices = key is not null ? provider.GetKeyedServices(handlerWrapperType, key) : provider.GetServices(handlerWrapperType); AddHandlers(keyedServices); IEnumerable additionalHandlers = handlerProvider.Get(key); AddHandlers(additionalHandlers); return handlers.SelectMany(entry => entry.Value).ToList(); } }