From 60aad9cf7e76c8df45bf4545a0d56dff3ae6fbc9 Mon Sep 17 00:00:00 2001 From: Karl Burtram Date: Thu, 21 Sep 2017 08:26:13 -0700 Subject: [PATCH] Allow connections on non-default port (#462) * Allow custom port * Update unit tests for port property --- .../Admin/AdminService.cs | 131 +++++++++++------- .../Connection/ConnectionService.cs | 10 +- .../Connection/Contracts/ConnectionDetails.cs | 16 +++ .../Contracts/ConnectionDetailsExtensions.cs | 3 +- .../Connection/ConnectionDetailsTests.cs | 4 + 5 files changed, 109 insertions(+), 55 deletions(-) diff --git a/src/Microsoft.SqlTools.ServiceLayer/Admin/AdminService.cs b/src/Microsoft.SqlTools.ServiceLayer/Admin/AdminService.cs index 5b28d9e2..ca5c350b 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Admin/AdminService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Admin/AdminService.cs @@ -83,19 +83,26 @@ namespace Microsoft.SqlTools.ServiceLayer.Admin DefaultDatabaseInfoParams optionsParams, RequestContext requestContext) { - var response = new DefaultDatabaseInfoResponse(); - ConnectionInfo connInfo; - AdminService.ConnectionServiceInstance.TryFindConnection( - optionsParams.OwnerUri, - out connInfo); - - if (taskHelper == null) + try { - taskHelper = CreateDatabaseTaskHelper(connInfo); - } + var response = new DefaultDatabaseInfoResponse(); + ConnectionInfo connInfo; + AdminService.ConnectionServiceInstance.TryFindConnection( + optionsParams.OwnerUri, + out connInfo); - response.DefaultDatabaseInfo = DatabaseTaskHelper.DatabasePrototypeToDatabaseInfo(taskHelper.Prototype); - await requestContext.SendResult(response); + if (taskHelper == null) + { + taskHelper = CreateDatabaseTaskHelper(connInfo); + } + + response.DefaultDatabaseInfo = DatabaseTaskHelper.DatabasePrototypeToDatabaseInfo(taskHelper.Prototype); + await requestContext.SendResult(response); + } + catch (Exception ex) + { + await requestContext.SendError(ex.ToString()); + } } /// @@ -105,31 +112,38 @@ namespace Microsoft.SqlTools.ServiceLayer.Admin CreateDatabaseParams databaseParams, RequestContext requestContext) { - var response = new DefaultDatabaseInfoResponse(); - ConnectionInfo connInfo; - AdminService.ConnectionServiceInstance.TryFindConnection( - databaseParams.OwnerUri, - out connInfo); + try + { + var response = new DefaultDatabaseInfoResponse(); + ConnectionInfo connInfo; + AdminService.ConnectionServiceInstance.TryFindConnection( + databaseParams.OwnerUri, + out connInfo); - if (taskHelper == null) - { - taskHelper = CreateDatabaseTaskHelper(connInfo); + if (taskHelper == null) + { + taskHelper = CreateDatabaseTaskHelper(connInfo); + } + + DatabasePrototype prototype = taskHelper.Prototype; + DatabaseTaskHelper.ApplyToPrototype(databaseParams.DatabaseInfo, taskHelper.Prototype); + + Database db = prototype.ApplyChanges(); + if (db != null) + { + taskHelper = null; + } + + await requestContext.SendResult(new CreateDatabaseResponse() + { + Result = true, + TaskId = 0 + }); } - - DatabasePrototype prototype = taskHelper.Prototype; - DatabaseTaskHelper.ApplyToPrototype(databaseParams.DatabaseInfo, taskHelper.Prototype); - - Database db = prototype.ApplyChanges(); - if (db != null) + catch (Exception ex) { - taskHelper = null; - } - - await requestContext.SendResult(new CreateDatabaseResponse() - { - Result = true, - TaskId = 0 - }); + await requestContext.SendError(ex.ToString()); + } } /// @@ -138,21 +152,28 @@ namespace Microsoft.SqlTools.ServiceLayer.Admin internal static async Task HandleGetDatabaseInfoRequest( GetDatabaseInfoParams databaseParams, RequestContext requestContext) - { - ConnectionInfo connInfo; - AdminService.ConnectionServiceInstance.TryFindConnection( - databaseParams.OwnerUri, - out connInfo); - DatabaseInfo info = null; - - if (connInfo != null) - { - info = GetDatabaseInfo(connInfo); - } + { + try + { + ConnectionInfo connInfo; + AdminService.ConnectionServiceInstance.TryFindConnection( + databaseParams.OwnerUri, + out connInfo); + DatabaseInfo info = null; + + if (connInfo != null) + { + info = GetDatabaseInfo(connInfo); + } - await requestContext.SendResult(new GetDatabaseInfoResponse(){ - DatabaseInfo = info - }); + await requestContext.SendResult(new GetDatabaseInfoResponse(){ + DatabaseInfo = info + }); + } + catch (Exception ex) + { + await requestContext.SendError(ex.ToString()); + } } /// @@ -177,15 +198,21 @@ namespace Microsoft.SqlTools.ServiceLayer.Admin XmlDocument xmlDoc = CreateDataContainerDocument(connInfo, databaseExists); CDataContainer dataContainer; + // add alternate port to server name property if provided + var connectionDetails = connInfo.ConnectionDetails; + string serverName = !connectionDetails.Port.HasValue + ? connectionDetails.ServerName + : string.Format("{0},{1}", connectionDetails.ServerName, connectionDetails.Port.Value); + // check if the connection is using SQL Auth or Integrated Auth - if (string.Equals(connInfo.ConnectionDetails.AuthenticationType, "SqlLogin", StringComparison.OrdinalIgnoreCase)) + if (string.Equals(connectionDetails.AuthenticationType, "SqlLogin", StringComparison.OrdinalIgnoreCase)) { - var passwordSecureString = BuildSecureStringFromPassword(connInfo.ConnectionDetails.Password); + var passwordSecureString = BuildSecureStringFromPassword(connectionDetails.Password); dataContainer = new CDataContainer( CDataContainer.ServerType.SQL, - connInfo.ConnectionDetails.ServerName, + serverName, false, - connInfo.ConnectionDetails.UserName, + connectionDetails.UserName, passwordSecureString, xmlDoc.InnerXml); } @@ -193,7 +220,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Admin { dataContainer = new CDataContainer( CDataContainer.ServerType.SQL, - connInfo.ConnectionDetails.ServerName, + serverName, true, null, null, diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs index f09e39a9..a8355b3b 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs @@ -886,10 +886,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection { connectionBuilder = new SqlConnectionStringBuilder(connectionDetails.ConnectionString); } - else { + else + { + // add alternate port to data source property if provided + string dataSource = !connectionDetails.Port.HasValue + ? connectionDetails.ServerName + : string.Format("{0},{1}", connectionDetails.ServerName, connectionDetails.Port.Value); + connectionBuilder = new SqlConnectionStringBuilder { - ["Data Source"] = connectionDetails.ServerName, + ["Data Source"] = dataSource, ["User Id"] = connectionDetails.UserName, ["Password"] = connectionDetails.Password }; diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs index 1cf7bf92..e0b865ad 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs @@ -418,6 +418,22 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts } } + /// + /// Gets or sets the port to use for the TCP/IP connection + /// + public int? Port + { + get + { + return GetOptionValue("port"); + } + + set + { + SetOptionValue("port", value); + } + } + /// /// Gets or sets a string value that indicates the type system the application expects. /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetailsExtensions.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetailsExtensions.cs index 80f3ad9a..36c2dde4 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetailsExtensions.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetailsExtensions.cs @@ -43,7 +43,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts MultipleActiveResultSets = details.MultipleActiveResultSets, PacketSize = details.PacketSize, TypeSystemVersion = details.TypeSystemVersion, - ConnectionString = details.ConnectionString + ConnectionString = details.ConnectionString, + Port = details.Port }; } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionDetailsTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionDetailsTests.cs index c19acf9a..6f32b3e7 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionDetailsTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionDetailsTests.cs @@ -51,6 +51,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection Assert.Equal(details.Pooling, expectedForBoolean); Assert.Equal(details.Replication, expectedForBoolean); Assert.Equal(details.TrustServerCertificate, expectedForBoolean); + Assert.Equal(details.Port, expectedForInt); } [Fact] @@ -87,6 +88,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection details.Pooling = (index++ % 2 == 0); details.Replication = (index++ % 2 == 0); details.TrustServerCertificate = (index++ % 2 == 0); + details.Port = expectedForInt + index++; index = 0; Assert.Equal(details.ApplicationIntent, expectedForStrings + index++); @@ -115,6 +117,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection Assert.Equal(details.Pooling, (index++ % 2 == 0)); Assert.Equal(details.Replication, (index++ % 2 == 0)); Assert.Equal(details.TrustServerCertificate, (index++ % 2 == 0)); + Assert.Equal(details.Port, (expectedForInt + index++)); } [Fact] @@ -152,6 +155,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection details.Pooling = (index++ % 2 == 0); details.Replication = (index++ % 2 == 0); details.TrustServerCertificate = (index++ % 2 == 0); + details.Port = expectedForInt + index++; if(optionMetadata.Options.Count() != details.Options.Count) {