diff --git a/src/Microsoft.SqlTools.ServiceLayer/Security/Contracts/UserRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/Security/Contracts/UserRequest.cs index 2de3dc86..82c2a60c 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Security/Contracts/UserRequest.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Security/Contracts/UserRequest.cs @@ -109,4 +109,27 @@ namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts RequestType Type = RequestType.Create("objectManagement/disposeUserView"); } + + /// + /// Script User params + /// + public class ScriptUserParams + { + public string? ContextId { get; set; } + + public UserInfo? User { get; set; } + } + + /// + /// Script User request type + /// + public class ScriptUserRequest + { + /// + /// Request definition + /// + public static readonly + RequestType Type = + RequestType.Create("objectManagement/scriptUser"); + } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Security/SecurityService.cs b/src/Microsoft.SqlTools.ServiceLayer/Security/SecurityService.cs index 671f4704..bebdc46f 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Security/SecurityService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Security/SecurityService.cs @@ -95,6 +95,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security this.ServiceHost.SetRequestHandler(InitializeUserViewRequest.Type, this.userServiceHandler.HandleInitializeUserViewRequest, true); this.ServiceHost.SetRequestHandler(CreateUserRequest.Type, this.userServiceHandler.HandleCreateUserRequest, true); this.ServiceHost.SetRequestHandler(UpdateUserRequest.Type, this.userServiceHandler.HandleUpdateUserRequest, true); + this.ServiceHost.SetRequestHandler(ScriptUserRequest.Type, this.userServiceHandler.HandleScriptUserRequest, true); this.ServiceHost.SetRequestHandler(DisposeUserViewRequest.Type, this.userServiceHandler.HandleDisposeUserViewRequest, true); } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Security/UserActions.cs b/src/Microsoft.SqlTools.ServiceLayer/Security/UserActions.cs index 5eb62935..def684a5 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Security/UserActions.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Security/UserActions.cs @@ -24,12 +24,15 @@ namespace Microsoft.SqlTools.ServiceLayer.Security { private class UserViewState { + public bool IsNewObject { get; set; } + public string Database { get; set; } public UserPrototypeData OriginalUserData { get; set; } - public UserViewState(string database, UserPrototypeData originalUserData) + public UserViewState(bool isNewObject, string database, UserPrototypeData originalUserData) { + this.IsNewObject = isNewObject; this.Database = database; this.OriginalUserData = originalUserData; } @@ -229,7 +232,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security this.contextIdToViewState.Add( parameters.ContextId, - new UserViewState(parameters.Database, currentUserPrototype.CurrentState)); + new UserViewState(parameters.IsNewObject, parameters.Database, currentUserPrototype.CurrentState)); await requestContext.SendResult(userViewInfo); } @@ -252,7 +255,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security throw new ArgumentException("Invalid context ID view state"); } - Tuple result = ConfigureUser( + ConfigureUser( parameters.ContextId, parameters.User, ConfigAction.Create, @@ -263,8 +266,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Security await requestContext.SendResult(new CreateUserResult() { User = parameters.User, - Success = result.Item1, - ErrorMessage = result.Item2 + Success = true, + ErrorMessage = string.Empty }); } @@ -286,7 +289,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security throw new ArgumentException("Invalid context ID view state"); } - Tuple result = ConfigureUser( + ConfigureUser( parameters.ContextId, parameters.User, ConfigAction.Update, @@ -296,11 +299,42 @@ namespace Microsoft.SqlTools.ServiceLayer.Security await requestContext.SendResult(new ResultStatus() { - Success = result.Item1, - ErrorMessage = result.Item2 + Success = true, + ErrorMessage = string.Empty }); } + /// + /// Handle request to update a user + /// + internal async Task HandleScriptUserRequest(ScriptUserParams parameters, RequestContext requestContext) + { + if (parameters.ContextId == null) + { + throw new ArgumentException("Invalid context ID"); + } + + UserViewState viewState; + this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState); + + if (viewState == null) + { + throw new ArgumentException("Invalid context ID view state"); + } + + // todo: check if it's an existing user + + string sqlScript = ConfigureUser( + parameters.ContextId, + parameters.User, + viewState.IsNewObject ? ConfigAction.Create : ConfigAction.Update, + RunType.ScriptToWindow, + viewState.Database, + viewState.OriginalUserData); + + await requestContext.SendResult(sqlScript); + } + internal async Task HandleDisposeUserViewRequest(DisposeUserViewRequestParams parameters, RequestContext requestContext) { this.ConnectionServiceInstance.Disconnect(new DisconnectParams() @@ -352,7 +386,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security return CDataContainer.CreateDataContainer(connectionInfoWithConnection, xmlDoc); } - internal Tuple ConfigureUser( + internal string ConfigureUser( string? ownerUri, UserInfo? user, ConfigAction configAction, @@ -367,6 +401,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security throw new ArgumentException("Invalid connection URI '{0}'", ownerUri); } + string sqlScript = string.Empty; CDataContainer dataContainer = CreateUserDataContainer(connInfo, user, configAction, databaseName); using (var actions = new UserActions(dataContainer, configAction, user, originalData)) { @@ -376,9 +411,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Security { throw executionHandler.ExecutionFailureException; } + + if (runType == RunType.ScriptToWindow) + { + sqlScript = executionHandler.ScriptTextFromLastRun; + } } - return new Tuple(true, string.Empty); + return sqlScript; } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Security/SecurityTestUtils.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Security/SecurityTestUtils.cs index 9d9b8a5a..8066599a 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Security/SecurityTestUtils.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Security/SecurityTestUtils.cs @@ -165,7 +165,8 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Security DatabaseUserType userType, string userName = null, string loginName = null, - string databaseName = "master") + string databaseName = "master", + bool scriptUser = false) { string contextId = System.Guid.NewGuid().ToString(); var initializeViewRequestParams = new InitializeUserViewParams @@ -181,6 +182,25 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Security .Returns(Task.FromResult(new UserViewInfo())); await service.HandleInitializeUserViewRequest(initializeViewRequestParams, initializeUserViewContext.Object); + + if (scriptUser) + { + var scriptParams = new ScriptUserParams + { + ContextId = contextId, + User = SecurityTestUtils.GetTestUserInfo(userType, userName, loginName) + }; + + var scriptUserContext = new Mock>(); + scriptUserContext.Setup(x => x.SendResult(It.IsAny())) + .Returns(Task.FromResult(new object())); + + await service.HandleScriptUserRequest(scriptParams, scriptUserContext.Object); + + // verify the result + scriptUserContext.Verify(x => x.SendResult(It.Is + (p => p.Contains("CREATE USER")))); + } var userParams = new CreateUserParams { diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Security/UserTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Security/UserTests.cs index 7a37884c..ed2c92ab 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Security/UserTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Security/UserTests.cs @@ -112,5 +112,29 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Security await SecurityTestUtils.DropObject(connectionResult.ConnectionInfo.OwnerUri, SecurityTestUtils.GetLoginURN(login.Name)); } } + + /// + /// Test the basic Create User method handler + /// + [Test] + public async Task TestScriptUserWithLogin() + { + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + { + // setup + UserServiceHandlerImpl userService = new UserServiceHandlerImpl(); + LoginServiceHandlerImpl loginService = new LoginServiceHandlerImpl(); + var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath); + + var login = await SecurityTestUtils.CreateLogin(loginService, connectionResult); + + var user = await SecurityTestUtils.CreateUser(userService, connectionResult, + DatabaseUserType.WithLogin, null, login.Name, scriptUser: true); + + await SecurityTestUtils.DropObject(connectionResult.ConnectionInfo.OwnerUri, SecurityTestUtils.GetUserURN(connectionResult.ConnectionInfo.ConnectionDetails.DatabaseName, user.Name)); + + await SecurityTestUtils.DropObject(connectionResult.ConnectionInfo.OwnerUri, SecurityTestUtils.GetLoginURN(login.Name)); + } + } } }