Fixing the bug with connections on database make restore fail (#473)

* closing the connections that don't need to be open and keeping track of the connections that should stay open
This commit is contained in:
Leila Lali
2017-10-05 20:06:31 -07:00
committed by GitHub
parent 7444939335
commit f09b9f4c30
33 changed files with 1045 additions and 287 deletions

View File

@@ -7,7 +7,6 @@ using Microsoft.SqlTools.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.Admin.Contracts; using Microsoft.SqlTools.ServiceLayer.Admin.Contracts;
using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Hosting; using Microsoft.SqlTools.ServiceLayer.Hosting;
using Microsoft.SqlTools.ServiceLayer.SqlContext;
using System; using System;
using System.Threading.Tasks; using System.Threading.Tasks;
using System.Xml; using System.Xml;
@@ -28,8 +27,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Admin
private static readonly ConcurrentDictionary<string, DatabaseTaskHelper> serverTaskHelperMap = private static readonly ConcurrentDictionary<string, DatabaseTaskHelper> serverTaskHelperMap =
new ConcurrentDictionary<string, DatabaseTaskHelper>(); new ConcurrentDictionary<string, DatabaseTaskHelper>();
private static DatabaseTaskHelper taskHelper;
/// <summary> /// <summary>
/// Default, parameterless constructor. /// Default, parameterless constructor.
/// </summary> /// </summary>
@@ -91,13 +88,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Admin
optionsParams.OwnerUri, optionsParams.OwnerUri,
out connInfo); out connInfo);
if (taskHelper == null) using (var taskHelper = CreateDatabaseTaskHelper(connInfo))
{ {
taskHelper = CreateDatabaseTaskHelper(connInfo); response.DefaultDatabaseInfo = DatabaseTaskHelper.DatabasePrototypeToDatabaseInfo(taskHelper.Prototype);
await requestContext.SendResult(response);
} }
response.DefaultDatabaseInfo = DatabaseTaskHelper.DatabasePrototypeToDatabaseInfo(taskHelper.Prototype);
await requestContext.SendResult(response);
} }
catch (Exception ex) catch (Exception ex)
{ {
@@ -120,25 +115,19 @@ namespace Microsoft.SqlTools.ServiceLayer.Admin
databaseParams.OwnerUri, databaseParams.OwnerUri,
out connInfo); out connInfo);
if (taskHelper == null) using (var taskHelper = CreateDatabaseTaskHelper(connInfo))
{ {
taskHelper = CreateDatabaseTaskHelper(connInfo); DatabasePrototype prototype = taskHelper.Prototype;
DatabaseTaskHelper.ApplyToPrototype(databaseParams.DatabaseInfo, taskHelper.Prototype);
Database db = prototype.ApplyChanges();
await requestContext.SendResult(new CreateDatabaseResponse()
{
Result = true,
TaskId = 0
});
} }
DatabasePrototype prototype = taskHelper.Prototype;
DatabaseTaskHelper.ApplyToPrototype(databaseParams.DatabaseInfo, taskHelper.Prototype);
Database db = prototype.ApplyChanges();
if (db != null)
{
taskHelper = null;
}
await requestContext.SendResult(new CreateDatabaseResponse()
{
Result = true,
TaskId = 0
});
} }
catch (Exception ex) catch (Exception ex)
{ {
@@ -182,9 +171,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Admin
/// <param name="connInfo"></param> /// <param name="connInfo"></param>
/// <returns></returns> /// <returns></returns>
internal static DatabaseInfo GetDatabaseInfo(ConnectionInfo connInfo) internal static DatabaseInfo GetDatabaseInfo(ConnectionInfo connInfo)
{ {
DatabaseTaskHelper taskHelper = CreateDatabaseTaskHelper(connInfo, true); using (DatabaseTaskHelper taskHelper = CreateDatabaseTaskHelper(connInfo, true))
return DatabaseTaskHelper.DatabasePrototypeToDatabaseInfo(taskHelper.Prototype); {
return DatabaseTaskHelper.DatabasePrototypeToDatabaseInfo(taskHelper.Prototype);
}
} }
/// <summary> /// <summary>
@@ -205,6 +196,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Admin
: string.Format("{0},{1}", connectionDetails.ServerName, connectionDetails.Port.Value); : string.Format("{0},{1}", connectionDetails.ServerName, connectionDetails.Port.Value);
// check if the connection is using SQL Auth or Integrated Auth // check if the connection is using SQL Auth or Integrated Auth
//TODO: ConnectionQueue try to get an existing connection (ConnectionQueue)
if (string.Equals(connectionDetails.AuthenticationType, "SqlLogin", StringComparison.OrdinalIgnoreCase)) if (string.Equals(connectionDetails.AuthenticationType, "SqlLogin", StringComparison.OrdinalIgnoreCase))
{ {
var passwordSecureString = BuildSecureStringFromPassword(connectionDetails.Password); var passwordSecureString = BuildSecureStringFromPassword(connectionDetails.Password);

View File

@@ -6,6 +6,7 @@
using Microsoft.SqlServer.Management.Common; using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlServer.Management.Smo; using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.ServiceLayer.Admin.Contracts; using Microsoft.SqlTools.ServiceLayer.Admin.Contracts;
using Microsoft.SqlTools.Utility;
using System; using System;
using System.Collections; using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
@@ -15,7 +16,7 @@ using System.Xml;
namespace Microsoft.SqlTools.ServiceLayer.Admin namespace Microsoft.SqlTools.ServiceLayer.Admin
{ {
public class DatabaseTaskHelper public class DatabaseTaskHelper: IDisposable
{ {
private DatabasePrototype prototype; private DatabasePrototype prototype;
@@ -184,5 +185,20 @@ namespace Microsoft.SqlTools.ServiceLayer.Admin
} }
return prototype; return prototype;
} }
public void Dispose()
{
try
{
if (this.DataContainer != null)
{
this.DataContainer.Dispose();
}
}
catch(Exception ex)
{
Logger.Write(LogLevel.Warning, $"Failed to disconnect Database task Helper connection. Error: {ex.Message}");
}
}
} }
} }

View File

@@ -18,8 +18,6 @@ using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection; using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
using Microsoft.SqlTools.ServiceLayer.LanguageServices; using Microsoft.SqlTools.ServiceLayer.LanguageServices;
using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts; using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts;
using Microsoft.SqlTools.ServiceLayer.SqlContext;
using Microsoft.SqlTools.ServiceLayer.Workspace;
using Microsoft.SqlServer.Management.Common; using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlTools.Utility; using Microsoft.SqlTools.Utility;
@@ -53,7 +51,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
/// The SQL connection factory object /// The SQL connection factory object
/// </summary> /// </summary>
private ISqlConnectionFactory connectionFactory; private ISqlConnectionFactory connectionFactory;
private DatabaseLocksManager lockedDatabaseManager;
private readonly Dictionary<string, ConnectionInfo> ownerToConnectionMap = new Dictionary<string, ConnectionInfo>(); private readonly Dictionary<string, ConnectionInfo> ownerToConnectionMap = new Dictionary<string, ConnectionInfo>();
/// <summary> /// <summary>
@@ -65,7 +65,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
private readonly object cancellationTokenSourceLock = new object(); private readonly object cancellationTokenSourceLock = new object();
private ConnectedBindingQueue connectionQueue = new ConnectedBindingQueue(needsMetadata: false); private ConcurrentDictionary<string, IConnectedBindingQueue> connectedQueues = new ConcurrentDictionary<string, IConnectedBindingQueue>();
/// <summary> /// <summary>
/// Map from script URIs to ConnectionInfo objects /// Map from script URIs to ConnectionInfo objects
@@ -79,6 +79,25 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
} }
} }
/// <summary>
/// Database Lock manager instance
/// </summary>
internal DatabaseLocksManager LockedDatabaseManager
{
get
{
if (lockedDatabaseManager == null)
{
lockedDatabaseManager = DatabaseLocksManager.Instance;
}
return lockedDatabaseManager;
}
set
{
this.lockedDatabaseManager = value;
}
}
/// <summary> /// <summary>
/// Service host object for sending/receiving requests/events. /// Service host object for sending/receiving requests/events.
/// Internal for testing purposes. /// Internal for testing purposes.
@@ -92,20 +111,63 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
/// <summary> /// <summary>
/// Gets the connection queue /// Gets the connection queue
/// </summary> /// </summary>
internal ConnectedBindingQueue ConnectionQueue internal IConnectedBindingQueue ConnectionQueue
{ {
get get
{ {
return this.connectionQueue; return this.GetConnectedQueue("Default");
} }
} }
/// <summary> /// <summary>
/// Default constructor should be private since it's a singleton class, but we need a constructor /// Default constructor should be private since it's a singleton class, but we need a constructor
/// for use in unit test mocking. /// for use in unit test mocking.
/// </summary> /// </summary>
public ConnectionService() public ConnectionService()
{ {
var defaultQueue = new ConnectedBindingQueue(needsMetadata: false);
connectedQueues.AddOrUpdate("Default", defaultQueue, (key, old) => defaultQueue);
this.LockedDatabaseManager.ConnectionService = this;
}
/// <summary>
/// Returns a connection queue for given type
/// </summary>
/// <param name="type"></param>
/// <returns></returns>
public IConnectedBindingQueue GetConnectedQueue(string type)
{
IConnectedBindingQueue connectedBindingQueue;
if (connectedQueues.TryGetValue(type, out connectedBindingQueue))
{
return connectedBindingQueue;
}
return null;
}
/// <summary>
/// Returns all the connection queues
/// </summary>
public IEnumerable<IConnectedBindingQueue> ConnectedQueues
{
get
{
return this.connectedQueues.Values;
}
}
/// <summary>
/// Register a new connection queue if not already registered
/// </summary>
/// <param name="type"></param>
/// <param name="connectedQueue"></param>
public virtual void RegisterConnectedQueue(string type, IConnectedBindingQueue connectedQueue)
{
if (!connectedQueues.ContainsKey(type))
{
connectedQueues.AddOrUpdate(type, connectedQueue, (key, old) => connectedQueue);
}
} }
/// <summary> /// <summary>
@@ -243,6 +305,15 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
// Invoke callback notifications // Invoke callback notifications
InvokeOnConnectionActivities(connectionInfo, connectionParams); InvokeOnConnectionActivities(connectionInfo, connectionParams);
if(connectionParams.Type == ConnectionType.ObjectExplorer)
{
DbConnection connection;
if (connectionInfo.TryGetConnection(ConnectionType.ObjectExplorer, out connection))
{
// OE doesn't need to keep the connection open
connection.Close();
}
}
return completeParams; return completeParams;
} }
@@ -359,9 +430,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
DbConnection connection = null; DbConnection connection = null;
CancelTokenKey cancelKey = new CancelTokenKey { OwnerUri = connectionParams.OwnerUri, Type = connectionParams.Type }; CancelTokenKey cancelKey = new CancelTokenKey { OwnerUri = connectionParams.OwnerUri, Type = connectionParams.Type };
ConnectionCompleteParams response = new ConnectionCompleteParams { OwnerUri = connectionInfo.OwnerUri, Type = connectionParams.Type }; ConnectionCompleteParams response = new ConnectionCompleteParams { OwnerUri = connectionInfo.OwnerUri, Type = connectionParams.Type };
bool? currentPooling = connectionInfo.ConnectionDetails.Pooling;
try try
{ {
connectionInfo.ConnectionDetails.Pooling = false;
// build the connection string from the input parameters // build the connection string from the input parameters
string connectionString = BuildConnectionString(connectionInfo.ConnectionDetails); string connectionString = BuildConnectionString(connectionInfo.ConnectionDetails);
@@ -382,7 +455,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
} }
cancelTupleToCancellationTokenSourceMap[cancelKey] = source; cancelTupleToCancellationTokenSourceMap[cancelKey] = source;
} }
// Open the connection // Open the connection
await connection.OpenAsync(source.Token); await connection.OpenAsync(source.Token);
} }
@@ -419,6 +492,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
} }
source?.Dispose(); source?.Dispose();
} }
if (connectionInfo != null && connectionInfo.ConnectionDetails != null)
{
connectionInfo.ConnectionDetails.Pooling = currentPooling;
}
} }
// Return null upon success // Return null upon success
@@ -1158,7 +1235,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
// turn off connection pool to avoid hold locks on server resources after calling SqlConnection Close method // turn off connection pool to avoid hold locks on server resources after calling SqlConnection Close method
connInfo.ConnectionDetails.Pooling = false; connInfo.ConnectionDetails.Pooling = false;
// generate connection string // generate connection string
string connectionString = ConnectionService.BuildConnectionString(connInfo.ConnectionDetails); string connectionString = ConnectionService.BuildConnectionString(connInfo.ConnectionDetails);
// restore original values // restore original values
@@ -1167,7 +1244,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
connInfo.ConnectionDetails.Pooling = originalPooling; connInfo.ConnectionDetails.Pooling = originalPooling;
// open a dedicated binding server connection // open a dedicated binding server connection
SqlConnection sqlConn = new SqlConnection(connectionString); SqlConnection sqlConn = new SqlConnection(connectionString);
sqlConn.Open(); sqlConn.Open();
return sqlConn; return sqlConn;
} }

