// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. // using System; using System.IO; using System.Collections.Concurrent; using System.Collections.Specialized; using System.Data.SqlClient; using System.Diagnostics; using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.SqlTools.Hosting.Protocol; using Microsoft.SqlTools.Hosting.Protocol.Contracts; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Hosting; using Microsoft.SqlTools.ServiceLayer.Metadata.Contracts; using Microsoft.SqlTools.ServiceLayer.Scripting.Contracts; using Microsoft.SqlTools.Utility; using Microsoft.SqlServer.Management.Smo; using Microsoft.SqlServer.Management.Common; using Microsoft.SqlServer.Management.Sdk.Sfc; namespace Microsoft.SqlTools.ServiceLayer.Scripting { /// /// Main class for Scripting Service functionality /// public sealed class ScriptingService : IDisposable { private const int ScriptingOperationTimeout = 60000; private static readonly Lazy LazyInstance = new Lazy(() => new ScriptingService()); public static ScriptingService Instance => LazyInstance.Value; private static ConnectionService connectionService = null; private readonly Lazy> operations = new Lazy>(() => new ConcurrentDictionary()); private bool disposed; /// /// Internal for testing purposes only /// internal static ConnectionService ConnectionServiceInstance { get { if (connectionService == null) { connectionService = ConnectionService.Instance; } return connectionService; } set { connectionService = value; } } /// /// The collection of active operations /// internal ConcurrentDictionary ActiveOperations => operations.Value; /// /// Initializes the Scripting Service instance /// /// /// public void InitializeService(ServiceHost serviceHost) { serviceHost.SetRequestHandler(ScriptingRequest.Type, this.HandleScriptExecuteRequest); serviceHost.SetRequestHandler(ScriptingCancelRequest.Type, this.HandleScriptCancelRequest); serviceHost.SetRequestHandler(ScriptingListObjectsRequest.Type, this.HandleListObjectsRequest); // Register handler for shutdown event serviceHost.RegisterShutdownTask((shutdownParams, requestContext) => { this.Dispose(); return Task.FromResult(0); }); } /// /// Handles request to execute start the list objects operation. /// private async Task HandleListObjectsRequest(ScriptingListObjectsParams parameters, RequestContext requestContext) { try { ScriptingListObjectsOperation operation = new ScriptingListObjectsOperation(parameters); operation.CompleteNotification += (sender, e) => requestContext.SendEvent(ScriptingListObjectsCompleteEvent.Type, e); RunTask(requestContext, operation); await requestContext.SendResult(new ScriptingListObjectsResult { OperationId = operation.OperationId }); } catch (Exception e) { await requestContext.SendError(e); } } /// /// Handles request to start the scripting operation /// public async Task HandleScriptExecuteRequest(ScriptingParams parameters, RequestContext requestContext) { try { // if a connection string wasn't provided as a parameter then // use the owner uri property to lookup its associated ConnectionInfo // and then build a connection string out of that ConnectionInfo connInfo; ScriptingService.ConnectionServiceInstance.TryFindConnection(parameters.OwnerUri, out connInfo); if (parameters.ConnectionString == null) { if (connInfo != null) { parameters.ConnectionString = ConnectionService.BuildConnectionString(connInfo.ConnectionDetails); } else { throw new Exception("Could not find ConnectionInfo"); } } // if the scripting operation is for SELECT then handle that message differently // for SELECT we'll build the SQL directly whereas other scripting operations depend on SMO if (parameters.ScriptOptions.ScriptCreateDrop == "ScriptSelect") { RunSelectTask(connInfo, parameters, requestContext); } else { ScriptingScriptOperation operation = new ScriptingScriptOperation(parameters); operation.PlanNotification += (sender, e) => requestContext.SendEvent(ScriptingPlanNotificationEvent.Type, e); operation.ProgressNotification += (sender, e) => requestContext.SendEvent(ScriptingProgressNotificationEvent.Type, e); operation.CompleteNotification += (sender, e) => this.SendScriptingCompleteEvent(requestContext, ScriptingCompleteEvent.Type, e, operation, parameters.ScriptDestination); RunTask(requestContext, operation); } } catch (Exception e) { await requestContext.SendError(e); } } /// /// Handles request to cancel a script operation. /// public async Task HandleScriptCancelRequest(ScriptingCancelParams parameters, RequestContext requestContext) { try { ScriptingOperation operation = null; if (this.ActiveOperations.TryRemove(parameters.OperationId, out operation)) { operation.Cancel(); } else { Logger.Write(LogLevel.Normal, string.Format("Operation {0} was not found", operation.OperationId)); } await requestContext.SendResult(new ScriptingCancelResult()); } catch (Exception e) { await requestContext.SendError(e); } } private async void SendScriptingCompleteEvent(RequestContext requestContext, EventType eventType, TParams parameters, ScriptingScriptOperation operation, string scriptDestination) { await Task.Run(async () => { await requestContext.SendEvent(eventType, parameters); if (scriptDestination == "ToEditor") { await requestContext.SendResult(new ScriptingResult { OperationId = operation.OperationId, Script = operation.ScriptText }); } else if (scriptDestination == "ToSingleFile") { await requestContext.SendResult(new ScriptingResult { OperationId = operation.OperationId }); } else { await requestContext.SendError(string.Format("Operation {0} failed", operation.ToString())); } }); } private Urn BuildScriptingObjectUrn( Server server, SqlConnectionStringBuilder connectionStringBuilder, ScriptingObject scriptingObject) { string serverName = server.Name.ToUpper(); // remove the port from server name if specified int commaPos = serverName.IndexOf(','); if (commaPos >= 0) { serverName = serverName.Substring(0, commaPos); } // build the URN string urnString = string.Format( "Server[@Name='{0}']/Database[@Name='{1}']/{2}[@Name='{3}' {4}]", serverName, connectionStringBuilder.InitialCatalog, scriptingObject.Type, scriptingObject.Name, scriptingObject.Schema != null ? string.Format("and @Schema = '{0}'", scriptingObject.Schema) : string.Empty); return new Urn(urnString); } /// /// Runs the async task that performs the scripting operation. /// private void RunSelectTask(ConnectionInfo connInfo, ScriptingParams parameters, RequestContext requestContext) { ConnectionServiceInstance.ConnectionQueue.QueueBindingOperation( key: ConnectionServiceInstance.ConnectionQueue.AddConnectionContext(connInfo), bindingTimeout: ScriptingOperationTimeout, bindOperation: (bindingContext, cancelToken) => { string script = string.Empty; ScriptingObject scriptingObject = parameters.ScriptingObjects[0]; try { Server server = new Server(bindingContext.ServerConnection); server.DefaultTextMode = true; // build object URN SqlConnectionStringBuilder connectionStringBuilder = new SqlConnectionStringBuilder(parameters.ConnectionString); Urn objectUrn = BuildScriptingObjectUrn(server, connectionStringBuilder, scriptingObject); string typeName = objectUrn.GetNameForType(scriptingObject.Type); // select from service broker if (string.Compare(typeName, "ServiceBroker", StringComparison.CurrentCultureIgnoreCase) == 0) { script = Scripter.SelectAllValuesFromTransmissionQueue(objectUrn); } // select from queues else if (string.Compare(typeName, "Queues", StringComparison.CurrentCultureIgnoreCase) == 0 || string.Compare(typeName, "SystemQueues", StringComparison.CurrentCultureIgnoreCase) == 0) { script = Scripter.SelectAllValues(objectUrn); } // select from table or view else { Database db = server.Databases[connectionStringBuilder.InitialCatalog]; bool isDw = db.IsSqlDw; script = new Scripter().SelectFromTableOrView(server, objectUrn, isDw); } // send script result to client requestContext.SendResult(new ScriptingResult { Script = script }); } catch (Exception e) { requestContext.SendError(e); } return null; }); } /// /// Runs the async task that performs the scripting operation. /// private void RunTask(RequestContext context, ScriptingOperation operation) { Task.Run(() => { try { this.ActiveOperations[operation.OperationId] = operation; operation.Execute(); } catch (Exception e) { context.SendError(e); } finally { ScriptingOperation temp; this.ActiveOperations.TryRemove(operation.OperationId, out temp); } }); } /// /// Disposes the scripting service and all active scripting operations. /// public void Dispose() { if (!disposed) { disposed = true; foreach (ScriptingScriptOperation operation in this.ActiveOperations.Values) { operation.Dispose(); } } } } }