From a074b5bf67e2f5b734d2f7571013eaf37b6046ab Mon Sep 17 00:00:00 2001 From: Karl Burtram Date: Mon, 6 Mar 2023 14:43:38 -0800 Subject: [PATCH] Update User edit handlers (#1903) * WIP * Update user tests * WIP updates * WIP * Fix several edit use bugs * Disable failing tests * minor updates * Remove unused using --- .../Security/Contracts/UserRequest.cs | 48 +- .../Security/SecurityService.cs | 500 +----------------- .../Security/UserActions.cs | 446 +++++++++++++++- .../Security/UserData.cs | 80 +-- .../Utility/DatabaseUtils.cs | 165 ++++++ .../Security/SecurityTestUtils.cs | 130 ++++- .../Security/UserTests.cs | 67 +-- 7 files changed, 863 insertions(+), 573 deletions(-) diff --git a/src/Microsoft.SqlTools.ServiceLayer/Security/Contracts/UserRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/Security/Contracts/UserRequest.cs index 45de2825..ba993f4c 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Security/Contracts/UserRequest.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Security/Contracts/UserRequest.cs @@ -5,7 +5,6 @@ using Microsoft.SqlTools.Hosting.Protocol.Contracts; using Microsoft.SqlTools.ServiceLayer.Utility; -using Microsoft.SqlTools.Utility; namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts { @@ -18,7 +17,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts public string? ConnectionUri { get; set; } - public bool isNewObject { get; set; } + public bool IsNewObject { get; set; } public string? Database { get; set; } @@ -41,7 +40,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts /// /// Create User parameters /// - public class CreateUserParams : GeneralRequestDetails + public class CreateUserParams { public string? ContextId { get; set; } public UserInfo? User { get; set; } @@ -68,6 +67,28 @@ namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts RequestType.Create("objectManagement/createUser"); } + /// + /// Update User parameters + /// + public class UpdateUserParams + { + public string? ContextId { get; set; } + public UserInfo? User { get; set; } + } + + /// + /// Update User request type + /// + public class UpdateUserRequest + { + /// + /// Request definition + /// + public static readonly + RequestType Type = + RequestType.Create("objectManagement/updateUser"); + } + /// /// Delete User params /// @@ -92,4 +113,25 @@ namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts RequestType Type = RequestType.Create("objectManagement/deleteUser"); } + + /// + /// Update User params + /// + public class DisposeUserViewRequestParams + { + public string? ContextId { get; set; } + } + + /// + /// Update User request type + /// + public class DisposeUserViewRequest + { + /// + /// Request definition + /// + public static readonly + RequestType Type = + RequestType.Create("objectManagement/disposeUserView"); + } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Security/SecurityService.cs b/src/Microsoft.SqlTools.ServiceLayer/Security/SecurityService.cs index b827a858..ad3f9f82 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Security/SecurityService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Security/SecurityService.cs @@ -9,10 +9,7 @@ using System.Collections.Generic; using System.Data; using System.Linq; using System.Threading.Tasks; -using System.Xml; using Microsoft.SqlServer.Management.Common; -using Microsoft.SqlServer.Management.Dmf; -using Microsoft.SqlServer.Management.Sdk.Sfc; using Microsoft.SqlServer.Management.Smo; using Microsoft.SqlTools.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.Connection; @@ -32,6 +29,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Security private ConnectionService? connectionService; + private UserServiceHandlerImpl userServiceHandler; + private static readonly Lazy instance = new Lazy(() => new SecurityService()); private Dictionary contextIdToConnectionUriMap = new Dictionary(); @@ -41,6 +40,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security /// public SecurityService() { + userServiceHandler = new UserServiceHandlerImpl(); } /// @@ -99,9 +99,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Security this.ServiceHost.SetRequestHandler(DisposeLoginViewRequest.Type, HandleDisposeLoginViewRequest, true); // User request handlers - this.ServiceHost.SetRequestHandler(InitializeUserViewRequest.Type, HandleInitializeUserViewRequest, true); - this.ServiceHost.SetRequestHandler(CreateUserRequest.Type, HandleCreateUserRequest, true); - this.ServiceHost.SetRequestHandler(DeleteUserRequest.Type, HandleDeleteUserRequest, true); + this.ServiceHost.SetRequestHandler(InitializeUserViewRequest.Type, this.userServiceHandler.HandleInitializeUserViewRequest, true); + this.ServiceHost.SetRequestHandler(CreateUserRequest.Type, this.userServiceHandler.HandleCreateUserRequest, true); + this.ServiceHost.SetRequestHandler(UpdateUserRequest.Type, this.userServiceHandler.HandleUpdateUserRequest, true); + this.ServiceHost.SetRequestHandler(DeleteUserRequest.Type, this.userServiceHandler.HandleDeleteUserRequest, true); + this.ServiceHost.SetRequestHandler(DisposeUserViewRequest.Type, this.userServiceHandler.HandleDisposeUserViewRequest, true); } @@ -150,7 +152,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security LoginPrototype newPrototype = new LoginPrototype(dataContainer.Server, dataContainer.Server.Logins[parameters.Login.Name]); var _ =newPrototype.ServerRoles.ServerRoleNames; - foreach (string role in parameters.Login.ServerRoles) + foreach (string role in parameters.Login.ServerRoles ?? Enumerable.Empty()) { newPrototype.ServerRoles.SetMember(role, true); } @@ -175,7 +177,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security Login login = dataContainer.Server?.Logins[parameters.Name]; dataContainer.SqlDialogSubject = login; - DoDropObject(dataContainer); + DatabaseUtils.DoDropObject(dataContainer); await requestContext.SendResult(new ResultStatus() { @@ -224,7 +226,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security } // check that password and confirm password controls' text matches - if (0 != String.Compare(prototype.SqlPassword, prototype.SqlPasswordConfirm, StringComparison.Ordinal)) + if (0 != string.Compare(prototype.SqlPassword, prototype.SqlPasswordConfirm, StringComparison.Ordinal)) { // raise error here } @@ -347,237 +349,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Security if (l == null) return null; return string.Format("{0} - {1}", l.Language.Alias, l.Language.Name); } - #endregion - - #region "User Handlers" - - internal Task> ConfigureUser( - string? ownerUri, - UserInfo? user, - ConfigAction configAction, - RunType runType) - { - return Task>.Run(() => - { - try - { - ConnectionInfo connInfo; - ConnectionServiceInstance.TryFindConnection(ownerUri, out connInfo); - if (connInfo == null) - { - throw new ArgumentException("Invalid connection URI '{0}'", ownerUri); - } - - var serverConnection = ConnectionService.OpenServerConnection(connInfo, "DataContainer"); - var connectionInfoWithConnection = new SqlConnectionInfoWithConnection(); - connectionInfoWithConnection.ServerConnection = serverConnection; - - string urn = string.Format(System.Globalization.CultureInfo.InvariantCulture, - "Server/Database[@Name='{0}']", - Urn.EscapeString(serverConnection.DatabaseName)); - - ActionContext context = new ActionContext(serverConnection, "new_user", urn); - DataContainerXmlGenerator containerXml = new DataContainerXmlGenerator(context); - containerXml.AddProperty("itemtype", "User"); - - XmlDocument xmlDoc = containerXml.GenerateXmlDocument(); - bool objectExists = configAction != ConfigAction.Create; - CDataContainer dataContainer = CDataContainer.CreateDataContainer(connectionInfoWithConnection, xmlDoc); - - using (var actions = new UserActions(dataContainer, user, configAction)) - { - var executionHandler = new ExecutonHandler(actions); - executionHandler.RunNow(runType, this); - } - - return new Tuple(true, string.Empty); - } - catch (Exception ex) - { - return new Tuple(false, ex.ToString()); - } - }); - } - - private Dictionary LoadSchemas(string databaseName, string dbroleName, ServerConnection serverConnection) - { - bool isPropertiesMode = false; - Dictionary schemaOwnership = new Dictionary(); - - Enumerator en = new Enumerator(); - Request req = new Request(); - req.Fields = new String[] { "Name", "Owner" }; - req.Urn = "Server/Database[@Name='" + Urn.EscapeString(databaseName) + "']/Schema"; - - DataTable dt = en.Process(serverConnection, req); - System.Diagnostics.Debug.Assert((dt != null) && (0 < dt.Rows.Count), "enumerator did not return schemas"); - System.Diagnostics.Debug.Assert(!isPropertiesMode || (dbroleName.Length != 0), "role name is not known"); - - foreach (DataRow dr in dt.Rows) - { - string schemaName = Convert.ToString(dr["Name"], System.Globalization.CultureInfo.InvariantCulture); - string schemaOwner = Convert.ToString(dr["Owner"], System.Globalization.CultureInfo.InvariantCulture); - bool roleOwnsSchema = isPropertiesMode && (0 == String.Compare(dbroleName, schemaOwner, StringComparison.Ordinal)); - if (schemaName != null) - { - schemaOwnership[schemaName] = new SchemaOwnership(roleOwnsSchema); - } - } - return schemaOwnership; - } - - private string[] LoadDatabaseRoles(ServerConnection serverConnection, string databaseName) - { - var server = new Server(serverConnection); - string dbUrn = "Server/Database[@Name='" + Urn.EscapeString(databaseName) + "']"; - Database? parent = server.GetSmoObject(new Urn(dbUrn)) as Database; - var roles = new List(); - if (parent != null) - { - foreach (DatabaseRole dbRole in parent.Roles) - { - var comparer = parent.GetStringComparer(); - if (comparer.Compare(dbRole.Name, "public") != 0) - { - roles.Add(dbRole.Name); - } - } - } - return roles.ToArray(); - } - - private string[] LoadSqlLogins(ServerConnection serverConnection) - { - return LoadItems(serverConnection, "Server/Login"); - } - - private string[] LoadItems(ServerConnection serverConnection, string urn) - { - List items = new List(); - Request req = new Request(); - req.Urn = urn; - req.ResultType = ResultType.IDataReader; - req.Fields = new string[] { "Name" }; - - Enumerator en = new Enumerator(); - using (IDataReader reader = en.Process(serverConnection, req).Data as IDataReader) - { - if (reader != null) - { - string name; - while (reader.Read()) - { - // Get the permission name - name = reader.GetString(0); - items.Add(name); - } - } - } - return items.ToArray(); - } - - /// - /// Handle request to initialize user view - /// - internal async Task HandleInitializeUserViewRequest(InitializeUserViewParams parameters, RequestContext requestContext) - { - ConnectionInfo connInfo; - ConnectionServiceInstance.TryFindConnection(parameters.ConnectionUri, out connInfo); - if (connInfo == null) - { - throw new ArgumentException("Invalid connection URI '{0}'", parameters.ConnectionUri); - } - - if (parameters.ContextId != null && parameters.ConnectionUri != null) - { - this.contextIdToConnectionUriMap.Add(parameters.ContextId, parameters.ConnectionUri); - } - - var serverConnection = ConnectionService.OpenServerConnection(connInfo, "DataContainer"); - - string databaseName = parameters.Database ?? "master"; - var schemaMap = LoadSchemas(databaseName, string.Empty, serverConnection); - - UserViewInfo userViewInfo = new UserViewInfo() - { - ObjectInfo = new UserInfo() - { - Type = DatabaseUserType.WithLogin, - Name = string.Empty, - LoginName = string.Empty, - Password = string.Empty, - DefaultSchema = string.Empty, - OwnedSchemas = new string[] { }, - DatabaseRoles = new string[] { }, - }, - SupportContainedUser = true, - SupportWindowsAuthentication = true, - SupportAADAuthentication = true, - SupportSQLAuthentication = true, - Languages = new string[] { }, - Schemas = schemaMap.Keys.ToArray(), - Logins = LoadSqlLogins(serverConnection), - DatabaseRoles = LoadDatabaseRoles(serverConnection, databaseName) - }; - - await requestContext.SendResult(userViewInfo); - } - - /// - /// Handle request to create a user - /// - internal async Task HandleCreateUserRequest(CreateUserParams parameters, RequestContext requestContext) - { - if (parameters.ContextId == null || !this.contextIdToConnectionUriMap.ContainsKey(parameters.ContextId)) - { - throw new ArgumentException("Invalid context ID"); - } - - string connectionUri = this.contextIdToConnectionUriMap[parameters.ContextId]; - var result = await ConfigureUser(connectionUri, - parameters.User, - ConfigAction.Create, - RunType.RunNow); - - await requestContext.SendResult(new CreateUserResult() - { - User = parameters.User, - Success = result.Item1, - ErrorMessage = result.Item2 - }); - } - - /// - /// Handle request to delete a user - /// - internal async Task HandleDeleteUserRequest(DeleteUserParams parameters, RequestContext requestContext) - { - ConnectionInfo connInfo; - ConnectionServiceInstance.TryFindConnection(parameters.ConnectionUri, out connInfo); - // if (connInfo == null) - // { - // // raise an error - // } - - CDataContainer dataContainer = CDataContainer.CreateDataContainer(connInfo, databaseExists: true); - string dbUrn = "Server/Database[@Name='" + Urn.EscapeString(parameters.Database) + "']"; - Database? parent = dataContainer.Server.GetSmoObject(new Urn(dbUrn)) as Database; - User user = parent.Users[parameters.Name]; - dataContainer.SqlDialogSubject = user; - DoDropObject(dataContainer); - - await requestContext.SendResult(new ResultStatus() - { - Success = true, - ErrorMessage = string.Empty - }); - } private IList GetDefaultLanguageOptions(CDataContainer dataContainer) { - // this.defaultLanguageComboBox.Items.Clear(); - // this.defaultLanguageComboBox.Items.Add(defaultLanguagePlaceholder); - // sort the languages alphabetically by alias SortedList sortedLanguages = new SortedList(Comparer.Default); @@ -600,97 +374,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Security return res; } - // code needs to be ported into the useraction class - // public void UserMemberships_OnRunNow(object sender, CDataContainer dataContainer) - // { - // UserPrototype currentPrototype = UserPrototypeFactory.GetInstance(dataContainer).CurrentPrototype; - - // //In case the UserGeneral/OwnedSchemas pages are loaded, - // //those will takes care of applying membership changes also. - // //Hence, we only need to apply changes in this method when those are not loaded. - // if (!currentPrototype.IsRoleMembershipChangesApplied) - // { - // //base.OnRunNow(sender); - - // User user = currentPrototype.ApplyChanges(); - - // //this.ExecutionMode = ExecutionMode.Success; - // dataContainer.ObjectName = currentPrototype.Name; - // dataContainer.SqlDialogSubject = user; - // } - - // //setting back to original after changes are applied - // currentPrototype.IsRoleMembershipChangesApplied = false; - // } - - // /// - // /// implementation of OnPanelRunNow - // /// - // /// - // public void UserOwnedSchemas_OnRunNow(object sender, CDataContainer dataContainer) - // { - // UserPrototype currentPrototype = UserPrototypeFactory.GetInstance(dataContainer).CurrentPrototype; - - // //In case the UserGeneral/Membership pages are loaded, - // //those will takes care of applying schema ownership changes also. - // //Hence, we only need to apply changes in this method when those are not loaded. - // if (!currentPrototype.IsSchemaOwnershipChangesApplied) - // { - // //base.OnRunNow(sender); - - // User user = currentPrototype.ApplyChanges(); - - // //this.ExecutionMode = ExecutionMode.Success; - // dataContainer.ObjectName = currentPrototype.Name; - // dataContainer.SqlDialogSubject = user; - // } - - // //setting back to original after changes are applied - // currentPrototype.IsSchemaOwnershipChangesApplied = false; - // } - - // how to populate defaults from prototype, will delete once refactored - // private void InitializeValuesInUiControls() - // { - - // IUserPrototypeWithDefaultLanguage defaultLanguagePrototype = this.currentUserPrototype - // as IUserPrototypeWithDefaultLanguage; - // if (defaultLanguagePrototype != null - // && defaultLanguagePrototype.IsDefaultLanguageSupported) - // { - // string defaultLanguageAlias = defaultLanguagePrototype.DefaultLanguageAlias; - - // //If engine returns default language as empty or null, that means the default language of - // //database will be used. - // //Default language is not applicable for users inside an uncontained authentication. - // if (string.IsNullOrEmpty(defaultLanguageAlias) - // && (this.DataContainer.Server.GetSmoObject(this.parentDbUrn) as Database).ContainmentType != ContainmentType.None) - // { - // defaultLanguageAlias = this.defaultLanguagePlaceholder; - // } - // this.defaultLanguageComboBox.Text = defaultLanguageAlias; - // } - - // IUserPrototypeWithDefaultSchema defaultSchemaPrototype = this.currentUserPrototype - // as IUserPrototypeWithDefaultSchema; - // if (defaultSchemaPrototype != null - // && defaultSchemaPrototype.IsDefaultSchemaSupported) - // { - // this.defaultSchemaTextBox.Text = defaultSchemaPrototype.DefaultSchema; - // } - // IUserPrototypeWithPassword userWithPwdPrototype = this.currentUserPrototype - // as IUserPrototypeWithPassword; - // if (userWithPwdPrototype != null - // && !this.DataContainer.IsNewObject) - // { - // this.passwordTextBox.Text = FAKE_PASSWORD; - // this.confirmPwdTextBox.Text = FAKE_PASSWORD; - // } - // } - - - - #endregion #region "Credential Handlers" @@ -833,167 +516,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Security }); } - /// - /// this is the main method that is called by DropAllObjects for every object - /// in the grid - /// - /// - private void DoDropObject(CDataContainer dataContainer) - { - // if a server isn't connected then there is nothing to do - if (dataContainer.Server == null) - { - return; - } - - var executionMode = dataContainer.Server.ConnectionContext.SqlExecutionModes; - var subjectExecutionMode = executionMode; - - //For Azure the ExecutionManager is different depending on which ExecutionManager - //used - one at the Server level and one at the Database level. So to ensure we - //don't use the wrong execution mode we need to set the mode for both (for on-prem - //this will essentially be a no-op) - SqlSmoObject sqlDialogSubject = null; - try - { - sqlDialogSubject = dataContainer.SqlDialogSubject; - } - catch (System.Exception) - { - //We may not have a valid dialog subject here (such as if the object hasn't been created yet) - //so in that case we'll just ignore it as that's a normal scenario. - } - if (sqlDialogSubject != null) - { - subjectExecutionMode = - sqlDialogSubject.ExecutionManager.ConnectionContext.SqlExecutionModes; - } - - Urn objUrn = sqlDialogSubject?.Urn; - System.Diagnostics.Debug.Assert(objUrn != null); - - SfcObjectQuery objectQuery = new SfcObjectQuery(dataContainer.Server); - - IDroppable droppableObj = null; - string[] fields = null; - - foreach (object obj in objectQuery.ExecuteIterator(new SfcQueryExpression(objUrn.ToString()), fields, null)) - { - System.Diagnostics.Debug.Assert(droppableObj == null, "there is only one object"); - droppableObj = obj as IDroppable; - } - - // For Azure databases, the SfcObjectQuery executions above may have overwritten our desired execution mode, so restore it - dataContainer.Server.ConnectionContext.SqlExecutionModes = executionMode; - if (sqlDialogSubject != null) - { - sqlDialogSubject.ExecutionManager.ConnectionContext.SqlExecutionModes = subjectExecutionMode; - } - - if (droppableObj == null) - { - string objectName = objUrn.GetAttribute("Name"); - objectName ??= string.Empty; - throw new Microsoft.SqlServer.Management.Smo.MissingObjectException("DropObjectsSR.ObjectDoesNotExist(objUrn.Type, objectName)"); - } - - //special case database drop - see if we need to delete backup and restore history - SpecialPreDropActionsForObject(dataContainer, droppableObj, - deleteBackupRestoreOrDisableAuditSpecOrDisableAudit: false, - dropOpenConnections: false); - - droppableObj.Drop(); - - //special case Resource Governor reconfigure - for pool, external pool, group Drop(), we need to issue - SpecialPostDropActionsForObject(dataContainer, droppableObj); - - } - - private void SpecialPreDropActionsForObject(CDataContainer dataContainer, IDroppable droppableObj, - bool deleteBackupRestoreOrDisableAuditSpecOrDisableAudit, bool dropOpenConnections) - { - // if a server isn't connected then there is nothing to do - if (dataContainer.Server == null) - { - return; - } - - Database db = droppableObj as Database; - - if (deleteBackupRestoreOrDisableAuditSpecOrDisableAudit) - { - if (db != null) - { - dataContainer.Server.DeleteBackupHistory(db.Name); - } - else - { - // else droppable object should be a server or database audit specification - ServerAuditSpecification sas = droppableObj as ServerAuditSpecification; - if (sas != null) - { - sas.Disable(); - } - else - { - DatabaseAuditSpecification das = droppableObj as DatabaseAuditSpecification; - if (das != null) - { - das.Disable(); - } - else - { - Audit aud = droppableObj as Audit; - if (aud != null) - { - aud.Disable(); - } - } - } - } - } - - // special case database drop - drop existing connections to the database other than this one - if (dropOpenConnections) - { - if (db?.ActiveConnections > 0) - { - // force the database to be single user - db.DatabaseOptions.UserAccess = DatabaseUserAccess.Single; - db.Alter(TerminationClause.RollbackTransactionsImmediately); - } - } - } - - private void SpecialPostDropActionsForObject(CDataContainer dataContainer, IDroppable droppableObj) - { - // if a server isn't connected then there is nothing to do - if (dataContainer.Server == null) - { - return; - } - - if (droppableObj is Policy) - { - Policy policyToDrop = (Policy)droppableObj; - if (!string.IsNullOrEmpty(policyToDrop.ObjectSet)) - { - ObjectSet objectSet = policyToDrop.Parent.ObjectSets[policyToDrop.ObjectSet]; - objectSet.Drop(); - } - } - - ResourcePool rp = droppableObj as ResourcePool; - ExternalResourcePool erp = droppableObj as ExternalResourcePool; - WorkloadGroup wg = droppableObj as WorkloadGroup; - - if (null != rp || null != erp || null != wg) - { - // Alter() Resource Governor to reconfigure - dataContainer.Server.ResourceGovernor.Alter(); - } - } - #endregion // "Helpers" // some potentially useful code for working with server & db roles to be refactored later diff --git a/src/Microsoft.SqlTools.ServiceLayer/Security/UserActions.cs b/src/Microsoft.SqlTools.ServiceLayer/Security/UserActions.cs index acee754b..fbec0a65 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Security/UserActions.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Security/UserActions.cs @@ -3,19 +3,441 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // +using System; +using System.Collections.Generic; +using System.Data; +using System.Threading.Tasks; +using System.Xml; +using Microsoft.SqlServer.Management.Common; using Microsoft.SqlServer.Management.Sdk.Sfc; using Microsoft.SqlServer.Management.Smo; +using Microsoft.SqlTools.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlTools.ServiceLayer.Management; using Microsoft.SqlTools.ServiceLayer.Security.Contracts; +using Microsoft.SqlTools.ServiceLayer.Utility; namespace Microsoft.SqlTools.ServiceLayer.Security { + internal class UserServiceHandlerImpl + { + private class UserViewState + { + public string Database { get; set; } + + public UserPrototypeData OriginalUserData { get; set; } + + public UserViewState(string database, UserPrototypeData originalUserData) + { + this.Database = database; + this.OriginalUserData = originalUserData; + } + } + + private ConnectionService? connectionService; + + private Dictionary contextIdToViewState = new Dictionary(); + + /// + /// Internal for testing purposes only + /// + internal ConnectionService ConnectionServiceInstance + { + get + { + connectionService ??= ConnectionService.Instance; + return connectionService; + } + + set + { + connectionService = value; + } + } + + /// + /// Handle request to initialize user view + /// + internal async Task HandleInitializeUserViewRequest(InitializeUserViewParams parameters, RequestContext requestContext) + { + if (string.IsNullOrWhiteSpace(parameters.Database)) + { + throw new ArgumentNullException("parameters.Database"); + } + + if (string.IsNullOrWhiteSpace(parameters.ContextId)) + { + throw new ArgumentNullException("parameters.ContextId"); + } + + ConnectionInfo originalConnInfo; + ConnectionServiceInstance.TryFindConnection(parameters.ConnectionUri, out originalConnInfo); + if (originalConnInfo == null) + { + throw new ArgumentException("Invalid connection URI '{0}'", parameters.ConnectionUri); + } + + string originalDatabaseName = originalConnInfo.ConnectionDetails.DatabaseName; + try + { + originalConnInfo.ConnectionDetails.DatabaseName = parameters.Database; + ConnectParams connectParams = new ConnectParams + { + OwnerUri = parameters.ContextId, + Connection = originalConnInfo.ConnectionDetails, + Type = Connection.ConnectionType.Default + }; + await this.ConnectionServiceInstance.Connect(connectParams); + } + finally + { + originalConnInfo.ConnectionDetails.DatabaseName = originalDatabaseName; + } + + ConnectionInfo connInfo; + this.ConnectionServiceInstance.TryFindConnection(parameters.ContextId, out connInfo); + CDataContainer dataContainer = CreateUserDataContainer(connInfo, null, ConfigAction.Create, parameters.Database); + + UserInfo? userInfo = null; + if (!parameters.IsNewObject) + { + User? existingUser = null; + string databaseUrn = string.Format(System.Globalization.CultureInfo.InvariantCulture, + "Server/Database[@Name='{0}']", + Urn.EscapeString(parameters.Database)); + Database? parentDb = dataContainer.Server.GetSmoObject(databaseUrn) as Database; + existingUser = dataContainer.Server.Databases[parentDb.Name].Users[parameters.Name]; + userInfo = new UserInfo() + { + Name = parameters.Name, + LoginName = existingUser.Login, + DefaultSchema = existingUser.DefaultSchema + }; + } + + UserPrototypeFactory userPrototypeFactory = UserPrototypeFactory.GetInstance(dataContainer, userInfo, originalData: null); + UserPrototype currentUserPrototype = userPrototypeFactory.GetUserPrototype(ExhaustiveUserTypes.LoginMappedUser); + + IUserPrototypeWithDefaultLanguage defaultLanguagePrototype = currentUserPrototype as IUserPrototypeWithDefaultLanguage; + string? defaultLanguageAlias = null; + if (defaultLanguagePrototype != null && defaultLanguagePrototype.IsDefaultLanguageSupported) + { + string dbUrn = "Server/Database[@Name='" + Urn.EscapeString(parameters.Database) + "']"; + defaultLanguageAlias = defaultLanguagePrototype.DefaultLanguageAlias; + //If engine returns default language as empty or null, that means the default language of + //database will be used. + //Default language is not applicable for users inside an uncontained authentication. + if (string.IsNullOrEmpty(defaultLanguageAlias) + && (dataContainer.Server.GetSmoObject(dbUrn) as Database).ContainmentType != ContainmentType.None) + { + defaultLanguageAlias = SR.DefaultLanguagePlaceholder; + } + } + + string? defaultSchema = null; + IUserPrototypeWithDefaultSchema defaultSchemaPrototype = currentUserPrototype as IUserPrototypeWithDefaultSchema; + if (defaultSchemaPrototype != null && defaultSchemaPrototype.IsDefaultSchemaSupported) + { + defaultSchema = defaultSchemaPrototype.DefaultSchema; + } + // IUserPrototypeWithPassword userWithPwdPrototype = currentUserPrototype as IUserPrototypeWithPassword; + // if (userWithPwdPrototype != null && !this.DataContainer.IsNewObject) + // { + // this.passwordTextBox.Text = FAKE_PASSWORD; + // this.confirmPwdTextBox.Text = FAKE_PASSWORD; + // } + + string? loginName = null; + IUserPrototypeWithMappedLogin mappedLoginPrototype = currentUserPrototype as IUserPrototypeWithMappedLogin; + if (mappedLoginPrototype != null) + { + loginName = mappedLoginPrototype.LoginName; + } + + List databaseRoles = new List(); + foreach (string role in currentUserPrototype.DatabaseRoleNames) + { + if (currentUserPrototype.IsRoleMember(role)) + { + databaseRoles.Add(role); + } + } + + List schemaNames = new List(); + foreach (string schema in currentUserPrototype.SchemaNames) + { + if (currentUserPrototype.IsSchemaOwner(schema)) + { + schemaNames.Add(schema); + } + } + + // default to dbo schema, if there isn't already a default + if (string.IsNullOrWhiteSpace(defaultSchema) && currentUserPrototype.SchemaNames.Contains("dbo")) + { + defaultSchema = "dbo"; + } + + ServerConnection serverConnection = dataContainer.ServerConnection; + UserViewInfo userViewInfo = new UserViewInfo() + { + ObjectInfo = new UserInfo() + { + Type = DatabaseUserType.WithLogin, + Name = currentUserPrototype.Name, + LoginName = loginName, + Password = string.Empty, + DefaultSchema = defaultSchema, + OwnedSchemas = schemaNames.ToArray(), + DatabaseRoles = databaseRoles.ToArray(), + DefaultLanguage = defaultLanguageAlias + }, + SupportContainedUser = false, // support for these will be added later + SupportWindowsAuthentication = false, + SupportAADAuthentication = false, + SupportSQLAuthentication = true, + Languages = new string[] { }, + Schemas = currentUserPrototype.SchemaNames.ToArray(), + Logins = LoadSqlLogins(serverConnection), + DatabaseRoles = currentUserPrototype.DatabaseRoleNames.ToArray() + }; + + this.contextIdToViewState.Add( + parameters.ContextId, + new UserViewState(parameters.Database, currentUserPrototype.CurrentState)); + + await requestContext.SendResult(userViewInfo); + } + + /// + /// Handle request to create a user + /// + internal async Task HandleCreateUserRequest(CreateUserParams parameters, RequestContext requestContext) + { + if (parameters.ContextId == null) + { + throw new ArgumentException("Invalid context ID"); + } + + UserViewState viewState; + this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState); + + if (viewState == null) + { + throw new ArgumentException("Invalid context ID view state"); + } + + Tuple result = ConfigureUser( + parameters.ContextId, + parameters.User, + ConfigAction.Create, + RunType.RunNow, + viewState.Database, + viewState.OriginalUserData); + + await requestContext.SendResult(new CreateUserResult() + { + User = parameters.User, + Success = result.Item1, + ErrorMessage = result.Item2 + }); + } + + /// + /// Handle request to update a user + /// + internal async Task HandleUpdateUserRequest(UpdateUserParams parameters, RequestContext requestContext) + { + if (parameters.ContextId == null) + { + throw new ArgumentException("Invalid context ID"); + } + + UserViewState viewState; + this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState); + + if (viewState == null) + { + throw new ArgumentException("Invalid context ID view state"); + } + + Tuple result = ConfigureUser( + parameters.ContextId, + parameters.User, + ConfigAction.Update, + RunType.RunNow, + viewState.Database, + viewState.OriginalUserData); + + await requestContext.SendResult(new ResultStatus() + { + Success = result.Item1, + ErrorMessage = result.Item2 + }); + } + + /// + /// Handle request to delete a user + /// + internal async Task HandleDeleteUserRequest(DeleteUserParams parameters, RequestContext requestContext) + { + ConnectionInfo connInfo; + ConnectionServiceInstance.TryFindConnection(parameters.ConnectionUri, out connInfo); + if (connInfo == null) + { + throw new ArgumentException("Invalid ConnectionUri"); + } + + if (string.IsNullOrWhiteSpace(parameters.Name) || string.IsNullOrWhiteSpace(parameters.Database)) + { + throw new ArgumentException("Invalid null parameter"); + } + + CDataContainer dataContainer = CDataContainer.CreateDataContainer(connInfo, databaseExists: true); + string dbUrn = "Server/Database[@Name='" + Urn.EscapeString(parameters.Database) + "']"; + Database? parent = dataContainer.Server.GetSmoObject(new Urn(dbUrn)) as Database; + User user = parent.Users[parameters.Name]; + dataContainer.SqlDialogSubject = user; + + CheckForSchemaOwnerships(parent, user); + DatabaseUtils.DoDropObject(dataContainer); + + await requestContext.SendResult(new ResultStatus() + { + Success = true, + ErrorMessage = string.Empty + }); + } + + internal async Task HandleDisposeUserViewRequest(DisposeUserViewRequestParams parameters, RequestContext requestContext) + { + this.ConnectionServiceInstance.Disconnect(new DisconnectParams(){ + OwnerUri = parameters.ContextId, + Type = null + }); + + if (parameters.ContextId != null) + { + this.contextIdToViewState.Remove(parameters.ContextId); + } + + await requestContext.SendResult(new ResultStatus() + { + Success = true, + ErrorMessage = string.Empty + }); + } + + internal CDataContainer CreateUserDataContainer( + ConnectionInfo connInfo, + UserInfo? user, + ConfigAction configAction, + string databaseName) + { + var serverConnection = ConnectionService.OpenServerConnection(connInfo, "DataContainer"); + var connectionInfoWithConnection = new SqlConnectionInfoWithConnection(); + connectionInfoWithConnection.ServerConnection = serverConnection; + + string urn = (configAction == ConfigAction.Update && user != null) + ? string.Format(System.Globalization.CultureInfo.InvariantCulture, + "Server/Database[@Name='{0}']/User[@Name='{1}']", + Urn.EscapeString(databaseName), + Urn.EscapeString(user.Name)) + : string.Format(System.Globalization.CultureInfo.InvariantCulture, + "Server/Database[@Name='{0}']", + Urn.EscapeString(databaseName)); + + ActionContext context = new ActionContext(serverConnection, "User", urn); + DataContainerXmlGenerator containerXml = new DataContainerXmlGenerator(context); + + if (configAction == ConfigAction.Create) + { + containerXml.AddProperty("itemtype", "User"); + } + + XmlDocument xmlDoc = containerXml.GenerateXmlDocument(); + return CDataContainer.CreateDataContainer(connectionInfoWithConnection, xmlDoc); + } + + internal Tuple ConfigureUser( + string? ownerUri, + UserInfo? user, + ConfigAction configAction, + RunType runType, + string databaseName, + UserPrototypeData? originalData) + { + ConnectionInfo connInfo; + this.ConnectionServiceInstance.TryFindConnection(ownerUri, out connInfo); + if (connInfo == null) + { + throw new ArgumentException("Invalid connection URI '{0}'", ownerUri); + } + + CDataContainer dataContainer = CreateUserDataContainer(connInfo, user, configAction, databaseName); + using (var actions = new UserActions(dataContainer, user, configAction, originalData)) + { + var executionHandler = new ExecutonHandler(actions); + executionHandler.RunNow(runType, this); + if (executionHandler.ExecutionResult == ExecutionMode.Failure) + { + throw executionHandler.ExecutionFailureException; + } + } + + return new Tuple(true, string.Empty); + } + + private void CheckForSchemaOwnerships(Database parentDb, User existingUser) + { + foreach (Schema sch in parentDb.Schemas) + { + var comparer = parentDb.GetStringComparer(); + if (comparer.Compare(sch.Owner, existingUser.Name) == 0) + { + throw new ApplicationException("Cannot drop user since it owns a schema"); + } + } + } + + private string[] LoadSqlLogins(ServerConnection serverConnection) + { + return LoadItems(serverConnection, "Server/Login"); + } + + private string[] LoadItems(ServerConnection serverConnection, string urn) + { + List items = new List(); + Request req = new Request(); + req.Urn = urn; + req.ResultType = ResultType.IDataReader; + req.Fields = new string[] { "Name" }; + + Enumerator en = new Enumerator(); + using (IDataReader reader = en.Process(serverConnection, req).Data as IDataReader) + { + if (reader != null) + { + string name; + while (reader.Read()) + { + // Get the permission name + name = reader.GetString(0); + items.Add(name); + } + } + } + return items.ToArray(); + } + } + internal class UserActions : ManagementActionBase { #region Variables //private UserPrototypeData userData; private UserPrototype userPrototype; - private UserInfo user; + private UserInfo? user; private ConfigAction configAction; #endregion @@ -27,13 +449,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Security public UserActions( CDataContainer context, UserInfo? user, - ConfigAction configAction) + ConfigAction configAction, + UserPrototypeData? originalData) { this.DataContainer = context; this.user = user; this.configAction = configAction; - this.userPrototype = InitUserNew(context, user); + this.userPrototype = InitUserPrototype(context, user, originalData); } // /// @@ -65,14 +488,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Security } } - private UserPrototype InitUserNew(CDataContainer dataContainer, UserInfo user) + private UserPrototype InitUserPrototype(CDataContainer dataContainer, UserInfo user, UserPrototypeData? originalData) { ExhaustiveUserTypes currentUserType; - UserPrototypeFactory userPrototypeFactory = UserPrototypeFactory.GetInstance(dataContainer, user); + UserPrototypeFactory userPrototypeFactory = UserPrototypeFactory.GetInstance(dataContainer, user, originalData); if (dataContainer.IsNewObject) { - if (IsParentDatabaseContained(dataContainer.ParentUrn, dataContainer)) + if (dataContainer.Server != null && IsParentDatabaseContained(dataContainer.ParentUrn, dataContainer.Server)) { currentUserType = ExhaustiveUserTypes.SqlUserWithPassword; } @@ -91,8 +514,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Security return currentUserPrototype; } - private ExhaustiveUserTypes GetCurrentUserTypeForExistingUser(User user) + private ExhaustiveUserTypes GetCurrentUserTypeForExistingUser(User? user) { + if (user == null) + { + return ExhaustiveUserTypes.Unknown; + } + switch (user.UserType) { case UserType.SqlUser: @@ -124,10 +552,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Security } } - private bool IsParentDatabaseContained(Urn parentDbUrn, CDataContainer dataContainer) + private static bool IsParentDatabaseContained(Urn parentDbUrn, Server server) { string parentDbName = parentDbUrn.GetNameForType("Database"); - Database parentDatabase = dataContainer.Server.Databases[parentDbName]; + Database parentDatabase = server.Databases[parentDbName]; if (parentDatabase.IsSupportedProperty("ContainmentType") && parentDatabase.ContainmentType == ContainmentType.Partial) diff --git a/src/Microsoft.SqlTools.ServiceLayer/Security/UserData.cs b/src/Microsoft.SqlTools.ServiceLayer/Security/UserData.cs index 28066a4c..f81500aa 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Security/UserData.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Security/UserData.cs @@ -103,7 +103,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security this.isMember = new Dictionary(); } - public UserPrototypeData(CDataContainer context, UserInfo userInfo) + public UserPrototypeData(CDataContainer context, UserInfo? userInfo) { this.isSchemaOwned = new Dictionary(); this.isMember = new Dictionary(); @@ -114,15 +114,21 @@ namespace Microsoft.SqlTools.ServiceLayer.Security } else { - this.name = userInfo.Name; - this.mappedLoginName = userInfo.LoginName; - this.defaultSchemaName = userInfo.DefaultSchema; - this.password = DatabaseUtils.GetReadOnlySecureString(userInfo.Password); + if (userInfo != null) + { + this.name = userInfo.Name; + this.mappedLoginName = userInfo.LoginName; + this.defaultSchemaName = userInfo.DefaultSchema; + if (!string.IsNullOrEmpty(userInfo.Password)) + { + this.password = DatabaseUtils.GetReadOnlySecureString(userInfo.Password); + } + } } - this.LoadRoleMembership(context); + this.LoadRoleMembership(context, userInfo); - this.LoadSchemaData(context); + this.LoadSchemaData(context, userInfo); } public UserPrototypeData Clone() @@ -145,20 +151,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Security foreach (string key in this.isMember?.Keys ?? Enumerable.Empty()) { - if (result.isMember?.ContainsKey(key) == true - && this.isMember?.ContainsKey(key) == true) - { - result.isMember[key] = this.isMember[key]; - } + result.isMember[key] = this.isMember[key]; } foreach (string key in this.isSchemaOwned?.Keys ?? Enumerable.Empty()) { - if (result.isSchemaOwned?.ContainsKey(key) == true - && this.isSchemaOwned?.ContainsKey(key) == true) - { - result.isSchemaOwned[key] = this.isSchemaOwned[key]; - } + result.isSchemaOwned[key] = this.isSchemaOwned[key]; } return result; @@ -248,7 +246,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security /// Loads role membership of a database user. /// /// - private void LoadRoleMembership(CDataContainer context) + private void LoadRoleMembership(CDataContainer context, UserInfo? userInfo) { Urn objUrn = new Urn(context.ObjectUrn); Urn databaseUrn = objUrn.Parent; @@ -259,20 +257,25 @@ namespace Microsoft.SqlTools.ServiceLayer.Security return; } - User existingUser = context.Server.Databases[parentDb.Name].Users[objUrn.GetNameForType("User")]; + string userName = userInfo?.Name ?? objUrn.GetNameForType("User"); + User existingUser = context.Server.Databases[parentDb.Name].Users[userName]; foreach (DatabaseRole dbRole in parentDb.Roles) { var comparer = parentDb.GetStringComparer(); if (comparer.Compare(dbRole.Name, "public") != 0) { - if (context.IsNewObject) + if (userInfo != null && userInfo.DatabaseRoles != null) { - this.isMember[dbRole.Name] = false; + this.isMember[dbRole.Name] = userInfo.DatabaseRoles.Contains(dbRole.Name); + } + else if (existingUser != null) + { + this.isMember[dbRole.Name] = existingUser.IsMember(dbRole.Name); } else { - this.isMember[dbRole.Name] = existingUser.IsMember(dbRole.Name); + this.isMember[dbRole.Name] = false; } } } @@ -282,7 +285,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security /// Loads schema ownership related data. /// /// - private void LoadSchemaData(CDataContainer context) + private void LoadSchemaData(CDataContainer context, UserInfo? userInfo) { Urn objUrn = new Urn(context.ObjectUrn); Urn databaseUrn = objUrn.Parent; @@ -293,7 +296,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Security return; } - User existingUser = context.Server.Databases[parentDb.Name].Users[objUrn.GetNameForType("User")]; + string userName = userInfo?.Name ?? objUrn.GetNameForType("User"); + User existingUser = context.Server.Databases[parentDb.Name].Users[userName]; if (!SqlMgmtUtils.IsYukonOrAbove(context.Server) || parentDb.CompatibilityLevel <= CompatibilityLevel.Version80) @@ -303,15 +307,19 @@ namespace Microsoft.SqlTools.ServiceLayer.Security foreach (Schema sch in parentDb.Schemas) { - if (context.IsNewObject) + if (userInfo != null && userInfo.OwnedSchemas != null) { - this.isSchemaOwned[sch.Name] = false; + this.isSchemaOwned[sch.Name] = userInfo.OwnedSchemas.Contains(sch.Name); } - else + else if (existingUser != null) { var comparer = parentDb.GetStringComparer(); this.isSchemaOwned[sch.Name] = comparer.Compare(sch.Owner, existingUser.Name) == 0; } + else + { + this.isSchemaOwned[sch.Name] = false; + } } } } @@ -334,6 +342,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Security #region IUserPrototype Members + public UserPrototypeData CurrentState + { + get + { + return this.currentState; + } + } + public string Name { get @@ -996,15 +1012,15 @@ namespace Microsoft.SqlTools.ServiceLayer.Security } } - private UserPrototypeFactory(CDataContainer context, UserInfo user) + private UserPrototypeFactory(CDataContainer context, UserInfo user, UserPrototypeData? originalData) { this.context = context; - this.originalData = new UserPrototypeData(this.context, user); - this.currentData = this.originalData.Clone(); + this.currentData = new UserPrototypeData(this.context, user); + this.originalData = originalData ?? this.currentData.Clone(); } - public static UserPrototypeFactory GetInstance(CDataContainer context, UserInfo user) + public static UserPrototypeFactory GetInstance(CDataContainer context, UserInfo? user, UserPrototypeData? originalData) { if (singletonInstance != null && singletonInstance.context != context) @@ -1012,7 +1028,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security singletonInstance = null; } - singletonInstance ??= new UserPrototypeFactory(context, user); + singletonInstance ??= new UserPrototypeFactory(context, user, originalData); return singletonInstance; } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Utility/DatabaseUtils.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/DatabaseUtils.cs index 7e6d2895..8367c4a2 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Utility/DatabaseUtils.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Utility/DatabaseUtils.cs @@ -5,6 +5,10 @@ #nullable disable +using Microsoft.SqlServer.Management.Common; +using Microsoft.SqlServer.Management.Dmf; +using Microsoft.SqlServer.Management.Sdk.Sfc; +using Microsoft.SqlServer.Management.Smo; using Microsoft.SqlTools.ServiceLayer.Management; using System; using System.Collections.Generic; @@ -80,5 +84,166 @@ namespace Microsoft.SqlTools.ServiceLayer.Utility return ss; } + + /// + /// this is the main method that is called by DropAllObjects for every object + /// in the grid + /// + /// + public static void DoDropObject(CDataContainer dataContainer) + { + // if a server isn't connected then there is nothing to do + if (dataContainer.Server == null) + { + return; + } + + var executionMode = dataContainer.Server.ConnectionContext.SqlExecutionModes; + var subjectExecutionMode = executionMode; + + //For Azure the ExecutionManager is different depending on which ExecutionManager + //used - one at the Server level and one at the Database level. So to ensure we + //don't use the wrong execution mode we need to set the mode for both (for on-prem + //this will essentially be a no-op) + SqlSmoObject sqlDialogSubject = null; + try + { + sqlDialogSubject = dataContainer.SqlDialogSubject; + } + catch (System.Exception) + { + //We may not have a valid dialog subject here (such as if the object hasn't been created yet) + //so in that case we'll just ignore it as that's a normal scenario. + } + if (sqlDialogSubject != null) + { + subjectExecutionMode = + sqlDialogSubject.ExecutionManager.ConnectionContext.SqlExecutionModes; + } + + Urn objUrn = sqlDialogSubject?.Urn; + System.Diagnostics.Debug.Assert(objUrn != null); + + SfcObjectQuery objectQuery = new SfcObjectQuery(dataContainer.Server); + + IDroppable droppableObj = null; + string[] fields = null; + + foreach (object obj in objectQuery.ExecuteIterator(new SfcQueryExpression(objUrn.ToString()), fields, null)) + { + System.Diagnostics.Debug.Assert(droppableObj == null, "there is only one object"); + droppableObj = obj as IDroppable; + } + + // For Azure databases, the SfcObjectQuery executions above may have overwritten our desired execution mode, so restore it + dataContainer.Server.ConnectionContext.SqlExecutionModes = executionMode; + if (sqlDialogSubject != null) + { + sqlDialogSubject.ExecutionManager.ConnectionContext.SqlExecutionModes = subjectExecutionMode; + } + + if (droppableObj == null) + { + string objectName = objUrn.GetAttribute("Name"); + objectName ??= string.Empty; + throw new Microsoft.SqlServer.Management.Smo.MissingObjectException("DropObjectsSR.ObjectDoesNotExist(objUrn.Type, objectName)"); + } + + //special case database drop - see if we need to delete backup and restore history + SpecialPreDropActionsForObject(dataContainer, droppableObj, + deleteBackupRestoreOrDisableAuditSpecOrDisableAudit: false, + dropOpenConnections: false); + + droppableObj.Drop(); + + //special case Resource Governor reconfigure - for pool, external pool, group Drop(), we need to issue + SpecialPostDropActionsForObject(dataContainer, droppableObj); + + } + + private static void SpecialPreDropActionsForObject(CDataContainer dataContainer, IDroppable droppableObj, + bool deleteBackupRestoreOrDisableAuditSpecOrDisableAudit, bool dropOpenConnections) + { + // if a server isn't connected then there is nothing to do + if (dataContainer.Server == null) + { + return; + } + + Database db = droppableObj as Database; + + if (deleteBackupRestoreOrDisableAuditSpecOrDisableAudit) + { + if (db != null) + { + dataContainer.Server.DeleteBackupHistory(db.Name); + } + else + { + // else droppable object should be a server or database audit specification + ServerAuditSpecification sas = droppableObj as ServerAuditSpecification; + if (sas != null) + { + sas.Disable(); + } + else + { + DatabaseAuditSpecification das = droppableObj as DatabaseAuditSpecification; + if (das != null) + { + das.Disable(); + } + else + { + Audit aud = droppableObj as Audit; + if (aud != null) + { + aud.Disable(); + } + } + } + } + } + + // special case database drop - drop existing connections to the database other than this one + if (dropOpenConnections) + { + if (db?.ActiveConnections > 0) + { + // force the database to be single user + db.DatabaseOptions.UserAccess = DatabaseUserAccess.Single; + db.Alter(TerminationClause.RollbackTransactionsImmediately); + } + } + } + + private static void SpecialPostDropActionsForObject(CDataContainer dataContainer, IDroppable droppableObj) + { + // if a server isn't connected then there is nothing to do + if (dataContainer.Server == null) + { + return; + } + + if (droppableObj is Policy) + { + Policy policyToDrop = (Policy)droppableObj; + if (!string.IsNullOrEmpty(policyToDrop.ObjectSet)) + { + ObjectSet objectSet = policyToDrop.Parent.ObjectSets[policyToDrop.ObjectSet]; + objectSet.Drop(); + } + } + + ResourcePool rp = droppableObj as ResourcePool; + ExternalResourcePool erp = droppableObj as ExternalResourcePool; + WorkloadGroup wg = droppableObj as WorkloadGroup; + + if (null != rp || null != erp || null != wg) + { + // Alter() Resource Governor to reconfigure + dataContainer.Server.ResourceGovernor.Alter(); + } + } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Security/SecurityTestUtils.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Security/SecurityTestUtils.cs index e9fca086..b3d7eb16 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Security/SecurityTestUtils.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Security/SecurityTestUtils.cs @@ -39,7 +39,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Security EnforcePasswordExpiration = false, Password = "placeholder", OldPassword = "placeholder", - DefaultLanguage = "us_english", + DefaultLanguage = "English - us_english", DefaultDatabase = "master" }; } @@ -53,7 +53,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Security LoginName = loginName, Password = "placeholder", DefaultSchema = "dbo", - OwnedSchemas = new string[] { "dbo" } + OwnedSchemas = new string[] { "" } }; } @@ -124,5 +124,131 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Security var service = new SecurityService(); await SecurityTestUtils.DeleteCredential(service, connectionResult, credential); } + + internal static async Task CreateLogin(SecurityService service, TestConnectionResult connectionResult, string contextId) + { + var initializeLoginViewRequestParams = new InitializeLoginViewRequestParams + { + ConnectionUri = connectionResult.ConnectionInfo.OwnerUri, + ContextId = contextId, + IsNewObject = true + }; + + var loginParams = new CreateLoginParams + { + ContextId = contextId, + Login = SecurityTestUtils.GetTestLoginInfo() + }; + + var createLoginContext = new Mock>(); + createLoginContext.Setup(x => x.SendResult(It.IsAny())) + .Returns(Task.FromResult(new object())); + var initializeLoginViewContext = new Mock>(); + initializeLoginViewContext.Setup(x => x.SendResult(It.IsAny())) + .Returns(Task.FromResult(new LoginViewInfo())); + + // call the create login method + await service.HandleInitializeLoginViewRequest(initializeLoginViewRequestParams, initializeLoginViewContext.Object); + await service.HandleCreateLoginRequest(loginParams, createLoginContext.Object); + + return loginParams.Login; + } + + internal static async Task DeleteLogin(SecurityService service, TestConnectionResult connectionResult, LoginInfo login) + { + // cleanup created login + var deleteParams = new DeleteLoginParams + { + ConnectionUri = connectionResult.ConnectionInfo.OwnerUri, + Name = login.Name + }; + + var deleteContext = new Mock>(); + deleteContext.Setup(x => x.SendResult(It.IsAny())) + .Returns(Task.FromResult(new object())); + + // call the create login method + await service.HandleDeleteLoginRequest(deleteParams, deleteContext.Object); + } + + internal static async Task CreateUser( + UserServiceHandlerImpl service, + TestConnectionResult connectionResult, + string contextId, + LoginInfo login) + { + var initializeViewRequestParams = new InitializeUserViewParams + { + ConnectionUri = connectionResult.ConnectionInfo.OwnerUri, + ContextId = contextId, + IsNewObject = true, + Database = "master" + }; + + var initializeUserViewContext = new Mock>(); + initializeUserViewContext.Setup(x => x.SendResult(It.IsAny())) + .Returns(Task.FromResult(new LoginViewInfo())); + + await service.HandleInitializeUserViewRequest(initializeViewRequestParams, initializeUserViewContext.Object); + + var userParams = new CreateUserParams + { + ContextId = contextId, + User = SecurityTestUtils.GetTestUserInfo(login.Name) + }; + + var createUserContext = new Mock>(); + createUserContext.Setup(x => x.SendResult(It.IsAny())) + .Returns(Task.FromResult(new object())); + + // call the create login method + await service.HandleCreateUserRequest(userParams, createUserContext.Object); + + // verify the result + createUserContext.Verify(x => x.SendResult(It.Is + (p => p.Success && p.User.Name != string.Empty))); + + return userParams.User; + } + + internal static async Task UpdateUser( + UserServiceHandlerImpl service, + TestConnectionResult connectionResult, + string contextId, + UserInfo user) + { + // update the user + user.OwnedSchemas = new string[] { "dbo" }; + var updateParams = new UpdateUserParams + { + ContextId = contextId, + User = user + }; + var updateUserContext = new Mock>(); + // call the create login method + await service.HandleUpdateUserRequest(updateParams, updateUserContext.Object); + // verify the result + updateUserContext.Verify(x => x.SendResult(It.Is(p => p.Success))); + } + + internal static async Task DeleteUser(UserServiceHandlerImpl service, TestConnectionResult connectionResult, UserInfo user) + { + // cleanup created user + var deleteParams = new DeleteUserParams + { + ConnectionUri = connectionResult.ConnectionInfo.OwnerUri, + Name = user.Name, + Database = connectionResult.ConnectionInfo.ConnectionDetails.DatabaseName + }; + + var deleteContext = new Mock>(); + deleteContext.Setup(x => x.SendResult(It.IsAny())) + .Returns(Task.FromResult(new object())); + + // call the create user method + await service.HandleDeleteUserRequest(deleteParams, deleteContext.Object); + + deleteContext.Verify(x => x.SendResult(It.Is(p => p.Success))); + } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Security/UserTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Security/UserTests.cs index 817bdf90..bc55643c 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Security/UserTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Security/UserTests.cs @@ -4,12 +4,9 @@ // using System.Threading.Tasks; -using Microsoft.SqlTools.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility; using Microsoft.SqlTools.ServiceLayer.Security; -using Microsoft.SqlTools.ServiceLayer.Security.Contracts; using Microsoft.SqlTools.ServiceLayer.Test.Common; -using Moq; namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Security { @@ -21,56 +18,50 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Security /// /// Test the basic Create User method handler /// - // [Test] - public async Task TestHandleCreateUserRequest() + //[Test] - enable tests in separate change + public async Task TestHandleCreateUserWithLoginRequest() { using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) { // setup + SecurityService service = new SecurityService(); + UserServiceHandlerImpl userService = new UserServiceHandlerImpl(); var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath); var contextId = System.Guid.NewGuid().ToString(); - var initializeLoginViewRequestParams = new InitializeLoginViewRequestParams - { - ConnectionUri = connectionResult.ConnectionInfo.OwnerUri, - ContextId = contextId, - IsNewObject = true - }; + var login = await SecurityTestUtils.CreateLogin(service, connectionResult, contextId); - var loginParams = new CreateLoginParams - { - ContextId = contextId, - Login = SecurityTestUtils.GetTestLoginInfo() - }; + var user = await SecurityTestUtils.CreateUser(userService, connectionResult, contextId, login); - var createLoginContext = new Mock>(); - createLoginContext.Setup(x => x.SendResult(It.IsAny())) - .Returns(Task.FromResult(new object())); - var initializeLoginViewContext = new Mock>(); - initializeLoginViewContext.Setup(x => x.SendResult(It.IsAny())) - .Returns(Task.FromResult(new LoginViewInfo())); + await SecurityTestUtils.DeleteUser(userService, connectionResult, user); - // call the create login method + await SecurityTestUtils.DeleteLogin(service, connectionResult, login); + } + } + + /// + /// Test the basic Update User method handler + /// + //[Test] - enable tests in separate change + public async Task TestHandleUpdateUserWithLoginRequest() + { + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + { + // setup SecurityService service = new SecurityService(); - await service.HandleInitializeLoginViewRequest(initializeLoginViewRequestParams, initializeLoginViewContext.Object); - await service.HandleCreateLoginRequest(loginParams, createLoginContext.Object); + UserServiceHandlerImpl userService = new UserServiceHandlerImpl(); + var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath); + var contextId = System.Guid.NewGuid().ToString(); - var userParams = new CreateUserParams - { - ContextId = connectionResult.ConnectionInfo.OwnerUri, - User = SecurityTestUtils.GetTestUserInfo(loginParams.Login.Name) - }; + var login = await SecurityTestUtils.CreateLogin(service, connectionResult, contextId); - var createUserContext = new Mock>(); - createUserContext.Setup(x => x.SendResult(It.IsAny())) - .Returns(Task.FromResult(new object())); + var user = await SecurityTestUtils.CreateUser(userService, connectionResult, contextId, login); - // call the create login method - await service.HandleCreateUserRequest(userParams, createUserContext.Object); + await SecurityTestUtils.UpdateUser(userService, connectionResult, contextId, user); - // verify the result - createUserContext.Verify(x => x.SendResult(It.Is - (p => p.Success && p.User.Name != string.Empty))); + await SecurityTestUtils.DeleteUser(userService, connectionResult, user); + + await SecurityTestUtils.DeleteLogin(service, connectionResult, login); } } }