View File

@@ -16,5 +16,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
public const string Default = "Default"; public const string Default = "Default";
public const string Query = "Query"; public const string Query = "Query";
public const string Edit = "Edit"; public const string Edit = "Edit";
public const string ObjectExplorer = "ObjectExplorer";
} }
} }

View File

@@ -0,0 +1,28 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System;
namespace Microsoft.SqlTools.ServiceLayer.Connection
{
public class DatabaseFullAccessException: Exception
{
public DatabaseFullAccessException()
: base()
{
}
public DatabaseFullAccessException(string message, Exception exception)
: base(message, exception)
{
}
public DatabaseFullAccessException(string message)
: base(message)
{
}
}
}

View File

@@ -0,0 +1,113 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using Microsoft.SqlTools.ServiceLayer.LanguageServices;
using System;
using System.Collections.Generic;
using System.Threading;
namespace Microsoft.SqlTools.ServiceLayer.Connection
{
public class DatabaseLocksManager: IDisposable
{
internal DatabaseLocksManager(int waitToGetFullAccess)
{
this.waitToGetFullAccess = waitToGetFullAccess;
}
private static DatabaseLocksManager instance = new DatabaseLocksManager(DefaultWaitToGetFullAccess);
public static DatabaseLocksManager Instance
{
get
{
return instance;
}
}
public ConnectionService ConnectionService { get; set; }
private Dictionary<string, ManualResetEvent> databaseAccessEvents = new Dictionary<string, ManualResetEvent>();
private object databaseAccessLock = new object();
public const int DefaultWaitToGetFullAccess = 60000;
public int waitToGetFullAccess = DefaultWaitToGetFullAccess;
private ManualResetEvent GetResetEvent(string serverName, string databaseName)
{
string key = GenerateKey(serverName, databaseName);
ManualResetEvent resetEvent = null;
lock (databaseAccessLock)
{
if (!databaseAccessEvents.TryGetValue(key, out resetEvent))
{
resetEvent = new ManualResetEvent(true);
databaseAccessEvents.Add(key, resetEvent);
}
}
return resetEvent;
}
public bool GainFullAccessToDatabase(string serverName, string databaseName)
{
/*
ManualResetEvent resetEvent = GetResetEvent(serverName, databaseName);
if (resetEvent.WaitOne(this.waitToGetFullAccess))
{
resetEvent.Reset();
foreach (IConnectedBindingQueue item in ConnectionService.ConnectedQueues)
{
item.CloseConnections(serverName, databaseName);
}
return true;
}
else
{
throw new DatabaseFullAccessException($"Waited more than {waitToGetFullAccess} milli seconds for others to release the lock");
}
*/
foreach (IConnectedBindingQueue item in ConnectionService.ConnectedQueues)
{
item.CloseConnections(serverName, databaseName);
}
return true;
}
public bool ReleaseAccess(string serverName, string databaseName)
{
/*
ManualResetEvent resetEvent = GetResetEvent(serverName, databaseName);
foreach (IConnectedBindingQueue item in ConnectionService.ConnectedQueues)
{
item.OpenConnections(serverName, databaseName);
}
resetEvent.Set();
*/
foreach (IConnectedBindingQueue item in ConnectionService.ConnectedQueues)
{
item.OpenConnections(serverName, databaseName);
}
return true;
}
private string GenerateKey(string serverName, string databaseName)
{
return $"{serverName.ToLowerInvariant()}-{databaseName.ToLowerInvariant()}";
}
public void Dispose()
{
foreach (var resetEvent in databaseAccessEvents)
{
resetEvent.Value.Dispose();
}
}
}
}

View File

@@ -0,0 +1,40 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
namespace Microsoft.SqlTools.ServiceLayer.Connection
{
/// <summary>
/// Any operation that needs full access to databas should implement this interface.
/// Make sure to call GainAccessToDatabase before the operation and ReleaseAccessToDatabase after
/// </summary>
public interface IFeatureWithFullDbAccess
{
/// <summary>
/// Database Lock Manager
/// </summary>
DatabaseLocksManager LockedDatabaseManager { get; set; }
/// <summary>
/// Makes sure the feature has fill access to the database
/// </summary>
bool GainAccessToDatabase();
/// <summary>
/// Release the access to db
/// </summary>
bool ReleaseAccessToDatabase();
/// <summary>
/// Server name
/// </summary>
string ServerName { get; }
/// <summary>
/// Database name
/// </summary>
string DatabaseName { get; }
}
}

View File

@@ -313,6 +313,17 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery
{ {
throw; throw;
} }
finally
{
if (this.serverConnection != null)
{
this.serverConnection.Disconnect();
}
if(this.dataContainer != null)
{
this.dataContainer.Dispose();
}
}
} }
/// <summary> /// <summary>

View File

