diff --git a/src/Microsoft.SqlTools.Hosting/Hosting/Protocol/RequestContext.cs b/src/Microsoft.SqlTools.Hosting/Hosting/Protocol/RequestContext.cs index 1ff30f9d..2207383b 100644 --- a/src/Microsoft.SqlTools.Hosting/Hosting/Protocol/RequestContext.cs +++ b/src/Microsoft.SqlTools.Hosting/Hosting/Protocol/RequestContext.cs @@ -48,8 +48,8 @@ namespace Microsoft.SqlTools.Hosting.Protocol Code = errorCode }; return this.messageWriter.WriteError( - requestMessage.Id, requestMessage.Method, + requestMessage.Id, error); } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs index 42931751..f589a547 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs @@ -412,8 +412,18 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection /// creates a new connection. This cannot be used to create a default connection or to create a /// connection if a default connection does not exist. /// + /// URI identifying the resource mapped to this connection + /// + /// What the purpose for this connection is. A single resource + /// such as a SQL file may have multiple connections - one for Intellisense, another for query execution + /// + /// + /// Workaround for .Net Core clone connection issues: should persist security be used so that + /// when SMO clones connections it can do so without breaking on SQL Password connections. + /// This should be removed once the core issue is resolved and clone works as expected + /// /// A DB connection for the connection type requested - public async Task GetOrOpenConnection(string ownerUri, string connectionType) + public async Task GetOrOpenConnection(string ownerUri, string connectionType, bool alwaysPersistSecurity = false) { Validate.IsNotNullOrEmptyString(nameof(ownerUri), ownerUri); Validate.IsNotNullOrEmptyString(nameof(connectionType), connectionType); @@ -439,13 +449,26 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection // If the DbConnection does not exist and is not the default connection, create one. // We can't create the default (initial) connection here because we won't have a ConnectionDetails // if Connect() has not yet been called. + bool? originalPersistSecurityInfo = connectionInfo.ConnectionDetails.PersistSecurityInfo; + if (alwaysPersistSecurity) + { + connectionInfo.ConnectionDetails.PersistSecurityInfo = true; + } ConnectParams connectParams = new ConnectParams { OwnerUri = ownerUri, Connection = connectionInfo.ConnectionDetails, Type = connectionType }; - await Connect(connectParams); + try + { + await Connect(connectParams); + } + finally + { + connectionInfo.ConnectionDetails.PersistSecurityInfo = originalPersistSecurityInfo; + } + connectionInfo.TryGetConnection(connectionType, out connection); } diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/EditInitializeRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/EditInitializeRequest.cs index 49533757..ecc1b9f8 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/EditInitializeRequest.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/EditInitializeRequest.cs @@ -22,6 +22,11 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.Contracts /// public string ObjectName { get; set; } + /// + /// The schema for the object to use + /// + public string SchemaName { get; set; } + /// /// The type of the object to use for generating an edit script /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/EditDataService.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/EditDataService.cs index 23700cdb..26f163a3 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/EditData/EditDataService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/EditDataService.cs @@ -152,7 +152,7 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData Func executionFailureHandler = (e) => SendSessionReadyEvent(requestContext, initParams.OwnerUri, false, e.Message); Func executionSuccessHandler = () => SendSessionReadyEvent(requestContext, initParams.OwnerUri, true, null); - EditSession.Connector connector = () => connectionService.GetOrOpenConnection(initParams.OwnerUri, ConnectionType.Edit); + EditSession.Connector connector = () => connectionService.GetOrOpenConnection(initParams.OwnerUri, ConnectionType.Edit, alwaysPersistSecurity: true); EditSession.QueryRunner queryRunner = q => SessionInitializeQueryRunner(initParams.OwnerUri, requestContext, q); try diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/EditSession.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/EditSession.cs index 5886c253..482c615e 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/EditData/EditSession.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/EditSession.cs @@ -431,7 +431,7 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData try { // Step 1) Look up the SMO metadata - string[] namedParts = SqlScriptFormatter.DecodeMultipartIdenfitier(initParams.ObjectName); + string[] namedParts = GetEditTargetName(initParams); objectMetadata = metadataFactory.GetObjectMetadata(await connector(), namedParts, initParams.ObjectType); @@ -459,6 +459,16 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData } } + public static string[] GetEditTargetName(EditInitializeParams initParams) + { + // Step 1) Look up the SMO metadata + if (initParams.SchemaName != null) + { + return new [] { initParams.SchemaName, initParams.ObjectName }; + } + return SqlScriptFormatter.DecodeMultipartIdenfitier(initParams.ObjectName); + } + private async Task CommitEditsInternal(DbConnection connection, Func successHandler, Func errorHandler) { try diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/SmoEditMetadataFactory.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/SmoEditMetadataFactory.cs index d9834cb6..ff613767 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/EditData/SmoEditMetadataFactory.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/SmoEditMetadataFactory.cs @@ -58,6 +58,7 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData // Connect with SMO and get the metadata for the table Server server = new Server(new ServerConnection(sqlConn)); Database db = new Database(server, sqlConn.Database); + TableViewTableTypeBase smoResult; switch (objectType.ToLowerInvariant()) { diff --git a/src/Microsoft.SqlTools.ServiceLayer/Workspace/Workspace.cs b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Workspace.cs index 139e2be0..7041f06b 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Workspace/Workspace.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Workspace.cs @@ -23,10 +23,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Workspace { #region Private Fields + private const string UntitledScheme = "untitled"; private static readonly HashSet fileUriSchemes = new HashSet(StringComparer.OrdinalIgnoreCase) { "file", - "untitled", + UntitledScheme, "tsqloutput" }; @@ -101,6 +102,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Workspace ScriptFile scriptFile = null; if (!this.workspaceFiles.TryGetValue(keyName, out scriptFile)) { + if (IsUntitled(resolvedFilePath)) + { + // It's not a registered untitled file, so any attempt to read from disk will fail as it's in memory + return null; + } // This method allows FileNotFoundException to bubble up // if the file isn't found. using (FileStream fileStream = new FileStream(resolvedFilePath, FileMode.Open, FileAccess.Read)) @@ -259,16 +265,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Workspace relativePath)); return combinedPath; - } - internal static bool IsPathInMemoryOrNonFileUri(string path) - { - string scheme = GetScheme(path); - if (!string.IsNullOrEmpty(scheme)) - { - return !scheme.Equals("file"); - } - return false; - } + } + internal static bool IsPathInMemoryOrNonFileUri(string path) + { + string scheme = GetScheme(path); + if (!string.IsNullOrEmpty(scheme)) + { + return !scheme.Equals("file"); + } + return false; + } public static string GetScheme(string uri) { @@ -287,17 +293,27 @@ namespace Microsoft.SqlTools.ServiceLayer.Workspace return match.Groups[1].Value; } return null; - } - - private bool IsNonFileUri(string path) - { - string scheme = GetScheme(path); - if (!string.IsNullOrEmpty(scheme)) - { - return !fileUriSchemes.Contains(scheme); ; - } - return false; - } + } + + private bool IsNonFileUri(string path) + { + string scheme = GetScheme(path); + if (!string.IsNullOrEmpty(scheme)) + { + return !fileUriSchemes.Contains(scheme); ; + } + return false; + } + + private bool IsUntitled(string path) + { + string scheme = GetScheme(path); + if (scheme != null && scheme.Length > 0) + { + return string.Compare(UntitledScheme, scheme, StringComparison.OrdinalIgnoreCase) == 0; + } + return false; + } #endregion diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/ServiceIntegrationTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/ServiceIntegrationTests.cs index 4200fc9b..f58803f2 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/ServiceIntegrationTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/ServiceIntegrationTests.cs @@ -277,6 +277,33 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.EditData Assert.Empty(eds.ActiveSessions); } + [Theory] + [InlineData("table", "myschema", new [] { "myschema", "table" })] // Use schema + [InlineData("table", null, new [] { "table" })] // skip schema + [InlineData("schema.table", "myschema", new [] { "myschema", "schema.table"})] // Use schema + [InlineData("schema.table", null, new [] { "schema", "table"})] // Split object name into schema + public void ShouldUseSchemaNameIfDefined(string objName, string schemaName, string[] expectedNameParts) + { + // Setup: Create an edit data service without a session + var eds = new EditDataService(null, null, null); + + // If: + // ... I have init params with an object and schema parameter + var initParams = new EditInitializeParams + { + ObjectName = objName, + SchemaName = schemaName, + OwnerUri = Common.OwnerUri, + ObjectType = "table" + }; + + // ... And I get named parts for that + string[] nameParts = EditSession.GetEditTargetName(initParams); + + // Then: + Assert.Equal(expectedNameParts, nameParts); + } + private static async Task GetDefaultSession() { // ... Create a session with a proper query and metadata