From b0263f8867e2bc166c137abe68e50035146c10c6 Mon Sep 17 00:00:00 2001 From: Leila Lali Date: Mon, 12 Jun 2017 11:02:57 -0700 Subject: [PATCH] Added task service (#374) * Added task service --- .../Hosting/IHostedService.cs | 19 + .../ObjectExplorer/ObjectExplorerService.cs | 17 - .../Contracts/ListTasksRequest.cs | 8 +- .../Contracts/TaskCancelRequest.cs | 24 + .../TaskServices/Contracts/TaskInfo.cs | 48 +- .../Contracts/TaskNotifications.cs | 29 ++ .../TaskServices/Contracts/TaskProgress.cs | 31 ++ .../TaskServices/SqlTask.cs | 413 ++++++++++++++++++ .../TaskServices/SqlTaskManager.cs | 210 +++++++++ .../TaskServices/TaskEventArgs.cs | 40 ++ .../TaskServices/TaskMessage.cs | 14 + .../TaskServices/TaskMetadata.cs | 41 ++ .../TaskServices/TaskResult.cs | 14 + .../TaskServices/TaskService.cs | 128 +++++- .../TaskServices/TaskStatus.cs | 18 + .../ObjectExplorerServiceTests.cs | 25 -- .../ObjectExplorer/ObjectExplorerTestBase.cs | 17 +- .../ServiceTestBase.cs | 71 +++ .../TaskServices/DatabaseOperationStub.cs | 61 +++ .../TaskServices/SqlTaskTests.cs | 139 ++++++ .../TaskServices/TaskManagerTests.cs | 86 ++++ .../TaskServices/TaskServiceTests.cs | 131 ++++++ 22 files changed, 1501 insertions(+), 83 deletions(-) create mode 100644 src/Microsoft.SqlTools.ServiceLayer/TaskServices/Contracts/TaskCancelRequest.cs create mode 100644 src/Microsoft.SqlTools.ServiceLayer/TaskServices/Contracts/TaskNotifications.cs create mode 100644 src/Microsoft.SqlTools.ServiceLayer/TaskServices/Contracts/TaskProgress.cs create mode 100644 src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTask.cs create mode 100644 src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTaskManager.cs create mode 100644 src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskEventArgs.cs create mode 100644 src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskMessage.cs create mode 100644 src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskMetadata.cs create mode 100644 src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskResult.cs create mode 100644 src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskStatus.cs create mode 100644 test/Microsoft.SqlTools.ServiceLayer.UnitTests/ServiceTestBase.cs create mode 100644 test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/DatabaseOperationStub.cs create mode 100644 test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/SqlTaskTests.cs create mode 100644 test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/TaskManagerTests.cs create mode 100644 test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/TaskServiceTests.cs diff --git a/src/Microsoft.SqlTools.Hosting/Hosting/IHostedService.cs b/src/Microsoft.SqlTools.Hosting/Hosting/IHostedService.cs index 25443404..61b73715 100644 --- a/src/Microsoft.SqlTools.Hosting/Hosting/IHostedService.cs +++ b/src/Microsoft.SqlTools.Hosting/Hosting/IHostedService.cs @@ -4,8 +4,10 @@ // using System; +using System.Threading.Tasks; using Microsoft.SqlTools.Extensibility; using Microsoft.SqlTools.Hosting.Protocol; +using Microsoft.SqlTools.Utility; namespace Microsoft.SqlTools.Hosting { @@ -65,6 +67,23 @@ namespace Microsoft.SqlTools.Hosting } } + protected async Task HandleRequestAsync(Func> handler, RequestContext requestContext, string requestType) + { + Logger.Write(LogLevel.Verbose, requestType); + + try + { + T result = await handler(); + await requestContext.SendResult(result); + return result; + } + catch (Exception ex) + { + await requestContext.SendError(ex.ToString()); + } + return default(T); + } + public abstract void InitializeService(IProtocolEndpoint serviceHost); } diff --git a/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs index 49ea5721..4b378392 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs @@ -480,23 +480,6 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer return new ExpandResponse() { SessionId = session.Uri, NodePath = expandParams.NodePath }; } - private async Task HandleRequestAsync(Func> handler, RequestContext requestContext, string requestType) - { - Logger.Write(LogLevel.Verbose, requestType); - - try - { - T result = await handler(); - await requestContext.SendResult(result); - return result; - } - catch (Exception ex) - { - await requestContext.SendError(ex.ToString()); - } - return default(T); - } - /// /// Generates a URI for object explorer using a similar pattern to Mongo DB (which has URI-based database definition) /// as this should ensure uniqueness diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/Contracts/ListTasksRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/Contracts/ListTasksRequest.cs index 5bcb0e07..fffa1d7d 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/Contracts/ListTasksRequest.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/Contracts/ListTasksRequest.cs @@ -4,22 +4,18 @@ // using Microsoft.SqlTools.Hosting.Protocol.Contracts; -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; namespace Microsoft.SqlTools.ServiceLayer.TaskServices.Contracts { public class ListTasksParams { - bool ListActiveTasksOnly { get; set; } + bool ListActiveTasksOnly { get; set; } } public class ListTasksResponse { - TaskInfo[] Tasks { get; set; } + public TaskInfo[] Tasks { get; set; } } public class ListTasksRequest diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/Contracts/TaskCancelRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/Contracts/TaskCancelRequest.cs new file mode 100644 index 00000000..a9e2603e --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/Contracts/TaskCancelRequest.cs @@ -0,0 +1,24 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.TaskServices.Contracts +{ + public class CancelTaskParams + { + /// + /// An id to unify the task + /// + public string TaskId { get; set; } + } + + public class CancelTaskRequest + { + public static readonly + RequestType Type = + RequestType.Create("tasks/canceltask"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/Contracts/TaskInfo.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/Contracts/TaskInfo.cs index 47f2a59b..2c1a4bea 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/Contracts/TaskInfo.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/Contracts/TaskInfo.cs @@ -1,21 +1,43 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + namespace Microsoft.SqlTools.ServiceLayer.TaskServices.Contracts { - public enum TaskState - { - NotStarted = 0, - Running = 1, - Complete = 2 - } - public class TaskInfo { - public int TaskId { get; set; } + /// + /// An id to unify the task + /// + public string TaskId { get; set; } + + /// + /// Task status + /// + public SqlTaskStatus Status { get; set; } + + /// + /// Database server name this task is created for + /// + public string ServerName { get; set; } + + /// + /// Database name this task is created for + /// + public string DatabaseName { get; set; } + + + /// + /// Task name which defines the type of the task (e.g. CreateDatabase, Backup) + /// + public string Name { get; set; } + + /// + /// Task description + /// + public string Description { get; set; } - public TaskState State { get; set; } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/Contracts/TaskNotifications.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/Contracts/TaskNotifications.cs new file mode 100644 index 00000000..a900093e --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/Contracts/TaskNotifications.cs @@ -0,0 +1,29 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.TaskServices.Contracts +{ + /// + /// Expand notification mapping entry + /// + public class TaskCreatedNotification + { + public static readonly + EventType Type = + EventType.Create("task/newtaskcreated"); + } + + /// + /// Expand notification mapping entry + /// + public class TaskStatusChangedNotification + { + public static readonly + EventType Type = + EventType.Create("task/statuschanged"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/Contracts/TaskProgress.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/Contracts/TaskProgress.cs new file mode 100644 index 00000000..26648903 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/Contracts/TaskProgress.cs @@ -0,0 +1,31 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.TaskServices.Contracts +{ + public class TaskProgressInfo + { + /// + /// An id to unify the task + /// + public string TaskId { get; set; } + + /// + /// Task status + /// + public SqlTaskStatus Status { get; set; } + + /// + /// Database server name this task is created for + /// + public string Message { get; set; } + + /// + /// The number of millisecond the task was running + /// + public double Duration { get; set; } + + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTask.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTask.cs new file mode 100644 index 00000000..0188ebd2 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTask.cs @@ -0,0 +1,413 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.TaskServices.Contracts; +using Microsoft.SqlTools.Utility; + +namespace Microsoft.SqlTools.ServiceLayer.TaskServices +{ + /// + /// A wrapper to a long running database operation. The class holds a refrence to the actual task that's running + /// and keeps track of the task status to send notifications + /// + public class SqlTask : IDisposable + { + private bool isCompleted; + private bool isCanceled; + private bool isDisposed; + private readonly object lockObject = new object(); + private readonly List messages = new List(); + + private DateTime startTime; + private SqlTaskStatus status = SqlTaskStatus.NotStarted; + private DateTime stopTime; + + public event EventHandler> MessageAdded; + public event EventHandler> StatusChanged; + public event EventHandler> TaskCanceled; + + /// + /// Creates new instance of SQL task + /// + /// Task Metadata + /// The function to run to start the task + public SqlTask(TaskMetadata taskMetdata, Func> testToRun) + { + Validate.IsNotNull(nameof(taskMetdata), taskMetdata); + Validate.IsNotNull(nameof(testToRun), testToRun); + + TaskMetadata = taskMetdata; + TaskToRun = testToRun; + StartTime = DateTime.UtcNow; + TaskId = Guid.NewGuid(); + } + + /// + /// Task Metadata + /// + internal TaskMetadata TaskMetadata { get; private set; } + + /// + /// The function to run + /// + private Func> TaskToRun + { + get; + set; + } + + /// + /// Task unique id + /// + public Guid TaskId { get; private set; } + + /// + /// Starts the task and monitor the task progress + /// + public async Task Run() + { + TaskStatus = SqlTaskStatus.InProgress; + await TaskToRun(this).ContinueWith(task => + { + if (task.IsCompleted) + { + TaskResult taskResult = task.Result; + TaskStatus = taskResult.TaskStatus; + } + else if(task.IsCanceled) + { + TaskStatus = SqlTaskStatus.Canceled; + } + else if(task.IsFaulted) + { + TaskStatus = SqlTaskStatus.Failed; + if(task.Exception != null) + { + AddMessage(task.Exception.Message); + } + } + }); + } + + /// + /// Returns true if task has any messages + /// + public bool HasMessages + { + get + { + lock (lockObject) + { + return messages.Any(); + } + } + } + + /// + /// Setting this to True will not change the Slot status. + /// Setting the Slot status to Canceled will set this to true. + /// + public bool IsCanceled + { + get + { + return isCanceled; + } + private set + { + if (isCanceled != value) + { + isCanceled = value; + OnTaskCanceled(); + } + } + } + + /// + /// Returns true if task is canceled, failed or succeed + /// + public bool IsCompleted + { + get + { + return isCompleted; + } + private set + { + if (isCompleted != value) + { + isCompleted = value; + if (isCompleted) + { + StopTime = DateTime.UtcNow; + } + } + } + } + + /// + /// Task Messages + /// + internal ReadOnlyCollection Messages + { + get + { + lock (lockObject) + { + return messages.AsReadOnly(); + } + } + } + + /// + /// Start Time + /// + public DateTime StartTime + { + get + { + return startTime; + } + internal set + { + startTime = value; + } + } + + /// + /// The total number of seconds to run the task + /// + public double Duration + { + get + { + return (stopTime - startTime).TotalMilliseconds; + } + } + + + /// + /// Task Status + /// + public SqlTaskStatus TaskStatus + { + get + { + return status; + } + private set + { + status = value; + switch (status) + { + case SqlTaskStatus.Canceled: + case SqlTaskStatus.Failed: + case SqlTaskStatus.Succeeded: + case SqlTaskStatus.SucceededWithWarning: + IsCompleted = true; + break; + case SqlTaskStatus.InProgress: + case SqlTaskStatus.NotStarted: + IsCompleted = false; + break; + default: + throw new NotSupportedException("IsCompleted is not determined for status: " + status); + } + + if (status == SqlTaskStatus.Canceled) + { + IsCanceled = true; + } + + OnStatusChanged(); + } + } + + /// + /// The date time that the task was complete + /// + public DateTime StopTime + { + get + { + return stopTime; + } + internal set + { + stopTime = value; + } + } + + /// + /// Try to cancel the task, and even to cancel the task will be raised + /// but the status won't change until that task actually get canceled by it's owner + /// + public void Cancel() + { + IsCanceled = true; + } + + /// + /// Adds a new message to the task messages + /// + /// Message description + /// Status of the message + /// If true, the new messages will be added to the top. Default is false + /// + public TaskMessage AddMessage(string description, SqlTaskStatus status = SqlTaskStatus.NotStarted, bool insertAboveLast = false) + { + ValidateNotDisposed(); + + if (!insertAboveLast) + { + // Make sure the last message is set to a completed status if a new message is being added at the bottom + CompleteLastMessageStatus(); + } + + var newMessage = new TaskMessage + { + Description = description, + Status = status, + }; + + lock (lockObject) + { + if (!insertAboveLast || messages.Count == 0) + { + messages.Add(newMessage); + } + else + { + int lastMessageIndex = messages.Count - 1; + messages.Insert(lastMessageIndex, newMessage); + } + } + OnMessageAdded(new TaskEventArgs(newMessage, this)); + + // If the slot is completed, this may be the last message, make sure the message is also set to completed. + if (IsCompleted) + { + CompleteLastMessageStatus(); + } + + return newMessage; + } + + /// + /// Converts the task to Task info to be used in the contracts + /// + /// + public TaskInfo ToTaskInfo() + { + return new TaskInfo + { + DatabaseName = TaskMetadata.DatabaseName, + ServerName = TaskMetadata.ServerName, + Name = TaskMetadata.Name, + Description = TaskMetadata.Description, + TaskId = TaskId.ToString() + }; + } + + /// + /// Makes sure the last message has a 'completed' status if it has a status of InProgress. + /// If success is true, then sets the status to Succeeded. Sets it to Failed if success is false. + /// If success is null (default), then the message status is based on the status of the slot. + /// + private void CompleteLastMessageStatus(bool? success = null) + { + var message = GetLastMessage(); + if (message != null) + { + if (message.Status == SqlTaskStatus.InProgress) + { + // infer the success boolean from the slot status if it's not set + if (success == null) + { + switch (TaskStatus) + { + case SqlTaskStatus.Canceled: + case SqlTaskStatus.Failed: + success = false; + break; + default: + success = true; + break; + } + } + + message.Status = success.Value ? SqlTaskStatus.Succeeded : SqlTaskStatus.Failed; + } + } + } + + private void OnMessageAdded(TaskEventArgs e) + { + var handler = MessageAdded; + if (handler != null) + { + handler(this, e); + } + } + + private void OnStatusChanged() + { + var handler = StatusChanged; + if (handler != null) + { + handler(this, new TaskEventArgs(TaskStatus, this)); + } + } + + private void OnTaskCanceled() + { + var handler = TaskCanceled; + if (handler != null) + { + handler(this, new TaskEventArgs(TaskStatus, this)); + } + } + + public void Dispose() + { + //Dispose + isDisposed = true; + } + + + + protected void ValidateNotDisposed() + { + if (isDisposed) + { + throw new ObjectDisposedException(typeof(SqlTask).FullName); + } + } + + /// + /// Returns the most recently created message. Returns null if there are no messages on the slot. + /// + public TaskMessage GetLastMessage() + { + ValidateNotDisposed(); + + lock (lockObject) + { + if (messages.Count > 0) + { + // get + return messages.Last(); + } + } + + return null; + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTaskManager.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTaskManager.cs new file mode 100644 index 00000000..659aaabb --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTaskManager.cs @@ -0,0 +1,210 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Concurrent; +using System.Collections.ObjectModel; +using System.Linq; +using System.Threading.Tasks; + +namespace Microsoft.SqlTools.ServiceLayer.TaskServices +{ + /// + /// A singleton class to manager the current long running operations + /// + public class SqlTaskManager : IDisposable + { + private static SqlTaskManager instance = new SqlTaskManager(); + private static readonly object lockObject = new object(); + private bool isDisposed; + private readonly ConcurrentDictionary tasks = new ConcurrentDictionary(); + + public event EventHandler> TaskAdded; + public event EventHandler> TaskRemoved; + + + /// + /// Constructor to create an instance for test purposes use only + /// + internal SqlTaskManager() + { + + } + + /// + /// Singleton instance + /// + public static SqlTaskManager Instance + { + get + { + return instance; + } + } + + /// + /// Task connections + /// + internal ReadOnlyCollection Tasks + { + get + { + lock (lockObject) + { + return new ReadOnlyCollection(tasks.Values.ToList()); + } + } + } + + /// + /// Clear completed tasks + /// + internal void ClearCompletedTasks() + { + ValidateNotDisposed(); + + lock (lockObject) + { + var tasksToRemove = (from task in tasks.Values + where task.IsCompleted + select task).ToList(); + foreach (var task in tasksToRemove) + { + RemoveCompletedTask(task); + } + } + } + + /// + /// Creates a new task + /// + /// Task Metadata + /// The function to run the operation + /// + public SqlTask CreateTask(TaskMetadata taskMetadata, Func> taskToRun) + { + ValidateNotDisposed(); + + var newtask = new SqlTask(taskMetadata, taskToRun ); + + lock (lockObject) + { + tasks.AddOrUpdate(newtask.TaskId, newtask, (key, oldValue) => newtask); + } + OnTaskAdded(new TaskEventArgs(newtask)); + return newtask; + } + + public void Dispose() + { + Dispose(true); + } + + void Dispose(bool disposing) + { + if (isDisposed) + { + return; + } + + if (disposing) + { + lock (lockObject) + { + foreach (var task in tasks.Values) + { + task.Dispose(); + } + tasks.Clear(); + } + } + + isDisposed = true; + } + + /// + /// Returns true if there's any completed task + /// + /// + internal bool HasCompletedTasks() + { + lock (lockObject) + { + return tasks.Values.Any(task => task.IsCompleted); + } + } + + private void OnTaskAdded(TaskEventArgs e) + { + var handler = TaskAdded; + if (handler != null) + { + handler(this, e); + } + } + + private void OnTaskRemoved(TaskEventArgs e) + { + var handler = TaskRemoved; + if (handler != null) + { + handler(this, e); + } + } + + /// + /// Cancel a task + /// + /// + public void CancelTask(Guid taskId) + { + SqlTask taskToCancel; + + lock (lockObject) + { + tasks.TryGetValue(taskId, out taskToCancel); + } + if(taskToCancel != null) + { + taskToCancel.Cancel(); + } + } + + /// + /// Internal for test purposes only. + /// Removes all tasks regardless of status from the model. + /// This is used as a test aid since Monitor is a singleton class. + /// + internal void Reset() + { + foreach (var task in tasks.Values) + { + RemoveTask(task); + } + } + + internal void RemoveCompletedTask(SqlTask task) + { + if (task.IsCompleted) + { + RemoveTask(task); + } + } + + private void RemoveTask(SqlTask task) + { + SqlTask removedTask; + tasks.TryRemove(task.TaskId, out removedTask); + } + + void ValidateNotDisposed() + { + if (isDisposed) + { + throw new ObjectDisposedException(typeof(SqlTaskManager).FullName); + } + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskEventArgs.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskEventArgs.cs new file mode 100644 index 00000000..4775822e --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskEventArgs.cs @@ -0,0 +1,40 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using Microsoft.SqlTools.Utility; + +namespace Microsoft.SqlTools.ServiceLayer.TaskServices +{ + public sealed class TaskEventArgs : EventArgs + { + readonly T taskData; + + public TaskEventArgs(T taskData, SqlTask sqlTask) + { + Validate.IsNotNull(nameof(taskData), taskData); + + this.taskData = taskData; + SqlTask = sqlTask; + } + + + public TaskEventArgs(SqlTask sqlTask) + { + taskData = (T)Convert.ChangeType(sqlTask, typeof(T)); + SqlTask = sqlTask; + } + + public T TaskData + { + get + { + return taskData; + } + } + + public SqlTask SqlTask { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskMessage.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskMessage.cs new file mode 100644 index 00000000..11ee61e8 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskMessage.cs @@ -0,0 +1,14 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.TaskServices +{ + public class TaskMessage + { + public SqlTaskStatus Status { get; set; } + + public string Description { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskMetadata.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskMetadata.cs new file mode 100644 index 00000000..2678c12e --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskMetadata.cs @@ -0,0 +1,41 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.TaskServices +{ + public class TaskMetadata + { + /// + /// Task Description + /// + public string Description { get; set; } + + /// + /// Task name to define the type of the task e.g. Create Db, back up + /// + public string Name { get; set; } + + /// + /// The number of seconds to wait before canceling the task. + /// This is a optional field and 0 or negative numbers means no timeout + /// + public int Timeout { get; set; } + + /// + /// Defines if the task can be canceled + /// + public bool IsCancelable { get; set; } + + /// + /// Database server name this task is created for + /// + public string ServerName { get; set; } + + /// + /// Database name this task is created for + /// + public string DatabaseName { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskResult.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskResult.cs new file mode 100644 index 00000000..caa608d8 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskResult.cs @@ -0,0 +1,14 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.TaskServices +{ + public class TaskResult + { + public SqlTaskStatus TaskStatus { get; set; } + + public string ErrorMessage { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskService.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskService.cs index a0156c20..d8264b81 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskService.cs @@ -4,23 +4,29 @@ // using Microsoft.SqlTools.Hosting.Protocol; -using Microsoft.SqlTools.ServiceLayer.Hosting; -using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.TaskServices.Contracts; using System; using System.Threading.Tasks; +using Microsoft.SqlTools.Hosting; +using Microsoft.SqlTools.Extensibility; +using Microsoft.SqlTools.Utility; +using System.Linq; namespace Microsoft.SqlTools.ServiceLayer.TaskServices { - public class TaskService + public class TaskService: HostedService, IComposableService { private static readonly Lazy instance = new Lazy(() => new TaskService()); + private SqlTaskManager taskManager = SqlTaskManager.Instance; + private IProtocolEndpoint serviceHost; + /// /// Default, parameterless constructor. /// - internal TaskService() + public TaskService() { + taskManager.TaskAdded += OnTaskAdded; } /// @@ -31,22 +37,128 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices get { return instance.Value; } } + /// + /// Task Manager Instance to use for testing + /// + internal SqlTaskManager TaskManager + { + get + { + return taskManager; + } + } + /// /// Initializes the service instance /// - public void InitializeService(ServiceHost serviceHost, SqlToolsContext context) + public override void InitializeService(IProtocolEndpoint serviceHost) { + this.serviceHost = serviceHost; + Logger.Write(LogLevel.Verbose, "TaskService initialized"); serviceHost.SetRequestHandler(ListTasksRequest.Type, HandleListTasksRequest); + serviceHost.SetRequestHandler(CancelTaskRequest.Type, HandleCancelTaskRequest); } /// /// Handles a list tasks request /// - internal static async Task HandleListTasksRequest( + internal async Task HandleListTasksRequest( ListTasksParams listTasksParams, - RequestContext requestContext) + RequestContext context) { - await requestContext.SendResult(new ListTasksResponse()); + Logger.Write(LogLevel.Verbose, "HandleListTasksRequest"); + + Func> getAllTasks = () => + { + Validate.IsNotNull(nameof(listTasksParams), listTasksParams); + return Task.Factory.StartNew(() => + { + ListTasksResponse response = new ListTasksResponse(); + response.Tasks = taskManager.Tasks.Select(x => x.ToTaskInfo()).ToArray(); + + return response; + }); + + }; + + await HandleRequestAsync(getAllTasks, context, "HandleListTasksRequest"); + } + + internal async Task HandleCancelTaskRequest(CancelTaskParams cancelTaskParams, RequestContext context) + { + Logger.Write(LogLevel.Verbose, "HandleCancelTaskRequest"); + Func> cancelTask = () => + { + Validate.IsNotNull(nameof(cancelTaskParams), cancelTaskParams); + + return Task.Factory.StartNew(() => + { + Guid taskId; + if (Guid.TryParse(cancelTaskParams.TaskId, out taskId)) + { + taskManager.CancelTask(taskId); + return true; + } + else + { + return false; + } + }); + + }; + + await HandleRequestAsync(cancelTask, context, "HandleCancelTaskRequest"); + } + + private async void OnTaskAdded(object sender, TaskEventArgs e) + { + SqlTask sqlTask = e.TaskData; + if (sqlTask != null) + { + TaskInfo taskInfo = sqlTask.ToTaskInfo(); + sqlTask.MessageAdded += OnTaskMessageAdded; + sqlTask.StatusChanged += OnTaskStatusChanged; + await serviceHost.SendEvent(TaskCreatedNotification.Type, taskInfo); + } + } + + private async void OnTaskStatusChanged(object sender, TaskEventArgs e) + { + SqlTask sqlTask = e.SqlTask; + if (sqlTask != null) + { + TaskProgressInfo progressInfo = new TaskProgressInfo + { + TaskId = sqlTask.TaskId.ToString(), + Status = e.TaskData + + }; + if (sqlTask.IsCompleted) + { + progressInfo.Duration = sqlTask.Duration; + } + await serviceHost.SendEvent(TaskStatusChangedNotification.Type, progressInfo); + } + } + + private async void OnTaskMessageAdded(object sender, TaskEventArgs e) + { + SqlTask sqlTask = e.SqlTask; + if (sqlTask != null) + { + TaskProgressInfo progressInfo = new TaskProgressInfo + { + TaskId = sqlTask.TaskId.ToString(), + Message = e.TaskData.Description + }; + await serviceHost.SendEvent(TaskStatusChangedNotification.Type, progressInfo); + } + } + + public void Dispose() + { + taskManager.TaskAdded -= OnTaskAdded; + taskManager.Dispose(); } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskStatus.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskStatus.cs new file mode 100644 index 00000000..ee697c5c --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskStatus.cs @@ -0,0 +1,18 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + + +namespace Microsoft.SqlTools.ServiceLayer.TaskServices +{ + public enum SqlTaskStatus + { + NotStarted, + InProgress, + Succeeded, + SucceededWithWarning, + Failed, + Canceled + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/ObjectExplorerServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/ObjectExplorerServiceTests.cs index fc8bd2bd..da386f92 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/ObjectExplorerServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/ObjectExplorerServiceTests.cs @@ -396,30 +396,5 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer ServerInfo = TestObjects.GetTestServerInfo() }; } - - private async Task RunAndVerify(Func, Task> test, Action verify) - { - T result = default(T); - var contextMock = RequestContextMocks.Create(r => result = r).AddErrorHandling(null); - TResult actualResult = await test(contextMock.Object); - if (actualResult == null && typeof(TResult) == typeof(T)) - { - actualResult = (TResult)Convert.ChangeType(result, typeof(TResult)); - } - VerifyResult(contextMock, verify, actualResult); - } - - private void VerifyResult(Mock> contextMock, Action verify, TResult actual) - { - contextMock.Verify(c => c.SendResult(It.IsAny()), Times.Once); - contextMock.Verify(c => c.SendError(It.IsAny(), It.IsAny()), Times.Never); - verify(actual); - } - - private void VerifyErrorSent(Mock> contextMock) - { - contextMock.Verify(c => c.SendResult(It.IsAny()), Times.Never); - contextMock.Verify(c => c.SendError(It.IsAny(), It.IsAny()), Times.Once); - } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/ObjectExplorerTestBase.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/ObjectExplorerTestBase.cs index bf79a129..3d7d9cc0 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/ObjectExplorerTestBase.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/ObjectExplorerTestBase.cs @@ -12,27 +12,16 @@ using Microsoft.SqlTools.ServiceLayer.ObjectExplorer.SmoModel; namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer { // Base class providing common test functionality for OE tests - public abstract class ObjectExplorerTestBase + public abstract class ObjectExplorerTestBase : ServiceTestBase { - protected RegisteredServiceProvider ServiceProvider - { - get; - set; - } - - protected RegisteredServiceProvider CreateServiceProviderWithMinServices() + + protected override RegisteredServiceProvider CreateServiceProviderWithMinServices() { return CreateProvider() .RegisterSingleService(new ConnectionService()) .RegisterSingleService(new ObjectExplorerService()); } - protected RegisteredServiceProvider CreateProvider() - { - ServiceProvider = new RegisteredServiceProvider(); - return ServiceProvider; - } - protected ObjectExplorerService CreateOEService(ConnectionService connService) { CreateProvider() diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ServiceTestBase.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ServiceTestBase.cs new file mode 100644 index 00000000..14971dce --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ServiceTestBase.cs @@ -0,0 +1,71 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Threading.Tasks; +using Microsoft.SqlTools.Extensibility; +using Microsoft.SqlTools.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.UnitTests.Utility; +using Moq; + +namespace Microsoft.SqlTools.ServiceLayer.UnitTests +{ + public abstract class ServiceTestBase + { + protected RegisteredServiceProvider ServiceProvider + { + get; + set; + } + + protected RegisteredServiceProvider CreateProvider() + { + ServiceProvider = new RegisteredServiceProvider(); + return ServiceProvider; + } + + protected abstract RegisteredServiceProvider CreateServiceProviderWithMinServices(); + + protected async Task RunAndVerify(Func, Task> test, Action verify) + { + T result = default(T); + var contextMock = RequestContextMocks.Create(r => result = r).AddErrorHandling(null); + TResult actualResult = await test(contextMock.Object); + if (actualResult == null && typeof(TResult) == typeof(T)) + { + actualResult = (TResult)Convert.ChangeType(result, typeof(TResult)); + } + VerifyResult(contextMock, verify, actualResult); + } + + protected async Task RunAndVerify(Func, Task> test, Action verify) + { + T result = default(T); + var contextMock = RequestContextMocks.Create(r => result = r).AddErrorHandling(null); + await test(contextMock.Object); + VerifyResult(contextMock, verify, result); + } + + protected void VerifyResult(Mock> contextMock, Action verify, TResult actual) + { + contextMock.Verify(c => c.SendResult(It.IsAny()), Times.Once); + contextMock.Verify(c => c.SendError(It.IsAny(), It.IsAny()), Times.Never); + verify(actual); + } + + protected void VerifyResult(Mock> contextMock, Action verify, T actual) + { + contextMock.Verify(c => c.SendResult(It.IsAny()), Times.Once); + contextMock.Verify(c => c.SendError(It.IsAny(), It.IsAny()), Times.Never); + verify(actual); + } + + protected void VerifyErrorSent(Mock> contextMock) + { + contextMock.Verify(c => c.SendResult(It.IsAny()), Times.Never); + contextMock.Verify(c => c.SendError(It.IsAny(), It.IsAny()), Times.Once); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/DatabaseOperationStub.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/DatabaseOperationStub.cs new file mode 100644 index 00000000..6acb0926 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/DatabaseOperationStub.cs @@ -0,0 +1,61 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.TaskServices; + +namespace Microsoft.SqlTools.ServiceLayer.UnitTests.TaskServices +{ + public class DatabaseOperationStub + { + private CancellationTokenSource cancellationTokenSource = new CancellationTokenSource(); + public void Stop() + { + IsStopped = true; + } + + public void FailTheOperation() + { + Failed = true; + } + + public TaskResult TaskResult { get; set; } + + public bool IsStopped { get; set; } + + public bool Failed { get; set; } + + public async Task FunctionToRun(SqlTask sqlTask) + { + sqlTask.TaskCanceled += OnTaskCanceled; + return await Task.Factory.StartNew(() => + { + while (!IsStopped) + { + //Just keep running + if (cancellationTokenSource.Token.IsCancellationRequested) + { + throw new OperationCanceledException(); + } + if (Failed) + { + throw new InvalidOperationException(); + } + sqlTask.AddMessage("still running", SqlTaskStatus.InProgress, true); + } + sqlTask.AddMessage("done!", SqlTaskStatus.Succeeded); + + return TaskResult; + }); + } + + private void OnTaskCanceled(object sender, TaskEventArgs e) + { + cancellationTokenSource.Cancel(); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/SqlTaskTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/SqlTaskTests.cs new file mode 100644 index 00000000..3b7ccfbe --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/SqlTaskTests.cs @@ -0,0 +1,139 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using Microsoft.SqlTools.ServiceLayer.TaskServices; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.UnitTests.TaskServices +{ + public class SqlTaskTests + { + [Fact] + public void CreateSqlTaskGivenInvalidArgumentShouldThrowException() + { + Assert.Throws(() => new SqlTask(null, new DatabaseOperationStub().FunctionToRun)); + Assert.Throws(() => new SqlTask(new TaskMetadata(), null)); + } + + [Fact] + public void CreateSqlTaskShouldGenerateANewId() + { + SqlTask sqlTask = new SqlTask(new TaskMetadata(), new DatabaseOperationStub().FunctionToRun); + Assert.NotNull(sqlTask.TaskId); + Assert.True(sqlTask.TaskId != Guid.Empty); + + SqlTask sqlTask2 = new SqlTask(new TaskMetadata(), new DatabaseOperationStub().FunctionToRun); + Assert.False(sqlTask.TaskId.CompareTo(sqlTask2.TaskId) == 0); + } + + [Fact] + public void RunShouldRunTheFunctionAndGetTheResult() + { + SqlTaskStatus expectedStatus = SqlTaskStatus.Succeeded; + DatabaseOperationStub operation = new DatabaseOperationStub(); + operation.TaskResult = new TaskResult + { + TaskStatus = expectedStatus + }; + SqlTask sqlTask = new SqlTask(new TaskMetadata(), operation.FunctionToRun); + Assert.Equal(sqlTask.TaskStatus, SqlTaskStatus.NotStarted); + + sqlTask.Run().ContinueWith(task => { + Assert.Equal(sqlTask.TaskStatus, expectedStatus); + Assert.Equal(sqlTask.IsCompleted, true); + Assert.True(sqlTask.Duration > 0); + }); + Assert.Equal(sqlTask.TaskStatus, SqlTaskStatus.InProgress); + operation.Stop(); + } + + [Fact] + public void ToTaskInfoShouldReturnTaskInfo() + { + SqlTaskStatus expectedStatus = SqlTaskStatus.Succeeded; + DatabaseOperationStub operation = new DatabaseOperationStub(); + operation.TaskResult = new TaskResult + { + TaskStatus = expectedStatus + }; + SqlTask sqlTask = new SqlTask(new TaskMetadata + { + ServerName = "server name", + DatabaseName = "database name" + }, operation.FunctionToRun); + + sqlTask.Run().ContinueWith(task => + { + var taskInfo = sqlTask.ToTaskInfo(); + Assert.Equal(taskInfo.TaskId, sqlTask.TaskId.ToString()); + Assert.Equal(taskInfo.ServerName, "server name"); + Assert.Equal(taskInfo.DatabaseName, "database name"); + }); + operation.Stop(); + } + + [Fact] + public void FailedOperationShouldReturnTheFailedResult() + { + SqlTaskStatus expectedStatus = SqlTaskStatus.Failed; + DatabaseOperationStub operation = new DatabaseOperationStub(); + operation.TaskResult = new TaskResult + { + TaskStatus = expectedStatus + }; + SqlTask sqlTask = new SqlTask(new TaskMetadata(), operation.FunctionToRun); + Assert.Equal(sqlTask.TaskStatus, SqlTaskStatus.NotStarted); + + sqlTask.Run().ContinueWith(task => { + Assert.Equal(sqlTask.TaskStatus, expectedStatus); + Assert.Equal(sqlTask.IsCompleted, true); + Assert.True(sqlTask.Duration > 0); + }); + Assert.Equal(sqlTask.TaskStatus, SqlTaskStatus.InProgress); + operation.Stop(); + } + + [Fact] + public void CancelingTheTaskShouldCancelTheOperation() + { + SqlTaskStatus expectedStatus = SqlTaskStatus.Canceled; + DatabaseOperationStub operation = new DatabaseOperationStub(); + operation.TaskResult = new TaskResult + { + }; + SqlTask sqlTask = new SqlTask(new TaskMetadata(), operation.FunctionToRun); + Assert.Equal(sqlTask.TaskStatus, SqlTaskStatus.NotStarted); + + sqlTask.Run().ContinueWith(task => { + Assert.Equal(sqlTask.TaskStatus, expectedStatus); + Assert.Equal(sqlTask.IsCanceled, true); + Assert.True(sqlTask.Duration > 0); + }); + Assert.Equal(sqlTask.TaskStatus, SqlTaskStatus.InProgress); + sqlTask.Cancel(); + } + + [Fact] + public void FailedOperationShouldFailTheTask() + { + SqlTaskStatus expectedStatus = SqlTaskStatus.Failed; + DatabaseOperationStub operation = new DatabaseOperationStub(); + operation.TaskResult = new TaskResult + { + }; + SqlTask sqlTask = new SqlTask(new TaskMetadata(), operation.FunctionToRun); + Assert.Equal(sqlTask.TaskStatus, SqlTaskStatus.NotStarted); + + sqlTask.Run().ContinueWith(task => { + Assert.Equal(sqlTask.TaskStatus, expectedStatus); + Assert.Equal(sqlTask.IsCanceled, true); + Assert.True(sqlTask.Duration > 0); + }); + Assert.Equal(sqlTask.TaskStatus, SqlTaskStatus.InProgress); + operation.FailTheOperation(); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/TaskManagerTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/TaskManagerTests.cs new file mode 100644 index 00000000..a2bb04c4 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/TaskManagerTests.cs @@ -0,0 +1,86 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using Microsoft.SqlTools.ServiceLayer.TaskServices; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.UnitTests.TaskServices +{ + + public class TaskManagerTests + { + private TaskMetadata taskMetaData = new TaskMetadata + { + ServerName = "server name", + DatabaseName = "database name" + }; + + [Fact] + public void ManagerInstanceWithNoTaskShouldNotBreakOnCancelTask() + { + SqlTaskManager manager = new SqlTaskManager(); + Assert.True(manager.Tasks.Count == 0); + manager.CancelTask(Guid.NewGuid()); + } + + [Fact] + public void VerifyCreateAndRunningTask() + { + using (SqlTaskManager manager = new SqlTaskManager()) + { + bool taskAddedEventRaised = false; + manager.TaskAdded += (object sender, TaskEventArgs e) => + { + taskAddedEventRaised = true; + }; + DatabaseOperationStub operation = new DatabaseOperationStub(); + operation.TaskResult = new TaskResult + { + }; + SqlTask sqlTask = manager.CreateTask(taskMetaData, operation.FunctionToRun); + Assert.NotNull(sqlTask); + Assert.True(taskAddedEventRaised); + + Assert.False(manager.HasCompletedTasks()); + sqlTask.Run().ContinueWith(task => + { + Assert.True(manager.HasCompletedTasks()); + manager.RemoveCompletedTask(sqlTask); + + + }); + operation.Stop(); + } + + } + + [Fact] + public void CancelTaskShouldCancelTheOperation() + { + using (SqlTaskManager manager = new SqlTaskManager()) + { + SqlTaskStatus expectedStatus = SqlTaskStatus.Canceled; + + DatabaseOperationStub operation = new DatabaseOperationStub(); + operation.TaskResult = new TaskResult + { + }; + SqlTask sqlTask = manager.CreateTask(taskMetaData, operation.FunctionToRun); + Assert.NotNull(sqlTask); + + sqlTask.Run().ContinueWith(task => + { + Assert.Equal(sqlTask.TaskStatus, expectedStatus); + Assert.Equal(sqlTask.IsCanceled, true); + manager.Reset(); + + }); + manager.CancelTask(sqlTask.TaskId); + } + + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/TaskServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/TaskServiceTests.cs new file mode 100644 index 00000000..8b0fb389 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/TaskServiceTests.cs @@ -0,0 +1,131 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SqlTools.Extensibility; +using Microsoft.SqlTools.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.TaskServices; +using Microsoft.SqlTools.ServiceLayer.TaskServices.Contracts; +using Microsoft.SqlTools.ServiceLayer.UnitTests.Utility; +using Moq; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.UnitTests.TaskServices +{ + public class TaskServiceTests : ServiceTestBase + { + private TaskService service; + private Mock serviceHostMock; + private TaskMetadata taskMetaData = new TaskMetadata + { + ServerName = "server name", + DatabaseName = "database name" + }; + + public TaskServiceTests() + { + serviceHostMock = new Mock(); + service = CreateService(); + service.InitializeService(serviceHostMock.Object); + } + + [Fact] + public async Task TaskListRequestErrorsIfParameterIsNull() + { + object errorResponse = null; + var contextMock = RequestContextMocks.Create(null) + .AddErrorHandling((errorMessage, errorCode) => errorResponse = errorMessage); + + await service.HandleListTasksRequest(null, contextMock.Object); + VerifyErrorSent(contextMock); + Assert.True(((string)errorResponse).Contains("ArgumentNullException")); + } + + [Fact] + public void NewTaskShouldSendNotification() + { + serviceHostMock.AddEventHandling(TaskCreatedNotification.Type, null); + serviceHostMock.AddEventHandling(TaskStatusChangedNotification.Type, null); + DatabaseOperationStub operation = new DatabaseOperationStub(); + SqlTask sqlTask = service.TaskManager.CreateTask(taskMetaData, operation.FunctionToRun); + sqlTask.Run().ContinueWith(task => + { + + }); + + serviceHostMock.Verify(x => x.SendEvent(TaskCreatedNotification.Type, + It.Is(t => t.TaskId == sqlTask.TaskId.ToString())), Times.Once()); + operation.Stop(); + + serviceHostMock.Verify(x => x.SendEvent(TaskStatusChangedNotification.Type, + It.Is(t => t.TaskId == sqlTask.TaskId.ToString())), Times.AtLeastOnce()); + } + + [Fact] + public async Task CancelTaskShouldCancelTheOperationAndSendNotification() + { + serviceHostMock.AddEventHandling(TaskCreatedNotification.Type, null); + serviceHostMock.AddEventHandling(TaskStatusChangedNotification.Type, null); + DatabaseOperationStub operation = new DatabaseOperationStub(); + SqlTask sqlTask = service.TaskManager.CreateTask(taskMetaData, operation.FunctionToRun); + sqlTask.Run().ContinueWith(task => + { + serviceHostMock.Verify(x => x.SendEvent(TaskStatusChangedNotification.Type, + It.Is(t => t.Status == SqlTaskStatus.Canceled)), Times.Once()); + }); + CancelTaskParams cancelParams = new CancelTaskParams + { + TaskId = sqlTask.TaskId.ToString() + }; + + await RunAndVerify( + test: (requestContext) => service.HandleCancelTaskRequest(cancelParams, requestContext), + verify: ((result) => + { + })); + + serviceHostMock.Verify(x => x.SendEvent(TaskCreatedNotification.Type, + It.Is(t => t.TaskId == sqlTask.TaskId.ToString())), Times.Once()); + } + + + [Fact] + public async Task TaskListTaskShouldReturnAllTasks() + { + serviceHostMock.AddEventHandling(TaskCreatedNotification.Type, null); + serviceHostMock.AddEventHandling(TaskStatusChangedNotification.Type, null); + DatabaseOperationStub operation = new DatabaseOperationStub(); + SqlTask sqlTask = service.TaskManager.CreateTask(taskMetaData, operation.FunctionToRun); + sqlTask.Run(); + ListTasksParams listParams = new ListTasksParams + { + }; + + await RunAndVerify( + test: (requestContext) => service.HandleListTasksRequest(listParams, requestContext), + verify: ((result) => + { + Assert.True(result.Tasks.Any(x => x.TaskId == sqlTask.TaskId.ToString())); + })); + + operation.Stop(); + } + + protected TaskService CreateService() + { + CreateServiceProviderWithMinServices(); + + // Create the service using the service provider, which will initialize dependencies + return ServiceProvider.GetService(); + } + + protected override RegisteredServiceProvider CreateServiceProviderWithMinServices() + { + return CreateProvider() + .RegisterSingleService(new TaskService()); + } + } +}