@@ -142,13 +142,17 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery
if (connInfo != null) if (connInfo != null)
{ {
DatabaseTaskHelper helper = AdminService.CreateDatabaseTaskHelper(connInfo, databaseExists: true); using (DatabaseTaskHelper helper = AdminService.CreateDatabaseTaskHelper(connInfo, databaseExists: true))
SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connInfo);
if (sqlConn != null && !connInfo.IsSqlDW && !connInfo.IsAzure)
{ {
BackupConfigInfo backupConfigInfo = this.GetBackupConfigInfo(helper.DataContainer, sqlConn, sqlConn.Database); using (SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connInfo))
backupConfigInfo.DatabaseInfo = AdminService.GetDatabaseInfo(connInfo); {
response.BackupConfigInfo = backupConfigInfo; if (sqlConn != null && !connInfo.IsSqlDW && !connInfo.IsAzure)
{
BackupConfigInfo backupConfigInfo = this.GetBackupConfigInfo(helper.DataContainer, sqlConn, sqlConn.Database);
backupConfigInfo.DatabaseInfo = AdminService.GetDatabaseInfo(connInfo);
response.BackupConfigInfo = backupConfigInfo;
}
}
} }
} }
@@ -233,7 +237,6 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery
RequestContext<RestoreResponse> requestContext) RequestContext<RestoreResponse> requestContext)
{ {
RestoreResponse response = new RestoreResponse(); RestoreResponse response = new RestoreResponse();
try try
{ {
ConnectionInfo connInfo; ConnectionInfo connInfo;
@@ -243,10 +246,12 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery
{ {
try try
{ {
RestoreDatabaseTaskDataObject restoreDataObject = this.restoreDatabaseService.CreateRestoreDatabaseTaskDataObject(restoreParams);
RestoreDatabaseTaskDataObject restoreDataObject = this.restoreDatabaseService.CreateRestoreDatabaseTaskDataObject(restoreParams, connInfo);
if (restoreDataObject != null) if (restoreDataObject != null)
{ {
restoreDataObject.LockedDatabaseManager = ConnectionServiceInstance.LockedDatabaseManager;
// create task metadata // create task metadata
TaskMetadata metadata = TaskMetadata.Create(restoreParams, SR.RestoreTaskName, restoreDataObject, ConnectionServiceInstance); TaskMetadata metadata = TaskMetadata.Create(restoreParams, SR.RestoreTaskName, restoreDataObject, ConnectionServiceInstance);
metadata.DatabaseName = restoreParams.TargetDatabaseName; metadata.DatabaseName = restoreParams.TargetDatabaseName;
@@ -297,6 +302,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery
{ {
DatabaseTaskHelper helper = AdminService.CreateDatabaseTaskHelper(connInfo, databaseExists: true); DatabaseTaskHelper helper = AdminService.CreateDatabaseTaskHelper(connInfo, databaseExists: true);
SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connInfo); SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connInfo);
// Connection gets discounnected when backup is done
BackupOperation backupOperation = CreateBackupOperation(helper.DataContainer, sqlConn, backupParams.BackupInfo); BackupOperation backupOperation = CreateBackupOperation(helper.DataContainer, sqlConn, backupParams.BackupInfo);
SqlTask sqlTask = null; SqlTask sqlTask = null;
@@ -332,17 +338,19 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery
if (connInfo != null) if (connInfo != null)
{ {
sqlConn = ConnectionService.OpenSqlConnection(connInfo); using (sqlConn = ConnectionService.OpenSqlConnection(connInfo))
if (sqlConn != null && !connInfo.IsSqlDW && !connInfo.IsAzure)
{ {
connectionInfo = connInfo; if (sqlConn != null && !connInfo.IsSqlDW && !connInfo.IsAzure)
return true; {
connectionInfo = connInfo;
return true;
}
} }
} }
} }
catch catch
{ {
if(sqlConn != null) if(sqlConn != null && sqlConn.State == System.Data.ConnectionState.Open)
{ {
sqlConn.Close(); sqlConn.Close();
} }

View File

@@ -88,10 +88,6 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation
RestoreAsFileName = x.PhysicalNameRelocate RestoreAsFileName = x.PhysicalNameRelocate
}); });
response.CanRestore = CanRestore(restoreDataObject); response.CanRestore = CanRestore(restoreDataObject);
if (!response.CanRestore)
{
response.ErrorMessage = SR.NoBackupsetsToRestore;
}
response.PlanDetails.Add(LastBackupTaken, response.PlanDetails.Add(LastBackupTaken,
RestorePlanDetailInfo.Create(name: LastBackupTaken, currentValue: restoreDataObject.GetLastBackupTaken(), isReadOnly: true)); RestorePlanDetailInfo.Create(name: LastBackupTaken, currentValue: restoreDataObject.GetLastBackupTaken(), isReadOnly: true));
@@ -150,7 +146,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation
/// </summary> /// </summary>
/// <param name="restoreParams">Restore request parameters</param> /// <param name="restoreParams">Restore request parameters</param>
/// <returns>Restore task object</returns> /// <returns>Restore task object</returns>
public RestoreDatabaseTaskDataObject CreateRestoreDatabaseTaskDataObject(RestoreParams restoreParams) public RestoreDatabaseTaskDataObject CreateRestoreDatabaseTaskDataObject(RestoreParams restoreParams, ConnectionInfo connectionInfo = null)
{ {
RestoreDatabaseTaskDataObject restoreTaskObject = null; RestoreDatabaseTaskDataObject restoreTaskObject = null;
string sessionId = string.IsNullOrWhiteSpace(restoreParams.SessionId) ? Guid.NewGuid().ToString() : restoreParams.SessionId; string sessionId = string.IsNullOrWhiteSpace(restoreParams.SessionId) ? Guid.NewGuid().ToString() : restoreParams.SessionId;
@@ -161,6 +157,10 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation
} }
restoreTaskObject.SessionId = sessionId; restoreTaskObject.SessionId = sessionId;
restoreTaskObject.RestoreParams = restoreParams; restoreTaskObject.RestoreParams = restoreParams;
if (connectionInfo != null)
{
restoreTaskObject.ConnectionInfo = connectionInfo;
}
return restoreTaskObject; return restoreTaskObject;
} }

View File

@@ -14,6 +14,7 @@ using Microsoft.SqlTools.ServiceLayer.DisasterRecovery.Contracts;
using Microsoft.SqlTools.ServiceLayer.TaskServices; using Microsoft.SqlTools.ServiceLayer.TaskServices;
using Microsoft.SqlTools.ServiceLayer.Utility; using Microsoft.SqlTools.ServiceLayer.Utility;
using Microsoft.SqlTools.Utility; using Microsoft.SqlTools.Utility;
using Microsoft.SqlTools.ServiceLayer.Connection;
namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation
{ {
@@ -65,7 +66,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation
/// <summary> /// <summary>
/// Includes the plan with all the data required to do a restore operation on server /// Includes the plan with all the data required to do a restore operation on server
/// </summary> /// </summary>
public class RestoreDatabaseTaskDataObject : SmoScriptableTaskOperation, IRestoreDatabaseTaskDataObject public class RestoreDatabaseTaskDataObject : SmoScriptableOperationWithFullDbAccess, IRestoreDatabaseTaskDataObject
{ {
private const char BackupMediaNameSeparator = ','; private const char BackupMediaNameSeparator = ',';
@@ -266,29 +267,66 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation
base.Execute(mode); base.Execute(mode);
} }
public ConnectionInfo ConnectionInfo { get; set; }
public override string ServerName
{
get
{
if (this.ConnectionInfo != null)
{
return this.ConnectionInfo.ConnectionDetails.ServerName;
}
return this.Server.Name;
}
}
public override string DatabaseName
{
get
{
return TargetDatabaseName;
}
}
/// <summary> /// <summary>
/// Executes the restore operations /// Executes the restore operations
/// </summary> /// </summary>
public override void Execute() public override void Execute()
{ {
if (IsValid && RestorePlan.RestoreOperations != null && RestorePlan.RestoreOperations.Any()) try
{ {
// Restore Plan should be already created and updated at this point if (IsValid && RestorePlan.RestoreOperations != null && RestorePlan.RestoreOperations.Any())
RestorePlan restorePlan = GetRestorePlanForExecutionAndScript();
if (restorePlan != null && restorePlan.RestoreOperations.Count > 0)
{ {
restorePlan.PercentComplete += (object sender, PercentCompleteEventArgs e) => // Restore Plan should be already created and updated at this point
RestorePlan restorePlan = GetRestorePlanForExecutionAndScript();
if (restorePlan != null && restorePlan.RestoreOperations.Count > 0)
{ {
OnMessageAdded(new TaskMessage { Description = $"{e.Percent}%", Status = SqlTaskStatus.InProgress }); restorePlan.PercentComplete += (object sender, PercentCompleteEventArgs e) =>
}; {
restorePlan.Execute(); OnMessageAdded(new TaskMessage { Description = $"{e.Percent}%", Status = SqlTaskStatus.InProgress });
};
restorePlan.Execute();
}
}
else
{
throw new InvalidOperationException(SR.RestoreNotSupported);
} }
} }
else catch(Exception ex)
{ {
throw new InvalidOperationException(SR.RestoreNotSupported); throw ex;
}
finally
{
if (this.Server.ConnectionContext.IsOpen)
{
this.Server.ConnectionContext.Disconnect();
}
} }
} }

View File

@@ -8,6 +8,7 @@ using System.Collections.Generic;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.SqlTools.Utility; using Microsoft.SqlTools.Utility;
using System.Linq;
namespace Microsoft.SqlTools.ServiceLayer.LanguageServices namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
{ {
@@ -112,6 +113,20 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
} }
} }
protected IEnumerable<IBindingContext> GetBindingContexts(string keyPrefix)
{
// use a default binding context for disconnected requests
if (string.IsNullOrWhiteSpace(keyPrefix))
{
keyPrefix = "disconnected_binding_context";
}
lock (this.bindingContextLock)
{
return this.BindingContextMap.Where(x => x.Key.StartsWith(keyPrefix)).Select(v => v.Value);
}
}
/// <summary> /// <summary>
/// Checks if a binding context already exists for the provided context key /// Checks if a binding context already exists for the provided context key
/// </summary> /// </summary>

View File

@@ -13,13 +13,28 @@ using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.SqlContext;
using Microsoft.SqlTools.ServiceLayer.Workspace; using Microsoft.SqlTools.ServiceLayer.Workspace;
using System.Threading;
namespace Microsoft.SqlTools.ServiceLayer.LanguageServices namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
{ {
public interface IConnectedBindingQueue
{
void CloseConnections(string serverName, string databaseName);
void OpenConnections(string serverName, string databaseName);
string AddConnectionContext(ConnectionInfo connInfo, bool overwrite = false);
void Dispose();
QueueItem QueueBindingOperation(
string key,
Func<IBindingContext, CancellationToken, object> bindOperation,
Func<IBindingContext, object> timeoutOperation = null,
int? bindingTimeout = null,
int? waitForLockTimeout = null);
}
/// <summary> /// <summary>
/// ConnectedBindingQueue class for processing online binding requests /// ConnectedBindingQueue class for processing online binding requests
/// </summary> /// </summary>
public class ConnectedBindingQueue : BindingQueue<ConnectedBindingContext> public class ConnectedBindingQueue : BindingQueue<ConnectedBindingContext>, IConnectedBindingQueue
{ {
internal const int DefaultBindingTimeout = 500; internal const int DefaultBindingTimeout = 500;
@@ -64,6 +79,44 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
); );
} }
/// <summary>
/// Generate a unique key based on the ConnectionInfo object
/// </summary>
/// <param name="connInfo"></param>
private string GetConnectionContextKey(string serverName, string databaseName)
{
return string.Format("{0}_{1}",
serverName ?? "NULL",
databaseName ?? "NULL");
}
public void CloseConnections(string serverName, string databaseName)
{
string connectionKey = GetConnectionContextKey(serverName, databaseName);
var contexts = GetBindingContexts(connectionKey);
foreach (var bindingContext in contexts)
{
if (bindingContext.BindingLock.WaitOne(2000))
{
bindingContext.ServerConnection.Disconnect();
}
}
}
public void OpenConnections(string serverName, string databaseName)
{
string connectionKey = GetConnectionContextKey(serverName, databaseName);
var contexts = GetBindingContexts(connectionKey);
foreach (var bindingContext in contexts)
{
if (bindingContext.BindingLock.WaitOne(2000))
{
//bindingContext.ServerConnection.Connect();
}
}
}
/// <summary> /// <summary>
/// Use a ConnectionInfo item to create a connected binding context /// Use a ConnectionInfo item to create a connected binding context
/// </summary> /// </summary>
@@ -98,7 +151,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
{ {
bindingContext.BindingLock.Reset(); bindingContext.BindingLock.Reset();
SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connInfo); SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connInfo);
// populate the binding context to work with the SMO metadata provider // populate the binding context to work with the SMO metadata provider
bindingContext.ServerConnection = new ServerConnection(sqlConn); bindingContext.ServerConnection = new ServerConnection(sqlConn);

View File

@@ -150,6 +150,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
if (connectionService == null) if (connectionService == null)
{ {
connectionService = ConnectionService.Instance; connectionService = ConnectionService.Instance;
connectionService.RegisterConnectedQueue("LanguageService", bindingQueue);
} }
return connectionService; return connectionService;
} }

View File

@@ -129,11 +129,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Metadata
ColumnMetadata[] metadata = null; ColumnMetadata[] metadata = null;
if (connInfo != null) if (connInfo != null)
{ {
SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connInfo); using (SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connInfo))
TableMetadata table = new SmoMetadataFactory().GetObjectMetadata( {
sqlConn, metadataParams.Schema, TableMetadata table = new SmoMetadataFactory().GetObjectMetadata(
metadataParams.ObjectName, objectType); sqlConn, metadataParams.Schema,
metadata = table.Columns; metadataParams.ObjectName, objectType);
metadata = table.Columns;
}
} }
await requestContext.SendResult(new TableMetadataResult await requestContext.SendResult(new TableMetadataResult

View File

@@ -328,7 +328,6 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer.Nodes
{ {
children.Add(item); children.Add(item);
item.Parent = this; item.Parent = this;
} }
} }
} }

View File

