Update User edit handlers (#1903)

* WIP

* Update user tests

* WIP updates

* WIP

* Fix several edit use bugs

* Disable failing tests

* minor updates

* Remove unused using
This commit is contained in:
Karl Burtram
2023-03-06 14:43:38 -08:00
committed by GitHub
parent 781fc35a87
commit a074b5bf67
7 changed files with 863 additions and 573 deletions

View File

@@ -5,7 +5,6 @@
using Microsoft.SqlTools.Hosting.Protocol.Contracts; using Microsoft.SqlTools.Hosting.Protocol.Contracts;
using Microsoft.SqlTools.ServiceLayer.Utility; using Microsoft.SqlTools.ServiceLayer.Utility;
using Microsoft.SqlTools.Utility;
namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts
{ {
@@ -18,7 +17,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts
public string? ConnectionUri { get; set; } public string? ConnectionUri { get; set; }
public bool isNewObject { get; set; } public bool IsNewObject { get; set; }
public string? Database { get; set; } public string? Database { get; set; }
@@ -41,7 +40,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts
/// <summary> /// <summary>
/// Create User parameters /// Create User parameters
/// </summary> /// </summary>
public class CreateUserParams : GeneralRequestDetails public class CreateUserParams
{ {
public string? ContextId { get; set; } public string? ContextId { get; set; }
public UserInfo? User { get; set; } public UserInfo? User { get; set; }
@@ -68,6 +67,28 @@ namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts
RequestType<CreateUserParams, CreateUserResult>.Create("objectManagement/createUser"); RequestType<CreateUserParams, CreateUserResult>.Create("objectManagement/createUser");
} }
/// <summary>
/// Update User parameters
/// </summary>
public class UpdateUserParams
{
public string? ContextId { get; set; }
public UserInfo? User { get; set; }
}
/// <summary>
/// Update User request type
/// </summary>
public class UpdateUserRequest
{
/// <summary>
/// Request definition
/// </summary>
public static readonly
RequestType<UpdateUserParams, ResultStatus> Type =
RequestType<UpdateUserParams, ResultStatus>.Create("objectManagement/updateUser");
}
/// <summary> /// <summary>
/// Delete User params /// Delete User params
/// </summary> /// </summary>
@@ -92,4 +113,25 @@ namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts
RequestType<DeleteUserParams, ResultStatus> Type = RequestType<DeleteUserParams, ResultStatus> Type =
RequestType<DeleteUserParams, ResultStatus>.Create("objectManagement/deleteUser"); RequestType<DeleteUserParams, ResultStatus>.Create("objectManagement/deleteUser");
} }
/// <summary>
/// Update User params
/// </summary>
public class DisposeUserViewRequestParams
{
public string? ContextId { get; set; }
}
/// <summary>
/// Update User request type
/// </summary>
public class DisposeUserViewRequest
{
/// <summary>
/// Request definition
/// </summary>
public static readonly
RequestType<DisposeUserViewRequestParams, ResultStatus> Type =
RequestType<DisposeUserViewRequestParams, ResultStatus>.Create("objectManagement/disposeUserView");
}
} }

View File

