diff --git a/src/Microsoft.SqlTools.Hosting/Extensibility/ExtensionServiceProvider.cs b/src/Microsoft.SqlTools.Hosting/Extensibility/ExtensionServiceProvider.cs index a031305b..620354a7 100644 --- a/src/Microsoft.SqlTools.Hosting/Extensibility/ExtensionServiceProvider.cs +++ b/src/Microsoft.SqlTools.Hosting/Extensibility/ExtensionServiceProvider.cs @@ -29,6 +29,15 @@ namespace Microsoft.SqlTools.ServiceLayer.Extensibility public static ExtensionServiceProvider CreateDefaultServiceProvider() { + // only allow loading MEF dependencies from our assemblies until we can + // better seperate out framework assemblies and extension assemblies + string[] inclusionList = + { + "microsoft.sqltools.credentials.dll", + "microsoft.sqltools.hosting.dll", + "microsoft.sqltools.servicelayer.dll" + }; + string assemblyPath = typeof(ExtensionStore).GetTypeInfo().Assembly.Location; string directory = Path.GetDirectoryName(assemblyPath); @@ -38,13 +47,29 @@ namespace Microsoft.SqlTools.ServiceLayer.Extensibility List assemblies = new List(); foreach (var path in assemblyPaths) { + // skip DLL files not in inclusion list + bool isInList = false; + foreach (var item in inclusionList) + { + if (path.EndsWith(item, StringComparison.OrdinalIgnoreCase)) + { + isInList = true; + break; + } + } + + if (!isInList) + { + continue; + } + try { assemblies.Add( context.LoadFromAssemblyName( AssemblyLoadContext.GetAssemblyName(path))); } - catch (System.BadImageFormatException) + catch (Exception) { // we expect exceptions trying to scan all DLLs since directory contains native libraries } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Extensibility/ExtensionTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Extensibility/ExtensionTests.cs index bca97a77..4b0b1081 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Extensibility/ExtensionTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Extensibility/ExtensionTests.cs @@ -42,12 +42,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Extensibility } [Fact] - public void CreateDefaultServiceProviderShouldFindTypesInAllAssemblies() + public void CreateDefaultServiceProviderShouldFindTypesInAllKnownAssemblies() { // Given a default ExtensionServiceProvider - // Then should not find exports from a different assembly + // Then we should not find exports from a test assembly ExtensionServiceProvider serviceProvider = ExtensionServiceProvider.CreateDefaultServiceProvider(); - Assert.NotEmpty(serviceProvider.GetServices()); + Assert.Empty(serviceProvider.GetServices()); // But should find exports that are defined in the main assembly Assert.NotEmpty(serviceProvider.GetServices());