mirror of
https://github.com/ckaczor/sqltoolsservice.git
synced 2026-01-15 09:35:37 -05:00
Add support for Azure Active Directory connections (#727)
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Data.SqlClient;
|
||||
using Microsoft.SqlTools.ServiceLayer.Connection;
|
||||
using Microsoft.SqlTools.ServiceLayer.Scripting.Contracts;
|
||||
using Microsoft.SqlTools.Utility;
|
||||
using Microsoft.SqlServer.Management.Common;
|
||||
@@ -42,10 +43,20 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
|
||||
ServerConnection = serverConnection;
|
||||
}
|
||||
|
||||
public ScriptAsScriptingOperation(ScriptingParams parameters) : base(parameters)
|
||||
public ScriptAsScriptingOperation(ScriptingParams parameters, string azureAccountToken) : base(parameters)
|
||||
{
|
||||
SqlConnection sqlConnection = new SqlConnection(this.Parameters.ConnectionString);
|
||||
if (azureAccountToken != null)
|
||||
{
|
||||
sqlConnection.AccessToken = azureAccountToken;
|
||||
}
|
||||
|
||||
ServerConnection = new ServerConnection(sqlConnection);
|
||||
if (azureAccountToken != null)
|
||||
{
|
||||
ServerConnection.AccessToken = new AzureAccessToken(azureAccountToken);
|
||||
}
|
||||
|
||||
disconnectAtDispose = true;
|
||||
}
|
||||
|
||||
|
||||
@@ -26,8 +26,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
|
||||
|
||||
private int eventSequenceNumber = 1;
|
||||
|
||||
public ScriptingScriptOperation(ScriptingParams parameters): base(parameters)
|
||||
private string azureAccessToken;
|
||||
|
||||
public ScriptingScriptOperation(ScriptingParams parameters, string azureAccessToken): base(parameters)
|
||||
{
|
||||
this.azureAccessToken = azureAccessToken;
|
||||
}
|
||||
|
||||
public override void Execute()
|
||||
@@ -200,7 +203,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
|
||||
selectedObjects.Count(),
|
||||
string.Join(", ", selectedObjects)));
|
||||
|
||||
string server = GetServerNameFromLiveInstance(this.Parameters.ConnectionString);
|
||||
string server = GetServerNameFromLiveInstance(this.Parameters.ConnectionString, this.azureAccessToken);
|
||||
string database = new SqlConnectionStringBuilder(this.Parameters.ConnectionString).InitialCatalog;
|
||||
|
||||
foreach (ScriptingObject scriptingObject in selectedObjects)
|
||||
|
||||
@@ -111,12 +111,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
|
||||
// use the owner uri property to lookup its associated ConnectionInfo
|
||||
// and then build a connection string out of that
|
||||
ConnectionInfo connInfo = null;
|
||||
string accessToken = null;
|
||||
if (parameters.ConnectionString == null)
|
||||
{
|
||||
ScriptingService.ConnectionServiceInstance.TryFindConnection(parameters.OwnerUri, out connInfo);
|
||||
if (connInfo != null)
|
||||
{
|
||||
parameters.ConnectionString = ConnectionService.BuildConnectionString(connInfo.ConnectionDetails);
|
||||
accessToken = connInfo.ConnectionDetails.AzureAccountToken;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -126,11 +128,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
|
||||
|
||||
if (!ShouldCreateScriptAsOperation(parameters))
|
||||
{
|
||||
operation = new ScriptingScriptOperation(parameters);
|
||||
operation = new ScriptingScriptOperation(parameters, accessToken);
|
||||
}
|
||||
else
|
||||
{
|
||||
operation = new ScriptAsScriptingOperation(parameters);
|
||||
operation = new ScriptAsScriptingOperation(parameters, accessToken);
|
||||
}
|
||||
|
||||
operation.PlanNotification += (sender, e) => requestContext.SendEvent(ScriptingPlanNotificationEvent.Type, e).Wait();
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
//
|
||||
|
||||
using Microsoft.SqlServer.Management.Common;
|
||||
using Microsoft.SqlTools.ServiceLayer.Connection;
|
||||
using Microsoft.SqlTools.ServiceLayer.Scripting.Contracts;
|
||||
using Microsoft.SqlTools.Utility;
|
||||
using System;
|
||||
@@ -71,17 +72,28 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
|
||||
parameters.OperationId = this.OperationId;
|
||||
}
|
||||
|
||||
protected string GetServerNameFromLiveInstance(string connectionString)
|
||||
protected string GetServerNameFromLiveInstance(string connectionString, string azureAccessToken)
|
||||
{
|
||||
string serverName = null;
|
||||
using (SqlConnection connection = new SqlConnection(connectionString))
|
||||
{
|
||||
if (azureAccessToken != null)
|
||||
{
|
||||
connection.AccessToken = azureAccessToken;
|
||||
}
|
||||
connection.Open();
|
||||
|
||||
try
|
||||
{
|
||||
|
||||
ServerConnection serverConnection = new ServerConnection(connection);
|
||||
ServerConnection serverConnection;
|
||||
if (azureAccessToken == null)
|
||||
{
|
||||
serverConnection = new ServerConnection(connection);
|
||||
}
|
||||
else
|
||||
{
|
||||
serverConnection = new ServerConnection(connection, new AzureAccessToken(azureAccessToken));
|
||||
}
|
||||
serverName = serverConnection.TrueName;
|
||||
}
|
||||
catch (SqlException e)
|
||||
|
||||
Reference in New Issue
Block a user