@@ -9,10 +9,7 @@ using System.Collections.Generic;
using System.Data; using System.Data;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
using System.Xml;
using Microsoft.SqlServer.Management.Common; using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlServer.Management.Dmf;
using Microsoft.SqlServer.Management.Sdk.Sfc;
using Microsoft.SqlServer.Management.Smo; using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.Hosting.Protocol; using Microsoft.SqlTools.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection;
@@ -32,6 +29,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
private ConnectionService? connectionService; private ConnectionService? connectionService;
private UserServiceHandlerImpl userServiceHandler;
private static readonly Lazy<SecurityService> instance = new Lazy<SecurityService>(() => new SecurityService()); private static readonly Lazy<SecurityService> instance = new Lazy<SecurityService>(() => new SecurityService());
private Dictionary<string, string> contextIdToConnectionUriMap = new Dictionary<string, string>(); private Dictionary<string, string> contextIdToConnectionUriMap = new Dictionary<string, string>();
@@ -41,6 +40,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
/// </summary> /// </summary>
public SecurityService() public SecurityService()
{ {
userServiceHandler = new UserServiceHandlerImpl();
} }
/// <summary> /// <summary>
@@ -99,9 +99,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
this.ServiceHost.SetRequestHandler(DisposeLoginViewRequest.Type, HandleDisposeLoginViewRequest, true); this.ServiceHost.SetRequestHandler(DisposeLoginViewRequest.Type, HandleDisposeLoginViewRequest, true);
// User request handlers // User request handlers
this.ServiceHost.SetRequestHandler(InitializeUserViewRequest.Type, HandleInitializeUserViewRequest, true); this.ServiceHost.SetRequestHandler(InitializeUserViewRequest.Type, this.userServiceHandler.HandleInitializeUserViewRequest, true);
this.ServiceHost.SetRequestHandler(CreateUserRequest.Type, HandleCreateUserRequest, true); this.ServiceHost.SetRequestHandler(CreateUserRequest.Type, this.userServiceHandler.HandleCreateUserRequest, true);
this.ServiceHost.SetRequestHandler(DeleteUserRequest.Type, HandleDeleteUserRequest, true); this.ServiceHost.SetRequestHandler(UpdateUserRequest.Type, this.userServiceHandler.HandleUpdateUserRequest, true);
this.ServiceHost.SetRequestHandler(DeleteUserRequest.Type, this.userServiceHandler.HandleDeleteUserRequest, true);
this.ServiceHost.SetRequestHandler(DisposeUserViewRequest.Type, this.userServiceHandler.HandleDisposeUserViewRequest, true);
} }
@@ -150,7 +152,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
LoginPrototype newPrototype = new LoginPrototype(dataContainer.Server, dataContainer.Server.Logins[parameters.Login.Name]); LoginPrototype newPrototype = new LoginPrototype(dataContainer.Server, dataContainer.Server.Logins[parameters.Login.Name]);
var _ =newPrototype.ServerRoles.ServerRoleNames; var _ =newPrototype.ServerRoles.ServerRoleNames;
foreach (string role in parameters.Login.ServerRoles) foreach (string role in parameters.Login.ServerRoles ?? Enumerable.Empty<string>())
{ {
newPrototype.ServerRoles.SetMember(role, true); newPrototype.ServerRoles.SetMember(role, true);
} }
@@ -175,7 +177,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
Login login = dataContainer.Server?.Logins[parameters.Name]; Login login = dataContainer.Server?.Logins[parameters.Name];
dataContainer.SqlDialogSubject = login; dataContainer.SqlDialogSubject = login;
DoDropObject(dataContainer); DatabaseUtils.DoDropObject(dataContainer);
await requestContext.SendResult(new ResultStatus() await requestContext.SendResult(new ResultStatus()
{ {
@@ -224,7 +226,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
} }
// check that password and confirm password controls' text matches // check that password and confirm password controls' text matches
if (0 != String.Compare(prototype.SqlPassword, prototype.SqlPasswordConfirm, StringComparison.Ordinal)) if (0 != string.Compare(prototype.SqlPassword, prototype.SqlPasswordConfirm, StringComparison.Ordinal))
{ {
// raise error here // raise error here
} }
@@ -347,237 +349,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
if (l == null) return null; if (l == null) return null;
return string.Format("{0} - {1}", l.Language.Alias, l.Language.Name); return string.Format("{0} - {1}", l.Language.Alias, l.Language.Name);
} }
#endregion
#region "User Handlers"
internal Task<Tuple<bool, string>> ConfigureUser(
string? ownerUri,
UserInfo? user,
ConfigAction configAction,
RunType runType)
{
return Task<Tuple<bool, string>>.Run(() =>
{
try
{
ConnectionInfo connInfo;
ConnectionServiceInstance.TryFindConnection(ownerUri, out connInfo);
if (connInfo == null)
{
throw new ArgumentException("Invalid connection URI '{0}'", ownerUri);
}
var serverConnection = ConnectionService.OpenServerConnection(connInfo, "DataContainer");
var connectionInfoWithConnection = new SqlConnectionInfoWithConnection();
connectionInfoWithConnection.ServerConnection = serverConnection;
string urn = string.Format(System.Globalization.CultureInfo.InvariantCulture,
"Server/Database[@Name='{0}']",
Urn.EscapeString(serverConnection.DatabaseName));
ActionContext context = new ActionContext(serverConnection, "new_user", urn);
DataContainerXmlGenerator containerXml = new DataContainerXmlGenerator(context);
containerXml.AddProperty("itemtype", "User");
XmlDocument xmlDoc = containerXml.GenerateXmlDocument();
bool objectExists = configAction != ConfigAction.Create;
CDataContainer dataContainer = CDataContainer.CreateDataContainer(connectionInfoWithConnection, xmlDoc);
using (var actions = new UserActions(dataContainer, user, configAction))
{
var executionHandler = new ExecutonHandler(actions);
executionHandler.RunNow(runType, this);
}
return new Tuple<bool, string>(true, string.Empty);
}
catch (Exception ex)
{
return new Tuple<bool, string>(false, ex.ToString());
}
});
}
private Dictionary<string, SchemaOwnership> LoadSchemas(string databaseName, string dbroleName, ServerConnection serverConnection)
{
bool isPropertiesMode = false;
Dictionary<string, SchemaOwnership> schemaOwnership = new Dictionary<string, SchemaOwnership>();
Enumerator en = new Enumerator();
Request req = new Request();
req.Fields = new String[] { "Name", "Owner" };
req.Urn = "Server/Database[@Name='" + Urn.EscapeString(databaseName) + "']/Schema";
DataTable dt = en.Process(serverConnection, req);
System.Diagnostics.Debug.Assert((dt != null) && (0 < dt.Rows.Count), "enumerator did not return schemas");
System.Diagnostics.Debug.Assert(!isPropertiesMode || (dbroleName.Length != 0), "role name is not known");
foreach (DataRow dr in dt.Rows)
{
string schemaName = Convert.ToString(dr["Name"], System.Globalization.CultureInfo.InvariantCulture);
string schemaOwner = Convert.ToString(dr["Owner"], System.Globalization.CultureInfo.InvariantCulture);
bool roleOwnsSchema = isPropertiesMode && (0 == String.Compare(dbroleName, schemaOwner, StringComparison.Ordinal));
if (schemaName != null)
{
schemaOwnership[schemaName] = new SchemaOwnership(roleOwnsSchema);
}
}
return schemaOwnership;
}
private string[] LoadDatabaseRoles(ServerConnection serverConnection, string databaseName)
{
var server = new Server(serverConnection);
string dbUrn = "Server/Database[@Name='" + Urn.EscapeString(databaseName) + "']";
Database? parent = server.GetSmoObject(new Urn(dbUrn)) as Database;
var roles = new List<string>();
if (parent != null)
{
foreach (DatabaseRole dbRole in parent.Roles)
{
var comparer = parent.GetStringComparer();
if (comparer.Compare(dbRole.Name, "public") != 0)
{
roles.Add(dbRole.Name);
}
}
}
return roles.ToArray();
}
private string[] LoadSqlLogins(ServerConnection serverConnection)
{
return LoadItems(serverConnection, "Server/Login");
}
private string[] LoadItems(ServerConnection serverConnection, string urn)
{
List<string> items = new List<string>();
Request req = new Request();
req.Urn = urn;
req.ResultType = ResultType.IDataReader;
req.Fields = new string[] { "Name" };
Enumerator en = new Enumerator();
using (IDataReader reader = en.Process(serverConnection, req).Data as IDataReader)
{
if (reader != null)
{
string name;
while (reader.Read())
{
// Get the permission name
name = reader.GetString(0);
items.Add(name);
}
}
}
return items.ToArray();
}
/// <summary>
/// Handle request to initialize user view
/// </summary>
internal async Task HandleInitializeUserViewRequest(InitializeUserViewParams parameters, RequestContext<UserViewInfo> requestContext)
{
ConnectionInfo connInfo;
ConnectionServiceInstance.TryFindConnection(parameters.ConnectionUri, out connInfo);
if (connInfo == null)
{
throw new ArgumentException("Invalid connection URI '{0}'", parameters.ConnectionUri);
}
if (parameters.ContextId != null && parameters.ConnectionUri != null)
{
this.contextIdToConnectionUriMap.Add(parameters.ContextId, parameters.ConnectionUri);
}
var serverConnection = ConnectionService.OpenServerConnection(connInfo, "DataContainer");
string databaseName = parameters.Database ?? "master";
var schemaMap = LoadSchemas(databaseName, string.Empty, serverConnection);
UserViewInfo userViewInfo = new UserViewInfo()
{
ObjectInfo = new UserInfo()
{
Type = DatabaseUserType.WithLogin,
Name = string.Empty,
LoginName = string.Empty,
Password = string.Empty,
DefaultSchema = string.Empty,
OwnedSchemas = new string[] { },
DatabaseRoles = new string[] { },
},
SupportContainedUser = true,
SupportWindowsAuthentication = true,
SupportAADAuthentication = true,
SupportSQLAuthentication = true,
Languages = new string[] { },
Schemas = schemaMap.Keys.ToArray(),
Logins = LoadSqlLogins(serverConnection),
DatabaseRoles = LoadDatabaseRoles(serverConnection, databaseName)
};
await requestContext.SendResult(userViewInfo);
}
/// <summary>
/// Handle request to create a user
/// </summary>
internal async Task HandleCreateUserRequest(CreateUserParams parameters, RequestContext<CreateUserResult> requestContext)
{
if (parameters.ContextId == null || !this.contextIdToConnectionUriMap.ContainsKey(parameters.ContextId))
{
throw new ArgumentException("Invalid context ID");
}
string connectionUri = this.contextIdToConnectionUriMap[parameters.ContextId];
var result = await ConfigureUser(connectionUri,
parameters.User,
ConfigAction.Create,
RunType.RunNow);
await requestContext.SendResult(new CreateUserResult()
{
User = parameters.User,
Success = result.Item1,
ErrorMessage = result.Item2
});
}
/// <summary>
/// Handle request to delete a user
/// </summary>
internal async Task HandleDeleteUserRequest(DeleteUserParams parameters, RequestContext<ResultStatus> requestContext)
{
ConnectionInfo connInfo;
ConnectionServiceInstance.TryFindConnection(parameters.ConnectionUri, out connInfo);
// if (connInfo == null)
// {
// // raise an error
// }
CDataContainer dataContainer = CDataContainer.CreateDataContainer(connInfo, databaseExists: true);
string dbUrn = "Server/Database[@Name='" + Urn.EscapeString(parameters.Database) + "']";
Database? parent = dataContainer.Server.GetSmoObject(new Urn(dbUrn)) as Database;
User user = parent.Users[parameters.Name];
dataContainer.SqlDialogSubject = user;
DoDropObject(dataContainer);
await requestContext.SendResult(new ResultStatus()
{
Success = true,
ErrorMessage = string.Empty
});
}
private IList<LanguageDisplay> GetDefaultLanguageOptions(CDataContainer dataContainer) private IList<LanguageDisplay> GetDefaultLanguageOptions(CDataContainer dataContainer)
{ {
// this.defaultLanguageComboBox.Items.Clear();
// this.defaultLanguageComboBox.Items.Add(defaultLanguagePlaceholder);
// sort the languages alphabetically by alias // sort the languages alphabetically by alias
SortedList sortedLanguages = new SortedList(Comparer.Default); SortedList sortedLanguages = new SortedList(Comparer.Default);
@@ -600,97 +374,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
return res; return res;
} }
// code needs to be ported into the useraction class
// public void UserMemberships_OnRunNow(object sender, CDataContainer dataContainer)
// {
// UserPrototype currentPrototype = UserPrototypeFactory.GetInstance(dataContainer).CurrentPrototype;
// //In case the UserGeneral/OwnedSchemas pages are loaded,
// //those will takes care of applying membership changes also.
// //Hence, we only need to apply changes in this method when those are not loaded.
// if (!currentPrototype.IsRoleMembershipChangesApplied)
// {
// //base.OnRunNow(sender);
// User user = currentPrototype.ApplyChanges();
// //this.ExecutionMode = ExecutionMode.Success;
// dataContainer.ObjectName = currentPrototype.Name;
// dataContainer.SqlDialogSubject = user;
// }
// //setting back to original after changes are applied
// currentPrototype.IsRoleMembershipChangesApplied = false;
// }
// /// <summary>
// /// implementation of OnPanelRunNow
// /// </summary>
// /// <param name="node"></param>
// public void UserOwnedSchemas_OnRunNow(object sender, CDataContainer dataContainer)
// {
// UserPrototype currentPrototype = UserPrototypeFactory.GetInstance(dataContainer).CurrentPrototype;
// //In case the UserGeneral/Membership pages are loaded,
// //those will takes care of applying schema ownership changes also.
// //Hence, we only need to apply changes in this method when those are not loaded.
// if (!currentPrototype.IsSchemaOwnershipChangesApplied)
// {
// //base.OnRunNow(sender);
// User user = currentPrototype.ApplyChanges();
// //this.ExecutionMode = ExecutionMode.Success;
// dataContainer.ObjectName = currentPrototype.Name;
// dataContainer.SqlDialogSubject = user;
// }
// //setting back to original after changes are applied
// currentPrototype.IsSchemaOwnershipChangesApplied = false;
// }
// how to populate defaults from prototype, will delete once refactored
// private void InitializeValuesInUiControls()
// {
// IUserPrototypeWithDefaultLanguage defaultLanguagePrototype = this.currentUserPrototype
// as IUserPrototypeWithDefaultLanguage;
// if (defaultLanguagePrototype != null
// && defaultLanguagePrototype.IsDefaultLanguageSupported)
// {
// string defaultLanguageAlias = defaultLanguagePrototype.DefaultLanguageAlias;
// //If engine returns default language as empty or null, that means the default language of
// //database will be used.
// //Default language is not applicable for users inside an uncontained authentication.
// if (string.IsNullOrEmpty(defaultLanguageAlias)
// && (this.DataContainer.Server.GetSmoObject(this.parentDbUrn) as Database).ContainmentType != ContainmentType.None)
// {
// defaultLanguageAlias = this.defaultLanguagePlaceholder;
// }
// this.defaultLanguageComboBox.Text = defaultLanguageAlias;
// }
// IUserPrototypeWithDefaultSchema defaultSchemaPrototype = this.currentUserPrototype
// as IUserPrototypeWithDefaultSchema;
// if (defaultSchemaPrototype != null
// && defaultSchemaPrototype.IsDefaultSchemaSupported)
// {
// this.defaultSchemaTextBox.Text = defaultSchemaPrototype.DefaultSchema;
// }
// IUserPrototypeWithPassword userWithPwdPrototype = this.currentUserPrototype
// as IUserPrototypeWithPassword;
// if (userWithPwdPrototype != null
// && !this.DataContainer.IsNewObject)
// {
// this.passwordTextBox.Text = FAKE_PASSWORD;
// this.confirmPwdTextBox.Text = FAKE_PASSWORD;
// }
// }
#endregion #endregion
#region "Credential Handlers" #region "Credential Handlers"
@@ -833,167 +516,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
}); });
} }
/// <summary>
/// this is the main method that is called by DropAllObjects for every object
/// in the grid
/// </summary>
/// <param name="objectRowNumber"></param>
private void DoDropObject(CDataContainer dataContainer)
{
// if a server isn't connected then there is nothing to do
if (dataContainer.Server == null)
{
return;
}
var executionMode = dataContainer.Server.ConnectionContext.SqlExecutionModes;
var subjectExecutionMode = executionMode;
//For Azure the ExecutionManager is different depending on which ExecutionManager
//used - one at the Server level and one at the Database level. So to ensure we
//don't use the wrong execution mode we need to set the mode for both (for on-prem
//this will essentially be a no-op)
SqlSmoObject sqlDialogSubject = null;
try
{
sqlDialogSubject = dataContainer.SqlDialogSubject;
}
catch (System.Exception)
{
//We may not have a valid dialog subject here (such as if the object hasn't been created yet)
//so in that case we'll just ignore it as that's a normal scenario.
}
if (sqlDialogSubject != null)
{
subjectExecutionMode =
sqlDialogSubject.ExecutionManager.ConnectionContext.SqlExecutionModes;
}
Urn objUrn = sqlDialogSubject?.Urn;
System.Diagnostics.Debug.Assert(objUrn != null);
SfcObjectQuery objectQuery = new SfcObjectQuery(dataContainer.Server);
IDroppable droppableObj = null;
string[] fields = null;
foreach (object obj in objectQuery.ExecuteIterator(new SfcQueryExpression(objUrn.ToString()), fields, null))
{
System.Diagnostics.Debug.Assert(droppableObj == null, "there is only one object");
droppableObj = obj as IDroppable;
}
// For Azure databases, the SfcObjectQuery executions above may have overwritten our desired execution mode, so restore it
dataContainer.Server.ConnectionContext.SqlExecutionModes = executionMode;
if (sqlDialogSubject != null)
{
sqlDialogSubject.ExecutionManager.ConnectionContext.SqlExecutionModes = subjectExecutionMode;
}
if (droppableObj == null)
{
string objectName = objUrn.GetAttribute("Name");
objectName ??= string.Empty;
throw new Microsoft.SqlServer.Management.Smo.MissingObjectException("DropObjectsSR.ObjectDoesNotExist(objUrn.Type, objectName)");
}
//special case database drop - see if we need to delete backup and restore history
SpecialPreDropActionsForObject(dataContainer, droppableObj,
deleteBackupRestoreOrDisableAuditSpecOrDisableAudit: false,
dropOpenConnections: false);
droppableObj.Drop();
//special case Resource Governor reconfigure - for pool, external pool, group Drop(), we need to issue
SpecialPostDropActionsForObject(dataContainer, droppableObj);
}
private void SpecialPreDropActionsForObject(CDataContainer dataContainer, IDroppable droppableObj,
bool deleteBackupRestoreOrDisableAuditSpecOrDisableAudit, bool dropOpenConnections)
{
// if a server isn't connected then there is nothing to do
if (dataContainer.Server == null)
{
return;
}
Database db = droppableObj as Database;
if (deleteBackupRestoreOrDisableAuditSpecOrDisableAudit)
{
if (db != null)
{
dataContainer.Server.DeleteBackupHistory(db.Name);
}
else
{
// else droppable object should be a server or database audit specification
ServerAuditSpecification sas = droppableObj as ServerAuditSpecification;
if (sas != null)
{
sas.Disable();
}
else
{
DatabaseAuditSpecification das = droppableObj as DatabaseAuditSpecification;
if (das != null)
{
das.Disable();
}
else
{
Audit aud = droppableObj as Audit;
if (aud != null)
{
aud.Disable();
}
}
}
}
}
// special case database drop - drop existing connections to the database other than this one
if (dropOpenConnections)
{
if (db?.ActiveConnections > 0)
{
// force the database to be single user
db.DatabaseOptions.UserAccess = DatabaseUserAccess.Single;
db.Alter(TerminationClause.RollbackTransactionsImmediately);
}
}
}
private void SpecialPostDropActionsForObject(CDataContainer dataContainer, IDroppable droppableObj)
{
// if a server isn't connected then there is nothing to do
if (dataContainer.Server == null)
{
return;
}
if (droppableObj is Policy)
{
Policy policyToDrop = (Policy)droppableObj;
if (!string.IsNullOrEmpty(policyToDrop.ObjectSet))
{
ObjectSet objectSet = policyToDrop.Parent.ObjectSets[policyToDrop.ObjectSet];
objectSet.Drop();
}
}
ResourcePool rp = droppableObj as ResourcePool;
ExternalResourcePool erp = droppableObj as ExternalResourcePool;
WorkloadGroup wg = droppableObj as WorkloadGroup;
if (null != rp || null != erp || null != wg)
{
// Alter() Resource Governor to reconfigure
dataContainer.Server.ResourceGovernor.Alter();
}
}
#endregion // "Helpers" #endregion // "Helpers"
// some potentially useful code for working with server & db roles to be refactored later // some potentially useful code for working with server & db roles to be refactored later

