diff --git a/framework/src/Volo.Abp.EntityFrameworkCore/Microsoft/Extensions/DependencyInjection/AbpEfCoreServiceCollectionExtensions.cs b/framework/src/Volo.Abp.EntityFrameworkCore/Microsoft/Extensions/DependencyInjection/AbpEfCoreServiceCollectionExtensions.cs index d39c5db121..538614ee22 100644 --- a/framework/src/Volo.Abp.EntityFrameworkCore/Microsoft/Extensions/DependencyInjection/AbpEfCoreServiceCollectionExtensions.cs +++ b/framework/src/Volo.Abp.EntityFrameworkCore/Microsoft/Extensions/DependencyInjection/AbpEfCoreServiceCollectionExtensions.cs @@ -21,7 +21,17 @@ namespace Microsoft.Extensions.DependencyInjection foreach (var dbContextType in options.ReplacedDbContextTypes) { - services.Replace(ServiceDescriptor.Transient(dbContextType, typeof(TDbContext))); + services.Replace( + ServiceDescriptor.Transient( + dbContextType, + sp => sp.GetRequiredService(typeof(TDbContext)) + ) + ); + + services.Configure(opts => + { + opts.DbContextReplacements[dbContextType] = typeof(TDbContext); + }); } new EfCoreRepositoryRegistrar(options).AddRepositories(); diff --git a/framework/src/Volo.Abp.EntityFrameworkCore/Properties/AssemblyInfo.cs b/framework/src/Volo.Abp.EntityFrameworkCore/Properties/AssemblyInfo.cs index 9345c565d8..9d3bf27070 100644 --- a/framework/src/Volo.Abp.EntityFrameworkCore/Properties/AssemblyInfo.cs +++ b/framework/src/Volo.Abp.EntityFrameworkCore/Properties/AssemblyInfo.cs @@ -1,4 +1,5 @@ using System.Reflection; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; // General Information about an assembly is controlled through the following @@ -8,6 +9,7 @@ using System.Runtime.InteropServices; [assembly: AssemblyCompany("")] [assembly: AssemblyProduct("Volo.Abp.EntityFrameworkCore")] [assembly: AssemblyTrademark("")] +[assembly: InternalsVisibleTo("Volo.Abp.EntityFrameworkCore.Tests")] // Setting ComVisible to false makes the types in this assembly not visible // to COM components. If you need to access a type in this assembly from diff --git a/framework/src/Volo.Abp.EntityFrameworkCore/Volo/Abp/EntityFrameworkCore/AbpDbContextOptions.cs b/framework/src/Volo.Abp.EntityFrameworkCore/Volo/Abp/EntityFrameworkCore/AbpDbContextOptions.cs index b1b84f7236..e3326ac7c5 100644 --- a/framework/src/Volo.Abp.EntityFrameworkCore/Volo/Abp/EntityFrameworkCore/AbpDbContextOptions.cs +++ b/framework/src/Volo.Abp.EntityFrameworkCore/Volo/Abp/EntityFrameworkCore/AbpDbContextOptions.cs @@ -7,19 +7,22 @@ namespace Volo.Abp.EntityFrameworkCore { public class AbpDbContextOptions { - internal List> DefaultPreConfigureActions { get; set; } + internal List> DefaultPreConfigureActions { get; } internal Action DefaultConfigureAction { get; set; } - internal Dictionary> PreConfigureActions { get; set; } + internal Dictionary> PreConfigureActions { get; } - internal Dictionary ConfigureActions { get; set; } + internal Dictionary ConfigureActions { get; } + + internal Dictionary DbContextReplacements { get; } public AbpDbContextOptions() { DefaultPreConfigureActions = new List>(); PreConfigureActions = new Dictionary>(); ConfigureActions = new Dictionary(); + DbContextReplacements = new Dictionary(); } public void PreConfigure([NotNull] Action action) @@ -57,5 +60,20 @@ namespace Volo.Abp.EntityFrameworkCore ConfigureActions[typeof(TDbContext)] = action; } + + internal Type GetReplacedTypeOrSelf(Type dbContextType) + { + while (true) + { + if (DbContextReplacements.TryGetValue(dbContextType, out var foundType)) + { + dbContextType = foundType; + } + else + { + return dbContextType; + } + } + } } } \ No newline at end of file diff --git a/framework/src/Volo.Abp.EntityFrameworkCore/Volo/Abp/Uow/EntityFrameworkCore/EfCoreDatabaseApi.cs b/framework/src/Volo.Abp.EntityFrameworkCore/Volo/Abp/Uow/EntityFrameworkCore/EfCoreDatabaseApi.cs index cffda462b0..88a2428318 100644 --- a/framework/src/Volo.Abp.EntityFrameworkCore/Volo/Abp/Uow/EntityFrameworkCore/EfCoreDatabaseApi.cs +++ b/framework/src/Volo.Abp.EntityFrameworkCore/Volo/Abp/Uow/EntityFrameworkCore/EfCoreDatabaseApi.cs @@ -1,15 +1,15 @@ using System.Threading; using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore; using Volo.Abp.EntityFrameworkCore; namespace Volo.Abp.Uow.EntityFrameworkCore { - public class EfCoreDatabaseApi : IDatabaseApi, ISupportsSavingChanges - where TDbContext : IEfCoreDbContext + public class EfCoreDatabaseApi : IDatabaseApi, ISupportsSavingChanges { - public TDbContext DbContext { get; } + public IEfCoreDbContext DbContext { get; } - public EfCoreDatabaseApi(TDbContext dbContext) + public EfCoreDatabaseApi(IEfCoreDbContext dbContext) { DbContext = dbContext; } diff --git a/framework/src/Volo.Abp.EntityFrameworkCore/Volo/Abp/Uow/EntityFrameworkCore/UnitOfWorkDbContextProvider.cs b/framework/src/Volo.Abp.EntityFrameworkCore/Volo/Abp/Uow/EntityFrameworkCore/UnitOfWorkDbContextProvider.cs index 89cc69fb74..07f1d200aa 100644 --- a/framework/src/Volo.Abp.EntityFrameworkCore/Volo/Abp/Uow/EntityFrameworkCore/UnitOfWorkDbContextProvider.cs +++ b/framework/src/Volo.Abp.EntityFrameworkCore/Volo/Abp/Uow/EntityFrameworkCore/UnitOfWorkDbContextProvider.cs @@ -6,6 +6,7 @@ using Microsoft.EntityFrameworkCore.Storage; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; using Volo.Abp.Data; using Volo.Abp.EntityFrameworkCore; using Volo.Abp.EntityFrameworkCore.DependencyInjection; @@ -14,8 +15,6 @@ using Volo.Abp.Threading; namespace Volo.Abp.Uow.EntityFrameworkCore { - //TODO: Implement logic in DefaultDbContextResolver.Resolve in old ABP. - public class UnitOfWorkDbContextProvider : IDbContextProvider where TDbContext : IEfCoreDbContext { @@ -25,17 +24,20 @@ namespace Volo.Abp.Uow.EntityFrameworkCore private readonly IConnectionStringResolver _connectionStringResolver; private readonly ICancellationTokenProvider _cancellationTokenProvider; private readonly ICurrentTenant _currentTenant; + private readonly AbpDbContextOptions _options; public UnitOfWorkDbContextProvider( IUnitOfWorkManager unitOfWorkManager, IConnectionStringResolver connectionStringResolver, ICancellationTokenProvider cancellationTokenProvider, - ICurrentTenant currentTenant) + ICurrentTenant currentTenant, + IOptions options) { _unitOfWorkManager = unitOfWorkManager; _connectionStringResolver = connectionStringResolver; _cancellationTokenProvider = cancellationTokenProvider; _currentTenant = currentTenant; + _options = options.Value; Logger = NullLogger>.Instance; } @@ -63,15 +65,16 @@ namespace Volo.Abp.Uow.EntityFrameworkCore var connectionStringName = ConnectionStringNameAttribute.GetConnStringName(); var connectionString = ResolveConnectionString(connectionStringName); - var dbContextKey = $"{typeof(TDbContext).FullName}_{connectionString}"; + var targetDbContextType = _options.GetReplacedTypeOrSelf(typeof(TDbContext)); + var dbContextKey = $"{targetDbContextType.FullName}_{connectionString}"; var databaseApi = unitOfWork.GetOrAddDatabaseApi( dbContextKey, - () => new EfCoreDatabaseApi( + () => new EfCoreDatabaseApi( CreateDbContext(unitOfWork, connectionStringName, connectionString) )); - return ((EfCoreDatabaseApi)databaseApi).DbContext; + return (TDbContext)((EfCoreDatabaseApi)databaseApi).DbContext; } public async Task GetDbContextAsync() @@ -85,20 +88,21 @@ namespace Volo.Abp.Uow.EntityFrameworkCore var connectionStringName = ConnectionStringNameAttribute.GetConnStringName(); var connectionString = await ResolveConnectionStringAsync(connectionStringName); - var dbContextKey = $"{typeof(TDbContext).FullName}_{connectionString}"; + var targetDbContextType = _options.GetReplacedTypeOrSelf(typeof(TDbContext)); + var dbContextKey = $"{targetDbContextType.FullName}_{connectionString}"; var databaseApi = unitOfWork.FindDatabaseApi(dbContextKey); if (databaseApi == null) { - databaseApi = new EfCoreDatabaseApi( + databaseApi = new EfCoreDatabaseApi( await CreateDbContextAsync(unitOfWork, connectionStringName, connectionString) ); unitOfWork.AddDatabaseApi(dbContextKey, databaseApi); } - return ((EfCoreDatabaseApi)databaseApi).DbContext; + return (TDbContext)((EfCoreDatabaseApi)databaseApi).DbContext; } private TDbContext CreateDbContext(IUnitOfWork unitOfWork, string connectionStringName, string connectionString) diff --git a/framework/test/Volo.Abp.EntityFrameworkCore.Tests/Volo/Abp/EntityFrameworkCore/DbContext_Replace_Tests.cs b/framework/test/Volo.Abp.EntityFrameworkCore.Tests/Volo/Abp/EntityFrameworkCore/DbContext_Replace_Tests.cs index 977cb8522a..8fd88f23fe 100644 --- a/framework/test/Volo.Abp.EntityFrameworkCore.Tests/Volo/Abp/EntityFrameworkCore/DbContext_Replace_Tests.cs +++ b/framework/test/Volo.Abp.EntityFrameworkCore.Tests/Volo/Abp/EntityFrameworkCore/DbContext_Replace_Tests.cs @@ -1,9 +1,11 @@ using System; using System.Threading.Tasks; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; using Shouldly; using Volo.Abp.Domain.Repositories; using Volo.Abp.EntityFrameworkCore.TestApp.ThirdDbContext; +using Volo.Abp.TestApp.Domain; using Volo.Abp.TestApp.EntityFrameworkCore; using Volo.Abp.Uow; using Xunit; @@ -13,24 +15,40 @@ namespace Volo.Abp.EntityFrameworkCore public class DbContext_Replace_Tests : EntityFrameworkCoreTestBase { private readonly IBasicRepository _dummyRepository; + private readonly IPersonRepository _personRepository; private readonly IUnitOfWorkManager _unitOfWorkManager; + private readonly AbpDbContextOptions _options; public DbContext_Replace_Tests() { - _dummyRepository = ServiceProvider.GetRequiredService>(); - _unitOfWorkManager = ServiceProvider.GetRequiredService(); + _dummyRepository = GetRequiredService>(); + _personRepository = GetRequiredService(); + _unitOfWorkManager = GetRequiredService(); + _options = GetRequiredService>().Value; } [Fact] public async Task Should_Replace_DbContext() { + _options.GetReplacedTypeOrSelf(typeof(IThirdDbContext)).ShouldBe(typeof(TestAppDbContext)); + (ServiceProvider.GetRequiredService() is TestAppDbContext).ShouldBeTrue(); using (var uow = _unitOfWorkManager.Begin()) { - ((await _dummyRepository.GetDbContextAsync()) is IThirdDbContext).ShouldBeTrue(); - ((await _dummyRepository.GetDbContextAsync()) is TestAppDbContext).ShouldBeTrue(); + var instance1 = await _dummyRepository.GetDbContextAsync(); + (instance1 is IThirdDbContext).ShouldBeTrue(); + var instance2 = await _dummyRepository.GetDbContextAsync(); + (instance2 is TestAppDbContext).ShouldBeTrue(); + + var instance3 = await _personRepository.GetDbContextAsync(); + (instance3 is TestAppDbContext).ShouldBeTrue(); + + // All instances should be the same! + instance3.ShouldBe(instance1); + instance3.ShouldBe(instance2); + await uow.CompleteAsync(); } }