mirror of
https://github.com/ckaczor/sqltoolsservice.git
synced 2026-01-18 01:25:41 -05:00
Update User edit handlers (#1903)
* WIP * Update user tests * WIP updates * WIP * Fix several edit use bugs * Disable failing tests * minor updates * Remove unused using
This commit is contained in:
@@ -3,19 +3,441 @@
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||
//
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Data;
|
||||
using System.Threading.Tasks;
|
||||
using System.Xml;
|
||||
using Microsoft.SqlServer.Management.Common;
|
||||
using Microsoft.SqlServer.Management.Sdk.Sfc;
|
||||
using Microsoft.SqlServer.Management.Smo;
|
||||
using Microsoft.SqlTools.Hosting.Protocol;
|
||||
using Microsoft.SqlTools.ServiceLayer.Connection;
|
||||
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
|
||||
using Microsoft.SqlTools.ServiceLayer.Management;
|
||||
using Microsoft.SqlTools.ServiceLayer.Security.Contracts;
|
||||
using Microsoft.SqlTools.ServiceLayer.Utility;
|
||||
|
||||
namespace Microsoft.SqlTools.ServiceLayer.Security
|
||||
{
|
||||
internal class UserServiceHandlerImpl
|
||||
{
|
||||
private class UserViewState
|
||||
{
|
||||
public string Database { get; set; }
|
||||
|
||||
public UserPrototypeData OriginalUserData { get; set; }
|
||||
|
||||
public UserViewState(string database, UserPrototypeData originalUserData)
|
||||
{
|
||||
this.Database = database;
|
||||
this.OriginalUserData = originalUserData;
|
||||
}
|
||||
}
|
||||
|
||||
private ConnectionService? connectionService;
|
||||
|
||||
private Dictionary<string, UserViewState> contextIdToViewState = new Dictionary<string, UserViewState>();
|
||||
|
||||
/// <summary>
|
||||
/// Internal for testing purposes only
|
||||
/// </summary>
|
||||
internal ConnectionService ConnectionServiceInstance
|
||||
{
|
||||
get
|
||||
{
|
||||
connectionService ??= ConnectionService.Instance;
|
||||
return connectionService;
|
||||
}
|
||||
|
||||
set
|
||||
{
|
||||
connectionService = value;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Handle request to initialize user view
|
||||
/// </summary>
|
||||
internal async Task HandleInitializeUserViewRequest(InitializeUserViewParams parameters, RequestContext<UserViewInfo> requestContext)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(parameters.Database))
|
||||
{
|
||||
throw new ArgumentNullException("parameters.Database");
|
||||
}
|
||||
|
||||
if (string.IsNullOrWhiteSpace(parameters.ContextId))
|
||||
{
|
||||
throw new ArgumentNullException("parameters.ContextId");
|
||||
}
|
||||
|
||||
ConnectionInfo originalConnInfo;
|
||||
ConnectionServiceInstance.TryFindConnection(parameters.ConnectionUri, out originalConnInfo);
|
||||
if (originalConnInfo == null)
|
||||
{
|
||||
throw new ArgumentException("Invalid connection URI '{0}'", parameters.ConnectionUri);
|
||||
}
|
||||
|
||||
string originalDatabaseName = originalConnInfo.ConnectionDetails.DatabaseName;
|
||||
try
|
||||
{
|
||||
originalConnInfo.ConnectionDetails.DatabaseName = parameters.Database;
|
||||
ConnectParams connectParams = new ConnectParams
|
||||
{
|
||||
OwnerUri = parameters.ContextId,
|
||||
Connection = originalConnInfo.ConnectionDetails,
|
||||
Type = Connection.ConnectionType.Default
|
||||
};
|
||||
await this.ConnectionServiceInstance.Connect(connectParams);
|
||||
}
|
||||
finally
|
||||
{
|
||||
originalConnInfo.ConnectionDetails.DatabaseName = originalDatabaseName;
|
||||
}
|
||||
|
||||
ConnectionInfo connInfo;
|
||||
this.ConnectionServiceInstance.TryFindConnection(parameters.ContextId, out connInfo);
|
||||
CDataContainer dataContainer = CreateUserDataContainer(connInfo, null, ConfigAction.Create, parameters.Database);
|
||||
|
||||
UserInfo? userInfo = null;
|
||||
if (!parameters.IsNewObject)
|
||||
{
|
||||
User? existingUser = null;
|
||||
string databaseUrn = string.Format(System.Globalization.CultureInfo.InvariantCulture,
|
||||
"Server/Database[@Name='{0}']",
|
||||
Urn.EscapeString(parameters.Database));
|
||||
Database? parentDb = dataContainer.Server.GetSmoObject(databaseUrn) as Database;
|
||||
existingUser = dataContainer.Server.Databases[parentDb.Name].Users[parameters.Name];
|
||||
userInfo = new UserInfo()
|
||||
{
|
||||
Name = parameters.Name,
|
||||
LoginName = existingUser.Login,
|
||||
DefaultSchema = existingUser.DefaultSchema
|
||||
};
|
||||
}
|
||||
|
||||
UserPrototypeFactory userPrototypeFactory = UserPrototypeFactory.GetInstance(dataContainer, userInfo, originalData: null);
|
||||
UserPrototype currentUserPrototype = userPrototypeFactory.GetUserPrototype(ExhaustiveUserTypes.LoginMappedUser);
|
||||
|
||||
IUserPrototypeWithDefaultLanguage defaultLanguagePrototype = currentUserPrototype as IUserPrototypeWithDefaultLanguage;
|
||||
string? defaultLanguageAlias = null;
|
||||
if (defaultLanguagePrototype != null && defaultLanguagePrototype.IsDefaultLanguageSupported)
|
||||
{
|
||||
string dbUrn = "Server/Database[@Name='" + Urn.EscapeString(parameters.Database) + "']";
|
||||
defaultLanguageAlias = defaultLanguagePrototype.DefaultLanguageAlias;
|
||||
//If engine returns default language as empty or null, that means the default language of
|
||||
//database will be used.
|
||||
//Default language is not applicable for users inside an uncontained authentication.
|
||||
if (string.IsNullOrEmpty(defaultLanguageAlias)
|
||||
&& (dataContainer.Server.GetSmoObject(dbUrn) as Database).ContainmentType != ContainmentType.None)
|
||||
{
|
||||
defaultLanguageAlias = SR.DefaultLanguagePlaceholder;
|
||||
}
|
||||
}
|
||||
|
||||
string? defaultSchema = null;
|
||||
IUserPrototypeWithDefaultSchema defaultSchemaPrototype = currentUserPrototype as IUserPrototypeWithDefaultSchema;
|
||||
if (defaultSchemaPrototype != null && defaultSchemaPrototype.IsDefaultSchemaSupported)
|
||||
{
|
||||
defaultSchema = defaultSchemaPrototype.DefaultSchema;
|
||||
}
|
||||
// IUserPrototypeWithPassword userWithPwdPrototype = currentUserPrototype as IUserPrototypeWithPassword;
|
||||
// if (userWithPwdPrototype != null && !this.DataContainer.IsNewObject)
|
||||
// {
|
||||
// this.passwordTextBox.Text = FAKE_PASSWORD;
|
||||
// this.confirmPwdTextBox.Text = FAKE_PASSWORD;
|
||||
// }
|
||||
|
||||
string? loginName = null;
|
||||
IUserPrototypeWithMappedLogin mappedLoginPrototype = currentUserPrototype as IUserPrototypeWithMappedLogin;
|
||||
if (mappedLoginPrototype != null)
|
||||
{
|
||||
loginName = mappedLoginPrototype.LoginName;
|
||||
}
|
||||
|
||||
List<string> databaseRoles = new List<string>();
|
||||
foreach (string role in currentUserPrototype.DatabaseRoleNames)
|
||||
{
|
||||
if (currentUserPrototype.IsRoleMember(role))
|
||||
{
|
||||
databaseRoles.Add(role);
|
||||
}
|
||||
}
|
||||
|
||||
List<string> schemaNames = new List<string>();
|
||||
foreach (string schema in currentUserPrototype.SchemaNames)
|
||||
{
|
||||
if (currentUserPrototype.IsSchemaOwner(schema))
|
||||
{
|
||||
schemaNames.Add(schema);
|
||||
}
|
||||
}
|
||||
|
||||
// default to dbo schema, if there isn't already a default
|
||||
if (string.IsNullOrWhiteSpace(defaultSchema) && currentUserPrototype.SchemaNames.Contains("dbo"))
|
||||
{
|
||||
defaultSchema = "dbo";
|
||||
}
|
||||
|
||||
ServerConnection serverConnection = dataContainer.ServerConnection;
|
||||
UserViewInfo userViewInfo = new UserViewInfo()
|
||||
{
|
||||
ObjectInfo = new UserInfo()
|
||||
{
|
||||
Type = DatabaseUserType.WithLogin,
|
||||
Name = currentUserPrototype.Name,
|
||||
LoginName = loginName,
|
||||
Password = string.Empty,
|
||||
DefaultSchema = defaultSchema,
|
||||
OwnedSchemas = schemaNames.ToArray(),
|
||||
DatabaseRoles = databaseRoles.ToArray(),
|
||||
DefaultLanguage = defaultLanguageAlias
|
||||
},
|
||||
SupportContainedUser = false, // support for these will be added later
|
||||
SupportWindowsAuthentication = false,
|
||||
SupportAADAuthentication = false,
|
||||
SupportSQLAuthentication = true,
|
||||
Languages = new string[] { },
|
||||
Schemas = currentUserPrototype.SchemaNames.ToArray(),
|
||||
Logins = LoadSqlLogins(serverConnection),
|
||||
DatabaseRoles = currentUserPrototype.DatabaseRoleNames.ToArray()
|
||||
};
|
||||
|
||||
this.contextIdToViewState.Add(
|
||||
parameters.ContextId,
|
||||
new UserViewState(parameters.Database, currentUserPrototype.CurrentState));
|
||||
|
||||
await requestContext.SendResult(userViewInfo);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Handle request to create a user
|
||||
/// </summary>
|
||||
internal async Task HandleCreateUserRequest(CreateUserParams parameters, RequestContext<CreateUserResult> requestContext)
|
||||
{
|
||||
if (parameters.ContextId == null)
|
||||
{
|
||||
throw new ArgumentException("Invalid context ID");
|
||||
}
|
||||
|
||||
UserViewState viewState;
|
||||
this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState);
|
||||
|
||||
if (viewState == null)
|
||||
{
|
||||
throw new ArgumentException("Invalid context ID view state");
|
||||
}
|
||||
|
||||
Tuple<bool, string> result = ConfigureUser(
|
||||
parameters.ContextId,
|
||||
parameters.User,
|
||||
ConfigAction.Create,
|
||||
RunType.RunNow,
|
||||
viewState.Database,
|
||||
viewState.OriginalUserData);
|
||||
|
||||
await requestContext.SendResult(new CreateUserResult()
|
||||
{
|
||||
User = parameters.User,
|
||||
Success = result.Item1,
|
||||
ErrorMessage = result.Item2
|
||||
});
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Handle request to update a user
|
||||
/// </summary>
|
||||
internal async Task HandleUpdateUserRequest(UpdateUserParams parameters, RequestContext<ResultStatus> requestContext)
|
||||
{
|
||||
if (parameters.ContextId == null)
|
||||
{
|
||||
throw new ArgumentException("Invalid context ID");
|
||||
}
|
||||
|
||||
UserViewState viewState;
|
||||
this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState);
|
||||
|
||||
if (viewState == null)
|
||||
{
|
||||
throw new ArgumentException("Invalid context ID view state");
|
||||
}
|
||||
|
||||
Tuple<bool, string> result = ConfigureUser(
|
||||
parameters.ContextId,
|
||||
parameters.User,
|
||||
ConfigAction.Update,
|
||||
RunType.RunNow,
|
||||
viewState.Database,
|
||||
viewState.OriginalUserData);
|
||||
|
||||
await requestContext.SendResult(new ResultStatus()
|
||||
{
|
||||
Success = result.Item1,
|
||||
ErrorMessage = result.Item2
|
||||
});
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Handle request to delete a user
|
||||
/// </summary>
|
||||
internal async Task HandleDeleteUserRequest(DeleteUserParams parameters, RequestContext<ResultStatus> requestContext)
|
||||
{
|
||||
ConnectionInfo connInfo;
|
||||
ConnectionServiceInstance.TryFindConnection(parameters.ConnectionUri, out connInfo);
|
||||
if (connInfo == null)
|
||||
{
|
||||
throw new ArgumentException("Invalid ConnectionUri");
|
||||
}
|
||||
|
||||
if (string.IsNullOrWhiteSpace(parameters.Name) || string.IsNullOrWhiteSpace(parameters.Database))
|
||||
{
|
||||
throw new ArgumentException("Invalid null parameter");
|
||||
}
|
||||
|
||||
CDataContainer dataContainer = CDataContainer.CreateDataContainer(connInfo, databaseExists: true);
|
||||
string dbUrn = "Server/Database[@Name='" + Urn.EscapeString(parameters.Database) + "']";
|
||||
Database? parent = dataContainer.Server.GetSmoObject(new Urn(dbUrn)) as Database;
|
||||
User user = parent.Users[parameters.Name];
|
||||
dataContainer.SqlDialogSubject = user;
|
||||
|
||||
CheckForSchemaOwnerships(parent, user);
|
||||
DatabaseUtils.DoDropObject(dataContainer);
|
||||
|
||||
await requestContext.SendResult(new ResultStatus()
|
||||
{
|
||||
Success = true,
|
||||
ErrorMessage = string.Empty
|
||||
});
|
||||
}
|
||||
|
||||
internal async Task HandleDisposeUserViewRequest(DisposeUserViewRequestParams parameters, RequestContext<ResultStatus> requestContext)
|
||||
{
|
||||
this.ConnectionServiceInstance.Disconnect(new DisconnectParams(){
|
||||
OwnerUri = parameters.ContextId,
|
||||
Type = null
|
||||
});
|
||||
|
||||
if (parameters.ContextId != null)
|
||||
{
|
||||
this.contextIdToViewState.Remove(parameters.ContextId);
|
||||
}
|
||||
|
||||
await requestContext.SendResult(new ResultStatus()
|
||||
{
|
||||
Success = true,
|
||||
ErrorMessage = string.Empty
|
||||
});
|
||||
}
|
||||
|
||||
internal CDataContainer CreateUserDataContainer(
|
||||
ConnectionInfo connInfo,
|
||||
UserInfo? user,
|
||||
ConfigAction configAction,
|
||||
string databaseName)
|
||||
{
|
||||
var serverConnection = ConnectionService.OpenServerConnection(connInfo, "DataContainer");
|
||||
var connectionInfoWithConnection = new SqlConnectionInfoWithConnection();
|
||||
connectionInfoWithConnection.ServerConnection = serverConnection;
|
||||
|
||||
string urn = (configAction == ConfigAction.Update && user != null)
|
||||
? string.Format(System.Globalization.CultureInfo.InvariantCulture,
|
||||
"Server/Database[@Name='{0}']/User[@Name='{1}']",
|
||||
Urn.EscapeString(databaseName),
|
||||
Urn.EscapeString(user.Name))
|
||||
: string.Format(System.Globalization.CultureInfo.InvariantCulture,
|
||||
"Server/Database[@Name='{0}']",
|
||||
Urn.EscapeString(databaseName));
|
||||
|
||||
ActionContext context = new ActionContext(serverConnection, "User", urn);
|
||||
DataContainerXmlGenerator containerXml = new DataContainerXmlGenerator(context);
|
||||
|
||||
if (configAction == ConfigAction.Create)
|
||||
{
|
||||
containerXml.AddProperty("itemtype", "User");
|
||||
}
|
||||
|
||||
XmlDocument xmlDoc = containerXml.GenerateXmlDocument();
|
||||
return CDataContainer.CreateDataContainer(connectionInfoWithConnection, xmlDoc);
|
||||
}
|
||||
|
||||
internal Tuple<bool, string> ConfigureUser(
|
||||
string? ownerUri,
|
||||
UserInfo? user,
|
||||
ConfigAction configAction,
|
||||
RunType runType,
|
||||
string databaseName,
|
||||
UserPrototypeData? originalData)
|
||||
{
|
||||
ConnectionInfo connInfo;
|
||||
this.ConnectionServiceInstance.TryFindConnection(ownerUri, out connInfo);
|
||||
if (connInfo == null)
|
||||
{
|
||||
throw new ArgumentException("Invalid connection URI '{0}'", ownerUri);
|
||||
}
|
||||
|
||||
CDataContainer dataContainer = CreateUserDataContainer(connInfo, user, configAction, databaseName);
|
||||
using (var actions = new UserActions(dataContainer, user, configAction, originalData))
|
||||
{
|
||||
var executionHandler = new ExecutonHandler(actions);
|
||||
executionHandler.RunNow(runType, this);
|
||||
if (executionHandler.ExecutionResult == ExecutionMode.Failure)
|
||||
{
|
||||
throw executionHandler.ExecutionFailureException;
|
||||
}
|
||||
}
|
||||
|
||||
return new Tuple<bool, string>(true, string.Empty);
|
||||
}
|
||||
|
||||
private void CheckForSchemaOwnerships(Database parentDb, User existingUser)
|
||||
{
|
||||
foreach (Schema sch in parentDb.Schemas)
|
||||
{
|
||||
var comparer = parentDb.GetStringComparer();
|
||||
if (comparer.Compare(sch.Owner, existingUser.Name) == 0)
|
||||
{
|
||||
throw new ApplicationException("Cannot drop user since it owns a schema");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private string[] LoadSqlLogins(ServerConnection serverConnection)
|
||||
{
|
||||
return LoadItems(serverConnection, "Server/Login");
|
||||
}
|
||||
|
||||
private string[] LoadItems(ServerConnection serverConnection, string urn)
|
||||
{
|
||||
List<string> items = new List<string>();
|
||||
Request req = new Request();
|
||||
req.Urn = urn;
|
||||
req.ResultType = ResultType.IDataReader;
|
||||
req.Fields = new string[] { "Name" };
|
||||
|
||||
Enumerator en = new Enumerator();
|
||||
using (IDataReader reader = en.Process(serverConnection, req).Data as IDataReader)
|
||||
{
|
||||
if (reader != null)
|
||||
{
|
||||
string name;
|
||||
while (reader.Read())
|
||||
{
|
||||
// Get the permission name
|
||||
name = reader.GetString(0);
|
||||
items.Add(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
return items.ToArray();
|
||||
}
|
||||
}
|
||||
|
||||
internal class UserActions : ManagementActionBase
|
||||
{
|
||||
#region Variables
|
||||
//private UserPrototypeData userData;
|
||||
private UserPrototype userPrototype;
|
||||
private UserInfo user;
|
||||
private UserInfo? user;
|
||||
private ConfigAction configAction;
|
||||
#endregion
|
||||
|
||||
@@ -27,13 +449,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
|
||||
public UserActions(
|
||||
CDataContainer context,
|
||||
UserInfo? user,
|
||||
ConfigAction configAction)
|
||||
ConfigAction configAction,
|
||||
UserPrototypeData? originalData)
|
||||
{
|
||||
this.DataContainer = context;
|
||||
this.user = user;
|
||||
this.configAction = configAction;
|
||||
|
||||
this.userPrototype = InitUserNew(context, user);
|
||||
this.userPrototype = InitUserPrototype(context, user, originalData);
|
||||
}
|
||||
|
||||
// /// <summary>
|
||||
@@ -65,14 +488,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
|
||||
}
|
||||
}
|
||||
|
||||
private UserPrototype InitUserNew(CDataContainer dataContainer, UserInfo user)
|
||||
private UserPrototype InitUserPrototype(CDataContainer dataContainer, UserInfo user, UserPrototypeData? originalData)
|
||||
{
|
||||
ExhaustiveUserTypes currentUserType;
|
||||
UserPrototypeFactory userPrototypeFactory = UserPrototypeFactory.GetInstance(dataContainer, user);
|
||||
UserPrototypeFactory userPrototypeFactory = UserPrototypeFactory.GetInstance(dataContainer, user, originalData);
|
||||
|
||||
if (dataContainer.IsNewObject)
|
||||
{
|
||||
if (IsParentDatabaseContained(dataContainer.ParentUrn, dataContainer))
|
||||
if (dataContainer.Server != null && IsParentDatabaseContained(dataContainer.ParentUrn, dataContainer.Server))
|
||||
{
|
||||
currentUserType = ExhaustiveUserTypes.SqlUserWithPassword;
|
||||
}
|
||||
@@ -91,8 +514,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
|
||||
return currentUserPrototype;
|
||||
}
|
||||
|
||||
private ExhaustiveUserTypes GetCurrentUserTypeForExistingUser(User user)
|
||||
private ExhaustiveUserTypes GetCurrentUserTypeForExistingUser(User? user)
|
||||
{
|
||||
if (user == null)
|
||||
{
|
||||
return ExhaustiveUserTypes.Unknown;
|
||||
}
|
||||
|
||||
switch (user.UserType)
|
||||
{
|
||||
case UserType.SqlUser:
|
||||
@@ -124,10 +552,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
|
||||
}
|
||||
}
|
||||
|
||||
private bool IsParentDatabaseContained(Urn parentDbUrn, CDataContainer dataContainer)
|
||||
private static bool IsParentDatabaseContained(Urn parentDbUrn, Server server)
|
||||
{
|
||||
string parentDbName = parentDbUrn.GetNameForType("Database");
|
||||
Database parentDatabase = dataContainer.Server.Databases[parentDbName];
|
||||
Database parentDatabase = server.Databases[parentDbName];
|
||||
|
||||
if (parentDatabase.IsSupportedProperty("ContainmentType")
|
||||
&& parentDatabase.ContainmentType == ContainmentType.Partial)
|
||||
|
||||
Reference in New Issue
Block a user