Simplify Object Management APIs (#2015)

* unify requests-wip

* wip

* unify api

* fix test

* add credential handler

* fix credential handler issue.

* generic type update

* fix scripting for user
This commit is contained in:
Alan Ren
2023-04-19 15:43:01 -07:00
committed by GitHub
parent 98ad0197e4
commit e314f839d8
57 changed files with 1802 additions and 2234 deletions

View File

@@ -34,7 +34,6 @@ using Microsoft.SqlTools.ServiceLayer.Profiler;
using Microsoft.SqlTools.ServiceLayer.QueryExecution; using Microsoft.SqlTools.ServiceLayer.QueryExecution;
using Microsoft.SqlTools.ServiceLayer.SchemaCompare; using Microsoft.SqlTools.ServiceLayer.SchemaCompare;
using Microsoft.SqlTools.ServiceLayer.Scripting; using Microsoft.SqlTools.ServiceLayer.Scripting;
using Microsoft.SqlTools.ServiceLayer.Security;
using Microsoft.SqlTools.ServiceLayer.ServerConfigurations; using Microsoft.SqlTools.ServiceLayer.ServerConfigurations;
using Microsoft.SqlTools.ServiceLayer.SqlAssessment; using Microsoft.SqlTools.ServiceLayer.SqlAssessment;
using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.SqlContext;
@@ -130,9 +129,6 @@ namespace Microsoft.SqlTools.ServiceLayer
ProfilerService.Instance.InitializeService(serviceHost); ProfilerService.Instance.InitializeService(serviceHost);
serviceProvider.RegisterSingleService(ProfilerService.Instance); serviceProvider.RegisterSingleService(ProfilerService.Instance);
SecurityService.Instance.InitializeService(serviceHost);
serviceProvider.RegisterSingleService(SecurityService.Instance);
DacFxService.Instance.InitializeService(serviceHost, commandOptions); DacFxService.Instance.InitializeService(serviceHost, commandOptions);
serviceProvider.RegisterSingleService(DacFxService.Instance); serviceProvider.RegisterSingleService(DacFxService.Instance);

View File

@@ -0,0 +1,23 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
#nullable disable
using Microsoft.SqlTools.Hosting.Protocol.Contracts;
using Microsoft.SqlTools.Utility;
namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement.Contracts
{
public class DisposeViewRequestParams : GeneralRequestDetails
{
public string ContextId { get; set; }
}
public class DisposeViewRequestResponse { }
public class DisposeViewRequest
{
public static readonly RequestType<DisposeViewRequestParams, DisposeViewRequestResponse> Type = RequestType<DisposeViewRequestParams, DisposeViewRequestResponse>.Create("objectManagement/disposeView");
}
}

View File

@@ -11,6 +11,10 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement.Contracts
{ {
public class DropRequestParams : GeneralRequestDetails public class DropRequestParams : GeneralRequestDetails
{ {
/// <summary>
/// The object type.
/// </summary>
public SqlObjectType ObjectType { get; set; }
/// <summary> /// <summary>
/// SFC (SMO) URN identifying the object /// SFC (SMO) URN identifying the object
/// </summary> /// </summary>
@@ -25,8 +29,10 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement.Contracts
public bool ThrowIfNotExist { get; set; } = false; public bool ThrowIfNotExist { get; set; } = false;
} }
public class DropRequestResponse { }
public class DropRequest public class DropRequest
{ {
public static readonly RequestType<DropRequestParams, bool> Type = RequestType<DropRequestParams, bool>.Create("objectManagement/drop"); public static readonly RequestType<DropRequestParams, DropRequestResponse> Type = RequestType<DropRequestParams, DropRequestResponse>.Create("objectManagement/drop");
} }
} }

View File

@@ -0,0 +1,48 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
#nullable disable
using Microsoft.SqlTools.Hosting.Protocol.Contracts;
using Microsoft.SqlTools.Utility;
namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement.Contracts
{
public class InitializeViewRequestParams : GeneralRequestDetails
{
/// <summary>
/// The connection uri.
/// </summary>
public string ConnectionUri { get; set; }
/// <summary>
/// The target database name.
/// </summary>
public string Database { get; set; }
/// <summary>
/// The object type.
/// </summary>
public SqlObjectType ObjectType { get; set; }
/// <summary>
/// Whether the view is for a new object.
/// </summary>
public bool IsNewObject { get; set; }
/// <summary>
/// The object view context id.
/// </summary>
public string ContextId { get; set; }
/// <summary>
/// Urn of the parent object.
/// </summary>
public string ParentUrn { get; set; }
/// <summary>
/// Urn of the object. Only set when the view is for an existing object.
/// </summary>
public string ObjectUrn { get; set; }
}
public class InitializeViewRequest
{
public static readonly RequestType<InitializeViewRequestParams, SqlObjectViewInfo> Type = RequestType<InitializeViewRequestParams, SqlObjectViewInfo>.Create("objectManagement/initializeView");
}
}

View File

@@ -11,6 +11,10 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement.Contracts
{ {
public class RenameRequestParams : GeneralRequestDetails public class RenameRequestParams : GeneralRequestDetails
{ {
/// <summary>
/// The object type.
/// </summary>
public SqlObjectType ObjectType { get; set; }
/// <summary> /// <summary>
/// SFC (SMO) URN identifying the object /// SFC (SMO) URN identifying the object
/// </summary> /// </summary>
@@ -24,8 +28,11 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement.Contracts
/// </summary> /// </summary>
public string ConnectionUri { get; set; } public string ConnectionUri { get; set; }
} }
public class RenameRequestResponse { }
public class RenameRequest public class RenameRequest
{ {
public static readonly RequestType<RenameRequestParams, bool> Type = RequestType<RenameRequestParams, bool>.Create("objectManagement/rename"); public static readonly RequestType<RenameRequestParams, RenameRequestResponse> Type = RequestType<RenameRequestParams, RenameRequestResponse>.Create("objectManagement/rename");
} }
} }

View File

@@ -0,0 +1,31 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
#nullable disable
using Microsoft.SqlTools.Hosting.Protocol.Contracts;
using Microsoft.SqlTools.Utility;
using Newtonsoft.Json.Linq;
namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement.Contracts
{
public class SaveObjectRequestParams : GeneralRequestDetails
{
/// <summary>
/// The context id.
/// </summary>
public string ContextId { get; set; }
/// <summary>
/// The object information.
/// </summary>
public JToken Object { get; set; }
}
public class SaveObjectRequestResponse { }
public class SaveObjectRequest
{
public static readonly RequestType<SaveObjectRequestParams, SaveObjectRequestResponse> Type = RequestType<SaveObjectRequestParams, SaveObjectRequestResponse>.Create("objectManagement/save");
}
}

View File

@@ -0,0 +1,29 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
#nullable disable
using Microsoft.SqlTools.Hosting.Protocol.Contracts;
using Microsoft.SqlTools.Utility;
using Newtonsoft.Json.Linq;
namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement.Contracts
{
public class ScriptObjectRequestParams : GeneralRequestDetails
{
/// <summary>
/// The context id.
/// </summary>
public string ContextId { get; set; }
/// <summary>
/// The object information.
/// </summary>
public JToken Object { get; set; }
}
public class ScriptObjectRequest
{
public static readonly RequestType<ScriptObjectRequestParams, string> Type = RequestType<ScriptObjectRequestParams, string>.Create("objectManagement/script");
}
}

View File

@@ -6,15 +6,11 @@
#nullable disable #nullable disable
using System; using System;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlServer.Management.Sdk.Sfc;
using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.Hosting.Protocol; using Microsoft.SqlTools.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Management;
using Microsoft.SqlTools.ServiceLayer.ObjectManagement.Contracts; using Microsoft.SqlTools.ServiceLayer.ObjectManagement.Contracts;
using Microsoft.SqlTools.ServiceLayer.Utility; using System.Collections.Generic;
using Microsoft.SqlTools.Utility; using System.Collections.Concurrent;
namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{ {
@@ -23,12 +19,21 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
/// </summary> /// </summary>
public class ObjectManagementService public class ObjectManagementService
{ {
private const string ObjectManagementServiceApplicationName = "azdata-object-management"; public const string ApplicationName = "azdata-object-management";
private static Lazy<ObjectManagementService> objectManagementServiceInstance = new Lazy<ObjectManagementService>(() => new ObjectManagementService()); private static Lazy<ObjectManagementService> objectManagementServiceInstance = new Lazy<ObjectManagementService>(() => new ObjectManagementService());
public static ObjectManagementService Instance => objectManagementServiceInstance.Value; public static ObjectManagementService Instance => objectManagementServiceInstance.Value;
public static ConnectionService connectionService; public static ConnectionService connectionService;
private IProtocolEndpoint serviceHost; private IProtocolEndpoint serviceHost;
public ObjectManagementService() { } private List<IObjectTypeHandler> objectTypeHandlers = new List<IObjectTypeHandler>();
private ConcurrentDictionary<string, SqlObjectViewContext> contextMap = new ConcurrentDictionary<string, SqlObjectViewContext>();
public ObjectManagementService()
{
this.objectTypeHandlers.Add(new CommonObjectTypeHandler(ConnectionService.Instance));
this.objectTypeHandlers.Add(new LoginHandler(ConnectionService.Instance));
this.objectTypeHandlers.Add(new UserHandler(ConnectionService.Instance));
this.objectTypeHandlers.Add(new CredentialHandler(ConnectionService.Instance));
}
/// <summary> /// <summary>
/// Internal for testing purposes only /// Internal for testing purposes only
@@ -51,84 +56,81 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
this.serviceHost = serviceHost; this.serviceHost = serviceHost;
this.serviceHost.SetRequestHandler(RenameRequest.Type, HandleRenameRequest, true); this.serviceHost.SetRequestHandler(RenameRequest.Type, HandleRenameRequest, true);
this.serviceHost.SetRequestHandler(DropRequest.Type, HandleDropRequest, true); this.serviceHost.SetRequestHandler(DropRequest.Type, HandleDropRequest, true);
this.serviceHost.SetRequestHandler(InitializeViewRequest.Type, HandleInitializeViewRequest, true);
this.serviceHost.SetRequestHandler(SaveObjectRequest.Type, HandleSaveObjectRequest, true);
this.serviceHost.SetRequestHandler(ScriptObjectRequest.Type, HandleScriptObjectRequest, true);
this.serviceHost.SetRequestHandler(DisposeViewRequest.Type, HandleDisposeViewRequest, true);
} }
/// <summary> internal async Task HandleRenameRequest(RenameRequestParams requestParams, RequestContext<RenameRequestResponse> requestContext)
/// Method to handle the renaming operation
/// </summary>
/// <param name="requestParams">parameters which are needed to execute renaming operation</param>
/// <param name="requestContext">Request Context</param>
/// <returns></returns>
internal async Task HandleRenameRequest(RenameRequestParams requestParams, RequestContext<bool> requestContext)
{ {
Logger.Verbose("Handle Request in HandleRenameRequest()"); var handler = this.GetObjectTypeHandler(requestParams.ObjectType);
ExecuteActionOnObject(requestParams.ConnectionUri, requestParams.ObjectUrn, (dbObject) => await handler.Rename(requestParams.ConnectionUri, requestParams.ObjectUrn, requestParams.NewName);
{ await requestContext.SendResult(new RenameRequestResponse());
var renamable = dbObject as IRenamable;
if (renamable != null)
{
renamable.Rename(requestParams.NewName);
}
else
{
throw new Exception(SR.ObjectNotRenamable(requestParams.ObjectUrn));
}
});
await requestContext.SendResult(true);
} }
/// <summary> internal async Task HandleDropRequest(DropRequestParams requestParams, RequestContext<DropRequestResponse> requestContext)
/// Method to handle the delete object request
/// </summary>
/// <param name="requestParams">parameters which are needed to execute deletion operation</param>
/// <param name="requestContext">Request Context</param>
/// <returns></returns>
internal async Task HandleDropRequest(DropRequestParams requestParams, RequestContext<bool> requestContext)
{ {
Logger.Verbose("Handle Request in HandleDeleteRequest()"); var handler = this.GetObjectTypeHandler(requestParams.ObjectType);
ConnectionInfo connectionInfo = this.GetConnectionInfo(requestParams.ConnectionUri); await handler.Drop(requestParams.ConnectionUri, requestParams.ObjectUrn, requestParams.ThrowIfNotExist);
using (CDataContainer dataContainer = CDataContainer.CreateDataContainer(connectionInfo, databaseExists: true)) await requestContext.SendResult(new DropRequestResponse());
{
try
{
dataContainer.SqlDialogSubject = dataContainer.Server?.GetSmoObject(requestParams.ObjectUrn);
DatabaseUtils.DoDropObject(dataContainer);
}
catch (FailedOperationException ex)
{
if (!(ex.InnerException is MissingObjectException) || (ex.InnerException is MissingObjectException && requestParams.ThrowIfNotExist))
{
throw;
}
}
}
await requestContext.SendResult(true);
} }
private ConnectionInfo GetConnectionInfo(string connectionUri) internal async Task HandleInitializeViewRequest(InitializeViewRequestParams requestParams, RequestContext<SqlObjectViewInfo> requestContext)
{ {
ConnectionInfo connInfo; var handler = this.GetObjectTypeHandler(requestParams.ObjectType);
if (ConnectionServiceInstance.TryFindConnection(connectionUri, out connInfo)) var result = await handler.InitializeObjectView(requestParams);
{ contextMap[requestParams.ContextId] = result.Context;
return connInfo; await requestContext.SendResult(result.ViewInfo);
}
else
{
Logger.Error($"The connection with URI '{connectionUri}' could not be found.");
throw new Exception(SR.ErrorConnectionNotFound);
}
} }
private void ExecuteActionOnObject(string connectionUri, string objectUrn, Action<SqlSmoObject> action) internal async Task HandleSaveObjectRequest(SaveObjectRequestParams requestParams, RequestContext<SaveObjectRequestResponse> requestContext)
{ {
ConnectionInfo connInfo = this.GetConnectionInfo(connectionUri); var context = this.GetContext(requestParams.ContextId);
ServerConnection serverConnection = ConnectionService.OpenServerConnection(connInfo, ObjectManagementServiceApplicationName); var handler = this.GetObjectTypeHandler(context.Parameters.ObjectType);
using (serverConnection.SqlConnectionObject) var obj = requestParams.Object.ToObject(handler.GetObjectType());
{ await handler.Save(context, obj as SqlObject);
Server server = new Server(serverConnection); await requestContext.SendResult(new SaveObjectRequestResponse());
SqlSmoObject dbObject = server.GetSmoObject(new Urn(objectUrn));
action(dbObject);
} }
internal async Task HandleScriptObjectRequest(ScriptObjectRequestParams requestParams, RequestContext<string> requestContext)
{
var context = this.GetContext(requestParams.ContextId);
var handler = this.GetObjectTypeHandler(context.Parameters.ObjectType);
var obj = requestParams.Object.ToObject(handler.GetObjectType());
var script = await handler.Script(context, obj as SqlObject);
await requestContext.SendResult(script);
}
internal async Task HandleDisposeViewRequest(DisposeViewRequestParams requestParams, RequestContext<DisposeViewRequestResponse> requestContext)
{
SqlObjectViewContext context;
if (contextMap.Remove(requestParams.ContextId, out context))
{
context.Dispose();
}
await requestContext.SendResult(new DisposeViewRequestResponse());
}
private IObjectTypeHandler GetObjectTypeHandler(SqlObjectType objectType)
{
foreach (var handler in objectTypeHandlers)
{
if (handler.CanHandleType(objectType))
{
return handler;
}
}
throw new NotSupportedException(objectType.ToString());
}
private SqlObjectViewContext GetContext(string contextId)
{
if (contextMap.TryGetValue(contextId, out SqlObjectViewContext context))
{
return context;
}
throw new ArgumentException($"Context '{contextId}' not found");
} }
} }
} }

View File

@@ -0,0 +1,115 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System;
using System.Threading.Tasks;
using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlServer.Management.Sdk.Sfc;
using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Management;
using Microsoft.SqlTools.ServiceLayer.Utility;
using Microsoft.SqlTools.Utility;
namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{
public interface IObjectTypeHandler
{
bool CanHandleType(SqlObjectType objectType);
Task<InitializeViewResult> InitializeObjectView(Contracts.InitializeViewRequestParams requestParams);
Task Save(SqlObjectViewContext context, SqlObject obj);
Task<string> Script(SqlObjectViewContext context, SqlObject obj);
Type GetObjectType();
Task Rename(string connectionUri, string objectUrn, string newName);
Task Drop(string connectionUri, string objectUrn, bool throwIfNotExist);
}
public abstract class ObjectTypeHandler<ObjectType, ContextType> : IObjectTypeHandler
where ObjectType : SqlObject
where ContextType : SqlObjectViewContext
{
protected ConnectionService ConnectionService { get; }
public ObjectTypeHandler(ConnectionService connectionService)
{
this.ConnectionService = connectionService;
}
public abstract bool CanHandleType(SqlObjectType objectType);
public abstract Task<InitializeViewResult> InitializeObjectView(Contracts.InitializeViewRequestParams requestParams);
public abstract Task Save(ContextType context, ObjectType obj);
public abstract Task<string> Script(ContextType context, ObjectType obj);
public Task Save(SqlObjectViewContext context, SqlObject obj)
{
return this.Save((ContextType)context, (ObjectType)obj);
}
public Task<string> Script(SqlObjectViewContext context, SqlObject obj)
{
return this.Script((ContextType)context, (ObjectType)obj);
}
public Type GetObjectType()
{
return typeof(ObjectType);
}
public virtual Task Rename(string connectionUri, string objectUrn, string newName)
{
ConnectionInfo connInfo = this.GetConnectionInfo(connectionUri);
ServerConnection serverConnection = ConnectionService.OpenServerConnection(connInfo, ObjectManagementService.ApplicationName);
using (serverConnection.SqlConnectionObject)
{
Server server = new Server(serverConnection);
SqlSmoObject dbObject = server.GetSmoObject(new Urn(objectUrn));
var renamable = dbObject as IRenamable;
if (renamable != null)
{
renamable.Rename(newName);
}
else
{
throw new Exception(SR.ObjectNotRenamable(objectUrn));
}
}
return Task.CompletedTask;
}
public virtual Task Drop(string connectionUri, string objectUrn, bool throwIfNotExist)
{
ConnectionInfo connectionInfo = this.GetConnectionInfo(connectionUri);
using (CDataContainer dataContainer = CDataContainer.CreateDataContainer(connectionInfo, databaseExists: true))
{
try
{
dataContainer.SqlDialogSubject = dataContainer.Server?.GetSmoObject(objectUrn);
DatabaseUtils.DoDropObject(dataContainer);
}
catch (FailedOperationException ex)
{
if (!(ex.InnerException is MissingObjectException) || (ex.InnerException is MissingObjectException && throwIfNotExist))
{
throw;
}
}
}
return Task.CompletedTask;
}
protected ConnectionInfo GetConnectionInfo(string connectionUri)
{
ConnectionInfo connInfo;
if (this.ConnectionService.TryFindConnection(connectionUri, out connInfo))
{
return connInfo;
}
else
{
Logger.Error($"The connection with URI '{connectionUri}' could not be found.");
throw new Exception(SR.ErrorConnectionNotFound);
}
}
}
}

