From 78136e53dd92ba17e25ec0bed3fe240c2de0b823 Mon Sep 17 00:00:00 2001 From: Karl Burtram Date: Thu, 13 Apr 2023 18:20:53 -0700 Subject: [PATCH] Enable scripting for Logins (#2003) * Add a LoginActions class * Enable scripting for Login --- .../Security/Contracts/LoginRequest.cs | 23 +++ .../Security/LoginActions.cs | 167 +++++++++++++++--- .../Security/LoginData.cs | 2 +- .../Security/SecurityService.cs | 1 + .../Security/UserActions.cs | 16 +- 5 files changed, 174 insertions(+), 35 deletions(-) diff --git a/src/Microsoft.SqlTools.ServiceLayer/Security/Contracts/LoginRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/Security/Contracts/LoginRequest.cs index 6d80867d..6192481e 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Security/Contracts/LoginRequest.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Security/Contracts/LoginRequest.cs @@ -102,4 +102,27 @@ namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts RequestType Type = RequestType.Create("objectManagement/initializeLoginView"); } + + /// + /// Script Login params + /// + public class ScriptLoginParams + { + public string? ContextId { get; set; } + + public LoginInfo? Login { get; set; } + } + + /// + /// Script Login request type + /// + public class ScriptLoginRequest + { + /// + /// Request definition + /// + public static readonly + RequestType Type = + RequestType.Create("objectManagement/scriptLogin"); + } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Security/LoginActions.cs b/src/Microsoft.SqlTools.ServiceLayer/Security/LoginActions.cs index e34fc05f..c737222e 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Security/LoginActions.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Security/LoginActions.cs @@ -20,7 +20,21 @@ namespace Microsoft.SqlTools.ServiceLayer.Security { internal class LoginServiceHandlerImpl { - private Dictionary contextIdToConnectionUriMap = new Dictionary(); + private class ViewState + { + public bool IsNewObject { get; set; } + + public string ConnectionUri { get; set; } + + public ViewState(bool isNewObject, string connectionUri) + { + this.IsNewObject = isNewObject; + + this.ConnectionUri = connectionUri; + } + } + + private Dictionary contextIdToViewState = new Dictionary(); private ConnectionService? connectionService; @@ -46,10 +60,19 @@ namespace Microsoft.SqlTools.ServiceLayer.Security /// internal async Task HandleCreateLoginRequest(CreateLoginParams parameters, RequestContext requestContext) { + DoHandleCreateLoginRequest(parameters.ContextId, parameters.Login, RunType.RunNow); + + await requestContext.SendResult(new object()); + } + + private string DoHandleCreateLoginRequest( + string contextId, LoginInfo login, RunType runType) + { + ViewState? viewState; + this.contextIdToViewState.TryGetValue(contextId, out viewState); + ConnectionInfo connInfo; - string ownerUri; - contextIdToConnectionUriMap.TryGetValue(parameters.ContextId, out ownerUri); - ConnectionServiceInstance.TryFindConnection(ownerUri, out connInfo); + ConnectionServiceInstance.TryFindConnection(viewState?.ConnectionUri, out connInfo); if (connInfo == null) { @@ -57,7 +80,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security } CDataContainer dataContainer = CDataContainer.CreateDataContainer(connInfo, databaseExists: true); - LoginPrototype prototype = new LoginPrototype(dataContainer.Server, parameters.Login); + LoginPrototype prototype = new LoginPrototype(dataContainer.Server, login); if (prototype.LoginType == SqlServer.Management.Smo.LoginType.SqlLogin) { @@ -77,39 +100,46 @@ namespace Microsoft.SqlTools.ServiceLayer.Security } } - prototype.ApplyGeneralChanges(dataContainer.Server); - // TODO move this to LoginData // TODO support role assignment for Azure - LoginPrototype newPrototype = new LoginPrototype(dataContainer.Server, dataContainer.Server.Logins[parameters.Login.Name]); - var _ =newPrototype.ServerRoles.ServerRoleNames; - - foreach (string role in parameters.Login.ServerRoles ?? Enumerable.Empty()) + prototype.ServerRoles.PopulateServerRoles(); + foreach (string role in login.ServerRoles ?? Enumerable.Empty()) { - newPrototype.ServerRoles.SetMember(role, true); + prototype.ServerRoles.SetMember(role, true); } - newPrototype.ApplyServerRoleChanges(dataContainer.Server); - await requestContext.SendResult(new object()); + return ConfigureLogin( + dataContainer, + ConfigAction.Create, + runType, + prototype); } internal async Task HandleUpdateLoginRequest(UpdateLoginParams parameters, RequestContext requestContext) { + DoHandleUpdateLoginRequest(parameters.ContextId, parameters.Login, RunType.RunNow); + + await requestContext.SendResult(new object()); + } + + private string DoHandleUpdateLoginRequest( + string contextId, LoginInfo login, RunType runType) + { + ViewState? viewState; + this.contextIdToViewState.TryGetValue(contextId, out viewState); + ConnectionInfo connInfo; - string ownerUri; - contextIdToConnectionUriMap.TryGetValue(parameters.ContextId, out ownerUri); - ConnectionServiceInstance.TryFindConnection(ownerUri, out connInfo); + ConnectionServiceInstance.TryFindConnection(viewState?.ConnectionUri, out connInfo); if (connInfo == null) { throw new ArgumentException("Invalid ConnectionUri"); } CDataContainer dataContainer = CDataContainer.CreateDataContainer(connInfo, databaseExists: true); - LoginPrototype prototype = new LoginPrototype(dataContainer.Server, dataContainer.Server.Logins[parameters.Login.Name]); + LoginPrototype prototype = new LoginPrototype(dataContainer.Server, dataContainer.Server.Logins[login.Name]); - var login = parameters.Login; prototype.SqlPassword = login.Password; - if (0 != String.Compare(login.DefaultLanguage, SR.DefaultLanguagePlaceholder, StringComparison.Ordinal)) + if (0 != string.Compare(login.DefaultLanguage, SR.DefaultLanguagePlaceholder, StringComparison.Ordinal)) { string[] arr = login.DefaultLanguage?.Split(" - "); if (arr != null && arr.Length > 1) @@ -155,15 +185,44 @@ namespace Microsoft.SqlTools.ServiceLayer.Security prototype.ServerRoles.SetMember(role, true); } - prototype.ApplyGeneralChanges(dataContainer.Server); - prototype.ApplyServerRoleChanges(dataContainer.Server); - prototype.ApplyDatabaseRoleChanges(dataContainer.Server); - await requestContext.SendResult(new object()); + return ConfigureLogin( + dataContainer, + ConfigAction.Update, + runType, + prototype); + } + + /// + /// Handle request to script a user + /// + internal async Task HandleScriptLoginRequest(ScriptLoginParams parameters, RequestContext requestContext) + { + if (parameters.ContextId == null) + { + throw new ArgumentException("Invalid context ID"); + } + + ViewState viewState; + this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState); + + if (viewState == null) + { + throw new ArgumentException("Invalid context ID view state"); + } + + string sqlScript = (viewState.IsNewObject) + ? DoHandleCreateLoginRequest(parameters.ContextId, parameters.Login, RunType.ScriptToWindow) + : DoHandleUpdateLoginRequest(parameters.ContextId, parameters.Login, RunType.ScriptToWindow); + + await requestContext.SendResult(sqlScript); } internal async Task HandleInitializeLoginViewRequest(InitializeLoginViewRequestParams parameters, RequestContext requestContext) { - contextIdToConnectionUriMap.Add(parameters.ContextId, parameters.ConnectionUri); + this.contextIdToViewState.Add( + parameters.ContextId, + new ViewState(parameters.IsNewObject, parameters.ConnectionUri)); + ConnectionInfo connInfo; ConnectionServiceInstance.TryFindConnection(parameters.ConnectionUri, out connInfo); if (connInfo == null) @@ -192,7 +251,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security : new LoginPrototype(dataContainer.Server, dataContainer.Server.Logins[parameters.Name]); List loginServerRoles = new List(); - foreach(string role in prototype.ServerRoles.ServerRoleNames) + foreach (string role in prototype.ServerRoles.ServerRoleNames) { if (prototype.ServerRoles.IsMember(role)) { @@ -254,5 +313,61 @@ namespace Microsoft.SqlTools.ServiceLayer.Security { await requestContext.SendResult(new object()); } + + internal string ConfigureLogin( + CDataContainer dataContainer, + ConfigAction configAction, + RunType runType, + LoginPrototype prototype) + { + string sqlScript = string.Empty; + using (var actions = new LoginActions(dataContainer, configAction, prototype)) + { + var executionHandler = new ExecutonHandler(actions); + executionHandler.RunNow(runType, this); + if (executionHandler.ExecutionResult == ExecutionMode.Failure) + { + throw executionHandler.ExecutionFailureException; + } + + if (runType == RunType.ScriptToWindow) + { + sqlScript = executionHandler.ScriptTextFromLastRun; + } + } + + return sqlScript; + } + } + + internal class LoginActions : ManagementActionBase + { + private ConfigAction configAction; + + private LoginPrototype prototype; + + /// + /// Handle login create and update actions + /// + public LoginActions(CDataContainer dataContainer, ConfigAction configAction, LoginPrototype prototype) + { + this.DataContainer = dataContainer; + this.configAction = configAction; + this.prototype = prototype; + } + + /// + /// called by the management actions framework to execute the action + /// + /// + public override void OnRunNow(object sender) + { + if (this.configAction != ConfigAction.Drop) + { + prototype.ApplyGeneralChanges(this.DataContainer.Server); + prototype.ApplyServerRoleChanges(this.DataContainer.Server); + prototype.ApplyDatabaseRoleChanges(this.DataContainer.Server); + } + } } } \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/Security/LoginData.cs b/src/Microsoft.SqlTools.ServiceLayer/Security/LoginData.cs index c7f68b4c..a80c0fdf 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Security/LoginData.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Security/LoginData.cs @@ -773,7 +773,7 @@ INNER JOIN sys.sql_logins AS sql_logins /// /// Populate the server roles map /// - private void PopulateServerRoles() + internal void PopulateServerRoles() { this.initialized = true; serverRoles.Clear(); diff --git a/src/Microsoft.SqlTools.ServiceLayer/Security/SecurityService.cs b/src/Microsoft.SqlTools.ServiceLayer/Security/SecurityService.cs index bebdc46f..8523d907 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Security/SecurityService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Security/SecurityService.cs @@ -89,6 +89,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security this.ServiceHost.SetRequestHandler(CreateLoginRequest.Type, this.loginServiceHandler.HandleCreateLoginRequest, true); this.ServiceHost.SetRequestHandler(UpdateLoginRequest.Type, this.loginServiceHandler.HandleUpdateLoginRequest, true); this.ServiceHost.SetRequestHandler(InitializeLoginViewRequest.Type, this.loginServiceHandler.HandleInitializeLoginViewRequest, true); + this.ServiceHost.SetRequestHandler(ScriptLoginRequest.Type, this.loginServiceHandler.HandleScriptLoginRequest, true); this.ServiceHost.SetRequestHandler(DisposeLoginViewRequest.Type, this.loginServiceHandler.HandleDisposeLoginViewRequest, true); // User request handlers diff --git a/src/Microsoft.SqlTools.ServiceLayer/Security/UserActions.cs b/src/Microsoft.SqlTools.ServiceLayer/Security/UserActions.cs index def684a5..3d4edb37 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Security/UserActions.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Security/UserActions.cs @@ -22,7 +22,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security { internal class UserServiceHandlerImpl { - private class UserViewState + private class ViewState { public bool IsNewObject { get; set; } @@ -30,7 +30,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security public UserPrototypeData OriginalUserData { get; set; } - public UserViewState(bool isNewObject, string database, UserPrototypeData originalUserData) + public ViewState(bool isNewObject, string database, UserPrototypeData originalUserData) { this.IsNewObject = isNewObject; this.Database = database; @@ -40,7 +40,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security private ConnectionService? connectionService; - private Dictionary contextIdToViewState = new Dictionary(); + private Dictionary contextIdToViewState = new Dictionary(); /// /// Internal for testing purposes only @@ -232,7 +232,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security this.contextIdToViewState.Add( parameters.ContextId, - new UserViewState(parameters.IsNewObject, parameters.Database, currentUserPrototype.CurrentState)); + new ViewState(parameters.IsNewObject, parameters.Database, currentUserPrototype.CurrentState)); await requestContext.SendResult(userViewInfo); } @@ -247,7 +247,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security throw new ArgumentException("Invalid context ID"); } - UserViewState viewState; + ViewState viewState; this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState); if (viewState == null) @@ -281,7 +281,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security throw new ArgumentException("Invalid context ID"); } - UserViewState viewState; + ViewState viewState; this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState); if (viewState == null) @@ -305,7 +305,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security } /// - /// Handle request to update a user + /// Handle request to script a user /// internal async Task HandleScriptUserRequest(ScriptUserParams parameters, RequestContext requestContext) { @@ -314,7 +314,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security throw new ArgumentException("Invalid context ID"); } - UserViewState viewState; + ViewState viewState; this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState); if (viewState == null)