Enable scripting for Logins (#2003)

* Add a LoginActions class

* Enable scripting for Login
This commit is contained in:
Karl Burtram
2023-04-13 18:20:53 -07:00
committed by GitHub
parent 948ae3903e
commit 78136e53dd
5 changed files with 174 additions and 35 deletions

View File

@@ -102,4 +102,27 @@ namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts
RequestType<InitializeLoginViewRequestParams, LoginViewInfo> Type = RequestType<InitializeLoginViewRequestParams, LoginViewInfo> Type =
RequestType<InitializeLoginViewRequestParams, LoginViewInfo>.Create("objectManagement/initializeLoginView"); RequestType<InitializeLoginViewRequestParams, LoginViewInfo>.Create("objectManagement/initializeLoginView");
} }
/// <summary>
/// Script Login params
/// </summary>
public class ScriptLoginParams
{
public string? ContextId { get; set; }
public LoginInfo? Login { get; set; }
}
/// <summary>
/// Script Login request type
/// </summary>
public class ScriptLoginRequest
{
/// <summary>
/// Request definition
/// </summary>
public static readonly
RequestType<ScriptLoginParams, string> Type =
RequestType<ScriptLoginParams, string>.Create("objectManagement/scriptLogin");
}
} }

View File

@@ -20,7 +20,21 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
{ {
internal class LoginServiceHandlerImpl internal class LoginServiceHandlerImpl
{ {
private Dictionary<string, string> contextIdToConnectionUriMap = new Dictionary<string, string>(); private class ViewState
{
public bool IsNewObject { get; set; }
public string ConnectionUri { get; set; }
public ViewState(bool isNewObject, string connectionUri)
{
this.IsNewObject = isNewObject;
this.ConnectionUri = connectionUri;
}
}
private Dictionary<string, ViewState> contextIdToViewState = new Dictionary<string, ViewState>();
private ConnectionService? connectionService; private ConnectionService? connectionService;
@@ -46,10 +60,19 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
/// </summary> /// </summary>
internal async Task HandleCreateLoginRequest(CreateLoginParams parameters, RequestContext<object> requestContext) internal async Task HandleCreateLoginRequest(CreateLoginParams parameters, RequestContext<object> requestContext)
{ {
DoHandleCreateLoginRequest(parameters.ContextId, parameters.Login, RunType.RunNow);
await requestContext.SendResult(new object());
}
private string DoHandleCreateLoginRequest(
string contextId, LoginInfo login, RunType runType)
{
ViewState? viewState;
this.contextIdToViewState.TryGetValue(contextId, out viewState);
ConnectionInfo connInfo; ConnectionInfo connInfo;
string ownerUri; ConnectionServiceInstance.TryFindConnection(viewState?.ConnectionUri, out connInfo);
contextIdToConnectionUriMap.TryGetValue(parameters.ContextId, out ownerUri);
ConnectionServiceInstance.TryFindConnection(ownerUri, out connInfo);
if (connInfo == null) if (connInfo == null)
{ {
@@ -57,7 +80,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
} }
CDataContainer dataContainer = CDataContainer.CreateDataContainer(connInfo, databaseExists: true); CDataContainer dataContainer = CDataContainer.CreateDataContainer(connInfo, databaseExists: true);
LoginPrototype prototype = new LoginPrototype(dataContainer.Server, parameters.Login); LoginPrototype prototype = new LoginPrototype(dataContainer.Server, login);
if (prototype.LoginType == SqlServer.Management.Smo.LoginType.SqlLogin) if (prototype.LoginType == SqlServer.Management.Smo.LoginType.SqlLogin)
{ {
@@ -77,39 +100,46 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
} }
} }
prototype.ApplyGeneralChanges(dataContainer.Server);
// TODO move this to LoginData // TODO move this to LoginData
// TODO support role assignment for Azure // TODO support role assignment for Azure
LoginPrototype newPrototype = new LoginPrototype(dataContainer.Server, dataContainer.Server.Logins[parameters.Login.Name]); prototype.ServerRoles.PopulateServerRoles();
var _ =newPrototype.ServerRoles.ServerRoleNames; foreach (string role in login.ServerRoles ?? Enumerable.Empty<string>())
foreach (string role in parameters.Login.ServerRoles ?? Enumerable.Empty<string>())
{ {
newPrototype.ServerRoles.SetMember(role, true); prototype.ServerRoles.SetMember(role, true);
} }
newPrototype.ApplyServerRoleChanges(dataContainer.Server); return ConfigureLogin(
await requestContext.SendResult(new object()); dataContainer,
ConfigAction.Create,
runType,
prototype);
} }
internal async Task HandleUpdateLoginRequest(UpdateLoginParams parameters, RequestContext<object> requestContext) internal async Task HandleUpdateLoginRequest(UpdateLoginParams parameters, RequestContext<object> requestContext)
{ {
DoHandleUpdateLoginRequest(parameters.ContextId, parameters.Login, RunType.RunNow);
await requestContext.SendResult(new object());
}
private string DoHandleUpdateLoginRequest(
string contextId, LoginInfo login, RunType runType)
{
ViewState? viewState;
this.contextIdToViewState.TryGetValue(contextId, out viewState);
ConnectionInfo connInfo; ConnectionInfo connInfo;
string ownerUri; ConnectionServiceInstance.TryFindConnection(viewState?.ConnectionUri, out connInfo);
contextIdToConnectionUriMap.TryGetValue(parameters.ContextId, out ownerUri);
ConnectionServiceInstance.TryFindConnection(ownerUri, out connInfo);
if (connInfo == null) if (connInfo == null)
{ {
throw new ArgumentException("Invalid ConnectionUri"); throw new ArgumentException("Invalid ConnectionUri");
} }
CDataContainer dataContainer = CDataContainer.CreateDataContainer(connInfo, databaseExists: true); CDataContainer dataContainer = CDataContainer.CreateDataContainer(connInfo, databaseExists: true);
LoginPrototype prototype = new LoginPrototype(dataContainer.Server, dataContainer.Server.Logins[parameters.Login.Name]); LoginPrototype prototype = new LoginPrototype(dataContainer.Server, dataContainer.Server.Logins[login.Name]);
var login = parameters.Login;
prototype.SqlPassword = login.Password; prototype.SqlPassword = login.Password;
if (0 != String.Compare(login.DefaultLanguage, SR.DefaultLanguagePlaceholder, StringComparison.Ordinal)) if (0 != string.Compare(login.DefaultLanguage, SR.DefaultLanguagePlaceholder, StringComparison.Ordinal))
{ {
string[] arr = login.DefaultLanguage?.Split(" - "); string[] arr = login.DefaultLanguage?.Split(" - ");
if (arr != null && arr.Length > 1) if (arr != null && arr.Length > 1)
@@ -155,15 +185,44 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
prototype.ServerRoles.SetMember(role, true); prototype.ServerRoles.SetMember(role, true);
} }
prototype.ApplyGeneralChanges(dataContainer.Server); return ConfigureLogin(
prototype.ApplyServerRoleChanges(dataContainer.Server); dataContainer,
prototype.ApplyDatabaseRoleChanges(dataContainer.Server); ConfigAction.Update,
await requestContext.SendResult(new object()); runType,
prototype);
}
/// <summary>
/// Handle request to script a user
/// </summary>
internal async Task HandleScriptLoginRequest(ScriptLoginParams parameters, RequestContext<string> requestContext)
{
if (parameters.ContextId == null)
{
throw new ArgumentException("Invalid context ID");
}
ViewState viewState;
this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState);
if (viewState == null)
{
throw new ArgumentException("Invalid context ID view state");
}
string sqlScript = (viewState.IsNewObject)
? DoHandleCreateLoginRequest(parameters.ContextId, parameters.Login, RunType.ScriptToWindow)
: DoHandleUpdateLoginRequest(parameters.ContextId, parameters.Login, RunType.ScriptToWindow);
await requestContext.SendResult(sqlScript);
} }
internal async Task HandleInitializeLoginViewRequest(InitializeLoginViewRequestParams parameters, RequestContext<LoginViewInfo> requestContext) internal async Task HandleInitializeLoginViewRequest(InitializeLoginViewRequestParams parameters, RequestContext<LoginViewInfo> requestContext)
{ {
contextIdToConnectionUriMap.Add(parameters.ContextId, parameters.ConnectionUri); this.contextIdToViewState.Add(
parameters.ContextId,
new ViewState(parameters.IsNewObject, parameters.ConnectionUri));
ConnectionInfo connInfo; ConnectionInfo connInfo;
ConnectionServiceInstance.TryFindConnection(parameters.ConnectionUri, out connInfo); ConnectionServiceInstance.TryFindConnection(parameters.ConnectionUri, out connInfo);
if (connInfo == null) if (connInfo == null)
@@ -192,7 +251,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
: new LoginPrototype(dataContainer.Server, dataContainer.Server.Logins[parameters.Name]); : new LoginPrototype(dataContainer.Server, dataContainer.Server.Logins[parameters.Name]);
List<string> loginServerRoles = new List<string>(); List<string> loginServerRoles = new List<string>();
foreach(string role in prototype.ServerRoles.ServerRoleNames) foreach (string role in prototype.ServerRoles.ServerRoleNames)
{ {
if (prototype.ServerRoles.IsMember(role)) if (prototype.ServerRoles.IsMember(role))
{ {
@@ -254,5 +313,61 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
{ {
await requestContext.SendResult(new object()); await requestContext.SendResult(new object());
} }
internal string ConfigureLogin(
CDataContainer dataContainer,
ConfigAction configAction,
RunType runType,
LoginPrototype prototype)
{
string sqlScript = string.Empty;
using (var actions = new LoginActions(dataContainer, configAction, prototype))
{
var executionHandler = new ExecutonHandler(actions);
executionHandler.RunNow(runType, this);
if (executionHandler.ExecutionResult == ExecutionMode.Failure)
{
throw executionHandler.ExecutionFailureException;
}
if (runType == RunType.ScriptToWindow)
{
sqlScript = executionHandler.ScriptTextFromLastRun;
}
}
return sqlScript;
}
}
internal class LoginActions : ManagementActionBase
{
private ConfigAction configAction;
private LoginPrototype prototype;
/// <summary>
/// Handle login create and update actions
/// </summary>
public LoginActions(CDataContainer dataContainer, ConfigAction configAction, LoginPrototype prototype)
{
this.DataContainer = dataContainer;
this.configAction = configAction;
this.prototype = prototype;
}
/// <summary>
/// called by the management actions framework to execute the action
/// </summary>
/// <param name="node"></param>
public override void OnRunNow(object sender)
{
if (this.configAction != ConfigAction.Drop)
{
prototype.ApplyGeneralChanges(this.DataContainer.Server);
prototype.ApplyServerRoleChanges(this.DataContainer.Server);
prototype.ApplyDatabaseRoleChanges(this.DataContainer.Server);
}
}
} }
} }

View File

@@ -773,7 +773,7 @@ INNER JOIN sys.sql_logins AS sql_logins
/// <summary> /// <summary>
/// Populate the server roles map /// Populate the server roles map
/// </summary> /// </summary>
private void PopulateServerRoles() internal void PopulateServerRoles()
{ {
this.initialized = true; this.initialized = true;
serverRoles.Clear(); serverRoles.Clear();

View File

@@ -89,6 +89,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
this.ServiceHost.SetRequestHandler(CreateLoginRequest.Type, this.loginServiceHandler.HandleCreateLoginRequest, true); this.ServiceHost.SetRequestHandler(CreateLoginRequest.Type, this.loginServiceHandler.HandleCreateLoginRequest, true);
this.ServiceHost.SetRequestHandler(UpdateLoginRequest.Type, this.loginServiceHandler.HandleUpdateLoginRequest, true); this.ServiceHost.SetRequestHandler(UpdateLoginRequest.Type, this.loginServiceHandler.HandleUpdateLoginRequest, true);
this.ServiceHost.SetRequestHandler(InitializeLoginViewRequest.Type, this.loginServiceHandler.HandleInitializeLoginViewRequest, true); this.ServiceHost.SetRequestHandler(InitializeLoginViewRequest.Type, this.loginServiceHandler.HandleInitializeLoginViewRequest, true);
this.ServiceHost.SetRequestHandler(ScriptLoginRequest.Type, this.loginServiceHandler.HandleScriptLoginRequest, true);
this.ServiceHost.SetRequestHandler(DisposeLoginViewRequest.Type, this.loginServiceHandler.HandleDisposeLoginViewRequest, true); this.ServiceHost.SetRequestHandler(DisposeLoginViewRequest.Type, this.loginServiceHandler.HandleDisposeLoginViewRequest, true);
// User request handlers // User request handlers

View File

@@ -22,7 +22,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
{ {
internal class UserServiceHandlerImpl internal class UserServiceHandlerImpl
{ {
private class UserViewState private class ViewState
{ {
public bool IsNewObject { get; set; } public bool IsNewObject { get; set; }
@@ -30,7 +30,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
public UserPrototypeData OriginalUserData { get; set; } public UserPrototypeData OriginalUserData { get; set; }
public UserViewState(bool isNewObject, string database, UserPrototypeData originalUserData) public ViewState(bool isNewObject, string database, UserPrototypeData originalUserData)
{ {
this.IsNewObject = isNewObject; this.IsNewObject = isNewObject;
this.Database = database; this.Database = database;
@@ -40,7 +40,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
private ConnectionService? connectionService; private ConnectionService? connectionService;
private Dictionary<string, UserViewState> contextIdToViewState = new Dictionary<string, UserViewState>(); private Dictionary<string, ViewState> contextIdToViewState = new Dictionary<string, ViewState>();
/// <summary> /// <summary>
/// Internal for testing purposes only /// Internal for testing purposes only
@@ -232,7 +232,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
this.contextIdToViewState.Add( this.contextIdToViewState.Add(
parameters.ContextId, parameters.ContextId,
new UserViewState(parameters.IsNewObject, parameters.Database, currentUserPrototype.CurrentState)); new ViewState(parameters.IsNewObject, parameters.Database, currentUserPrototype.CurrentState));
await requestContext.SendResult(userViewInfo); await requestContext.SendResult(userViewInfo);
} }
@@ -247,7 +247,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
throw new ArgumentException("Invalid context ID"); throw new ArgumentException("Invalid context ID");
} }
UserViewState viewState; ViewState viewState;
this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState); this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState);
if (viewState == null) if (viewState == null)
@@ -281,7 +281,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
throw new ArgumentException("Invalid context ID"); throw new ArgumentException("Invalid context ID");
} }
UserViewState viewState; ViewState viewState;
this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState); this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState);
if (viewState == null) if (viewState == null)
@@ -305,7 +305,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
} }
/// <summary> /// <summary>
/// Handle request to update a user /// Handle request to script a user
/// </summary> /// </summary>
internal async Task HandleScriptUserRequest(ScriptUserParams parameters, RequestContext<string> requestContext) internal async Task HandleScriptUserRequest(ScriptUserParams parameters, RequestContext<string> requestContext)
{ {
@@ -314,7 +314,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
throw new ArgumentException("Invalid context ID"); throw new ArgumentException("Invalid context ID");
} }
UserViewState viewState; ViewState viewState;
this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState); this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState);
if (viewState == null) if (viewState == null)