Add support for Azure Active Directory connections (#727)

This commit is contained in:
Matt Irvine
2018-11-13 11:50:30 -08:00
committed by GitHub
parent 2cb7f682c5
commit 7f28f249de
32 changed files with 291 additions and 121 deletions

View File

@@ -258,7 +258,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
RetryPolicy connectionRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
RetryPolicy commandRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
ReliableSqlConnection connection = new ReliableSqlConnection(csb.ConnectionString, connectionRetryPolicy, commandRetryPolicy);
ReliableSqlConnection connection = new ReliableSqlConnection(csb.ConnectionString, connectionRetryPolicy, commandRetryPolicy, null);
return connection;
}
@@ -283,7 +283,8 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
logPath = ReliableConnectionHelper.GetDefaultDatabaseLogPath(conn);
},
catchException: null,
useRetry: false);
useRetry: false,
azureAccountToken: null);
Assert.False(string.IsNullOrWhiteSpace(filePath));
Assert.False(string.IsNullOrWhiteSpace(logPath));
@@ -342,7 +343,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
var connectionBuilder = CreateTestConnectionStringBuilder();
Assert.NotNull(connectionBuilder);
bool isReadOnly = ReliableConnectionHelper.IsDatabaseReadonly(connectionBuilder);
bool isReadOnly = ReliableConnectionHelper.IsDatabaseReadonly(connectionBuilder, null);
Assert.False(isReadOnly);
}
@@ -352,7 +353,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
[Fact]
public void TestIsDatabaseReadonlyWithNullBuilder()
{
Assert.Throws<ArgumentNullException>(() => ReliableConnectionHelper.IsDatabaseReadonly(null));
Assert.Throws<ArgumentNullException>(() => ReliableConnectionHelper.IsDatabaseReadonly(null, null));
}
/// <summary>
@@ -361,7 +362,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
[Fact]
public void VerifyAnsiNullAndQuotedIdentifierSettingsReplayed()
{
using (ReliableSqlConnection conn = (ReliableSqlConnection) ReliableConnectionHelper.OpenConnection(CreateTestConnectionStringBuilder(), useRetry: true))
using (ReliableSqlConnection conn = (ReliableSqlConnection) ReliableConnectionHelper.OpenConnection(CreateTestConnectionStringBuilder(), useRetry: true, azureAccountToken: null))
{
VerifySessionSettings(conn, true);
VerifySessionSettings(conn, false);
@@ -506,7 +507,8 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
"SET NOCOUNT ON; SET NOCOUNT OFF;",
ReliableConnectionHelper.SetCommandTimeout,
null,
true
true,
null
);
Assert.NotNull(result);
}
@@ -519,7 +521,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
{
ReliableConnectionHelper.ServerInfo info = null;
var connBuilder = CreateTestConnectionStringBuilder();
Assert.True(ReliableConnectionHelper.TryGetServerVersion(connBuilder.ConnectionString, out info));
Assert.True(ReliableConnectionHelper.TryGetServerVersion(connBuilder.ConnectionString, out info, null));
Assert.NotNull(info);
Assert.NotNull(info.ServerVersion);
@@ -535,7 +537,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
RunIfWrapper.RunIfWindows(() =>
{
ReliableConnectionHelper.ServerInfo info = null;
Assert.False(ReliableConnectionHelper.TryGetServerVersion("this is not a valid connstr", out info));
Assert.False(ReliableConnectionHelper.TryGetServerVersion("this is not a valid connstr", out info, null));
});
}
@@ -679,12 +681,13 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
var result = LiveConnectionHelper.InitLiveConnectionInfo(null, queryTempFile.FilePath);
ConnectionInfo connInfo = result.ConnectionInfo;
DbConnection connection = connInfo.ConnectionTypeToConnectionMap[ConnectionType.Default];
connection.Open();
Assert.True(connection.State == ConnectionState.Open, "Connection should be open.");
Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(connection));
SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder();
Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(builder));
Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(builder, null));
ReliableConnectionHelper.TryAddAlwaysOnConnectionProperties(builder, new SqlConnectionStringBuilder());
Assert.NotNull(ReliableConnectionHelper.GetServerName(connection));

