mirror of
https://github.com/ckaczor/sqltoolsservice.git
synced 2026-02-16 10:58:30 -05:00
Update connection logic to handle multiple connections per URI (#176)
* Add CancelTokenKey for uniquely identifying cancelations of Connections associated with an OwnerUri and ConnectionType string. * Update ConnectionInfo to use ConcurrentDictionary of DbConnection instances. Add wrapper functions for the ConcurrentDictionary. * Refactor Connect and Disconnect in ConnectionService. * Update ConnectionService: Handle multiple connections per ConnectionInfo. Handle cancelation tokens uniquely identified with CancelTokenKey. Add GetOrOpenConnection() for other services to request an existing or create a new DbConnection. * Add ConnectionType.cs for ConnectionType strings. * Add ConnectionType string to ConnectParams, ConnectionCompleteNotification, DisconnectParams. * Update Query ExecuteInternal to use the dedicated query connection and GetOrOpenConnection(). * Update test library to account for multiple connections in ConnectionInfo. * Write tests ensuring multiple connections don’t create redundant data. * Write tests ensuring database changes affect all connections of a given ConnectionInfo. * Write tests for TRANSACTION statements and temp tables.
This commit is contained in:
@@ -0,0 +1,36 @@
|
|||||||
|
using System;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using System.Linq;
|
||||||
|
using System.Threading.Tasks;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
|
||||||
|
|
||||||
|
namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Used to uniquely identify a CancellationTokenSource associated with both
|
||||||
|
/// a string URI and a string connection type.
|
||||||
|
/// </summary>
|
||||||
|
public class CancelTokenKey : CancelConnectParams, IEquatable<CancelTokenKey>
|
||||||
|
{
|
||||||
|
public override bool Equals(object obj)
|
||||||
|
{
|
||||||
|
CancelTokenKey other = obj as CancelTokenKey;
|
||||||
|
if (other == null)
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return other.OwnerUri == OwnerUri && other.Type == Type;
|
||||||
|
}
|
||||||
|
|
||||||
|
public bool Equals(CancelTokenKey obj)
|
||||||
|
{
|
||||||
|
return obj.OwnerUri == OwnerUri && obj.Type == Type;
|
||||||
|
}
|
||||||
|
|
||||||
|
public override int GetHashCode()
|
||||||
|
{
|
||||||
|
return OwnerUri.GetHashCode() ^ Type.GetHashCode();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,8 +4,12 @@
|
|||||||
//
|
//
|
||||||
|
|
||||||
using System;
|
using System;
|
||||||
|
using System.Collections.Concurrent;
|
||||||
|
using System.Collections.Generic;
|
||||||
using System.Data.Common;
|
using System.Data.Common;
|
||||||
|
using System.Linq;
|
||||||
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
|
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Utility;
|
||||||
|
|
||||||
namespace Microsoft.SqlTools.ServiceLayer.Connection
|
namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||||
{
|
{
|
||||||
@@ -47,9 +51,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
public ConnectionDetails ConnectionDetails { get; private set; }
|
public ConnectionDetails ConnectionDetails { get; private set; }
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// The connection to the SQL database that commands will be run against.
|
/// A map containing all connections to the database that are associated with
|
||||||
|
/// this ConnectionInfo's OwnerUri.
|
||||||
|
/// This is internal for testing access only
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public DbConnection SqlConnection { get; set; }
|
internal readonly ConcurrentDictionary<string, DbConnection> ConnectionTypeToConnectionMap =
|
||||||
|
new ConcurrentDictionary<string, DbConnection>();
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Intellisense Metrics
|
/// Intellisense Metrics
|
||||||
@@ -61,7 +68,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
/// </summary>
|
/// </summary>
|
||||||
public bool IsAzure { get; set; }
|
public bool IsAzure { get; set; }
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Returns true if the sql connection is to a DW instance
|
/// Returns true if the sql connection is to a DW instance
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public bool IsSqlDW { get; set; }
|
public bool IsSqlDW { get; set; }
|
||||||
@@ -71,5 +77,86 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
/// </summary>
|
/// </summary>
|
||||||
public int MajorVersion { get; set; }
|
public int MajorVersion { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// All DbConnection instances held by this ConnectionInfo
|
||||||
|
/// </summary>
|
||||||
|
public ICollection<DbConnection> AllConnections
|
||||||
|
{
|
||||||
|
get
|
||||||
|
{
|
||||||
|
return ConnectionTypeToConnectionMap.Values;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// All connection type strings held by this ConnectionInfo
|
||||||
|
/// </summary>
|
||||||
|
/// <returns></returns>
|
||||||
|
public ICollection<string> AllConnectionTypes
|
||||||
|
{
|
||||||
|
get
|
||||||
|
{
|
||||||
|
return ConnectionTypeToConnectionMap.Keys;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// The count of DbConnectioninstances held by this ConnectionInfo
|
||||||
|
/// </summary>
|
||||||
|
public int CountConnections
|
||||||
|
{
|
||||||
|
get
|
||||||
|
{
|
||||||
|
return ConnectionTypeToConnectionMap.Count;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Try to get the DbConnection associated with the given connection type string.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if a connection with type connectionType was located and out connection was set,
|
||||||
|
/// false otherwise </returns>
|
||||||
|
/// <exception cref="ArgumentException">Thrown when connectionType is null or empty</exception>
|
||||||
|
public bool TryGetConnection(string connectionType, out DbConnection connection)
|
||||||
|
{
|
||||||
|
Validate.IsNotNullOrEmptyString("Connection Type", connectionType);
|
||||||
|
return ConnectionTypeToConnectionMap.TryGetValue(connectionType, out connection);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Adds a DbConnection to this object and associates it with the given
|
||||||
|
/// connection type string. If a connection already exists with an identical
|
||||||
|
/// connection type string, it is not overwritten. Ignores calls where connectionType = null
|
||||||
|
/// </summary>
|
||||||
|
/// <exception cref="ArgumentException">Thrown when connectionType is null or empty</exception>
|
||||||
|
public void AddConnection(string connectionType, DbConnection connection)
|
||||||
|
{
|
||||||
|
Validate.IsNotNullOrEmptyString("Connection Type", connectionType);
|
||||||
|
ConnectionTypeToConnectionMap.TryAdd(connectionType, connection);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Removes the single DbConnection instance associated with string connectionType
|
||||||
|
/// </summary>
|
||||||
|
/// <exception cref="ArgumentException">Thrown when connectionType is null or empty</exception>
|
||||||
|
public void RemoveConnection(string connectionType)
|
||||||
|
{
|
||||||
|
Validate.IsNotNullOrEmptyString("Connection Type", connectionType);
|
||||||
|
DbConnection connection;
|
||||||
|
ConnectionTypeToConnectionMap.TryRemove(connectionType, out connection);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Removes all DbConnection instances held by this object
|
||||||
|
/// </summary>
|
||||||
|
public void RemoveAllConnections()
|
||||||
|
{
|
||||||
|
foreach (var type in AllConnectionTypes)
|
||||||
|
{
|
||||||
|
DbConnection connection;
|
||||||
|
ConnectionTypeToConnectionMap.TryRemove(type, out connection);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,7 +52,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
|
|
||||||
private readonly Dictionary<string, ConnectionInfo> ownerToConnectionMap = new Dictionary<string, ConnectionInfo>();
|
private readonly Dictionary<string, ConnectionInfo> ownerToConnectionMap = new Dictionary<string, ConnectionInfo>();
|
||||||
|
|
||||||
private readonly ConcurrentDictionary<string, CancellationTokenSource> ownerToCancellationTokenSourceMap = new ConcurrentDictionary<string, CancellationTokenSource>();
|
/// <summary>
|
||||||
|
/// A map containing all CancellationTokenSource objects that are associated with a given URI/ConnectionType pair.
|
||||||
|
/// Entries in this map correspond to DbConnection instances that are in the process of connecting.
|
||||||
|
/// </summary>
|
||||||
|
private readonly ConcurrentDictionary<CancelTokenKey, CancellationTokenSource> cancelTupleToCancellationTokenSourceMap =
|
||||||
|
new ConcurrentDictionary<CancelTokenKey, CancellationTokenSource>();
|
||||||
|
|
||||||
private readonly object cancellationTokenSourceLock = new object();
|
private readonly object cancellationTokenSourceLock = new object();
|
||||||
|
|
||||||
@@ -120,6 +125,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
}
|
}
|
||||||
return this.connectionFactory;
|
return this.connectionFactory;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal set { this.connectionFactory = value; }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
@@ -138,12 +145,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Open a connection with the specified connection details
|
/// Validates the given ConnectParams object.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="connectionParams"></param>
|
/// <param name="connectionParams">The params to validate</param>
|
||||||
public async Task<ConnectionCompleteParams> Connect(ConnectParams connectionParams)
|
/// <returns>A ConnectionCompleteParams object upon validation error,
|
||||||
|
/// null upon validation success</returns>
|
||||||
|
public ConnectionCompleteParams ValidateConnectParams(ConnectParams connectionParams)
|
||||||
{
|
{
|
||||||
// Validate parameters
|
|
||||||
string paramValidationErrorMessage;
|
string paramValidationErrorMessage;
|
||||||
if (connectionParams == null)
|
if (connectionParams == null)
|
||||||
{
|
{
|
||||||
@@ -161,29 +169,139 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve if it is an existing connection
|
// return null upon success
|
||||||
// Disconnect active connection if the URI is already connected
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Open a connection with the specified ConnectParams
|
||||||
|
/// </summary>
|
||||||
|
public async Task<ConnectionCompleteParams> Connect(ConnectParams connectionParams)
|
||||||
|
{
|
||||||
|
// Validate parameters
|
||||||
|
ConnectionCompleteParams validationResults = ValidateConnectParams(connectionParams);
|
||||||
|
if (validationResults != null)
|
||||||
|
{
|
||||||
|
return validationResults;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there is no ConnectionInfo in the map, create a new ConnectionInfo,
|
||||||
|
// but wait until later when we are connected to add it to the map.
|
||||||
ConnectionInfo connectionInfo;
|
ConnectionInfo connectionInfo;
|
||||||
if (ownerToConnectionMap.TryGetValue(connectionParams.OwnerUri, out connectionInfo) )
|
if (!ownerToConnectionMap.TryGetValue(connectionParams.OwnerUri, out connectionInfo))
|
||||||
|
{
|
||||||
|
connectionInfo = new ConnectionInfo(ConnectionFactory, connectionParams.OwnerUri, connectionParams.Connection);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve if it is an existing connection
|
||||||
|
// Disconnect active connection if the URI is already connected for this connection type
|
||||||
|
DbConnection existingConnection;
|
||||||
|
if (connectionInfo.TryGetConnection(connectionParams.Type, out existingConnection))
|
||||||
{
|
{
|
||||||
var disconnectParams = new DisconnectParams()
|
var disconnectParams = new DisconnectParams()
|
||||||
{
|
{
|
||||||
OwnerUri = connectionParams.OwnerUri
|
OwnerUri = connectionParams.OwnerUri,
|
||||||
|
Type = connectionParams.Type
|
||||||
};
|
};
|
||||||
Disconnect(disconnectParams);
|
Disconnect(disconnectParams);
|
||||||
}
|
}
|
||||||
connectionInfo = new ConnectionInfo(ConnectionFactory, connectionParams.OwnerUri, connectionParams.Connection);
|
|
||||||
|
|
||||||
// try to connect
|
// Try to open a connection with the given ConnectParams
|
||||||
var response = new ConnectionCompleteParams {OwnerUri = connectionParams.OwnerUri};
|
ConnectionCompleteParams response = await TryOpenConnection(connectionInfo, connectionParams);
|
||||||
|
if (response != null)
|
||||||
|
{
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this is the first connection for this URI, add the ConnectionInfo to the map
|
||||||
|
if (!ownerToConnectionMap.ContainsKey(connectionParams.OwnerUri))
|
||||||
|
{
|
||||||
|
ownerToConnectionMap[connectionParams.OwnerUri] = connectionInfo;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invoke callback notifications
|
||||||
|
InvokeOnConnectionActivities(connectionInfo, connectionParams);
|
||||||
|
|
||||||
|
// Return information about the connected SQL Server instance
|
||||||
|
return GetConnectionCompleteParams(connectionParams.Type, connectionInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Creates a ConnectionCompleteParams as a response to a successful connection.
|
||||||
|
/// Also sets the DatabaseName and IsAzure properties of ConnectionInfo.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>A ConnectionCompleteParams in response to the successful connection</returns>
|
||||||
|
private ConnectionCompleteParams GetConnectionCompleteParams(string connectionType, ConnectionInfo connectionInfo)
|
||||||
|
{
|
||||||
|
ConnectionCompleteParams response = new ConnectionCompleteParams { OwnerUri = connectionInfo.OwnerUri, Type = connectionType };
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
DbConnection connection;
|
||||||
|
connectionInfo.TryGetConnection(connectionType, out connection);
|
||||||
|
|
||||||
|
// Update with the actual database name in connectionInfo and result
|
||||||
|
// Doing this here as we know the connection is open - expect to do this only on connecting
|
||||||
|
connectionInfo.ConnectionDetails.DatabaseName = connection.Database;
|
||||||
|
response.ConnectionSummary = new ConnectionSummary
|
||||||
|
{
|
||||||
|
ServerName = connectionInfo.ConnectionDetails.ServerName,
|
||||||
|
DatabaseName = connectionInfo.ConnectionDetails.DatabaseName,
|
||||||
|
UserName = connectionInfo.ConnectionDetails.UserName,
|
||||||
|
};
|
||||||
|
|
||||||
|
response.ConnectionId = connectionInfo.ConnectionId.ToString();
|
||||||
|
|
||||||
|
var reliableConnection = connection as ReliableSqlConnection;
|
||||||
|
DbConnection underlyingConnection = reliableConnection != null
|
||||||
|
? reliableConnection.GetUnderlyingConnection()
|
||||||
|
: connection;
|
||||||
|
|
||||||
|
ReliableConnectionHelper.ServerInfo serverInfo = ReliableConnectionHelper.GetServerVersion(underlyingConnection);
|
||||||
|
response.ServerInfo = new ServerInfo
|
||||||
|
{
|
||||||
|
ServerMajorVersion = serverInfo.ServerMajorVersion,
|
||||||
|
ServerMinorVersion = serverInfo.ServerMinorVersion,
|
||||||
|
ServerReleaseVersion = serverInfo.ServerReleaseVersion,
|
||||||
|
EngineEditionId = serverInfo.EngineEditionId,
|
||||||
|
ServerVersion = serverInfo.ServerVersion,
|
||||||
|
ServerLevel = serverInfo.ServerLevel,
|
||||||
|
ServerEdition = serverInfo.ServerEdition,
|
||||||
|
IsCloud = serverInfo.IsCloud,
|
||||||
|
AzureVersion = serverInfo.AzureVersion,
|
||||||
|
OsVersion = serverInfo.OsVersion
|
||||||
|
};
|
||||||
|
connectionInfo.IsAzure = serverInfo.IsCloud;
|
||||||
|
connectionInfo.MajorVersion = serverInfo.ServerMajorVersion;
|
||||||
|
connectionInfo.IsSqlDW = (serverInfo.EngineEditionId == (int)DatabaseEngineEdition.SqlDataWarehouse);
|
||||||
|
}
|
||||||
|
catch (Exception ex)
|
||||||
|
{
|
||||||
|
response.Messages = ex.ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Tries to create and open a connection with the given ConnectParams.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>null upon success, a ConnectionCompleteParams detailing the error upon failure</returns>
|
||||||
|
private async Task<ConnectionCompleteParams> TryOpenConnection(ConnectionInfo connectionInfo, ConnectParams connectionParams)
|
||||||
|
{
|
||||||
CancellationTokenSource source = null;
|
CancellationTokenSource source = null;
|
||||||
|
DbConnection connection = null;
|
||||||
|
CancelTokenKey cancelKey = new CancelTokenKey { OwnerUri = connectionParams.OwnerUri, Type = connectionParams.Type };
|
||||||
|
ConnectionCompleteParams response = new ConnectionCompleteParams { OwnerUri = connectionInfo.OwnerUri, Type = connectionParams.Type };
|
||||||
|
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
// build the connection string from the input parameters
|
// build the connection string from the input parameters
|
||||||
string connectionString = BuildConnectionString(connectionInfo.ConnectionDetails);
|
string connectionString = BuildConnectionString(connectionInfo.ConnectionDetails);
|
||||||
|
|
||||||
// create a sql connection instance
|
// create a sql connection instance
|
||||||
connectionInfo.SqlConnection = connectionInfo.Factory.CreateSqlConnection(connectionString);
|
connection = connectionInfo.Factory.CreateSqlConnection(connectionString);
|
||||||
|
connectionInfo.AddConnection(connectionParams.Type, connection);
|
||||||
|
|
||||||
// Add a cancellation token source so that the connection OpenAsync() can be cancelled
|
// Add a cancellation token source so that the connection OpenAsync() can be cancelled
|
||||||
using (source = new CancellationTokenSource())
|
using (source = new CancellationTokenSource())
|
||||||
@@ -193,11 +311,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
{
|
{
|
||||||
// If the URI is currently connecting from a different request, cancel it before we try to connect
|
// If the URI is currently connecting from a different request, cancel it before we try to connect
|
||||||
CancellationTokenSource currentSource;
|
CancellationTokenSource currentSource;
|
||||||
if (ownerToCancellationTokenSourceMap.TryGetValue(connectionParams.OwnerUri, out currentSource))
|
if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out currentSource))
|
||||||
{
|
{
|
||||||
currentSource.Cancel();
|
currentSource.Cancel();
|
||||||
}
|
}
|
||||||
ownerToCancellationTokenSourceMap[connectionParams.OwnerUri] = source;
|
cancelTupleToCancellationTokenSourceMap[cancelKey] = source;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a task to handle cancellation requests
|
// Create a task to handle cancellation requests
|
||||||
@@ -214,8 +332,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
var openTask = Task.Run(async () => {
|
var openTask = Task.Run(async () =>
|
||||||
await connectionInfo.SqlConnection.OpenAsync(source.Token);
|
{
|
||||||
|
await connection.OpenAsync(source.Token);
|
||||||
});
|
});
|
||||||
|
|
||||||
// Open the connection
|
// Open the connection
|
||||||
@@ -250,60 +369,61 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
{
|
{
|
||||||
// Only remove the token from the map if it is the same one created by this request
|
// Only remove the token from the map if it is the same one created by this request
|
||||||
CancellationTokenSource sourceValue;
|
CancellationTokenSource sourceValue;
|
||||||
if (ownerToCancellationTokenSourceMap.TryGetValue(connectionParams.OwnerUri, out sourceValue) && sourceValue == source)
|
if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out sourceValue) && sourceValue == source)
|
||||||
{
|
{
|
||||||
ownerToCancellationTokenSourceMap.TryRemove(connectionParams.OwnerUri, out sourceValue);
|
cancelTupleToCancellationTokenSourceMap.TryRemove(cancelKey, out sourceValue);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ownerToConnectionMap[connectionParams.OwnerUri] = connectionInfo;
|
// Return null upon success
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
// Update with the actual database name in connectionInfo and result
|
/// <summary>
|
||||||
// Doing this here as we know the connection is open - expect to do this only on connecting
|
/// Gets the existing connection with the given URI and connection type string. If none exists,
|
||||||
connectionInfo.ConnectionDetails.DatabaseName = connectionInfo.SqlConnection.Database;
|
/// creates a new connection. This cannot be used to create a default connection or to create a
|
||||||
response.ConnectionSummary = new ConnectionSummary
|
/// connection if a default connection does not exist.
|
||||||
|
/// </summary>
|
||||||
|
public async Task<DbConnection> GetOrOpenConnection(string ownerUri, string connectionType)
|
||||||
{
|
{
|
||||||
ServerName = connectionInfo.ConnectionDetails.ServerName,
|
if (string.IsNullOrEmpty(ownerUri) || string.IsNullOrEmpty(connectionType))
|
||||||
DatabaseName = connectionInfo.ConnectionDetails.DatabaseName,
|
{
|
||||||
UserName = connectionInfo.ConnectionDetails.UserName,
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to get the ConnectionInfo, if it exists
|
||||||
|
ConnectionInfo connectionInfo = ownerToConnectionMap[ownerUri];
|
||||||
|
if (connectionInfo == null)
|
||||||
|
{
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure a default connection exists
|
||||||
|
DbConnection defaultConnection;
|
||||||
|
if (!connectionInfo.TryGetConnection(ConnectionType.Default, out defaultConnection))
|
||||||
|
{
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to get the DbConnection
|
||||||
|
DbConnection connection;
|
||||||
|
if (!connectionInfo.TryGetConnection(connectionType, out connection) && ConnectionType.Default != connectionType)
|
||||||
|
{
|
||||||
|
// If the DbConnection does not exist and is not the default connection, create one.
|
||||||
|
// We can't create the default (initial) connection here because we won't have a ConnectionDetails
|
||||||
|
// if Connect() has not yet been called.
|
||||||
|
ConnectParams connectParams = new ConnectParams()
|
||||||
|
{
|
||||||
|
OwnerUri = ownerUri,
|
||||||
|
Connection = connectionInfo.ConnectionDetails,
|
||||||
|
Type = connectionType
|
||||||
};
|
};
|
||||||
|
await Connect(connectParams);
|
||||||
// invoke callback notifications
|
connectionInfo.TryGetConnection(connectionType, out connection);
|
||||||
InvokeOnConnectionActivities(connectionInfo);
|
|
||||||
|
|
||||||
// try to get information about the connected SQL Server instance
|
|
||||||
try
|
|
||||||
{
|
|
||||||
var reliableConnection = connectionInfo.SqlConnection as ReliableSqlConnection;
|
|
||||||
DbConnection connection = reliableConnection != null ? reliableConnection.GetUnderlyingConnection() : connectionInfo.SqlConnection;
|
|
||||||
|
|
||||||
ReliableConnectionHelper.ServerInfo serverInfo = ReliableConnectionHelper.GetServerVersion(connection);
|
|
||||||
response.ServerInfo = new ServerInfo
|
|
||||||
{
|
|
||||||
ServerMajorVersion = serverInfo.ServerMajorVersion,
|
|
||||||
ServerMinorVersion = serverInfo.ServerMinorVersion,
|
|
||||||
ServerReleaseVersion = serverInfo.ServerReleaseVersion,
|
|
||||||
EngineEditionId = serverInfo.EngineEditionId,
|
|
||||||
ServerVersion = serverInfo.ServerVersion,
|
|
||||||
ServerLevel = serverInfo.ServerLevel,
|
|
||||||
ServerEdition = serverInfo.ServerEdition,
|
|
||||||
IsCloud = serverInfo.IsCloud,
|
|
||||||
AzureVersion = serverInfo.AzureVersion,
|
|
||||||
OsVersion = serverInfo.OsVersion
|
|
||||||
};
|
|
||||||
connectionInfo.IsAzure = serverInfo.IsCloud;
|
|
||||||
connectionInfo.MajorVersion = serverInfo.ServerMajorVersion;
|
|
||||||
connectionInfo.IsSqlDW = (serverInfo.EngineEditionId == (int)DatabaseEngineEdition.SqlDataWarehouse);
|
|
||||||
}
|
|
||||||
catch(Exception ex)
|
|
||||||
{
|
|
||||||
response.Messages = ex.ToString();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// return the connection result
|
return connection;
|
||||||
response.ConnectionId = connectionInfo.ConnectionId.ToString();
|
|
||||||
return response;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
@@ -317,9 +437,15 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CancelTokenKey cancelKey = new CancelTokenKey
|
||||||
|
{
|
||||||
|
OwnerUri = cancelParams.OwnerUri,
|
||||||
|
Type = cancelParams.Type
|
||||||
|
};
|
||||||
|
|
||||||
// Cancel any current connection attempts for this URI
|
// Cancel any current connection attempts for this URI
|
||||||
CancellationTokenSource source;
|
CancellationTokenSource source;
|
||||||
if (ownerToCancellationTokenSourceMap.TryGetValue(cancelParams.OwnerUri, out source))
|
if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out source))
|
||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
@@ -331,11 +457,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else
|
|
||||||
{
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Close a connection with the specified connection details.
|
/// Close a connection with the specified connection details.
|
||||||
@@ -349,55 +473,125 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Cancel if we are in the middle of connecting
|
// Cancel if we are in the middle of connecting
|
||||||
if (CancelConnect(new CancelConnectParams() { OwnerUri = disconnectParams.OwnerUri }))
|
if (CancelConnections(disconnectParams.OwnerUri, disconnectParams.Type))
|
||||||
{
|
{
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lookup the connection owned by the URI
|
// Lookup the ConnectionInfo owned by the URI
|
||||||
ConnectionInfo info;
|
ConnectionInfo info;
|
||||||
if (!ownerToConnectionMap.TryGetValue(disconnectParams.OwnerUri, out info))
|
if (!ownerToConnectionMap.TryGetValue(disconnectParams.OwnerUri, out info))
|
||||||
{
|
{
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ServiceHost != null)
|
// Call Close() on the connections we want to disconnect
|
||||||
|
// If no connections were located, return false
|
||||||
|
if (!CloseConnections(info, disconnectParams.Type))
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the disconnected connections from the ConnectionInfo map
|
||||||
|
if (disconnectParams.Type == null)
|
||||||
|
{
|
||||||
|
info.RemoveAllConnections();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
info.RemoveConnection(disconnectParams.Type);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the ConnectionInfo has no more connections, remove the ConnectionInfo
|
||||||
|
if (info.CountConnections == 0)
|
||||||
|
{
|
||||||
|
ownerToConnectionMap.Remove(disconnectParams.OwnerUri);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle Telemetry disconnect events if we are disconnecting the default connection
|
||||||
|
if (disconnectParams.Type == null || disconnectParams.Type == ConnectionType.Default)
|
||||||
|
{
|
||||||
|
HandleDisconnectTelemetry(info);
|
||||||
|
InvokeOnDisconnectionActivities(info);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return true upon success
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Cancel connections associated with the given ownerUri.
|
||||||
|
/// If connectionType is not null, cancel the connection with the given connectionType
|
||||||
|
/// If connectionType is null, cancel all pending connections associated with ownerUri.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if a single pending connection associated with the non-null connectionType was
|
||||||
|
/// found and cancelled, false otherwise</returns>
|
||||||
|
private bool CancelConnections(string ownerUri, string connectionType)
|
||||||
|
{
|
||||||
|
// Cancel the connection of the given type
|
||||||
|
if (connectionType != null)
|
||||||
|
{
|
||||||
|
// If we are trying to disconnect a specific connection and it was just cancelled,
|
||||||
|
// this will return true
|
||||||
|
return CancelConnect(new CancelConnectParams() { OwnerUri = ownerUri, Type = connectionType });
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancel all pending connections
|
||||||
|
foreach (var entry in cancelTupleToCancellationTokenSourceMap)
|
||||||
|
{
|
||||||
|
string entryConnectionUri = entry.Key.OwnerUri;
|
||||||
|
string entryConnectionType = entry.Key.Type;
|
||||||
|
if (ownerUri.Equals(entryConnectionUri))
|
||||||
|
{
|
||||||
|
CancelConnect(new CancelConnectParams() { OwnerUri = ownerUri, Type = entryConnectionType });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Closes DbConnections associated with the given ConnectionInfo.
|
||||||
|
/// If connectionType is not null, closes the DbConnection with the type given by connectionType.
|
||||||
|
/// If connectionType is null, closes all DbConnections.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>true if connections were found and attempted to be closed,
|
||||||
|
/// false if no connections were found</returns>
|
||||||
|
private bool CloseConnections(ConnectionInfo connectionInfo, string connectionType)
|
||||||
|
{
|
||||||
|
ICollection<DbConnection> connectionsToDisconnect = new List<DbConnection>();
|
||||||
|
if (connectionType == null)
|
||||||
|
{
|
||||||
|
connectionsToDisconnect = connectionInfo.AllConnections;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// Make sure there is an existing connection of this type
|
||||||
|
DbConnection connection;
|
||||||
|
if (!connectionInfo.TryGetConnection(connectionType, out connection))
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
connectionsToDisconnect.Add(connection);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (connectionsToDisconnect.Count == 0)
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
foreach (DbConnection connection in connectionsToDisconnect)
|
||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
// Send a telemetry notification for intellisense performance metrics
|
connection.Close();
|
||||||
ServiceHost.SendEvent(TelemetryNotification.Type, new TelemetryParams()
|
|
||||||
{
|
|
||||||
Params = new TelemetryProperties
|
|
||||||
{
|
|
||||||
Properties = new Dictionary<string, string>
|
|
||||||
{
|
|
||||||
{ "IsAzure", info.IsAzure ? "1" : "0" }
|
|
||||||
},
|
|
||||||
EventName = TelemetryEventNames.IntellisenseQuantile,
|
|
||||||
Measures = info.IntellisenseMetrics.Quantile
|
|
||||||
}
|
}
|
||||||
});
|
catch (Exception)
|
||||||
}
|
|
||||||
catch (Exception ex)
|
|
||||||
{
|
{
|
||||||
Logger.Write(LogLevel.Verbose, "Could not send Connection telemetry event " + ex.ToString());
|
// Ignore
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close the connection
|
|
||||||
info.SqlConnection.Close();
|
|
||||||
|
|
||||||
// Remove URI mapping
|
|
||||||
ownerToConnectionMap.Remove(disconnectParams.OwnerUri);
|
|
||||||
|
|
||||||
// Invoke callback notifications
|
|
||||||
foreach (var activity in this.onDisconnectActivities)
|
|
||||||
{
|
|
||||||
activity(info.ConnectionDetails, disconnectParams.OwnerUri);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Success
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -490,7 +684,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
/// <summary>
|
/// <summary>
|
||||||
/// Handle new connection requests
|
/// Handle new connection requests
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="connectionDetails"></param>
|
/// <param name="connectParams"></param>
|
||||||
/// <param name="requestContext"></param>
|
/// <param name="requestContext"></param>
|
||||||
/// <returns></returns>
|
/// <returns></returns>
|
||||||
protected async Task HandleConnectRequest(
|
protected async Task HandleConnectRequest(
|
||||||
@@ -517,8 +711,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
|
// result is null if the ConnectParams was successfully validated
|
||||||
|
ConnectionCompleteParams result = ValidateConnectParams(connectParams);
|
||||||
|
if (result != null)
|
||||||
|
{
|
||||||
|
await ServiceHost.SendEvent(ConnectionCompleteNotification.Type, result);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// open connection based on request details
|
// open connection based on request details
|
||||||
ConnectionCompleteParams result = await Instance.Connect(connectParams);
|
result = await Instance.Connect(connectParams);
|
||||||
await ServiceHost.SendEvent(ConnectionCompleteNotification.Type, result);
|
await ServiceHost.SendEvent(ConnectionCompleteNotification.Type, result);
|
||||||
}
|
}
|
||||||
catch (Exception ex)
|
catch (Exception ex)
|
||||||
@@ -744,10 +946,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
if (info.SqlConnection.State == ConnectionState.Open)
|
foreach (DbConnection connection in info.AllConnections)
|
||||||
{
|
{
|
||||||
info.SqlConnection.ChangeDatabase(newDatabaseName);
|
if (connection.State == ConnectionState.Open)
|
||||||
|
{
|
||||||
|
connection.ChangeDatabase(newDatabaseName);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
info.ConnectionDetails.DatabaseName = newDatabaseName;
|
info.ConnectionDetails.DatabaseName = newDatabaseName;
|
||||||
|
|
||||||
// Fire a connection changed event
|
// Fire a connection changed event
|
||||||
@@ -771,13 +977,65 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void InvokeOnConnectionActivities(ConnectionInfo connectionInfo)
|
/// <summary>
|
||||||
|
/// Invokes the initial on-connect activities if the provided ConnectParams represents the default
|
||||||
|
/// connection.
|
||||||
|
/// </summary>
|
||||||
|
private void InvokeOnConnectionActivities(ConnectionInfo connectionInfo, ConnectParams connectParams)
|
||||||
{
|
{
|
||||||
|
if (connectParams.Type != ConnectionType.Default)
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
foreach (var activity in this.onConnectionActivities)
|
foreach (var activity in this.onConnectionActivities)
|
||||||
{
|
{
|
||||||
// not awaiting here to allow handlers to run in the background
|
// not awaiting here to allow handlers to run in the background
|
||||||
activity(connectionInfo);
|
activity(connectionInfo);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Invokes the fianl on-disconnect activities if the provided DisconnectParams represents the default
|
||||||
|
/// connection or is null - representing that all connections are being disconnected.
|
||||||
|
/// </summary>
|
||||||
|
private void InvokeOnDisconnectionActivities(ConnectionInfo connectionInfo)
|
||||||
|
{
|
||||||
|
foreach (var activity in this.onDisconnectActivities)
|
||||||
|
{
|
||||||
|
activity(connectionInfo.ConnectionDetails, connectionInfo.OwnerUri);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Handles the Telemetry events that occur upon disconnect.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="info"></param>
|
||||||
|
private void HandleDisconnectTelemetry(ConnectionInfo connectionInfo)
|
||||||
|
{
|
||||||
|
if (ServiceHost != null)
|
||||||
|
{
|
||||||
|
try
|
||||||
|
{
|
||||||
|
// Send a telemetry notification for intellisense performance metrics
|
||||||
|
ServiceHost.SendEvent(TelemetryNotification.Type, new TelemetryParams()
|
||||||
|
{
|
||||||
|
Params = new TelemetryProperties
|
||||||
|
{
|
||||||
|
Properties = new Dictionary<string, string>
|
||||||
|
{
|
||||||
|
{"IsAzure", connectionInfo.IsAzure ? "1" : "0"}
|
||||||
|
},
|
||||||
|
EventName = TelemetryEventNames.IntellisenseQuantile,
|
||||||
|
Measures = connectionInfo.IntellisenseMetrics.Quantile
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
catch (Exception ex)
|
||||||
|
{
|
||||||
|
Logger.Write(LogLevel.Verbose, "Could not send Connection telemetry event " + ex.ToString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,19 @@
|
|||||||
|
//
|
||||||
|
// Copyright (c) Microsoft. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||||
|
//
|
||||||
|
|
||||||
|
namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// String constants that represent connection types.
|
||||||
|
///
|
||||||
|
/// Default: Connection used by the editor. Opened by the editor upon the initial connection.
|
||||||
|
/// Query: Connection used for executing queries. Opened when the first query is executed.
|
||||||
|
/// </summary>
|
||||||
|
public static class ConnectionType
|
||||||
|
{
|
||||||
|
public const string Default = "Default";
|
||||||
|
public const string Query = "Query";
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,5 +15,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
|
|||||||
/// or a virtual file representing an object in a database.
|
/// or a virtual file representing an object in a database.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public string OwnerUri { get; set; }
|
public string OwnerUri { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// The type of connection we are trying to cancel
|
||||||
|
/// </summary>
|
||||||
|
public string Type { get; set; } = ConnectionType.Default;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,5 +22,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
|
|||||||
/// connection properties to the same database.
|
/// connection properties to the same database.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public ConnectionDetails Connection { get; set; }
|
public ConnectionDetails Connection { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// The type of this connection. By default, this is set to ConnectionType.Default.
|
||||||
|
/// </summary>
|
||||||
|
public string Type { get; set; } = ConnectionType.Default;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,6 +47,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
|
|||||||
/// Gets or sets the actual Connection established, including Database Name
|
/// Gets or sets the actual Connection established, including Database Name
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public ConnectionSummary ConnectionSummary { get; set; }
|
public ConnectionSummary ConnectionSummary { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// The type of connection that this notification is for
|
||||||
|
/// </summary>
|
||||||
|
public string Type { get; set; } = ConnectionType.Default;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
|
|||||||
@@ -15,5 +15,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
|
|||||||
/// or a virtual file representing an object in a database.
|
/// or a virtual file representing an object in a database.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public string OwnerUri { get; set; }
|
public string OwnerUri { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// The type of connection we are disconnecting. If null, we will disconnect all connections.
|
||||||
|
/// connections.
|
||||||
|
/// </summary>
|
||||||
|
public string Type { get; set; }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
|
|||||||
using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage;
|
using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage;
|
||||||
using Microsoft.SqlTools.ServiceLayer.SqlContext;
|
using Microsoft.SqlTools.ServiceLayer.SqlContext;
|
||||||
using Microsoft.SqlTools.ServiceLayer.Utility;
|
using Microsoft.SqlTools.ServiceLayer.Utility;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
|
||||||
using System.Collections.Generic;
|
using System.Collections.Generic;
|
||||||
|
|
||||||
namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
|
namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
|
||||||
@@ -369,26 +370,9 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open up a connection for querying the database
|
// Locate and setup the connection
|
||||||
string connectionString = ConnectionService.BuildConnectionString(editorConnection.ConnectionDetails);
|
DbConnection queryConnection = await ConnectionService.Instance.GetOrOpenConnection(editorConnection.OwnerUri, ConnectionType.Query);
|
||||||
// TODO: Don't create a new connection every time, see TFS #834978
|
ReliableSqlConnection sqlConn = queryConnection as ReliableSqlConnection;
|
||||||
using (DbConnection conn = editorConnection.Factory.CreateSqlConnection(connectionString))
|
|
||||||
{
|
|
||||||
try
|
|
||||||
{
|
|
||||||
await conn.OpenAsync();
|
|
||||||
}
|
|
||||||
catch (Exception exception)
|
|
||||||
{
|
|
||||||
this.HasExecuted = true;
|
|
||||||
if (QueryConnectionException != null)
|
|
||||||
{
|
|
||||||
await QueryConnectionException(exception.Message);
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
ReliableSqlConnection sqlConn = conn as ReliableSqlConnection;
|
|
||||||
if (sqlConn != null)
|
if (sqlConn != null)
|
||||||
{
|
{
|
||||||
// Subscribe to database informational messages
|
// Subscribe to database informational messages
|
||||||
@@ -400,7 +384,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
|
|||||||
// Execute beforeBatches synchronously, before the user defined batches
|
// Execute beforeBatches synchronously, before the user defined batches
|
||||||
foreach (Batch b in BeforeBatches)
|
foreach (Batch b in BeforeBatches)
|
||||||
{
|
{
|
||||||
await b.Execute(conn, cancellationSource.Token);
|
await b.Execute(queryConnection, cancellationSource.Token);
|
||||||
}
|
}
|
||||||
|
|
||||||
// We need these to execute synchronously, otherwise the user will be very unhappy
|
// We need these to execute synchronously, otherwise the user will be very unhappy
|
||||||
@@ -411,13 +395,13 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
|
|||||||
b.BatchCompletion += BatchCompleted;
|
b.BatchCompletion += BatchCompleted;
|
||||||
b.BatchMessageSent += BatchMessageSent;
|
b.BatchMessageSent += BatchMessageSent;
|
||||||
b.ResultSetCompletion += ResultSetCompleted;
|
b.ResultSetCompletion += ResultSetCompleted;
|
||||||
await b.Execute(conn, cancellationSource.Token);
|
await b.Execute(queryConnection, cancellationSource.Token);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute afterBatches synchronously, after the user defined batches
|
// Execute afterBatches synchronously, after the user defined batches
|
||||||
foreach (Batch b in AfterBatches)
|
foreach (Batch b in AfterBatches)
|
||||||
{
|
{
|
||||||
await b.Execute(conn, cancellationSource.Token);
|
await b.Execute(queryConnection, cancellationSource.Token);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call the query execution callback
|
// Call the query execution callback
|
||||||
@@ -442,9 +426,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
|
|||||||
sqlConn.GetUnderlyingConnection().InfoMessage -= OnInfoMessage;
|
sqlConn.GetUnderlyingConnection().InfoMessage -= OnInfoMessage;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Close connection after eliminating using statement for above TODO
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
|
|||||||
@@ -0,0 +1,104 @@
|
|||||||
|
//
|
||||||
|
// Copyright (c) Microsoft. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||||
|
//
|
||||||
|
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using System.Data.Common;
|
||||||
|
using System.Threading.Tasks;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Connection;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
|
||||||
|
using Microsoft.SqlTools.Test.Utility;
|
||||||
|
using Xunit;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.QueryExecution;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.SqlContext;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Test.QueryExecution;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts;
|
||||||
|
|
||||||
|
namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Tests for the ServiceHost Connection Service tests that require a live database connection
|
||||||
|
/// </summary>
|
||||||
|
public class ConnectionServiceTests
|
||||||
|
{
|
||||||
|
[Fact]
|
||||||
|
public async Task RunningMultipleQueriesCreatesOnlyOneConnection()
|
||||||
|
{
|
||||||
|
// Connect/disconnect twice to ensure reconnection can occur
|
||||||
|
ConnectionService service = ConnectionService.Instance;
|
||||||
|
service.OwnerToConnectionMap.Clear();
|
||||||
|
for (int i = 0; i < 2; i++)
|
||||||
|
{
|
||||||
|
var result = await TestObjects.InitLiveConnectionInfo();
|
||||||
|
ConnectionInfo connectionInfo = result.ConnectionInfo;
|
||||||
|
string uri = connectionInfo.OwnerUri;
|
||||||
|
|
||||||
|
// We should see one ConnectionInfo and one DbConnection
|
||||||
|
Assert.Equal(1, connectionInfo.CountConnections);
|
||||||
|
Assert.Equal(1, service.OwnerToConnectionMap.Count);
|
||||||
|
|
||||||
|
// If we run a query
|
||||||
|
var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary<string, byte[]>());
|
||||||
|
Query query = new Query(Common.StandardQuery, connectionInfo, new QueryExecutionSettings(), fileStreamFactory);
|
||||||
|
query.Execute();
|
||||||
|
query.ExecutionTask.Wait();
|
||||||
|
|
||||||
|
// We should see two DbConnections
|
||||||
|
Assert.Equal(2, connectionInfo.CountConnections);
|
||||||
|
|
||||||
|
// If we run another query
|
||||||
|
query = new Query(Common.StandardQuery, connectionInfo, new QueryExecutionSettings(), fileStreamFactory);
|
||||||
|
query.Execute();
|
||||||
|
query.ExecutionTask.Wait();
|
||||||
|
|
||||||
|
// We should still have 2 DbConnections
|
||||||
|
Assert.Equal(2, connectionInfo.CountConnections);
|
||||||
|
|
||||||
|
// If we disconnect, we should remain in a consistent state to do it over again
|
||||||
|
// e.g. loop and do it over again
|
||||||
|
service.Disconnect(new DisconnectParams() { OwnerUri = connectionInfo.OwnerUri });
|
||||||
|
|
||||||
|
// We should be left with an empty connection map
|
||||||
|
Assert.Equal(0, service.OwnerToConnectionMap.Count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task DatabaseChangesAffectAllConnections()
|
||||||
|
{
|
||||||
|
// If we make a connection to a live database
|
||||||
|
ConnectionService service = ConnectionService.Instance;
|
||||||
|
var result = await TestObjects.InitLiveConnectionInfo();
|
||||||
|
ConnectionInfo connectionInfo = result.ConnectionInfo;
|
||||||
|
ConnectionDetails details = connectionInfo.ConnectionDetails;
|
||||||
|
string uri = connectionInfo.OwnerUri;
|
||||||
|
string initialDatabaseName = details.DatabaseName;
|
||||||
|
string newDatabaseName = "tempdb";
|
||||||
|
string changeDatabaseQuery = "use " + newDatabaseName;
|
||||||
|
|
||||||
|
// Then run any query to create a query DbConnection
|
||||||
|
var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary<string, byte[]>());
|
||||||
|
Query query = new Query(Common.StandardQuery, connectionInfo, new QueryExecutionSettings(), fileStreamFactory);
|
||||||
|
query.Execute();
|
||||||
|
query.ExecutionTask.Wait();
|
||||||
|
|
||||||
|
// All open DbConnections (Query and Default) should have initialDatabaseName as their database
|
||||||
|
foreach (DbConnection connection in connectionInfo.AllConnections)
|
||||||
|
{
|
||||||
|
Assert.Equal(connection.Database, initialDatabaseName);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we run a query to change the database
|
||||||
|
query = new Query(changeDatabaseQuery, connectionInfo, new QueryExecutionSettings(), fileStreamFactory);
|
||||||
|
query.Execute();
|
||||||
|
query.ExecutionTask.Wait();
|
||||||
|
|
||||||
|
// All open DbConnections (Query and Default) should have newDatabaseName as their database
|
||||||
|
foreach (DbConnection connection in connectionInfo.AllConnections)
|
||||||
|
{
|
||||||
|
Assert.Equal(connection.Database, newDatabaseName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -685,19 +685,21 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
|
|||||||
{
|
{
|
||||||
var result = await TestObjects.InitLiveConnectionInfo();
|
var result = await TestObjects.InitLiveConnectionInfo();
|
||||||
ConnectionInfo connInfo = result.ConnectionInfo;
|
ConnectionInfo connInfo = result.ConnectionInfo;
|
||||||
|
DbConnection connection = connInfo.ConnectionTypeToConnectionMap[ConnectionType.Default];
|
||||||
|
|
||||||
Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(connInfo.SqlConnection));
|
|
||||||
|
Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(connection));
|
||||||
|
|
||||||
SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder();
|
SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder();
|
||||||
Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(builder));
|
Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(builder));
|
||||||
ReliableConnectionHelper.TryAddAlwaysOnConnectionProperties(builder, new SqlConnectionStringBuilder());
|
ReliableConnectionHelper.TryAddAlwaysOnConnectionProperties(builder, new SqlConnectionStringBuilder());
|
||||||
|
|
||||||
Assert.NotNull(ReliableConnectionHelper.GetServerName(connInfo.SqlConnection));
|
Assert.NotNull(ReliableConnectionHelper.GetServerName(connection));
|
||||||
Assert.NotNull(ReliableConnectionHelper.ReadServerVersion(connInfo.SqlConnection));
|
Assert.NotNull(ReliableConnectionHelper.ReadServerVersion(connection));
|
||||||
|
|
||||||
Assert.NotNull(ReliableConnectionHelper.GetAsSqlConnection(connInfo.SqlConnection));
|
Assert.NotNull(ReliableConnectionHelper.GetAsSqlConnection(connection));
|
||||||
|
|
||||||
ReliableConnectionHelper.ServerInfo info = ReliableConnectionHelper.GetServerVersion(connInfo.SqlConnection);
|
ReliableConnectionHelper.ServerInfo info = ReliableConnectionHelper.GetServerVersion(connection);
|
||||||
Assert.NotNull(ReliableConnectionHelper.IsVersionGreaterThan2012RTM(info));
|
Assert.NotNull(ReliableConnectionHelper.IsVersionGreaterThan2012RTM(info));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -728,8 +730,10 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
|
|||||||
{
|
{
|
||||||
var result = await TestObjects.InitLiveConnectionInfo();
|
var result = await TestObjects.InitLiveConnectionInfo();
|
||||||
ConnectionInfo connInfo = result.ConnectionInfo;
|
ConnectionInfo connInfo = result.ConnectionInfo;
|
||||||
|
DbConnection dbConnection;
|
||||||
|
connInfo.TryGetConnection(ConnectionType.Default, out dbConnection);
|
||||||
|
|
||||||
var connection = connInfo.SqlConnection as ReliableSqlConnection;
|
var connection = dbConnection as ReliableSqlConnection;
|
||||||
var command = new ReliableSqlConnection.ReliableSqlCommand(connection);
|
var command = new ReliableSqlConnection.ReliableSqlCommand(connection);
|
||||||
Assert.NotNull(command.Connection);
|
Assert.NotNull(command.Connection);
|
||||||
|
|
||||||
|
|||||||
@@ -19,8 +19,10 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.QueryExecution.DataSt
|
|||||||
private async Task<StorageDataReader> GetTestStorageDataReader(string query)
|
private async Task<StorageDataReader> GetTestStorageDataReader(string query)
|
||||||
{
|
{
|
||||||
var result = await TestObjects.InitLiveConnectionInfo();
|
var result = await TestObjects.InitLiveConnectionInfo();
|
||||||
|
DbConnection connection;
|
||||||
|
result.ConnectionInfo.TryGetConnection(ConnectionType.Default, out connection);
|
||||||
|
|
||||||
var command = result.ConnectionInfo.SqlConnection.CreateCommand();
|
var command = connection.CreateCommand();
|
||||||
command.CommandText = query;
|
command.CommandText = query;
|
||||||
DbDataReader reader = command.ExecuteReader();
|
DbDataReader reader = command.ExecuteReader();
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,112 @@
|
|||||||
|
using System;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using System.Data.Common;
|
||||||
|
using System.Linq;
|
||||||
|
using System.Threading.Tasks;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Connection;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.QueryExecution;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.SqlContext;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Test.QueryExecution;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts;
|
||||||
|
using Microsoft.SqlTools.Test.Utility;
|
||||||
|
using Xunit;
|
||||||
|
|
||||||
|
namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.QueryExecution
|
||||||
|
{
|
||||||
|
public class ExecuteTests
|
||||||
|
{
|
||||||
|
[Fact]
|
||||||
|
public async Task RollbackTransactionFailsWithoutBeginTransaction()
|
||||||
|
{
|
||||||
|
const string refactorText = "ROLLBACK TRANSACTION";
|
||||||
|
|
||||||
|
// Given a connection to a live database
|
||||||
|
var result = await TestObjects.InitLiveConnectionInfo();
|
||||||
|
ConnectionInfo connInfo = result.ConnectionInfo;
|
||||||
|
var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary<string, byte[]>());
|
||||||
|
|
||||||
|
// If I run a "ROLLBACK TRANSACTION" query
|
||||||
|
Query query = new Query(refactorText, connInfo, new QueryExecutionSettings(), fileStreamFactory);
|
||||||
|
query.Execute();
|
||||||
|
query.ExecutionTask.Wait();
|
||||||
|
|
||||||
|
// There should be an error
|
||||||
|
Assert.True(query.Batches[0].HasError);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task TransactionsSucceedAcrossQueries()
|
||||||
|
{
|
||||||
|
const string beginText = "BEGIN TRANSACTION";
|
||||||
|
const string rollbackText = "ROLLBACK TRANSACTION";
|
||||||
|
|
||||||
|
// Given a connection to a live database
|
||||||
|
var result = await TestObjects.InitLiveConnectionInfo();
|
||||||
|
ConnectionInfo connInfo = result.ConnectionInfo;
|
||||||
|
var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary<string, byte[]>());
|
||||||
|
|
||||||
|
// If I run a "BEGIN TRANSACTION" query
|
||||||
|
CreateAndExecuteQuery(beginText, connInfo, fileStreamFactory);
|
||||||
|
|
||||||
|
// Then I run a "ROLLBACK TRANSACTION" query, there should be no errors
|
||||||
|
Query rollbackQuery = CreateAndExecuteQuery(rollbackText, connInfo, fileStreamFactory);
|
||||||
|
Assert.False(rollbackQuery.Batches[0].HasError);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task TempTablesPersistAcrossQueries()
|
||||||
|
{
|
||||||
|
const string createTempText = "CREATE TABLE #someTempTable (id int)";
|
||||||
|
const string insertTempText = "INSERT INTO #someTempTable VALUES(1)";
|
||||||
|
|
||||||
|
// Given a connection to a live database
|
||||||
|
var result = await TestObjects.InitLiveConnectionInfo();
|
||||||
|
ConnectionInfo connInfo = result.ConnectionInfo;
|
||||||
|
var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary<string, byte[]>());
|
||||||
|
|
||||||
|
// If I run a query creating a temp table
|
||||||
|
CreateAndExecuteQuery(createTempText, connInfo, fileStreamFactory);
|
||||||
|
|
||||||
|
// Then I run a different query using that temp table, there should be no errors
|
||||||
|
Query insertTempQuery = CreateAndExecuteQuery(insertTempText, connInfo, fileStreamFactory);
|
||||||
|
Assert.False(insertTempQuery.Batches[0].HasError);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task DatabaseChangesWhenCallingUseDatabase()
|
||||||
|
{
|
||||||
|
const string master = "master";
|
||||||
|
const string tempdb = "tempdb";
|
||||||
|
const string useQuery = "USE {0}";
|
||||||
|
|
||||||
|
// Given a connection to a live database
|
||||||
|
var result = await TestObjects.InitLiveConnectionInfo();
|
||||||
|
ConnectionInfo connInfo = result.ConnectionInfo;
|
||||||
|
DbConnection connection;
|
||||||
|
connInfo.TryGetConnection(ConnectionType.Default, out connection);
|
||||||
|
|
||||||
|
var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary<string, byte[]>());
|
||||||
|
|
||||||
|
// If I use master, the current database should be master
|
||||||
|
CreateAndExecuteQuery(string.Format(useQuery, master), connInfo, fileStreamFactory);
|
||||||
|
Assert.Equal(master, connection.Database);
|
||||||
|
|
||||||
|
// If I use tempdb, the current database should be tempdb
|
||||||
|
CreateAndExecuteQuery(string.Format(useQuery, tempdb), connInfo, fileStreamFactory);
|
||||||
|
Assert.Equal(tempdb, connection.Database);
|
||||||
|
|
||||||
|
// If I switch back to master, the current database should be master
|
||||||
|
CreateAndExecuteQuery(string.Format(useQuery, master), connInfo, fileStreamFactory);
|
||||||
|
Assert.Equal(master, connection.Database);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Query CreateAndExecuteQuery(string queryText, ConnectionInfo connectionInfo, IFileStreamFactory fileStreamFactory)
|
||||||
|
{
|
||||||
|
Query query = new Query(queryText, connectionInfo, new QueryExecutionSettings(), fileStreamFactory);
|
||||||
|
query.Execute();
|
||||||
|
query.ExecutionTask.Wait();
|
||||||
|
return query;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -19,6 +19,10 @@ using Microsoft.SqlTools.Test.Utility;
|
|||||||
using Moq;
|
using Moq;
|
||||||
using Moq.Protected;
|
using Moq.Protected;
|
||||||
using Xunit;
|
using Xunit;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.QueryExecution;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.SqlContext;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Test.QueryExecution;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts;
|
||||||
|
|
||||||
namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
|
namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
|
||||||
{
|
{
|
||||||
@@ -994,5 +998,161 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
|
|||||||
Assert.NotNull(errorMessage);
|
Assert.NotNull(errorMessage);
|
||||||
Assert.NotEmpty(errorMessage);
|
Assert.NotEmpty(errorMessage);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async void ConnectingTwiceWithTheSameUriDoesNotCreateAnotherDbConnection()
|
||||||
|
{
|
||||||
|
// Setup the connect and disconnect params
|
||||||
|
var connectParamsSame1 = new ConnectParams()
|
||||||
|
{
|
||||||
|
OwnerUri = "connectParamsSame",
|
||||||
|
Connection = TestObjects.GetTestConnectionDetails()
|
||||||
|
};
|
||||||
|
var connectParamsSame2 = new ConnectParams()
|
||||||
|
{
|
||||||
|
OwnerUri = "connectParamsSame",
|
||||||
|
Connection = TestObjects.GetTestConnectionDetails()
|
||||||
|
};
|
||||||
|
var disconnectParamsSame = new DisconnectParams()
|
||||||
|
{
|
||||||
|
OwnerUri = connectParamsSame1.OwnerUri
|
||||||
|
};
|
||||||
|
var connectParamsDifferent = new ConnectParams()
|
||||||
|
{
|
||||||
|
OwnerUri = "connectParamsDifferent",
|
||||||
|
Connection = TestObjects.GetTestConnectionDetails()
|
||||||
|
};
|
||||||
|
var disconnectParamsDifferent = new DisconnectParams()
|
||||||
|
{
|
||||||
|
OwnerUri = connectParamsDifferent.OwnerUri
|
||||||
|
};
|
||||||
|
|
||||||
|
// Given a request to connect to a database, there should be no initial connections in the map
|
||||||
|
var service = TestObjects.GetTestConnectionService();
|
||||||
|
Dictionary<string, ConnectionInfo> ownerToConnectionMap = service.OwnerToConnectionMap;
|
||||||
|
Assert.Equal(0, ownerToConnectionMap.Count);
|
||||||
|
|
||||||
|
// If we connect to the service, there should be 1 connection
|
||||||
|
await service.Connect(connectParamsSame1);
|
||||||
|
Assert.Equal(1, ownerToConnectionMap.Count);
|
||||||
|
|
||||||
|
// If we connect again with the same URI, there should still be 1 connection
|
||||||
|
await service.Connect(connectParamsSame2);
|
||||||
|
Assert.Equal(1, ownerToConnectionMap.Count);
|
||||||
|
|
||||||
|
// If we connect with a different URI, there should be 2 connections
|
||||||
|
await service.Connect(connectParamsDifferent);
|
||||||
|
Assert.Equal(2, ownerToConnectionMap.Count);
|
||||||
|
|
||||||
|
// If we disconenct with the unique URI, there should be 1 connection
|
||||||
|
service.Disconnect(disconnectParamsDifferent);
|
||||||
|
Assert.Equal(1, ownerToConnectionMap.Count);
|
||||||
|
|
||||||
|
// If we disconenct with the duplicate URI, there should be 0 connections
|
||||||
|
service.Disconnect(disconnectParamsSame);
|
||||||
|
Assert.Equal(0, ownerToConnectionMap.Count);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async void DbConnectionDoesntLeakUponDisconnect()
|
||||||
|
{
|
||||||
|
// If we connect with a single URI and 2 connection types
|
||||||
|
var connectParamsDefault = new ConnectParams()
|
||||||
|
{
|
||||||
|
OwnerUri = "connectParams",
|
||||||
|
Connection = TestObjects.GetTestConnectionDetails(),
|
||||||
|
Type = ConnectionType.Default
|
||||||
|
};
|
||||||
|
var connectParamsQuery = new ConnectParams()
|
||||||
|
{
|
||||||
|
OwnerUri = "connectParams",
|
||||||
|
Connection = TestObjects.GetTestConnectionDetails(),
|
||||||
|
Type = ConnectionType.Query
|
||||||
|
};
|
||||||
|
var disconnectParams = new DisconnectParams()
|
||||||
|
{
|
||||||
|
OwnerUri = connectParamsDefault.OwnerUri
|
||||||
|
};
|
||||||
|
var service = TestObjects.GetTestConnectionService();
|
||||||
|
await service.Connect(connectParamsDefault);
|
||||||
|
await service.Connect(connectParamsQuery);
|
||||||
|
|
||||||
|
// We should have one ConnectionInfo and 2 DbConnections
|
||||||
|
ConnectionInfo connectionInfo = service.OwnerToConnectionMap[connectParamsDefault.OwnerUri];
|
||||||
|
Assert.Equal(2, connectionInfo.CountConnections);
|
||||||
|
Assert.Equal(1, service.OwnerToConnectionMap.Count);
|
||||||
|
|
||||||
|
// If we record when the Default connecton calls Close()
|
||||||
|
bool defaultDisconnectCalled = false;
|
||||||
|
var mockDefaultConnection = new Mock<DbConnection> { CallBase = true };
|
||||||
|
mockDefaultConnection.Setup(x => x.Close())
|
||||||
|
.Callback(() =>
|
||||||
|
{
|
||||||
|
defaultDisconnectCalled = true;
|
||||||
|
});
|
||||||
|
connectionInfo.ConnectionTypeToConnectionMap[ConnectionType.Default] = mockDefaultConnection.Object;
|
||||||
|
|
||||||
|
// And when the Query connecton calls Close()
|
||||||
|
bool queryDisconnectCalled = false;
|
||||||
|
var mockQueryConnection = new Mock<DbConnection> { CallBase = true };
|
||||||
|
mockQueryConnection.Setup(x => x.Close())
|
||||||
|
.Callback(() =>
|
||||||
|
{
|
||||||
|
queryDisconnectCalled = true;
|
||||||
|
});
|
||||||
|
connectionInfo.ConnectionTypeToConnectionMap[ConnectionType.Query] = mockQueryConnection.Object;
|
||||||
|
|
||||||
|
// If we disconnect all open connections with the same URI as used above
|
||||||
|
service.Disconnect(disconnectParams);
|
||||||
|
|
||||||
|
// Close() should have gotten called for both DbConnections
|
||||||
|
Assert.True(defaultDisconnectCalled);
|
||||||
|
Assert.True(queryDisconnectCalled);
|
||||||
|
|
||||||
|
// And the maps that hold connection data should be empty
|
||||||
|
Assert.Equal(0, connectionInfo.CountConnections);
|
||||||
|
Assert.Equal(0, service.OwnerToConnectionMap.Count);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async void ClosingQueryConnectionShouldLeaveDefaultConnectionOpen()
|
||||||
|
{
|
||||||
|
// Setup the connect and disconnect params
|
||||||
|
var connectParamsDefault = new ConnectParams()
|
||||||
|
{
|
||||||
|
OwnerUri = "connectParamsSame",
|
||||||
|
Connection = TestObjects.GetTestConnectionDetails(),
|
||||||
|
Type = ConnectionType.Default
|
||||||
|
};
|
||||||
|
var connectParamsQuery = new ConnectParams()
|
||||||
|
{
|
||||||
|
OwnerUri = connectParamsDefault.OwnerUri,
|
||||||
|
Connection = TestObjects.GetTestConnectionDetails(),
|
||||||
|
Type = ConnectionType.Query
|
||||||
|
};
|
||||||
|
var disconnectParamsQuery = new DisconnectParams()
|
||||||
|
{
|
||||||
|
OwnerUri = connectParamsDefault.OwnerUri,
|
||||||
|
Type = connectParamsQuery.Type
|
||||||
|
};
|
||||||
|
|
||||||
|
// If I connect a Default and a Query connection
|
||||||
|
var service = TestObjects.GetTestConnectionService();
|
||||||
|
Dictionary<string, ConnectionInfo> ownerToConnectionMap = service.OwnerToConnectionMap;
|
||||||
|
await service.Connect(connectParamsDefault);
|
||||||
|
await service.Connect(connectParamsQuery);
|
||||||
|
ConnectionInfo connectionInfo = service.OwnerToConnectionMap[connectParamsDefault.OwnerUri];
|
||||||
|
|
||||||
|
// There should be 2 connections in the map
|
||||||
|
Assert.Equal(2, connectionInfo.CountConnections);
|
||||||
|
|
||||||
|
// If I Disconnect only the Query connection, there should be 1 connection in the map
|
||||||
|
service.Disconnect(disconnectParamsQuery);
|
||||||
|
Assert.Equal(1, connectionInfo.CountConnections);
|
||||||
|
|
||||||
|
// If I reconnect, there should be 2 again
|
||||||
|
await service.Connect(connectParamsQuery);
|
||||||
|
Assert.Equal(2, connectionInfo.CountConnections);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -82,6 +82,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
|
|||||||
public static Query GetBasicExecutedQuery()
|
public static Query GetBasicExecutedQuery()
|
||||||
{
|
{
|
||||||
ConnectionInfo ci = CreateTestConnectionInfo(new[] {StandardTestData}, false);
|
ConnectionInfo ci = CreateTestConnectionInfo(new[] {StandardTestData}, false);
|
||||||
|
|
||||||
|
// Query won't be able to request a new query DbConnection unless the ConnectionService has a
|
||||||
|
// ConnectionInfo with the same URI as the query, so we will manually set it
|
||||||
|
ConnectionService.Instance.OwnerToConnectionMap[ci.OwnerUri] = ci;
|
||||||
|
|
||||||
Query query = new Query(StandardQuery, ci, new QueryExecutionSettings(), GetFileStreamFactory(new Dictionary<string, byte[]>()));
|
Query query = new Query(StandardQuery, ci, new QueryExecutionSettings(), GetFileStreamFactory(new Dictionary<string, byte[]>()));
|
||||||
query.Execute();
|
query.Execute();
|
||||||
query.ExecutionTask.Wait();
|
query.ExecutionTask.Wait();
|
||||||
@@ -222,6 +227,23 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
|
|||||||
return new ConnectionInfo(CreateMockFactory(data, throwOnRead), OwnerUri, StandardConnectionDetails);
|
return new ConnectionInfo(CreateMockFactory(data, throwOnRead), OwnerUri, StandardConnectionDetails);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static ConnectionInfo CreateConnectedConnectionInfo(Dictionary<string, string>[][] data, bool throwOnRead, string type = ConnectionType.Default)
|
||||||
|
{
|
||||||
|
ConnectionService connectionService = ConnectionService.Instance;
|
||||||
|
connectionService.OwnerToConnectionMap.Clear();
|
||||||
|
connectionService.ConnectionFactory = CreateMockFactory(data, throwOnRead);
|
||||||
|
|
||||||
|
ConnectParams connectParams = new ConnectParams()
|
||||||
|
{
|
||||||
|
Connection = StandardConnectionDetails,
|
||||||
|
OwnerUri = Common.OwnerUri,
|
||||||
|
Type = type
|
||||||
|
};
|
||||||
|
|
||||||
|
connectionService.Connect(connectParams).Wait();
|
||||||
|
return connectionService.OwnerToConnectionMap[OwnerUri];
|
||||||
|
}
|
||||||
|
|
||||||
#endregion
|
#endregion
|
||||||
|
|
||||||
#region Service Mocking
|
#region Service Mocking
|
||||||
@@ -233,12 +255,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
|
|||||||
// Create a place for the temp "files" to be written
|
// Create a place for the temp "files" to be written
|
||||||
storage = new Dictionary<string, byte[]>();
|
storage = new Dictionary<string, byte[]>();
|
||||||
|
|
||||||
// Create the connection factory with the dataset
|
|
||||||
var factory = CreateTestConnectionInfo(data, throwOnRead).Factory;
|
|
||||||
|
|
||||||
// Mock the connection service
|
// Mock the connection service
|
||||||
var connectionService = new Mock<ConnectionService>();
|
var connectionService = new Mock<ConnectionService>();
|
||||||
ConnectionInfo ci = new ConnectionInfo(factory, OwnerUri, StandardConnectionDetails);
|
ConnectionInfo ci = CreateConnectedConnectionInfo(data, throwOnRead);
|
||||||
ConnectionInfo outValMock;
|
ConnectionInfo outValMock;
|
||||||
connectionService
|
connectionService
|
||||||
.Setup(service => service.TryFindConnection(It.IsAny<string>(), out outValMock))
|
.Setup(service => service.TryFindConnection(It.IsAny<string>(), out outValMock))
|
||||||
|
|||||||
@@ -7,10 +7,16 @@
|
|||||||
|
|
||||||
using System.Data.Common;
|
using System.Data.Common;
|
||||||
using System.Threading.Tasks;
|
using System.Threading.Tasks;
|
||||||
|
using System;
|
||||||
|
using System.Collections.Generic;
|
||||||
using Microsoft.SqlTools.ServiceLayer.Connection;
|
using Microsoft.SqlTools.ServiceLayer.Connection;
|
||||||
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
|
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
|
||||||
using Microsoft.SqlTools.ServiceLayer.QueryExecution;
|
using Microsoft.SqlTools.ServiceLayer.QueryExecution;
|
||||||
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
|
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.SqlContext;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts;
|
||||||
|
using Microsoft.SqlTools.Test.Utility;
|
||||||
|
using Xunit;
|
||||||
|
|
||||||
namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
|
namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
|
||||||
{
|
{
|
||||||
@@ -24,12 +30,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
|
|||||||
// If:
|
// If:
|
||||||
// ... I create a query with a udt column in the result set
|
// ... I create a query with a udt column in the result set
|
||||||
ConnectionInfo connectionInfo = TestObjects.GetTestConnectionInfo();
|
ConnectionInfo connectionInfo = TestObjects.GetTestConnectionInfo();
|
||||||
Query query = new Query(Common.UdtQuery, connectionInfo, new QueryExecutionSettings(), Common.GetFileStreamFactory());
|
Query query = new Query(Common.UdtQuery, connectionInfo, new QueryExecutionSettings(), Common.GetFileStreamFactory(new Dictionary<string, byte[]>()));
|
||||||
|
|
||||||
// If:
|
// If:
|
||||||
// ... I then execute the query
|
// ... I then execute the query
|
||||||
DateTime startTime = DateTime.Now;
|
DateTime startTime = DateTime.Now;
|
||||||
query.Execute().Wait();
|
query.Execute();
|
||||||
|
query.ExecutionTask.Wait();
|
||||||
|
|
||||||
// Then:
|
// Then:
|
||||||
// ... The query should complete within 2 seconds since retry logic should not kick in
|
// ... The query should complete within 2 seconds since retry logic should not kick in
|
||||||
|
|||||||
@@ -163,7 +163,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution
|
|||||||
|
|
||||||
// If:
|
// If:
|
||||||
// ... I create a query from two batches (with separator)
|
// ... I create a query from two batches (with separator)
|
||||||
ConnectionInfo ci = Common.CreateTestConnectionInfo(null, false);
|
ConnectionInfo ci = Common.CreateConnectedConnectionInfo(null, false);
|
||||||
|
|
||||||
string queryText = string.Format("{0}\r\nGO\r\n{0}", Common.StandardQuery);
|
string queryText = string.Format("{0}\r\nGO\r\n{0}", Common.StandardQuery);
|
||||||
var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary<string, byte[]>());
|
var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary<string, byte[]>());
|
||||||
Query query = new Query(queryText, ci, new QueryExecutionSettings(), fileStreamFactory);
|
Query query = new Query(queryText, ci, new QueryExecutionSettings(), fileStreamFactory);
|
||||||
@@ -280,6 +281,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution
|
|||||||
// If:
|
// If:
|
||||||
// ... I create a query from an invalid batch
|
// ... I create a query from an invalid batch
|
||||||
ConnectionInfo ci = Common.CreateTestConnectionInfo(null, true);
|
ConnectionInfo ci = Common.CreateTestConnectionInfo(null, true);
|
||||||
|
ConnectionService.Instance.OwnerToConnectionMap[ci.OwnerUri] = ci;
|
||||||
|
|
||||||
var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary<string, byte[]>());
|
var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary<string, byte[]>());
|
||||||
Query query = new Query(Common.InvalidQuery, ci, new QueryExecutionSettings(), fileStreamFactory);
|
Query query = new Query(Common.InvalidQuery, ci, new QueryExecutionSettings(), fileStreamFactory);
|
||||||
BatchCallbackHelper(query,
|
BatchCallbackHelper(query,
|
||||||
|
|||||||
Reference in New Issue
Block a user