Add support for Azure Active Directory connections (#727)

This commit is contained in:
Matt Irvine
2018-11-13 11:50:30 -08:00
committed by GitHub
parent 2cb7f682c5
commit 7f28f249de
32 changed files with 291 additions and 121 deletions

View File

@@ -80,7 +80,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
});
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(mockConnection.Object);
@@ -146,7 +146,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
.Returns(() => Task.Run(() => {}));
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.SetupSequence(factory => factory.CreateSqlConnection(It.IsAny<string>()))
mockFactory.SetupSequence(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(mockConnection.Object)
.Returns(mockConnection2.Object);
@@ -209,7 +209,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
});
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(mockConnection.Object);
@@ -282,7 +282,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
connectionMock.Setup(c => c.Database).Returns(expectedDbName);
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(connectionMock.Object);
var connectionService = new ConnectionService(mockFactory.Object);
@@ -321,8 +321,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
var dummySqlConnection = new TestSqlConnection(null);
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
.Returns((string connString) =>
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns((string connString, string azureAccountToken) =>
{
dummySqlConnection.ConnectionString = connString;
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
@@ -775,7 +775,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
// Setup mock connection factory to inject query results
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(CreateMockDbConnection(new[] {data}));
var connectionService = new ConnectionService(mockFactory.Object);
@@ -1324,8 +1324,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
var connection = new TestSqlConnection(null);
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
.Returns((string connString) =>
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns((string connString, string azureAccountToken) =>
{
connection.ConnectionString = connString;
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
@@ -1374,8 +1374,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
var connection = mockConnection.Object;
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
.Returns((string connString) =>
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns((string connString, string azureAccountToken) =>
{
connection.ConnectionString = connString;
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
@@ -1427,8 +1427,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
var connection = mockConnection.Object;
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
.Returns((string connString) =>
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns((string connString, string azureAccountToken) =>
{
connection.ConnectionString = connString;
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
@@ -1484,5 +1484,32 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
Assert.Equal(false, details.TrustServerCertificate);
Assert.Equal(30, details.ConnectTimeout);
}
[Fact]
public async void ConnectingWithAzureAccountUsesToken()
{
// Set up mock connection factory
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(new TestSqlConnection(null));
var connectionService = new ConnectionService(mockFactory.Object);
var details = TestObjects.GetTestConnectionDetails();
var azureAccountToken = "testAzureAccountToken";
details.AzureAccountToken = azureAccountToken;
details.UserName = "";
details.Password = "";
details.AuthenticationType = "AzureMFA";
// If I open a connection using connection details that include an account token
await connectionService.Connect(new ConnectParams
{
OwnerUri = "testURI",
Connection = details
});
// Then the connection factory got called with details including an account token
mockFactory.Verify(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.Is<string>(accountToken => accountToken == azureAccountToken)), Times.Once());
}
}
}

View File

@@ -0,0 +1,36 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
using Microsoft.SqlTools.ServiceLayer.UnitTests.Utility;
using Xunit;
namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
{
/// <summary>
/// Tests for ReliableConnection code
/// </summary>
public class ReliableConnectionTests
{
[Fact]
public void ReliableSqlConnectionUsesAzureToken()
{
ConnectionDetails details = TestObjects.GetTestConnectionDetails();
details.UserName = "";
details.Password = "";
string connectionString = ConnectionService.BuildConnectionString(details);
string azureAccountToken = "testAzureAccountToken";
RetryPolicy retryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
// If I create a ReliableSqlConnection using an azure account token
var reliableConnection = new ReliableSqlConnection(connectionString, retryPolicy, retryPolicy, azureAccountToken);
// Then the connection's azureAccountToken gets set
Assert.Equal(azureAccountToken, reliableConnection.GetUnderlyingConnection().AccessToken);
}
}
}

View File

@@ -14,8 +14,8 @@
<PackageReference Include="NUnit" Version="3.10.1" />
<PackageReference Include="xunit" Version="2.2.0" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" />
<PackageReference Include="System.Data.SqlClient" Version="4.5.0" />
<PackageReference Include="System.Text.Encoding.CodePages" Version="4.5.0" />
<PackageReference Include="System.Data.SqlClient" Version="4.6.0-preview3-27014-02" />
<PackageReference Include="System.Text.Encoding.CodePages" Version="4.6.0-preview3-26501-04" />
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
</ItemGroup>
<ItemGroup>

View File

@@ -408,7 +408,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
// Stub out the connection to avoid a 30second timeout while attempting to connect.
// The tests don't need any connection context anyhow so this doesn't impact the scenario
mockConnectionOpener.Setup(b => b.OpenSqlConnection(It.IsAny<ConnectionInfo>(), It.IsAny<string>()))
mockConnectionOpener.Setup(b => b.OpenServerConnection(It.IsAny<ConnectionInfo>(), It.IsAny<string>()))
.Throws<Exception>();
connectionServiceMock.Setup(c => c.Connect(It.IsAny<ConnectParams>()))
.Returns((ConnectParams connectParams) => Task.FromResult(GetCompleteParamsForConnection(connectParams.OwnerUri, details)));

View File

@@ -173,7 +173,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution
private static ISqlConnectionFactory CreateMockFactory(TestResultSet[] data, bool throwOnExecute, bool throwOnRead)
{
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(() => CreateTestConnection(data, throwOnExecute, throwOnRead));
return mockFactory.Object;
@@ -184,7 +184,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution
// Create a connection info and add the default connection to it
ISqlConnectionFactory factory = CreateMockFactory(data, throwOnExecute, throwOnRead);
ConnectionInfo ci = new ConnectionInfo(factory, Constants.OwnerUri, StandardConnectionDetails);
ci.ConnectionTypeToConnectionMap[ConnectionType.Default] = factory.CreateSqlConnection(null);
ci.ConnectionTypeToConnectionMap[ConnectionType.Default] = factory.CreateSqlConnection(null, null);
return ci;
}

View File

@@ -427,7 +427,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.Execution
private static DbConnection GetConnection(ConnectionInfo info)
{
return info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails));
return info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails), null);
}
[SuppressMessage("ReSharper", "UnusedParameter.Local")]

View File

@@ -386,7 +386,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.Execution
private static DbDataReader GetReader(TestResultSet[] dataSet, bool throwOnRead, string query)
{
var info = Common.CreateTestConnectionInfo(dataSet, false, throwOnRead);
var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails));
var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails), null);
var command = connection.CreateCommand();
command.CommandText = query;
return command.ExecuteReader();

View File

@@ -174,7 +174,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.SaveResults
private static DbDataReader GetReader(TestResultSet[] dataSet, string query)
{
var info = Common.CreateTestConnectionInfo(dataSet, false, false);
var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails));
var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails), null);
var command = connection.CreateCommand();
command.CommandText = query;
return command.ExecuteReader();

View File

@@ -8,6 +8,8 @@ using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlTools.ServiceLayer.LanguageServices;
@@ -273,7 +275,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
/// </summary>
public class TestSqlConnectionFactory : ISqlConnectionFactory
{
public DbConnection CreateSqlConnection(string connectionString)
public DbConnection CreateSqlConnection(string connectionString, string azureAccountToken)
{
return new TestSqlConnection(null)
{