View File

@@ -3,19 +3,441 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information. // Licensed under the MIT license. See LICENSE file in the project root for full license information.
// //
using System;
using System.Collections.Generic;
using System.Data;
using System.Threading.Tasks;
using System.Xml;
using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlServer.Management.Sdk.Sfc; using Microsoft.SqlServer.Management.Sdk.Sfc;
using Microsoft.SqlServer.Management.Smo; using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlTools.ServiceLayer.Management; using Microsoft.SqlTools.ServiceLayer.Management;
using Microsoft.SqlTools.ServiceLayer.Security.Contracts; using Microsoft.SqlTools.ServiceLayer.Security.Contracts;
using Microsoft.SqlTools.ServiceLayer.Utility;
namespace Microsoft.SqlTools.ServiceLayer.Security namespace Microsoft.SqlTools.ServiceLayer.Security
{ {
internal class UserServiceHandlerImpl
{
private class UserViewState
{
public string Database { get; set; }
public UserPrototypeData OriginalUserData { get; set; }
public UserViewState(string database, UserPrototypeData originalUserData)
{
this.Database = database;
this.OriginalUserData = originalUserData;
}
}
private ConnectionService? connectionService;
private Dictionary<string, UserViewState> contextIdToViewState = new Dictionary<string, UserViewState>();
/// <summary>
/// Internal for testing purposes only
/// </summary>
internal ConnectionService ConnectionServiceInstance
{
get
{
connectionService ??= ConnectionService.Instance;
return connectionService;
}
set
{
connectionService = value;
}
}
/// <summary>
/// Handle request to initialize user view
/// </summary>
internal async Task HandleInitializeUserViewRequest(InitializeUserViewParams parameters, RequestContext<UserViewInfo> requestContext)
{
if (string.IsNullOrWhiteSpace(parameters.Database))
{
throw new ArgumentNullException("parameters.Database");
}
if (string.IsNullOrWhiteSpace(parameters.ContextId))
{
throw new ArgumentNullException("parameters.ContextId");
}
ConnectionInfo originalConnInfo;
ConnectionServiceInstance.TryFindConnection(parameters.ConnectionUri, out originalConnInfo);
if (originalConnInfo == null)
{
throw new ArgumentException("Invalid connection URI '{0}'", parameters.ConnectionUri);
}
string originalDatabaseName = originalConnInfo.ConnectionDetails.DatabaseName;
try
{
originalConnInfo.ConnectionDetails.DatabaseName = parameters.Database;
ConnectParams connectParams = new ConnectParams
{
OwnerUri = parameters.ContextId,
Connection = originalConnInfo.ConnectionDetails,
Type = Connection.ConnectionType.Default
};
await this.ConnectionServiceInstance.Connect(connectParams);
}
finally
{
originalConnInfo.ConnectionDetails.DatabaseName = originalDatabaseName;
}
ConnectionInfo connInfo;
this.ConnectionServiceInstance.TryFindConnection(parameters.ContextId, out connInfo);
CDataContainer dataContainer = CreateUserDataContainer(connInfo, null, ConfigAction.Create, parameters.Database);
UserInfo? userInfo = null;
if (!parameters.IsNewObject)
{
User? existingUser = null;
string databaseUrn = string.Format(System.Globalization.CultureInfo.InvariantCulture,
"Server/Database[@Name='{0}']",
Urn.EscapeString(parameters.Database));
Database? parentDb = dataContainer.Server.GetSmoObject(databaseUrn) as Database;
existingUser = dataContainer.Server.Databases[parentDb.Name].Users[parameters.Name];
userInfo = new UserInfo()
{
Name = parameters.Name,
LoginName = existingUser.Login,
DefaultSchema = existingUser.DefaultSchema
};
}
UserPrototypeFactory userPrototypeFactory = UserPrototypeFactory.GetInstance(dataContainer, userInfo, originalData: null);
UserPrototype currentUserPrototype = userPrototypeFactory.GetUserPrototype(ExhaustiveUserTypes.LoginMappedUser);
IUserPrototypeWithDefaultLanguage defaultLanguagePrototype = currentUserPrototype as IUserPrototypeWithDefaultLanguage;
string? defaultLanguageAlias = null;
if (defaultLanguagePrototype != null && defaultLanguagePrototype.IsDefaultLanguageSupported)
{
string dbUrn = "Server/Database[@Name='" + Urn.EscapeString(parameters.Database) + "']";
defaultLanguageAlias = defaultLanguagePrototype.DefaultLanguageAlias;
//If engine returns default language as empty or null, that means the default language of
//database will be used.
//Default language is not applicable for users inside an uncontained authentication.
if (string.IsNullOrEmpty(defaultLanguageAlias)
&& (dataContainer.Server.GetSmoObject(dbUrn) as Database).ContainmentType != ContainmentType.None)
{
defaultLanguageAlias = SR.DefaultLanguagePlaceholder;
}
}
string? defaultSchema = null;
IUserPrototypeWithDefaultSchema defaultSchemaPrototype = currentUserPrototype as IUserPrototypeWithDefaultSchema;
if (defaultSchemaPrototype != null && defaultSchemaPrototype.IsDefaultSchemaSupported)
{
defaultSchema = defaultSchemaPrototype.DefaultSchema;
}
// IUserPrototypeWithPassword userWithPwdPrototype = currentUserPrototype as IUserPrototypeWithPassword;
// if (userWithPwdPrototype != null && !this.DataContainer.IsNewObject)
// {
// this.passwordTextBox.Text = FAKE_PASSWORD;
// this.confirmPwdTextBox.Text = FAKE_PASSWORD;
// }
string? loginName = null;
IUserPrototypeWithMappedLogin mappedLoginPrototype = currentUserPrototype as IUserPrototypeWithMappedLogin;
if (mappedLoginPrototype != null)
{
loginName = mappedLoginPrototype.LoginName;
}
List<string> databaseRoles = new List<string>();
foreach (string role in currentUserPrototype.DatabaseRoleNames)
{
if (currentUserPrototype.IsRoleMember(role))
{
databaseRoles.Add(role);
}
}
List<string> schemaNames = new List<string>();
foreach (string schema in currentUserPrototype.SchemaNames)
{
if (currentUserPrototype.IsSchemaOwner(schema))
{
schemaNames.Add(schema);
}
}
// default to dbo schema, if there isn't already a default
if (string.IsNullOrWhiteSpace(defaultSchema) && currentUserPrototype.SchemaNames.Contains("dbo"))
{
defaultSchema = "dbo";
}
ServerConnection serverConnection = dataContainer.ServerConnection;
UserViewInfo userViewInfo = new UserViewInfo()
{
ObjectInfo = new UserInfo()
{
Type = DatabaseUserType.WithLogin,
Name = currentUserPrototype.Name,
LoginName = loginName,
Password = string.Empty,
DefaultSchema = defaultSchema,
OwnedSchemas = schemaNames.ToArray(),
DatabaseRoles = databaseRoles.ToArray(),
DefaultLanguage = defaultLanguageAlias
},
SupportContainedUser = false, // support for these will be added later
SupportWindowsAuthentication = false,
SupportAADAuthentication = false,
SupportSQLAuthentication = true,
Languages = new string[] { },
Schemas = currentUserPrototype.SchemaNames.ToArray(),
Logins = LoadSqlLogins(serverConnection),
DatabaseRoles = currentUserPrototype.DatabaseRoleNames.ToArray()
};
this.contextIdToViewState.Add(
parameters.ContextId,
new UserViewState(parameters.Database, currentUserPrototype.CurrentState));
await requestContext.SendResult(userViewInfo);
}
/// <summary>
/// Handle request to create a user
/// </summary>
internal async Task HandleCreateUserRequest(CreateUserParams parameters, RequestContext<CreateUserResult> requestContext)
{
if (parameters.ContextId == null)
{
throw new ArgumentException("Invalid context ID");
}
UserViewState viewState;
this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState);
if (viewState == null)
{
throw new ArgumentException("Invalid context ID view state");
}
Tuple<bool, string> result = ConfigureUser(
parameters.ContextId,
parameters.User,
ConfigAction.Create,
RunType.RunNow,
viewState.Database,
viewState.OriginalUserData);
await requestContext.SendResult(new CreateUserResult()
{
User = parameters.User,
Success = result.Item1,
ErrorMessage = result.Item2
});
}
/// <summary>
/// Handle request to update a user
/// </summary>
internal async Task HandleUpdateUserRequest(UpdateUserParams parameters, RequestContext<ResultStatus> requestContext)
{
if (parameters.ContextId == null)
{
throw new ArgumentException("Invalid context ID");
}
UserViewState viewState;
this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState);
if (viewState == null)
{
throw new ArgumentException("Invalid context ID view state");
}
Tuple<bool, string> result = ConfigureUser(
parameters.ContextId,
parameters.User,
ConfigAction.Update,
RunType.RunNow,
viewState.Database,
viewState.OriginalUserData);
await requestContext.SendResult(new ResultStatus()
{
Success = result.Item1,
ErrorMessage = result.Item2
});
}
/// <summary>
/// Handle request to delete a user
/// </summary>
internal async Task HandleDeleteUserRequest(DeleteUserParams parameters, RequestContext<ResultStatus> requestContext)
{
ConnectionInfo connInfo;
ConnectionServiceInstance.TryFindConnection(parameters.ConnectionUri, out connInfo);
if (connInfo == null)
{
throw new ArgumentException("Invalid ConnectionUri");
}
if (string.IsNullOrWhiteSpace(parameters.Name) || string.IsNullOrWhiteSpace(parameters.Database))
{
throw new ArgumentException("Invalid null parameter");
}
CDataContainer dataContainer = CDataContainer.CreateDataContainer(connInfo, databaseExists: true);
string dbUrn = "Server/Database[@Name='" + Urn.EscapeString(parameters.Database) + "']";
Database? parent = dataContainer.Server.GetSmoObject(new Urn(dbUrn)) as Database;
User user = parent.Users[parameters.Name];
dataContainer.SqlDialogSubject = user;
CheckForSchemaOwnerships(parent, user);
DatabaseUtils.DoDropObject(dataContainer);
await requestContext.SendResult(new ResultStatus()
{
Success = true,
ErrorMessage = string.Empty
});
}
internal async Task HandleDisposeUserViewRequest(DisposeUserViewRequestParams parameters, RequestContext<ResultStatus> requestContext)
{
this.ConnectionServiceInstance.Disconnect(new DisconnectParams(){
OwnerUri = parameters.ContextId,
Type = null
});
if (parameters.ContextId != null)
{
this.contextIdToViewState.Remove(parameters.ContextId);
}
await requestContext.SendResult(new ResultStatus()
{
Success = true,
ErrorMessage = string.Empty
});
}
internal CDataContainer CreateUserDataContainer(
ConnectionInfo connInfo,
UserInfo? user,
ConfigAction configAction,
string databaseName)
{
var serverConnection = ConnectionService.OpenServerConnection(connInfo, "DataContainer");
var connectionInfoWithConnection = new SqlConnectionInfoWithConnection();
connectionInfoWithConnection.ServerConnection = serverConnection;
string urn = (configAction == ConfigAction.Update && user != null)
? string.Format(System.Globalization.CultureInfo.InvariantCulture,
"Server/Database[@Name='{0}']/User[@Name='{1}']",
Urn.EscapeString(databaseName),
Urn.EscapeString(user.Name))
: string.Format(System.Globalization.CultureInfo.InvariantCulture,
"Server/Database[@Name='{0}']",
Urn.EscapeString(databaseName));
ActionContext context = new ActionContext(serverConnection, "User", urn);
DataContainerXmlGenerator containerXml = new DataContainerXmlGenerator(context);
if (configAction == ConfigAction.Create)
{
containerXml.AddProperty("itemtype", "User");
}
XmlDocument xmlDoc = containerXml.GenerateXmlDocument();
return CDataContainer.CreateDataContainer(connectionInfoWithConnection, xmlDoc);
}
internal Tuple<bool, string> ConfigureUser(
string? ownerUri,
UserInfo? user,
ConfigAction configAction,
RunType runType,
string databaseName,
UserPrototypeData? originalData)
{
ConnectionInfo connInfo;
this.ConnectionServiceInstance.TryFindConnection(ownerUri, out connInfo);
if (connInfo == null)
{
throw new ArgumentException("Invalid connection URI '{0}'", ownerUri);
}
CDataContainer dataContainer = CreateUserDataContainer(connInfo, user, configAction, databaseName);
using (var actions = new UserActions(dataContainer, user, configAction, originalData))
{
var executionHandler = new ExecutonHandler(actions);
executionHandler.RunNow(runType, this);
if (executionHandler.ExecutionResult == ExecutionMode.Failure)
{
throw executionHandler.ExecutionFailureException;
}
}
return new Tuple<bool, string>(true, string.Empty);
}
private void CheckForSchemaOwnerships(Database parentDb, User existingUser)
{
foreach (Schema sch in parentDb.Schemas)
{
var comparer = parentDb.GetStringComparer();
if (comparer.Compare(sch.Owner, existingUser.Name) == 0)
{
throw new ApplicationException("Cannot drop user since it owns a schema");
}
}
}
private string[] LoadSqlLogins(ServerConnection serverConnection)
{
return LoadItems(serverConnection, "Server/Login");
}
private string[] LoadItems(ServerConnection serverConnection, string urn)
{
List<string> items = new List<string>();
Request req = new Request();
req.Urn = urn;
req.ResultType = ResultType.IDataReader;
req.Fields = new string[] { "Name" };
Enumerator en = new Enumerator();
using (IDataReader reader = en.Process(serverConnection, req).Data as IDataReader)
{
if (reader != null)
{
string name;
while (reader.Read())
{
// Get the permission name
name = reader.GetString(0);
items.Add(name);
}
}
}
return items.ToArray();
}
}
internal class UserActions : ManagementActionBase internal class UserActions : ManagementActionBase
{ {
#region Variables #region Variables
//private UserPrototypeData userData; //private UserPrototypeData userData;
private UserPrototype userPrototype; private UserPrototype userPrototype;
private UserInfo user; private UserInfo? user;
private ConfigAction configAction; private ConfigAction configAction;
#endregion #endregion
@@ -27,13 +449,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
public UserActions( public UserActions(
CDataContainer context, CDataContainer context,
UserInfo? user, UserInfo? user,
ConfigAction configAction) ConfigAction configAction,
UserPrototypeData? originalData)
{ {
this.DataContainer = context; this.DataContainer = context;
this.user = user; this.user = user;
this.configAction = configAction; this.configAction = configAction;
this.userPrototype = InitUserNew(context, user); this.userPrototype = InitUserPrototype(context, user, originalData);
} }
// /// <summary> // /// <summary>
@@ -65,14 +488,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
} }
} }
private UserPrototype InitUserNew(CDataContainer dataContainer, UserInfo user) private UserPrototype InitUserPrototype(CDataContainer dataContainer, UserInfo user, UserPrototypeData? originalData)
{ {
ExhaustiveUserTypes currentUserType; ExhaustiveUserTypes currentUserType;
UserPrototypeFactory userPrototypeFactory = UserPrototypeFactory.GetInstance(dataContainer, user); UserPrototypeFactory userPrototypeFactory = UserPrototypeFactory.GetInstance(dataContainer, user, originalData);
if (dataContainer.IsNewObject) if (dataContainer.IsNewObject)
{ {
if (IsParentDatabaseContained(dataContainer.ParentUrn, dataContainer)) if (dataContainer.Server != null && IsParentDatabaseContained(dataContainer.ParentUrn, dataContainer.Server))
{ {
currentUserType = ExhaustiveUserTypes.SqlUserWithPassword; currentUserType = ExhaustiveUserTypes.SqlUserWithPassword;
} }
@@ -91,8 +514,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
return currentUserPrototype; return currentUserPrototype;
} }
private ExhaustiveUserTypes GetCurrentUserTypeForExistingUser(User user) private ExhaustiveUserTypes GetCurrentUserTypeForExistingUser(User? user)
{ {
if (user == null)
{
return ExhaustiveUserTypes.Unknown;
}
switch (user.UserType) switch (user.UserType)
{ {
case UserType.SqlUser: case UserType.SqlUser:
@@ -124,10 +552,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
} }
} }
private bool IsParentDatabaseContained(Urn parentDbUrn, CDataContainer dataContainer) private static bool IsParentDatabaseContained(Urn parentDbUrn, Server server)
{ {
string parentDbName = parentDbUrn.GetNameForType("Database"); string parentDbName = parentDbUrn.GetNameForType("Database");
Database parentDatabase = dataContainer.Server.Databases[parentDbName]; Database parentDatabase = server.Databases[parentDbName];
if (parentDatabase.IsSupportedProperty("ContainmentType") if (parentDatabase.IsSupportedProperty("ContainmentType")
&& parentDatabase.ContainmentType == ContainmentType.Partial) && parentDatabase.ContainmentType == ContainmentType.Partial)

View File

@@ -103,7 +103,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
this.isMember = new Dictionary<string, bool>(); this.isMember = new Dictionary<string, bool>();
} }
public UserPrototypeData(CDataContainer context, UserInfo userInfo) public UserPrototypeData(CDataContainer context, UserInfo? userInfo)
{ {
this.isSchemaOwned = new Dictionary<string, bool>(); this.isSchemaOwned = new Dictionary<string, bool>();
this.isMember = new Dictionary<string, bool>(); this.isMember = new Dictionary<string, bool>();
@@ -114,15 +114,21 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
} }
else else
{ {
this.name = userInfo.Name; if (userInfo != null)
this.mappedLoginName = userInfo.LoginName; {
this.defaultSchemaName = userInfo.DefaultSchema; this.name = userInfo.Name;
this.password = DatabaseUtils.GetReadOnlySecureString(userInfo.Password); this.mappedLoginName = userInfo.LoginName;
this.defaultSchemaName = userInfo.DefaultSchema;
if (!string.IsNullOrEmpty(userInfo.Password))
{
this.password = DatabaseUtils.GetReadOnlySecureString(userInfo.Password);
}
}
} }
this.LoadRoleMembership(context); this.LoadRoleMembership(context, userInfo);
this.LoadSchemaData(context); this.LoadSchemaData(context, userInfo);
} }
public UserPrototypeData Clone() public UserPrototypeData Clone()
@@ -145,20 +151,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
foreach (string key in this.isMember?.Keys ?? Enumerable.Empty<string>()) foreach (string key in this.isMember?.Keys ?? Enumerable.Empty<string>())
{ {
if (result.isMember?.ContainsKey(key) == true result.isMember[key] = this.isMember[key];
&& this.isMember?.ContainsKey(key) == true)
{
result.isMember[key] = this.isMember[key];
}
} }
foreach (string key in this.isSchemaOwned?.Keys ?? Enumerable.Empty<string>()) foreach (string key in this.isSchemaOwned?.Keys ?? Enumerable.Empty<string>())
{ {
if (result.isSchemaOwned?.ContainsKey(key) == true result.isSchemaOwned[key] = this.isSchemaOwned[key];
&& this.isSchemaOwned?.ContainsKey(key) == true)
{
result.isSchemaOwned[key] = this.isSchemaOwned[key];
}
} }
return result; return result;
@@ -248,7 +246,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
/// Loads role membership of a database user. /// Loads role membership of a database user.
/// </summary> /// </summary>
/// <param name="context"></param> /// <param name="context"></param>
private void LoadRoleMembership(CDataContainer context) private void LoadRoleMembership(CDataContainer context, UserInfo? userInfo)
{ {
Urn objUrn = new Urn(context.ObjectUrn); Urn objUrn = new Urn(context.ObjectUrn);
Urn databaseUrn = objUrn.Parent; Urn databaseUrn = objUrn.Parent;
@@ -259,20 +257,25 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
return; return;
} }
User existingUser = context.Server.Databases[parentDb.Name].Users[objUrn.GetNameForType("User")]; string userName = userInfo?.Name ?? objUrn.GetNameForType("User");
User existingUser = context.Server.Databases[parentDb.Name].Users[userName];
foreach (DatabaseRole dbRole in parentDb.Roles) foreach (DatabaseRole dbRole in parentDb.Roles)
{ {
var comparer = parentDb.GetStringComparer(); var comparer = parentDb.GetStringComparer();
if (comparer.Compare(dbRole.Name, "public") != 0) if (comparer.Compare(dbRole.Name, "public") != 0)
{ {
if (context.IsNewObject) if (userInfo != null && userInfo.DatabaseRoles != null)
{ {
this.isMember[dbRole.Name] = false; this.isMember[dbRole.Name] = userInfo.DatabaseRoles.Contains(dbRole.Name);
}
else if (existingUser != null)
{
this.isMember[dbRole.Name] = existingUser.IsMember(dbRole.Name);
} }
else else
{ {
this.isMember[dbRole.Name] = existingUser.IsMember(dbRole.Name); this.isMember[dbRole.Name] = false;
} }
} }
} }
@@ -282,7 +285,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
/// Loads schema ownership related data. /// Loads schema ownership related data.
/// </summary> /// </summary>
/// <param name="context"></param> /// <param name="context"></param>
private void LoadSchemaData(CDataContainer context) private void LoadSchemaData(CDataContainer context, UserInfo? userInfo)
{ {
Urn objUrn = new Urn(context.ObjectUrn); Urn objUrn = new Urn(context.ObjectUrn);
Urn databaseUrn = objUrn.Parent; Urn databaseUrn = objUrn.Parent;
@@ -293,7 +296,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
return; return;
} }
User existingUser = context.Server.Databases[parentDb.Name].Users[objUrn.GetNameForType("User")]; string userName = userInfo?.Name ?? objUrn.GetNameForType("User");
User existingUser = context.Server.Databases[parentDb.Name].Users[userName];
if (!SqlMgmtUtils.IsYukonOrAbove(context.Server) if (!SqlMgmtUtils.IsYukonOrAbove(context.Server)
|| parentDb.CompatibilityLevel <= CompatibilityLevel.Version80) || parentDb.CompatibilityLevel <= CompatibilityLevel.Version80)
@@ -303,15 +307,19 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
foreach (Schema sch in parentDb.Schemas) foreach (Schema sch in parentDb.Schemas)
{ {
if (context.IsNewObject) if (userInfo != null && userInfo.OwnedSchemas != null)
{ {
this.isSchemaOwned[sch.Name] = false; this.isSchemaOwned[sch.Name] = userInfo.OwnedSchemas.Contains(sch.Name);
} }
else else if (existingUser != null)
{ {
var comparer = parentDb.GetStringComparer(); var comparer = parentDb.GetStringComparer();
this.isSchemaOwned[sch.Name] = comparer.Compare(sch.Owner, existingUser.Name) == 0; this.isSchemaOwned[sch.Name] = comparer.Compare(sch.Owner, existingUser.Name) == 0;
} }
else
{
this.isSchemaOwned[sch.Name] = false;
}
} }
} }
} }
@@ -334,6 +342,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
#region IUserPrototype Members #region IUserPrototype Members
public UserPrototypeData CurrentState
{
get
{
return this.currentState;
}
}
public string Name public string Name
{ {
get get
@@ -996,15 +1012,15 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
} }
} }
private UserPrototypeFactory(CDataContainer context, UserInfo user) private UserPrototypeFactory(CDataContainer context, UserInfo user, UserPrototypeData? originalData)
{ {
this.context = context; this.context = context;
this.originalData = new UserPrototypeData(this.context, user); this.currentData = new UserPrototypeData(this.context, user);
this.currentData = this.originalData.Clone(); this.originalData = originalData ?? this.currentData.Clone();
} }
public static UserPrototypeFactory GetInstance(CDataContainer context, UserInfo user) public static UserPrototypeFactory GetInstance(CDataContainer context, UserInfo? user, UserPrototypeData? originalData)
{ {
if (singletonInstance != null if (singletonInstance != null
&& singletonInstance.context != context) && singletonInstance.context != context)
@@ -1012,7 +1028,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
singletonInstance = null; singletonInstance = null;
} }
singletonInstance ??= new UserPrototypeFactory(context, user); singletonInstance ??= new UserPrototypeFactory(context, user, originalData);
return singletonInstance; return singletonInstance;
} }

