Add support for Azure Active Directory connections (#727)

This commit is contained in:
Matt Irvine
2018-11-13 11:50:30 -08:00
committed by GitHub
parent 2cb7f682c5
commit 7f28f249de
32 changed files with 291 additions and 121 deletions

View File

@@ -6,6 +6,7 @@
using System;
using System.Collections.Generic;
using System.Data.SqlClient;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Scripting.Contracts;
using Microsoft.SqlTools.Utility;
using Microsoft.SqlServer.Management.Common;
@@ -42,10 +43,20 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
ServerConnection = serverConnection;
}
public ScriptAsScriptingOperation(ScriptingParams parameters) : base(parameters)
public ScriptAsScriptingOperation(ScriptingParams parameters, string azureAccountToken) : base(parameters)
{
SqlConnection sqlConnection = new SqlConnection(this.Parameters.ConnectionString);
if (azureAccountToken != null)
{
sqlConnection.AccessToken = azureAccountToken;
}
ServerConnection = new ServerConnection(sqlConnection);
if (azureAccountToken != null)
{
ServerConnection.AccessToken = new AzureAccessToken(azureAccountToken);
}
disconnectAtDispose = true;
}

View File

@@ -26,8 +26,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
private int eventSequenceNumber = 1;
public ScriptingScriptOperation(ScriptingParams parameters): base(parameters)
private string azureAccessToken;
public ScriptingScriptOperation(ScriptingParams parameters, string azureAccessToken): base(parameters)
{
this.azureAccessToken = azureAccessToken;
}
public override void Execute()
@@ -200,7 +203,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
selectedObjects.Count(),
string.Join(", ", selectedObjects)));
string server = GetServerNameFromLiveInstance(this.Parameters.ConnectionString);
string server = GetServerNameFromLiveInstance(this.Parameters.ConnectionString, this.azureAccessToken);
string database = new SqlConnectionStringBuilder(this.Parameters.ConnectionString).InitialCatalog;
foreach (ScriptingObject scriptingObject in selectedObjects)

View File

@@ -111,12 +111,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
// use the owner uri property to lookup its associated ConnectionInfo
// and then build a connection string out of that
ConnectionInfo connInfo = null;
string accessToken = null;
if (parameters.ConnectionString == null)
{
ScriptingService.ConnectionServiceInstance.TryFindConnection(parameters.OwnerUri, out connInfo);
if (connInfo != null)
{
parameters.ConnectionString = ConnectionService.BuildConnectionString(connInfo.ConnectionDetails);
accessToken = connInfo.ConnectionDetails.AzureAccountToken;
}
else
{
@@ -126,11 +128,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
if (!ShouldCreateScriptAsOperation(parameters))
{
operation = new ScriptingScriptOperation(parameters);
operation = new ScriptingScriptOperation(parameters, accessToken);
}
else
{
operation = new ScriptAsScriptingOperation(parameters);
operation = new ScriptAsScriptingOperation(parameters, accessToken);
}
operation.PlanNotification += (sender, e) => requestContext.SendEvent(ScriptingPlanNotificationEvent.Type, e).Wait();

View File

@@ -4,6 +4,7 @@
//
using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Scripting.Contracts;
using Microsoft.SqlTools.Utility;
using System;
@@ -71,17 +72,28 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
parameters.OperationId = this.OperationId;
}
protected string GetServerNameFromLiveInstance(string connectionString)
protected string GetServerNameFromLiveInstance(string connectionString, string azureAccessToken)
{
string serverName = null;
using (SqlConnection connection = new SqlConnection(connectionString))
{
if (azureAccessToken != null)
{
connection.AccessToken = azureAccessToken;
}
connection.Open();
try
{
ServerConnection serverConnection = new ServerConnection(connection);
ServerConnection serverConnection;
if (azureAccessToken == null)
{
serverConnection = new ServerConnection(connection);
}
else
{
serverConnection = new ServerConnection(connection, new AzureAccessToken(azureAccessToken));
}
serverName = serverConnection.TrueName;
}
catch (SqlException e)