diff --git a/README.md b/README.md index 396802d6..e915aca8 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ [![Travis CI](https://travis-ci.org/Microsoft/sqltoolsservice.svg?branch=dev)](https://travis-ci.org/Microsoft/sqltoolsservice) [![AppVeyor](https://ci.appveyor.com/api/projects/status/github/Microsoft/sqltoolsservice?svg=true&retina=true&branch=dev)](https://ci.appveyor.com/project/kburtram/sqltoolsservice) -[![Coverage Status](https://coveralls.io/repos/github/Microsoft/sqltoolsservice/badge.svg?branch=dev)](https://coveralls.io/github/Microsoft/sqltoolsservice?branch=dev) +[![Coverage Status](https://coveralls.io/repos/github/Microsoft/sqltoolsservice/badge.svg?branch=dev)](https://coveralls.io/github/Microsoft/sqltoolsservice?branch=master) # Microsoft SQL Tools Service The SQL Tools Service is an application that provides core functionality for various SQL Server tools. These features include the following: diff --git a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/BackupOperation/BackupOperation.cs b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/BackupOperation/BackupOperation.cs index dae69d45..ca3c5127 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/BackupOperation/BackupOperation.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/BackupOperation/BackupOperation.cs @@ -52,7 +52,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery /// this is used when the backup dialog is launched in the context of a backup device /// The InitialBackupDestination will be loaded in LoadData private string initialBackupDestination = string.Empty; - + // Helps in populating the properties of an Azure blob given its URI private class BlobProperties { @@ -163,6 +163,19 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery } } + /// + /// The error occurred during backup operation + /// + public string ErrorMessage + { + get + { + return string.Empty; + } + } + + public SqlTask SqlTask { get; set; } + /// /// Execute backup /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/Contracts/RestoreRequestParams.cs b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/Contracts/RestoreRequestParams.cs index b761ced1..de32a85d 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/Contracts/RestoreRequestParams.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/Contracts/RestoreRequestParams.cs @@ -4,6 +4,7 @@ // using System.Collections.Generic; +using Microsoft.SqlTools.ServiceLayer.TaskServices; using Microsoft.SqlTools.ServiceLayer.Utility; using Newtonsoft.Json.Linq; @@ -12,7 +13,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.Contracts /// /// Restore request parameters /// - public class RestoreParams : GeneralRequestDetails + public class RestoreParams : GeneralRequestDetails, IScriptableRequestParams { /// /// Restore session id. The parameter is optional and if passed, an existing plan will be used @@ -140,6 +141,26 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.Contracts SetOptionValue(RestoreOptionsHelper.SelectedBackupSets, value); } } + + /// + /// The executation mode for the operation. default is execution + /// + public TaskExecutionMode TaskExecutionMode { get; set; } + + /// + /// Same as Target Database name. Used by task manager to create task info + /// + public string DatabaseName + { + get + { + return TargetDatabaseName; + } + set + { + TargetDatabaseName = value; + } + } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/DisasterRecoveryService.cs b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/DisasterRecoveryService.cs index 46c713f6..9d5d9403 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/DisasterRecoveryService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/DisasterRecoveryService.cs @@ -224,15 +224,10 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery if (restoreDataObject != null) { // create task metadata - TaskMetadata metadata = new TaskMetadata(); - metadata.ServerName = connInfo.ConnectionDetails.ServerName; - metadata.DatabaseName = restoreParams.TargetDatabaseName; - metadata.Name = SR.RestoreTaskName; - metadata.IsCancelable = true; - metadata.Data = restoreDataObject; + TaskMetadata metadata = TaskMetadata.Create(restoreParams, SR.RestoreTaskName, restoreDataObject, ConnectionServiceInstance); // create restore task and perform - SqlTask sqlTask = SqlTaskManagerInstance.CreateAndRun(metadata, this.restoreDatabaseService.RestoreTaskAsync, restoreDatabaseService.CancelTaskAsync); + SqlTask sqlTask = SqlTaskManagerInstance.CreateAndRun(metadata); response.TaskId = sqlTask.TaskId.ToString(); } else @@ -285,8 +280,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery TaskMetadata metadata = new TaskMetadata(); metadata.ServerName = connInfo.ConnectionDetails.ServerName; metadata.DatabaseName = connInfo.ConnectionDetails.DatabaseName; - metadata.Data = backupOperation; - metadata.IsCancelable = true; + metadata.TaskOperation = backupOperation; if (backupParams.IsScripting) { @@ -414,7 +408,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery /// internal async Task PerformBackupTaskAsync(SqlTask sqlTask) { - IBackupOperation backupOperation = sqlTask.TaskMetadata.Data as IBackupOperation; + IBackupOperation backupOperation = sqlTask.TaskMetadata.TaskOperation as IBackupOperation; TaskResult result = new TaskResult(); // Create a task to perform backup @@ -463,7 +457,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery /// internal async Task CancelBackupTaskAsync(SqlTask sqlTask) { - IBackupOperation backupOperation = sqlTask.TaskMetadata.Data as IBackupOperation; + IBackupOperation backupOperation = sqlTask.TaskMetadata.TaskOperation as IBackupOperation; TaskResult result = new TaskResult(); await Task.Factory.StartNew(() => diff --git a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseHelper.cs b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseHelper.cs index f2b934e3..370ff474 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseHelper.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseHelper.cs @@ -7,13 +7,11 @@ using System; using System.Linq; using System.Data.Common; using System.Data.SqlClient; -using System.Threading.Tasks; using Microsoft.SqlServer.Management.Common; using Microsoft.SqlServer.Management.Smo; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection; using Microsoft.SqlTools.ServiceLayer.DisasterRecovery.Contracts; -using Microsoft.SqlTools.ServiceLayer.TaskServices; using Microsoft.SqlTools.Utility; using System.Collections.Concurrent; using Microsoft.SqlTools.ServiceLayer.Utility; @@ -28,107 +26,6 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation public const string LastBackupTaken = "lastBackupTaken"; private ConcurrentDictionary sessions = new ConcurrentDictionary(); - /// - /// Create a backup task for execution and cancellation - /// - /// - /// - internal async Task RestoreTaskAsync(SqlTask sqlTask) - { - sqlTask.AddMessage(SR.TaskInProgress, SqlTaskStatus.InProgress, true); - RestoreDatabaseTaskDataObject restoreDataObject = sqlTask.TaskMetadata.Data as RestoreDatabaseTaskDataObject; - TaskResult taskResult = null; - - if (restoreDataObject != null) - { - // Create a task to perform backup - return await Task.Factory.StartNew(() => - { - TaskResult result = new TaskResult(); - try - { - if (restoreDataObject.IsValid) - { - ExecuteRestore(restoreDataObject, sqlTask); - result.TaskStatus = SqlTaskStatus.Succeeded; - } - else - { - result.TaskStatus = SqlTaskStatus.Failed; - if (restoreDataObject.ActiveException != null) - { - result.ErrorMessage = restoreDataObject.ActiveException.Message; - } - else - { - result.ErrorMessage = SR.RestoreNotSupported; - } - } - } - catch (Exception ex) - { - result.TaskStatus = SqlTaskStatus.Failed; - result.ErrorMessage = ex.Message; - if (ex.InnerException != null) - { - result.ErrorMessage += Environment.NewLine + ex.InnerException.Message; - } - if (restoreDataObject != null && restoreDataObject.ActiveException != null) - { - result.ErrorMessage += Environment.NewLine + restoreDataObject.ActiveException.Message; - } - } - return result; - }); - } - else - { - taskResult = new TaskResult(); - taskResult.TaskStatus = SqlTaskStatus.Failed; - } - - return taskResult; - } - - - - /// - /// Async task to cancel restore - /// - public async Task CancelTaskAsync(SqlTask sqlTask) - { - RestoreDatabaseTaskDataObject restoreDataObject = sqlTask.TaskMetadata.Data as RestoreDatabaseTaskDataObject; - TaskResult taskResult = null; - - - if (restoreDataObject != null && restoreDataObject.IsValid) - { - // Create a task for backup cancellation request - return await Task.Factory.StartNew(() => - { - - foreach (Restore restore in restoreDataObject.RestorePlan.RestoreOperations) - { - restore.Abort(); - } - - - return new TaskResult - { - TaskStatus = SqlTaskStatus.Canceled - }; - - }); - } - else - { - taskResult = new TaskResult(); - taskResult.TaskStatus = SqlTaskStatus.Failed; - } - - return taskResult; - } - /// /// Creates response which includes information about the server given to restore (default data location, db names with backupsets) /// @@ -166,7 +63,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation { if (restoreDataObject != null && restoreDataObject.IsValid) { - UpdateRestorePlan(restoreDataObject); + restoreDataObject.UpdateRestoreTaskObject(); if (restoreDataObject != null && restoreDataObject.IsValid) { @@ -227,6 +124,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation response.ErrorMessage += Environment.NewLine; response.ErrorMessage += ex.InnerException.Message; } + Logger.Write(LogLevel.Normal, $"Failed to create restore plan. error: { response.ErrorMessage}"); } return response; @@ -295,63 +193,11 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation return null; } - /// - /// Create a restore data object that includes the plan to do the restore operation - /// - /// - /// - private void UpdateRestorePlan(RestoreDatabaseTaskDataObject restoreDataObject) - { - bool shouldCreateNewPlan = restoreDataObject.ShouldCreateNewPlan(); - - if (!string.IsNullOrEmpty(restoreDataObject.RestoreParams.BackupFilePaths)) - { - restoreDataObject.AddFiles(restoreDataObject.RestoreParams.BackupFilePaths); - } - restoreDataObject.RestorePlanner.ReadHeaderFromMedia = restoreDataObject.RestoreParams.ReadHeaderFromMedia; - - RestoreOptionFactory.Instance.SetAndValidate(RestoreOptionsHelper.SourceDatabaseName, restoreDataObject); - RestoreOptionFactory.Instance.SetAndValidate(RestoreOptionsHelper.TargetDatabaseName, restoreDataObject); - - if (shouldCreateNewPlan) - { - restoreDataObject.CreateNewRestorePlan(); - } - - restoreDataObject.UpdateRestorePlan(); - - } + private bool CanChangeTargetDatabase(RestoreDatabaseTaskDataObject restoreDataObject) { return DatabaseUtils.IsSystemDatabaseConnection(restoreDataObject.Server.ConnectionContext.DatabaseName); } - - /// - /// Executes the restore operation - /// - /// - public void ExecuteRestore(RestoreDatabaseTaskDataObject restoreDataObject, SqlTask sqlTask = null) - { - // Restore Plan should be already created and updated at this point - UpdateRestorePlan(restoreDataObject); - - if (restoreDataObject != null && CanRestore(restoreDataObject)) - { - try - { - restoreDataObject.SqlTask = sqlTask; - restoreDataObject.Execute(); - } - catch(Exception ex) - { - throw ex; - } - } - else - { - throw new InvalidOperationException(SR.RestoreNotSupported); - } - } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseTaskDataObject.cs b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseTaskDataObject.cs index 046db852..a12b6d6a 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseTaskDataObject.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseTaskDataObject.cs @@ -13,6 +13,7 @@ using Microsoft.SqlServer.Management.Smo; using Microsoft.SqlTools.ServiceLayer.DisasterRecovery.Contracts; using Microsoft.SqlTools.ServiceLayer.TaskServices; using Microsoft.SqlTools.ServiceLayer.Utility; +using Microsoft.SqlTools.Utility; namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation { @@ -61,7 +62,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation /// /// Includes the plan with all the data required to do a restore operation on server /// - public class RestoreDatabaseTaskDataObject : IRestoreDatabaseTaskDataObject + public class RestoreDatabaseTaskDataObject : SmoScriptableTaskOperation, IRestoreDatabaseTaskDataObject { private const char BackupMediaNameSeparator = ','; @@ -71,11 +72,12 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation private bool? isTailLogBackupPossible = false; private bool? isTailLogBackupWithNoRecoveryPossible = false; private string backupMediaList = string.Empty; + private Server server; public RestoreDatabaseTaskDataObject(Server server, String databaseName) { PlanUpdateRequired = true; - this.Server = server; + this.server = server; this.Util = new RestoreUtil(server); restorePlanner = new DatabaseRestorePlanner(server); @@ -105,11 +107,6 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation /// public string SessionId { get; set; } - /// - /// Sql task assigned to the restore object - /// - public SqlTask SqlTask { get; set; } - public string TargetDatabaseName { get @@ -180,7 +177,13 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation /// /// Current sqlserver instance /// - public Server Server; + public override Server Server + { + get + { + return server; + } + } /// /// Recent exception that was thrown @@ -254,23 +257,36 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation } + public override void Execute(TaskExecutionMode mode) + { + UpdateRestoreTaskObject(); + + base.Execute(mode); + } + /// /// Executes the restore operations /// - public void Execute() + public override void Execute() { - RestorePlan restorePlan = GetRestorePlanForExecutionAndScript(); - - if (restorePlan != null && restorePlan.RestoreOperations.Count > 0) + if (IsValid && RestorePlan.RestoreOperations != null && RestorePlan.RestoreOperations.Any()) { - restorePlan.PercentComplete += (object sender, PercentCompleteEventArgs e) => + // Restore Plan should be already created and updated at this point + + RestorePlan restorePlan = GetRestorePlanForExecutionAndScript(); + + if (restorePlan != null && restorePlan.RestoreOperations.Count > 0) { - if (SqlTask != null) + restorePlan.PercentComplete += (object sender, PercentCompleteEventArgs e) => { - SqlTask.AddMessage($"{e.Percent}%", SqlTaskStatus.InProgress); - } - }; - restorePlan.Execute(); + OnMessageAdded(new TaskMessage { Description = $"{e.Percent}%", Status = SqlTaskStatus.InProgress }); + }; + restorePlan.Execute(); + } + } + else + { + throw new InvalidOperationException(SR.RestoreNotSupported); } } @@ -657,10 +673,6 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation { get { - if (this.restorePlan == null) - { - this.UpdateRestorePlan(); - } return this.restorePlan; } internal set @@ -797,20 +809,27 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation { Database db = null; List ret = new List(); - if (!this.RestorePlanner.ReadHeaderFromMedia) + try { - db = this.Server.Databases[this.RestorePlanner.DatabaseName]; + if (!this.RestorePlanner.ReadHeaderFromMedia) + { + db = this.Server.Databases[this.RestorePlanner.DatabaseName]; + } + if (restorePlan != null && restorePlan.RestoreOperations.Count > 0) + { + if (db != null && db.Status == DatabaseStatus.Normal) + { + ret = this.Util.GetDbFiles(db); + } + else + { + ret = this.Util.GetDbFiles(restorePlan.RestoreOperations[0]); + } + } } - if (restorePlan != null && restorePlan.RestoreOperations.Count > 0) + catch(Exception ex ) { - if (db != null && db.Status == DatabaseStatus.Normal) - { - ret = this.Util.GetDbFiles(db); - } - else - { - ret = this.Util.GetDbFiles(restorePlan.RestoreOperations[0]); - } + Logger.Write(LogLevel.Normal, $"Failed to get restore db files. error: {ex.Message}"); } return ret; } @@ -1033,10 +1052,22 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation } } + /// + /// Returns the restore plan error message + /// + public override string ErrorMessage + { + get + { + if (ActiveException != null) + { + return ActiveException.Message; + } + return string.Empty; + } + } - - private bool IsAnyFullBackupSetSelected() { bool isSelected = false; @@ -1210,5 +1241,47 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation } return true; } + + /// + /// Cancels the restore operations + /// + public override void Cancel() + { + foreach (Restore restore in RestorePlan.RestoreOperations) + { + restore.Abort(); + } + } + + /// + /// Create a restore data object that includes the plan to do the restore operation + /// + /// + /// + internal void UpdateRestoreTaskObject() + { + bool shouldCreateNewPlan = ShouldCreateNewPlan(); + + if (!string.IsNullOrEmpty(RestoreParams.BackupFilePaths) && RestoreParams.ReadHeaderFromMedia) + { + AddFiles(RestoreParams.BackupFilePaths); + } + else + { + RestorePlanner.BackupMediaList.Clear(); + } + RestorePlanner.ReadHeaderFromMedia = RestoreParams.ReadHeaderFromMedia; + + RestoreOptionFactory.Instance.SetAndValidate(RestoreOptionsHelper.SourceDatabaseName, this); + RestoreOptionFactory.Instance.SetAndValidate(RestoreOptionsHelper.TargetDatabaseName, this); + + if (shouldCreateNewPlan) + { + CreateNewRestorePlan(); + } + + UpdateRestorePlan(); + + } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreOptionFactory.cs b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreOptionFactory.cs index fc4117bc..2ac1ffb2 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreOptionFactory.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreOptionFactory.cs @@ -538,10 +538,15 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation }, ValidateFunction = (IRestoreDatabaseTaskDataObject restoreDataObject, object currentValue, object defaultValue) => { - + string errorMessage = string.Empty; + if (currentValue!= null && DatabaseUtils.IsSystemDatabaseConnection(currentValue.ToString())) + { + errorMessage = "Cannot restore to system database"; + } return new OptionValidationResult() { - IsReadOnly = !restoreDataObject.CanChangeTargetDatabase + IsReadOnly = !restoreDataObject.CanChangeTargetDatabase, + ErrorMessage = errorMessage }; }, SetValueFunction = (IRestoreDatabaseTaskDataObject restoreDataObject, object value) => diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/ITaskOperation.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/ITaskOperation.cs index 1f8eb862..ed83f72b 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/ITaskOperation.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/ITaskOperation.cs @@ -1,6 +1,7 @@ -using System; -using System.Collections.Generic; -using System.Text; +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// namespace Microsoft.SqlTools.ServiceLayer.TaskServices { @@ -19,6 +20,16 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices /// Cancel a task /// void Cancel(); + + /// + /// If an error occurred during task execution, this field contains the error message text + /// + string ErrorMessage { get; } + + /// + /// The sql task that's executing the operation + /// + SqlTask SqlTask { get; set; } } /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SmoScriptableTaskOperation.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SmoScriptableTaskOperation.cs new file mode 100644 index 00000000..7c265525 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SmoScriptableTaskOperation.cs @@ -0,0 +1,146 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Text; +using Microsoft.SqlServer.Management.Common; +using Microsoft.SqlServer.Management.Smo; + +namespace Microsoft.SqlTools.ServiceLayer.TaskServices +{ + /// + /// Any SMO operation that supports scripting should implement this class. + /// It provides all of the configuration needed to choose between scripting or execution mode, + /// hook into the Task manager framework, and send success / completion notifications to the caller. + /// + public abstract class SmoScriptableTaskOperation : IScriptableTaskOperation + { + /// + /// Script content + /// + public string ScriptContent + { + get; set; + } + + /// + /// If an error occurred during task execution, this field contains the error message text + /// + public abstract string ErrorMessage { get; } + + /// + /// SMO Server instance used for the operation + /// + public abstract Server Server { get; } + + /// + /// Cancels the operation + /// + public abstract void Cancel(); + + /// + /// Updates messages in sql task given new progress message + /// + /// + public void OnMessageAdded(TaskMessage message) + { + if (this.SqlTask != null) + { + this.SqlTask.AddMessage(message.Description, message.Status); + } + } + + /// + /// Updates scripts in sql task given new script + /// + /// + public void OnScriptAdded(TaskScript script) + { + this.SqlTask.AddScript(script.Status, script.Script, script.ErrorMessage); + } + + /// + /// Executes the operations + /// + public abstract void Execute(); + + + /// + /// Execute the operation for given execution mode + /// + /// + public virtual void Execute(TaskExecutionMode mode) + { + var currentExecutionMode = Server.ConnectionContext.SqlExecutionModes; + try + { + + if (Server != null) + { + Server.ConnectionContext.CapturedSql.Clear(); + switch (mode) + { + case TaskExecutionMode.Execute: + { + Server.ConnectionContext.SqlExecutionModes = SqlExecutionModes.ExecuteSql; + break; + } + case TaskExecutionMode.ExecuteAndScript: + { + Server.ConnectionContext.SqlExecutionModes = SqlExecutionModes.ExecuteAndCaptureSql; + break; + } + case TaskExecutionMode.Script: + { + Server.ConnectionContext.SqlExecutionModes = SqlExecutionModes.CaptureSql; + break; + } + } + } + + Execute(); + if (mode == TaskExecutionMode.Script || mode == TaskExecutionMode.ExecuteAndScript) + { + this.ScriptContent = GetScriptContent(); + if (SqlTask != null) + { + OnScriptAdded(new TaskScript + { + Status = SqlTaskStatus.Succeeded, + Script = this.ScriptContent + }); + } + } + } + catch + { + throw; + } + finally + { + Server.ConnectionContext.CapturedSql.Clear(); + Server.ConnectionContext.SqlExecutionModes = currentExecutionMode; + } + + } + + private string GetScriptContent() + { + StringBuilder sb = new StringBuilder(); + foreach (String s in this.Server.ConnectionContext.CapturedSql.Text) + { + sb.Append(s); + sb.Append(Environment.NewLine); + } + return sb.ToString(); + } + + /// + /// The sql task to run the operations + /// + public SqlTask SqlTask { get; set; } + + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTask.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTask.cs index 6e4628b9..0e654b6e 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTask.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTask.cs @@ -34,6 +34,14 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices public event EventHandler> MessageAdded; public event EventHandler> StatusChanged; + /// + /// Default constructor to create the geenric type. calling Initialize method is required after creating + /// the insance + /// + public SqlTask() + { + } + /// /// Creates new instance of SQL task /// @@ -41,10 +49,21 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices /// The function to run to start the task public SqlTask(TaskMetadata taskMetdata, Func> taskToRun, Func> taskToCancel) { - Validate.IsNotNull(nameof(taskMetdata), taskMetdata); + Init(taskMetdata, taskToRun, taskToCancel); + } + + /// + /// Initializes the Sql task + /// + /// Task metadata + /// The function to execute the operation + /// The function to cancel the operation + public virtual void Init(TaskMetadata taskMetadata, Func> taskToRun, Func> taskToCancel) + { + Validate.IsNotNull(nameof(taskMetadata), taskMetadata); Validate.IsNotNull(nameof(taskToRun), taskToRun); - TaskMetadata = taskMetdata; + TaskMetadata = taskMetadata; TaskToRun = taskToRun; TaskToCancel = taskToCancel; StartTime = DateTime.UtcNow; @@ -120,8 +139,9 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices /// internal async Task RunAndCancel() { + TokenSource = new CancellationTokenSource(); AddMessage(SR.TaskInProgress, SqlTaskStatus.InProgress, true); - + TaskResult taskResult = new TaskResult(); Task performTask = TaskToRun(this); Task completedTask = null; @@ -452,7 +472,7 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices Name = TaskMetadata.Name, Description = TaskMetadata.Description, TaskExecutionMode = TaskMetadata.TaskExecutionMode, - IsCancelable = TaskMetadata.IsCancelable, + IsCancelable = this.TaskToCancel != null, }; } diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTaskManager.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTaskManager.cs index b1d2198d..9705608f 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTaskManager.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTaskManager.cs @@ -8,6 +8,7 @@ using System.Collections.Concurrent; using System.Collections.ObjectModel; using System.Linq; using System.Threading.Tasks; +using Microsoft.SqlTools.Utility; namespace Microsoft.SqlTools.ServiceLayer.TaskServices { @@ -83,19 +84,58 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices /// Task Metadata /// The function to run the operation /// The function to cancel the operation - /// - public SqlTask CreateTask(TaskMetadata taskMetadata, Func> taskToRun, Func> taskToCancel) + /// The new sql task + public SqlTask CreateTask(TaskMetadata taskMetadata, Func> taskToRun, Func> taskToCancel) + { + return CreateTask(taskMetadata, taskToRun, taskToCancel); + } + + /// + /// Creates a new task + /// + /// Task Metadata + /// The new sql task + public SqlTask CreateTask(TaskMetadata taskMetadata) where T : SqlTask, new() + { + Validate.IsNotNull(nameof(taskMetadata), taskMetadata); + return CreateTask(taskMetadata, TaskOperationHelper.ExecuteTaskAsync, TaskOperationHelper.CancelTaskAsync); + } + + /// + /// Creates a new task + /// + /// Task Metadata + /// The function to run the operation + /// The function to cancel the operation + /// The new sql task + public SqlTask CreateTask(TaskMetadata taskMetadata, Func> taskToRun, Func> taskToCancel) where T : SqlTask, new() { ValidateNotDisposed(); - var newtask = new SqlTask(taskMetadata, taskToRun, taskToCancel); + var newTask = new T(); + newTask.Init(taskMetadata, taskToRun, taskToCancel); + if (taskMetadata != null && taskMetadata.TaskOperation != null) + { + taskMetadata.TaskOperation.SqlTask = newTask; + } lock (lockObject) { - tasks.AddOrUpdate(newtask.TaskId, newtask, (key, oldValue) => newtask); + tasks.AddOrUpdate(newTask.TaskId, newTask, (key, oldValue) => newTask); } - OnTaskAdded(new TaskEventArgs(newtask)); - return newtask; + OnTaskAdded(new TaskEventArgs(newTask)); + return newTask; + } + + /// + /// Creates a new task + /// + /// Task Metadata + /// The function to run the operation + /// The new sql task + public SqlTask CreateTask(TaskMetadata taskMetadata, Func> taskToRun) + { + return CreateTask(taskMetadata, taskToRun); } /// @@ -104,9 +144,9 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices /// Task Metadata /// The function to run the operation /// - public SqlTask CreateTask(TaskMetadata taskMetadata, Func> taskToRun) + public SqlTask CreateTask(TaskMetadata taskMetadata, Func> taskToRun) where T : SqlTask, new() { - return CreateTask(taskMetadata, taskToRun, null); + return CreateTask(taskMetadata, taskToRun, null); } /// @@ -118,7 +158,26 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices /// public SqlTask CreateAndRun(TaskMetadata taskMetadata, Func> taskToRun, Func> taskToCancel) { - var sqlTask = CreateTask(taskMetadata, taskToRun, taskToCancel); + return CreateAndRun(taskMetadata, taskToRun, taskToCancel); + } + + public SqlTask CreateAndRun(TaskMetadata taskMetadata) where T : SqlTask, new() + { + var sqlTask = CreateTask(taskMetadata); + sqlTask.Run(); + return sqlTask; + } + + /// + /// Creates a new task and starts the task + /// + /// Task Metadata + /// The function to run the operation + /// The function to cancel the operation + /// + public SqlTask CreateAndRun(TaskMetadata taskMetadata, Func> taskToRun, Func> taskToCancel) where T : SqlTask, new() + { + var sqlTask = CreateTask(taskMetadata, taskToRun, taskToCancel); sqlTask.Run(); return sqlTask; } diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskMetadata.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskMetadata.cs index ca485524..26ebcb4b 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskMetadata.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskMetadata.cs @@ -3,6 +3,9 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Utility; + namespace Microsoft.SqlTools.ServiceLayer.TaskServices { public class TaskMetadata @@ -28,11 +31,6 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices /// public int Timeout { get; set; } - /// - /// Defines if the task can be canceled - /// - public bool IsCancelable { get; set; } - /// /// Database server name this task is created for /// @@ -46,6 +44,52 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices /// /// Data required to perform the task /// - public object Data { get; set; } + public ITaskOperation TaskOperation { get; set; } + + /// + /// Creates task metadata given the request parameters + /// + /// Request parameters + /// Task name + /// Task operation + /// Connection Service + /// Task metadata + public static TaskMetadata Create(IRequestParams requestParam, string taskName, ITaskOperation taskOperation, ConnectionService connectionService) + { + TaskMetadata taskMetadata = new TaskMetadata(); + ConnectionInfo connInfo; + connectionService.TryFindConnection( + requestParam.OwnerUri, + out connInfo); + + if (connInfo != null) + { + taskMetadata.ServerName = connInfo.ConnectionDetails.ServerName; + } + + if (!string.IsNullOrEmpty(requestParam.DatabaseName)) + { + taskMetadata.DatabaseName = requestParam.DatabaseName; + } + else if (connInfo != null) + { + taskMetadata.DatabaseName = connInfo.ConnectionDetails.DatabaseName; + } + + IScriptableRequestParams scriptableRequestParams = requestParam as IScriptableRequestParams; + if (scriptableRequestParams != null && scriptableRequestParams.TaskExecutionMode != TaskExecutionMode.Execute) + { + taskMetadata.Name = string.Format("{0} {1}", taskName, SR.ScriptTaskName); + } + else + { + taskMetadata.Name = taskName; + } + taskMetadata.TaskExecutionMode = scriptableRequestParams.TaskExecutionMode; + + taskMetadata.TaskOperation = taskOperation; + return taskMetadata; + } + } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskOperationHelper.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskOperationHelper.cs new file mode 100644 index 00000000..50b69630 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskOperationHelper.cs @@ -0,0 +1,113 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Threading.Tasks; + +namespace Microsoft.SqlTools.ServiceLayer.TaskServices +{ + /// + /// Helper class for task operations + /// + public static class TaskOperationHelper + { + /// + /// Async method to execute the operation + /// + /// Sql Task + /// Task Result + public static async Task ExecuteTaskAsync(SqlTask sqlTask) + { + sqlTask.AddMessage(SR.TaskInProgress, SqlTaskStatus.InProgress, true); + ITaskOperation taskOperation = sqlTask.TaskMetadata.TaskOperation as ITaskOperation; + TaskResult taskResult = null; + + if (taskOperation != null) + { + taskOperation.SqlTask = sqlTask; + + return await Task.Factory.StartNew(() => + { + TaskResult result = new TaskResult(); + try + { + if (string.IsNullOrEmpty(taskOperation.ErrorMessage)) + { + taskOperation.Execute(sqlTask.TaskMetadata.TaskExecutionMode); + result.TaskStatus = SqlTaskStatus.Succeeded; + } + else + { + result.TaskStatus = SqlTaskStatus.Failed; + result.ErrorMessage = taskOperation.ErrorMessage; + } + } + catch (Exception ex) + { + result.TaskStatus = SqlTaskStatus.Failed; + result.ErrorMessage = ex.Message; + if (ex.InnerException != null) + { + result.ErrorMessage += Environment.NewLine + ex.InnerException.Message; + } + if (taskOperation != null && taskOperation.ErrorMessage != null) + { + result.ErrorMessage += Environment.NewLine + taskOperation.ErrorMessage; + } + } + return result; + }); + } + else + { + taskResult = new TaskResult(); + taskResult.TaskStatus = SqlTaskStatus.Failed; + } + + return taskResult; + } + + /// + /// Async method to cancel the operations + /// + public static async Task CancelTaskAsync(SqlTask sqlTask) + { + ITaskOperation taskOperation = sqlTask.TaskMetadata.TaskOperation as ITaskOperation; + TaskResult taskResult = null; + + if (taskOperation != null) + { + + return await Task.Factory.StartNew(() => + { + try + { + taskOperation.Cancel(); + + return new TaskResult + { + TaskStatus = SqlTaskStatus.Canceled + }; + } + catch (Exception ex) + { + return new TaskResult + { + TaskStatus = SqlTaskStatus.Failed, + ErrorMessage = ex.Message + }; + } + }); + } + else + { + taskResult = new TaskResult(); + taskResult.TaskStatus = SqlTaskStatus.Failed; + } + + return taskResult; + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskService.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskService.cs index 15f9b513..b0614444 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskService.cs @@ -17,18 +17,9 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices public class TaskService: HostedService, IComposableService { private static readonly Lazy instance = new Lazy(() => new TaskService()); - private SqlTaskManager taskManager = SqlTaskManager.Instance; + private SqlTaskManager taskManager = null; private IProtocolEndpoint serviceHost; - - /// - /// Default, parameterless constructor. - /// - public TaskService() - { - taskManager.TaskAdded += OnTaskAdded; - } - /// /// Gets the singleton instance object /// @@ -44,8 +35,16 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices { get { + if(taskManager == null) + { + taskManager = SqlTaskManager.Instance; + } return taskManager; } + set + { + taskManager = value; + } } /// @@ -57,6 +56,7 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices Logger.Write(LogLevel.Verbose, "TaskService initialized"); serviceHost.SetRequestHandler(ListTasksRequest.Type, HandleListTasksRequest); serviceHost.SetRequestHandler(CancelTaskRequest.Type, HandleCancelTaskRequest); + TaskManager.TaskAdded += OnTaskAdded; } /// @@ -74,7 +74,7 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices return Task.Factory.StartNew(() => { ListTasksResponse response = new ListTasksResponse(); - response.Tasks = taskManager.Tasks.Select(x => x.ToTaskInfo()).ToArray(); + response.Tasks = TaskManager.Tasks.Select(x => x.ToTaskInfo()).ToArray(); return response; }); @@ -96,7 +96,7 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices Guid taskId; if (Guid.TryParse(cancelTaskParams.TaskId, out taskId)) { - taskManager.CancelTask(taskId); + TaskManager.CancelTask(taskId); return true; } else @@ -176,8 +176,8 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices public void Dispose() { - taskManager.TaskAdded -= OnTaskAdded; - taskManager.Dispose(); + TaskManager.TaskAdded -= OnTaskAdded; + TaskManager.Dispose(); } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Utility/IRequestParams.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/IRequestParams.cs new file mode 100644 index 00000000..183acb13 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Utility/IRequestParams.cs @@ -0,0 +1,21 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + + +namespace Microsoft.SqlTools.ServiceLayer.Utility +{ + public interface IRequestParams + { + /// + /// The Uri to find the connection to do the restore operations + /// + string OwnerUri { get; set; } + + /// + /// Database name + /// + string DatabaseName { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Utility/IScriptableRequestParams.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/IScriptableRequestParams.cs new file mode 100644 index 00000000..338ae782 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Utility/IScriptableRequestParams.cs @@ -0,0 +1,18 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + + +using Microsoft.SqlTools.ServiceLayer.TaskServices; + +namespace Microsoft.SqlTools.ServiceLayer.Utility +{ + public interface IScriptableRequestParams : IRequestParams + { + /// + /// The executation mode for the operation. default is execution + /// + TaskExecutionMode TaskExecutionMode { get; set; } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/DisasterRecovery/RestoreDatabaseServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/DisasterRecovery/RestoreDatabaseServiceTests.cs index aa766be2..50552d4e 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/DisasterRecovery/RestoreDatabaseServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/DisasterRecovery/RestoreDatabaseServiceTests.cs @@ -82,7 +82,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery //Verify that all backupsets are restored int[] expectedTable = new int[] { }; - await VerifyRestoreMultipleBackupSets(backupFiles, indexToDelete, expectedTable); + await VerifyRestoreMultipleBackupSets(backupFiles, indexToDelete, expectedTable, TaskExecutionModeFlag.Execute); } [Fact] @@ -95,7 +95,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery { Dictionary options = new Dictionary(); options.Add(RestoreOptionsHelper.ReplaceDatabase, true); - await VerifyRestore(null, databaseNameToRestoreFrom, true, true, testDb.DatabaseName, null, options, (database) => + await VerifyRestore(null, databaseNameToRestoreFrom, true, TaskExecutionModeFlag.ExecuteAndScript, testDb.DatabaseName, null, options, (database) => { return database.Tables.Contains("tb1", "test"); }); @@ -129,14 +129,14 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery await VerifyRestoreMultipleBackupSets(backupFiles, indexToDelete, expectedTable); } - private async Task VerifyRestoreMultipleBackupSets(string[] backupFiles, int backupSetIndexToDelete, int[] expectedSelectedIndexes) + private async Task VerifyRestoreMultipleBackupSets(string[] backupFiles, int backupSetIndexToDelete, int[] expectedSelectedIndexes, TaskExecutionModeFlag executionMode = TaskExecutionModeFlag.ExecuteAndScript) { var testDb = await SqlTestDb.CreateNewAsync(TestServerType.OnPrem, false, null, null, "RestoreTest"); try { string targetDbName = testDb.DatabaseName; bool canRestore = true; - var response = await VerifyRestore(backupFiles, null, canRestore, false, targetDbName, null, null); + var response = await VerifyRestore(backupFiles, null, canRestore, TaskExecutionModeFlag.None, targetDbName, null, null); Assert.True(response.BackupSetsToRestore.Count() >= 2); var allIds = response.BackupSetsToRestore.Select(x => x.Id).ToList(); if (backupSetIndexToDelete >= 0) @@ -146,20 +146,24 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery string[] selectedIds = allIds.ToArray(); Dictionary options = new Dictionary(); options.Add(RestoreOptionsHelper.ReplaceDatabase, true); - response = await VerifyRestore(backupFiles, null, canRestore, true, targetDbName, selectedIds, options, (database) => + response = await VerifyRestore(backupFiles, null, canRestore, executionMode, targetDbName, selectedIds, options, (database) => { - bool tablesFound = true; - for (int i = 0; i < tableNames.Length; i++) + if (executionMode.HasFlag(TaskExecutionModeFlag.Execute)) { - string tableName = tableNames[i]; - if (!database.Tables.Contains(tableName, "test") && expectedSelectedIndexes.Contains(i)) + bool tablesFound = true; + for (int i = 0; i < tableNames.Length; i++) { - tablesFound = false; - break; + string tableName = tableNames[i]; + if (!database.Tables.Contains(tableName, "test") && expectedSelectedIndexes.Contains(i)) + { + tablesFound = false; + break; + } } + bool numberOfTableCreatedIsCorrect = database.Tables.Count == expectedSelectedIndexes.Length; + return numberOfTableCreatedIsCorrect && tablesFound; } - bool numberOfTableCreatedIsCorrect = database.Tables.Count == expectedSelectedIndexes.Length; - return numberOfTableCreatedIsCorrect && tablesFound; + return true; }); for (int i = 0; i < response.BackupSetsToRestore.Count(); i++) @@ -190,7 +194,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery Dictionary options = new Dictionary(); options.Add(RestoreOptionsHelper.ReplaceDatabase, true); - await VerifyRestore(new string[] { fullBackupFilePath }, null, canRestore, true, testDb.DatabaseName, null, options); + await VerifyRestore(new string[] { fullBackupFilePath }, null, canRestore, TaskExecutionModeFlag.ExecuteAndScript, testDb.DatabaseName, null, options); } finally { @@ -212,7 +216,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery await VerifyBackupFileCreated(); bool canRestore = true; - await VerifyRestore(new string[] { fullBackupFilePath }, null, canRestore, false, testDb.DatabaseName, null, null); + await VerifyRestore(new string[] { fullBackupFilePath }, null, canRestore, TaskExecutionModeFlag.None, testDb.DatabaseName, null, null); } finally { @@ -229,7 +233,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery string[] backupFileNames = new string[] { "FullBackup.bak", "DiffBackup.bak" }; bool canRestore = true; - var response = await VerifyRestore(backupFileNames, null, canRestore, false, "RestoredFromTwoBackupFile"); + var response = await VerifyRestore(backupFileNames, null, canRestore, TaskExecutionModeFlag.None, "RestoredFromTwoBackupFile"); Assert.True(response.BackupSetsToRestore.Count() == 2); } @@ -239,13 +243,13 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery string[] backupFileNames = new string[] { "FullBackup.bak", "DiffBackup.bak" }; bool canRestore = true; - var response = await VerifyRestore(backupFileNames, null, canRestore, false, "RestoredFromTwoBackupFile"); + var response = await VerifyRestore(backupFileNames, null, canRestore, TaskExecutionModeFlag.None, "RestoredFromTwoBackupFile"); Assert.True(response.BackupSetsToRestore.Count() == 2); var fileInfo = response.BackupSetsToRestore.FirstOrDefault(x => x.GetPropertyValueAsString(BackupSetInfo.BackupTypePropertyName) != RestoreConstants.TypeFull); if(fileInfo != null) { var selectedBackupSets = new string[] { fileInfo.Id }; - await VerifyRestore(backupFileNames, null, true, false, "RestoredFromTwoBackupFile", selectedBackupSets); + await VerifyRestore(backupFileNames, null, true, TaskExecutionModeFlag.None, "RestoredFromTwoBackupFile", selectedBackupSets); } } @@ -255,13 +259,13 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery string[] backupFileNames = new string[] { "FullBackup.bak", "DiffBackup.bak" }; bool canRestore = true; - var response = await VerifyRestore(backupFileNames, null, canRestore, false, "RestoredFromTwoBackupFile"); + var response = await VerifyRestore(backupFileNames, null, canRestore, TaskExecutionModeFlag.None, "RestoredFromTwoBackupFile"); Assert.True(response.BackupSetsToRestore.Count() == 2); var fileInfo = response.BackupSetsToRestore.FirstOrDefault(x => x.GetPropertyValueAsString(BackupSetInfo.BackupTypePropertyName) == RestoreConstants.TypeFull); if (fileInfo != null) { var selectedBackupSets = new string[] { fileInfo.Id }; - await VerifyRestore(backupFileNames, null, true, false, "RestoredFromTwoBackupFile2", selectedBackupSets); + await VerifyRestore(backupFileNames, null, true, TaskExecutionModeFlag.None, "RestoredFromTwoBackupFile2", selectedBackupSets); } } @@ -272,7 +276,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery string backupFileName = fullBackupFilePath; bool canRestore = true; - var restorePlan = await VerifyRestore(backupFileName, canRestore, true); + var restorePlan = await VerifyRestore(backupFileName, canRestore, TaskExecutionModeFlag.Execute); Assert.NotNull(restorePlan.BackupSetsToRestore); } @@ -283,7 +287,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery string backupFileName = fullBackupFilePath; bool canRestore = true; - var restorePlan = await VerifyRestore(backupFileName, canRestore, true, "NewRestoredDatabase"); + var restorePlan = await VerifyRestore(backupFileName, canRestore, TaskExecutionModeFlag.ExecuteAndScript, "NewRestoredDatabase"); } [Fact] @@ -414,16 +418,16 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery await TestServiceProvider.Instance.RunQueryAsync(TestServerType.OnPrem, "master", dropDatabaseQuery); } - private async Task VerifyRestore(string backupFileName, bool canRestore, bool execute = false, string targetDatabase = null) + private async Task VerifyRestore(string backupFileName, bool canRestore, TaskExecutionModeFlag executionMode = TaskExecutionModeFlag.None, string targetDatabase = null) { - return await VerifyRestore(new string[] { backupFileName }, null, canRestore, execute, targetDatabase); + return await VerifyRestore(new string[] { backupFileName }, null, canRestore, executionMode, targetDatabase); } private async Task VerifyRestore( string[] backupFileNames = null, string sourceDbName = null, - bool canRestore = true, - bool execute = false, + bool canRestore = true, + TaskExecutionModeFlag executionMode = TaskExecutionModeFlag.None, string targetDatabase = null, string[] selectedBackupSets = null, Dictionary options = null, @@ -496,7 +500,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery Assert.NotNull(response.PlanDetails[RestoreOptionsHelper.StandbyFile]); Assert.NotNull(response.PlanDetails[RestoreOptionsHelper.StandbyFile]); - if(execute) + if(executionMode != TaskExecutionModeFlag.None) { try { @@ -504,21 +508,28 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery restoreDataObject = service.CreateRestoreDatabaseTaskDataObject(request); Assert.Equal(response.SessionId, restoreDataObject.SessionId); request.RelocateDbFiles = !restoreDataObject.DbFilesLocationAreValid(); - service.ExecuteRestore(restoreDataObject); - Assert.True(restoreDataObject.Server.Databases.Contains(targetDatabase)); + restoreDataObject.Execute((TaskExecutionMode)Enum.Parse(typeof(TaskExecutionMode), executionMode.ToString())); - if (verifyDatabase != null) + if (executionMode.HasFlag(TaskExecutionModeFlag.Execute)) { - Assert.True(verifyDatabase(restoreDataObject.Server.Databases[targetDatabase])); - } + Assert.True(restoreDataObject.Server.Databases.Contains(targetDatabase)); - //To verify the backupset that are restored, verifying the database is a better options. - //Some tests still verify the number of backup sets that are executed which in some cases can be less than the selected list - if (verifyDatabase == null && selectedBackupSets != null) + if (verifyDatabase != null) + { + Assert.True(verifyDatabase(restoreDataObject.Server.Databases[targetDatabase])); + } + + //To verify the backupset that are restored, verifying the database is a better options. + //Some tests still verify the number of backup sets that are executed which in some cases can be less than the selected list + if (verifyDatabase == null && selectedBackupSets != null) + { + Assert.Equal(selectedBackupSets.Count(), restoreDataObject.RestorePlanToExecute.RestoreOperations.Count()); + } + } + if(executionMode.HasFlag(TaskExecutionModeFlag.Script)) { - Assert.Equal(selectedBackupSets.Count(), restoreDataObject.RestorePlanToExecute.RestoreOperations.Count()); + Assert.False(string.IsNullOrEmpty(restoreDataObject.ScriptContent)); } - } catch(Exception ex) { diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/DisasterRecovery/TaskExecutionModeFlag.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/DisasterRecovery/TaskExecutionModeFlag.cs new file mode 100644 index 00000000..49e7b3cb --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/DisasterRecovery/TaskExecutionModeFlag.cs @@ -0,0 +1,32 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + + +using System; + +namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery +{ + [Flags] + public enum TaskExecutionModeFlag + { + None = 0x00, + + /// + /// Execute task + /// + Execute = 0x01, + + /// + /// Script task + /// + Script = 0x02, + + /// + /// Execute and script task + /// Needed for tasks that will show the script when execution completes + /// + ExecuteAndScript = Execute | Script + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/LanguageServer/LanguageServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/LanguageServer/LanguageServiceTests.cs index d6ef38e3..8304782f 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/LanguageServer/LanguageServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/LanguageServer/LanguageServiceTests.cs @@ -63,12 +63,15 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.LanguageServer [Fact] public void PrepopulateCommonMetadata() { - var result = LiveConnectionHelper.InitLiveConnectionInfo(); - var connInfo = result.ConnectionInfo; + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + { + var result = LiveConnectionHelper.InitLiveConnectionInfo("master", queryTempFile.FilePath); + var connInfo = result.ConnectionInfo; - ScriptParseInfo scriptInfo = new ScriptParseInfo { IsConnected = true }; + ScriptParseInfo scriptInfo = new ScriptParseInfo { IsConnected = true }; - LanguageService.Instance.PrepopulateCommonMetadata(connInfo, scriptInfo, null); + LanguageService.Instance.PrepopulateCommonMetadata(connInfo, scriptInfo, null); + } } // This test currently requires a live database connection to initialize diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/TaskServices/RequstParamStub.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/TaskServices/RequstParamStub.cs new file mode 100644 index 00000000..82f0c6d4 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/TaskServices/RequstParamStub.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.SqlTools.ServiceLayer.TaskServices; +using Microsoft.SqlTools.ServiceLayer.Utility; + +namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.TaskServices +{ + public class RequstParamStub : IScriptableRequestParams + { + public TaskExecutionMode TaskExecutionMode { get; set; } + public string OwnerUri { get; set; } + public string DatabaseName { get; set; } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/TaskServices/SmoScriptableTaskOperationStub.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/TaskServices/SmoScriptableTaskOperationStub.cs new file mode 100644 index 00000000..5a3c6457 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/TaskServices/SmoScriptableTaskOperationStub.cs @@ -0,0 +1,52 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlServer.Management.Smo; +using Microsoft.SqlTools.ServiceLayer.TaskServices; + +namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.TaskServices +{ + public class SmoScriptableTaskOperationStub : SmoScriptableTaskOperation + { + private Server server; + public string DatabaseName { get; set; } + public SmoScriptableTaskOperationStub(Server server) + { + this.server = server; + } + public override string ErrorMessage + { + get + { + return string.Empty; + } + } + + public override Server Server + { + get + { + return server; + } + } + + public override void Cancel() + { + } + + public string TableName { get; set; } + + public override void Execute() + { + var database = server.Databases[DatabaseName]; + Table table = new Table(database, TableName, "test"); + Column column = new Column(table, "c1"); + column.DataType = DataType.Int; + table.Columns.Add(column); + database.Tables.Add(table); + table.Create(); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/TaskServices/TaskServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/TaskServices/TaskServiceTests.cs new file mode 100644 index 00000000..82fc60eb --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/TaskServices/TaskServiceTests.cs @@ -0,0 +1,156 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Data.Common; +using System.Data.SqlClient; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SqlServer.Management.Common; +using Microsoft.SqlServer.Management.Smo; +using Microsoft.SqlTools.Extensibility; +using Microsoft.SqlTools.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection; +using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility; +using Microsoft.SqlTools.ServiceLayer.TaskServices; +using Microsoft.SqlTools.ServiceLayer.TaskServices.Contracts; +using Microsoft.SqlTools.ServiceLayer.Test.Common; +using Microsoft.SqlTools.ServiceLayer.UnitTests; +using Microsoft.SqlTools.ServiceLayer.UnitTests.Utility; +using Moq; +using Xunit; +using static Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility.LiveConnectionHelper; + +namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.TaskServices +{ + public class TaskServiceTests : ServiceTestBase + { + private TaskService service; + private Mock serviceHostMock; + + public TaskServiceTests() + { + serviceHostMock = new Mock(); + service = CreateService(); + service.InitializeService(serviceHostMock.Object); + } + + [Fact] + public async Task VerifyTaskExecuteTheQueryGivenExecutionModeExecute() + { + await VerifyTaskWithExecutionMode(TaskExecutionMode.Execute); + } + + [Fact] + public async Task VerifyTaskGenerateScriptOnlyGivenExecutionModeScript() + { + await VerifyTaskWithExecutionMode(TaskExecutionMode.Script); + } + + [Fact] + public async Task VerifyTaskNotExecuteAndGenerateScriptGivenExecutionModeExecuteAndScript() + { + await VerifyTaskWithExecutionMode(TaskExecutionMode.ExecuteAndScript); + } + + [Fact] + public async Task VerifyTaskSendsFailureNotificationGivenInvalidQuery() + { + await VerifyTaskWithExecutionMode(TaskExecutionMode.ExecuteAndScript, true); + } + + public async Task VerifyTaskWithExecutionMode(TaskExecutionMode executionMode, bool makeTaskFail = false) + { + serviceHostMock.AddEventHandling(TaskStatusChangedNotification.Type, null); + + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + { + //To make the task fail don't create the schema so create table fails + string query = string.Empty; + if (!makeTaskFail) + { + query = $"CREATE SCHEMA [test]"; + } + SqlTestDb testDb = await SqlTestDb.CreateNewAsync(TestServerType.OnPrem, false, null, query, "TaskService"); + try + { + TestConnectionResult connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync(testDb.DatabaseName, queryTempFile.FilePath); + string taskName = "task name"; + Server server = CreateServerObject(connectionResult.ConnectionInfo); + RequstParamStub requstParam = new RequstParamStub + { + TaskExecutionMode = executionMode, + OwnerUri = queryTempFile.FilePath + }; + SmoScriptableTaskOperationStub taskOperation = new SmoScriptableTaskOperationStub(server); + taskOperation.DatabaseName = testDb.DatabaseName; + taskOperation.TableName = "newTable"; + TaskMetadata taskMetadata = TaskMetadata.Create(requstParam, taskName, taskOperation, ConnectionService.Instance); + SqlTask sqlTask = service.TaskManager.CreateTask(taskMetadata); + Task taskToVerify = sqlTask.RunAsync().ContinueWith(task => + { + if (!makeTaskFail) + { + if (executionMode == TaskExecutionMode.Script || executionMode == TaskExecutionMode.ExecuteAndScript) + { + serviceHostMock.Verify(x => x.SendEvent(TaskStatusChangedNotification.Type, + It.Is(t => !string.IsNullOrEmpty(t.Script))), Times.AtLeastOnce()); + } + + //Verify if the table created if execution mode includes execute + bool expected = executionMode == TaskExecutionMode.Execute || executionMode == TaskExecutionMode.ExecuteAndScript; + Server serverToverfiy = CreateServerObject(connectionResult.ConnectionInfo); + bool actual = serverToverfiy.Databases[testDb.DatabaseName].Tables.Contains(taskOperation.TableName, "test"); + Assert.Equal(expected, actual); + } + else + { + serviceHostMock.Verify(x => x.SendEvent(TaskStatusChangedNotification.Type, + It.Is(t => t.Status == SqlTaskStatus.Failed)), Times.AtLeastOnce()); + } + }); + await taskToVerify; + } + finally + { + testDb.Cleanup(); + } + + } + } + protected TaskService CreateService() + { + CreateServiceProviderWithMinServices(); + + // Create the service using the service provider, which will initialize dependencies + return ServiceProvider.GetService(); + } + + protected override RegisteredServiceProvider CreateServiceProviderWithMinServices() + { + TaskService service = new TaskService(); + service.TaskManager = new SqlTaskManager(); + return CreateProvider() + .RegisterSingleService(service); + } + + private Server CreateServerObject(ConnectionInfo connInfo ) + { + SqlConnection connection = null; + DbConnection dbConnection = connInfo.AllConnections.First(); + ReliableSqlConnection reliableSqlConnection = dbConnection as ReliableSqlConnection; + SqlConnection sqlConnection = dbConnection as SqlConnection; + if (reliableSqlConnection != null) + { + connection = reliableSqlConnection.GetUnderlyingConnection(); + } + else if (sqlConnection != null) + { + connection = sqlConnection; + } + return new Server(new ServerConnection(connection)); + } + } +} \ No newline at end of file diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/DisasterRecovery/BackupOperationStub.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/DisasterRecovery/BackupOperationStub.cs index 8e80f853..a8c13a10 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/DisasterRecovery/BackupOperationStub.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/DisasterRecovery/BackupOperationStub.cs @@ -28,6 +28,16 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.DisasterRecovery public string ScriptContent { get; set; } + public string ErrorMessage + { + get + { + return string.Empty; + } + } + + public SqlTask SqlTask { get; set; } + /// /// Initialize /// diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/DisasterRecovery/BackupTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/DisasterRecovery/BackupTests.cs index 1a3386aa..636ed87a 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/DisasterRecovery/BackupTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/DisasterRecovery/BackupTests.cs @@ -201,15 +201,14 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.DisasterRecovery } } - private TaskMetadata CreateTaskMetaData(object data) + private TaskMetadata CreateTaskMetaData(IBackupOperation data) { TaskMetadata taskMetaData = new TaskMetadata { ServerName = "server name", DatabaseName = "database name", Name = "backup database", - IsCancelable = true, - Data = data + TaskOperation = data }; return taskMetaData; diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/TaskServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/TaskServiceTests.cs index 6350c455..cead820d 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/TaskServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/TaskServiceTests.cs @@ -124,8 +124,10 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.TaskServices protected override RegisteredServiceProvider CreateServiceProviderWithMinServices() { + TaskService service = new TaskService(); + service.TaskManager = new SqlTaskManager(); return CreateProvider() - .RegisterSingleService(new TaskService()); + .RegisterSingleService(service); } } }