View File

@@ -5,6 +5,10 @@
#nullable disable #nullable disable
using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlServer.Management.Dmf;
using Microsoft.SqlServer.Management.Sdk.Sfc;
using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.ServiceLayer.Management; using Microsoft.SqlTools.ServiceLayer.Management;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
@@ -80,5 +84,166 @@ namespace Microsoft.SqlTools.ServiceLayer.Utility
return ss; return ss;
} }
/// <summary>
/// this is the main method that is called by DropAllObjects for every object
/// in the grid
/// </summary>
/// <param name="objectRowNumber"></param>
public static void DoDropObject(CDataContainer dataContainer)
{
// if a server isn't connected then there is nothing to do
if (dataContainer.Server == null)
{
return;
}
var executionMode = dataContainer.Server.ConnectionContext.SqlExecutionModes;
var subjectExecutionMode = executionMode;
//For Azure the ExecutionManager is different depending on which ExecutionManager
//used - one at the Server level and one at the Database level. So to ensure we
//don't use the wrong execution mode we need to set the mode for both (for on-prem
//this will essentially be a no-op)
SqlSmoObject sqlDialogSubject = null;
try
{
sqlDialogSubject = dataContainer.SqlDialogSubject;
}
catch (System.Exception)
{
//We may not have a valid dialog subject here (such as if the object hasn't been created yet)
//so in that case we'll just ignore it as that's a normal scenario.
}
if (sqlDialogSubject != null)
{
subjectExecutionMode =
sqlDialogSubject.ExecutionManager.ConnectionContext.SqlExecutionModes;
}
Urn objUrn = sqlDialogSubject?.Urn;
System.Diagnostics.Debug.Assert(objUrn != null);
SfcObjectQuery objectQuery = new SfcObjectQuery(dataContainer.Server);
IDroppable droppableObj = null;
string[] fields = null;
foreach (object obj in objectQuery.ExecuteIterator(new SfcQueryExpression(objUrn.ToString()), fields, null))
{
System.Diagnostics.Debug.Assert(droppableObj == null, "there is only one object");
droppableObj = obj as IDroppable;
}
// For Azure databases, the SfcObjectQuery executions above may have overwritten our desired execution mode, so restore it
dataContainer.Server.ConnectionContext.SqlExecutionModes = executionMode;
if (sqlDialogSubject != null)
{
sqlDialogSubject.ExecutionManager.ConnectionContext.SqlExecutionModes = subjectExecutionMode;
}
if (droppableObj == null)
{
string objectName = objUrn.GetAttribute("Name");
objectName ??= string.Empty;
throw new Microsoft.SqlServer.Management.Smo.MissingObjectException("DropObjectsSR.ObjectDoesNotExist(objUrn.Type, objectName)");
}
//special case database drop - see if we need to delete backup and restore history
SpecialPreDropActionsForObject(dataContainer, droppableObj,
deleteBackupRestoreOrDisableAuditSpecOrDisableAudit: false,
dropOpenConnections: false);
droppableObj.Drop();
//special case Resource Governor reconfigure - for pool, external pool, group Drop(), we need to issue
SpecialPostDropActionsForObject(dataContainer, droppableObj);
}
private static void SpecialPreDropActionsForObject(CDataContainer dataContainer, IDroppable droppableObj,
bool deleteBackupRestoreOrDisableAuditSpecOrDisableAudit, bool dropOpenConnections)
{
// if a server isn't connected then there is nothing to do
if (dataContainer.Server == null)
{
return;
}
Database db = droppableObj as Database;
if (deleteBackupRestoreOrDisableAuditSpecOrDisableAudit)
{
if (db != null)
{
dataContainer.Server.DeleteBackupHistory(db.Name);
}
else
{
// else droppable object should be a server or database audit specification
ServerAuditSpecification sas = droppableObj as ServerAuditSpecification;
if (sas != null)
{
sas.Disable();
}
else
{
DatabaseAuditSpecification das = droppableObj as DatabaseAuditSpecification;
if (das != null)
{
das.Disable();
}
else
{
Audit aud = droppableObj as Audit;
if (aud != null)
{
aud.Disable();
}
}
}
}
}
// special case database drop - drop existing connections to the database other than this one
if (dropOpenConnections)
{
if (db?.ActiveConnections > 0)
{
// force the database to be single user
db.DatabaseOptions.UserAccess = DatabaseUserAccess.Single;
db.Alter(TerminationClause.RollbackTransactionsImmediately);
}
}
}
private static void SpecialPostDropActionsForObject(CDataContainer dataContainer, IDroppable droppableObj)
{
// if a server isn't connected then there is nothing to do
if (dataContainer.Server == null)
{
return;
}
if (droppableObj is Policy)
{
Policy policyToDrop = (Policy)droppableObj;
if (!string.IsNullOrEmpty(policyToDrop.ObjectSet))
{
ObjectSet objectSet = policyToDrop.Parent.ObjectSets[policyToDrop.ObjectSet];
objectSet.Drop();
}
}
ResourcePool rp = droppableObj as ResourcePool;
ExternalResourcePool erp = droppableObj as ExternalResourcePool;
WorkloadGroup wg = droppableObj as WorkloadGroup;
if (null != rp || null != erp || null != wg)
{
// Alter() Resource Governor to reconfigure
dataContainer.Server.ResourceGovernor.Alter();
}
}
} }
} }