View File

@@ -33,7 +33,7 @@
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.3.0" />
<PackageReference Include="xunit" Version="2.2.0" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" />
<PackageReference Include="System.Data.SqlClient" Version="4.5.0" />
<PackageReference Include="System.Data.SqlClient" Version="4.6.0-preview3-27014-02" />
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
</ItemGroup>
<ItemGroup>

View File

@@ -12,7 +12,7 @@
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.3.0" />
<PackageReference Include="xunit" Version="2.2.0" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" />
<PackageReference Include="System.Data.SqlClient" Version="4.5.0" />
<PackageReference Include="System.Data.SqlClient" Version="4.6.0-preview3-27014-02" />
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
</ItemGroup>
<ItemGroup>

View File

@@ -12,7 +12,7 @@
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.3.0" />
<PackageReference Include="xunit" Version="2.2.0" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" />
<PackageReference Include="System.Data.SqlClient" Version="4.5.0" />
<PackageReference Include="System.Data.SqlClient" Version="4.6.0-preview3-27014-02" />
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
</ItemGroup>
<ItemGroup>

View File

@@ -80,7 +80,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
});
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(mockConnection.Object);
@@ -146,7 +146,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
.Returns(() => Task.Run(() => {}));
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.SetupSequence(factory => factory.CreateSqlConnection(It.IsAny<string>()))
mockFactory.SetupSequence(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(mockConnection.Object)
.Returns(mockConnection2.Object);
@@ -209,7 +209,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
});
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(mockConnection.Object);
@@ -282,7 +282,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
connectionMock.Setup(c => c.Database).Returns(expectedDbName);
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(connectionMock.Object);
var connectionService = new ConnectionService(mockFactory.Object);
@@ -321,8 +321,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
var dummySqlConnection = new TestSqlConnection(null);
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
.Returns((string connString) =>
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns((string connString, string azureAccountToken) =>
{
dummySqlConnection.ConnectionString = connString;
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
@@ -775,7 +775,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
// Setup mock connection factory to inject query results
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(CreateMockDbConnection(new[] {data}));
var connectionService = new ConnectionService(mockFactory.Object);
@@ -1324,8 +1324,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
var connection = new TestSqlConnection(null);
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
.Returns((string connString) =>
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns((string connString, string azureAccountToken) =>
{
connection.ConnectionString = connString;
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
@@ -1374,8 +1374,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
var connection = mockConnection.Object;
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
.Returns((string connString) =>
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns((string connString, string azureAccountToken) =>
{
connection.ConnectionString = connString;
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
@@ -1427,8 +1427,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
var connection = mockConnection.Object;
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
.Returns((string connString) =>
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns((string connString, string azureAccountToken) =>
{
connection.ConnectionString = connString;
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
@@ -1484,5 +1484,32 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
Assert.Equal(false, details.TrustServerCertificate);
Assert.Equal(30, details.ConnectTimeout);
}
[Fact]
public async void ConnectingWithAzureAccountUsesToken()
{
// Set up mock connection factory
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(new TestSqlConnection(null));
var connectionService = new ConnectionService(mockFactory.Object);
var details = TestObjects.GetTestConnectionDetails();
var azureAccountToken = "testAzureAccountToken";
details.AzureAccountToken = azureAccountToken;
details.UserName = "";
details.Password = "";
details.AuthenticationType = "AzureMFA";
// If I open a connection using connection details that include an account token
await connectionService.Connect(new ConnectParams
{
OwnerUri = "testURI",
Connection = details
});
// Then the connection factory got called with details including an account token
mockFactory.Verify(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.Is<string>(accountToken => accountToken == azureAccountToken)), Times.Once());
}
}
}

View File

@@ -0,0 +1,36 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
using Microsoft.SqlTools.ServiceLayer.UnitTests.Utility;
using Xunit;
namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
{
/// <summary>
/// Tests for ReliableConnection code
/// </summary>
public class ReliableConnectionTests
{
[Fact]
public void ReliableSqlConnectionUsesAzureToken()
{
ConnectionDetails details = TestObjects.GetTestConnectionDetails();
details.UserName = "";
details.Password = "";
string connectionString = ConnectionService.BuildConnectionString(details);
string azureAccountToken = "testAzureAccountToken";
RetryPolicy retryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
// If I create a ReliableSqlConnection using an azure account token
var reliableConnection = new ReliableSqlConnection(connectionString, retryPolicy, retryPolicy, azureAccountToken);
// Then the connection's azureAccountToken gets set
Assert.Equal(azureAccountToken, reliableConnection.GetUnderlyingConnection().AccessToken);
}
}
}

View File

@@ -14,8 +14,8 @@
<PackageReference Include="NUnit" Version="3.10.1" />
<PackageReference Include="xunit" Version="2.2.0" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" />
<PackageReference Include="System.Data.SqlClient" Version="4.5.0" />
<PackageReference Include="System.Text.Encoding.CodePages" Version="4.5.0" />
<PackageReference Include="System.Data.SqlClient" Version="4.6.0-preview3-27014-02" />
<PackageReference Include="System.Text.Encoding.CodePages" Version="4.6.0-preview3-26501-04" />
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
</ItemGroup>
<ItemGroup>

View File

@@ -408,7 +408,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
// Stub out the connection to avoid a 30second timeout while attempting to connect.
// The tests don't need any connection context anyhow so this doesn't impact the scenario
mockConnectionOpener.Setup(b => b.OpenSqlConnection(It.IsAny<ConnectionInfo>(), It.IsAny<string>()))
mockConnectionOpener.Setup(b => b.OpenServerConnection(It.IsAny<ConnectionInfo>(), It.IsAny<string>()))
.Throws<Exception>();
connectionServiceMock.Setup(c => c.Connect(It.IsAny<ConnectParams>()))
.Returns((ConnectParams connectParams) => Task.FromResult(GetCompleteParamsForConnection(connectParams.OwnerUri, details)));

View File

@@ -173,7 +173,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution
private static ISqlConnectionFactory CreateMockFactory(TestResultSet[] data, bool throwOnExecute, bool throwOnRead)
{
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(() => CreateTestConnection(data, throwOnExecute, throwOnRead));
return mockFactory.Object;
@@ -184,7 +184,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution
// Create a connection info and add the default connection to it
ISqlConnectionFactory factory = CreateMockFactory(data, throwOnExecute, throwOnRead);
ConnectionInfo ci = new ConnectionInfo(factory, Constants.OwnerUri, StandardConnectionDetails);
ci.ConnectionTypeToConnectionMap[ConnectionType.Default] = factory.CreateSqlConnection(null);
ci.ConnectionTypeToConnectionMap[ConnectionType.Default] = factory.CreateSqlConnection(null, null);
return ci;
}

View File

@@ -427,7 +427,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.Execution
private static DbConnection GetConnection(ConnectionInfo info)
{
return info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails));
return info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails), null);
}
[SuppressMessage("ReSharper", "UnusedParameter.Local")]

View File

@@ -386,7 +386,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.Execution
private static DbDataReader GetReader(TestResultSet[] dataSet, bool throwOnRead, string query)
{
var info = Common.CreateTestConnectionInfo(dataSet, false, throwOnRead);
var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails));
var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails), null);
var command = connection.CreateCommand();
command.CommandText = query;
return command.ExecuteReader();

View File

@@ -174,7 +174,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.SaveResults
private static DbDataReader GetReader(TestResultSet[] dataSet, string query)
{
var info = Common.CreateTestConnectionInfo(dataSet, false, false);
var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails));
var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails), null);
var command = connection.CreateCommand();
command.CommandText = query;
return command.ExecuteReader();

View File

@@ -8,6 +8,8 @@ using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlTools.ServiceLayer.LanguageServices;
@@ -273,7 +275,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
/// </summary>
public class TestSqlConnectionFactory : ISqlConnectionFactory
{
public DbConnection CreateSqlConnection(string connectionString)
public DbConnection CreateSqlConnection(string connectionString, string azureAccountToken)
{
return new TestSqlConnection(null)
{