mirror of
https://github.com/ckaczor/sqltoolsservice.git
synced 2026-01-19 01:25:40 -05:00
Add support for Azure Active Directory connections (#727)
This commit is contained in:
@@ -80,7 +80,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
||||
});
|
||||
|
||||
var mockFactory = new Mock<ISqlConnectionFactory>();
|
||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
|
||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
|
||||
.Returns(mockConnection.Object);
|
||||
|
||||
|
||||
@@ -146,7 +146,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
||||
.Returns(() => Task.Run(() => {}));
|
||||
|
||||
var mockFactory = new Mock<ISqlConnectionFactory>();
|
||||
mockFactory.SetupSequence(factory => factory.CreateSqlConnection(It.IsAny<string>()))
|
||||
mockFactory.SetupSequence(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
|
||||
.Returns(mockConnection.Object)
|
||||
.Returns(mockConnection2.Object);
|
||||
|
||||
@@ -209,7 +209,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
||||
});
|
||||
|
||||
var mockFactory = new Mock<ISqlConnectionFactory>();
|
||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
|
||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
|
||||
.Returns(mockConnection.Object);
|
||||
|
||||
|
||||
@@ -282,7 +282,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
||||
connectionMock.Setup(c => c.Database).Returns(expectedDbName);
|
||||
|
||||
var mockFactory = new Mock<ISqlConnectionFactory>();
|
||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
|
||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
|
||||
.Returns(connectionMock.Object);
|
||||
|
||||
var connectionService = new ConnectionService(mockFactory.Object);
|
||||
@@ -321,8 +321,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
||||
var dummySqlConnection = new TestSqlConnection(null);
|
||||
|
||||
var mockFactory = new Mock<ISqlConnectionFactory>();
|
||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
|
||||
.Returns((string connString) =>
|
||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
|
||||
.Returns((string connString, string azureAccountToken) =>
|
||||
{
|
||||
dummySqlConnection.ConnectionString = connString;
|
||||
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
|
||||
@@ -775,7 +775,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
||||
|
||||
// Setup mock connection factory to inject query results
|
||||
var mockFactory = new Mock<ISqlConnectionFactory>();
|
||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
|
||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
|
||||
.Returns(CreateMockDbConnection(new[] {data}));
|
||||
var connectionService = new ConnectionService(mockFactory.Object);
|
||||
|
||||
@@ -1324,8 +1324,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
||||
var connection = new TestSqlConnection(null);
|
||||
|
||||
var mockFactory = new Mock<ISqlConnectionFactory>();
|
||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
|
||||
.Returns((string connString) =>
|
||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
|
||||
.Returns((string connString, string azureAccountToken) =>
|
||||
{
|
||||
connection.ConnectionString = connString;
|
||||
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
|
||||
@@ -1374,8 +1374,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
||||
var connection = mockConnection.Object;
|
||||
|
||||
var mockFactory = new Mock<ISqlConnectionFactory>();
|
||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
|
||||
.Returns((string connString) =>
|
||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
|
||||
.Returns((string connString, string azureAccountToken) =>
|
||||
{
|
||||
connection.ConnectionString = connString;
|
||||
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
|
||||
@@ -1427,8 +1427,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
||||
var connection = mockConnection.Object;
|
||||
|
||||
var mockFactory = new Mock<ISqlConnectionFactory>();
|
||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
|
||||
.Returns((string connString) =>
|
||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
|
||||
.Returns((string connString, string azureAccountToken) =>
|
||||
{
|
||||
connection.ConnectionString = connString;
|
||||
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
|
||||
@@ -1484,5 +1484,32 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
||||
Assert.Equal(false, details.TrustServerCertificate);
|
||||
Assert.Equal(30, details.ConnectTimeout);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async void ConnectingWithAzureAccountUsesToken()
|
||||
{
|
||||
// Set up mock connection factory
|
||||
var mockFactory = new Mock<ISqlConnectionFactory>();
|
||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
|
||||
.Returns(new TestSqlConnection(null));
|
||||
var connectionService = new ConnectionService(mockFactory.Object);
|
||||
|
||||
var details = TestObjects.GetTestConnectionDetails();
|
||||
var azureAccountToken = "testAzureAccountToken";
|
||||
details.AzureAccountToken = azureAccountToken;
|
||||
details.UserName = "";
|
||||
details.Password = "";
|
||||
details.AuthenticationType = "AzureMFA";
|
||||
|
||||
// If I open a connection using connection details that include an account token
|
||||
await connectionService.Connect(new ConnectParams
|
||||
{
|
||||
OwnerUri = "testURI",
|
||||
Connection = details
|
||||
});
|
||||
|
||||
// Then the connection factory got called with details including an account token
|
||||
mockFactory.Verify(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.Is<string>(accountToken => accountToken == azureAccountToken)), Times.Once());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||
//
|
||||
|
||||
using Microsoft.SqlTools.ServiceLayer.Connection;
|
||||
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
|
||||
using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
|
||||
using Microsoft.SqlTools.ServiceLayer.UnitTests.Utility;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
||||
{
|
||||
/// <summary>
|
||||
/// Tests for ReliableConnection code
|
||||
/// </summary>
|
||||
public class ReliableConnectionTests
|
||||
{
|
||||
[Fact]
|
||||
public void ReliableSqlConnectionUsesAzureToken()
|
||||
{
|
||||
ConnectionDetails details = TestObjects.GetTestConnectionDetails();
|
||||
details.UserName = "";
|
||||
details.Password = "";
|
||||
string connectionString = ConnectionService.BuildConnectionString(details);
|
||||
string azureAccountToken = "testAzureAccountToken";
|
||||
RetryPolicy retryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
|
||||
|
||||
// If I create a ReliableSqlConnection using an azure account token
|
||||
var reliableConnection = new ReliableSqlConnection(connectionString, retryPolicy, retryPolicy, azureAccountToken);
|
||||
|
||||
// Then the connection's azureAccountToken gets set
|
||||
Assert.Equal(azureAccountToken, reliableConnection.GetUnderlyingConnection().AccessToken);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -14,8 +14,8 @@
|
||||
<PackageReference Include="NUnit" Version="3.10.1" />
|
||||
<PackageReference Include="xunit" Version="2.2.0" />
|
||||
<PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" />
|
||||
<PackageReference Include="System.Data.SqlClient" Version="4.5.0" />
|
||||
<PackageReference Include="System.Text.Encoding.CodePages" Version="4.5.0" />
|
||||
<PackageReference Include="System.Data.SqlClient" Version="4.6.0-preview3-27014-02" />
|
||||
<PackageReference Include="System.Text.Encoding.CodePages" Version="4.6.0-preview3-26501-04" />
|
||||
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
|
||||
@@ -408,7 +408,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
|
||||
|
||||
// Stub out the connection to avoid a 30second timeout while attempting to connect.
|
||||
// The tests don't need any connection context anyhow so this doesn't impact the scenario
|
||||
mockConnectionOpener.Setup(b => b.OpenSqlConnection(It.IsAny<ConnectionInfo>(), It.IsAny<string>()))
|
||||
mockConnectionOpener.Setup(b => b.OpenServerConnection(It.IsAny<ConnectionInfo>(), It.IsAny<string>()))
|
||||
.Throws<Exception>();
|
||||
connectionServiceMock.Setup(c => c.Connect(It.IsAny<ConnectParams>()))
|
||||
.Returns((ConnectParams connectParams) => Task.FromResult(GetCompleteParamsForConnection(connectParams.OwnerUri, details)));
|
||||
|
||||
@@ -173,7 +173,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution
|
||||
private static ISqlConnectionFactory CreateMockFactory(TestResultSet[] data, bool throwOnExecute, bool throwOnRead)
|
||||
{
|
||||
var mockFactory = new Mock<ISqlConnectionFactory>();
|
||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
|
||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
|
||||
.Returns(() => CreateTestConnection(data, throwOnExecute, throwOnRead));
|
||||
|
||||
return mockFactory.Object;
|
||||
@@ -184,7 +184,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution
|
||||
// Create a connection info and add the default connection to it
|
||||
ISqlConnectionFactory factory = CreateMockFactory(data, throwOnExecute, throwOnRead);
|
||||
ConnectionInfo ci = new ConnectionInfo(factory, Constants.OwnerUri, StandardConnectionDetails);
|
||||
ci.ConnectionTypeToConnectionMap[ConnectionType.Default] = factory.CreateSqlConnection(null);
|
||||
ci.ConnectionTypeToConnectionMap[ConnectionType.Default] = factory.CreateSqlConnection(null, null);
|
||||
return ci;
|
||||
}
|
||||
|
||||
|
||||
@@ -427,7 +427,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.Execution
|
||||
|
||||
private static DbConnection GetConnection(ConnectionInfo info)
|
||||
{
|
||||
return info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails));
|
||||
return info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails), null);
|
||||
}
|
||||
|
||||
[SuppressMessage("ReSharper", "UnusedParameter.Local")]
|
||||
|
||||
@@ -386,7 +386,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.Execution
|
||||
private static DbDataReader GetReader(TestResultSet[] dataSet, bool throwOnRead, string query)
|
||||
{
|
||||
var info = Common.CreateTestConnectionInfo(dataSet, false, throwOnRead);
|
||||
var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails));
|
||||
var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails), null);
|
||||
var command = connection.CreateCommand();
|
||||
command.CommandText = query;
|
||||
return command.ExecuteReader();
|
||||
|
||||
@@ -174,7 +174,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.SaveResults
|
||||
private static DbDataReader GetReader(TestResultSet[] dataSet, string query)
|
||||
{
|
||||
var info = Common.CreateTestConnectionInfo(dataSet, false, false);
|
||||
var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails));
|
||||
var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails), null);
|
||||
var command = connection.CreateCommand();
|
||||
command.CommandText = query;
|
||||
return command.ExecuteReader();
|
||||
|
||||
@@ -8,6 +8,8 @@ using System.Collections.Generic;
|
||||
using System.Data;
|
||||
using System.Data.Common;
|
||||
using System.Linq;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.SqlTools.ServiceLayer.Connection;
|
||||
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
|
||||
using Microsoft.SqlTools.ServiceLayer.LanguageServices;
|
||||
@@ -273,7 +275,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
|
||||
/// </summary>
|
||||
public class TestSqlConnectionFactory : ISqlConnectionFactory
|
||||
{
|
||||
public DbConnection CreateSqlConnection(string connectionString)
|
||||
public DbConnection CreateSqlConnection(string connectionString, string azureAccountToken)
|
||||
{
|
||||
return new TestSqlConnection(null)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user