View File

@@ -39,7 +39,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Security
EnforcePasswordExpiration = false, EnforcePasswordExpiration = false,
Password = "placeholder", Password = "placeholder",
OldPassword = "placeholder", OldPassword = "placeholder",
DefaultLanguage = "us_english", DefaultLanguage = "English - us_english",
DefaultDatabase = "master" DefaultDatabase = "master"
}; };
} }
@@ -53,7 +53,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Security
LoginName = loginName, LoginName = loginName,
Password = "placeholder", Password = "placeholder",
DefaultSchema = "dbo", DefaultSchema = "dbo",
OwnedSchemas = new string[] { "dbo" } OwnedSchemas = new string[] { "" }
}; };
} }
@@ -124,5 +124,131 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Security
var service = new SecurityService(); var service = new SecurityService();
await SecurityTestUtils.DeleteCredential(service, connectionResult, credential); await SecurityTestUtils.DeleteCredential(service, connectionResult, credential);
} }
internal static async Task<LoginInfo> CreateLogin(SecurityService service, TestConnectionResult connectionResult, string contextId)
{
var initializeLoginViewRequestParams = new InitializeLoginViewRequestParams
{
ConnectionUri = connectionResult.ConnectionInfo.OwnerUri,
ContextId = contextId,
IsNewObject = true
};
var loginParams = new CreateLoginParams
{
ContextId = contextId,
Login = SecurityTestUtils.GetTestLoginInfo()
};
var createLoginContext = new Mock<RequestContext<object>>();
createLoginContext.Setup(x => x.SendResult(It.IsAny<object>()))
.Returns(Task.FromResult(new object()));
var initializeLoginViewContext = new Mock<RequestContext<LoginViewInfo>>();
initializeLoginViewContext.Setup(x => x.SendResult(It.IsAny<LoginViewInfo>()))
.Returns(Task.FromResult(new LoginViewInfo()));
// call the create login method
await service.HandleInitializeLoginViewRequest(initializeLoginViewRequestParams, initializeLoginViewContext.Object);
await service.HandleCreateLoginRequest(loginParams, createLoginContext.Object);
return loginParams.Login;
}
internal static async Task DeleteLogin(SecurityService service, TestConnectionResult connectionResult, LoginInfo login)
{
// cleanup created login
var deleteParams = new DeleteLoginParams
{
ConnectionUri = connectionResult.ConnectionInfo.OwnerUri,
Name = login.Name
};
var deleteContext = new Mock<RequestContext<object>>();
deleteContext.Setup(x => x.SendResult(It.IsAny<object>()))
.Returns(Task.FromResult(new object()));
// call the create login method
await service.HandleDeleteLoginRequest(deleteParams, deleteContext.Object);
}
internal static async Task<UserInfo> CreateUser(
UserServiceHandlerImpl service,
TestConnectionResult connectionResult,
string contextId,
LoginInfo login)
{
var initializeViewRequestParams = new InitializeUserViewParams
{
ConnectionUri = connectionResult.ConnectionInfo.OwnerUri,
ContextId = contextId,
IsNewObject = true,
Database = "master"
};
var initializeUserViewContext = new Mock<RequestContext<UserViewInfo>>();
initializeUserViewContext.Setup(x => x.SendResult(It.IsAny<UserViewInfo>()))
.Returns(Task.FromResult(new LoginViewInfo()));
await service.HandleInitializeUserViewRequest(initializeViewRequestParams, initializeUserViewContext.Object);
var userParams = new CreateUserParams
{
ContextId = contextId,
User = SecurityTestUtils.GetTestUserInfo(login.Name)
};
var createUserContext = new Mock<RequestContext<CreateUserResult>>();
createUserContext.Setup(x => x.SendResult(It.IsAny<CreateUserResult>()))
.Returns(Task.FromResult(new object()));
// call the create login method
await service.HandleCreateUserRequest(userParams, createUserContext.Object);
// verify the result
createUserContext.Verify(x => x.SendResult(It.Is<CreateUserResult>
(p => p.Success && p.User.Name != string.Empty)));
return userParams.User;
}
internal static async Task UpdateUser(
UserServiceHandlerImpl service,
TestConnectionResult connectionResult,
string contextId,
UserInfo user)
{
// update the user
user.OwnedSchemas = new string[] { "dbo" };
var updateParams = new UpdateUserParams
{
ContextId = contextId,
User = user
};
var updateUserContext = new Mock<RequestContext<ResultStatus>>();
// call the create login method
await service.HandleUpdateUserRequest(updateParams, updateUserContext.Object);
// verify the result
updateUserContext.Verify(x => x.SendResult(It.Is<ResultStatus>(p => p.Success)));
}
internal static async Task DeleteUser(UserServiceHandlerImpl service, TestConnectionResult connectionResult, UserInfo user)
{
// cleanup created user
var deleteParams = new DeleteUserParams
{
ConnectionUri = connectionResult.ConnectionInfo.OwnerUri,
Name = user.Name,
Database = connectionResult.ConnectionInfo.ConnectionDetails.DatabaseName
};
var deleteContext = new Mock<RequestContext<ResultStatus>>();
deleteContext.Setup(x => x.SendResult(It.IsAny<ResultStatus>()))
.Returns(Task.FromResult(new object()));
// call the create user method
await service.HandleDeleteUserRequest(deleteParams, deleteContext.Object);
deleteContext.Verify(x => x.SendResult(It.Is<ResultStatus>(p => p.Success)));
}
} }
} }

