Enable scripting for Logins (#2003)

* Add a LoginActions class

* Enable scripting for Login
This commit is contained in:
Karl Burtram
2023-04-13 18:20:53 -07:00
committed by GitHub
parent 948ae3903e
commit 78136e53dd
5 changed files with 174 additions and 35 deletions

View File

@@ -20,7 +20,21 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
{
internal class LoginServiceHandlerImpl
{
private Dictionary<string, string> contextIdToConnectionUriMap = new Dictionary<string, string>();
private class ViewState
{
public bool IsNewObject { get; set; }
public string ConnectionUri { get; set; }
public ViewState(bool isNewObject, string connectionUri)
{
this.IsNewObject = isNewObject;
this.ConnectionUri = connectionUri;
}
}
private Dictionary<string, ViewState> contextIdToViewState = new Dictionary<string, ViewState>();
private ConnectionService? connectionService;
@@ -46,10 +60,19 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
/// </summary>
internal async Task HandleCreateLoginRequest(CreateLoginParams parameters, RequestContext<object> requestContext)
{
DoHandleCreateLoginRequest(parameters.ContextId, parameters.Login, RunType.RunNow);
await requestContext.SendResult(new object());
}
private string DoHandleCreateLoginRequest(
string contextId, LoginInfo login, RunType runType)
{
ViewState? viewState;
this.contextIdToViewState.TryGetValue(contextId, out viewState);
ConnectionInfo connInfo;
string ownerUri;
contextIdToConnectionUriMap.TryGetValue(parameters.ContextId, out ownerUri);
ConnectionServiceInstance.TryFindConnection(ownerUri, out connInfo);
ConnectionServiceInstance.TryFindConnection(viewState?.ConnectionUri, out connInfo);
if (connInfo == null)
{
@@ -57,7 +80,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
}
CDataContainer dataContainer = CDataContainer.CreateDataContainer(connInfo, databaseExists: true);
LoginPrototype prototype = new LoginPrototype(dataContainer.Server, parameters.Login);
LoginPrototype prototype = new LoginPrototype(dataContainer.Server, login);
if (prototype.LoginType == SqlServer.Management.Smo.LoginType.SqlLogin)
{
@@ -77,39 +100,46 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
}
}
prototype.ApplyGeneralChanges(dataContainer.Server);
// TODO move this to LoginData
// TODO support role assignment for Azure
LoginPrototype newPrototype = new LoginPrototype(dataContainer.Server, dataContainer.Server.Logins[parameters.Login.Name]);
var _ =newPrototype.ServerRoles.ServerRoleNames;
foreach (string role in parameters.Login.ServerRoles ?? Enumerable.Empty<string>())
prototype.ServerRoles.PopulateServerRoles();
foreach (string role in login.ServerRoles ?? Enumerable.Empty<string>())
{
newPrototype.ServerRoles.SetMember(role, true);
prototype.ServerRoles.SetMember(role, true);
}
newPrototype.ApplyServerRoleChanges(dataContainer.Server);
await requestContext.SendResult(new object());
return ConfigureLogin(
dataContainer,
ConfigAction.Create,
runType,
prototype);
}
internal async Task HandleUpdateLoginRequest(UpdateLoginParams parameters, RequestContext<object> requestContext)
{
DoHandleUpdateLoginRequest(parameters.ContextId, parameters.Login, RunType.RunNow);
await requestContext.SendResult(new object());
}
private string DoHandleUpdateLoginRequest(
string contextId, LoginInfo login, RunType runType)
{
ViewState? viewState;
this.contextIdToViewState.TryGetValue(contextId, out viewState);
ConnectionInfo connInfo;
string ownerUri;
contextIdToConnectionUriMap.TryGetValue(parameters.ContextId, out ownerUri);
ConnectionServiceInstance.TryFindConnection(ownerUri, out connInfo);
ConnectionServiceInstance.TryFindConnection(viewState?.ConnectionUri, out connInfo);
if (connInfo == null)
{
throw new ArgumentException("Invalid ConnectionUri");
}
CDataContainer dataContainer = CDataContainer.CreateDataContainer(connInfo, databaseExists: true);
LoginPrototype prototype = new LoginPrototype(dataContainer.Server, dataContainer.Server.Logins[parameters.Login.Name]);
LoginPrototype prototype = new LoginPrototype(dataContainer.Server, dataContainer.Server.Logins[login.Name]);
var login = parameters.Login;
prototype.SqlPassword = login.Password;
if (0 != String.Compare(login.DefaultLanguage, SR.DefaultLanguagePlaceholder, StringComparison.Ordinal))
if (0 != string.Compare(login.DefaultLanguage, SR.DefaultLanguagePlaceholder, StringComparison.Ordinal))
{
string[] arr = login.DefaultLanguage?.Split(" - ");
if (arr != null && arr.Length > 1)
@@ -155,15 +185,44 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
prototype.ServerRoles.SetMember(role, true);
}
prototype.ApplyGeneralChanges(dataContainer.Server);
prototype.ApplyServerRoleChanges(dataContainer.Server);
prototype.ApplyDatabaseRoleChanges(dataContainer.Server);
await requestContext.SendResult(new object());
return ConfigureLogin(
dataContainer,
ConfigAction.Update,
runType,
prototype);
}
/// <summary>
/// Handle request to script a user
/// </summary>
internal async Task HandleScriptLoginRequest(ScriptLoginParams parameters, RequestContext<string> requestContext)
{
if (parameters.ContextId == null)
{
throw new ArgumentException("Invalid context ID");
}
ViewState viewState;
this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState);
if (viewState == null)
{
throw new ArgumentException("Invalid context ID view state");
}
string sqlScript = (viewState.IsNewObject)
? DoHandleCreateLoginRequest(parameters.ContextId, parameters.Login, RunType.ScriptToWindow)
: DoHandleUpdateLoginRequest(parameters.ContextId, parameters.Login, RunType.ScriptToWindow);
await requestContext.SendResult(sqlScript);
}
internal async Task HandleInitializeLoginViewRequest(InitializeLoginViewRequestParams parameters, RequestContext<LoginViewInfo> requestContext)
{
contextIdToConnectionUriMap.Add(parameters.ContextId, parameters.ConnectionUri);
this.contextIdToViewState.Add(
parameters.ContextId,
new ViewState(parameters.IsNewObject, parameters.ConnectionUri));
ConnectionInfo connInfo;
ConnectionServiceInstance.TryFindConnection(parameters.ConnectionUri, out connInfo);
if (connInfo == null)
@@ -192,7 +251,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
: new LoginPrototype(dataContainer.Server, dataContainer.Server.Logins[parameters.Name]);
List<string> loginServerRoles = new List<string>();
foreach(string role in prototype.ServerRoles.ServerRoleNames)
foreach (string role in prototype.ServerRoles.ServerRoleNames)
{
if (prototype.ServerRoles.IsMember(role))
{
@@ -254,5 +313,61 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
{
await requestContext.SendResult(new object());
}
internal string ConfigureLogin(
CDataContainer dataContainer,
ConfigAction configAction,
RunType runType,
LoginPrototype prototype)
{
string sqlScript = string.Empty;
using (var actions = new LoginActions(dataContainer, configAction, prototype))
{
var executionHandler = new ExecutonHandler(actions);
executionHandler.RunNow(runType, this);
if (executionHandler.ExecutionResult == ExecutionMode.Failure)
{
throw executionHandler.ExecutionFailureException;
}
if (runType == RunType.ScriptToWindow)
{
sqlScript = executionHandler.ScriptTextFromLastRun;
}
}
return sqlScript;
}
}
internal class LoginActions : ManagementActionBase
{
private ConfigAction configAction;
private LoginPrototype prototype;
/// <summary>
/// Handle login create and update actions
/// </summary>
public LoginActions(CDataContainer dataContainer, ConfigAction configAction, LoginPrototype prototype)
{
this.DataContainer = dataContainer;
this.configAction = configAction;
this.prototype = prototype;
}
/// <summary>
/// called by the management actions framework to execute the action
/// </summary>
/// <param name="node"></param>
public override void OnRunNow(object sender)
{
if (this.configAction != ConfigAction.Drop)
{
prototype.ApplyGeneralChanges(this.DataContainer.Server);
prototype.ApplyServerRoleChanges(this.DataContainer.Server);
prototype.ApplyDatabaseRoleChanges(this.DataContainer.Server);
}
}
}
}