View File

@@ -0,0 +1,54 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.ObjectManagement.Contracts;
namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{
public class CommonObjectType : SqlObject { }
public class CommonObjectTypeViewContext : SqlObjectViewContext
{
public CommonObjectTypeViewContext(InitializeViewRequestParams parameters) : base(parameters) { }
public override void Dispose() { }
}
/// <summary>
/// A handler for the object types that only has rename/drop support
/// </summary>
public class CommonObjectTypeHandler : ObjectTypeHandler<CommonObjectType, CommonObjectTypeViewContext>
{
// The message is only used in developing time, no need to be localized.
private const string NotSupportedException = "This operation is not supported for this object type";
public CommonObjectTypeHandler(ConnectionService connectionService) : base(connectionService) { }
public override bool CanHandleType(SqlObjectType objectType)
{
return objectType == SqlObjectType.Column ||
objectType == SqlObjectType.Table ||
objectType == SqlObjectType.View;
}
public override Task Save(CommonObjectTypeViewContext context, CommonObjectType obj)
{
throw new NotSupportedException(NotSupportedException);
}
public override Task<InitializeViewResult> InitializeObjectView(Contracts.InitializeViewRequestParams requestParams)
{
throw new NotSupportedException(NotSupportedException);
}
public override Task<string> Script(CommonObjectTypeViewContext context, CommonObjectType obj)
{
throw new NotSupportedException(NotSupportedException);
}
}
}

View File

@@ -0,0 +1,88 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
#nullable disable
using System;
using System.Threading.Tasks;
using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Management;
namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{
/// <summary>
/// Credential object type handler
/// </summary>
public class CredentialHandler : ObjectTypeHandler<CredentialInfo, CredentialViewContext>
{
public CredentialHandler(ConnectionService connectionService) : base(connectionService)
{
}
public override bool CanHandleType(SqlObjectType objectType)
{
return objectType == SqlObjectType.Credential;
}
public override Task<InitializeViewResult> InitializeObjectView(Contracts.InitializeViewRequestParams parameters)
{
// TODO: this is partially implemented only.
ConnectionInfo connInfo;
this.ConnectionService.TryFindConnection(parameters.ConnectionUri, out connInfo);
CDataContainer dataContainer = CDataContainer.CreateDataContainer(connInfo, databaseExists: true);
var credentialInfo = new CredentialInfo();
if (!parameters.IsNewObject)
{
var credential = dataContainer.Server.GetSmoObject(parameters.ObjectUrn) as Credential;
credentialInfo.Name = credential.Name;
credentialInfo.Identity = credential.Identity;
credentialInfo.Id = credential.ID;
credentialInfo.DateLastModified = credential.DateLastModified;
credentialInfo.CreateDate = credential.CreateDate;
credentialInfo.ProviderName = credential.ProviderName;
}
var viewInfo = new CredentialViewInfo() { ObjectInfo = credentialInfo };
var context = new CredentialViewContext(parameters);
var result = new InitializeViewResult { ViewInfo = viewInfo, Context = context };
return Task.FromResult(result);
}
public override async Task Save(CredentialViewContext context, CredentialInfo obj)
{
await ConfigureCredential(context.Parameters.ConnectionUri, obj, ConfigAction.Update, RunType.RunNow);
}
public override Task<string> Script(CredentialViewContext context, CredentialInfo obj)
{
throw new NotImplementedException();
}
private Task<Tuple<bool, string>> ConfigureCredential(string ownerUri, CredentialInfo credential, ConfigAction configAction, RunType runType)
{
return Task<Tuple<bool, string>>.Run(() =>
{
try
{
ConnectionInfo connInfo;
this.ConnectionService.TryFindConnection(ownerUri, out connInfo);
CDataContainer dataContainer = CDataContainer.CreateDataContainer(connInfo, databaseExists: true);
using (CredentialActions actions = new CredentialActions(dataContainer, credential, configAction))
{
var executionHandler = new ExecutonHandler(actions);
executionHandler.RunNow(runType, this);
}
return new Tuple<bool, string>(true, string.Empty);
}
catch (Exception ex)
{
return new Tuple<bool, string>(false, ex.ToString());
}
});
}
}
}

View File

