diff --git a/README.md b/README.md
index 396802d6..e915aca8 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
[](https://travis-ci.org/Microsoft/sqltoolsservice)
[](https://ci.appveyor.com/project/kburtram/sqltoolsservice)
-[](https://coveralls.io/github/Microsoft/sqltoolsservice?branch=dev)
+[](https://coveralls.io/github/Microsoft/sqltoolsservice?branch=master)
# Microsoft SQL Tools Service
The SQL Tools Service is an application that provides core functionality for various SQL Server tools. These features include the following:
diff --git a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/BackupOperation/BackupOperation.cs b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/BackupOperation/BackupOperation.cs
index dae69d45..ca3c5127 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/BackupOperation/BackupOperation.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/BackupOperation/BackupOperation.cs
@@ -52,7 +52,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery
/// this is used when the backup dialog is launched in the context of a backup device
/// The InitialBackupDestination will be loaded in LoadData
private string initialBackupDestination = string.Empty;
-
+
// Helps in populating the properties of an Azure blob given its URI
private class BlobProperties
{
@@ -163,6 +163,19 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery
}
}
+ ///
+ /// The error occurred during backup operation
+ ///
+ public string ErrorMessage
+ {
+ get
+ {
+ return string.Empty;
+ }
+ }
+
+ public SqlTask SqlTask { get; set; }
+
///
/// Execute backup
///
diff --git a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/Contracts/RestoreRequestParams.cs b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/Contracts/RestoreRequestParams.cs
index b761ced1..de32a85d 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/Contracts/RestoreRequestParams.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/Contracts/RestoreRequestParams.cs
@@ -4,6 +4,7 @@
//
using System.Collections.Generic;
+using Microsoft.SqlTools.ServiceLayer.TaskServices;
using Microsoft.SqlTools.ServiceLayer.Utility;
using Newtonsoft.Json.Linq;
@@ -12,7 +13,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.Contracts
///
/// Restore request parameters
///
- public class RestoreParams : GeneralRequestDetails
+ public class RestoreParams : GeneralRequestDetails, IScriptableRequestParams
{
///
/// Restore session id. The parameter is optional and if passed, an existing plan will be used
@@ -140,6 +141,26 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.Contracts
SetOptionValue(RestoreOptionsHelper.SelectedBackupSets, value);
}
}
+
+ ///
+ /// The executation mode for the operation. default is execution
+ ///
+ public TaskExecutionMode TaskExecutionMode { get; set; }
+
+ ///
+ /// Same as Target Database name. Used by task manager to create task info
+ ///
+ public string DatabaseName
+ {
+ get
+ {
+ return TargetDatabaseName;
+ }
+ set
+ {
+ TargetDatabaseName = value;
+ }
+ }
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/DisasterRecoveryService.cs b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/DisasterRecoveryService.cs
index 46c713f6..9d5d9403 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/DisasterRecoveryService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/DisasterRecoveryService.cs
@@ -224,15 +224,10 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery
if (restoreDataObject != null)
{
// create task metadata
- TaskMetadata metadata = new TaskMetadata();
- metadata.ServerName = connInfo.ConnectionDetails.ServerName;
- metadata.DatabaseName = restoreParams.TargetDatabaseName;
- metadata.Name = SR.RestoreTaskName;
- metadata.IsCancelable = true;
- metadata.Data = restoreDataObject;
+ TaskMetadata metadata = TaskMetadata.Create(restoreParams, SR.RestoreTaskName, restoreDataObject, ConnectionServiceInstance);
// create restore task and perform
- SqlTask sqlTask = SqlTaskManagerInstance.CreateAndRun(metadata, this.restoreDatabaseService.RestoreTaskAsync, restoreDatabaseService.CancelTaskAsync);
+ SqlTask sqlTask = SqlTaskManagerInstance.CreateAndRun(metadata);
response.TaskId = sqlTask.TaskId.ToString();
}
else
@@ -285,8 +280,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery
TaskMetadata metadata = new TaskMetadata();
metadata.ServerName = connInfo.ConnectionDetails.ServerName;
metadata.DatabaseName = connInfo.ConnectionDetails.DatabaseName;
- metadata.Data = backupOperation;
- metadata.IsCancelable = true;
+ metadata.TaskOperation = backupOperation;
if (backupParams.IsScripting)
{
@@ -414,7 +408,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery
///
internal async Task PerformBackupTaskAsync(SqlTask sqlTask)
{
- IBackupOperation backupOperation = sqlTask.TaskMetadata.Data as IBackupOperation;
+ IBackupOperation backupOperation = sqlTask.TaskMetadata.TaskOperation as IBackupOperation;
TaskResult result = new TaskResult();
// Create a task to perform backup
@@ -463,7 +457,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery
///
internal async Task CancelBackupTaskAsync(SqlTask sqlTask)
{
- IBackupOperation backupOperation = sqlTask.TaskMetadata.Data as IBackupOperation;
+ IBackupOperation backupOperation = sqlTask.TaskMetadata.TaskOperation as IBackupOperation;
TaskResult result = new TaskResult();
await Task.Factory.StartNew(() =>
diff --git a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseHelper.cs b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseHelper.cs
index f2b934e3..370ff474 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseHelper.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseHelper.cs
@@ -7,13 +7,11 @@ using System;
using System.Linq;
using System.Data.Common;
using System.Data.SqlClient;
-using System.Threading.Tasks;
using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
using Microsoft.SqlTools.ServiceLayer.DisasterRecovery.Contracts;
-using Microsoft.SqlTools.ServiceLayer.TaskServices;
using Microsoft.SqlTools.Utility;
using System.Collections.Concurrent;
using Microsoft.SqlTools.ServiceLayer.Utility;
@@ -28,107 +26,6 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation
public const string LastBackupTaken = "lastBackupTaken";
private ConcurrentDictionary sessions = new ConcurrentDictionary();
- ///
- /// Create a backup task for execution and cancellation
- ///
- ///
- ///
- internal async Task RestoreTaskAsync(SqlTask sqlTask)
- {
- sqlTask.AddMessage(SR.TaskInProgress, SqlTaskStatus.InProgress, true);
- RestoreDatabaseTaskDataObject restoreDataObject = sqlTask.TaskMetadata.Data as RestoreDatabaseTaskDataObject;
- TaskResult taskResult = null;
-
- if (restoreDataObject != null)
- {
- // Create a task to perform backup
- return await Task.Factory.StartNew(() =>
- {
- TaskResult result = new TaskResult();
- try
- {
- if (restoreDataObject.IsValid)
- {
- ExecuteRestore(restoreDataObject, sqlTask);
- result.TaskStatus = SqlTaskStatus.Succeeded;
- }
- else
- {
- result.TaskStatus = SqlTaskStatus.Failed;
- if (restoreDataObject.ActiveException != null)
- {
- result.ErrorMessage = restoreDataObject.ActiveException.Message;
- }
- else
- {
- result.ErrorMessage = SR.RestoreNotSupported;
- }
- }
- }
- catch (Exception ex)
- {
- result.TaskStatus = SqlTaskStatus.Failed;
- result.ErrorMessage = ex.Message;
- if (ex.InnerException != null)
- {
- result.ErrorMessage += Environment.NewLine + ex.InnerException.Message;
- }
- if (restoreDataObject != null && restoreDataObject.ActiveException != null)
- {
- result.ErrorMessage += Environment.NewLine + restoreDataObject.ActiveException.Message;
- }
- }
- return result;
- });
- }
- else
- {
- taskResult = new TaskResult();
- taskResult.TaskStatus = SqlTaskStatus.Failed;
- }
-
- return taskResult;
- }
-
-
-
- ///
- /// Async task to cancel restore
- ///
- public async Task CancelTaskAsync(SqlTask sqlTask)
- {
- RestoreDatabaseTaskDataObject restoreDataObject = sqlTask.TaskMetadata.Data as RestoreDatabaseTaskDataObject;
- TaskResult taskResult = null;
-
-
- if (restoreDataObject != null && restoreDataObject.IsValid)
- {
- // Create a task for backup cancellation request
- return await Task.Factory.StartNew(() =>
- {
-
- foreach (Restore restore in restoreDataObject.RestorePlan.RestoreOperations)
- {
- restore.Abort();
- }
-
-
- return new TaskResult
- {
- TaskStatus = SqlTaskStatus.Canceled
- };
-
- });
- }
- else
- {
- taskResult = new TaskResult();
- taskResult.TaskStatus = SqlTaskStatus.Failed;
- }
-
- return taskResult;
- }
-
///
/// Creates response which includes information about the server given to restore (default data location, db names with backupsets)
///
@@ -166,7 +63,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation
{
if (restoreDataObject != null && restoreDataObject.IsValid)
{
- UpdateRestorePlan(restoreDataObject);
+ restoreDataObject.UpdateRestoreTaskObject();
if (restoreDataObject != null && restoreDataObject.IsValid)
{
@@ -227,6 +124,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation
response.ErrorMessage += Environment.NewLine;
response.ErrorMessage += ex.InnerException.Message;
}
+ Logger.Write(LogLevel.Normal, $"Failed to create restore plan. error: { response.ErrorMessage}");
}
return response;
@@ -295,63 +193,11 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation
return null;
}
- ///
- /// Create a restore data object that includes the plan to do the restore operation
- ///
- ///
- ///
- private void UpdateRestorePlan(RestoreDatabaseTaskDataObject restoreDataObject)
- {
- bool shouldCreateNewPlan = restoreDataObject.ShouldCreateNewPlan();
-
- if (!string.IsNullOrEmpty(restoreDataObject.RestoreParams.BackupFilePaths))
- {
- restoreDataObject.AddFiles(restoreDataObject.RestoreParams.BackupFilePaths);
- }
- restoreDataObject.RestorePlanner.ReadHeaderFromMedia = restoreDataObject.RestoreParams.ReadHeaderFromMedia;
-
- RestoreOptionFactory.Instance.SetAndValidate(RestoreOptionsHelper.SourceDatabaseName, restoreDataObject);
- RestoreOptionFactory.Instance.SetAndValidate(RestoreOptionsHelper.TargetDatabaseName, restoreDataObject);
-
- if (shouldCreateNewPlan)
- {
- restoreDataObject.CreateNewRestorePlan();
- }
-
- restoreDataObject.UpdateRestorePlan();
-
- }
+
private bool CanChangeTargetDatabase(RestoreDatabaseTaskDataObject restoreDataObject)
{
return DatabaseUtils.IsSystemDatabaseConnection(restoreDataObject.Server.ConnectionContext.DatabaseName);
}
-
- ///
- /// Executes the restore operation
- ///
- ///
- public void ExecuteRestore(RestoreDatabaseTaskDataObject restoreDataObject, SqlTask sqlTask = null)
- {
- // Restore Plan should be already created and updated at this point
- UpdateRestorePlan(restoreDataObject);
-
- if (restoreDataObject != null && CanRestore(restoreDataObject))
- {
- try
- {
- restoreDataObject.SqlTask = sqlTask;
- restoreDataObject.Execute();
- }
- catch(Exception ex)
- {
- throw ex;
- }
- }
- else
- {
- throw new InvalidOperationException(SR.RestoreNotSupported);
- }
- }
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseTaskDataObject.cs b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseTaskDataObject.cs
index 046db852..a12b6d6a 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseTaskDataObject.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseTaskDataObject.cs
@@ -13,6 +13,7 @@ using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.ServiceLayer.DisasterRecovery.Contracts;
using Microsoft.SqlTools.ServiceLayer.TaskServices;
using Microsoft.SqlTools.ServiceLayer.Utility;
+using Microsoft.SqlTools.Utility;
namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation
{
@@ -61,7 +62,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation
///
/// Includes the plan with all the data required to do a restore operation on server
///
- public class RestoreDatabaseTaskDataObject : IRestoreDatabaseTaskDataObject
+ public class RestoreDatabaseTaskDataObject : SmoScriptableTaskOperation, IRestoreDatabaseTaskDataObject
{
private const char BackupMediaNameSeparator = ',';
@@ -71,11 +72,12 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation
private bool? isTailLogBackupPossible = false;
private bool? isTailLogBackupWithNoRecoveryPossible = false;
private string backupMediaList = string.Empty;
+ private Server server;
public RestoreDatabaseTaskDataObject(Server server, String databaseName)
{
PlanUpdateRequired = true;
- this.Server = server;
+ this.server = server;
this.Util = new RestoreUtil(server);
restorePlanner = new DatabaseRestorePlanner(server);
@@ -105,11 +107,6 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation
///
public string SessionId { get; set; }
- ///
- /// Sql task assigned to the restore object
- ///
- public SqlTask SqlTask { get; set; }
-
public string TargetDatabaseName
{
get
@@ -180,7 +177,13 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation
///
/// Current sqlserver instance
///
- public Server Server;
+ public override Server Server
+ {
+ get
+ {
+ return server;
+ }
+ }
///
/// Recent exception that was thrown
@@ -254,23 +257,36 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation
}
+ public override void Execute(TaskExecutionMode mode)
+ {
+ UpdateRestoreTaskObject();
+
+ base.Execute(mode);
+ }
+
///
/// Executes the restore operations
///
- public void Execute()
+ public override void Execute()
{
- RestorePlan restorePlan = GetRestorePlanForExecutionAndScript();
-
- if (restorePlan != null && restorePlan.RestoreOperations.Count > 0)
+ if (IsValid && RestorePlan.RestoreOperations != null && RestorePlan.RestoreOperations.Any())
{
- restorePlan.PercentComplete += (object sender, PercentCompleteEventArgs e) =>
+ // Restore Plan should be already created and updated at this point
+
+ RestorePlan restorePlan = GetRestorePlanForExecutionAndScript();
+
+ if (restorePlan != null && restorePlan.RestoreOperations.Count > 0)
{
- if (SqlTask != null)
+ restorePlan.PercentComplete += (object sender, PercentCompleteEventArgs e) =>
{
- SqlTask.AddMessage($"{e.Percent}%", SqlTaskStatus.InProgress);
- }
- };
- restorePlan.Execute();
+ OnMessageAdded(new TaskMessage { Description = $"{e.Percent}%", Status = SqlTaskStatus.InProgress });
+ };
+ restorePlan.Execute();
+ }
+ }
+ else
+ {
+ throw new InvalidOperationException(SR.RestoreNotSupported);
}
}
@@ -657,10 +673,6 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation
{
get
{
- if (this.restorePlan == null)
- {
- this.UpdateRestorePlan();
- }
return this.restorePlan;
}
internal set
@@ -797,20 +809,27 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation
{
Database db = null;
List ret = new List();
- if (!this.RestorePlanner.ReadHeaderFromMedia)
+ try
{
- db = this.Server.Databases[this.RestorePlanner.DatabaseName];
+ if (!this.RestorePlanner.ReadHeaderFromMedia)
+ {
+ db = this.Server.Databases[this.RestorePlanner.DatabaseName];
+ }
+ if (restorePlan != null && restorePlan.RestoreOperations.Count > 0)
+ {
+ if (db != null && db.Status == DatabaseStatus.Normal)
+ {
+ ret = this.Util.GetDbFiles(db);
+ }
+ else
+ {
+ ret = this.Util.GetDbFiles(restorePlan.RestoreOperations[0]);
+ }
+ }
}
- if (restorePlan != null && restorePlan.RestoreOperations.Count > 0)
+ catch(Exception ex )
{
- if (db != null && db.Status == DatabaseStatus.Normal)
- {
- ret = this.Util.GetDbFiles(db);
- }
- else
- {
- ret = this.Util.GetDbFiles(restorePlan.RestoreOperations[0]);
- }
+ Logger.Write(LogLevel.Normal, $"Failed to get restore db files. error: {ex.Message}");
}
return ret;
}
@@ -1033,10 +1052,22 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation
}
}
+ ///
+ /// Returns the restore plan error message
+ ///
+ public override string ErrorMessage
+ {
+ get
+ {
+ if (ActiveException != null)
+ {
+ return ActiveException.Message;
+ }
+ return string.Empty;
+ }
+ }
-
-
private bool IsAnyFullBackupSetSelected()
{
bool isSelected = false;
@@ -1210,5 +1241,47 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation
}
return true;
}
+
+ ///
+ /// Cancels the restore operations
+ ///
+ public override void Cancel()
+ {
+ foreach (Restore restore in RestorePlan.RestoreOperations)
+ {
+ restore.Abort();
+ }
+ }
+
+ ///
+ /// Create a restore data object that includes the plan to do the restore operation
+ ///
+ ///
+ ///
+ internal void UpdateRestoreTaskObject()
+ {
+ bool shouldCreateNewPlan = ShouldCreateNewPlan();
+
+ if (!string.IsNullOrEmpty(RestoreParams.BackupFilePaths) && RestoreParams.ReadHeaderFromMedia)
+ {
+ AddFiles(RestoreParams.BackupFilePaths);
+ }
+ else
+ {
+ RestorePlanner.BackupMediaList.Clear();
+ }
+ RestorePlanner.ReadHeaderFromMedia = RestoreParams.ReadHeaderFromMedia;
+
+ RestoreOptionFactory.Instance.SetAndValidate(RestoreOptionsHelper.SourceDatabaseName, this);
+ RestoreOptionFactory.Instance.SetAndValidate(RestoreOptionsHelper.TargetDatabaseName, this);
+
+ if (shouldCreateNewPlan)
+ {
+ CreateNewRestorePlan();
+ }
+
+ UpdateRestorePlan();
+
+ }
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreOptionFactory.cs b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreOptionFactory.cs
index fc4117bc..2ac1ffb2 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreOptionFactory.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreOptionFactory.cs
@@ -538,10 +538,15 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation
},
ValidateFunction = (IRestoreDatabaseTaskDataObject restoreDataObject, object currentValue, object defaultValue) =>
{
-
+ string errorMessage = string.Empty;
+ if (currentValue!= null && DatabaseUtils.IsSystemDatabaseConnection(currentValue.ToString()))
+ {
+ errorMessage = "Cannot restore to system database";
+ }
return new OptionValidationResult()
{
- IsReadOnly = !restoreDataObject.CanChangeTargetDatabase
+ IsReadOnly = !restoreDataObject.CanChangeTargetDatabase,
+ ErrorMessage = errorMessage
};
},
SetValueFunction = (IRestoreDatabaseTaskDataObject restoreDataObject, object value) =>
diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/ITaskOperation.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/ITaskOperation.cs
index 1f8eb862..ed83f72b 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/ITaskOperation.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/ITaskOperation.cs
@@ -1,6 +1,7 @@
-using System;
-using System.Collections.Generic;
-using System.Text;
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
namespace Microsoft.SqlTools.ServiceLayer.TaskServices
{
@@ -19,6 +20,16 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices
/// Cancel a task
///
void Cancel();
+
+ ///
+ /// If an error occurred during task execution, this field contains the error message text
+ ///
+ string ErrorMessage { get; }
+
+ ///
+ /// The sql task that's executing the operation
+ ///
+ SqlTask SqlTask { get; set; }
}
///
diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SmoScriptableTaskOperation.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SmoScriptableTaskOperation.cs
new file mode 100644
index 00000000..7c265525
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SmoScriptableTaskOperation.cs
@@ -0,0 +1,146 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.Text;
+using Microsoft.SqlServer.Management.Common;
+using Microsoft.SqlServer.Management.Smo;
+
+namespace Microsoft.SqlTools.ServiceLayer.TaskServices
+{
+ ///
+ /// Any SMO operation that supports scripting should implement this class.
+ /// It provides all of the configuration needed to choose between scripting or execution mode,
+ /// hook into the Task manager framework, and send success / completion notifications to the caller.
+ ///
+ public abstract class SmoScriptableTaskOperation : IScriptableTaskOperation
+ {
+ ///
+ /// Script content
+ ///
+ public string ScriptContent
+ {
+ get; set;
+ }
+
+ ///
+ /// If an error occurred during task execution, this field contains the error message text
+ ///
+ public abstract string ErrorMessage { get; }
+
+ ///
+ /// SMO Server instance used for the operation
+ ///
+ public abstract Server Server { get; }
+
+ ///
+ /// Cancels the operation
+ ///
+ public abstract void Cancel();
+
+ ///
+ /// Updates messages in sql task given new progress message
+ ///
+ ///
+ public void OnMessageAdded(TaskMessage message)
+ {
+ if (this.SqlTask != null)
+ {
+ this.SqlTask.AddMessage(message.Description, message.Status);
+ }
+ }
+
+ ///
+ /// Updates scripts in sql task given new script
+ ///
+ ///
+ public void OnScriptAdded(TaskScript script)
+ {
+ this.SqlTask.AddScript(script.Status, script.Script, script.ErrorMessage);
+ }
+
+ ///
+ /// Executes the operations
+ ///
+ public abstract void Execute();
+
+
+ ///
+ /// Execute the operation for given execution mode
+ ///
+ ///
+ public virtual void Execute(TaskExecutionMode mode)
+ {
+ var currentExecutionMode = Server.ConnectionContext.SqlExecutionModes;
+ try
+ {
+
+ if (Server != null)
+ {
+ Server.ConnectionContext.CapturedSql.Clear();
+ switch (mode)
+ {
+ case TaskExecutionMode.Execute:
+ {
+ Server.ConnectionContext.SqlExecutionModes = SqlExecutionModes.ExecuteSql;
+ break;
+ }
+ case TaskExecutionMode.ExecuteAndScript:
+ {
+ Server.ConnectionContext.SqlExecutionModes = SqlExecutionModes.ExecuteAndCaptureSql;
+ break;
+ }
+ case TaskExecutionMode.Script:
+ {
+ Server.ConnectionContext.SqlExecutionModes = SqlExecutionModes.CaptureSql;
+ break;
+ }
+ }
+ }
+
+ Execute();
+ if (mode == TaskExecutionMode.Script || mode == TaskExecutionMode.ExecuteAndScript)
+ {
+ this.ScriptContent = GetScriptContent();
+ if (SqlTask != null)
+ {
+ OnScriptAdded(new TaskScript
+ {
+ Status = SqlTaskStatus.Succeeded,
+ Script = this.ScriptContent
+ });
+ }
+ }
+ }
+ catch
+ {
+ throw;
+ }
+ finally
+ {
+ Server.ConnectionContext.CapturedSql.Clear();
+ Server.ConnectionContext.SqlExecutionModes = currentExecutionMode;
+ }
+
+ }
+
+ private string GetScriptContent()
+ {
+ StringBuilder sb = new StringBuilder();
+ foreach (String s in this.Server.ConnectionContext.CapturedSql.Text)
+ {
+ sb.Append(s);
+ sb.Append(Environment.NewLine);
+ }
+ return sb.ToString();
+ }
+
+ ///
+ /// The sql task to run the operations
+ ///
+ public SqlTask SqlTask { get; set; }
+
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTask.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTask.cs
index 6e4628b9..0e654b6e 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTask.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTask.cs
@@ -34,6 +34,14 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices
public event EventHandler> MessageAdded;
public event EventHandler> StatusChanged;
+ ///
+ /// Default constructor to create the geenric type. calling Initialize method is required after creating
+ /// the insance
+ ///
+ public SqlTask()
+ {
+ }
+
///
/// Creates new instance of SQL task
///
@@ -41,10 +49,21 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices
/// The function to run to start the task
public SqlTask(TaskMetadata taskMetdata, Func> taskToRun, Func> taskToCancel)
{
- Validate.IsNotNull(nameof(taskMetdata), taskMetdata);
+ Init(taskMetdata, taskToRun, taskToCancel);
+ }
+
+ ///
+ /// Initializes the Sql task
+ ///
+ /// Task metadata
+ /// The function to execute the operation
+ /// The function to cancel the operation
+ public virtual void Init(TaskMetadata taskMetadata, Func> taskToRun, Func> taskToCancel)
+ {
+ Validate.IsNotNull(nameof(taskMetadata), taskMetadata);
Validate.IsNotNull(nameof(taskToRun), taskToRun);
- TaskMetadata = taskMetdata;
+ TaskMetadata = taskMetadata;
TaskToRun = taskToRun;
TaskToCancel = taskToCancel;
StartTime = DateTime.UtcNow;
@@ -120,8 +139,9 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices
///
internal async Task RunAndCancel()
{
+ TokenSource = new CancellationTokenSource();
AddMessage(SR.TaskInProgress, SqlTaskStatus.InProgress, true);
-
+
TaskResult taskResult = new TaskResult();
Task performTask = TaskToRun(this);
Task completedTask = null;
@@ -452,7 +472,7 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices
Name = TaskMetadata.Name,
Description = TaskMetadata.Description,
TaskExecutionMode = TaskMetadata.TaskExecutionMode,
- IsCancelable = TaskMetadata.IsCancelable,
+ IsCancelable = this.TaskToCancel != null,
};
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTaskManager.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTaskManager.cs
index b1d2198d..9705608f 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTaskManager.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTaskManager.cs
@@ -8,6 +8,7 @@ using System.Collections.Concurrent;
using System.Collections.ObjectModel;
using System.Linq;
using System.Threading.Tasks;
+using Microsoft.SqlTools.Utility;
namespace Microsoft.SqlTools.ServiceLayer.TaskServices
{
@@ -83,19 +84,58 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices
/// Task Metadata
/// The function to run the operation
/// The function to cancel the operation
- ///
- public SqlTask CreateTask(TaskMetadata taskMetadata, Func> taskToRun, Func> taskToCancel)
+ /// The new sql task
+ public SqlTask CreateTask(TaskMetadata taskMetadata, Func> taskToRun, Func> taskToCancel)
+ {
+ return CreateTask(taskMetadata, taskToRun, taskToCancel);
+ }
+
+ ///
+ /// Creates a new task
+ ///
+ /// Task Metadata
+ /// The new sql task
+ public SqlTask CreateTask(TaskMetadata taskMetadata) where T : SqlTask, new()
+ {
+ Validate.IsNotNull(nameof(taskMetadata), taskMetadata);
+ return CreateTask(taskMetadata, TaskOperationHelper.ExecuteTaskAsync, TaskOperationHelper.CancelTaskAsync);
+ }
+
+ ///
+ /// Creates a new task
+ ///
+ /// Task Metadata
+ /// The function to run the operation
+ /// The function to cancel the operation
+ /// The new sql task
+ public SqlTask CreateTask(TaskMetadata taskMetadata, Func> taskToRun, Func> taskToCancel) where T : SqlTask, new()
{
ValidateNotDisposed();
- var newtask = new SqlTask(taskMetadata, taskToRun, taskToCancel);
+ var newTask = new T();
+ newTask.Init(taskMetadata, taskToRun, taskToCancel);
+ if (taskMetadata != null && taskMetadata.TaskOperation != null)
+ {
+ taskMetadata.TaskOperation.SqlTask = newTask;
+ }
lock (lockObject)
{
- tasks.AddOrUpdate(newtask.TaskId, newtask, (key, oldValue) => newtask);
+ tasks.AddOrUpdate(newTask.TaskId, newTask, (key, oldValue) => newTask);
}
- OnTaskAdded(new TaskEventArgs(newtask));
- return newtask;
+ OnTaskAdded(new TaskEventArgs(newTask));
+ return newTask;
+ }
+
+ ///
+ /// Creates a new task
+ ///
+ /// Task Metadata
+ /// The function to run the operation
+ /// The new sql task
+ public SqlTask CreateTask(TaskMetadata taskMetadata, Func> taskToRun)
+ {
+ return CreateTask(taskMetadata, taskToRun);
}
///
@@ -104,9 +144,9 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices
/// Task Metadata
/// The function to run the operation
///
- public SqlTask CreateTask(TaskMetadata taskMetadata, Func> taskToRun)
+ public SqlTask CreateTask(TaskMetadata taskMetadata, Func> taskToRun) where T : SqlTask, new()
{
- return CreateTask(taskMetadata, taskToRun, null);
+ return CreateTask(taskMetadata, taskToRun, null);
}
///
@@ -118,7 +158,26 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices
///
public SqlTask CreateAndRun(TaskMetadata taskMetadata, Func> taskToRun, Func> taskToCancel)
{
- var sqlTask = CreateTask(taskMetadata, taskToRun, taskToCancel);
+ return CreateAndRun(taskMetadata, taskToRun, taskToCancel);
+ }
+
+ public SqlTask CreateAndRun(TaskMetadata taskMetadata) where T : SqlTask, new()
+ {
+ var sqlTask = CreateTask(taskMetadata);
+ sqlTask.Run();
+ return sqlTask;
+ }
+
+ ///
+ /// Creates a new task and starts the task
+ ///
+ /// Task Metadata
+ /// The function to run the operation
+ /// The function to cancel the operation
+ ///
+ public SqlTask CreateAndRun(TaskMetadata taskMetadata, Func> taskToRun, Func> taskToCancel) where T : SqlTask, new()
+ {
+ var sqlTask = CreateTask(taskMetadata, taskToRun, taskToCancel);
sqlTask.Run();
return sqlTask;
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskMetadata.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskMetadata.cs
index ca485524..26ebcb4b 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskMetadata.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskMetadata.cs
@@ -3,6 +3,9 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
+using Microsoft.SqlTools.ServiceLayer.Connection;
+using Microsoft.SqlTools.ServiceLayer.Utility;
+
namespace Microsoft.SqlTools.ServiceLayer.TaskServices
{
public class TaskMetadata
@@ -28,11 +31,6 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices
///
public int Timeout { get; set; }
- ///
- /// Defines if the task can be canceled
- ///
- public bool IsCancelable { get; set; }
-
///
/// Database server name this task is created for
///
@@ -46,6 +44,52 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices
///
/// Data required to perform the task
///
- public object Data { get; set; }
+ public ITaskOperation TaskOperation { get; set; }
+
+ ///
+ /// Creates task metadata given the request parameters
+ ///
+ /// Request parameters
+ /// Task name
+ /// Task operation
+ /// Connection Service
+ /// Task metadata
+ public static TaskMetadata Create(IRequestParams requestParam, string taskName, ITaskOperation taskOperation, ConnectionService connectionService)
+ {
+ TaskMetadata taskMetadata = new TaskMetadata();
+ ConnectionInfo connInfo;
+ connectionService.TryFindConnection(
+ requestParam.OwnerUri,
+ out connInfo);
+
+ if (connInfo != null)
+ {
+ taskMetadata.ServerName = connInfo.ConnectionDetails.ServerName;
+ }
+
+ if (!string.IsNullOrEmpty(requestParam.DatabaseName))
+ {
+ taskMetadata.DatabaseName = requestParam.DatabaseName;
+ }
+ else if (connInfo != null)
+ {
+ taskMetadata.DatabaseName = connInfo.ConnectionDetails.DatabaseName;
+ }
+
+ IScriptableRequestParams scriptableRequestParams = requestParam as IScriptableRequestParams;
+ if (scriptableRequestParams != null && scriptableRequestParams.TaskExecutionMode != TaskExecutionMode.Execute)
+ {
+ taskMetadata.Name = string.Format("{0} {1}", taskName, SR.ScriptTaskName);
+ }
+ else
+ {
+ taskMetadata.Name = taskName;
+ }
+ taskMetadata.TaskExecutionMode = scriptableRequestParams.TaskExecutionMode;
+
+ taskMetadata.TaskOperation = taskOperation;
+ return taskMetadata;
+ }
+
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskOperationHelper.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskOperationHelper.cs
new file mode 100644
index 00000000..50b69630
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskOperationHelper.cs
@@ -0,0 +1,113 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.Threading.Tasks;
+
+namespace Microsoft.SqlTools.ServiceLayer.TaskServices
+{
+ ///
+ /// Helper class for task operations
+ ///
+ public static class TaskOperationHelper
+ {
+ ///
+ /// Async method to execute the operation
+ ///
+ /// Sql Task
+ /// Task Result
+ public static async Task ExecuteTaskAsync(SqlTask sqlTask)
+ {
+ sqlTask.AddMessage(SR.TaskInProgress, SqlTaskStatus.InProgress, true);
+ ITaskOperation taskOperation = sqlTask.TaskMetadata.TaskOperation as ITaskOperation;
+ TaskResult taskResult = null;
+
+ if (taskOperation != null)
+ {
+ taskOperation.SqlTask = sqlTask;
+
+ return await Task.Factory.StartNew(() =>
+ {
+ TaskResult result = new TaskResult();
+ try
+ {
+ if (string.IsNullOrEmpty(taskOperation.ErrorMessage))
+ {
+ taskOperation.Execute(sqlTask.TaskMetadata.TaskExecutionMode);
+ result.TaskStatus = SqlTaskStatus.Succeeded;
+ }
+ else
+ {
+ result.TaskStatus = SqlTaskStatus.Failed;
+ result.ErrorMessage = taskOperation.ErrorMessage;
+ }
+ }
+ catch (Exception ex)
+ {
+ result.TaskStatus = SqlTaskStatus.Failed;
+ result.ErrorMessage = ex.Message;
+ if (ex.InnerException != null)
+ {
+ result.ErrorMessage += Environment.NewLine + ex.InnerException.Message;
+ }
+ if (taskOperation != null && taskOperation.ErrorMessage != null)
+ {
+ result.ErrorMessage += Environment.NewLine + taskOperation.ErrorMessage;
+ }
+ }
+ return result;
+ });
+ }
+ else
+ {
+ taskResult = new TaskResult();
+ taskResult.TaskStatus = SqlTaskStatus.Failed;
+ }
+
+ return taskResult;
+ }
+
+ ///
+ /// Async method to cancel the operations
+ ///
+ public static async Task CancelTaskAsync(SqlTask sqlTask)
+ {
+ ITaskOperation taskOperation = sqlTask.TaskMetadata.TaskOperation as ITaskOperation;
+ TaskResult taskResult = null;
+
+ if (taskOperation != null)
+ {
+
+ return await Task.Factory.StartNew(() =>
+ {
+ try
+ {
+ taskOperation.Cancel();
+
+ return new TaskResult
+ {
+ TaskStatus = SqlTaskStatus.Canceled
+ };
+ }
+ catch (Exception ex)
+ {
+ return new TaskResult
+ {
+ TaskStatus = SqlTaskStatus.Failed,
+ ErrorMessage = ex.Message
+ };
+ }
+ });
+ }
+ else
+ {
+ taskResult = new TaskResult();
+ taskResult.TaskStatus = SqlTaskStatus.Failed;
+ }
+
+ return taskResult;
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskService.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskService.cs
index 15f9b513..b0614444 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/TaskService.cs
@@ -17,18 +17,9 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices
public class TaskService: HostedService, IComposableService
{
private static readonly Lazy instance = new Lazy(() => new TaskService());
- private SqlTaskManager taskManager = SqlTaskManager.Instance;
+ private SqlTaskManager taskManager = null;
private IProtocolEndpoint serviceHost;
-
- ///
- /// Default, parameterless constructor.
- ///
- public TaskService()
- {
- taskManager.TaskAdded += OnTaskAdded;
- }
-
///
/// Gets the singleton instance object
///
@@ -44,8 +35,16 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices
{
get
{
+ if(taskManager == null)
+ {
+ taskManager = SqlTaskManager.Instance;
+ }
return taskManager;
}
+ set
+ {
+ taskManager = value;
+ }
}
///
@@ -57,6 +56,7 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices
Logger.Write(LogLevel.Verbose, "TaskService initialized");
serviceHost.SetRequestHandler(ListTasksRequest.Type, HandleListTasksRequest);
serviceHost.SetRequestHandler(CancelTaskRequest.Type, HandleCancelTaskRequest);
+ TaskManager.TaskAdded += OnTaskAdded;
}
///
@@ -74,7 +74,7 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices
return Task.Factory.StartNew(() =>
{
ListTasksResponse response = new ListTasksResponse();
- response.Tasks = taskManager.Tasks.Select(x => x.ToTaskInfo()).ToArray();
+ response.Tasks = TaskManager.Tasks.Select(x => x.ToTaskInfo()).ToArray();
return response;
});
@@ -96,7 +96,7 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices
Guid taskId;
if (Guid.TryParse(cancelTaskParams.TaskId, out taskId))
{
- taskManager.CancelTask(taskId);
+ TaskManager.CancelTask(taskId);
return true;
}
else
@@ -176,8 +176,8 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices
public void Dispose()
{
- taskManager.TaskAdded -= OnTaskAdded;
- taskManager.Dispose();
+ TaskManager.TaskAdded -= OnTaskAdded;
+ TaskManager.Dispose();
}
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Utility/IRequestParams.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/IRequestParams.cs
new file mode 100644
index 00000000..183acb13
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Utility/IRequestParams.cs
@@ -0,0 +1,21 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+
+namespace Microsoft.SqlTools.ServiceLayer.Utility
+{
+ public interface IRequestParams
+ {
+ ///
+ /// The Uri to find the connection to do the restore operations
+ ///
+ string OwnerUri { get; set; }
+
+ ///
+ /// Database name
+ ///
+ string DatabaseName { get; set; }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Utility/IScriptableRequestParams.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/IScriptableRequestParams.cs
new file mode 100644
index 00000000..338ae782
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Utility/IScriptableRequestParams.cs
@@ -0,0 +1,18 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+
+using Microsoft.SqlTools.ServiceLayer.TaskServices;
+
+namespace Microsoft.SqlTools.ServiceLayer.Utility
+{
+ public interface IScriptableRequestParams : IRequestParams
+ {
+ ///
+ /// The executation mode for the operation. default is execution
+ ///
+ TaskExecutionMode TaskExecutionMode { get; set; }
+ }
+}
diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/DisasterRecovery/RestoreDatabaseServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/DisasterRecovery/RestoreDatabaseServiceTests.cs
index aa766be2..50552d4e 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/DisasterRecovery/RestoreDatabaseServiceTests.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/DisasterRecovery/RestoreDatabaseServiceTests.cs
@@ -82,7 +82,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery
//Verify that all backupsets are restored
int[] expectedTable = new int[] { };
- await VerifyRestoreMultipleBackupSets(backupFiles, indexToDelete, expectedTable);
+ await VerifyRestoreMultipleBackupSets(backupFiles, indexToDelete, expectedTable, TaskExecutionModeFlag.Execute);
}
[Fact]
@@ -95,7 +95,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery
{
Dictionary options = new Dictionary();
options.Add(RestoreOptionsHelper.ReplaceDatabase, true);
- await VerifyRestore(null, databaseNameToRestoreFrom, true, true, testDb.DatabaseName, null, options, (database) =>
+ await VerifyRestore(null, databaseNameToRestoreFrom, true, TaskExecutionModeFlag.ExecuteAndScript, testDb.DatabaseName, null, options, (database) =>
{
return database.Tables.Contains("tb1", "test");
});
@@ -129,14 +129,14 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery
await VerifyRestoreMultipleBackupSets(backupFiles, indexToDelete, expectedTable);
}
- private async Task VerifyRestoreMultipleBackupSets(string[] backupFiles, int backupSetIndexToDelete, int[] expectedSelectedIndexes)
+ private async Task VerifyRestoreMultipleBackupSets(string[] backupFiles, int backupSetIndexToDelete, int[] expectedSelectedIndexes, TaskExecutionModeFlag executionMode = TaskExecutionModeFlag.ExecuteAndScript)
{
var testDb = await SqlTestDb.CreateNewAsync(TestServerType.OnPrem, false, null, null, "RestoreTest");
try
{
string targetDbName = testDb.DatabaseName;
bool canRestore = true;
- var response = await VerifyRestore(backupFiles, null, canRestore, false, targetDbName, null, null);
+ var response = await VerifyRestore(backupFiles, null, canRestore, TaskExecutionModeFlag.None, targetDbName, null, null);
Assert.True(response.BackupSetsToRestore.Count() >= 2);
var allIds = response.BackupSetsToRestore.Select(x => x.Id).ToList();
if (backupSetIndexToDelete >= 0)
@@ -146,20 +146,24 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery
string[] selectedIds = allIds.ToArray();
Dictionary options = new Dictionary();
options.Add(RestoreOptionsHelper.ReplaceDatabase, true);
- response = await VerifyRestore(backupFiles, null, canRestore, true, targetDbName, selectedIds, options, (database) =>
+ response = await VerifyRestore(backupFiles, null, canRestore, executionMode, targetDbName, selectedIds, options, (database) =>
{
- bool tablesFound = true;
- for (int i = 0; i < tableNames.Length; i++)
+ if (executionMode.HasFlag(TaskExecutionModeFlag.Execute))
{
- string tableName = tableNames[i];
- if (!database.Tables.Contains(tableName, "test") && expectedSelectedIndexes.Contains(i))
+ bool tablesFound = true;
+ for (int i = 0; i < tableNames.Length; i++)
{
- tablesFound = false;
- break;
+ string tableName = tableNames[i];
+ if (!database.Tables.Contains(tableName, "test") && expectedSelectedIndexes.Contains(i))
+ {
+ tablesFound = false;
+ break;
+ }
}
+ bool numberOfTableCreatedIsCorrect = database.Tables.Count == expectedSelectedIndexes.Length;
+ return numberOfTableCreatedIsCorrect && tablesFound;
}
- bool numberOfTableCreatedIsCorrect = database.Tables.Count == expectedSelectedIndexes.Length;
- return numberOfTableCreatedIsCorrect && tablesFound;
+ return true;
});
for (int i = 0; i < response.BackupSetsToRestore.Count(); i++)
@@ -190,7 +194,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery
Dictionary options = new Dictionary();
options.Add(RestoreOptionsHelper.ReplaceDatabase, true);
- await VerifyRestore(new string[] { fullBackupFilePath }, null, canRestore, true, testDb.DatabaseName, null, options);
+ await VerifyRestore(new string[] { fullBackupFilePath }, null, canRestore, TaskExecutionModeFlag.ExecuteAndScript, testDb.DatabaseName, null, options);
}
finally
{
@@ -212,7 +216,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery
await VerifyBackupFileCreated();
bool canRestore = true;
- await VerifyRestore(new string[] { fullBackupFilePath }, null, canRestore, false, testDb.DatabaseName, null, null);
+ await VerifyRestore(new string[] { fullBackupFilePath }, null, canRestore, TaskExecutionModeFlag.None, testDb.DatabaseName, null, null);
}
finally
{
@@ -229,7 +233,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery
string[] backupFileNames = new string[] { "FullBackup.bak", "DiffBackup.bak" };
bool canRestore = true;
- var response = await VerifyRestore(backupFileNames, null, canRestore, false, "RestoredFromTwoBackupFile");
+ var response = await VerifyRestore(backupFileNames, null, canRestore, TaskExecutionModeFlag.None, "RestoredFromTwoBackupFile");
Assert.True(response.BackupSetsToRestore.Count() == 2);
}
@@ -239,13 +243,13 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery
string[] backupFileNames = new string[] { "FullBackup.bak", "DiffBackup.bak" };
bool canRestore = true;
- var response = await VerifyRestore(backupFileNames, null, canRestore, false, "RestoredFromTwoBackupFile");
+ var response = await VerifyRestore(backupFileNames, null, canRestore, TaskExecutionModeFlag.None, "RestoredFromTwoBackupFile");
Assert.True(response.BackupSetsToRestore.Count() == 2);
var fileInfo = response.BackupSetsToRestore.FirstOrDefault(x => x.GetPropertyValueAsString(BackupSetInfo.BackupTypePropertyName) != RestoreConstants.TypeFull);
if(fileInfo != null)
{
var selectedBackupSets = new string[] { fileInfo.Id };
- await VerifyRestore(backupFileNames, null, true, false, "RestoredFromTwoBackupFile", selectedBackupSets);
+ await VerifyRestore(backupFileNames, null, true, TaskExecutionModeFlag.None, "RestoredFromTwoBackupFile", selectedBackupSets);
}
}
@@ -255,13 +259,13 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery
string[] backupFileNames = new string[] { "FullBackup.bak", "DiffBackup.bak" };
bool canRestore = true;
- var response = await VerifyRestore(backupFileNames, null, canRestore, false, "RestoredFromTwoBackupFile");
+ var response = await VerifyRestore(backupFileNames, null, canRestore, TaskExecutionModeFlag.None, "RestoredFromTwoBackupFile");
Assert.True(response.BackupSetsToRestore.Count() == 2);
var fileInfo = response.BackupSetsToRestore.FirstOrDefault(x => x.GetPropertyValueAsString(BackupSetInfo.BackupTypePropertyName) == RestoreConstants.TypeFull);
if (fileInfo != null)
{
var selectedBackupSets = new string[] { fileInfo.Id };
- await VerifyRestore(backupFileNames, null, true, false, "RestoredFromTwoBackupFile2", selectedBackupSets);
+ await VerifyRestore(backupFileNames, null, true, TaskExecutionModeFlag.None, "RestoredFromTwoBackupFile2", selectedBackupSets);
}
}
@@ -272,7 +276,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery
string backupFileName = fullBackupFilePath;
bool canRestore = true;
- var restorePlan = await VerifyRestore(backupFileName, canRestore, true);
+ var restorePlan = await VerifyRestore(backupFileName, canRestore, TaskExecutionModeFlag.Execute);
Assert.NotNull(restorePlan.BackupSetsToRestore);
}
@@ -283,7 +287,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery
string backupFileName = fullBackupFilePath;
bool canRestore = true;
- var restorePlan = await VerifyRestore(backupFileName, canRestore, true, "NewRestoredDatabase");
+ var restorePlan = await VerifyRestore(backupFileName, canRestore, TaskExecutionModeFlag.ExecuteAndScript, "NewRestoredDatabase");
}
[Fact]
@@ -414,16 +418,16 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery
await TestServiceProvider.Instance.RunQueryAsync(TestServerType.OnPrem, "master", dropDatabaseQuery);
}
- private async Task VerifyRestore(string backupFileName, bool canRestore, bool execute = false, string targetDatabase = null)
+ private async Task VerifyRestore(string backupFileName, bool canRestore, TaskExecutionModeFlag executionMode = TaskExecutionModeFlag.None, string targetDatabase = null)
{
- return await VerifyRestore(new string[] { backupFileName }, null, canRestore, execute, targetDatabase);
+ return await VerifyRestore(new string[] { backupFileName }, null, canRestore, executionMode, targetDatabase);
}
private async Task VerifyRestore(
string[] backupFileNames = null,
string sourceDbName = null,
- bool canRestore = true,
- bool execute = false,
+ bool canRestore = true,
+ TaskExecutionModeFlag executionMode = TaskExecutionModeFlag.None,
string targetDatabase = null,
string[] selectedBackupSets = null,
Dictionary options = null,
@@ -496,7 +500,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery
Assert.NotNull(response.PlanDetails[RestoreOptionsHelper.StandbyFile]);
Assert.NotNull(response.PlanDetails[RestoreOptionsHelper.StandbyFile]);
- if(execute)
+ if(executionMode != TaskExecutionModeFlag.None)
{
try
{
@@ -504,21 +508,28 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery
restoreDataObject = service.CreateRestoreDatabaseTaskDataObject(request);
Assert.Equal(response.SessionId, restoreDataObject.SessionId);
request.RelocateDbFiles = !restoreDataObject.DbFilesLocationAreValid();
- service.ExecuteRestore(restoreDataObject);
- Assert.True(restoreDataObject.Server.Databases.Contains(targetDatabase));
+ restoreDataObject.Execute((TaskExecutionMode)Enum.Parse(typeof(TaskExecutionMode), executionMode.ToString()));
- if (verifyDatabase != null)
+ if (executionMode.HasFlag(TaskExecutionModeFlag.Execute))
{
- Assert.True(verifyDatabase(restoreDataObject.Server.Databases[targetDatabase]));
- }
+ Assert.True(restoreDataObject.Server.Databases.Contains(targetDatabase));
- //To verify the backupset that are restored, verifying the database is a better options.
- //Some tests still verify the number of backup sets that are executed which in some cases can be less than the selected list
- if (verifyDatabase == null && selectedBackupSets != null)
+ if (verifyDatabase != null)
+ {
+ Assert.True(verifyDatabase(restoreDataObject.Server.Databases[targetDatabase]));
+ }
+
+ //To verify the backupset that are restored, verifying the database is a better options.
+ //Some tests still verify the number of backup sets that are executed which in some cases can be less than the selected list
+ if (verifyDatabase == null && selectedBackupSets != null)
+ {
+ Assert.Equal(selectedBackupSets.Count(), restoreDataObject.RestorePlanToExecute.RestoreOperations.Count());
+ }
+ }
+ if(executionMode.HasFlag(TaskExecutionModeFlag.Script))
{
- Assert.Equal(selectedBackupSets.Count(), restoreDataObject.RestorePlanToExecute.RestoreOperations.Count());
+ Assert.False(string.IsNullOrEmpty(restoreDataObject.ScriptContent));
}
-
}
catch(Exception ex)
{
diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/DisasterRecovery/TaskExecutionModeFlag.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/DisasterRecovery/TaskExecutionModeFlag.cs
new file mode 100644
index 00000000..49e7b3cb
--- /dev/null
+++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/DisasterRecovery/TaskExecutionModeFlag.cs
@@ -0,0 +1,32 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+
+using System;
+
+namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery
+{
+ [Flags]
+ public enum TaskExecutionModeFlag
+ {
+ None = 0x00,
+
+ ///
+ /// Execute task
+ ///
+ Execute = 0x01,
+
+ ///
+ /// Script task
+ ///
+ Script = 0x02,
+
+ ///
+ /// Execute and script task
+ /// Needed for tasks that will show the script when execution completes
+ ///
+ ExecuteAndScript = Execute | Script
+ }
+}
diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/LanguageServer/LanguageServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/LanguageServer/LanguageServiceTests.cs
index d6ef38e3..8304782f 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/LanguageServer/LanguageServiceTests.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/LanguageServer/LanguageServiceTests.cs
@@ -63,12 +63,15 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.LanguageServer
[Fact]
public void PrepopulateCommonMetadata()
{
- var result = LiveConnectionHelper.InitLiveConnectionInfo();
- var connInfo = result.ConnectionInfo;
+ using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
+ {
+ var result = LiveConnectionHelper.InitLiveConnectionInfo("master", queryTempFile.FilePath);
+ var connInfo = result.ConnectionInfo;
- ScriptParseInfo scriptInfo = new ScriptParseInfo { IsConnected = true };
+ ScriptParseInfo scriptInfo = new ScriptParseInfo { IsConnected = true };
- LanguageService.Instance.PrepopulateCommonMetadata(connInfo, scriptInfo, null);
+ LanguageService.Instance.PrepopulateCommonMetadata(connInfo, scriptInfo, null);
+ }
}
// This test currently requires a live database connection to initialize
diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/TaskServices/RequstParamStub.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/TaskServices/RequstParamStub.cs
new file mode 100644
index 00000000..82f0c6d4
--- /dev/null
+++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/TaskServices/RequstParamStub.cs
@@ -0,0 +1,15 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Microsoft.SqlTools.ServiceLayer.TaskServices;
+using Microsoft.SqlTools.ServiceLayer.Utility;
+
+namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.TaskServices
+{
+ public class RequstParamStub : IScriptableRequestParams
+ {
+ public TaskExecutionMode TaskExecutionMode { get; set; }
+ public string OwnerUri { get; set; }
+ public string DatabaseName { get; set; }
+ }
+}
diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/TaskServices/SmoScriptableTaskOperationStub.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/TaskServices/SmoScriptableTaskOperationStub.cs
new file mode 100644
index 00000000..5a3c6457
--- /dev/null
+++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/TaskServices/SmoScriptableTaskOperationStub.cs
@@ -0,0 +1,52 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using Microsoft.SqlServer.Management.Smo;
+using Microsoft.SqlTools.ServiceLayer.TaskServices;
+
+namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.TaskServices
+{
+ public class SmoScriptableTaskOperationStub : SmoScriptableTaskOperation
+ {
+ private Server server;
+ public string DatabaseName { get; set; }
+ public SmoScriptableTaskOperationStub(Server server)
+ {
+ this.server = server;
+ }
+ public override string ErrorMessage
+ {
+ get
+ {
+ return string.Empty;
+ }
+ }
+
+ public override Server Server
+ {
+ get
+ {
+ return server;
+ }
+ }
+
+ public override void Cancel()
+ {
+ }
+
+ public string TableName { get; set; }
+
+ public override void Execute()
+ {
+ var database = server.Databases[DatabaseName];
+ Table table = new Table(database, TableName, "test");
+ Column column = new Column(table, "c1");
+ column.DataType = DataType.Int;
+ table.Columns.Add(column);
+ database.Tables.Add(table);
+ table.Create();
+ }
+ }
+}
diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/TaskServices/TaskServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/TaskServices/TaskServiceTests.cs
new file mode 100644
index 00000000..82fc60eb
--- /dev/null
+++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/TaskServices/TaskServiceTests.cs
@@ -0,0 +1,156 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System.Data.Common;
+using System.Data.SqlClient;
+using System.Linq;
+using System.Threading.Tasks;
+using Microsoft.SqlServer.Management.Common;
+using Microsoft.SqlServer.Management.Smo;
+using Microsoft.SqlTools.Extensibility;
+using Microsoft.SqlTools.Hosting.Protocol;
+using Microsoft.SqlTools.ServiceLayer.Connection;
+using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
+using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility;
+using Microsoft.SqlTools.ServiceLayer.TaskServices;
+using Microsoft.SqlTools.ServiceLayer.TaskServices.Contracts;
+using Microsoft.SqlTools.ServiceLayer.Test.Common;
+using Microsoft.SqlTools.ServiceLayer.UnitTests;
+using Microsoft.SqlTools.ServiceLayer.UnitTests.Utility;
+using Moq;
+using Xunit;
+using static Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility.LiveConnectionHelper;
+
+namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.TaskServices
+{
+ public class TaskServiceTests : ServiceTestBase
+ {
+ private TaskService service;
+ private Mock serviceHostMock;
+
+ public TaskServiceTests()
+ {
+ serviceHostMock = new Mock();
+ service = CreateService();
+ service.InitializeService(serviceHostMock.Object);
+ }
+
+ [Fact]
+ public async Task VerifyTaskExecuteTheQueryGivenExecutionModeExecute()
+ {
+ await VerifyTaskWithExecutionMode(TaskExecutionMode.Execute);
+ }
+
+ [Fact]
+ public async Task VerifyTaskGenerateScriptOnlyGivenExecutionModeScript()
+ {
+ await VerifyTaskWithExecutionMode(TaskExecutionMode.Script);
+ }
+
+ [Fact]
+ public async Task VerifyTaskNotExecuteAndGenerateScriptGivenExecutionModeExecuteAndScript()
+ {
+ await VerifyTaskWithExecutionMode(TaskExecutionMode.ExecuteAndScript);
+ }
+
+ [Fact]
+ public async Task VerifyTaskSendsFailureNotificationGivenInvalidQuery()
+ {
+ await VerifyTaskWithExecutionMode(TaskExecutionMode.ExecuteAndScript, true);
+ }
+
+ public async Task VerifyTaskWithExecutionMode(TaskExecutionMode executionMode, bool makeTaskFail = false)
+ {
+ serviceHostMock.AddEventHandling(TaskStatusChangedNotification.Type, null);
+
+ using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
+ {
+ //To make the task fail don't create the schema so create table fails
+ string query = string.Empty;
+ if (!makeTaskFail)
+ {
+ query = $"CREATE SCHEMA [test]";
+ }
+ SqlTestDb testDb = await SqlTestDb.CreateNewAsync(TestServerType.OnPrem, false, null, query, "TaskService");
+ try
+ {
+ TestConnectionResult connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync(testDb.DatabaseName, queryTempFile.FilePath);
+ string taskName = "task name";
+ Server server = CreateServerObject(connectionResult.ConnectionInfo);
+ RequstParamStub requstParam = new RequstParamStub
+ {
+ TaskExecutionMode = executionMode,
+ OwnerUri = queryTempFile.FilePath
+ };
+ SmoScriptableTaskOperationStub taskOperation = new SmoScriptableTaskOperationStub(server);
+ taskOperation.DatabaseName = testDb.DatabaseName;
+ taskOperation.TableName = "newTable";
+ TaskMetadata taskMetadata = TaskMetadata.Create(requstParam, taskName, taskOperation, ConnectionService.Instance);
+ SqlTask sqlTask = service.TaskManager.CreateTask(taskMetadata);
+ Task taskToVerify = sqlTask.RunAsync().ContinueWith(task =>
+ {
+ if (!makeTaskFail)
+ {
+ if (executionMode == TaskExecutionMode.Script || executionMode == TaskExecutionMode.ExecuteAndScript)
+ {
+ serviceHostMock.Verify(x => x.SendEvent(TaskStatusChangedNotification.Type,
+ It.Is(t => !string.IsNullOrEmpty(t.Script))), Times.AtLeastOnce());
+ }
+
+ //Verify if the table created if execution mode includes execute
+ bool expected = executionMode == TaskExecutionMode.Execute || executionMode == TaskExecutionMode.ExecuteAndScript;
+ Server serverToverfiy = CreateServerObject(connectionResult.ConnectionInfo);
+ bool actual = serverToverfiy.Databases[testDb.DatabaseName].Tables.Contains(taskOperation.TableName, "test");
+ Assert.Equal(expected, actual);
+ }
+ else
+ {
+ serviceHostMock.Verify(x => x.SendEvent(TaskStatusChangedNotification.Type,
+ It.Is(t => t.Status == SqlTaskStatus.Failed)), Times.AtLeastOnce());
+ }
+ });
+ await taskToVerify;
+ }
+ finally
+ {
+ testDb.Cleanup();
+ }
+
+ }
+ }
+ protected TaskService CreateService()
+ {
+ CreateServiceProviderWithMinServices();
+
+ // Create the service using the service provider, which will initialize dependencies
+ return ServiceProvider.GetService();
+ }
+
+ protected override RegisteredServiceProvider CreateServiceProviderWithMinServices()
+ {
+ TaskService service = new TaskService();
+ service.TaskManager = new SqlTaskManager();
+ return CreateProvider()
+ .RegisterSingleService(service);
+ }
+
+ private Server CreateServerObject(ConnectionInfo connInfo )
+ {
+ SqlConnection connection = null;
+ DbConnection dbConnection = connInfo.AllConnections.First();
+ ReliableSqlConnection reliableSqlConnection = dbConnection as ReliableSqlConnection;
+ SqlConnection sqlConnection = dbConnection as SqlConnection;
+ if (reliableSqlConnection != null)
+ {
+ connection = reliableSqlConnection.GetUnderlyingConnection();
+ }
+ else if (sqlConnection != null)
+ {
+ connection = sqlConnection;
+ }
+ return new Server(new ServerConnection(connection));
+ }
+ }
+}
\ No newline at end of file
diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/DisasterRecovery/BackupOperationStub.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/DisasterRecovery/BackupOperationStub.cs
index 8e80f853..a8c13a10 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/DisasterRecovery/BackupOperationStub.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/DisasterRecovery/BackupOperationStub.cs
@@ -28,6 +28,16 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.DisasterRecovery
public string ScriptContent { get; set; }
+ public string ErrorMessage
+ {
+ get
+ {
+ return string.Empty;
+ }
+ }
+
+ public SqlTask SqlTask { get; set; }
+
///
/// Initialize
///
diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/DisasterRecovery/BackupTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/DisasterRecovery/BackupTests.cs
index 1a3386aa..636ed87a 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/DisasterRecovery/BackupTests.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/DisasterRecovery/BackupTests.cs
@@ -201,15 +201,14 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.DisasterRecovery
}
}
- private TaskMetadata CreateTaskMetaData(object data)
+ private TaskMetadata CreateTaskMetaData(IBackupOperation data)
{
TaskMetadata taskMetaData = new TaskMetadata
{
ServerName = "server name",
DatabaseName = "database name",
Name = "backup database",
- IsCancelable = true,
- Data = data
+ TaskOperation = data
};
return taskMetaData;
diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/TaskServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/TaskServiceTests.cs
index 6350c455..cead820d 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/TaskServiceTests.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/TaskServiceTests.cs
@@ -124,8 +124,10 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.TaskServices
protected override RegisteredServiceProvider CreateServiceProviderWithMinServices()
{
+ TaskService service = new TaskService();
+ service.TaskManager = new SqlTaskManager();
return CreateProvider()
- .RegisterSingleService(new TaskService());
+ .RegisterSingleService(service);
}
}
}