//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System;
using System.Collections.Generic;
using System.Data.Common;
using System.Data.SqlClient;
using System.Threading.Tasks;
using Microsoft.SqlTools.EditorServices.Utility;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlTools.ServiceLayer.Hosting;
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.SqlContext;
using Microsoft.SqlTools.ServiceLayer.Workspace;
namespace Microsoft.SqlTools.ServiceLayer.Connection
{
public class ConnectionInfo
{
public ConnectionInfo(ISqlConnectionFactory factory, string ownerUri, ConnectionDetails details)
{
Factory = factory;
OwnerUri = ownerUri;
ConnectionDetails = details;
ConnectionId = Guid.NewGuid();
}
///
/// Unique Id, helpful to identify a connection info object
///
public Guid ConnectionId { get; private set; }
public string OwnerUri { get; private set; }
private ISqlConnectionFactory Factory {get; set;}
public ConnectionDetails ConnectionDetails { get; private set; }
public DbConnection SqlConnection { get; private set; }
public void OpenConnection()
{
// build the connection string from the input parameters
string connectionString = ConnectionService.BuildConnectionString(ConnectionDetails);
// create a sql connection instance
SqlConnection = Factory.CreateSqlConnection(connectionString);
SqlConnection.Open();
}
}
///
/// Main class for the Connection Management services
///
public class ConnectionService
{
///
/// Singleton service instance
///
private static Lazy instance
= new Lazy(() => new ConnectionService());
///
/// Gets the singleton service instance
///
public static ConnectionService Instance
{
get
{
return instance.Value;
}
}
///
/// The SQL connection factory object
///
private ISqlConnectionFactory connectionFactory;
private Dictionary ownerToConnectionMap = new Dictionary();
///
/// Default constructor is private since it's a singleton class
///
private ConnectionService()
{
}
///
/// Callback for onconnection handler
///
///
public delegate Task OnConnectionHandler(ConnectionInfo info);
///
/// List of onconnection handlers
///
private readonly List onConnectionActivities = new List();
///
/// Gets the SQL connection factory instance
///
public ISqlConnectionFactory ConnectionFactory
{
get
{
if (this.connectionFactory == null)
{
this.connectionFactory = new SqlConnectionFactory();
}
return this.connectionFactory;
}
}
///
/// Test constructor that injects dependency interfaces
///
///
public ConnectionService(ISqlConnectionFactory testFactory)
{
this.connectionFactory = testFactory;
}
// Attempts to link a URI to an actively used connection for this URI
public bool TryFindConnection(string ownerUri, out ConnectionSummary connectionSummary)
{
connectionSummary = null;
ConnectionInfo connectionInfo;
if (this.ownerToConnectionMap.TryGetValue(ownerUri, out connectionInfo))
{
connectionSummary = CopySummary(connectionInfo.ConnectionDetails);
return true;
}
return false;
}
private static ConnectionSummary CopySummary(ConnectionSummary summary)
{
return new ConnectionSummary()
{
ServerName = summary.ServerName,
DatabaseName = summary.DatabaseName,
UserName = summary.UserName
};
}
///
/// Open a connection with the specified connection details
///
///
public ConnectResponse Connect(ConnectParams connectionParams)
{
ConnectionInfo connectionInfo;
if (ownerToConnectionMap.TryGetValue(connectionParams.OwnerUri, out connectionInfo) )
{
// TODO disconnect
}
connectionInfo = new ConnectionInfo(this.connectionFactory, connectionParams.OwnerUri, connectionParams.Connection);
// try to connect
connectionInfo.OpenConnection();
// TODO: check that connection worked
ownerToConnectionMap[connectionParams.OwnerUri] = connectionInfo;
// invoke callback notifications
foreach (var activity in this.onConnectionActivities)
{
activity(connectionInfo);
}
// return the connection result
return new ConnectResponse()
{
ConnectionId = connectionInfo.ConnectionId.ToString()
};
}
public void InitializeService(IProtocolEndpoint serviceHost)
{
// Register request and event handlers with the Service Host
serviceHost.SetRequestHandler(ConnectionRequest.Type, HandleConnectRequest);
// Register the configuration update handler
WorkspaceService.Instance.RegisterConfigChangeCallback(HandleDidChangeConfigurationNotification);
}
///
/// Add a new method to be called when the onconnection request is submitted
///
///
public void RegisterOnConnectionTask(OnConnectionHandler activity)
{
onConnectionActivities.Add(activity);
}
///
/// Handle new connection requests
///
///
///
///
protected async Task HandleConnectRequest(
ConnectParams connectParams,
RequestContext requestContext)
{
Logger.Write(LogLevel.Verbose, "HandleConnectRequest");
try
{
// open connection base on request details
ConnectResponse result = ConnectionService.Instance.Connect(connectParams);
await requestContext.SendResult(result);
}
catch(Exception ex)
{
await requestContext.SendError(ex.Message);
}
}
public Task HandleDidChangeConfigurationNotification(
SqlToolsSettings newSettings,
SqlToolsSettings oldSettings,
EventContext eventContext)
{
return Task.FromResult(true);
}
///
/// Build a connection string from a connection details instance
///
///
public static string BuildConnectionString(ConnectionDetails connectionDetails)
{
SqlConnectionStringBuilder connectionBuilder = new SqlConnectionStringBuilder();
connectionBuilder["Data Source"] = connectionDetails.ServerName;
connectionBuilder["Integrated Security"] = false;
connectionBuilder["User Id"] = connectionDetails.UserName;
connectionBuilder["Password"] = connectionDetails.Password;
connectionBuilder["Initial Catalog"] = connectionDetails.DatabaseName;
return connectionBuilder.ToString();
}
}
}