// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. // using Microsoft.SqlTools.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.TaskServices.Contracts; using System; using System.Threading.Tasks; using Microsoft.SqlTools.Hosting; using Microsoft.SqlTools.Extensibility; using Microsoft.SqlTools.Utility; using System.Linq; namespace Microsoft.SqlTools.ServiceLayer.TaskServices { public class TaskService: HostedService, IComposableService { private static readonly Lazy instance = new Lazy(() => new TaskService()); private SqlTaskManager taskManager = SqlTaskManager.Instance; private IProtocolEndpoint serviceHost; /// /// Default, parameterless constructor. /// public TaskService() { taskManager.TaskAdded += OnTaskAdded; } /// /// Gets the singleton instance object /// public static TaskService Instance { get { return instance.Value; } } /// /// Task Manager Instance to use for testing /// internal SqlTaskManager TaskManager { get { return taskManager; } } /// /// Initializes the service instance /// public override void InitializeService(IProtocolEndpoint serviceHost) { this.serviceHost = serviceHost; Logger.Write(LogLevel.Verbose, "TaskService initialized"); serviceHost.SetRequestHandler(ListTasksRequest.Type, HandleListTasksRequest); serviceHost.SetRequestHandler(CancelTaskRequest.Type, HandleCancelTaskRequest); } /// /// Handles a list tasks request /// internal async Task HandleListTasksRequest( ListTasksParams listTasksParams, RequestContext context) { Logger.Write(LogLevel.Verbose, "HandleListTasksRequest"); Func> getAllTasks = () => { Validate.IsNotNull(nameof(listTasksParams), listTasksParams); return Task.Factory.StartNew(() => { ListTasksResponse response = new ListTasksResponse(); response.Tasks = taskManager.Tasks.Select(x => x.ToTaskInfo()).ToArray(); return response; }); }; await HandleRequestAsync(getAllTasks, context, "HandleListTasksRequest"); } internal async Task HandleCancelTaskRequest(CancelTaskParams cancelTaskParams, RequestContext context) { Logger.Write(LogLevel.Verbose, "HandleCancelTaskRequest"); Func> cancelTask = () => { Validate.IsNotNull(nameof(cancelTaskParams), cancelTaskParams); return Task.Factory.StartNew(() => { Guid taskId; if (Guid.TryParse(cancelTaskParams.TaskId, out taskId)) { taskManager.CancelTask(taskId); return true; } else { return false; } }); }; await HandleRequestAsync(cancelTask, context, "HandleCancelTaskRequest"); } private async void OnTaskAdded(object sender, TaskEventArgs e) { SqlTask sqlTask = e.TaskData; if (sqlTask != null) { TaskInfo taskInfo = sqlTask.ToTaskInfo(); sqlTask.MessageAdded += OnTaskMessageAdded; sqlTask.StatusChanged += OnTaskStatusChanged; await serviceHost.SendEvent(TaskCreatedNotification.Type, taskInfo); } } private async void OnTaskStatusChanged(object sender, TaskEventArgs e) { SqlTask sqlTask = e.SqlTask; if (sqlTask != null) { TaskProgressInfo progressInfo = new TaskProgressInfo { TaskId = sqlTask.TaskId.ToString(), Status = e.TaskData }; if (sqlTask.IsCompleted) { progressInfo.Duration = sqlTask.Duration; } await serviceHost.SendEvent(TaskStatusChangedNotification.Type, progressInfo); } } private async void OnTaskMessageAdded(object sender, TaskEventArgs e) { SqlTask sqlTask = e.SqlTask; if (sqlTask != null) { TaskProgressInfo progressInfo = new TaskProgressInfo { TaskId = sqlTask.TaskId.ToString(), Message = e.TaskData.Description, Status = sqlTask.TaskStatus }; await serviceHost.SendEvent(TaskStatusChangedNotification.Type, progressInfo); } } public void Dispose() { taskManager.TaskAdded -= OnTaskAdded; taskManager.Dispose(); } } }