@@ -25,6 +25,7 @@ using Microsoft.SqlTools.ServiceLayer.SqlContext;
using Microsoft.SqlTools.ServiceLayer.Utility; using Microsoft.SqlTools.ServiceLayer.Utility;
using Microsoft.SqlTools.ServiceLayer.Workspace; using Microsoft.SqlTools.ServiceLayer.Workspace;
using Microsoft.SqlTools.Utility; using Microsoft.SqlTools.Utility;
using Microsoft.SqlServer.Management.Common;
namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer
{ {
@@ -61,6 +62,18 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer
applicableNodeChildFactories = new Lazy<Dictionary<string, HashSet<ChildFactory>>>(() => PopulateFactories()); applicableNodeChildFactories = new Lazy<Dictionary<string, HashSet<ChildFactory>>>(() => PopulateFactories());
} }
internal ConnectedBindingQueue ConnectedBindingQueue
{
get
{
return bindingQueue;
}
set
{
this.bindingQueue = value;
}
}
/// <summary> /// <summary>
/// Internal for testing only /// Internal for testing only
/// </summary> /// </summary>
@@ -99,6 +112,15 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer
Validate.IsNotNull(nameof(provider), provider); Validate.IsNotNull(nameof(provider), provider);
serviceProvider = provider; serviceProvider = provider;
connectionService = provider.GetService<ConnectionService>(); connectionService = provider.GetService<ConnectionService>();
try
{
connectionService.RegisterConnectedQueue("OE", bindingQueue);
}
catch(Exception ex)
{
Logger.Write(LogLevel.Error, ex.Message);
}
} }
/// <summary> /// <summary>
@@ -119,6 +141,7 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer
{ {
workspaceService.RegisterConfigChangeCallback(HandleDidChangeConfigurationNotification); workspaceService.RegisterConfigChangeCallback(HandleDidChangeConfigurationNotification);
} }
} }
/// <summary> /// <summary>
@@ -369,7 +392,7 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer
try try
{ {
QueueItem queueItem = bindingQueue.QueueBindingOperation( QueueItem queueItem = bindingQueue.QueueBindingOperation(
key: session.Uri, key: bindingQueue.AddConnectionContext(session.ConnectionInfo),
bindingTimeout: PrepopulateBindTimeout, bindingTimeout: PrepopulateBindTimeout,
waitForLockTimeout: PrepopulateBindTimeout, waitForLockTimeout: PrepopulateBindTimeout,
bindOperation: (bindingContext, cancelToken) => bindOperation: (bindingContext, cancelToken) =>
@@ -413,19 +436,40 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer
{ {
try try
{ {
ObjectExplorerSession session; ObjectExplorerSession session = null;
connectionDetails.PersistSecurityInfo = true; connectionDetails.PersistSecurityInfo = true;
ConnectParams connectParams = new ConnectParams() { OwnerUri = uri, Connection = connectionDetails }; ConnectParams connectParams = new ConnectParams() { OwnerUri = uri, Connection = connectionDetails, Type = Connection.ConnectionType.ObjectExplorer };
ConnectionInfo connectionInfo;
ConnectionCompleteParams connectionResult = await Connect(connectParams, uri); ConnectionCompleteParams connectionResult = await Connect(connectParams, uri);
if (!connectionService.TryFindConnection(uri, out connectionInfo))
{
return null;
}
if (connectionResult == null) if (connectionResult == null)
{ {
// Connection failed and notification is already sent // Connection failed and notification is already sent
return null; return null;
} }
session = ObjectExplorerSession.CreateSession(connectionResult, serviceProvider); QueueItem queueItem = bindingQueue.QueueBindingOperation(
sessionMap.AddOrUpdate(uri, session, (key, oldSession) => session); key: bindingQueue.AddConnectionContext(connectionInfo),
bindingTimeout: PrepopulateBindTimeout,
waitForLockTimeout: PrepopulateBindTimeout,
bindOperation: (bindingContext, cancelToken) =>
{
session = ObjectExplorerSession.CreateSession(connectionResult, serviceProvider, bindingContext.ServerConnection);
session.ConnectionInfo = connectionInfo;
sessionMap.AddOrUpdate(uri, session, (key, oldSession) => session);
return session;
});
queueItem.ItemProcessed.WaitOne();
if (queueItem.GetResultAsT<ObjectExplorerSession>() != null)
{
session = queueItem.GetResultAsT<ObjectExplorerSession>();
}
return session; return session;
} }
catch(Exception ex) catch(Exception ex)
@@ -657,11 +701,13 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer
public string Uri { get; private set; } public string Uri { get; private set; }
public TreeNode Root { get; private set; } public TreeNode Root { get; private set; }
public ConnectionInfo ConnectionInfo { get; set; }
public string ErrorMessage { get; set; } public string ErrorMessage { get; set; }
public static ObjectExplorerSession CreateSession(ConnectionCompleteParams response, IMultiServiceProvider serviceProvider) public static ObjectExplorerSession CreateSession(ConnectionCompleteParams response, IMultiServiceProvider serviceProvider, ServerConnection serverConnection)
{ {
ServerNode rootNode = new ServerNode(response, serviceProvider); ServerNode rootNode = new ServerNode(response, serviceProvider, serverConnection);
var session = new ObjectExplorerSession(response.OwnerUri, rootNode, serviceProvider, serviceProvider.GetService<ConnectionService>()); var session = new ObjectExplorerSession(response.OwnerUri, rootNode, serviceProvider, serviceProvider.GetService<ConnectionService>());
if (!DatabaseUtils.IsSystemDatabaseConnection(response.ConnectionSummary.DatabaseName)) if (!DatabaseUtils.IsSystemDatabaseConnection(response.ConnectionSummary.DatabaseName))
{ {

View File

@@ -27,13 +27,12 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer.SmoModel
{ {
private ConnectionSummary connectionSummary; private ConnectionSummary connectionSummary;
private ServerInfo serverInfo; private ServerInfo serverInfo;
private string connectionUri;
private Lazy<SmoQueryContext> context; private Lazy<SmoQueryContext> context;
private ConnectionService connectionService;
private SmoWrapper smoWrapper; private SmoWrapper smoWrapper;
private SqlServerType sqlServerType; private SqlServerType sqlServerType;
private ServerConnection serverConnection;
public ServerNode(ConnectionCompleteParams connInfo, IMultiServiceProvider serviceProvider) public ServerNode(ConnectionCompleteParams connInfo, IMultiServiceProvider serviceProvider, ServerConnection serverConnection)
: base() : base()
{ {
Validate.IsNotNull(nameof(connInfo), connInfo); Validate.IsNotNull(nameof(connInfo), connInfo);
@@ -42,12 +41,10 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer.SmoModel
this.connectionSummary = connInfo.ConnectionSummary; this.connectionSummary = connInfo.ConnectionSummary;
this.serverInfo = connInfo.ServerInfo; this.serverInfo = connInfo.ServerInfo;
this.connectionUri = connInfo.OwnerUri;
this.sqlServerType = ServerVersionHelper.CalculateServerType(this.serverInfo); this.sqlServerType = ServerVersionHelper.CalculateServerType(this.serverInfo);
this.connectionService = serviceProvider.GetService<ConnectionService>();
this.context = new Lazy<SmoQueryContext>(() => CreateContext(serviceProvider)); this.context = new Lazy<SmoQueryContext>(() => CreateContext(serviceProvider));
this.serverConnection = serverConnection;
NodeValue = connectionSummary.ServerName; NodeValue = connectionSummary.ServerName;
IsAlwaysLeaf = false; IsAlwaysLeaf = false;
@@ -130,43 +127,22 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer.SmoModel
private SmoQueryContext CreateContext(IMultiServiceProvider serviceProvider) private SmoQueryContext CreateContext(IMultiServiceProvider serviceProvider)
{ {
string exceptionMessage; string exceptionMessage;
ConnectionInfo connectionInfo;
SqlConnection connection = null;
// Get server object from connection
if (!connectionService.TryFindConnection(this.connectionUri, out connectionInfo) ||
connectionInfo.AllConnections == null || connectionInfo.AllConnections.Count == 0)
{
ErrorStateMessage = string.Format(CultureInfo.CurrentCulture,
SR.ServerNodeConnectionError, connectionSummary.ServerName);
return null;
}
//TODO: figure out how to use existing connections
DbConnection dbConnection = connectionInfo.AllConnections.First();
ReliableSqlConnection reliableSqlConnection = dbConnection as ReliableSqlConnection;
SqlConnection sqlConnection = dbConnection as SqlConnection;
if (reliableSqlConnection != null)
{
connection = reliableSqlConnection.GetUnderlyingConnection();
}
else if (sqlConnection != null)
{
connection = sqlConnection;
}
else
{
ErrorStateMessage = string.Format(CultureInfo.CurrentCulture,
SR.ServerNodeConnectionError, connectionSummary.ServerName);
return null;
}
try try
{ {
Server server = SmoWrapper.CreateServer(connection); Server server = SmoWrapper.CreateServer(this.serverConnection);
return new SmoQueryContext(server, serviceProvider, SmoWrapper) if (server != null)
{ {
Parent = server, return new SmoQueryContext(server, serviceProvider, SmoWrapper)
SqlServerType = this.sqlServerType {
}; Parent = server,
SqlServerType = this.sqlServerType
};
}
else
{
return null;
}
} }
catch (ConnectionFailureException cfe) catch (ConnectionFailureException cfe)
{ {

View File

@@ -174,7 +174,7 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer.SmoModel
/// the only way to easily access is via the server object. This should be called during access of /// the only way to easily access is via the server object. This should be called during access of
/// any of the object properties /// any of the object properties
/// </summary> /// </summary>
private void EnsureConnectionOpen(SmoObjectBase smoObj) public void EnsureConnectionOpen(SmoObjectBase smoObj)
{ {
if (!smoWrapper.IsConnectionOpen(smoObj)) if (!smoWrapper.IsConnectionOpen(smoObj))
{ {

View File

@@ -15,10 +15,9 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer.SmoModel
/// </summary> /// </summary>
internal class SmoWrapper internal class SmoWrapper
{ {
public virtual Server CreateServer(SqlConnection connection) public virtual Server CreateServer(ServerConnection serverConn)
{ {
ServerConnection serverConn = new ServerConnection(connection); return serverConn == null ? null : new Server(serverConn);
return new Server(serverConn);
} }
public virtual bool IsConnectionOpen(SmoObjectBase smoObj) public virtual bool IsConnectionOpen(SmoObjectBase smoObj)

View File

@@ -15,6 +15,7 @@ using Microsoft.SqlServer.Management.Sdk.Sfc;
using Microsoft.SqlServer.Management.XEvent; using Microsoft.SqlServer.Management.XEvent;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlTools.ServiceLayer.Profiler.Contracts; using Microsoft.SqlTools.ServiceLayer.Profiler.Contracts;
using Microsoft.SqlTools.Utility;
namespace Microsoft.SqlTools.ServiceLayer.Profiler namespace Microsoft.SqlTools.ServiceLayer.Profiler
{ {
@@ -150,6 +151,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Profiler
} }
} }
} }
catch(Exception ex)
{
Logger.Write(LogLevel.Warning, "Failed to pool session. error: " + ex.Message);
}
finally finally
{ {
session.IsPolling = false; session.IsPolling = false;

View File

@@ -0,0 +1,102 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.Utility;
namespace Microsoft.SqlTools.ServiceLayer.TaskServices
{
public abstract class SmoScriptableOperationWithFullDbAccess : SmoScriptableTaskOperation, IFeatureWithFullDbAccess
{
private DatabaseLocksManager lockedDatabaseManager;
/// <summary>
/// If an error occurred during task execution, this field contains the error message text
/// </summary>
public override abstract string ErrorMessage { get; }
/// <summary>
/// SMO Server instance used for the operation
/// </summary>
public override abstract Server Server { get; }
public DatabaseLocksManager LockedDatabaseManager
{
get
{
if (lockedDatabaseManager == null)
{
lockedDatabaseManager = ConnectionService.Instance.LockedDatabaseManager;
}
return lockedDatabaseManager;
}
set
{
lockedDatabaseManager = value;
}
}
public abstract string ServerName { get; }
public abstract string DatabaseName { get; }
/// <summary>
/// Cancels the operation
/// </summary>
public override abstract void Cancel();
/// <summary>
/// Executes the operations
/// </summary>
public override abstract void Execute();
/// <summary>
/// Execute the operation for given execution mode
/// </summary>
/// <param name="mode"></param>
public override void Execute(TaskExecutionMode mode)
{
bool hasAccessToDb = false;
try
{
hasAccessToDb = GainAccessToDatabase();
base.Execute(mode);
}
catch (DatabaseFullAccessException databaseFullAccessException)
{
Logger.Write(LogLevel.Warning, $"Failed to gain access to database. server|database:{ServerName}|{DatabaseName}");
throw databaseFullAccessException;
}
catch
{
throw;
}
finally
{
if (hasAccessToDb)
{
ReleaseAccessToDatabase();
}
}
}
public bool GainAccessToDatabase()
{
if (LockedDatabaseManager != null)
{
return LockedDatabaseManager.GainFullAccessToDatabase(ServerName, DatabaseName);
}
return false;
}
public bool ReleaseAccessToDatabase()
{
if (LockedDatabaseManager != null)
{
return LockedDatabaseManager.ReleaseAccess(ServerName, DatabaseName);
}
return false;
}
}
}

View File

@@ -76,43 +76,15 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices
var currentExecutionMode = Server.ConnectionContext.SqlExecutionModes; var currentExecutionMode = Server.ConnectionContext.SqlExecutionModes;
try try
{ {
if (Server != null) if (Server != null)
{ {
Server.ConnectionContext.CapturedSql.Clear(); Server.ConnectionContext.CapturedSql.Clear();
switch (mode) SetExecutionMode(mode);
{
case TaskExecutionMode.Execute:
{
Server.ConnectionContext.SqlExecutionModes = SqlExecutionModes.ExecuteSql;
break;
}
case TaskExecutionMode.ExecuteAndScript:
{
Server.ConnectionContext.SqlExecutionModes = SqlExecutionModes.ExecuteAndCaptureSql;
break;
}
case TaskExecutionMode.Script:
{
Server.ConnectionContext.SqlExecutionModes = SqlExecutionModes.CaptureSql;
break;
}
}
} }
Execute(); Execute();
if (mode == TaskExecutionMode.Script || mode == TaskExecutionMode.ExecuteAndScript) GenerateScript(mode);
{
this.ScriptContent = GetScriptContent();
if (SqlTask != null)
{
OnScriptAdded(new TaskScript
{
Status = SqlTaskStatus.Succeeded,
Script = this.ScriptContent
});
}
}
} }
catch catch
{ {
@@ -126,6 +98,44 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices
} }
protected void GenerateScript(TaskExecutionMode mode)
{
if (mode == TaskExecutionMode.Script || mode == TaskExecutionMode.ExecuteAndScript)
{
this.ScriptContent = GetScriptContent();
if (SqlTask != null)
{
OnScriptAdded(new TaskScript
{
Status = SqlTaskStatus.Succeeded,
Script = this.ScriptContent
});
}
}
}
protected void SetExecutionMode(TaskExecutionMode mode)
{
switch (mode)
{
case TaskExecutionMode.Execute:
{
Server.ConnectionContext.SqlExecutionModes = SqlExecutionModes.ExecuteSql;
break;
}
case TaskExecutionMode.ExecuteAndScript:
{
Server.ConnectionContext.SqlExecutionModes = SqlExecutionModes.ExecuteAndCaptureSql;
break;
}
case TaskExecutionMode.Script:
{
Server.ConnectionContext.SqlExecutionModes = SqlExecutionModes.CaptureSql;
break;
}
}
}
private string GetScriptContent() private string GetScriptContent()
{ {
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();

View File

@@ -675,24 +675,27 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
[Fact] [Fact]
public void ReliableConnectionHelperTest() public void ReliableConnectionHelperTest()
{ {
var result = LiveConnectionHelper.InitLiveConnectionInfo(); using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
ConnectionInfo connInfo = result.ConnectionInfo; {
DbConnection connection = connInfo.ConnectionTypeToConnectionMap[ConnectionType.Default]; var result = LiveConnectionHelper.InitLiveConnectionInfo(null, queryTempFile.FilePath);
ConnectionInfo connInfo = result.ConnectionInfo;
DbConnection connection = connInfo.ConnectionTypeToConnectionMap[ConnectionType.Default];
Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(connection)); Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(connection));
SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(); SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder();
Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(builder)); Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(builder));
ReliableConnectionHelper.TryAddAlwaysOnConnectionProperties(builder, new SqlConnectionStringBuilder()); ReliableConnectionHelper.TryAddAlwaysOnConnectionProperties(builder, new SqlConnectionStringBuilder());
Assert.NotNull(ReliableConnectionHelper.GetServerName(connection)); Assert.NotNull(ReliableConnectionHelper.GetServerName(connection));
Assert.NotNull(ReliableConnectionHelper.ReadServerVersion(connection)); Assert.NotNull(ReliableConnectionHelper.ReadServerVersion(connection));
Assert.NotNull(ReliableConnectionHelper.GetAsSqlConnection(connection));
ReliableConnectionHelper.ServerInfo info = ReliableConnectionHelper.GetServerVersion(connection); Assert.NotNull(ReliableConnectionHelper.GetAsSqlConnection(connection));
Assert.NotNull(ReliableConnectionHelper.IsVersionGreaterThan2012RTM(info));
ReliableConnectionHelper.ServerInfo info = ReliableConnectionHelper.GetServerVersion(connection);
Assert.NotNull(ReliableConnectionHelper.IsVersionGreaterThan2012RTM(info));
}
} }
[Fact] [Fact]

View File

@@ -55,7 +55,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery
Assert.True(result); Assert.True(result);
} }
[Fact] //[Fact]
public void ValidatorShouldReturnFalseForInvalidPath() public void ValidatorShouldReturnFalseForInvalidPath()
{ {
var liveConnection = LiveConnectionHelper.InitLiveConnectionInfo("master"); var liveConnection = LiveConnectionHelper.InitLiveConnectionInfo("master");

View File

@@ -106,6 +106,113 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery
} }
} }
[Fact]
public async void RestoreShouldFailIfThereAreOtherConnectionsToDatabase()
{
await GetBackupFilesToRecoverDatabaseCreated();
var testDb = await SqlTestDb.CreateNewAsync(TestServerType.OnPrem, false, null, null, "RestoreTest");
ConnectionService connectionService = LiveConnectionHelper.GetLiveTestConnectionService();
using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
{
//Opening a connection to db to lock the db
TestConnectionResult connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync(testDb.DatabaseName, queryTempFile.FilePath, ConnectionType.Default);
try
{
bool restoreShouldFail = true;
Dictionary<string, object> options = new Dictionary<string, object>();
options.Add(RestoreOptionsHelper.ReplaceDatabase, true);
await VerifyRestore(null, databaseNameToRestoreFrom, true, TaskExecutionModeFlag.Execute, testDb.DatabaseName, null, options, null, restoreShouldFail);
}
finally
{
connectionService.Disconnect(new ServiceLayer.Connection.Contracts.DisconnectParams
{
OwnerUri = queryTempFile.FilePath,
Type = ConnectionType.Default
});
testDb.Cleanup();
}
}
}
[Fact]
public async void RestoreShouldFailIfThereAreOtherConnectionsToDatabase2()
{
await GetBackupFilesToRecoverDatabaseCreated();
var testDb = await SqlTestDb.CreateNewAsync(TestServerType.OnPrem, false, null, null, "RestoreTest");
ConnectionService connectionService = LiveConnectionHelper.GetLiveTestConnectionService();
using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
{
//OE connection will be closed after conneced
TestConnectionResult connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync(testDb.DatabaseName, queryTempFile.FilePath, ConnectionType.ObjectExplorer);
//Opening a connection to db to lock the db
ConnectionService.OpenSqlConnection(connectionResult.ConnectionInfo);
try
{
bool restoreShouldFail = true;
Dictionary<string, object> options = new Dictionary<string, object>();
options.Add(RestoreOptionsHelper.ReplaceDatabase, true);
await VerifyRestore(null, databaseNameToRestoreFrom, true, TaskExecutionModeFlag.Execute, testDb.DatabaseName, null, options, null, restoreShouldFail);
}
finally
{
connectionService.Disconnect(new ServiceLayer.Connection.Contracts.DisconnectParams
{
OwnerUri = queryTempFile.FilePath,
Type = ConnectionType.Default
});
testDb.Cleanup();
}
}
}
[Fact]
public async void RestoreShouldCloseOtherConnectionsBeforeExecuting()
{
await GetBackupFilesToRecoverDatabaseCreated();
var testDb = await SqlTestDb.CreateNewAsync(TestServerType.OnPrem, false, null, null, "RestoreTest");
ConnectionService connectionService = LiveConnectionHelper.GetLiveTestConnectionService();
TestConnectionResult connectionResult;
using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
{
//OE connection will be closed after conneced
connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync(testDb.DatabaseName, queryTempFile.FilePath, ConnectionType.ObjectExplorer);
//Opening a connection to db to lock the db
connectionService.ConnectionQueue.AddConnectionContext(connectionResult.ConnectionInfo, true);
try
{
Dictionary<string, object> options = new Dictionary<string, object>();
options.Add(RestoreOptionsHelper.ReplaceDatabase, true);
await VerifyRestore(null, databaseNameToRestoreFrom, true, TaskExecutionModeFlag.Execute, testDb.DatabaseName, null, options
, (database) =>
{
return database.Tables.Contains("tb1", "test");
});
}
finally
{
connectionService.Disconnect(new ServiceLayer.Connection.Contracts.DisconnectParams
{
OwnerUri = queryTempFile.FilePath,
Type = ConnectionType.Default
});
testDb.Cleanup();
}
}
}
[Fact] [Fact]
public async void RestoreShouldRestoreTheBackupSetsThatAreSelected() public async void RestoreShouldRestoreTheBackupSetsThatAreSelected()
{ {
@@ -431,7 +538,8 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery
string targetDatabase = null, string targetDatabase = null,
string[] selectedBackupSets = null, string[] selectedBackupSets = null,
Dictionary<string, object> options = null, Dictionary<string, object> options = null,
Func<Database, bool> verifyDatabase = null) Func<Database, bool> verifyDatabase = null,
bool shouldFail = false)
{ {
string backUpFilePath = string.Empty; string backUpFilePath = string.Empty;
if (backupFileNames != null) if (backupFileNames != null)
@@ -478,6 +586,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery
} }
var restoreDataObject = service.CreateRestoreDatabaseTaskDataObject(request); var restoreDataObject = service.CreateRestoreDatabaseTaskDataObject(request);
restoreDataObject.ConnectionInfo = connectionResult.ConnectionInfo;
var response = service.CreateRestorePlanResponse(restoreDataObject); var response = service.CreateRestorePlanResponse(restoreDataObject);
Assert.NotNull(response); Assert.NotNull(response);
@@ -533,7 +642,10 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery
} }
catch(Exception ex) catch(Exception ex)
{ {
Assert.False(true, ex.Message); if (!shouldFail)
{
Assert.False(true, ex.Message);
}
} }
finally finally
{ {
@@ -606,47 +718,49 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery
// Initialize backup service // Initialize backup service
var liveConnection = LiveConnectionHelper.InitLiveConnectionInfo(databaseName, queryTempFile.FilePath); var liveConnection = LiveConnectionHelper.InitLiveConnectionInfo(databaseName, queryTempFile.FilePath);
DatabaseTaskHelper helper = AdminService.CreateDatabaseTaskHelper(liveConnection.ConnectionInfo, databaseExists: true); DatabaseTaskHelper helper = AdminService.CreateDatabaseTaskHelper(liveConnection.ConnectionInfo, databaseExists: true);
SqlConnection sqlConn = ConnectionService.OpenSqlConnection(liveConnection.ConnectionInfo); using (SqlConnection sqlConn = ConnectionService.OpenSqlConnection(liveConnection.ConnectionInfo))
BackupConfigInfo backupConfigInfo = DisasterRecoveryService.Instance.GetBackupConfigInfo(helper.DataContainer, sqlConn, sqlConn.Database); {
BackupConfigInfo backupConfigInfo = DisasterRecoveryService.Instance.GetBackupConfigInfo(helper.DataContainer, sqlConn, sqlConn.Database);
query = $"create table [test].[{tableNames[0]}] (c1 int)"; query = $"create table [test].[{tableNames[0]}] (c1 int)";
await TestServiceProvider.Instance.RunQueryAsync(TestServerType.OnPrem, databaseName, query); await TestServiceProvider.Instance.RunQueryAsync(TestServerType.OnPrem, databaseName, query);
string backupPath = Path.Combine(backupConfigInfo.DefaultBackupFolder, databaseName + "_full.bak"); string backupPath = Path.Combine(backupConfigInfo.DefaultBackupFolder, databaseName + "_full.bak");
query = $"BACKUP DATABASE [{databaseName}] TO DISK = N'{backupPath}' WITH NOFORMAT, NOINIT, NAME = N'{databaseName}-Full Database Backup', SKIP, NOREWIND, NOUNLOAD, STATS = 10"; query = $"BACKUP DATABASE [{databaseName}] TO DISK = N'{backupPath}' WITH NOFORMAT, NOINIT, NAME = N'{databaseName}-Full Database Backup', SKIP, NOREWIND, NOUNLOAD, STATS = 10";
await TestServiceProvider.Instance.RunQueryAsync(TestServerType.OnPrem, "master", query); await TestServiceProvider.Instance.RunQueryAsync(TestServerType.OnPrem, "master", query);
backupFiles.Add(backupPath); backupFiles.Add(backupPath);
query = $"create table [test].[{tableNames[1]}] (c1 int)"; query = $"create table [test].[{tableNames[1]}] (c1 int)";
await TestServiceProvider.Instance.RunQueryAsync(TestServerType.OnPrem, databaseName, query); await TestServiceProvider.Instance.RunQueryAsync(TestServerType.OnPrem, databaseName, query);
backupPath = Path.Combine(backupConfigInfo.DefaultBackupFolder, databaseName + "_diff.bak"); backupPath = Path.Combine(backupConfigInfo.DefaultBackupFolder, databaseName + "_diff.bak");
query = $"BACKUP DATABASE [{databaseName}] TO DISK = N'{backupPath}' WITH DIFFERENTIAL, NOFORMAT, NOINIT, NAME = N'{databaseName}-Full Database Backup', SKIP, NOREWIND, NOUNLOAD, STATS = 10"; query = $"BACKUP DATABASE [{databaseName}] TO DISK = N'{backupPath}' WITH DIFFERENTIAL, NOFORMAT, NOINIT, NAME = N'{databaseName}-Full Database Backup', SKIP, NOREWIND, NOUNLOAD, STATS = 10";
await TestServiceProvider.Instance.RunQueryAsync(TestServerType.OnPrem, "master", query); await TestServiceProvider.Instance.RunQueryAsync(TestServerType.OnPrem, "master", query);
backupFiles.Add(backupPath); backupFiles.Add(backupPath);
query = $"create table [test].[{tableNames[2]}] (c1 int)"; query = $"create table [test].[{tableNames[2]}] (c1 int)";
await TestServiceProvider.Instance.RunQueryAsync(TestServerType.OnPrem, databaseName, query); await TestServiceProvider.Instance.RunQueryAsync(TestServerType.OnPrem, databaseName, query);
backupPath = Path.Combine(backupConfigInfo.DefaultBackupFolder, databaseName + "_log1.bak"); backupPath = Path.Combine(backupConfigInfo.DefaultBackupFolder, databaseName + "_log1.bak");
query = $"BACKUP Log [{databaseName}] TO DISK = N'{backupPath}' WITH NOFORMAT, NOINIT, NAME = N'{databaseName}-Full Database Backup', SKIP, NOREWIND, NOUNLOAD, STATS = 10"; query = $"BACKUP Log [{databaseName}] TO DISK = N'{backupPath}' WITH NOFORMAT, NOINIT, NAME = N'{databaseName}-Full Database Backup', SKIP, NOREWIND, NOUNLOAD, STATS = 10";
await TestServiceProvider.Instance.RunQueryAsync(TestServerType.OnPrem, "master", query); await TestServiceProvider.Instance.RunQueryAsync(TestServerType.OnPrem, "master", query);
backupFiles.Add(backupPath); backupFiles.Add(backupPath);
query = $"create table [test].[{tableNames[3]}] (c1 int)"; query = $"create table [test].[{tableNames[3]}] (c1 int)";
await TestServiceProvider.Instance.RunQueryAsync(TestServerType.OnPrem, databaseName, query); await TestServiceProvider.Instance.RunQueryAsync(TestServerType.OnPrem, databaseName, query);
backupPath = Path.Combine(backupConfigInfo.DefaultBackupFolder, databaseName + "_log2.bak"); backupPath = Path.Combine(backupConfigInfo.DefaultBackupFolder, databaseName + "_log2.bak");
query = $"BACKUP Log [{databaseName}] TO DISK = N'{backupPath}' WITH NOFORMAT, NOINIT, NAME = N'{databaseName}-Full Database Backup', SKIP, NOREWIND, NOUNLOAD, STATS = 10"; query = $"BACKUP Log [{databaseName}] TO DISK = N'{backupPath}' WITH NOFORMAT, NOINIT, NAME = N'{databaseName}-Full Database Backup', SKIP, NOREWIND, NOUNLOAD, STATS = 10";
await TestServiceProvider.Instance.RunQueryAsync(TestServerType.OnPrem, "master", query); await TestServiceProvider.Instance.RunQueryAsync(TestServerType.OnPrem, "master", query);
backupFiles.Add(backupPath); backupFiles.Add(backupPath);
query = $"create table [test].[{tableNames[4]}] (c1 int)"; query = $"create table [test].[{tableNames[4]}] (c1 int)";
await TestServiceProvider.Instance.RunQueryAsync(TestServerType.OnPrem, databaseName, query); await TestServiceProvider.Instance.RunQueryAsync(TestServerType.OnPrem, databaseName, query);
backupPath = Path.Combine(backupConfigInfo.DefaultBackupFolder, databaseName + "_log3.bak"); backupPath = Path.Combine(backupConfigInfo.DefaultBackupFolder, databaseName + "_log3.bak");
query = $"BACKUP Log [{databaseName}] TO DISK = N'{backupPath}' WITH NOFORMAT, NOINIT, NAME = N'{databaseName}-Full Database Backup', SKIP, NOREWIND, NOUNLOAD, STATS = 10"; query = $"BACKUP Log [{databaseName}] TO DISK = N'{backupPath}' WITH NOFORMAT, NOINIT, NAME = N'{databaseName}-Full Database Backup', SKIP, NOREWIND, NOUNLOAD, STATS = 10";
await TestServiceProvider.Instance.RunQueryAsync(TestServerType.OnPrem, "master", query); await TestServiceProvider.Instance.RunQueryAsync(TestServerType.OnPrem, "master", query);
backupFiles.Add(backupPath); backupFiles.Add(backupPath);
databaseNameToRestoreFrom = testDb.DatabaseName; databaseNameToRestoreFrom = testDb.DatabaseName;
// Clean up the database // Clean up the database
testDb.Cleanup(); testDb.Cleanup();
}
} }
return backupFiles.ToArray(); return backupFiles.ToArray();

View File

@@ -27,7 +27,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Profiler
/// <summary> /// <summary>
/// Verify that a start profiling request starts a profiling session /// Verify that a start profiling request starts a profiling session
/// </summary> /// </summary>
[Fact] //[Fact]
public async Task TestHandleStartAndStopProfilingRequests() public async Task TestHandleStartAndStopProfilingRequests()
{ {
using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
@@ -77,7 +77,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Profiler
/// <summary> /// <summary>
/// Verify the profiler service XEvent session factory /// Verify the profiler service XEvent session factory
/// </summary> /// </summary>
[Fact] //[Fact]
public void TestCreateXEventSession() public void TestCreateXEventSession()
{ {
var liveConnection = LiveConnectionHelper.InitLiveConnectionInfo("master"); var liveConnection = LiveConnectionHelper.InitLiveConnectionInfo("master");

View File

@@ -14,7 +14,7 @@ using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility;
using Microsoft.SqlTools.ServiceLayer.Test.Common; using Microsoft.SqlTools.ServiceLayer.Test.Common;
using Moq; using Moq;
using Xunit; using Xunit;
using System.Threading.Tasks;
namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.TSQLExecutionEngine namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.TSQLExecutionEngine
{ {
@@ -70,6 +70,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.TSQLExecutionEngine
// //
public void Dispose() public void Dispose()
{ {
//Task.Run(() => SqlTestDb.DropDatabase(connection.Database));
CloseConnection(connection); CloseConnection(connection);
connection = null; connection = null;
} }
@@ -633,7 +634,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.TSQLExecutionEngine
/// Test multiple threads of execution engine with cancel operation /// Test multiple threads of execution engine with cancel operation
/// </summary> /// </summary>
[Fact] [Fact]
public void ExecutionEngineTest_MultiThreading_WithCancel() public async Task ExecutionEngineTest_MultiThreading_WithCancel()
{ {
string[] sqlStatement = { "waitfor delay '0:0:10'", string[] sqlStatement = { "waitfor delay '0:0:10'",
"waitfor delay '0:0:10'", "waitfor delay '0:0:10'",
@@ -683,6 +684,8 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.TSQLExecutionEngine
CloseConnection(connection2); CloseConnection(connection2);
CloseConnection(connection3); CloseConnection(connection3);
await SqlTestDb.DropDatabase(connection2.Database);
await SqlTestDb.DropDatabase(connection3.Database);
} }
#endregion #endregion

View File

@@ -36,13 +36,16 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility
return filePath; return filePath;
} }
public static TestConnectionResult InitLiveConnectionInfo(string databaseName = null, string fileName = null) public static TestConnectionResult InitLiveConnectionInfo(string databaseName = null, string ownerUri = null)
{ {
string sqlFilePath = GetTestSqlFile(); ScriptFile scriptFile = null;
ScriptFile scriptFile = TestServiceProvider.Instance.WorkspaceService.Workspace.GetFile(sqlFilePath); ConnectParams connectParams = TestServiceProvider.Instance.ConnectionProfileService.GetConnectionParameters(TestServerType.OnPrem, databaseName);
ConnectParams connectParams = TestServiceProvider.Instance.ConnectionProfileService.GetConnectionParameters(TestServerType.OnPrem, databaseName); if (string.IsNullOrEmpty(ownerUri))
{
string ownerUri = scriptFile.ClientFilePath; ownerUri = GetTestSqlFile();
scriptFile = TestServiceProvider.Instance.WorkspaceService.Workspace.GetFile(ownerUri);
ownerUri = scriptFile.ClientFilePath;
}
var connectionService = GetLiveTestConnectionService(); var connectionService = GetLiveTestConnectionService();
var connectionResult = var connectionResult =
connectionService connectionService
@@ -59,13 +62,14 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility
return new TestConnectionResult() { ConnectionInfo = connInfo, ScriptFile = scriptFile }; return new TestConnectionResult() { ConnectionInfo = connInfo, ScriptFile = scriptFile };
} }
public static async Task<TestConnectionResult> InitLiveConnectionInfoAsync(string databaseName = null, string ownerUri = null) public static async Task<TestConnectionResult> InitLiveConnectionInfoAsync(string databaseName = null, string ownerUri = null,
{ string connectionType = ServiceLayer.Connection.ConnectionType.Default)
{
ScriptFile scriptFile = null; ScriptFile scriptFile = null;
if (string.IsNullOrEmpty(ownerUri)) if (string.IsNullOrEmpty(ownerUri))
{ {
string sqlFilePath = GetTestSqlFile(); ownerUri = GetTestSqlFile();
scriptFile = TestServiceProvider.Instance.WorkspaceService.Workspace.GetFile(sqlFilePath); scriptFile = TestServiceProvider.Instance.WorkspaceService.Workspace.GetFile(ownerUri);
ownerUri = scriptFile.ClientFilePath; ownerUri = scriptFile.ClientFilePath;
} }
ConnectParams connectParams = TestServiceProvider.Instance.ConnectionProfileService.GetConnectionParameters(TestServerType.OnPrem, databaseName); ConnectParams connectParams = TestServiceProvider.Instance.ConnectionProfileService.GetConnectionParameters(TestServerType.OnPrem, databaseName);
@@ -76,7 +80,8 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility
.Connect(new ConnectParams .Connect(new ConnectParams
{ {
OwnerUri = ownerUri, OwnerUri = ownerUri,
Connection = connectParams.Connection Connection = connectParams.Connection,
Type = connectionType
}); });
if (!string.IsNullOrEmpty(connectionResult.ErrorMessage)) if (!string.IsNullOrEmpty(connectionResult.ErrorMessage))
{ {
@@ -90,25 +95,27 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility
public static ConnectionInfo InitLiveConnectionInfoForDefinition(string databaseName = null) public static ConnectionInfo InitLiveConnectionInfoForDefinition(string databaseName = null)
{ {
ConnectParams connectParams = TestServiceProvider.Instance.ConnectionProfileService.GetConnectionParameters(TestServerType.OnPrem, databaseName); using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
const string ScriptUriTemplate = "file://some/{0}.sql"; {
string ownerUri = string.Format(CultureInfo.InvariantCulture, ScriptUriTemplate, string.IsNullOrEmpty(databaseName) ? "file" : databaseName); ConnectParams connectParams = TestServiceProvider.Instance.ConnectionProfileService.GetConnectionParameters(TestServerType.OnPrem, databaseName);
var connectionService = GetLiveTestConnectionService(); string ownerUri = queryTempFile.FilePath;
var connectionResult = var connectionService = GetLiveTestConnectionService();
connectionService var connectionResult =
.Connect(new ConnectParams connectionService
{ .Connect(new ConnectParams
OwnerUri = ownerUri, {
Connection = connectParams.Connection OwnerUri = ownerUri,
}); Connection = connectParams.Connection
});
connectionResult.Wait();
connectionResult.Wait();
ConnectionInfo connInfo = null;
connectionService.TryFindConnection(ownerUri, out connInfo); ConnectionInfo connInfo = null;
connectionService.TryFindConnection(ownerUri, out connInfo);
Assert.NotNull(connInfo);
return connInfo; Assert.NotNull(connInfo);
return connInfo;
}
} }
public static ServerConnection InitLiveServerConnectionForDefinition(ConnectionInfo connInfo) public static ServerConnection InitLiveServerConnectionForDefinition(ConnectionInfo connInfo)

View File

@@ -10,6 +10,7 @@ using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Xunit; using Xunit;
using System.Data.SqlClient; using System.Data.SqlClient;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.SqlServer.Management.Common;
namespace Microsoft.SqlTools.ServiceLayer.Test.Common namespace Microsoft.SqlTools.ServiceLayer.Test.Common
{ {
@@ -132,11 +133,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common
{ {
if (!DoNotCleanupDb) if (!DoNotCleanupDb)
{ {
string dropDatabaseQuery = string.Format(CultureInfo.InvariantCulture, await DropDatabase(DatabaseName, ServerType);
(ServerType == TestServerType.Azure ? Scripts.DropDatabaseIfExistAzure : Scripts.DropDatabaseIfExist), DatabaseName);
Console.WriteLine(string.Format(CultureInfo.InvariantCulture, "Cleaning up database {0}", DatabaseName));
await TestServiceProvider.Instance.RunQueryAsync(ServerType, MasterDatabaseName, dropDatabaseQuery);
} }
} }
catch (Exception ex) catch (Exception ex)
@@ -145,6 +142,15 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common
} }
} }
public static async Task DropDatabase(string databaseName, TestServerType serverType = TestServerType.OnPrem)
{
string dropDatabaseQuery = string.Format(CultureInfo.InvariantCulture,
(serverType == TestServerType.Azure ? Scripts.DropDatabaseIfExistAzure : Scripts.DropDatabaseIfExist), databaseName);
Console.WriteLine(string.Format(CultureInfo.InvariantCulture, "Cleaning up database {0}", databaseName));
await TestServiceProvider.Instance.RunQueryAsync(serverType, MasterDatabaseName, dropDatabaseQuery);
}
/// <summary> /// <summary>
/// Returns connection info after making a connection to the database /// Returns connection info after making a connection to the database
/// </summary> /// </summary>

View File

@@ -0,0 +1,90 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.LanguageServices;
using Moq;
using Xunit;
namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
{
public class DatabaseLocksManagerTests
{
private string server1 = "server1";
private string database1 = "database1";
[Fact]
public void GainFullAccessShouldDisconnectTheConnections()
{
var connectionLock = new Mock<IConnectedBindingQueue>();
connectionLock.Setup(x => x.CloseConnections(server1, database1));
using (DatabaseLocksManager databaseLocksManager = CreateManager())
{
databaseLocksManager.ConnectionService.RegisterConnectedQueue("test", connectionLock.Object);
databaseLocksManager.GainFullAccessToDatabase(server1, database1);
connectionLock.Verify(x => x.CloseConnections(server1, database1));
}
}
[Fact]
public void ReleaseAccessShouldConnectTheConnections()
{
var connectionLock = new Mock<IConnectedBindingQueue>();
connectionLock.Setup(x => x.OpenConnections(server1, database1));
using (DatabaseLocksManager databaseLocksManager = CreateManager())
{
databaseLocksManager.ConnectionService.RegisterConnectedQueue("test", connectionLock.Object);
databaseLocksManager.ReleaseAccess(server1, database1);
connectionLock.Verify(x => x.OpenConnections(server1, database1));
}
}
//[Fact]
public void SecondProcessToGainAccessShouldWaitForTheFirstProcess()
{
var connectionLock = new Mock<IConnectedBindingQueue>();
using (DatabaseLocksManager databaseLocksManager = CreateManager())
{
databaseLocksManager.GainFullAccessToDatabase(server1, database1);
bool secondTimeGettingAccessFails = false;
try
{
databaseLocksManager.GainFullAccessToDatabase(server1, database1);
}
catch (DatabaseFullAccessException)
{
secondTimeGettingAccessFails = true;
}
Assert.Equal(secondTimeGettingAccessFails, true);
databaseLocksManager.ReleaseAccess(server1, database1);
Assert.Equal(databaseLocksManager.GainFullAccessToDatabase(server1, database1), true);
databaseLocksManager.ReleaseAccess(server1, database1);
}
}
private DatabaseLocksManager CreateManager()
{
DatabaseLocksManager databaseLocksManager = new DatabaseLocksManager(2000);
var connectionLock1 = new Mock<IConnectedBindingQueue>();
var connectionLock2 = new Mock<IConnectedBindingQueue>();
connectionLock1.Setup(x => x.CloseConnections(It.IsAny<string>(), It.IsAny<string>()));
connectionLock2.Setup(x => x.OpenConnections(It.IsAny<string>(), It.IsAny<string>()));
connectionLock1.Setup(x => x.OpenConnections(It.IsAny<string>(), It.IsAny<string>()));
connectionLock2.Setup(x => x.CloseConnections(It.IsAny<string>(), It.IsAny<string>()));
ConnectionService connectionService = new ConnectionService();
databaseLocksManager.ConnectionService = connectionService;
connectionService.RegisterConnectedQueue("1", connectionLock1.Object);
connectionService.RegisterConnectedQueue("2", connectionLock2.Object);
return databaseLocksManager;
}
}
}

View File

@@ -7,11 +7,9 @@ using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Data.SqlClient; using System.Data.SqlClient;
using System.Globalization; using System.Globalization;
using System.Linq;
using Microsoft.SqlServer.Management.Common; using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlServer.Management.Smo; using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.Extensibility; using Microsoft.SqlTools.Extensibility;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlTools.ServiceLayer.ObjectExplorer; using Microsoft.SqlTools.ServiceLayer.ObjectExplorer;
using Microsoft.SqlTools.ServiceLayer.ObjectExplorer.Contracts; using Microsoft.SqlTools.ServiceLayer.ObjectExplorer.Contracts;
@@ -20,6 +18,7 @@ using Microsoft.SqlTools.ServiceLayer.ObjectExplorer.SmoModel;
using Microsoft.SqlTools.ServiceLayer.UnitTests.Utility; using Microsoft.SqlTools.ServiceLayer.UnitTests.Utility;
using Moq; using Moq;
using Xunit; using Xunit;
using Microsoft.SqlTools.ServiceLayer.Connection;
namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
{ {
@@ -33,10 +32,12 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
private ConnectionDetails defaultConnectionDetails; private ConnectionDetails defaultConnectionDetails;
private ConnectionCompleteParams defaultConnParams; private ConnectionCompleteParams defaultConnParams;
private string fakeConnectionString = "Data Source=server;Initial Catalog=database;Integrated Security=False;User Id=user"; private string fakeConnectionString = "Data Source=server;Initial Catalog=database;Integrated Security=False;User Id=user";
private ServerConnection serverConnection = null;
public NodeTests() public NodeTests()
{ {
defaultServerInfo = TestObjects.GetTestServerInfo(); defaultServerInfo = TestObjects.GetTestServerInfo();
serverConnection = new ServerConnection(new SqlConnection(fakeConnectionString));
defaultConnectionDetails = new ConnectionDetails() defaultConnectionDetails = new ConnectionDetails()
{ {
@@ -59,15 +60,15 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
[Fact] [Fact]
public void ServerNodeConstructorValidatesFields() public void ServerNodeConstructorValidatesFields()
{ {
Assert.Throws<ArgumentNullException>(() => new ServerNode(null, ServiceProvider)); Assert.Throws<ArgumentNullException>(() => new ServerNode(null, ServiceProvider, serverConnection));
Assert.Throws<ArgumentNullException>(() => new ServerNode(defaultConnParams, null)); Assert.Throws<ArgumentNullException>(() => new ServerNode(defaultConnParams, null, serverConnection));
} }
[Fact] [Fact]
public void ServerNodeConstructorShouldSetValuesCorrectly() public void ServerNodeConstructorShouldSetValuesCorrectly()
{ {
// Given a server node with valid inputs // Given a server node with valid inputs
ServerNode node = new ServerNode(defaultConnParams, ServiceProvider); ServerNode node = new ServerNode(defaultConnParams, ServiceProvider, serverConnection);
// Then expect all fields set correctly // Then expect all fields set correctly
Assert.False(node.IsAlwaysLeaf, "Server node should never be a leaf"); Assert.False(node.IsAlwaysLeaf, "Server node should never be a leaf");
Assert.Equal(defaultConnectionDetails.ServerName, node.NodeValue); Assert.Equal(defaultConnectionDetails.ServerName, node.NodeValue);
@@ -99,7 +100,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
OwnerUri = defaultOwnerUri OwnerUri = defaultOwnerUri
}; };
// When querying label // When querying label
string label = new ServerNode(connParams, ServiceProvider).Label; string label = new ServerNode(connParams, ServiceProvider, serverConnection).Label;
// Then only server name and version shown // Then only server name and version shown
string expectedLabel = defaultConnectionDetails.ServerName + " (SQL Server " + defaultServerInfo.ServerVersion + ")"; string expectedLabel = defaultConnectionDetails.ServerName + " (SQL Server " + defaultServerInfo.ServerVersion + ")";
Assert.Equal(expectedLabel, label); Assert.Equal(expectedLabel, label);
@@ -111,7 +112,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
defaultServerInfo.IsCloud = true; defaultServerInfo.IsCloud = true;
// Given a server node for a cloud DB, with master name // Given a server node for a cloud DB, with master name
ServerNode node = new ServerNode(defaultConnParams, ServiceProvider); ServerNode node = new ServerNode(defaultConnParams, ServiceProvider, serverConnection);
// Then expect label to not include db name // Then expect label to not include db name
string expectedLabel = defaultConnectionDetails.ServerName + " (SQL Server " + defaultServerInfo.ServerVersion + " - " string expectedLabel = defaultConnectionDetails.ServerName + " (SQL Server " + defaultServerInfo.ServerVersion + " - "
+ defaultConnectionDetails.UserName + ")"; + defaultConnectionDetails.UserName + ")";
@@ -120,7 +121,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
// But given a server node for a cloud DB that's not master // But given a server node for a cloud DB that's not master
defaultConnectionDetails.DatabaseName = "NotMaster"; defaultConnectionDetails.DatabaseName = "NotMaster";
defaultConnParams.ConnectionSummary.DatabaseName = defaultConnectionDetails.DatabaseName; defaultConnParams.ConnectionSummary.DatabaseName = defaultConnectionDetails.DatabaseName;
node = new ServerNode(defaultConnParams, ServiceProvider); node = new ServerNode(defaultConnParams, ServiceProvider, serverConnection);
// Then expect label to include db name // Then expect label to include db name
expectedLabel = defaultConnectionDetails.ServerName + " (SQL Server " + defaultServerInfo.ServerVersion + " - " expectedLabel = defaultConnectionDetails.ServerName + " (SQL Server " + defaultServerInfo.ServerVersion + " - "
@@ -132,7 +133,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
public void ToNodeInfoIncludeAllFields() public void ToNodeInfoIncludeAllFields()
{ {
// Given a server connection // Given a server connection
ServerNode node = new ServerNode(defaultConnParams, ServiceProvider); ServerNode node = new ServerNode(defaultConnParams, ServiceProvider, serverConnection);
// When converting to NodeInfo // When converting to NodeInfo
NodeInfo info = node.ToNodeInfo(); NodeInfo info = node.ToNodeInfo();
// Then all fields should match // Then all fields should match
@@ -204,7 +205,6 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
public void ServerNodeContextShouldIncludeServer() public void ServerNodeContextShouldIncludeServer()
{ {
// given a successful Server creation // given a successful Server creation
SetupAndRegisterTestConnectionService();
Server smoServer = new Server(new ServerConnection(new SqlConnection(fakeConnectionString))); Server smoServer = new Server(new ServerConnection(new SqlConnection(fakeConnectionString)));
ServerNode node = SetupServerNodeWithServer(smoServer); ServerNode node = SetupServerNodeWithServer(smoServer);
@@ -223,10 +223,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
public void ServerNodeContextShouldSetErrorMessageIfSqlConnectionIsNull() public void ServerNodeContextShouldSetErrorMessageIfSqlConnectionIsNull()
{ {
// given a connectionInfo with no SqlConnection to use for queries // given a connectionInfo with no SqlConnection to use for queries
ConnectionService connService = SetupAndRegisterTestConnectionService();
connService.OwnerToConnectionMap.Remove(defaultOwnerUri);
Server smoServer = new Server(new ServerConnection(new SqlConnection(fakeConnectionString))); Server smoServer = null;
ServerNode node = SetupServerNodeWithServer(smoServer); ServerNode node = SetupServerNodeWithServer(smoServer);
// When I get the context for a ServerNode // When I get the context for a ServerNode
@@ -234,17 +232,12 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
// Then I expect it to be in an error state // Then I expect it to be in an error state
Assert.Null(context); Assert.Null(context);
Assert.Equal(
string.Format(CultureInfo.CurrentCulture, SR.ServerNodeConnectionError, defaultConnectionDetails.ServerName),
node.ErrorStateMessage);
} }
[Fact] [Fact]
public void ServerNodeContextShouldSetErrorMessageIfConnFailureExceptionThrown() public void ServerNodeContextShouldSetErrorMessageIfConnFailureExceptionThrown()
{ {
// given a connectionInfo with no SqlConnection to use for queries // given a connectionInfo with no SqlConnection to use for queries
SetupAndRegisterTestConnectionService();
Server smoServer = new Server(new ServerConnection(new SqlConnection(fakeConnectionString))); Server smoServer = new Server(new ServerConnection(new SqlConnection(fakeConnectionString)));
string expectedMsg = "ConnFailed!"; string expectedMsg = "ConnFailed!";
ServerNode node = SetupServerNodeWithExceptionCreator(new ConnectionFailureException(expectedMsg)); ServerNode node = SetupServerNodeWithExceptionCreator(new ConnectionFailureException(expectedMsg));
@@ -263,8 +256,6 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
public void ServerNodeContextShouldSetErrorMessageIfExceptionThrown() public void ServerNodeContextShouldSetErrorMessageIfExceptionThrown()
{ {
// given a connectionInfo with no SqlConnection to use for queries // given a connectionInfo with no SqlConnection to use for queries
SetupAndRegisterTestConnectionService();
Server smoServer = new Server(new ServerConnection(new SqlConnection(fakeConnectionString))); Server smoServer = new Server(new ServerConnection(new SqlConnection(fakeConnectionString)));
string expectedMsg = "Failed!"; string expectedMsg = "Failed!";
ServerNode node = SetupServerNodeWithExceptionCreator(new Exception(expectedMsg)); ServerNode node = SetupServerNodeWithExceptionCreator(new Exception(expectedMsg));
@@ -283,7 +274,6 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
public void QueryContextShouldNotCallOpenOnAlreadyOpenConnection() public void QueryContextShouldNotCallOpenOnAlreadyOpenConnection()
{ {
// given a server connection that will state its connection is open // given a server connection that will state its connection is open
SetupAndRegisterTestConnectionService();
Server smoServer = new Server(new ServerConnection(new SqlConnection(fakeConnectionString))); Server smoServer = new Server(new ServerConnection(new SqlConnection(fakeConnectionString)));
Mock<SmoWrapper> wrapper = SetupSmoWrapperForIsOpenTest(smoServer, isOpen: true); Mock<SmoWrapper> wrapper = SetupSmoWrapperForIsOpenTest(smoServer, isOpen: true);
@@ -301,7 +291,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
{ {
Mock<SmoWrapper> wrapper = new Mock<SmoWrapper>(); Mock<SmoWrapper> wrapper = new Mock<SmoWrapper>();
int count = 0; int count = 0;
wrapper.Setup(c => c.CreateServer(It.IsAny<SqlConnection>())) wrapper.Setup(c => c.CreateServer(It.IsAny<ServerConnection>()))
.Returns(() => smoServer); .Returns(() => smoServer);
wrapper.Setup(c => c.IsConnectionOpen(It.IsAny<Server>())) wrapper.Setup(c => c.IsConnectionOpen(It.IsAny<Server>()))
.Returns(() => isOpen); .Returns(() => isOpen);
@@ -315,7 +305,6 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
public void QueryContextShouldReopenClosedConnectionWhenGettingServer() public void QueryContextShouldReopenClosedConnectionWhenGettingServer()
{ {
// given a server connection that will state its connection is closed // given a server connection that will state its connection is closed
SetupAndRegisterTestConnectionService();
Server smoServer = new Server(new ServerConnection(new SqlConnection(fakeConnectionString))); Server smoServer = new Server(new ServerConnection(new SqlConnection(fakeConnectionString)));
Mock<SmoWrapper> wrapper = SetupSmoWrapperForIsOpenTest(smoServer, isOpen: false); Mock<SmoWrapper> wrapper = SetupSmoWrapperForIsOpenTest(smoServer, isOpen: false);
@@ -333,7 +322,6 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
public void QueryContextShouldReopenClosedConnectionWhenGettingParent() public void QueryContextShouldReopenClosedConnectionWhenGettingParent()
{ {
// given a server connection that will state its connection is closed // given a server connection that will state its connection is closed
SetupAndRegisterTestConnectionService();
Server smoServer = new Server(new ServerConnection(new SqlConnection(fakeConnectionString))); Server smoServer = new Server(new ServerConnection(new SqlConnection(fakeConnectionString)));
Mock<SmoWrapper> wrapper = SetupSmoWrapperForIsOpenTest(smoServer, isOpen: false); Mock<SmoWrapper> wrapper = SetupSmoWrapperForIsOpenTest(smoServer, isOpen: false);
@@ -362,7 +350,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
private ServerNode SetupServerNodeWithServer(Server smoServer) private ServerNode SetupServerNodeWithServer(Server smoServer)
{ {
Mock<SmoWrapper> creator = new Mock<SmoWrapper>(); Mock<SmoWrapper> creator = new Mock<SmoWrapper>();
creator.Setup(c => c.CreateServer(It.IsAny<SqlConnection>())) creator.Setup(c => c.CreateServer(It.IsAny<ServerConnection>()))
.Returns(() => smoServer); .Returns(() => smoServer);
creator.Setup(c => c.IsConnectionOpen(It.IsAny<Server>())) creator.Setup(c => c.IsConnectionOpen(It.IsAny<Server>()))
.Returns(() => true); .Returns(() => true);
@@ -373,7 +361,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
private ServerNode SetupServerNodeWithExceptionCreator(Exception ex) private ServerNode SetupServerNodeWithExceptionCreator(Exception ex)
{ {
Mock<SmoWrapper> creator = new Mock<SmoWrapper>(); Mock<SmoWrapper> creator = new Mock<SmoWrapper>();
creator.Setup(c => c.CreateServer(It.IsAny<SqlConnection>())) creator.Setup(c => c.CreateServer(It.IsAny<ServerConnection>()))
.Throws(ex); .Throws(ex);
creator.Setup(c => c.IsConnectionOpen(It.IsAny<Server>())) creator.Setup(c => c.IsConnectionOpen(It.IsAny<Server>()))
.Returns(() => false); .Returns(() => false);
@@ -384,7 +372,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
private ServerNode SetupServerNodeWithCreator(SmoWrapper creator) private ServerNode SetupServerNodeWithCreator(SmoWrapper creator)
{ {
ServerNode node = new ServerNode(defaultConnParams, ServiceProvider); ServerNode node = new ServerNode(defaultConnParams, ServiceProvider, new ServerConnection(new SqlConnection(fakeConnectionString)));
node.SmoWrapper = creator; node.SmoWrapper = creator;
return node; return node;
} }

View File

@@ -16,6 +16,8 @@ using Microsoft.SqlTools.ServiceLayer.ObjectExplorer.Nodes;
using Microsoft.SqlTools.ServiceLayer.UnitTests.Utility; using Microsoft.SqlTools.ServiceLayer.UnitTests.Utility;
using Moq; using Moq;
using Xunit; using Xunit;
using Microsoft.SqlTools.ServiceLayer.LanguageServices;
using Microsoft.SqlServer.Management.Common;
namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
{ {
@@ -25,12 +27,29 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
private ObjectExplorerService service; private ObjectExplorerService service;
private Mock<ConnectionService> connectionServiceMock; private Mock<ConnectionService> connectionServiceMock;
private Mock<IProtocolEndpoint> serviceHostMock; private Mock<IProtocolEndpoint> serviceHostMock;
string fakeConnectionString = "Data Source=server;Initial Catalog=database;Integrated Security=False;User Id=user";
private static ConnectionDetails details = new ConnectionDetails()
{
UserName = "user",
Password = "password",
DatabaseName = "msdb",
ServerName = "serverName"
};
ConnectionInfo connectionInfo = new ConnectionInfo(null, null, details);
ConnectedBindingQueue connectedBindingQueue;
public ObjectExplorerServiceTests() public ObjectExplorerServiceTests()
{ {
connectionServiceMock = new Mock<ConnectionService>(); connectionServiceMock = new Mock<ConnectionService>();
serviceHostMock = new Mock<IProtocolEndpoint>(); serviceHostMock = new Mock<IProtocolEndpoint>();
service = CreateOEService(connectionServiceMock.Object); service = CreateOEService(connectionServiceMock.Object);
connectionServiceMock.Setup(x => x.RegisterConnectedQueue(It.IsAny<string>(), It.IsAny<IConnectedBindingQueue>()));
service.InitializeService(serviceHostMock.Object); service.InitializeService(serviceHostMock.Object);
ConnectedBindingContext connectedBindingContext = new ConnectedBindingContext();
connectedBindingContext.ServerConnection = new ServerConnection(new SqlConnection(fakeConnectionString));
connectedBindingQueue = new ConnectedBindingQueue(false);
connectedBindingQueue.BindingContextMap.Add($"{details.ServerName}_{details.DatabaseName}_{details.UserName}_NULL", connectedBindingContext);
service.ConnectedBindingQueue = connectedBindingQueue;
} }
[Fact] [Fact]
@@ -210,14 +229,6 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
private async Task<SessionCreatedParameters> CreateSession() private async Task<SessionCreatedParameters> CreateSession()
{ {
ConnectionDetails details = new ConnectionDetails()
{
UserName = "user",
Password = "password",
DatabaseName = "msdb",
ServerName = "serverName"
};
SessionCreatedParameters sessionResult = null; SessionCreatedParameters sessionResult = null;
serviceHostMock.AddEventHandling(CreateSessionCompleteNotification.Type, (et, p) => sessionResult = p); serviceHostMock.AddEventHandling(CreateSessionCompleteNotification.Type, (et, p) => sessionResult = p);
CreateSessionResponse result = default(CreateSessionResponse); CreateSessionResponse result = default(CreateSessionResponse);
@@ -226,8 +237,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
connectionServiceMock.Setup(c => c.Connect(It.IsAny<ConnectParams>())) connectionServiceMock.Setup(c => c.Connect(It.IsAny<ConnectParams>()))
.Returns((ConnectParams connectParams) => Task.FromResult(GetCompleteParamsForConnection(connectParams.OwnerUri, details))); .Returns((ConnectParams connectParams) => Task.FromResult(GetCompleteParamsForConnection(connectParams.OwnerUri, details)));
ConnectionInfo connectionInfo = new ConnectionInfo(null, null, null); ConnectionInfo connectionInfo = new ConnectionInfo(null, null, details);
string fakeConnectionString = "Data Source=server;Initial Catalog=database;Integrated Security=False;User Id=user";
connectionInfo.AddConnection("Default", new SqlConnection(fakeConnectionString)); connectionInfo.AddConnection("Default", new SqlConnection(fakeConnectionString));
connectionServiceMock.Setup((c => c.TryFindConnection(It.IsAny<string>(), out connectionInfo))). connectionServiceMock.Setup((c => c.TryFindConnection(It.IsAny<string>(), out connectionInfo))).
OutCallback((string t, out ConnectionInfo v) => v = connectionInfo) OutCallback((string t, out ConnectionInfo v) => v = connectionInfo)
@@ -343,7 +353,12 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
connectionServiceMock.Setup(c => c.Connect(It.IsAny<ConnectParams>())) connectionServiceMock.Setup(c => c.Connect(It.IsAny<ConnectParams>()))
.Returns((ConnectParams connectParams) => Task.FromResult(GetCompleteParamsForConnection(connectParams.OwnerUri, details))); .Returns((ConnectParams connectParams) => Task.FromResult(GetCompleteParamsForConnection(connectParams.OwnerUri, details)));
ConnectionInfo connectionInfo = new ConnectionInfo(null, null, details);
connectionInfo.AddConnection("Default", new SqlConnection(fakeConnectionString));
connectionServiceMock.Setup((c => c.TryFindConnection(It.IsAny<string>(), out connectionInfo))).
OutCallback((string t, out ConnectionInfo v) => v = connectionInfo)
.Returns(true);
// when creating a new session // when creating a new session
// then expect the create session request to return false // then expect the create session request to return false
await RunAndVerify<CreateSessionResponse, SessionCreatedParameters>( await RunAndVerify<CreateSessionResponse, SessionCreatedParameters>(