Enable scripting for Logins (#2003)

* Add a LoginActions class

* Enable scripting for Login
This commit is contained in:
Karl Burtram
2023-04-13 18:20:53 -07:00
committed by GitHub
parent 948ae3903e
commit 78136e53dd
5 changed files with 174 additions and 35 deletions

View File

@@ -102,4 +102,27 @@ namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts
RequestType<InitializeLoginViewRequestParams, LoginViewInfo> Type =
RequestType<InitializeLoginViewRequestParams, LoginViewInfo>.Create("objectManagement/initializeLoginView");
}
/// <summary>
/// Script Login params
/// </summary>
public class ScriptLoginParams
{
public string? ContextId { get; set; }
public LoginInfo? Login { get; set; }
}
/// <summary>
/// Script Login request type
/// </summary>
public class ScriptLoginRequest
{
/// <summary>
/// Request definition
/// </summary>
public static readonly
RequestType<ScriptLoginParams, string> Type =
RequestType<ScriptLoginParams, string>.Create("objectManagement/scriptLogin");
}
}

View File

@@ -20,7 +20,21 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
{
internal class LoginServiceHandlerImpl
{
private Dictionary<string, string> contextIdToConnectionUriMap = new Dictionary<string, string>();
private class ViewState
{
public bool IsNewObject { get; set; }
public string ConnectionUri { get; set; }
public ViewState(bool isNewObject, string connectionUri)
{
this.IsNewObject = isNewObject;
this.ConnectionUri = connectionUri;
}
}
private Dictionary<string, ViewState> contextIdToViewState = new Dictionary<string, ViewState>();
private ConnectionService? connectionService;
@@ -46,10 +60,19 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
/// </summary>
internal async Task HandleCreateLoginRequest(CreateLoginParams parameters, RequestContext<object> requestContext)
{
DoHandleCreateLoginRequest(parameters.ContextId, parameters.Login, RunType.RunNow);
await requestContext.SendResult(new object());
}
private string DoHandleCreateLoginRequest(
string contextId, LoginInfo login, RunType runType)
{
ViewState? viewState;
this.contextIdToViewState.TryGetValue(contextId, out viewState);
ConnectionInfo connInfo;
string ownerUri;
contextIdToConnectionUriMap.TryGetValue(parameters.ContextId, out ownerUri);
ConnectionServiceInstance.TryFindConnection(ownerUri, out connInfo);
ConnectionServiceInstance.TryFindConnection(viewState?.ConnectionUri, out connInfo);
if (connInfo == null)
{
@@ -57,7 +80,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
}
CDataContainer dataContainer = CDataContainer.CreateDataContainer(connInfo, databaseExists: true);
LoginPrototype prototype = new LoginPrototype(dataContainer.Server, parameters.Login);
LoginPrototype prototype = new LoginPrototype(dataContainer.Server, login);
if (prototype.LoginType == SqlServer.Management.Smo.LoginType.SqlLogin)
{
@@ -77,39 +100,46 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
}
}
prototype.ApplyGeneralChanges(dataContainer.Server);
// TODO move this to LoginData
// TODO support role assignment for Azure
LoginPrototype newPrototype = new LoginPrototype(dataContainer.Server, dataContainer.Server.Logins[parameters.Login.Name]);
var _ =newPrototype.ServerRoles.ServerRoleNames;
foreach (string role in parameters.Login.ServerRoles ?? Enumerable.Empty<string>())
prototype.ServerRoles.PopulateServerRoles();
foreach (string role in login.ServerRoles ?? Enumerable.Empty<string>())
{
newPrototype.ServerRoles.SetMember(role, true);
prototype.ServerRoles.SetMember(role, true);
}
newPrototype.ApplyServerRoleChanges(dataContainer.Server);
await requestContext.SendResult(new object());
return ConfigureLogin(
dataContainer,
ConfigAction.Create,
runType,
prototype);
}
internal async Task HandleUpdateLoginRequest(UpdateLoginParams parameters, RequestContext<object> requestContext)
{
DoHandleUpdateLoginRequest(parameters.ContextId, parameters.Login, RunType.RunNow);
await requestContext.SendResult(new object());
}
private string DoHandleUpdateLoginRequest(
string contextId, LoginInfo login, RunType runType)
{
ViewState? viewState;
this.contextIdToViewState.TryGetValue(contextId, out viewState);
ConnectionInfo connInfo;
string ownerUri;
contextIdToConnectionUriMap.TryGetValue(parameters.ContextId, out ownerUri);
ConnectionServiceInstance.TryFindConnection(ownerUri, out connInfo);
ConnectionServiceInstance.TryFindConnection(viewState?.ConnectionUri, out connInfo);
if (connInfo == null)
{
throw new ArgumentException("Invalid ConnectionUri");
}
CDataContainer dataContainer = CDataContainer.CreateDataContainer(connInfo, databaseExists: true);
LoginPrototype prototype = new LoginPrototype(dataContainer.Server, dataContainer.Server.Logins[parameters.Login.Name]);
LoginPrototype prototype = new LoginPrototype(dataContainer.Server, dataContainer.Server.Logins[login.Name]);
var login = parameters.Login;
prototype.SqlPassword = login.Password;
if (0 != String.Compare(login.DefaultLanguage, SR.DefaultLanguagePlaceholder, StringComparison.Ordinal))
if (0 != string.Compare(login.DefaultLanguage, SR.DefaultLanguagePlaceholder, StringComparison.Ordinal))
{
string[] arr = login.DefaultLanguage?.Split(" - ");
if (arr != null && arr.Length > 1)
@@ -155,15 +185,44 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
prototype.ServerRoles.SetMember(role, true);
}
prototype.ApplyGeneralChanges(dataContainer.Server);
prototype.ApplyServerRoleChanges(dataContainer.Server);
prototype.ApplyDatabaseRoleChanges(dataContainer.Server);
await requestContext.SendResult(new object());
return ConfigureLogin(
dataContainer,
ConfigAction.Update,
runType,
prototype);
}
/// <summary>
/// Handle request to script a user
/// </summary>
internal async Task HandleScriptLoginRequest(ScriptLoginParams parameters, RequestContext<string> requestContext)
{
if (parameters.ContextId == null)
{
throw new ArgumentException("Invalid context ID");
}
ViewState viewState;
this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState);
if (viewState == null)
{
throw new ArgumentException("Invalid context ID view state");
}
string sqlScript = (viewState.IsNewObject)
? DoHandleCreateLoginRequest(parameters.ContextId, parameters.Login, RunType.ScriptToWindow)
: DoHandleUpdateLoginRequest(parameters.ContextId, parameters.Login, RunType.ScriptToWindow);
await requestContext.SendResult(sqlScript);
}
internal async Task HandleInitializeLoginViewRequest(InitializeLoginViewRequestParams parameters, RequestContext<LoginViewInfo> requestContext)
{
contextIdToConnectionUriMap.Add(parameters.ContextId, parameters.ConnectionUri);
this.contextIdToViewState.Add(
parameters.ContextId,
new ViewState(parameters.IsNewObject, parameters.ConnectionUri));
ConnectionInfo connInfo;
ConnectionServiceInstance.TryFindConnection(parameters.ConnectionUri, out connInfo);
if (connInfo == null)
@@ -192,7 +251,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
: new LoginPrototype(dataContainer.Server, dataContainer.Server.Logins[parameters.Name]);
List<string> loginServerRoles = new List<string>();
foreach(string role in prototype.ServerRoles.ServerRoleNames)
foreach (string role in prototype.ServerRoles.ServerRoleNames)
{
if (prototype.ServerRoles.IsMember(role))
{
@@ -254,5 +313,61 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
{
await requestContext.SendResult(new object());
}
internal string ConfigureLogin(
CDataContainer dataContainer,
ConfigAction configAction,
RunType runType,
LoginPrototype prototype)
{
string sqlScript = string.Empty;
using (var actions = new LoginActions(dataContainer, configAction, prototype))
{
var executionHandler = new ExecutonHandler(actions);
executionHandler.RunNow(runType, this);
if (executionHandler.ExecutionResult == ExecutionMode.Failure)
{
throw executionHandler.ExecutionFailureException;
}
if (runType == RunType.ScriptToWindow)
{
sqlScript = executionHandler.ScriptTextFromLastRun;
}
}
return sqlScript;
}
}
internal class LoginActions : ManagementActionBase
{
private ConfigAction configAction;
private LoginPrototype prototype;
/// <summary>
/// Handle login create and update actions
/// </summary>
public LoginActions(CDataContainer dataContainer, ConfigAction configAction, LoginPrototype prototype)
{
this.DataContainer = dataContainer;
this.configAction = configAction;
this.prototype = prototype;
}
/// <summary>
/// called by the management actions framework to execute the action
/// </summary>
/// <param name="node"></param>
public override void OnRunNow(object sender)
{
if (this.configAction != ConfigAction.Drop)
{
prototype.ApplyGeneralChanges(this.DataContainer.Server);
prototype.ApplyServerRoleChanges(this.DataContainer.Server);
prototype.ApplyDatabaseRoleChanges(this.DataContainer.Server);
}
}
}
}

View File

@@ -773,7 +773,7 @@ INNER JOIN sys.sql_logins AS sql_logins
/// <summary>
/// Populate the server roles map
/// </summary>
private void PopulateServerRoles()
internal void PopulateServerRoles()
{
this.initialized = true;
serverRoles.Clear();

View File

@@ -89,6 +89,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
this.ServiceHost.SetRequestHandler(CreateLoginRequest.Type, this.loginServiceHandler.HandleCreateLoginRequest, true);
this.ServiceHost.SetRequestHandler(UpdateLoginRequest.Type, this.loginServiceHandler.HandleUpdateLoginRequest, true);
this.ServiceHost.SetRequestHandler(InitializeLoginViewRequest.Type, this.loginServiceHandler.HandleInitializeLoginViewRequest, true);
this.ServiceHost.SetRequestHandler(ScriptLoginRequest.Type, this.loginServiceHandler.HandleScriptLoginRequest, true);
this.ServiceHost.SetRequestHandler(DisposeLoginViewRequest.Type, this.loginServiceHandler.HandleDisposeLoginViewRequest, true);
// User request handlers

View File

@@ -22,7 +22,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
{
internal class UserServiceHandlerImpl
{
private class UserViewState
private class ViewState
{
public bool IsNewObject { get; set; }
@@ -30,7 +30,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
public UserPrototypeData OriginalUserData { get; set; }
public UserViewState(bool isNewObject, string database, UserPrototypeData originalUserData)
public ViewState(bool isNewObject, string database, UserPrototypeData originalUserData)
{
this.IsNewObject = isNewObject;
this.Database = database;
@@ -40,7 +40,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
private ConnectionService? connectionService;
private Dictionary<string, UserViewState> contextIdToViewState = new Dictionary<string, UserViewState>();
private Dictionary<string, ViewState> contextIdToViewState = new Dictionary<string, ViewState>();
/// <summary>
/// Internal for testing purposes only
@@ -232,7 +232,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
this.contextIdToViewState.Add(
parameters.ContextId,
new UserViewState(parameters.IsNewObject, parameters.Database, currentUserPrototype.CurrentState));
new ViewState(parameters.IsNewObject, parameters.Database, currentUserPrototype.CurrentState));
await requestContext.SendResult(userViewInfo);
}
@@ -247,7 +247,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
throw new ArgumentException("Invalid context ID");
}
UserViewState viewState;
ViewState viewState;
this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState);
if (viewState == null)
@@ -281,7 +281,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
throw new ArgumentException("Invalid context ID");
}
UserViewState viewState;
ViewState viewState;
this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState);
if (viewState == null)
@@ -305,7 +305,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
}
/// <summary>
/// Handle request to update a user
/// Handle request to script a user
/// </summary>
internal async Task HandleScriptUserRequest(ScriptUserParams parameters, RequestContext<string> requestContext)
{
@@ -314,7 +314,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
throw new ArgumentException("Invalid context ID");
}
UserViewState viewState;
ViewState viewState;
this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState);
if (viewState == null)