View File

@@ -4,12 +4,9 @@
// //
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.SqlTools.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility; using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility;
using Microsoft.SqlTools.ServiceLayer.Security; using Microsoft.SqlTools.ServiceLayer.Security;
using Microsoft.SqlTools.ServiceLayer.Security.Contracts;
using Microsoft.SqlTools.ServiceLayer.Test.Common; using Microsoft.SqlTools.ServiceLayer.Test.Common;
using Moq;
namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Security namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Security
{ {
@@ -21,56 +18,50 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Security
/// <summary> /// <summary>
/// Test the basic Create User method handler /// Test the basic Create User method handler
/// </summary> /// </summary>
// [Test] //[Test] - enable tests in separate change
public async Task TestHandleCreateUserRequest() public async Task TestHandleCreateUserWithLoginRequest()
{ {
using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
{ {
// setup // setup
SecurityService service = new SecurityService();
UserServiceHandlerImpl userService = new UserServiceHandlerImpl();
var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath); var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath);
var contextId = System.Guid.NewGuid().ToString(); var contextId = System.Guid.NewGuid().ToString();
var initializeLoginViewRequestParams = new InitializeLoginViewRequestParams var login = await SecurityTestUtils.CreateLogin(service, connectionResult, contextId);
{
ConnectionUri = connectionResult.ConnectionInfo.OwnerUri,
ContextId = contextId,
IsNewObject = true
};
var loginParams = new CreateLoginParams var user = await SecurityTestUtils.CreateUser(userService, connectionResult, contextId, login);
{
ContextId = contextId,
Login = SecurityTestUtils.GetTestLoginInfo()
};
var createLoginContext = new Mock<RequestContext<object>>(); await SecurityTestUtils.DeleteUser(userService, connectionResult, user);
createLoginContext.Setup(x => x.SendResult(It.IsAny<object>()))
.Returns(Task.FromResult(new object()));
var initializeLoginViewContext = new Mock<RequestContext<LoginViewInfo>>();
initializeLoginViewContext.Setup(x => x.SendResult(It.IsAny<LoginViewInfo>()))
.Returns(Task.FromResult(new LoginViewInfo()));
// call the create login method await SecurityTestUtils.DeleteLogin(service, connectionResult, login);
}
}
/// <summary>
/// Test the basic Update User method handler
/// </summary>
//[Test] - enable tests in separate change
public async Task TestHandleUpdateUserWithLoginRequest()
{
using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
{
// setup
SecurityService service = new SecurityService(); SecurityService service = new SecurityService();
await service.HandleInitializeLoginViewRequest(initializeLoginViewRequestParams, initializeLoginViewContext.Object); UserServiceHandlerImpl userService = new UserServiceHandlerImpl();
await service.HandleCreateLoginRequest(loginParams, createLoginContext.Object); var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath);
var contextId = System.Guid.NewGuid().ToString();
var userParams = new CreateUserParams var login = await SecurityTestUtils.CreateLogin(service, connectionResult, contextId);
{
ContextId = connectionResult.ConnectionInfo.OwnerUri,
User = SecurityTestUtils.GetTestUserInfo(loginParams.Login.Name)
};
var createUserContext = new Mock<RequestContext<CreateUserResult>>(); var user = await SecurityTestUtils.CreateUser(userService, connectionResult, contextId, login);
createUserContext.Setup(x => x.SendResult(It.IsAny<CreateUserResult>()))
.Returns(Task.FromResult(new object()));
// call the create login method await SecurityTestUtils.UpdateUser(userService, connectionResult, contextId, user);
await service.HandleCreateUserRequest(userParams, createUserContext.Object);
// verify the result await SecurityTestUtils.DeleteUser(userService, connectionResult, user);
createUserContext.Verify(x => x.SendResult(It.Is<CreateUserResult>
(p => p.Success && p.User.Name != string.Empty))); await SecurityTestUtils.DeleteLogin(service, connectionResult, login);
} }
} }
} }