@@ -7,16 +7,15 @@
using System; using System;
namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{ {
/// <summary> /// <summary>
/// a class for storing various credential properties /// a class for storing various credential properties
/// </summary> /// </summary>
public class CredentialInfo public class CredentialInfo : SqlObject
{ {
public int Id { get; set; } public int Id { get; set; }
public string Identity { get; set; } public string Identity { get; set; }
public string Name { get; set; }
public DateTime DateLastModified { get; set; } public DateTime DateLastModified { get; set; }
public DateTime CreateDate { get; set; } public DateTime CreateDate { get; set; }
public string ProviderName { get; set; } public string ProviderName { get; set; }

View File

@@ -0,0 +1,16 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using Microsoft.SqlTools.ServiceLayer.ObjectManagement.Contracts;
namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{
public class CredentialViewContext : SqlObjectViewContext
{
public CredentialViewContext(InitializeViewRequestParams parameters) : base(parameters) { }
public override void Dispose() { }
}
}

View File

@@ -0,0 +1,14 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{
/// <summary>
/// The information required to render the credential view.
/// </summary>
public class CredentialViewInfo : SqlObjectViewInfo
{
}
}

View File

@@ -3,132 +3,178 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information. // Licensed under the MIT license. See LICENSE file in the project root for full license information.
// //
#nullable disable
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Data;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.SqlServer.Management.Common; using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlServer.Management.Smo; using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Management; using Microsoft.SqlTools.ServiceLayer.Management;
using Microsoft.SqlTools.ServiceLayer.Security.Contracts; using Microsoft.SqlTools.ServiceLayer.ObjectManagement.Contracts;
using Microsoft.SqlTools.ServiceLayer.Utility; using Microsoft.SqlTools.ServiceLayer.Utility;
namespace Microsoft.SqlTools.ServiceLayer.Security namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{ {
internal class LoginServiceHandlerImpl
{
private class ViewState
{
public bool IsNewObject { get; set; }
public string ConnectionUri { get; set; }
public ViewState(bool isNewObject, string connectionUri)
{
this.IsNewObject = isNewObject;
this.ConnectionUri = connectionUri;
}
}
private Dictionary<string, ViewState> contextIdToViewState = new Dictionary<string, ViewState>();
private ConnectionService? connectionService;
/// <summary> /// <summary>
/// Internal for testing purposes only /// Login object type handler
/// </summary> /// </summary>
internal ConnectionService ConnectionServiceInstance public class LoginHandler : ObjectTypeHandler<LoginInfo, LoginViewContext>
{ {
get public LoginHandler(ConnectionService connectionService) : base(connectionService) { }
public override bool CanHandleType(SqlObjectType objectType)
{ {
connectionService ??= ConnectionService.Instance; return objectType == SqlObjectType.ServerLevelLogin;
return connectionService;
} }
set public override Task<InitializeViewResult> InitializeObjectView(InitializeViewRequestParams parameters)
{ {
connectionService = value;
}
}
/// <summary>
/// Handle request to create a login
/// </summary>
internal async Task HandleCreateLoginRequest(CreateLoginParams parameters, RequestContext<object> requestContext)
{
DoHandleCreateLoginRequest(parameters.ContextId, parameters.Login, RunType.RunNow);
await requestContext.SendResult(new object());
}
private string DoHandleCreateLoginRequest(
string contextId, LoginInfo login, RunType runType)
{
ViewState? viewState;
this.contextIdToViewState.TryGetValue(contextId, out viewState);
ConnectionInfo connInfo; ConnectionInfo connInfo;
ConnectionServiceInstance.TryFindConnection(viewState?.ConnectionUri, out connInfo); this.ConnectionService.TryFindConnection(parameters.ConnectionUri, out connInfo);
if (connInfo == null) if (connInfo == null)
{ {
throw new ArgumentException("Invalid ConnectionUri"); throw new ArgumentException("Invalid ConnectionUri");
} }
CDataContainer dataContainer = CDataContainer.CreateDataContainer(connInfo, databaseExists: true); CDataContainer dataContainer = CDataContainer.CreateDataContainer(connInfo, databaseExists: true);
LoginPrototype prototype = new LoginPrototype(dataContainer.Server, login); LoginViewInfo loginViewInfo = new LoginViewInfo();
if (prototype.LoginType == SqlServer.Management.Smo.LoginType.SqlLogin) // TODO cache databases and languages
string[] databases = new string[dataContainer.Server.Databases.Count];
for (int i = 0; i < dataContainer.Server.Databases.Count; i++)
{ {
// check that there is a password databases[i] = dataContainer.Server.Databases[i].Name;
// this check is made if policy enforcement is off
// with policy turned on we do not display this message, instead we let server
// return the error associated with null password (coming from policy) - see bug 124377
if (prototype.SqlPassword.Length == 0 && prototype.EnforcePolicy == false)
{
// raise error here
} }
// check that password and confirm password controls' text matches var languageOptions = LanguageUtils.GetDefaultLanguageOptions(dataContainer);
if (0 != string.Compare(prototype.SqlPassword, prototype.SqlPasswordConfirm, StringComparison.Ordinal)) var languageOptionsList = languageOptions.Select(LanguageUtils.FormatLanguageDisplay).ToList();
if (parameters.IsNewObject)
{ {
// raise error here languageOptionsList.Insert(0, SR.DefaultLanguagePlaceholder);
}
string[] languages = languageOptionsList.ToArray();
LoginPrototype prototype = parameters.IsNewObject
? new LoginPrototype(dataContainer.Server)
: new LoginPrototype(dataContainer.Server, dataContainer.Server.GetSmoObject(parameters.ObjectUrn) as Login);
List<string> loginServerRoles = new List<string>();
foreach (string role in prototype.ServerRoles.ServerRoleNames)
{
if (prototype.ServerRoles.IsMember(role))
{
loginServerRoles.Add(role);
} }
} }
// TODO move this to LoginData LoginInfo loginInfo = new LoginInfo()
// TODO support role assignment for Azure
foreach (string role in login.ServerRoles ?? Enumerable.Empty<string>())
{ {
prototype.ServerRoles.SetMember(role, true); Name = prototype.LoginName,
Password = prototype.SqlPassword,
OldPassword = prototype.OldPassword,
AuthenticationType = LoginTypeToAuthenticationType(prototype.LoginType),
EnforcePasswordExpiration = prototype.EnforceExpiration,
EnforcePasswordPolicy = prototype.EnforcePolicy,
MustChangePassword = prototype.MustChange,
DefaultDatabase = prototype.DefaultDatabase,
DefaultLanguage = parameters.IsNewObject ? SR.DefaultLanguagePlaceholder : LanguageUtils.FormatLanguageDisplay(languageOptions.FirstOrDefault(o => o?.Language.Name == prototype.DefaultLanguage || o?.Language.Alias == prototype.DefaultLanguage, null)),
ServerRoles = loginServerRoles.ToArray(),
ConnectPermission = prototype.WindowsGrantAccess,
IsEnabled = !prototype.IsDisabled,
IsLockedOut = prototype.IsLockedOut,
UserMapping = new ServerLoginDatabaseUserMapping[0]
};
var viewInfo = new LoginViewInfo()
{
ObjectInfo = loginInfo,
SupportWindowsAuthentication = prototype.WindowsAuthSupported,
SupportAADAuthentication = prototype.AADAuthSupported,
SupportSQLAuthentication = true, // SQL Auth support for login, not necessarily mean SQL Auth support for CONNECT etc.
CanEditLockedOutState = !parameters.IsNewObject && prototype.IsLockedOut,
Databases = databases,
Languages = languages,
ServerRoles = prototype.ServerRoles.ServerRoleNames,
SupportAdvancedPasswordOptions = dataContainer.Server.DatabaseEngineType == DatabaseEngineType.Standalone || dataContainer.Server.DatabaseEngineEdition == DatabaseEngineEdition.SqlDataWarehouse,
SupportAdvancedOptions = dataContainer.Server.DatabaseEngineType == DatabaseEngineType.Standalone || dataContainer.Server.DatabaseEngineEdition == DatabaseEngineEdition.SqlManagedInstance
};
var context = new LoginViewContext(parameters);
return Task.FromResult(new InitializeViewResult()
{
ViewInfo = viewInfo,
Context = context
});
} }
return ConfigureLogin( public override Task Save(LoginViewContext context, LoginInfo obj)
dataContainer, {
ConfigAction.Create, if (context.Parameters.IsNewObject)
runType, {
prototype); this.DoHandleCreateLoginRequest(context, obj, RunType.RunNow);
}
else
{
this.DoHandleUpdateLoginRequest(context, obj, RunType.RunNow);
}
return Task.CompletedTask;
} }
internal async Task HandleUpdateLoginRequest(UpdateLoginParams parameters, RequestContext<object> requestContext) public override Task<string> Script(LoginViewContext context, LoginInfo obj)
{ {
DoHandleUpdateLoginRequest(parameters.ContextId, parameters.Login, RunType.RunNow); string script;
if (context.Parameters.IsNewObject)
await requestContext.SendResult(new object()); {
script = this.DoHandleCreateLoginRequest(context, obj, RunType.ScriptToWindow);
}
else
{
script = this.DoHandleUpdateLoginRequest(context, obj, RunType.ScriptToWindow);
}
return Task.FromResult(script);
} }
private string DoHandleUpdateLoginRequest( private LoginAuthenticationType LoginTypeToAuthenticationType(LoginType loginType)
string contextId, LoginInfo login, RunType runType)
{ {
ViewState? viewState; switch (loginType)
this.contextIdToViewState.TryGetValue(contextId, out viewState); {
case LoginType.WindowsUser:
case LoginType.WindowsGroup:
return LoginAuthenticationType.Windows;
case LoginType.SqlLogin:
return LoginAuthenticationType.Sql;
case LoginType.ExternalUser:
case LoginType.ExternalGroup:
return LoginAuthenticationType.AAD;
default:
return LoginAuthenticationType.Others;
}
}
private string ConfigureLogin(CDataContainer dataContainer, ConfigAction configAction, RunType runType, LoginPrototype prototype)
{
string sqlScript = string.Empty;
using (var actions = new LoginActions(dataContainer, configAction, prototype))
{
var executionHandler = new ExecutonHandler(actions);
executionHandler.RunNow(runType, this);
if (executionHandler.ExecutionResult == ExecutionMode.Failure)
{
throw executionHandler.ExecutionFailureException;
}
if (runType == RunType.ScriptToWindow)
{
sqlScript = executionHandler.ScriptTextFromLastRun;
}
}
return sqlScript;
}
private string DoHandleUpdateLoginRequest(LoginViewContext context, LoginInfo login, RunType runType)
{
ConnectionInfo connInfo; ConnectionInfo connInfo;
ConnectionServiceInstance.TryFindConnection(viewState?.ConnectionUri, out connInfo); this.ConnectionService.TryFindConnection(context.Parameters.ConnectionUri, out connInfo);
if (connInfo == null) if (connInfo == null)
{ {
throw new ArgumentException("Invalid ConnectionUri"); throw new ArgumentException("Invalid ConnectionUri");
@@ -191,182 +237,46 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
prototype); prototype);
} }
/// <summary> private string DoHandleCreateLoginRequest(LoginViewContext context, LoginInfo login, RunType runType)
/// Handle request to script a user
/// </summary>
internal async Task HandleScriptLoginRequest(ScriptLoginParams parameters, RequestContext<string> requestContext)
{ {
if (parameters.ContextId == null)
{
throw new ArgumentException("Invalid context ID");
}
ViewState viewState;
this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState);
if (viewState == null)
{
throw new ArgumentException("Invalid context ID view state");
}
string sqlScript = (viewState.IsNewObject)
? DoHandleCreateLoginRequest(parameters.ContextId, parameters.Login, RunType.ScriptToWindow)
: DoHandleUpdateLoginRequest(parameters.ContextId, parameters.Login, RunType.ScriptToWindow);
await requestContext.SendResult(sqlScript);
}
internal async Task HandleInitializeLoginViewRequest(InitializeLoginViewRequestParams parameters, RequestContext<LoginViewInfo> requestContext)
{
this.contextIdToViewState.Add(
parameters.ContextId,
new ViewState(parameters.IsNewObject, parameters.ConnectionUri));
ConnectionInfo connInfo; ConnectionInfo connInfo;
ConnectionServiceInstance.TryFindConnection(parameters.ConnectionUri, out connInfo); this.ConnectionService.TryFindConnection(context.Parameters.ConnectionUri, out connInfo);
if (connInfo == null) if (connInfo == null)
{ {
throw new ArgumentException("Invalid ConnectionUri"); throw new ArgumentException("Invalid ConnectionUri");
} }
CDataContainer dataContainer = CDataContainer.CreateDataContainer(connInfo, databaseExists: true); CDataContainer dataContainer = CDataContainer.CreateDataContainer(connInfo, databaseExists: true);
LoginViewInfo loginViewInfo = new LoginViewInfo(); LoginPrototype prototype = new LoginPrototype(dataContainer.Server, login);
// TODO cache databases and languages if (prototype.LoginType == SqlServer.Management.Smo.LoginType.SqlLogin)
string[] databases = new string[dataContainer.Server.Databases.Count];
for (int i = 0; i < dataContainer.Server.Databases.Count; i++)
{ {
databases[i] = dataContainer.Server.Databases[i].Name; // check that there is a password
// this check is made if policy enforcement is off
// with policy turned on we do not display this message, instead we let server
// return the error associated with null password (coming from policy) - see bug 124377
if (prototype.SqlPassword.Length == 0 && prototype.EnforcePolicy == false)
{
// raise error here
} }
var languageOptions = LanguageUtils.GetDefaultLanguageOptions(dataContainer); // check that password and confirm password controls' text matches
var languageOptionsList = languageOptions.Select(LanguageUtils.FormatLanguageDisplay).ToList(); if (0 != string.Compare(prototype.SqlPassword, prototype.SqlPasswordConfirm, StringComparison.Ordinal))
if (parameters.IsNewObject)
{ {
languageOptionsList.Insert(0, SR.DefaultLanguagePlaceholder); // raise error here
}
string[] languages = languageOptionsList.ToArray();
LoginPrototype prototype = parameters.IsNewObject
? new LoginPrototype(dataContainer.Server)
: new LoginPrototype(dataContainer.Server, dataContainer.Server.Logins[parameters.Name]);
List<string> loginServerRoles = new List<string>();
foreach (string role in prototype.ServerRoles.ServerRoleNames)
{
if (prototype.ServerRoles.IsMember(role))
{
loginServerRoles.Add(role);
} }
} }
LoginInfo loginInfo = new LoginInfo() // TODO move this to LoginData
// TODO support role assignment for Azure
foreach (string role in login.ServerRoles ?? Enumerable.Empty<string>())
{ {
Name = prototype.LoginName, prototype.ServerRoles.SetMember(role, true);
Password = prototype.SqlPassword,
OldPassword = prototype.OldPassword,
AuthenticationType = LoginTypeToAuthenticationType(prototype.LoginType),
EnforcePasswordExpiration = prototype.EnforceExpiration,
EnforcePasswordPolicy = prototype.EnforcePolicy,
MustChangePassword = prototype.MustChange,
DefaultDatabase = prototype.DefaultDatabase,
DefaultLanguage = parameters.IsNewObject ? SR.DefaultLanguagePlaceholder : LanguageUtils.FormatLanguageDisplay(languageOptions.FirstOrDefault(o => o?.Language.Name == prototype.DefaultLanguage || o?.Language.Alias == prototype.DefaultLanguage, null)),
ServerRoles = loginServerRoles.ToArray(),
ConnectPermission = prototype.WindowsGrantAccess,
IsEnabled = !prototype.IsDisabled,
IsLockedOut = prototype.IsLockedOut,
UserMapping = new ServerLoginDatabaseUserMapping[0]
};
await requestContext.SendResult(new LoginViewInfo()
{
ObjectInfo = loginInfo,
SupportWindowsAuthentication = prototype.WindowsAuthSupported,
SupportAADAuthentication = prototype.AADAuthSupported,
SupportSQLAuthentication = true, // SQL Auth support for login, not necessarily mean SQL Auth support for CONNECT etc.
CanEditLockedOutState = !parameters.IsNewObject && prototype.IsLockedOut,
Databases = databases,
Languages = languages,
ServerRoles = prototype.ServerRoles.ServerRoleNames,
SupportAdvancedPasswordOptions = dataContainer.Server.DatabaseEngineType == DatabaseEngineType.Standalone || dataContainer.Server.DatabaseEngineEdition == DatabaseEngineEdition.SqlDataWarehouse,
SupportAdvancedOptions = dataContainer.Server.DatabaseEngineType == DatabaseEngineType.Standalone || dataContainer.Server.DatabaseEngineEdition == DatabaseEngineEdition.SqlManagedInstance
});
} }
private LoginAuthenticationType LoginTypeToAuthenticationType(LoginType loginType) return ConfigureLogin(dataContainer, ConfigAction.Create, runType, prototype);
{
switch (loginType)
{
case LoginType.WindowsUser:
case LoginType.WindowsGroup:
return LoginAuthenticationType.Windows;
case LoginType.SqlLogin:
return LoginAuthenticationType.Sql;
case LoginType.ExternalUser:
case LoginType.ExternalGroup:
return LoginAuthenticationType.AAD;
default:
return LoginAuthenticationType.Others;
}
} }
internal async Task HandleDisposeLoginViewRequest(DisposeLoginViewRequestParams parameters, RequestContext<object> requestContext)
{
await requestContext.SendResult(new object());
}
internal string ConfigureLogin(
CDataContainer dataContainer,
ConfigAction configAction,
RunType runType,
LoginPrototype prototype)
{
string sqlScript = string.Empty;
using (var actions = new LoginActions(dataContainer, configAction, prototype))
{
var executionHandler = new ExecutonHandler(actions);
executionHandler.RunNow(runType, this);
if (executionHandler.ExecutionResult == ExecutionMode.Failure)
{
throw executionHandler.ExecutionFailureException;
}
if (runType == RunType.ScriptToWindow)
{
sqlScript = executionHandler.ScriptTextFromLastRun;
}
}
return sqlScript;
}
}
internal class LoginActions : ManagementActionBase
{
private ConfigAction configAction;
private LoginPrototype prototype;
/// <summary>
/// Handle login create and update actions
/// </summary>
public LoginActions(CDataContainer dataContainer, ConfigAction configAction, LoginPrototype prototype)
{
this.DataContainer = dataContainer;
this.configAction = configAction;
this.prototype = prototype;
}
/// <summary>
/// called by the management actions framework to execute the action
/// </summary>
/// <param name="node"></param>
public override void OnRunNow(object sender)
{
if (this.configAction != ConfigAction.Drop)
{
prototype.ApplyGeneralChanges(this.DataContainer.Server);
prototype.ApplyServerRoleChanges(this.DataContainer.Server);
prototype.ApplyDatabaseRoleChanges(this.DataContainer.Server);
}
}
} }
} }

View File

@@ -9,7 +9,7 @@ using System.Runtime.Serialization;
using Newtonsoft.Json; using Newtonsoft.Json;
using Newtonsoft.Json.Converters; using Newtonsoft.Json.Converters;
namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{ {
[JsonConverter(typeof(StringEnumConverter))] [JsonConverter(typeof(StringEnumConverter))]
public enum LoginAuthenticationType public enum LoginAuthenticationType
@@ -35,10 +35,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts
/// <summary> /// <summary>
/// a class for storing various login properties /// a class for storing various login properties
/// </summary> /// </summary>
public class LoginInfo public class LoginInfo : SqlObject
{ {
public string Name { get; set; }
public LoginAuthenticationType AuthenticationType { get; set; } public LoginAuthenticationType AuthenticationType { get; set; }
public bool WindowsGrantAccess { get; set; } public bool WindowsGrantAccess { get; set; }
@@ -62,7 +60,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts
public string DefaultDatabase { get; set; } public string DefaultDatabase { get; set; }
public string[] ServerRoles {get; set;} public string[] ServerRoles { get; set; }
public ServerLoginDatabaseUserMapping[] UserMapping; public ServerLoginDatabaseUserMapping[] UserMapping;
} }

View File

@@ -0,0 +1,18 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{
public class LoginViewContext : SqlObjectViewContext
{
public LoginViewContext(Contracts.InitializeViewRequestParams parameters) : base(parameters)
{
}
public override void Dispose()
{
}
}
}

View File

@@ -2,13 +2,12 @@
// Copyright (c) Microsoft. All rights reserved. // Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information. // Licensed under the MIT license. See LICENSE file in the project root for full license information.
// //
#nullable disable
namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{ {
public class LoginViewInfo public class LoginViewInfo : SqlObjectViewInfo
{ {
public LoginInfo ObjectInfo { get; set; }
public bool SupportWindowsAuthentication { get; set; } public bool SupportWindowsAuthentication { get; set; }
public bool SupportAADAuthentication { get; set; } public bool SupportAADAuthentication { get; set; }
public bool SupportSQLAuthentication { get; set; } public bool SupportSQLAuthentication { get; set; }

View File

@@ -15,7 +15,7 @@ using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlServer.Management.Smo; using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.ServiceLayer.Management; using Microsoft.SqlTools.ServiceLayer.Management;
namespace Microsoft.SqlTools.ServiceLayer.Security namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{ {
/// <summary> /// <summary>
/// AppRoleGeneral - main app role page /// AppRoleGeneral - main app role page

View File

@@ -6,24 +6,23 @@
#nullable disable #nullable disable
using Microsoft.SqlTools.ServiceLayer.Management; using Microsoft.SqlTools.ServiceLayer.Management;
using Microsoft.SqlTools.ServiceLayer.Security.Contracts;
namespace Microsoft.SqlTools.ServiceLayer.Security namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{ {
internal class CredentialActions : ManagementActionBase internal class CredentialActions : ManagementActionBase
{ {
#region Constants #region Constants
private const int MAX_SQL_SYS_NAME_LENGTH = 128; // max sql sys name length private const int MAX_SQL_SYS_NAME_LENGTH = 128; // max sql sys name length
#endregion #endregion
#region Variables #region Variables
private CredentialData credentialData = null; private CredentialData credentialData = null;
private CredentialInfo credential; private CredentialInfo credential;
private ConfigAction configAction; private ConfigAction configAction;
#endregion #endregion
#region Constructors / Dispose #region Constructors / Dispose
/// <summary> /// <summary>
/// required when loading from Object Explorer context /// required when loading from Object Explorer context
/// </summary> /// </summary>
@@ -44,9 +43,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
/// <summary> /// <summary>
/// Clean up any resources being used. /// Clean up any resources being used.
/// </summary> /// </summary>
protected override void Dispose( bool disposing ) protected override void Dispose(bool disposing)
{ {
base.Dispose( disposing ); base.Dispose(disposing);
if (disposing == true) if (disposing == true)
{ {
if (this.credentialData != null) if (this.credentialData != null)
@@ -55,7 +54,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
} }
} }
} }
#endregion #endregion
/// <summary> /// <summary>
/// called on background thread by the framework to execute the action /// called on background thread by the framework to execute the action

View File

@@ -10,9 +10,8 @@ using System.Security;
using Microsoft.SqlServer.Management.Sdk.Sfc; using Microsoft.SqlServer.Management.Sdk.Sfc;
using Microsoft.SqlServer.Management.Smo; using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.ServiceLayer.Management; using Microsoft.SqlTools.ServiceLayer.Management;
using Microsoft.SqlTools.ServiceLayer.Security.Contracts;
namespace Microsoft.SqlTools.ServiceLayer.Security namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{ {
internal class CredentialData : IDisposable internal class CredentialData : IDisposable
{ {
@@ -146,7 +145,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
bool isKatmaiAndNotMatrix = (this.context.Server.Version.Major >= 10); bool isKatmaiAndNotMatrix = (this.context.Server.Version.Major >= 10);
Urn urn = new Urn("Server/Credential[@Name='" + Urn.EscapeString(this.CredentialName) + "']"); Urn urn = new Urn("Server/Credential[@Name='" + Urn.EscapeString(this.CredentialName) + "']");
string [] fields; string[] fields;
if (isKatmaiAndNotMatrix) if (isKatmaiAndNotMatrix)
{ {
fields = new string[] { ENUMERATOR_FIELD_IDENTITY, ENUMERATOR_FIELD_PROVIDER_NAME }; fields = new string[] { ENUMERATOR_FIELD_IDENTITY, ENUMERATOR_FIELD_PROVIDER_NAME };
@@ -206,7 +205,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Security
/// </summary> /// </summary>
private void SendToServerCreateCredential() private void SendToServerCreateCredential()
{ {
Microsoft.SqlServer.Management.Smo.Credential smoCredential = new Microsoft.SqlServer.Management.Smo.Credential ( Microsoft.SqlServer.Management.Smo.Credential smoCredential = new Microsoft.SqlServer.Management.Smo.Credential(
this.Context.Server, this.Context.Server,
this.CredentialName); this.CredentialName);
if (this.isEncryptionByProvider) if (this.isEncryptionByProvider)

View File

@@ -15,8 +15,7 @@ using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlServer.Management.Smo; using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.ServiceLayer.Management; using Microsoft.SqlTools.ServiceLayer.Management;
namespace Microsoft.SqlTools.ServiceLayer.Security namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{ {
/// <summary> /// <summary>
/// DatabaseRoleGeneral - main panel for database role /// DatabaseRoleGeneral - main panel for database role

View File

@@ -0,0 +1,40 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using Microsoft.SqlTools.ServiceLayer.Management;
namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{
internal class LoginActions : ManagementActionBase
{
private ConfigAction configAction;
private LoginPrototype prototype;
/// <summary>
/// Handle login create and update actions
/// </summary>
public LoginActions(CDataContainer dataContainer, ConfigAction configAction, LoginPrototype prototype)
{
this.DataContainer = dataContainer;
this.configAction = configAction;
this.prototype = prototype;
}
/// <summary>
/// called by the management actions framework to execute the action
/// </summary>
/// <param name="node"></param>
public override void OnRunNow(object sender)
{
if (this.configAction != ConfigAction.Drop)
{
prototype.ApplyGeneralChanges(this.DataContainer.Server);
prototype.ApplyServerRoleChanges(this.DataContainer.Server);
prototype.ApplyDatabaseRoleChanges(this.DataContainer.Server);
}
}
}
}

View File

@@ -16,9 +16,8 @@ using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlServer.Management.Sdk.Sfc; using Microsoft.SqlServer.Management.Sdk.Sfc;
using Microsoft.SqlServer.Management.Smo; using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.ServiceLayer.Management; using Microsoft.SqlTools.ServiceLayer.Management;
using Microsoft.SqlTools.ServiceLayer.Security.Contracts;
namespace Microsoft.SqlTools.ServiceLayer.Security namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{ {
/// <summary> /// <summary>
/// Encapsulates database roles, access and default schema /// Encapsulates database roles, access and default schema

View File

@@ -10,7 +10,7 @@ using System.Globalization;
using SMO = Microsoft.SqlServer.Management.Smo; using SMO = Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlServer.Management.Common; using Microsoft.SqlServer.Management.Common;
namespace Microsoft.SqlTools.ServiceLayer.Security namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{ {
/// <summary> /// <summary>
/// NetStandard compatible helpers /// NetStandard compatible helpers

View File

@@ -17,7 +17,7 @@ using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlServer.Management.Smo; using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlServer.Management.Smo.Broker; using Microsoft.SqlServer.Management.Smo.Broker;
namespace Microsoft.SqlTools.ServiceLayer.Security namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{ {
/// <summary> /// <summary>
/// Enumeration of sql object types that can have permissions /// Enumeration of sql object types that can have permissions

View File

@@ -9,7 +9,7 @@ using System;
using System.Linq; using System.Linq;
using Microsoft.SqlServer.Management.Common; using Microsoft.SqlServer.Management.Common;
namespace Microsoft.SqlTools.ServiceLayer.Security namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{ {
internal static class PermissionsDataExtensions internal static class PermissionsDataExtensions
{ {

View File

@@ -12,7 +12,7 @@ using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlServer.Management.Sdk.Sfc; using Microsoft.SqlServer.Management.Sdk.Sfc;
using Microsoft.SqlServer.Management.Smo; using Microsoft.SqlServer.Management.Smo;
namespace Microsoft.SqlTools.ServiceLayer.Security namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{ {
/// <summary> /// <summary>
/// An attribute for sqlmgmt\src\permissionsdata.cs!SecurableType that maps it to the corresponding SMO /// An attribute for sqlmgmt\src\permissionsdata.cs!SecurableType that maps it to the corresponding SMO

View File

@@ -14,7 +14,7 @@ using System.Collections.Specialized;
using Microsoft.SqlServer.Management.Common; using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlTools.ServiceLayer.Management; using Microsoft.SqlTools.ServiceLayer.Management;
namespace Microsoft.SqlTools.ServiceLayer.Security namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{ {
public class ServerRoleManageTaskFormComponent public class ServerRoleManageTaskFormComponent
{ {

View File

@@ -13,7 +13,7 @@ using Microsoft.SqlServer.Management.Smo;
#endregion #endregion
namespace Microsoft.SqlTools.ServiceLayer.Security namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{ {
/// <summary> /// <summary>
/// String comparer that uses the case sensitivity and other settings /// String comparer that uses the case sensitivity and other settings

View File

@@ -15,7 +15,7 @@ using System.Globalization;
using Microsoft.SqlServer.Management.Smo; using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlServer.Management.Sdk.Sfc; using Microsoft.SqlServer.Management.Sdk.Sfc;
namespace Microsoft.SqlTools.ServiceLayer.Security namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{ {
/// <summary> /// <summary>
/// An enumeration of the SQL object types the search dialog knows how to look for /// An enumeration of the SQL object types the search dialog knows how to look for

View File

@@ -0,0 +1,167 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using Microsoft.SqlServer.Management.Sdk.Sfc;
using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.ServiceLayer.Management;
namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{
internal class UserActions : ManagementActionBase
{
#region Variables
private UserPrototype userPrototype;
private ConfigAction configAction;
#endregion
#region Constructors / Dispose
/// <summary>
/// Handle user create and update actions
/// </summary>
public UserActions(
CDataContainer dataContainer,
ConfigAction configAction,
UserInfo user,
UserPrototypeData? originalData)
{
this.DataContainer = dataContainer;
this.IsDatabaseOperation = true;
this.configAction = configAction;
ExhaustiveUserTypes currentUserType;
if (dataContainer.IsNewObject)
{
currentUserType = UserActions.GetUserTypeForUserInfo(user);
}
else
{
currentUserType = UserActions.GetCurrentUserTypeForExistingUser(dataContainer.Server.GetSmoObject(dataContainer.ObjectUrn) as User);
}
this.userPrototype = UserPrototypeFactory.GetUserPrototype(dataContainer, user, originalData, currentUserType);
}
// /// <summary>
// /// Clean up any resources being used.
// /// </summary>
// protected override void Dispose(bool disposing)
// {
// base.Dispose(disposing);
// }
#endregion
/// <summary>
/// called by the management actions framework to execute the action
/// </summary>
/// <param name="node"></param>
public override void OnRunNow(object sender)
{
if (this.configAction != ConfigAction.Drop)
{
this.userPrototype.ApplyChanges(this.ParentDb);
}
}
internal static ExhaustiveUserTypes GetUserTypeForUserInfo(UserInfo user)
{
ExhaustiveUserTypes userType = ExhaustiveUserTypes.LoginMappedUser;
switch (user.Type)
{
case DatabaseUserType.WithLogin:
userType = ExhaustiveUserTypes.LoginMappedUser;
break;
case DatabaseUserType.WithWindowsGroupLogin:
userType = ExhaustiveUserTypes.WindowsUser;
break;
case DatabaseUserType.Contained:
if (user.AuthenticationType == ServerAuthenticationType.AzureActiveDirectory)
{
userType = ExhaustiveUserTypes.ExternalUser;
}
else
{
userType = ExhaustiveUserTypes.SqlUserWithPassword;
}
break;
case DatabaseUserType.NoConnectAccess:
userType = ExhaustiveUserTypes.SqlUserWithoutLogin;
break;
}
return userType;
}
internal static DatabaseUserType GetDatabaseUserTypeForUserType(ExhaustiveUserTypes userType)
{
DatabaseUserType databaseUserType = DatabaseUserType.WithLogin;
switch (userType)
{
case ExhaustiveUserTypes.LoginMappedUser:
databaseUserType = DatabaseUserType.WithLogin;
break;
case ExhaustiveUserTypes.WindowsUser:
databaseUserType = DatabaseUserType.WithWindowsGroupLogin;
break;
case ExhaustiveUserTypes.SqlUserWithPassword:
databaseUserType = DatabaseUserType.Contained;
break;
case ExhaustiveUserTypes.SqlUserWithoutLogin:
databaseUserType = DatabaseUserType.NoConnectAccess;
break;
case ExhaustiveUserTypes.ExternalUser:
databaseUserType = DatabaseUserType.Contained;
break;
}
return databaseUserType;
}
internal static ExhaustiveUserTypes GetCurrentUserTypeForExistingUser(User? user)
{
if (user == null)
{
return ExhaustiveUserTypes.Unknown;
}
switch (user.UserType)
{
case UserType.SqlUser:
if (user.IsSupportedProperty("AuthenticationType"))
{
if (user.AuthenticationType == AuthenticationType.Windows)
{
return ExhaustiveUserTypes.WindowsUser;
}
else if (user.AuthenticationType == AuthenticationType.Database)
{
return ExhaustiveUserTypes.SqlUserWithPassword;
}
}
return ExhaustiveUserTypes.LoginMappedUser;
case UserType.NoLogin:
return ExhaustiveUserTypes.SqlUserWithoutLogin;
case UserType.Certificate:
return ExhaustiveUserTypes.CertificateMappedUser;
case UserType.AsymmetricKey:
return ExhaustiveUserTypes.AsymmetricKeyMappedUser;
case UserType.External:
return ExhaustiveUserTypes.ExternalUser;
default:
return ExhaustiveUserTypes.Unknown;
}
}
internal static bool IsParentDatabaseContained(Urn parentDbUrn, Server server)
{
string parentDbName = parentDbUrn.GetNameForType("Database");
return IsParentDatabaseContained(server.Databases[parentDbName]);
}
internal static bool IsParentDatabaseContained(Database parentDatabase)
{
return parentDatabase.IsSupportedProperty("ContainmentType")
&& parentDatabase.ContainmentType == ContainmentType.Partial;
}
}
}

View File

@@ -11,10 +11,9 @@ using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlServer.Management.Smo; using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlServer.Management.Sdk.Sfc; using Microsoft.SqlServer.Management.Sdk.Sfc;
using Microsoft.SqlTools.ServiceLayer.Management; using Microsoft.SqlTools.ServiceLayer.Management;
using Microsoft.SqlTools.ServiceLayer.Security.Contracts;
using Microsoft.SqlTools.ServiceLayer.Utility; using Microsoft.SqlTools.ServiceLayer.Utility;
namespace Microsoft.SqlTools.ServiceLayer.Security namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{ {
/// <summary> /// <summary>
/// Defines the common behavior of all types of database user objects. /// Defines the common behavior of all types of database user objects.

View File

@@ -0,0 +1,283 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
#nullable disable
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using System.Xml;
using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlServer.Management.Sdk.Sfc;
using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlTools.ServiceLayer.Management;
using Microsoft.SqlTools.ServiceLayer.Utility;
namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{
/// <summary>
/// User object type handler
/// </summary>
public class UserHandler : ObjectTypeHandler<UserInfo, UserViewContext>
{
public UserHandler(ConnectionService connectionService) : base(connectionService)
{
}
public override bool CanHandleType(SqlObjectType objectType)
{
return objectType == SqlObjectType.User;
}
public override async Task<InitializeViewResult> InitializeObjectView(Contracts.InitializeViewRequestParams parameters)
{
// check input parameters
if (string.IsNullOrWhiteSpace(parameters.Database))
{
throw new ArgumentNullException("parameters.Database");
}
// open a connection for running the user dialog and associated task
ConnectionInfo originalConnInfo;
this.ConnectionService.TryFindConnection(parameters.ConnectionUri, out originalConnInfo);
if (originalConnInfo == null)
{
throw new ArgumentException("Invalid connection URI '{0}'", parameters.ConnectionUri);
}
string originalDatabaseName = originalConnInfo.ConnectionDetails.DatabaseName;
try
{
originalConnInfo.ConnectionDetails.DatabaseName = parameters.Database;
ConnectParams connectParams = new ConnectParams
{
OwnerUri = parameters.ContextId,
Connection = originalConnInfo.ConnectionDetails,
Type = Connection.ConnectionType.Default
};
await this.ConnectionService.Connect(connectParams);
}
finally
{
originalConnInfo.ConnectionDetails.DatabaseName = originalDatabaseName;
}
ConnectionInfo connInfo;
this.ConnectionService.TryFindConnection(parameters.ContextId, out connInfo);
// create a default user data context and database object
CDataContainer dataContainer = CreateUserDataContainer(connInfo, null, ConfigAction.Create, parameters.Database);
string databaseUrn = string.Format(System.Globalization.CultureInfo.InvariantCulture,
"Server/Database[@Name='{0}']", Urn.EscapeString(parameters.Database));
Database parentDb = dataContainer.Server.GetSmoObject(databaseUrn) as Database;
var languageOptions = LanguageUtils.GetDefaultLanguageOptions(dataContainer);
var languageOptionsList = languageOptions.Select(LanguageUtils.FormatLanguageDisplay).ToList();
languageOptionsList.Insert(0, SR.DefaultLanguagePlaceholder);
// if viewing an exisitng user then populate some properties
UserInfo userInfo = null;
string defaultLanguageAlias = null;
ExhaustiveUserTypes userType = ExhaustiveUserTypes.LoginMappedUser;
if (!parameters.IsNewObject)
{
User existingUser = dataContainer.Server.GetSmoObject(parameters.ObjectUrn) as User;
userType = UserActions.GetCurrentUserTypeForExistingUser(existingUser);
DatabaseUserType databaseUserType = UserActions.GetDatabaseUserTypeForUserType(userType);
// if contained user determine if SQL or AAD auth type
ServerAuthenticationType authenticationType =
(databaseUserType == DatabaseUserType.Contained && userType == ExhaustiveUserTypes.ExternalUser)
? ServerAuthenticationType.AzureActiveDirectory : ServerAuthenticationType.Sql;
userInfo = new UserInfo()
{
Type = databaseUserType,
AuthenticationType = authenticationType,
Name = existingUser.Name,
LoginName = existingUser.Login,
DefaultSchema = existingUser.DefaultSchema,
};
// Default language is only applicable for users inside a contained database.
if (LanguageUtils.IsDefaultLanguageSupported(dataContainer.Server)
&& parentDb.ContainmentType != ContainmentType.None)
{
defaultLanguageAlias = LanguageUtils.GetLanguageAliasFromName(
existingUser.Parent.Parent,
existingUser.DefaultLanguage.Name);
}
}
// generate a user prototype
UserPrototype currentUserPrototype = UserPrototypeFactory.GetUserPrototype(dataContainer, userInfo, originalData: null, userType);
// get the default schema if available
string defaultSchema = null;
IUserPrototypeWithDefaultSchema defaultSchemaPrototype = currentUserPrototype as IUserPrototypeWithDefaultSchema;
if (defaultSchemaPrototype != null && defaultSchemaPrototype.IsDefaultSchemaSupported)
{
defaultSchema = defaultSchemaPrototype.DefaultSchema;
}
ServerConnection serverConnection = dataContainer.ServerConnection;
bool isSqlAzure = serverConnection.DatabaseEngineType == DatabaseEngineType.SqlAzureDatabase;
bool supportsContainedUser = isSqlAzure || UserActions.IsParentDatabaseContained(parentDb);
// set default alias to <default> if needed
if (string.IsNullOrEmpty(defaultLanguageAlias)
&& supportsContainedUser
&& LanguageUtils.IsDefaultLanguageSupported(dataContainer.Server))
{
defaultLanguageAlias = SR.DefaultLanguagePlaceholder;
}
// set the fake password placeholder when editing an existing user
string password = null;
IUserPrototypeWithPassword userWithPwdPrototype = currentUserPrototype as IUserPrototypeWithPassword;
if (userWithPwdPrototype != null && !parameters.IsNewObject)
{
userWithPwdPrototype.Password = DatabaseUtils.GetReadOnlySecureString(LoginPrototype.fakePassword);
userWithPwdPrototype.PasswordConfirm = DatabaseUtils.GetReadOnlySecureString(LoginPrototype.fakePassword);
password = LoginPrototype.fakePassword;
}
// get the login name if it exists
string loginName = null;
IUserPrototypeWithMappedLogin mappedLoginPrototype = currentUserPrototype as IUserPrototypeWithMappedLogin;
if (mappedLoginPrototype != null)
{
loginName = mappedLoginPrototype.LoginName;
}
// populate user's role assignments
List<string> databaseRoles = new List<string>();
foreach (string role in currentUserPrototype.DatabaseRoleNames)
{
if (currentUserPrototype.IsRoleMember(role))
{
databaseRoles.Add(role);
}
}
// populate user's schema ownerships
List<string> schemaNames = new List<string>();
foreach (string schema in currentUserPrototype.SchemaNames)
{
if (currentUserPrototype.IsSchemaOwner(schema))
{
schemaNames.Add(schema);
}
}
UserViewInfo userViewInfo = new UserViewInfo()
{
ObjectInfo = new UserInfo()
{
Type = userInfo?.Type ?? DatabaseUserType.WithLogin,
AuthenticationType = userInfo?.AuthenticationType ?? ServerAuthenticationType.Sql,
Name = currentUserPrototype.Name,
LoginName = loginName,
Password = password,
DefaultSchema = defaultSchema,
OwnedSchemas = schemaNames.ToArray(),
DatabaseRoles = databaseRoles.ToArray(),
DefaultLanguage = LanguageUtils.FormatLanguageDisplay(
languageOptions.FirstOrDefault(o => o?.Language.Name == defaultLanguageAlias || o?.Language.Alias == defaultLanguageAlias, null)),
},
SupportContainedUser = supportsContainedUser,
SupportWindowsAuthentication = false,
SupportAADAuthentication = currentUserPrototype.AADAuthSupported,
SupportSQLAuthentication = true,
Languages = languageOptionsList.ToArray(),
Schemas = currentUserPrototype.SchemaNames.ToArray(),
Logins = DatabaseUtils.LoadSqlLogins(serverConnection),
DatabaseRoles = currentUserPrototype.DatabaseRoleNames.ToArray()
};
var context = new UserViewContext(parameters, serverConnection, currentUserPrototype.CurrentState);
return new InitializeViewResult { ViewInfo = userViewInfo, Context = context };
}
public override Task Save(UserViewContext context, UserInfo obj)
{
ConfigureUser(
context.Parameters.ContextId,
obj,
context.Parameters.IsNewObject ? ConfigAction.Create : ConfigAction.Update,
RunType.RunNow,
context.Parameters.Database,
context.OriginalUserData);
return Task.CompletedTask;
}
public override Task<string> Script(UserViewContext context, UserInfo obj)
{
var script = ConfigureUser(
context.Parameters.ContextId,
obj,
context.Parameters.IsNewObject ? ConfigAction.Create : ConfigAction.Update,
RunType.ScriptToWindow,
context.Parameters.Database,
context.OriginalUserData);
return Task.FromResult(script);
}
internal CDataContainer CreateUserDataContainer(ConnectionInfo connInfo, UserInfo user, ConfigAction configAction, string databaseName)
{
var serverConnection = ConnectionService.OpenServerConnection(connInfo, "DataContainer");
var connectionInfoWithConnection = new SqlConnectionInfoWithConnection();
connectionInfoWithConnection.ServerConnection = serverConnection;
string urn = (configAction == ConfigAction.Update && user != null)
? string.Format(System.Globalization.CultureInfo.InvariantCulture,
"Server/Database[@Name='{0}']/User[@Name='{1}']",
Urn.EscapeString(databaseName),
Urn.EscapeString(user.Name))
: string.Format(System.Globalization.CultureInfo.InvariantCulture,
"Server/Database[@Name='{0}']",
Urn.EscapeString(databaseName));
ActionContext context = new ActionContext(serverConnection, "User", urn);
DataContainerXmlGenerator containerXml = new DataContainerXmlGenerator(context);
if (configAction == ConfigAction.Create)
{
containerXml.AddProperty("itemtype", "User");
}
XmlDocument xmlDoc = containerXml.GenerateXmlDocument();
return CDataContainer.CreateDataContainer(connectionInfoWithConnection, xmlDoc);
}
internal string ConfigureUser(string ownerUri, UserInfo user, ConfigAction configAction, RunType runType, string databaseName, UserPrototypeData originalData)
{
ConnectionInfo connInfo;
this.ConnectionService.TryFindConnection(ownerUri, out connInfo);
if (connInfo == null)
{
throw new ArgumentException("Invalid connection URI '{0}'", ownerUri);
}
string sqlScript = string.Empty;
CDataContainer dataContainer = CreateUserDataContainer(connInfo, user, configAction, databaseName);
using (var actions = new UserActions(dataContainer, configAction, user, originalData))
{
var executionHandler = new ExecutonHandler(actions);
executionHandler.RunNow(runType, this);
if (executionHandler.ExecutionResult == ExecutionMode.Failure)
{
throw executionHandler.ExecutionFailureException;
}
if (runType == RunType.ScriptToWindow)
{
sqlScript = executionHandler.ScriptTextFromLastRun;
}
}
return sqlScript;
}
}
}

View File

@@ -7,7 +7,7 @@ using System.Runtime.Serialization;
using Newtonsoft.Json; using Newtonsoft.Json;
using Newtonsoft.Json.Converters; using Newtonsoft.Json.Converters;
namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{ {
[JsonConverter(typeof(StringEnumConverter))] [JsonConverter(typeof(StringEnumConverter))]
public enum ServerAuthenticationType public enum ServerAuthenticationType
@@ -41,12 +41,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts
/// <summary> /// <summary>
/// a class for storing various user properties /// a class for storing various user properties
/// </summary> /// </summary>
public class UserInfo public class UserInfo : SqlObject
{ {
public DatabaseUserType? Type { get; set; } public DatabaseUserType? Type { get; set; }
public string? Name { get; set; }
public string? LoginName { get; set; } public string? LoginName { get; set; }
public string? Password { get; set; } public string? Password { get; set; }
@@ -61,28 +59,4 @@ namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts
public string? DefaultLanguage { get; set; } public string? DefaultLanguage { get; set; }
} }
/// <summary>
/// The information required to render the user view.
/// </summary>
public class UserViewInfo
{
public UserInfo? ObjectInfo { get; set; }
public bool SupportContainedUser { get; set; }
public bool SupportWindowsAuthentication { get; set; }
public bool SupportAADAuthentication { get; set; }
public bool SupportSQLAuthentication { get; set; }
public string[]? Languages { get; set; }
public string[]? Schemas { get; set; }
public string[]? Logins { get; set; }
public string[]? DatabaseRoles { get; set; }
}
} }

View File

@@ -0,0 +1,35 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlTools.ServiceLayer.ObjectManagement.Contracts;
namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{
public class UserViewContext : SqlObjectViewContext
{
public UserViewContext(InitializeViewRequestParams parameters, ServerConnection connection, UserPrototypeData originalUserData) : base(parameters)
{
this.OriginalUserData = originalUserData;
this.Connection = connection;
}
public UserPrototypeData OriginalUserData { get; }
public ServerConnection Connection { get; }
public override void Dispose()
{
try
{
this.Connection.Disconnect();
}
catch
{
// ignore
}
}
}
}

View File

@@ -0,0 +1,29 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{
/// <summary>
/// The information required to render the user view.
/// </summary>
public class UserViewInfo : SqlObjectViewInfo
{
public bool SupportContainedUser { get; set; }
public bool SupportWindowsAuthentication { get; set; }
public bool SupportAADAuthentication { get; set; }
public bool SupportSQLAuthentication { get; set; }
public string[]? Languages { get; set; }
public string[]? Schemas { get; set; }
public string[]? Logins { get; set; }
public string[]? DatabaseRoles { get; set; }
}
}

View File

@@ -0,0 +1,12 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{
public abstract class SqlObject
{
public string? Name { get; set; }
}
}

View File

@@ -0,0 +1,28 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System.Runtime.Serialization;
using Newtonsoft.Json;
using Newtonsoft.Json.Converters;
namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{
[JsonConverter(typeof(StringEnumConverter))]
public enum SqlObjectType
{
[EnumMember(Value = "Column")]
Column,
[EnumMember(Value = "Credential")]
Credential,
[EnumMember(Value = "ServerLevelLogin")]
ServerLevelLogin,
[EnumMember(Value = "Table")]
Table,
[EnumMember(Value = "User")]
User,
[EnumMember(Value = "View")]
View
}
}

View File

@@ -0,0 +1,28 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
#nullable disable
using System;
using Microsoft.SqlTools.ServiceLayer.ObjectManagement.Contracts;
namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{
public abstract class SqlObjectViewContext : IDisposable
{
public SqlObjectViewContext(InitializeViewRequestParams parameters)
{
this.Parameters = parameters;
}
public InitializeViewRequestParams Parameters { get; }
public abstract void Dispose();
}
public class InitializeViewResult
{
public SqlObjectViewContext Context { get; set; }
public SqlObjectViewInfo ViewInfo { get; set; }
}
}

View File

@@ -0,0 +1,13 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
#nullable disable
namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement
{
public abstract class SqlObjectViewInfo
{
public SqlObject ObjectInfo { get; set; }
}
}

View File

@@ -1,117 +0,0 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
#nullable disable
using Microsoft.SqlTools.Hosting.Protocol.Contracts;
using Microsoft.SqlTools.ServiceLayer.Utility;
using Microsoft.SqlTools.Utility;
namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts
{
/// <summary>
/// Get Credential parameters
/// </summary>
public class GetCredentialsParams: GeneralRequestDetails
{
public string OwnerUri { get; set; }
}
public class GetCredentialsResult: ResultStatus
{
public CredentialInfo[] Credentials { get; set; }
}
/// <summary>
/// SQL Agent Credentials request type
/// </summary>
public class GetCredentialsRequest
{
/// <summary>
/// Request definition
/// </summary>
public static readonly
RequestType<GetCredentialsParams, GetCredentialsResult> Type =
RequestType<GetCredentialsParams, GetCredentialsResult>.Create("security/credentials");
}
/// <summary>
/// Create Credential parameters
/// </summary>
public class CreateCredentialParams : GeneralRequestDetails
{
public string OwnerUri { get; set; }
public CredentialInfo Credential { get; set; }
}
/// <summary>
/// Create Credential result
/// </summary>
public class CredentialResult : ResultStatus
{
public CredentialInfo Credential { get; set; }
}
/// <summary>
/// Create Credential request type
/// </summary>
public class CreateCredentialRequest
{
/// <summary>
/// Request definition
/// </summary>
public static readonly
RequestType<CreateCredentialParams, CredentialResult> Type =
RequestType<CreateCredentialParams, CredentialResult>.Create("security/createcredential");
}
/// <summary>
/// Update Credential params
/// </summary>
public class UpdateCredentialParams : GeneralRequestDetails
{
public string OwnerUri { get; set; }
public CredentialInfo Credential { get; set; }
}
/// <summary>
/// Update Credential request type
/// </summary>
public class UpdateCredentialRequest
{
/// <summary>
/// Request definition
/// </summary>
public static readonly
RequestType<UpdateCredentialParams, CredentialResult> Type =
RequestType<UpdateCredentialParams, CredentialResult>.Create("security/updatecredential");
}
/// <summary>
/// Delete Credential params
/// </summary>
public class DeleteCredentialParams : GeneralRequestDetails
{
public string OwnerUri { get; set; }
public CredentialInfo Credential { get; set; }
}
/// <summary>
/// Delete Credential request type
/// </summary>
public class DeleteCredentialRequest
{
/// <summary>
/// Request definition
/// </summary>
public static readonly
RequestType<DeleteCredentialParams, ResultStatus> Type =
RequestType<DeleteCredentialParams, ResultStatus>.Create("security/deletecredential");
}
}

View File

@@ -1,128 +0,0 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
#nullable disable
using Microsoft.SqlTools.Hosting.Protocol.Contracts;
namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts
{
/// <summary>
/// Create Login parameters
/// </summary>
public class CreateLoginParams
{
public string ContextId { get; set; }
public LoginInfo Login { get; set; }
}
/// <summary>
/// Create Login request type
/// </summary>
public class CreateLoginRequest
{
/// <summary>
/// Request definition
/// </summary>
public static readonly
RequestType<CreateLoginParams, object> Type =
RequestType<CreateLoginParams, object>.Create("objectManagement/createLogin");
}
/// <summary>
/// Update Login params
/// </summary>
public class UpdateLoginParams
{
public string ContextId { get; set; }
public LoginInfo Login { get; set; }
}
/// <summary>
/// Update Login request type
/// </summary>
public class UpdateLoginRequest
{
/// <summary>
/// Request definition
/// </summary>
public static readonly
RequestType<UpdateLoginParams, object> Type =
RequestType<UpdateLoginParams, object>.Create("objectManagement/updateLogin");
}
/// <summary>
/// Update Login params
/// </summary>
public class DisposeLoginViewRequestParams
{
public string ContextId { get; set; }
}
/// <summary>
/// Update Login request type
/// </summary>
public class DisposeLoginViewRequest
{
/// <summary>
/// Request definition
/// </summary>
public static readonly
RequestType<DisposeLoginViewRequestParams, object> Type =
RequestType<DisposeLoginViewRequestParams, object>.Create("objectManagement/disposeLoginView");
}
/// <summary>
/// Initialize Login View Request params
/// </summary>
public class InitializeLoginViewRequestParams
{
public string ConnectionUri { get; set; }
public string ContextId { get; set; }
public bool IsNewObject { get; set; }
public string Name { get; set; }
}
/// <summary>
/// Initialize Login View request type
/// </summary>
public class InitializeLoginViewRequest
{
/// <summary>
/// Request definition
/// </summary>
public static readonly
RequestType<InitializeLoginViewRequestParams, LoginViewInfo> Type =
RequestType<InitializeLoginViewRequestParams, LoginViewInfo>.Create("objectManagement/initializeLoginView");
}
/// <summary>
/// Script Login params
/// </summary>
public class ScriptLoginParams
{
public string? ContextId { get; set; }
public LoginInfo? Login { get; set; }
}
/// <summary>
/// Script Login request type
/// </summary>
public class ScriptLoginRequest
{
/// <summary>
/// Request definition
/// </summary>
public static readonly
RequestType<ScriptLoginParams, string> Type =
RequestType<ScriptLoginParams, string>.Create("objectManagement/scriptLogin");
}
}

View File

@@ -1,135 +0,0 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using Microsoft.SqlTools.Hosting.Protocol.Contracts;
using Microsoft.SqlTools.ServiceLayer.Utility;
namespace Microsoft.SqlTools.ServiceLayer.Security.Contracts
{
/// <summary>
/// Initialize User View parameters
/// </summary>
public class InitializeUserViewParams
{
public string? ContextId { get; set; }
public string? ConnectionUri { get; set; }
public bool IsNewObject { get; set; }
public string? Database { get; set; }
public string? Name { get; set; }
}
/// <summary>
/// Initialize User View request type
/// </summary>
public class InitializeUserViewRequest
{
/// <summary>
/// Request definition
/// </summary>
public static readonly
RequestType<InitializeUserViewParams, UserViewInfo> Type =
RequestType<InitializeUserViewParams, UserViewInfo>.Create("objectManagement/initializeUserView");
}
/// <summary>
/// Create User parameters
/// </summary>
public class CreateUserParams
{
public string? ContextId { get; set; }
public UserInfo? User { get; set; }
}
/// <summary>
/// Create User result
/// </summary>
public class CreateUserResult : ResultStatus
{
public UserInfo? User { get; set; }
}
/// <summary>
/// Create User request type
/// </summary>
public class CreateUserRequest
{
/// <summary>
/// Request definition
/// </summary>
public static readonly
RequestType<CreateUserParams, CreateUserResult> Type =
RequestType<CreateUserParams, CreateUserResult>.Create("objectManagement/createUser");
}
/// <summary>
/// Update User parameters
/// </summary>
public class UpdateUserParams
{
public string? ContextId { get; set; }
public UserInfo? User { get; set; }
}
/// <summary>
/// Update User request type
/// </summary>
public class UpdateUserRequest
{
/// <summary>
/// Request definition
/// </summary>
public static readonly
RequestType<UpdateUserParams, ResultStatus> Type =
RequestType<UpdateUserParams, ResultStatus>.Create("objectManagement/updateUser");
}
/// <summary>
/// Update User params
/// </summary>
public class DisposeUserViewRequestParams
{
public string? ContextId { get; set; }
}
/// <summary>
/// Update User request type
/// </summary>
public class DisposeUserViewRequest
{
/// <summary>
/// Request definition
/// </summary>
public static readonly
RequestType<DisposeUserViewRequestParams, ResultStatus> Type =
RequestType<DisposeUserViewRequestParams, ResultStatus>.Create("objectManagement/disposeUserView");
}
/// <summary>
/// Script User params
/// </summary>
public class ScriptUserParams
{
public string? ContextId { get; set; }
public UserInfo? User { get; set; }
}
/// <summary>
/// Script User request type
/// </summary>
public class ScriptUserRequest
{
/// <summary>
/// Request definition
/// </summary>
public static readonly
RequestType<ScriptUserParams, string> Type =
RequestType<ScriptUserParams, string>.Create("objectManagement/scriptUser");
}
}

View File

@@ -1,223 +0,0 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System;
using System.Threading.Tasks;
using Microsoft.SqlTools.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Hosting;
using Microsoft.SqlTools.ServiceLayer.Management;
using Microsoft.SqlTools.ServiceLayer.Security.Contracts;
namespace Microsoft.SqlTools.ServiceLayer.Security
{
/// <summary>
/// Main class for Security Service functionality
/// </summary>
public sealed class SecurityService : IDisposable
{
private bool disposed;
private ConnectionService? connectionService;
private UserServiceHandlerImpl userServiceHandler;
private LoginServiceHandlerImpl loginServiceHandler;
private static readonly Lazy<SecurityService> instance = new Lazy<SecurityService>(() => new SecurityService());
/// <summary>
/// Construct a new SecurityService instance with default parameters
/// </summary>
public SecurityService()
{
userServiceHandler = new UserServiceHandlerImpl();
loginServiceHandler = new LoginServiceHandlerImpl();
}
/// <summary>
/// Gets the singleton instance object
/// </summary>
public static SecurityService Instance
{
get { return instance.Value; }
}
/// <summary>
/// Internal for testing purposes only
/// </summary>
internal ConnectionService ConnectionServiceInstance
{
get
{
connectionService ??= ConnectionService.Instance;
return connectionService;
}
set
{
connectionService = value;
}
}
/// <summary>
/// Service host object for sending/receiving requests/events.
/// Internal for testing purposes.
/// </summary>
internal IProtocolEndpoint? ServiceHost
{
get;
set;
}
/// <summary>
/// Initializes the Security Service instance
/// </summary>
public void InitializeService(ServiceHost serviceHost)
{
this.ServiceHost = serviceHost;
// Credential request handlers
this.ServiceHost.SetRequestHandler(CreateCredentialRequest.Type, HandleCreateCredentialRequest, true);
this.ServiceHost.SetRequestHandler(UpdateCredentialRequest.Type, HandleUpdateCredentialRequest, true);
this.ServiceHost.SetRequestHandler(GetCredentialsRequest.Type, HandleGetCredentialsRequest, true);
// Login request handlers
this.ServiceHost.SetRequestHandler(CreateLoginRequest.Type, this.loginServiceHandler.HandleCreateLoginRequest, true);
this.ServiceHost.SetRequestHandler(UpdateLoginRequest.Type, this.loginServiceHandler.HandleUpdateLoginRequest, true);
this.ServiceHost.SetRequestHandler(InitializeLoginViewRequest.Type, this.loginServiceHandler.HandleInitializeLoginViewRequest, true);
this.ServiceHost.SetRequestHandler(ScriptLoginRequest.Type, this.loginServiceHandler.HandleScriptLoginRequest, true);
this.ServiceHost.SetRequestHandler(DisposeLoginViewRequest.Type, this.loginServiceHandler.HandleDisposeLoginViewRequest, true);
// User request handlers
this.ServiceHost.SetRequestHandler(InitializeUserViewRequest.Type, this.userServiceHandler.HandleInitializeUserViewRequest, true);
this.ServiceHost.SetRequestHandler(CreateUserRequest.Type, this.userServiceHandler.HandleCreateUserRequest, true);
this.ServiceHost.SetRequestHandler(UpdateUserRequest.Type, this.userServiceHandler.HandleUpdateUserRequest, true);
this.ServiceHost.SetRequestHandler(ScriptUserRequest.Type, this.userServiceHandler.HandleScriptUserRequest, true);
this.ServiceHost.SetRequestHandler(DisposeUserViewRequest.Type, this.userServiceHandler.HandleDisposeUserViewRequest, true);
}
#region "Credential Handlers"
/// <summary>
/// Handle request to create a credential
/// </summary>
internal async Task HandleCreateCredentialRequest(CreateCredentialParams parameters, RequestContext<CredentialResult> requestContext)
{
var result = await ConfigureCredential(parameters.OwnerUri,
parameters.Credential,
ConfigAction.Create,
RunType.RunNow);
await requestContext.SendResult(new CredentialResult()
{
Credential = parameters.Credential,
Success = result.Item1,
ErrorMessage = result.Item2
});
}
/// <summary>
/// Handle request to update a credential
/// </summary>
internal async Task HandleUpdateCredentialRequest(UpdateCredentialParams parameters, RequestContext<CredentialResult> requestContext)
{
var result = await ConfigureCredential(parameters.OwnerUri,
parameters.Credential,
ConfigAction.Update,
RunType.RunNow);
await requestContext.SendResult(new CredentialResult()
{
Credential = parameters.Credential,
Success = result.Item1,
ErrorMessage = result.Item2
});
}
/// <summary>
/// Handle request to get all credentials
/// </summary>
internal async Task HandleGetCredentialsRequest(GetCredentialsParams parameters, RequestContext<GetCredentialsResult> requestContext)
{
var result = new GetCredentialsResult();
try
{
ConnectionInfo connInfo;
ConnectionServiceInstance.TryFindConnection(parameters.OwnerUri, out connInfo);
CDataContainer dataContainer = CDataContainer.CreateDataContainer(connInfo, databaseExists: true);
var credentials = dataContainer.Server?.Credentials;
int credentialsCount = credentials != null ? credentials.Count : 0;
CredentialInfo[] credentialsInfos = new CredentialInfo[credentialsCount];
if (credentials != null)
{
for (int i = 0; i < credentialsCount; ++i)
{
credentialsInfos[i] = new CredentialInfo();
credentialsInfos[i].Name = credentials[i].Name;
credentialsInfos[i].Identity = credentials[i].Identity;
credentialsInfos[i].Id = credentials[i].ID;
credentialsInfos[i].DateLastModified = credentials[i].DateLastModified;
credentialsInfos[i].CreateDate = credentials[i].CreateDate;
credentialsInfos[i].ProviderName = credentials[i].ProviderName;
}
}
result.Credentials = credentialsInfos;
result.Success = true;
}
catch (Exception ex)
{
result.Success = false;
result.ErrorMessage = ex.ToString();
}
await requestContext.SendResult(result);
}
/// <summary>
/// Disposes the service
/// </summary>
public void Dispose()
{
if (!disposed)
{
disposed = true;
}
}
internal Task<Tuple<bool, string>> ConfigureCredential(
string ownerUri,
CredentialInfo credential,
ConfigAction configAction,
RunType runType)
{
return Task<Tuple<bool, string>>.Run(() =>
{
try
{
ConnectionInfo connInfo;
ConnectionServiceInstance.TryFindConnection(ownerUri, out connInfo);
CDataContainer dataContainer = CDataContainer.CreateDataContainer(connInfo, databaseExists: true);
using (CredentialActions actions = new CredentialActions(dataContainer, credential, configAction))
{
var executionHandler = new ExecutonHandler(actions);
executionHandler.RunNow(runType, this);
}
return new Tuple<bool, string>(true, string.Empty);
}
catch (Exception ex)
{
return new Tuple<bool, string>(false, ex.ToString());
}
});
}
#endregion // "Credential Handlers"
}
}

View File

@@ -1,582 +0,0 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using System.Xml;
using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlServer.Management.Sdk.Sfc;
using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlTools.ServiceLayer.Management;
using Microsoft.SqlTools.ServiceLayer.Security.Contracts;
using Microsoft.SqlTools.ServiceLayer.Utility;
namespace Microsoft.SqlTools.ServiceLayer.Security
{
internal class UserServiceHandlerImpl
{
private class ViewState
{
public bool IsNewObject { get; set; }
public string Database { get; set; }
public UserPrototypeData OriginalUserData { get; set; }
public ViewState(bool isNewObject, string database, UserPrototypeData originalUserData)
{
this.IsNewObject = isNewObject;
this.Database = database;
this.OriginalUserData = originalUserData;
}
}
private ConnectionService? connectionService;
private Dictionary<string, ViewState> contextIdToViewState = new Dictionary<string, ViewState>();
/// <summary>
/// Internal for testing purposes only
/// </summary>
internal ConnectionService ConnectionServiceInstance
{
get
{
connectionService ??= ConnectionService.Instance;
return connectionService;
}
set
{
connectionService = value;
}
}
/// <summary>
/// Handle request to initialize user view
/// </summary>
internal async Task HandleInitializeUserViewRequest(InitializeUserViewParams parameters, RequestContext<UserViewInfo> requestContext)
{
// check input parameters
if (string.IsNullOrWhiteSpace(parameters.Database))
{
throw new ArgumentNullException("parameters.Database");
}
if (string.IsNullOrWhiteSpace(parameters.ContextId))
{
throw new ArgumentNullException("parameters.ContextId");
}
// open a connection for running the user dialog and associated task
ConnectionInfo originalConnInfo;
ConnectionServiceInstance.TryFindConnection(parameters.ConnectionUri, out originalConnInfo);
if (originalConnInfo == null)
{
throw new ArgumentException("Invalid connection URI '{0}'", parameters.ConnectionUri);
}
string originalDatabaseName = originalConnInfo.ConnectionDetails.DatabaseName;
try
{
originalConnInfo.ConnectionDetails.DatabaseName = parameters.Database;
ConnectParams connectParams = new ConnectParams
{
OwnerUri = parameters.ContextId,
Connection = originalConnInfo.ConnectionDetails,
Type = Connection.ConnectionType.Default
};
await this.ConnectionServiceInstance.Connect(connectParams);
}
finally
{
originalConnInfo.ConnectionDetails.DatabaseName = originalDatabaseName;
}
ConnectionInfo connInfo;
this.ConnectionServiceInstance.TryFindConnection(parameters.ContextId, out connInfo);
// create a default user data context and database object
CDataContainer dataContainer = CreateUserDataContainer(connInfo, null, ConfigAction.Create, parameters.Database);
string databaseUrn = string.Format(System.Globalization.CultureInfo.InvariantCulture,
"Server/Database[@Name='{0}']", Urn.EscapeString(parameters.Database));
Database? parentDb = dataContainer.Server.GetSmoObject(databaseUrn) as Database;
var languageOptions = LanguageUtils.GetDefaultLanguageOptions(dataContainer);
var languageOptionsList = languageOptions.Select(LanguageUtils.FormatLanguageDisplay).ToList();
languageOptionsList.Insert(0, SR.DefaultLanguagePlaceholder);
// if viewing an exisitng user then populate some properties
UserInfo? userInfo = null;
string? defaultLanguageAlias = null;
ExhaustiveUserTypes userType = ExhaustiveUserTypes.LoginMappedUser;
if (!parameters.IsNewObject)
{
User existingUser = dataContainer.Server.Databases[parentDb.Name].Users[parameters.Name];
userType = UserActions.GetCurrentUserTypeForExistingUser(existingUser);
DatabaseUserType databaseUserType = UserActions.GetDatabaseUserTypeForUserType(userType);
// if contained user determine if SQL or AAD auth type
ServerAuthenticationType authenticationType =
(databaseUserType == DatabaseUserType.Contained && userType == ExhaustiveUserTypes.ExternalUser)
? ServerAuthenticationType.AzureActiveDirectory : ServerAuthenticationType.Sql;
userInfo = new UserInfo()
{
Type = databaseUserType,
AuthenticationType = authenticationType,
Name = parameters.Name,
LoginName = existingUser.Login,
DefaultSchema = existingUser.DefaultSchema,
};
// Default language is only applicable for users inside a contained database.
if (LanguageUtils.IsDefaultLanguageSupported(dataContainer.Server)
&& parentDb.ContainmentType != ContainmentType.None)
{
defaultLanguageAlias = LanguageUtils.GetLanguageAliasFromName(
existingUser.Parent.Parent,
existingUser.DefaultLanguage.Name);
}
}
// generate a user prototype
UserPrototype currentUserPrototype = UserPrototypeFactory.GetUserPrototype(dataContainer, userInfo, originalData: null, userType);
// get the default schema if available
string? defaultSchema = null;
IUserPrototypeWithDefaultSchema defaultSchemaPrototype = currentUserPrototype as IUserPrototypeWithDefaultSchema;
if (defaultSchemaPrototype != null && defaultSchemaPrototype.IsDefaultSchemaSupported)
{
defaultSchema = defaultSchemaPrototype.DefaultSchema;
}
ServerConnection serverConnection = dataContainer.ServerConnection;
bool isSqlAzure = serverConnection.DatabaseEngineType == DatabaseEngineType.SqlAzureDatabase;
bool supportsContainedUser = isSqlAzure || UserActions.IsParentDatabaseContained(parentDb);
// set default alias to <default> if needed
if (string.IsNullOrEmpty(defaultLanguageAlias)
&& supportsContainedUser
&& LanguageUtils.IsDefaultLanguageSupported(dataContainer.Server))
{
defaultLanguageAlias = SR.DefaultLanguagePlaceholder;
}
// set the fake password placeholder when editing an existing user
string? password = null;
IUserPrototypeWithPassword userWithPwdPrototype = currentUserPrototype as IUserPrototypeWithPassword;
if (userWithPwdPrototype != null && !parameters.IsNewObject)
{
userWithPwdPrototype.Password = DatabaseUtils.GetReadOnlySecureString(LoginPrototype.fakePassword);
userWithPwdPrototype.PasswordConfirm = DatabaseUtils.GetReadOnlySecureString(LoginPrototype.fakePassword);
password = LoginPrototype.fakePassword;
}
// get the login name if it exists
string? loginName = null;
IUserPrototypeWithMappedLogin mappedLoginPrototype = currentUserPrototype as IUserPrototypeWithMappedLogin;
if (mappedLoginPrototype != null)
{
loginName = mappedLoginPrototype.LoginName;
}
// populate user's role assignments
List<string> databaseRoles = new List<string>();
foreach (string role in currentUserPrototype.DatabaseRoleNames)
{
if (currentUserPrototype.IsRoleMember(role))
{
databaseRoles.Add(role);
}
}
// populate user's schema ownerships
List<string> schemaNames = new List<string>();
foreach (string schema in currentUserPrototype.SchemaNames)
{
if (currentUserPrototype.IsSchemaOwner(schema))
{
schemaNames.Add(schema);
}
}
UserViewInfo userViewInfo = new UserViewInfo()
{
ObjectInfo = new UserInfo()
{
Type = userInfo?.Type ?? DatabaseUserType.WithLogin,
AuthenticationType = userInfo?.AuthenticationType ?? ServerAuthenticationType.Sql,
Name = currentUserPrototype.Name,
LoginName = loginName,
Password = password,
DefaultSchema = defaultSchema,
OwnedSchemas = schemaNames.ToArray(),
DatabaseRoles = databaseRoles.ToArray(),
DefaultLanguage = LanguageUtils.FormatLanguageDisplay(
languageOptions.FirstOrDefault(o => o?.Language.Name == defaultLanguageAlias || o?.Language.Alias == defaultLanguageAlias, null)),
},
SupportContainedUser = supportsContainedUser,
SupportWindowsAuthentication = false,
SupportAADAuthentication = currentUserPrototype.AADAuthSupported,
SupportSQLAuthentication = true,
Languages = languageOptionsList.ToArray(),
Schemas = currentUserPrototype.SchemaNames.ToArray(),
Logins = DatabaseUtils.LoadSqlLogins(serverConnection),
DatabaseRoles = currentUserPrototype.DatabaseRoleNames.ToArray()
};
this.contextIdToViewState.Add(
parameters.ContextId,
new ViewState(parameters.IsNewObject, parameters.Database, currentUserPrototype.CurrentState));
await requestContext.SendResult(userViewInfo);
}
/// <summary>
/// Handle request to create a user
/// </summary>
internal async Task HandleCreateUserRequest(CreateUserParams parameters, RequestContext<CreateUserResult> requestContext)
{
if (parameters.ContextId == null)
{
throw new ArgumentException("Invalid context ID");
}
ViewState viewState;
this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState);
if (viewState == null)
{
throw new ArgumentException("Invalid context ID view state");
}
ConfigureUser(
parameters.ContextId,
parameters.User,
ConfigAction.Create,
RunType.RunNow,
viewState.Database,
viewState.OriginalUserData);
await requestContext.SendResult(new CreateUserResult()
{
User = parameters.User,
Success = true,
ErrorMessage = string.Empty
});
}
/// <summary>
/// Handle request to update a user
/// </summary>
internal async Task HandleUpdateUserRequest(UpdateUserParams parameters, RequestContext<ResultStatus> requestContext)
{
if (parameters.ContextId == null)
{
throw new ArgumentException("Invalid context ID");
}
ViewState viewState;
this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState);
if (viewState == null)
{
throw new ArgumentException("Invalid context ID view state");
}
ConfigureUser(
parameters.ContextId,
parameters.User,
ConfigAction.Update,
RunType.RunNow,
viewState.Database,
viewState.OriginalUserData);
await requestContext.SendResult(new ResultStatus()
{
Success = true,
ErrorMessage = string.Empty
});
}
/// <summary>
/// Handle request to script a user
/// </summary>
internal async Task HandleScriptUserRequest(ScriptUserParams parameters, RequestContext<string> requestContext)
{
if (parameters.ContextId == null)
{
throw new ArgumentException("Invalid context ID");
}
ViewState viewState;
this.contextIdToViewState.TryGetValue(parameters.ContextId, out viewState);
if (viewState == null)
{
throw new ArgumentException("Invalid context ID view state");
}
// todo: check if it's an existing user
string sqlScript = ConfigureUser(
parameters.ContextId,
parameters.User,
viewState.IsNewObject ? ConfigAction.Create : ConfigAction.Update,
RunType.ScriptToWindow,
viewState.Database,
viewState.OriginalUserData);
await requestContext.SendResult(sqlScript);
}
internal async Task HandleDisposeUserViewRequest(DisposeUserViewRequestParams parameters, RequestContext<ResultStatus> requestContext)
{
this.ConnectionServiceInstance.Disconnect(new DisconnectParams()
{
OwnerUri = parameters.ContextId,
Type = null
});
if (parameters.ContextId != null)
{
this.contextIdToViewState.Remove(parameters.ContextId);
}
await requestContext.SendResult(new ResultStatus()
{
Success = true,
ErrorMessage = string.Empty
});
}
internal CDataContainer CreateUserDataContainer(
ConnectionInfo connInfo,
UserInfo? user,
ConfigAction configAction,
string databaseName)
{
var serverConnection = ConnectionService.OpenServerConnection(connInfo, "DataContainer");
var connectionInfoWithConnection = new SqlConnectionInfoWithConnection();
connectionInfoWithConnection.ServerConnection = serverConnection;
string urn = (configAction == ConfigAction.Update && user != null)
? string.Format(System.Globalization.CultureInfo.InvariantCulture,
"Server/Database[@Name='{0}']/User[@Name='{1}']",
Urn.EscapeString(databaseName),
Urn.EscapeString(user.Name))
: string.Format(System.Globalization.CultureInfo.InvariantCulture,
"Server/Database[@Name='{0}']",
Urn.EscapeString(databaseName));
ActionContext context = new ActionContext(serverConnection, "User", urn);
DataContainerXmlGenerator containerXml = new DataContainerXmlGenerator(context);
if (configAction == ConfigAction.Create)
{
containerXml.AddProperty("itemtype", "User");
}
XmlDocument xmlDoc = containerXml.GenerateXmlDocument();
return CDataContainer.CreateDataContainer(connectionInfoWithConnection, xmlDoc);
}
internal string ConfigureUser(
string? ownerUri,
UserInfo? user,
ConfigAction configAction,
RunType runType,
string databaseName,
UserPrototypeData? originalData)
{
ConnectionInfo connInfo;
this.ConnectionServiceInstance.TryFindConnection(ownerUri, out connInfo);
if (connInfo == null)
{
throw new ArgumentException("Invalid connection URI '{0}'", ownerUri);
}
string sqlScript = string.Empty;
CDataContainer dataContainer = CreateUserDataContainer(connInfo, user, configAction, databaseName);
using (var actions = new UserActions(dataContainer, configAction, user, originalData))
{
var executionHandler = new ExecutonHandler(actions);
executionHandler.RunNow(runType, this);
if (executionHandler.ExecutionResult == ExecutionMode.Failure)
{
throw executionHandler.ExecutionFailureException;
}
if (runType == RunType.ScriptToWindow)
{
sqlScript = executionHandler.ScriptTextFromLastRun;
}
}
return sqlScript;
}
}
internal class UserActions : ManagementActionBase
{
#region Variables
private UserPrototype userPrototype;
private ConfigAction configAction;
#endregion
#region Constructors / Dispose
/// <summary>
/// Handle user create and update actions
/// </summary>
public UserActions(
CDataContainer dataContainer,
ConfigAction configAction,
UserInfo? user,
UserPrototypeData? originalData)
{
this.DataContainer = dataContainer;
this.IsDatabaseOperation = true;
this.configAction = configAction;
ExhaustiveUserTypes currentUserType;
if (dataContainer.IsNewObject)
{
currentUserType = UserActions.GetUserTypeForUserInfo(user);
}
else
{
currentUserType = UserActions.GetCurrentUserTypeForExistingUser(
dataContainer.Server.GetSmoObject(dataContainer.ObjectUrn) as User);
}
this.userPrototype = UserPrototypeFactory.GetUserPrototype(dataContainer, user, originalData, currentUserType);
}
// /// <summary>
// /// Clean up any resources being used.
// /// </summary>
// protected override void Dispose(bool disposing)
// {
// base.Dispose(disposing);
// }
#endregion
/// <summary>
/// called by the management actions framework to execute the action
/// </summary>
/// <param name="node"></param>
public override void OnRunNow(object sender)
{
if (this.configAction != ConfigAction.Drop)
{
this.userPrototype.ApplyChanges(this.ParentDb);
}
}
internal static ExhaustiveUserTypes GetUserTypeForUserInfo(UserInfo user)
{
ExhaustiveUserTypes userType = ExhaustiveUserTypes.LoginMappedUser;
switch (user.Type)
{
case DatabaseUserType.WithLogin:
userType = ExhaustiveUserTypes.LoginMappedUser;
break;
case DatabaseUserType.WithWindowsGroupLogin:
userType = ExhaustiveUserTypes.WindowsUser;
break;
case DatabaseUserType.Contained:
if (user.AuthenticationType == ServerAuthenticationType.AzureActiveDirectory)
{
userType = ExhaustiveUserTypes.ExternalUser;
}
else
{
userType = ExhaustiveUserTypes.SqlUserWithPassword;
}
break;
case DatabaseUserType.NoConnectAccess:
userType = ExhaustiveUserTypes.SqlUserWithoutLogin;
break;
}
return userType;
}
internal static DatabaseUserType GetDatabaseUserTypeForUserType(ExhaustiveUserTypes userType)
{
DatabaseUserType databaseUserType = DatabaseUserType.WithLogin;
switch (userType)
{
case ExhaustiveUserTypes.LoginMappedUser:
databaseUserType = DatabaseUserType.WithLogin;
break;
case ExhaustiveUserTypes.WindowsUser:
databaseUserType = DatabaseUserType.WithWindowsGroupLogin;
break;
case ExhaustiveUserTypes.SqlUserWithPassword:
databaseUserType = DatabaseUserType.Contained;
break;
case ExhaustiveUserTypes.SqlUserWithoutLogin:
databaseUserType = DatabaseUserType.NoConnectAccess;
break;
case ExhaustiveUserTypes.ExternalUser:
databaseUserType = DatabaseUserType.Contained;
break;
}
return databaseUserType;
}
internal static ExhaustiveUserTypes GetCurrentUserTypeForExistingUser(User? user)
{
if (user == null)
{
return ExhaustiveUserTypes.Unknown;
}
switch (user.UserType)
{
case UserType.SqlUser:
if (user.IsSupportedProperty("AuthenticationType"))
{
if (user.AuthenticationType == AuthenticationType.Windows)
{
return ExhaustiveUserTypes.WindowsUser;
}
else if (user.AuthenticationType == AuthenticationType.Database)
{
return ExhaustiveUserTypes.SqlUserWithPassword;
}
}
return ExhaustiveUserTypes.LoginMappedUser;
case UserType.NoLogin:
return ExhaustiveUserTypes.SqlUserWithoutLogin;
case UserType.Certificate:
return ExhaustiveUserTypes.CertificateMappedUser;
case UserType.AsymmetricKey:
return ExhaustiveUserTypes.AsymmetricKeyMappedUser;
case UserType.External:
return ExhaustiveUserTypes.ExternalUser;
default:
return ExhaustiveUserTypes.Unknown;
}
}
internal static bool IsParentDatabaseContained(Urn parentDbUrn, Server server)
{
string parentDbName = parentDbUrn.GetNameForType("Database");
return IsParentDatabaseContained(server.Databases[parentDbName]);
}
internal static bool IsParentDatabaseContained(Database parentDatabase)
{
return parentDatabase.IsSupportedProperty("ContainmentType")
&& parentDatabase.ContainmentType == ContainmentType.Partial;
}
}
}

View File

@@ -9,7 +9,7 @@ using System.Threading.Tasks;
using Microsoft.SqlTools.Hosting.Protocol; using Microsoft.SqlTools.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.Agent; using Microsoft.SqlTools.ServiceLayer.Agent;
using Microsoft.SqlTools.ServiceLayer.Agent.Contracts; using Microsoft.SqlTools.ServiceLayer.Agent.Contracts;
using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Security; using Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectManagement;
using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility; using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility;
using Microsoft.SqlTools.ServiceLayer.Test.Common; using Microsoft.SqlTools.ServiceLayer.Test.Common;
using Moq; using Moq;
@@ -50,7 +50,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Agent
{ {
// setup // setup
var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath); var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath);
var credential = await SecurityTestUtils.SetupCredential(connectionResult); var credential = await ObjectManagementTestUtils.SetupCredential(connectionResult.ConnectionInfo.OwnerUri);
var service = new AgentService(); var service = new AgentService();
var proxy = AgentTestUtils.GetTestProxyInfo(); var proxy = AgentTestUtils.GetTestProxyInfo();
await AgentTestUtils.DeleteAgentProxy(service, connectionResult, proxy); await AgentTestUtils.DeleteAgentProxy(service, connectionResult, proxy);
@@ -60,7 +60,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Agent
// cleanup // cleanup
await AgentTestUtils.DeleteAgentProxy(service, connectionResult, proxy); await AgentTestUtils.DeleteAgentProxy(service, connectionResult, proxy);
await SecurityTestUtils.CleanupCredential(connectionResult, credential); await ObjectManagementTestUtils.CleanupCredential(connectionResult.ConnectionInfo.OwnerUri, credential);
} }
} }
@@ -74,7 +74,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Agent
{ {
// setup // setup
var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath); var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath);
var credential = await SecurityTestUtils.SetupCredential(connectionResult); var credential = await ObjectManagementTestUtils.SetupCredential(connectionResult.ConnectionInfo.OwnerUri);
var service = new AgentService(); var service = new AgentService();
var proxy = AgentTestUtils.GetTestProxyInfo(); var proxy = AgentTestUtils.GetTestProxyInfo();
await AgentTestUtils.DeleteAgentProxy(service, connectionResult, proxy); await AgentTestUtils.DeleteAgentProxy(service, connectionResult, proxy);
@@ -87,7 +87,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Agent
// cleanup // cleanup
await AgentTestUtils.DeleteAgentProxy(service, connectionResult, proxy); await AgentTestUtils.DeleteAgentProxy(service, connectionResult, proxy);
await SecurityTestUtils.CleanupCredential(connectionResult, credential); await ObjectManagementTestUtils.CleanupCredential(connectionResult.ConnectionInfo.OwnerUri, credential);
} }
} }
@@ -101,13 +101,13 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Agent
{ {
// setup // setup
var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath); var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath);
var credential = await SecurityTestUtils.SetupCredential(connectionResult); var credential = await ObjectManagementTestUtils.SetupCredential(connectionResult.ConnectionInfo.OwnerUri);
var service = new AgentService(); var service = new AgentService();
var proxy = AgentTestUtils.GetTestProxyInfo(); var proxy = AgentTestUtils.GetTestProxyInfo();
// test // test
await AgentTestUtils.DeleteAgentProxy(service, connectionResult, proxy); await AgentTestUtils.DeleteAgentProxy(service, connectionResult, proxy);
await SecurityTestUtils.CleanupCredential(connectionResult, credential); await ObjectManagementTestUtils.CleanupCredential(connectionResult.ConnectionInfo.OwnerUri, credential);
} }
} }
} }

View File

@@ -12,7 +12,7 @@ using System.Threading.Tasks;
using Microsoft.SqlTools.Hosting.Protocol; using Microsoft.SqlTools.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.Agent; using Microsoft.SqlTools.ServiceLayer.Agent;
using Microsoft.SqlTools.ServiceLayer.Agent.Contracts; using Microsoft.SqlTools.ServiceLayer.Agent.Contracts;
using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Security; using Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectManagement;
using Microsoft.SqlTools.ServiceLayer.Management; using Microsoft.SqlTools.ServiceLayer.Management;
using Microsoft.SqlTools.ServiceLayer.Utility; using Microsoft.SqlTools.ServiceLayer.Utility;
using Moq; using Moq;
@@ -81,7 +81,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Agent
return new AgentProxyInfo() return new AgentProxyInfo()
{ {
AccountName = "Test Proxy", AccountName = "Test Proxy",
CredentialName = SecurityTestUtils.TestCredentialName, CredentialName = ObjectManagementTestUtils.TestCredentialName,
Description = "Test proxy description", Description = "Test proxy description",
IsEnabled = true IsEnabled = true
}; };

View File

@@ -0,0 +1,46 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
#nullable disable
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility;
using Microsoft.SqlTools.ServiceLayer.ObjectManagement;
using Microsoft.SqlTools.ServiceLayer.Test.Common;
using NUnit.Framework;
namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectManagement
{
/// <summary>
/// Tests for the Credential management component
/// </summary>
public class CredentialTests
{
/// <summary>
/// TestHandleCreateCredentialRequest
/// </summary>
[Test]
public async Task TestCredentialOperations()
{
using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
{
// setup, drop credential if exists.
var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath);
var credential = ObjectManagementTestUtils.GetTestCredentialInfo();
var objUrn = ObjectManagementTestUtils.GetCredentialURN(credential.Name);
await ObjectManagementTestUtils.DropObject(connectionResult.ConnectionInfo.OwnerUri, objUrn);
// create and update
var parametersForCreation = ObjectManagementTestUtils.GetInitializeViewRequestParams(connectionResult.ConnectionInfo.OwnerUri, "master", true, SqlObjectType.Credential, "", "");
await ObjectManagementTestUtils.SaveObject(parametersForCreation, credential);
var parametersForUpdate = ObjectManagementTestUtils.GetInitializeViewRequestParams(connectionResult.ConnectionInfo.OwnerUri, "master", false, SqlObjectType.Credential, "", objUrn);
await ObjectManagementTestUtils.SaveObject(parametersForUpdate, credential);
// cleanup
await ObjectManagementTestUtils.DropObject(connectionResult.ConnectionInfo.OwnerUri, objUrn);
}
}
}
}

View File

@@ -0,0 +1,46 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
#nullable disable
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility;
using Microsoft.SqlTools.ServiceLayer.ObjectManagement;
using Microsoft.SqlTools.ServiceLayer.Test.Common;
namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectManagement
{
/// <summary>
/// Tests for the Login management component
/// </summary>
public class LoginTests
{
/// <summary>
/// Test the basic Create Login method handler
/// </summary>
// [Test]
public async Task TestHandleCreateLoginRequest()
{
using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
{
// setup, drop credential if exists.
var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath);
var testLogin = ObjectManagementTestUtils.GetTestLoginInfo();
var objUrn = ObjectManagementTestUtils.GetLoginURN(testLogin.Name);
await ObjectManagementTestUtils.DropObject(connectionResult.ConnectionInfo.OwnerUri, objUrn);
// create and update
var parametersForCreation = ObjectManagementTestUtils.GetInitializeViewRequestParams(connectionResult.ConnectionInfo.OwnerUri, "master", true, SqlObjectType.ServerLevelLogin, "", "");
await ObjectManagementTestUtils.SaveObject(parametersForCreation, testLogin);
var parametersForUpdate = ObjectManagementTestUtils.GetInitializeViewRequestParams(connectionResult.ConnectionInfo.OwnerUri, "master", false, SqlObjectType.ServerLevelLogin, "", objUrn);
await ObjectManagementTestUtils.SaveObject(parametersForUpdate, testLogin);
// cleanup
await ObjectManagementTestUtils.DropObject(connectionResult.ConnectionInfo.OwnerUri, objUrn);
}
}
}
}

View File

@@ -11,7 +11,6 @@ using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.Hosting.Protocol; using Microsoft.SqlTools.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility; using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility;
using Microsoft.SqlTools.ServiceLayer.ObjectManagement;
using Microsoft.SqlTools.ServiceLayer.ObjectManagement.Contracts; using Microsoft.SqlTools.ServiceLayer.ObjectManagement.Contracts;
using Microsoft.SqlTools.ServiceLayer.QueryExecution; using Microsoft.SqlTools.ServiceLayer.QueryExecution;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage; using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage;
@@ -27,22 +26,15 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectManagement
{ {
private const string TableQuery = @"CREATE TABLE testTable1_RenamingTable (c1 int)"; private const string TableQuery = @"CREATE TABLE testTable1_RenamingTable (c1 int)";
private const string OwnerUri = "testDB"; private const string OwnerUri = "testDB";
private ObjectManagementService objectManagementService;
private SqlTestDb testDb; private SqlTestDb testDb;
private Mock<RequestContext<bool>> requestContextMock; private Mock<RequestContext<RenameRequestResponse>> requestContextMock;
[SetUp] [SetUp]
public async Task TestInitialize() public async Task TestInitialize()
{ {
this.testDb = await SqlTestDb.CreateNewAsync(serverType: TestServerType.OnPrem, query: TableQuery, dbNamePrefix: "RenameTest"); this.testDb = await SqlTestDb.CreateNewAsync(serverType: TestServerType.OnPrem, query: TableQuery, dbNamePrefix: "RenameTest");
requestContextMock = new Mock<RequestContext<RenameRequestResponse>>();
requestContextMock = new Mock<RequestContext<bool>>();
ConnectionService connectionService = LiveConnectionHelper.GetLiveTestConnectionService();
TestConnectionResult connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync(testDb.DatabaseName, OwnerUri, ConnectionType.Default); TestConnectionResult connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync(testDb.DatabaseName, OwnerUri, ConnectionType.Default);
ObjectManagementService.ConnectionServiceInstance = connectionService;
this.objectManagementService = new ObjectManagementService();
} }
[TearDown] [TearDown]
@@ -55,10 +47,10 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectManagement
public async Task TestRenameTable() public async Task TestRenameTable()
{ {
//arrange & act //arrange & act
await objectManagementService.HandleRenameRequest(this.InitRequestParams("RenamingTable", String.Format("Server/Database[@Name='{0}']/Table[@Name='testTable1_RenamingTable' and @Schema='dbo']", testDb.DatabaseName)), requestContextMock.Object); await ObjectManagementTestUtils.Service.HandleRenameRequest(this.InitRequestParams("RenamingTable", String.Format("Server/Database[@Name='{0}']/Table[@Name='testTable1_RenamingTable' and @Schema='dbo']", testDb.DatabaseName)), requestContextMock.Object);
//assert //assert
requestContextMock.Verify(x => x.SendResult(It.Is<bool>(r => r == true))); requestContextMock.Verify(x => x.SendResult(It.IsAny<RenameRequestResponse>()));
Query queryRenameObject = ExecuteQuery("SELECT * FROM " + testDb.DatabaseName + ".sys.tables WHERE name='RenamingTable'"); Query queryRenameObject = ExecuteQuery("SELECT * FROM " + testDb.DatabaseName + ".sys.tables WHERE name='RenamingTable'");
Assert.That(queryRenameObject.HasExecuted, Is.True, "The query to check for the renamed table was not executed"); Assert.That(queryRenameObject.HasExecuted, Is.True, "The query to check for the renamed table was not executed");
@@ -75,10 +67,10 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectManagement
public async Task TestRenameColumn() public async Task TestRenameColumn()
{ {
//arrange & act //arrange & act
await objectManagementService.HandleRenameRequest(this.InitRequestParams("RenameColumn", String.Format("Server/Database[@Name='{0}']/Table[@Name='testTable1_RenamingTable' and @Schema='dbo']/Column[@Name='C1']", testDb.DatabaseName)), requestContextMock.Object); await ObjectManagementTestUtils.Service.HandleRenameRequest(this.InitRequestParams("RenameColumn", String.Format("Server/Database[@Name='{0}']/Table[@Name='testTable1_RenamingTable' and @Schema='dbo']/Column[@Name='C1']", testDb.DatabaseName)), requestContextMock.Object);
//assert //assert
requestContextMock.Verify(x => x.SendResult(It.Is<bool>(r => r == true))); requestContextMock.Verify(x => x.SendResult(It.IsAny<RenameRequestResponse>()));
Query queryRenameObject = ExecuteQuery("SELECT * FROM " + testDb.DatabaseName + ".sys.columns WHERE name='RenameColumn'"); Query queryRenameObject = ExecuteQuery("SELECT * FROM " + testDb.DatabaseName + ".sys.columns WHERE name='RenameColumn'");
Assert.That(queryRenameObject.HasExecuted, Is.True, "The query to check for the renamed column was not executed"); Assert.That(queryRenameObject.HasExecuted, Is.True, "The query to check for the renamed column was not executed");
@@ -96,7 +88,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectManagement
{ {
Assert.That(async () => Assert.That(async () =>
{ {
await objectManagementService.HandleRenameRequest(this.InitRequestParams("RenameColumn", String.Format("Server/Database[@Name='{0}']/Table[@Name='testTable1_RenamingTable' and @Schema='dbo']/Column[@Name='C1_NOT']", testDb.DatabaseName)), requestContextMock.Object); await ObjectManagementTestUtils.Service.HandleRenameRequest(this.InitRequestParams("RenameColumn", String.Format("Server/Database[@Name='{0}']/Table[@Name='testTable1_RenamingTable' and @Schema='dbo']/Column[@Name='C1_NOT']", testDb.DatabaseName)), requestContextMock.Object);
}, Throws.Exception.TypeOf<FailedOperationException>(), "Did find the column, which should not have existed"); }, Throws.Exception.TypeOf<FailedOperationException>(), "Did find the column, which should not have existed");
} }
@@ -105,7 +97,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectManagement
{ {
Assert.That(async () => Assert.That(async () =>
{ {
await objectManagementService.HandleRenameRequest(this.InitRequestParams("RenamingTable", String.Format("Server/Database[@Name='{0}']/Table[@Name='testTable1_Not' and @Schema='dbo']", testDb.DatabaseName)), requestContextMock.Object); await ObjectManagementTestUtils.Service.HandleRenameRequest(this.InitRequestParams("RenamingTable", String.Format("Server/Database[@Name='{0}']/Table[@Name='testTable1_Not' and @Schema='dbo']", testDb.DatabaseName)), requestContextMock.Object);
}, Throws.Exception.TypeOf<FailedOperationException>(), "Did find the table, which should not have existed"); }, Throws.Exception.TypeOf<FailedOperationException>(), "Did find the table, which should not have existed");
} }
@@ -120,7 +112,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectManagement
}; };
Assert.That(async () => Assert.That(async () =>
{ {
await objectManagementService.HandleRenameRequest(testRenameRequestParams, requestContextMock.Object); await ObjectManagementTestUtils.Service.HandleRenameRequest(testRenameRequestParams, requestContextMock.Object);
}, Throws.Exception.TypeOf<Exception>(), "Did find the connection, which should not have existed"); }, Throws.Exception.TypeOf<Exception>(), "Did find the connection, which should not have existed");
} }

View File

@@ -0,0 +1,206 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
#nullable disable
using System;
using System.Threading.Tasks;
using Microsoft.SqlTools.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility;
using Microsoft.SqlTools.ServiceLayer.ObjectManagement;
using Microsoft.SqlTools.ServiceLayer.ObjectManagement.Contracts;
using Moq;
using Newtonsoft.Json.Linq;
namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectManagement
{
public static class ObjectManagementTestUtils
{
private static ObjectManagementService _objectManagementService;
static ObjectManagementTestUtils()
{
ObjectManagementService.ConnectionServiceInstance = LiveConnectionHelper.GetLiveTestConnectionService();
_objectManagementService = new ObjectManagementService();
}
internal static ObjectManagementService Service
{
get
{
return _objectManagementService;
}
}
public static string TestCredentialName = "Current User";
internal static string GetCurrentUserIdentity()
{
return string.Format(@"{0}\{1}", Environment.UserDomainName, Environment.UserName);
}
internal static string GetLoginURN(string name)
{
return string.Format("Server/Login[@Name='{0}']", name);
}
internal static string GetUserURN(string database, string name)
{
return string.Format("Server/Database[@Name='{0}']/User[@Name='{1}']", database, name);
}
internal static string GetCredentialURN(string name)
{
return string.Format("Server/Credential[@Name = '{0}']", name);
}
internal static LoginInfo GetTestLoginInfo()
{
return new LoginInfo()
{
Name = "TestLoginName_" + new Random().NextInt64(10000000, 90000000).ToString(),
AuthenticationType = LoginAuthenticationType.Sql,
WindowsGrantAccess = true,
MustChangePassword = false,
IsEnabled = false,
IsLockedOut = false,
EnforcePasswordPolicy = false,
EnforcePasswordExpiration = false,
Password = "placeholder" + new Random().NextInt64(10000000, 90000000).ToString() + "!*PLACEHOLDER",
OldPassword = "placeholder" + new Random().NextInt64(10000000, 90000000).ToString() + "!*PLACEHOLDER",
DefaultLanguage = "English - us_english",
DefaultDatabase = "master"
};
}
internal static UserInfo GetTestUserInfo(DatabaseUserType userType, string userName = null, string loginName = null)
{
return new UserInfo()
{
Type = userType,
AuthenticationType = ServerAuthenticationType.Sql,
Name = userName ?? "TestUserName_" + new Random().NextInt64(10000000, 90000000).ToString(),
LoginName = loginName,
Password = "placeholder" + new Random().NextInt64(10000000, 90000000).ToString() + "!*PLACEHOLDER",
DefaultSchema = "dbo",
OwnedSchemas = new string[] { "" }
};
}
internal static CredentialInfo GetTestCredentialInfo()
{
return new CredentialInfo()
{
Identity = GetCurrentUserIdentity(),
Name = TestCredentialName
};
}
internal static InitializeViewRequestParams GetInitializeViewRequestParams(string connectionUri, string database, bool isNewObject, SqlObjectType objectType, string parentUrn, string objectUrn)
{
return new InitializeViewRequestParams()
{
ConnectionUri = connectionUri,
Database = database,
IsNewObject = isNewObject,
ObjectType = objectType,
ContextId = Guid.NewGuid().ToString(),
ParentUrn = parentUrn,
ObjectUrn = objectUrn
};
}
internal static async Task SaveObject(InitializeViewRequestParams parameters, SqlObject obj)
{
// Initialize the view
var initViewRequestContext = new Mock<RequestContext<SqlObjectViewInfo>>();
initViewRequestContext.Setup(x => x.SendResult(It.IsAny<SqlObjectViewInfo>()))
.Returns(Task.FromResult<SqlObjectViewInfo>(null));
await Service.HandleInitializeViewRequest(parameters, initViewRequestContext.Object);
// Save the object
var saveObjectRequestContext = new Mock<RequestContext<SaveObjectRequestResponse>>();
saveObjectRequestContext.Setup(x => x.SendResult(It.IsAny<SaveObjectRequestResponse>()))
.Returns(Task.FromResult<SaveObjectRequestResponse>(new SaveObjectRequestResponse()));
await Service.HandleSaveObjectRequest(new SaveObjectRequestParams { ContextId = parameters.ContextId, Object = JToken.FromObject(obj) }, saveObjectRequestContext.Object);
// Dispose the view
var disposeViewRequestContext = new Mock<RequestContext<DisposeViewRequestResponse>>();
disposeViewRequestContext.Setup(x => x.SendResult(It.IsAny<DisposeViewRequestResponse>()))
.Returns(Task.FromResult<DisposeViewRequestResponse>(new DisposeViewRequestResponse()));
await Service.HandleDisposeViewRequest(new DisposeViewRequestParams { ContextId = parameters.ContextId }, disposeViewRequestContext.Object);
}
internal static async Task ScriptObject(InitializeViewRequestParams parameters, SqlObject obj)
{
// Initialize the view
var initViewRequestContext = new Mock<RequestContext<SqlObjectViewInfo>>();
initViewRequestContext.Setup(x => x.SendResult(It.IsAny<SqlObjectViewInfo>()))
.Returns(Task.FromResult<SqlObjectViewInfo>(null));
await Service.HandleInitializeViewRequest(parameters, initViewRequestContext.Object);
// Script the object
var scriptObjectRequestContext = new Mock<RequestContext<string>>();
scriptObjectRequestContext.Setup(x => x.SendResult(It.IsAny<string>()))
.Returns(Task.FromResult<string>(""));
await Service.HandleScriptObjectRequest(new ScriptObjectRequestParams { ContextId = parameters.ContextId, Object = JToken.FromObject(obj) }, scriptObjectRequestContext.Object);
// Dispose the view
var disposeViewRequestContext = new Mock<RequestContext<DisposeViewRequestResponse>>();
disposeViewRequestContext.Setup(x => x.SendResult(It.IsAny<DisposeViewRequestResponse>()))
.Returns(Task.FromResult<DisposeViewRequestResponse>(new DisposeViewRequestResponse()));
await Service.HandleDisposeViewRequest(new DisposeViewRequestParams { ContextId = parameters.ContextId }, disposeViewRequestContext.Object);
}
internal static async Task DropObject(string connectionUri, string objectUrn)
{
var dropParams = new DropRequestParams
{
ConnectionUri = connectionUri,
ObjectUrn = objectUrn
};
var dropRequestContext = new Mock<RequestContext<DropRequestResponse>>();
dropRequestContext.Setup(x => x.SendResult(It.IsAny<DropRequestResponse>()))
.Returns(Task.FromResult(new DropRequestResponse()));
await Service.HandleDropRequest(dropParams, dropRequestContext.Object);
}
internal static async Task<LoginInfo> CreateTestLogin(string connectionUri)
{
var testLogin = GetTestLoginInfo();
var parametersForCreation = GetInitializeViewRequestParams(connectionUri, "master", true, SqlObjectType.ServerLevelLogin, "", "");
await SaveObject(parametersForCreation, testLogin);
return testLogin;
}
internal static async Task<UserInfo> CreateTestUser(string connectionUri, DatabaseUserType userType,
string userName = null,
string loginName = null,
string databaseName = "master",
bool scriptUser = false)
{
var testUser = GetTestUserInfo(userType, userName, loginName);
var parametersForCreation = GetInitializeViewRequestParams(connectionUri, databaseName, true, SqlObjectType.User, "", "");
await SaveObject(parametersForCreation, testUser);
return testUser;
}
internal static async Task<CredentialInfo> SetupCredential(string connectionUri)
{
var credential = GetTestCredentialInfo();
var parametersForCreation = ObjectManagementTestUtils.GetInitializeViewRequestParams(connectionUri, "master", true, SqlObjectType.Credential, "", "");
await DropObject(connectionUri, GetCredentialURN(credential.Name));
await ObjectManagementTestUtils.SaveObject(parametersForCreation, credential);
return credential;
}
internal static async Task CleanupCredential(string connectionUri, CredentialInfo credential)
{
await DropObject(connectionUri, GetCredentialURN(credential.Name));
}
}
}

View File

@@ -0,0 +1,97 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility;
using Microsoft.SqlTools.ServiceLayer.ObjectManagement;
using Microsoft.SqlTools.ServiceLayer.Test.Common;
using NUnit.Framework;
namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectManagement
{
/// <summary>
/// Tests for the User management component
/// </summary>
public class UserTests
{
/// <summary>
/// Test the basic Create User method handler
/// </summary>
[Test]
public async Task TestHandleSaveUserWithLoginRequest()
{
using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
{
var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath);
var connectionUri = connectionResult.ConnectionInfo.OwnerUri;
var login = await ObjectManagementTestUtils.CreateTestLogin(connectionUri);
var user = await ObjectManagementTestUtils.CreateTestUser(connectionUri, DatabaseUserType.WithLogin, null, login.Name);
var userUrn = ObjectManagementTestUtils.GetUserURN(connectionResult.ConnectionInfo.ConnectionDetails.DatabaseName, user.Name);
var parameters = ObjectManagementTestUtils.GetInitializeViewRequestParams(connectionUri, "master", false, SqlObjectType.User, "", userUrn);
await ObjectManagementTestUtils.SaveObject(parameters, user);
await ObjectManagementTestUtils.DropObject(connectionUri, userUrn);
await ObjectManagementTestUtils.DropObject(connectionUri, ObjectManagementTestUtils.GetLoginURN(login.Name));
}
}
/// <summary>
/// Test the basic Create User method handler
/// </summary>
// [Test] - Windows-only
public async Task TestHandleCreateUserWithWindowsGroup()
{
using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
{
var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath);
var connectionUri = connectionResult.ConnectionInfo.OwnerUri;
var user = await ObjectManagementTestUtils.CreateTestUser(connectionUri, DatabaseUserType.WithWindowsGroupLogin, $"{Environment.MachineName}\\Administrator");
await ObjectManagementTestUtils.DropObject(connectionUri, ObjectManagementTestUtils.GetUserURN(connectionResult.ConnectionInfo.ConnectionDetails.DatabaseName, user.Name));
}
}
/// <summary>
/// Test the basic Create User method handler
/// </summary>
// [Test] - needs contained DB
public async Task TestHandleCreateUserWithContainedSqlPassword()
{
using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
{
string databaseName = "CRM";
var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync(databaseName, queryTempFile.FilePath);
var connectionUri = connectionResult.ConnectionInfo.OwnerUri;
var user = await ObjectManagementTestUtils.CreateTestUser(connectionUri, DatabaseUserType.Contained,
userName: null,
loginName: null,
databaseName: connectionResult.ConnectionInfo.ConnectionDetails.DatabaseName);
await ObjectManagementTestUtils.DropObject(connectionResult.ConnectionInfo.OwnerUri, ObjectManagementTestUtils.GetUserURN(connectionResult.ConnectionInfo.ConnectionDetails.DatabaseName, user.Name));
}
}
/// <summary>
/// Test the basic Create User method handler
/// </summary>
[Test]
public async Task TestScriptUserWithLogin()
{
using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
{
var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath);
var connectionUri = connectionResult.ConnectionInfo.OwnerUri;
var login = await ObjectManagementTestUtils.CreateTestLogin(connectionUri);
var user = await ObjectManagementTestUtils.CreateTestUser(connectionUri, DatabaseUserType.WithLogin, null, login.Name);
var userUrn = ObjectManagementTestUtils.GetUserURN(connectionResult.ConnectionInfo.ConnectionDetails.DatabaseName, user.Name);
var parameters = ObjectManagementTestUtils.GetInitializeViewRequestParams(connectionUri, "master", false, SqlObjectType.User, "", userUrn);
await ObjectManagementTestUtils.ScriptObject(parameters, user);
await ObjectManagementTestUtils.DropObject(connectionUri, userUrn);
await ObjectManagementTestUtils.DropObject(connectionUri, ObjectManagementTestUtils.GetLoginURN(login.Name));
}
}
}
}

View File

@@ -1,86 +0,0 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
#nullable disable
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility;
using Microsoft.SqlTools.ServiceLayer.Security;
using Microsoft.SqlTools.ServiceLayer.Test.Common;
using NUnit.Framework;
namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Security
{
/// <summary>
/// Tests for the Credential management component
/// </summary>
public class CredentialTests
{
/// <summary>
/// TestHandleCreateCredentialRequest
/// </summary>
[Test]
public async Task TestHandleCreateCredentialRequest()
{
using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
{
// setup
var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath);
var service = new SecurityService();
var credential = SecurityTestUtils.GetTestCredentialInfo();
await SecurityTestUtils.DropObject(connectionResult.ConnectionInfo.OwnerUri, SecurityTestUtils.GetCredentialURN(credential.Name));
// test
await SecurityTestUtils.CreateCredential(service, connectionResult, credential);
// cleanup
await SecurityTestUtils.DropObject(connectionResult.ConnectionInfo.OwnerUri, SecurityTestUtils.GetCredentialURN(credential.Name));
}
}
/// <summary>
/// TestHandleUpdateCredentialRequest
/// </summary>
[Test]
public async Task TestHandleUpdateCredentialRequest()
{
using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
{
// setup
var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath);
var service = new SecurityService();
var credential = SecurityTestUtils.GetTestCredentialInfo();
await SecurityTestUtils.DropObject(connectionResult.ConnectionInfo.OwnerUri, SecurityTestUtils.GetCredentialURN(credential.Name));
await SecurityTestUtils.CreateCredential(service, connectionResult, credential);
// test
await SecurityTestUtils.UpdateCredential(service, connectionResult, credential);
// cleanup
await SecurityTestUtils.DropObject(connectionResult.ConnectionInfo.OwnerUri, SecurityTestUtils.GetCredentialURN(credential.Name));
}
}
/// <summary>
/// TestHandleDeleteCredentialRequest
/// </summary>
[Test]
public async Task TestHandleDeleteCredentialRequest()
{
using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
{
// setup
var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath);
var service = new SecurityService();
var credential = SecurityTestUtils.GetTestCredentialInfo();
await SecurityTestUtils.DropObject(connectionResult.ConnectionInfo.OwnerUri, SecurityTestUtils.GetCredentialURN(credential.Name));
await SecurityTestUtils.CreateCredential(service, connectionResult, credential);
// test
await SecurityTestUtils.DropObject(connectionResult.ConnectionInfo.OwnerUri, SecurityTestUtils.GetCredentialURN(credential.Name));
}
}
}
}

View File

@@ -1,65 +0,0 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
#nullable disable
using System.Threading.Tasks;
using Microsoft.SqlTools.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility;
using Microsoft.SqlTools.ServiceLayer.Security;
using Microsoft.SqlTools.ServiceLayer.Security.Contracts;
using Microsoft.SqlTools.ServiceLayer.Test.Common;
// using Microsoft.SqlTools.ServiceLayer.Utility;
using Moq;
namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Security
{
/// <summary>
/// Tests for the Login management component
/// </summary>
public class LoginTests
{
/// <summary>
/// Test the basic Create Login method handler
/// </summary>
// [Test]
public async Task TestHandleCreateLoginRequest()
{
using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
{
// setup
var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath);
var contextId = System.Guid.NewGuid().ToString();
var initializeLoginViewRequestParams = new InitializeLoginViewRequestParams
{
ConnectionUri = connectionResult.ConnectionInfo.OwnerUri,
ContextId = contextId,
IsNewObject = true
};
var loginParams = new CreateLoginParams
{
ContextId = contextId,
Login = SecurityTestUtils.GetTestLoginInfo()
};
var createLoginContext = new Mock<RequestContext<object>>();
createLoginContext.Setup(x => x.SendResult(It.IsAny<object>()))
.Returns(Task.FromResult(new object()));
var initializeLoginViewContext = new Mock<RequestContext<LoginViewInfo>>();
initializeLoginViewContext.Setup(x => x.SendResult(It.IsAny<LoginViewInfo>()))
.Returns(Task.FromResult(new LoginViewInfo()));
// call the create login method
LoginServiceHandlerImpl service = new LoginServiceHandlerImpl();
await service.HandleInitializeLoginViewRequest(initializeLoginViewRequestParams, initializeLoginViewContext.Object);
await service.HandleCreateLoginRequest(loginParams, createLoginContext.Object);
await SecurityTestUtils.DropObject(connectionResult.ConnectionInfo.OwnerUri, SecurityTestUtils.GetLoginURN(loginParams.Login.Name));
}
}
}
}

View File

@@ -1,300 +0,0 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
#nullable disable
using System;
using System.Threading.Tasks;
using Microsoft.SqlTools.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.ObjectManagement;
using Microsoft.SqlTools.ServiceLayer.ObjectManagement.Contracts;
using Microsoft.SqlTools.ServiceLayer.Security;
using Microsoft.SqlTools.ServiceLayer.Security.Contracts;
using Microsoft.SqlTools.ServiceLayer.Utility;
using Moq;
using static Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility.LiveConnectionHelper;
namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Security
{
public static class SecurityTestUtils
{
public static string TestCredentialName = "Current User";
internal static string GetCurrentUserIdentity()
{
return string.Format(@"{0}\{1}", Environment.UserDomainName, Environment.UserName);
}
internal static string GetLoginURN(string name)
{
return string.Format("Server/Login[@Name='{0}']", name);
}
internal static string GetUserURN(string database, string name)
{
return string.Format("Server/Database[@Name='{0}']/User[@Name='{1}']", database, name);
}
internal static string GetCredentialURN(string name)
{
return string.Format("Server/Credential[@Name = '{0}']", name);
}
internal static LoginInfo GetTestLoginInfo()
{
return new LoginInfo()
{
Name = "TestLoginName_" + new Random().NextInt64(10000000, 90000000).ToString(),
AuthenticationType = LoginAuthenticationType.Sql,
WindowsGrantAccess = true,
MustChangePassword = false,
IsEnabled = false,
IsLockedOut = false,
EnforcePasswordPolicy = false,
EnforcePasswordExpiration = false,
Password = "placeholder" + new Random().NextInt64(10000000, 90000000).ToString() + "!*PLACEHOLDER",
OldPassword = "placeholder" + new Random().NextInt64(10000000, 90000000).ToString() + "!*PLACEHOLDER",
DefaultLanguage = "English - us_english",
DefaultDatabase = "master"
};
}
internal static UserInfo GetTestUserInfo(DatabaseUserType userType, string userName = null, string loginName = null)
{
return new UserInfo()
{
Type = userType,
AuthenticationType = ServerAuthenticationType.Sql,
Name = userName ?? "TestUserName_" + new Random().NextInt64(10000000, 90000000).ToString(),
LoginName = loginName,
Password = "placeholder" + new Random().NextInt64(10000000, 90000000).ToString() + "!*PLACEHOLDER",
DefaultSchema = "dbo",
OwnedSchemas = new string[] { "" }
};
}
internal static CredentialInfo GetTestCredentialInfo()
{
return new CredentialInfo()
{
Identity = GetCurrentUserIdentity(),
Name = TestCredentialName
};
}
internal static async Task CreateCredential(
SecurityService service,
TestConnectionResult connectionResult,
CredentialInfo credential)
{
var context = new Mock<RequestContext<CredentialResult>>();
await service.HandleCreateCredentialRequest(new CreateCredentialParams
{
OwnerUri = connectionResult.ConnectionInfo.OwnerUri,
Credential = credential
}, context.Object);
context.VerifyAll();
}
internal static async Task UpdateCredential(
SecurityService service,
TestConnectionResult connectionResult,
CredentialInfo credential)
{
var context = new Mock<RequestContext<CredentialResult>>();
await service.HandleUpdateCredentialRequest(new UpdateCredentialParams
{
OwnerUri = connectionResult.ConnectionInfo.OwnerUri,
Credential = credential
}, context.Object);
context.VerifyAll();
}
public static async Task<CredentialInfo> SetupCredential(TestConnectionResult connectionResult)
{
var service = new SecurityService();
var credential = SecurityTestUtils.GetTestCredentialInfo();
await SecurityTestUtils.DropObject(connectionResult.ConnectionInfo.OwnerUri, SecurityTestUtils.GetCredentialURN(credential.Name));
await SecurityTestUtils.CreateCredential(service, connectionResult, credential);
return credential;
}
public static async Task CleanupCredential(
TestConnectionResult connectionResult,
CredentialInfo credential)
{
var service = new SecurityService();
await SecurityTestUtils.DropObject(connectionResult.ConnectionInfo.OwnerUri, SecurityTestUtils.GetCredentialURN(credential.Name));
}
internal static async Task<LoginInfo> CreateLogin(LoginServiceHandlerImpl service, TestConnectionResult connectionResult)
{
string contextId = System.Guid.NewGuid().ToString();
var initializeLoginViewRequestParams = new InitializeLoginViewRequestParams
{
ConnectionUri = connectionResult.ConnectionInfo.OwnerUri,
ContextId = contextId,
IsNewObject = true
};
var loginParams = new CreateLoginParams
{
ContextId = contextId,
Login = SecurityTestUtils.GetTestLoginInfo()
};
var createLoginContext = new Mock<RequestContext<object>>();
createLoginContext.Setup(x => x.SendResult(It.IsAny<object>()))
.Returns(Task.FromResult(new object()));
var initializeLoginViewContext = new Mock<RequestContext<LoginViewInfo>>();
initializeLoginViewContext.Setup(x => x.SendResult(It.IsAny<LoginViewInfo>()))
.Returns(Task.FromResult(new LoginViewInfo()));
// call the create login method
await service.HandleInitializeLoginViewRequest(initializeLoginViewRequestParams, initializeLoginViewContext.Object);
await service.HandleCreateLoginRequest(loginParams, createLoginContext.Object);
return loginParams.Login;
}
internal static async Task<UserInfo> CreateUser(
UserServiceHandlerImpl service,
TestConnectionResult connectionResult,
DatabaseUserType userType,
string userName = null,
string loginName = null,
string databaseName = "master",
bool scriptUser = false)
{
string contextId = System.Guid.NewGuid().ToString();
var initializeViewRequestParams = new InitializeUserViewParams
{
ConnectionUri = connectionResult.ConnectionInfo.OwnerUri,
ContextId = contextId,
IsNewObject = true,
Database = databaseName
};
var initializeUserViewContext = new Mock<RequestContext<UserViewInfo>>();
initializeUserViewContext.Setup(x => x.SendResult(It.IsAny<UserViewInfo>()))
.Returns(Task.FromResult(new UserViewInfo()));
await service.HandleInitializeUserViewRequest(initializeViewRequestParams, initializeUserViewContext.Object);
if (scriptUser)
{
var scriptParams = new ScriptUserParams
{
ContextId = contextId,
User = SecurityTestUtils.GetTestUserInfo(userType, userName, loginName)
};
var scriptUserContext = new Mock<RequestContext<string>>();
scriptUserContext.Setup(x => x.SendResult(It.IsAny<string>()))
.Returns(Task.FromResult(new object()));
await service.HandleScriptUserRequest(scriptParams, scriptUserContext.Object);
// verify the result
scriptUserContext.Verify(x => x.SendResult(It.Is<string>
(p => p.Contains("CREATE USER"))));
}
var userParams = new CreateUserParams
{
ContextId = contextId,
User = SecurityTestUtils.GetTestUserInfo(userType, userName, loginName)
};
var createUserContext = new Mock<RequestContext<CreateUserResult>>();
createUserContext.Setup(x => x.SendResult(It.IsAny<CreateUserResult>()))
.Returns(Task.FromResult(new object()));
// call the create login method
await service.HandleCreateUserRequest(userParams, createUserContext.Object);
// verify the result
createUserContext.Verify(x => x.SendResult(It.Is<CreateUserResult>
(p => p.Success && p.User.Name != string.Empty)));
var disposeViewRequestParams = new DisposeUserViewRequestParams
{
ContextId = contextId
};
var disposeUserViewContext = new Mock<RequestContext<ResultStatus>>();
disposeUserViewContext.Setup(x => x.SendResult(It.IsAny<ResultStatus>()))
.Returns(Task.FromResult(new object()));
await service.HandleDisposeUserViewRequest(disposeViewRequestParams, disposeUserViewContext.Object);
return userParams.User;
}
internal static async Task<UserInfo> UpdateUser(
UserServiceHandlerImpl service,
TestConnectionResult connectionResult,
UserInfo user)
{
string contextId = System.Guid.NewGuid().ToString();
var initializeViewRequestParams = new InitializeUserViewParams
{
ConnectionUri = connectionResult.ConnectionInfo.OwnerUri,
ContextId = contextId,
IsNewObject = false,
Database = "master",
Name = user.Name
};
var initializeUserViewContext = new Mock<RequestContext<UserViewInfo>>();
initializeUserViewContext.Setup(x => x.SendResult(It.IsAny<UserViewInfo>()))
.Returns(Task.FromResult(new UserViewInfo()));
await service.HandleInitializeUserViewRequest(initializeViewRequestParams, initializeUserViewContext.Object);
// update the user
user.DatabaseRoles = new string[] { "db_datareader" };
var updateParams = new UpdateUserParams
{
ContextId = contextId,
User = user
};
var updateUserContext = new Mock<RequestContext<ResultStatus>>();
// call the create login method
await service.HandleUpdateUserRequest(updateParams, updateUserContext.Object);
// verify the result
updateUserContext.Verify(x => x.SendResult(It.Is<ResultStatus>(p => p.Success)));
var disposeViewRequestParams = new DisposeUserViewRequestParams
{
ContextId = contextId
};
var disposeUserViewContext = new Mock<RequestContext<ResultStatus>>();
disposeUserViewContext.Setup(x => x.SendResult(It.IsAny<ResultStatus>()))
.Returns(Task.FromResult(new object()));
await service.HandleDisposeUserViewRequest(disposeViewRequestParams, disposeUserViewContext.Object);
return updateParams.User;
}
internal static async Task DropObject(string connectionUri, string objectUrn)
{
ObjectManagementService objectManagementService = new ObjectManagementService();
var dropParams = new DropRequestParams
{
ConnectionUri = connectionUri,
ObjectUrn = objectUrn
};
var dropRequestContext = new Mock<RequestContext<bool>>();
dropRequestContext.Setup(x => x.SendResult(It.IsAny<bool>()))
.Returns(Task.FromResult(true));
await objectManagementService.HandleDropRequest(dropParams, dropRequestContext.Object);
}
}
}

View File

@@ -1,140 +0,0 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility;
using Microsoft.SqlTools.ServiceLayer.Security;
using Microsoft.SqlTools.ServiceLayer.Security.Contracts;
using Microsoft.SqlTools.ServiceLayer.Test.Common;
using NUnit.Framework;
namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Security
{
/// <summary>
/// Tests for the User management component
/// </summary>
public class UserTests
{
/// <summary>
/// Test the basic Create User method handler
/// </summary>
[Test]
public async Task TestHandleCreateUserWithLoginRequest()
{
using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
{
// setup
UserServiceHandlerImpl userService = new UserServiceHandlerImpl();
LoginServiceHandlerImpl loginService = new LoginServiceHandlerImpl();
var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath);
var login = await SecurityTestUtils.CreateLogin(loginService, connectionResult);
var user = await SecurityTestUtils.CreateUser(userService, connectionResult, DatabaseUserType.WithLogin, null, login.Name);
await SecurityTestUtils.DropObject(connectionResult.ConnectionInfo.OwnerUri, SecurityTestUtils.GetUserURN(connectionResult.ConnectionInfo.ConnectionDetails.DatabaseName, user.Name));
await SecurityTestUtils.DropObject(connectionResult.ConnectionInfo.OwnerUri, SecurityTestUtils.GetLoginURN(login.Name));
}
}
/// <summary>
/// Test the basic Create User method handler
/// </summary>
// [Test] - Windows-only
public async Task TestHandleCreateUserWithWindowsGroup()
{
using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
{
// setup
UserServiceHandlerImpl userService = new UserServiceHandlerImpl();
var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath);
var user = await SecurityTestUtils.CreateUser(
userService,
connectionResult,
DatabaseUserType.WithWindowsGroupLogin,
$"{Environment.MachineName}\\Administrator");
await SecurityTestUtils.DropObject(connectionResult.ConnectionInfo.OwnerUri, SecurityTestUtils.GetUserURN(connectionResult.ConnectionInfo.ConnectionDetails.DatabaseName, user.Name));
}
}
/// <summary>
/// Test the basic Create User method handler
/// </summary>
// [Test] - needs contained DB
public async Task TestHandleCreateUserWithContainedSqlPassword()
{
using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
{
// setup
UserServiceHandlerImpl userService = new UserServiceHandlerImpl();
string databaseName = "CRM";
var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync(databaseName, queryTempFile.FilePath);
var user = await SecurityTestUtils.CreateUser(
userService,
connectionResult,
DatabaseUserType.Contained,
userName: null,
loginName: null,
databaseName: connectionResult.ConnectionInfo.ConnectionDetails.DatabaseName);
await SecurityTestUtils.DropObject(connectionResult.ConnectionInfo.OwnerUri, SecurityTestUtils.GetUserURN(connectionResult.ConnectionInfo.ConnectionDetails.DatabaseName, user.Name));
}
}
/// <summary>
/// Test the basic Update User method handler
/// </summary>
[Test]
public async Task TestHandleUpdateUserWithLoginRequest()
{
using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
{
// setup
UserServiceHandlerImpl userService = new UserServiceHandlerImpl();
LoginServiceHandlerImpl loginService = new LoginServiceHandlerImpl();
var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath);
var login = await SecurityTestUtils.CreateLogin(loginService, connectionResult);
var user = await SecurityTestUtils.CreateUser(userService, connectionResult, DatabaseUserType.WithLogin, null, login.Name);
await SecurityTestUtils.UpdateUser(userService, connectionResult, user);
await SecurityTestUtils.DropObject(connectionResult.ConnectionInfo.OwnerUri, SecurityTestUtils.GetUserURN(connectionResult.ConnectionInfo.ConnectionDetails.DatabaseName, user.Name));
await SecurityTestUtils.DropObject(connectionResult.ConnectionInfo.OwnerUri, SecurityTestUtils.GetLoginURN(login.Name));
}
}
/// <summary>
/// Test the basic Create User method handler
/// </summary>
[Test]
public async Task TestScriptUserWithLogin()
{
using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
{
// setup
UserServiceHandlerImpl userService = new UserServiceHandlerImpl();
LoginServiceHandlerImpl loginService = new LoginServiceHandlerImpl();
var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath);
var login = await SecurityTestUtils.CreateLogin(loginService, connectionResult);
var user = await SecurityTestUtils.CreateUser(userService, connectionResult,
DatabaseUserType.WithLogin, null, login.Name, scriptUser: true);
await SecurityTestUtils.DropObject(connectionResult.ConnectionInfo.OwnerUri, SecurityTestUtils.GetUserURN(connectionResult.ConnectionInfo.ConnectionDetails.DatabaseName, user.Name));
await SecurityTestUtils.DropObject(connectionResult.ConnectionInfo.OwnerUri, SecurityTestUtils.GetLoginURN(login.Name));
}
}
}
}