diff --git a/bin/nuget/Microsoft.SqlServer.Smo.140.2.1.nupkg b/bin/nuget/Microsoft.SqlServer.Smo.140.2.1.nupkg deleted file mode 100644 index 3cc8cb5f..00000000 Binary files a/bin/nuget/Microsoft.SqlServer.Smo.140.2.1.nupkg and /dev/null differ diff --git a/bin/nuget/Microsoft.SqlServer.Smo.140.2.2.nupkg b/bin/nuget/Microsoft.SqlServer.Smo.140.2.2.nupkg new file mode 100644 index 00000000..09dfe8f4 Binary files /dev/null and b/bin/nuget/Microsoft.SqlServer.Smo.140.2.2.nupkg differ diff --git a/src/Microsoft.SqlTools.Credentials/Microsoft.SqlTools.Credentials.csproj b/src/Microsoft.SqlTools.Credentials/Microsoft.SqlTools.Credentials.csproj index da97d5fa..94504624 100644 --- a/src/Microsoft.SqlTools.Credentials/Microsoft.SqlTools.Credentials.csproj +++ b/src/Microsoft.SqlTools.Credentials/Microsoft.SqlTools.Credentials.csproj @@ -19,7 +19,7 @@ - + diff --git a/src/Microsoft.SqlTools.Hosting/Microsoft.SqlTools.Hosting.csproj b/src/Microsoft.SqlTools.Hosting/Microsoft.SqlTools.Hosting.csproj index 247b20a3..38432c73 100644 --- a/src/Microsoft.SqlTools.Hosting/Microsoft.SqlTools.Hosting.csproj +++ b/src/Microsoft.SqlTools.Hosting/Microsoft.SqlTools.Hosting.csproj @@ -14,7 +14,7 @@ - + diff --git a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/CommonUtilities.cs b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/CommonUtilities.cs index aa85258b..32377f0f 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/CommonUtilities.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/CommonUtilities.cs @@ -340,28 +340,6 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery } } - public bool IsDestinationPathValid(string path, ref bool isFolder) - { - Enumerator en = null; - DataTable dt; - Request req = new Request(); - - en = new Enumerator(); - req.Urn = "Server/File[@FullName='" + Urn.EscapeString(path) + "']"; - dt = en.Process(this.sqlConnection, req); - - if (dt.Rows.Count > 0) - { - isFolder = !(Convert.ToBoolean(dt.Rows[0]["IsFile"], System.Globalization.CultureInfo.InvariantCulture)); - return true; - } - else - { - isFolder = false; - return false; - } - } - public string GetMediaNameFromBackupSetId(int backupSetId) { Enumerator en = null; @@ -422,72 +400,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery return result; } - public string GetNewPhysicalRestoredFileName(string filePathParam, string dbName, bool isNewDatabase, string type, ref int fileIndex) - { - if (string.IsNullOrEmpty(filePathParam)) - { - return string.Empty; - } - - string result = string.Empty; - string filePath = filePathParam; - int idx = filePath.LastIndexOf('\\'); - string folderPath = filePath.Substring(0,idx); - - string fileName = filePath.Substring(idx + 1); - idx = fileName.LastIndexOf('.'); - string fileExtension = fileName.Substring(idx + 1); - - bool isFolder = true; - bool isValidPath = IsDestinationPathValid(folderPath, ref isFolder); - - if (!isValidPath || !isFolder) - { - SMO.Server server = new SMO.Server(this.sqlConnection); - if (type != RestoreConstants.Log) - { - folderPath = server.Settings.DefaultFile; - if (folderPath.Length == 0) - { - folderPath = server.Information.MasterDBPath; - } - } - else - { - folderPath = server.Settings.DefaultLog; - if (folderPath.Length == 0) - { - folderPath = server.Information.MasterDBLogPath; - } - } - } - else - { - if (!isNewDatabase) - { - return filePathParam; - } - } - - if (!isNewDatabase) - { - result = folderPath + "\\" + dbName + "." + fileExtension; - } - else - { - if (0 != string.Compare(fileExtension, "mdf", StringComparison.OrdinalIgnoreCase)) - { - result = folderPath + "\\" + dbName + "_" + Convert.ToString(fileIndex, System.Globalization.CultureInfo.InvariantCulture) + "." + fileExtension; - fileIndex++; - } - else - { - result = folderPath + "\\" + dbName + "." + fileExtension; - } - } - - return result; - } + // TODO: This is implemented as internal property in SMO. public bool IsLocalPrimaryReplica(string databaseName) @@ -663,29 +576,29 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery return dict; }*/ - public void GetBackupSetTypeAndComponent(int numType, ref string backupType, ref string backupComponent) + public static void GetBackupSetTypeAndComponent(BackupSetType backupSetType, out string backupType, out string backupComponent) { - switch (numType) + switch (backupSetType) { - case 1: + case BackupSetType.Database: backupType = RestoreConstants.TypeFull; backupComponent = RestoreConstants.ComponentDatabase; break; - case 2: + case BackupSetType.Differential: backupType = RestoreConstants.TypeTransactionLog; - backupComponent = ""; + backupComponent = RestoreConstants.ComponentDatabase; break; - case 4: + case BackupSetType.FileOrFileGroup: backupType = RestoreConstants.TypeFilegroup; backupComponent = RestoreConstants.ComponentFile; break; - case 5: + case BackupSetType.FileOrFileGroupDifferential: backupType = RestoreConstants.TypeDifferential; backupComponent = RestoreConstants.ComponentDatabase; break; - case 6: - backupType = RestoreConstants.TypeFilegroupDifferential; - backupComponent = RestoreConstants.ComponentFile; + case BackupSetType.Log: + backupType = RestoreConstants.Log; + backupComponent = RestoreConstants.ComponentLog; break; default: backupType = RestoreConstants.NotKnown; diff --git a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/Contracts/RestoreRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/Contracts/RestoreRequest.cs new file mode 100644 index 00000000..33ce0d9e --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/Contracts/RestoreRequest.cs @@ -0,0 +1,125 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Collections.Generic; +using Microsoft.SqlTools.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.Contracts +{ + /// + /// Restore request parameters + /// + public class RestoreParams + { + /// + /// The Uri to find the connection to do the restore operations + /// + public string OwnerUri { get; set; } + + /// + /// The backup file path + /// + public string BackupFilePath { get; set; } + + /// + /// Database name to restore from (either the back file path or database name can be used for restore operation, + /// If the backup file is set, the database name will be ignored) + /// + public string DatabaseName { get; set; } + + /// + /// If set to true, the db files will be relocated to default data location in the server + /// + public bool RelocateDbFiles { get; set; } + } + + /// + /// Restore response + /// + public class RestoreResponse + { + /// + /// Indicates if the restore task created successfully + /// + public bool Result { get; set; } + + /// + /// The task id assosiated witht the restore operation + /// + public string TaskId { get; set; } + + + /// + /// Errors occurred while creating the restore operation task + /// + public string ErrorMessage { get; set; } + } + + /// + /// Restore Plan Response + /// + public class RestorePlanResponse + { + /// + /// The backup file path + /// + public string BackupFilePath { get; set; } + + /// + /// Indicates whether the restore operation is supported + /// + public bool CanRestore { get; set; } + + /// + /// Errors occurred while creating restore plan + /// + public string ErrorMessage { get; set; } + + /// + /// The db files included in the backup file + /// + public IEnumerable DbFiles { get; set; } + + /// + /// Server name + /// + public string ServerName { get; set; } + + /// + /// Database name to restore to + /// + public string DatabaseName { get; set; } + + /// + /// Indicates whether relocating the db files is required + /// because the original file paths are not valid in the target server + /// + public bool RelocateFilesNeeded { get; set; } + + /// + /// Default Data folder path in the target server + /// + public string DefaultDataFolder { get; set; } + + /// + /// Default log folder path in the target server + /// + public string DefaultLogFolder { get; set; } + } + + public class RestoreRequest + { + public static readonly + RequestType Type = + RequestType.Create("disasterrecovery/restore"); + } + + public class RestorePlanRequest + { + public static readonly + RequestType Type = + RequestType.Create("disasterrecovery/restoreplan"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/DisasterRecoveryConstants.cs b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/DisasterRecoveryConstants.cs index 352a9ce1..4290a15f 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/DisasterRecoveryConstants.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/DisasterRecoveryConstants.cs @@ -42,6 +42,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery public static string TypeFilegroupDifferential = "Filegroup Differential"; public static string ComponentDatabase = "Database"; public static string ComponentFile = "File"; + public static string ComponentLog = "Log"; } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/DisasterRecoveryService.cs b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/DisasterRecoveryService.cs index 15a4fd4d..61f9ab50 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/DisasterRecoveryService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/DisasterRecoveryService.cs @@ -13,6 +13,9 @@ using Microsoft.SqlTools.ServiceLayer.DisasterRecovery.Contracts; using Microsoft.SqlTools.ServiceLayer.Hosting; using Microsoft.SqlTools.ServiceLayer.TaskServices; using System.Threading; +using Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation; +using Microsoft.SqlServer.Management.Smo; +using Microsoft.SqlServer.Management.Common; namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery { @@ -23,6 +26,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery { private static readonly Lazy instance = new Lazy(() => new DisasterRecoveryService()); private static ConnectionService connectionService = null; + private RestoreDatabaseHelper restoreDatabaseService = new RestoreDatabaseHelper(); /// /// Default, parameterless constructor. @@ -61,12 +65,17 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery /// /// Initializes the service instance /// - public void InitializeService(ServiceHost serviceHost) + public void InitializeService(IProtocolEndpoint serviceHost) { // Get database info serviceHost.SetRequestHandler(BackupConfigInfoRequest.Type, HandleBackupConfigInfoRequest); // Create backup serviceHost.SetRequestHandler(BackupRequest.Type, HandleBackupRequest); + + // Create respore task + serviceHost.SetRequestHandler(RestoreRequest.Type, HandleRestoreRequest); + // Create respore plan + serviceHost.SetRequestHandler(RestorePlanRequest.Type, HandleRestorePlanRequest); } /// @@ -100,6 +109,81 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery await requestContext.SendResult(response); } + /// + /// Handles a restore request + /// + internal async Task HandleRestorePlanRequest( + RestoreParams restoreParams, + RequestContext requestContext) + { + RestorePlanResponse response = new RestorePlanResponse(); + ConnectionInfo connInfo; + bool supported = IsBackupRestoreOperationSupported(restoreParams, out connInfo); + + if (supported && connInfo != null) + { + RestoreDatabaseTaskDataObject restoreDataObject = this.restoreDatabaseService.CreateRestoreDatabaseTaskDataObject(restoreParams); + response = this.restoreDatabaseService.CreateRestorePlanResponse(restoreDataObject); + } + else + { + response.CanRestore = false; + response.ErrorMessage = "Restore is not supported"; //TOOD: have a better error message + } + await requestContext.SendResult(response); + + } + + /// + /// Handles a restore request + /// + internal async Task HandleRestoreRequest( + RestoreParams restoreParams, + RequestContext requestContext) + { + RestoreResponse response = new RestoreResponse(); + ConnectionInfo connInfo; + bool supported = IsBackupRestoreOperationSupported(restoreParams, out connInfo); + + if (supported && connInfo != null) + { + try + { + RestoreDatabaseTaskDataObject restoreDataObject = this.restoreDatabaseService.CreateRestoreDatabaseTaskDataObject(restoreParams); + + if (restoreDataObject != null) + { + // create task metadata + TaskMetadata metadata = new TaskMetadata(); + metadata.ServerName = connInfo.ConnectionDetails.ServerName; + metadata.DatabaseName = connInfo.ConnectionDetails.DatabaseName; + metadata.Name = SR.Backup_TaskName; + metadata.IsCancelable = true; + metadata.Data = restoreDataObject; + + + // create restore task and perform + SqlTask sqlTask = SqlTaskManager.Instance.CreateAndRun(metadata, this.restoreDatabaseService.RestoreTaskAsync, restoreDatabaseService.CancelTaskAsync); + response.TaskId = sqlTask.TaskId.ToString(); + } + else + { + response.ErrorMessage = "Failed to create restore task"; + } + } + catch (Exception ex) + { + response.ErrorMessage = ex.Message; + } + } + else + { + response.ErrorMessage = "Restore database is not supported"; //TOOD: have a better error message + } + + await requestContext.SendResult(response); + } + /// /// Handles a backup request /// @@ -163,6 +247,37 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery return null; } + private bool IsBackupRestoreOperationSupported(RestoreParams restoreParams, out ConnectionInfo connectionInfo) + { + SqlConnection sqlConn = null; + try + { + ConnectionInfo connInfo; + DisasterRecoveryService.ConnectionServiceInstance.TryFindConnection( + restoreParams.OwnerUri, + out connInfo); + + if (connInfo != null) + { + sqlConn = GetSqlConnection(connInfo); + if ((sqlConn != null) && !connInfo.IsSqlDW && !connInfo.IsAzure) + { + connectionInfo = connInfo; + return true; + } + } + } + catch + { + if(sqlConn != null) + { + sqlConn.Close(); + } + } + connectionInfo = null; + return false; + } + internal BackupConfigInfo GetBackupConfigInfo(CDataContainer dataContainer, SqlConnection sqlConnection, string databaseName) { BackupOperation backupOperation = new BackupOperation(); diff --git a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/BackupSetInfo.cs b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/BackupSetInfo.cs new file mode 100644 index 00000000..89787946 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/BackupSetInfo.cs @@ -0,0 +1,23 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation +{ + /// + /// Backup set Information + /// + public class BackupSetInfo + { + /// + /// Backup type (Full, Transaction Log, Differential ...) + /// + public string BackupType { get; set; } + + /// + /// Backup component (Database, File, Log ...) + /// + public string BackupComponent { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseService.cs b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseService.cs new file mode 100644 index 00000000..c7f9f329 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseService.cs @@ -0,0 +1,220 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SqlServer.Management.Common; +using Microsoft.SqlServer.Management.Smo; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.DisasterRecovery.Contracts; +using Microsoft.SqlTools.ServiceLayer.TaskServices; + +namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation +{ + /// + /// Includes method to all restore operations + /// + public class RestoreDatabaseHelper + { + + /// + /// Create a backup task for execution and cancellation + /// + /// + /// + internal async Task RestoreTaskAsync(SqlTask sqlTask) + { + sqlTask.AddMessage(SR.Task_InProgress, SqlTaskStatus.InProgress, true); + RestoreDatabaseTaskDataObject restoreDataObject = sqlTask.TaskMetadata.Data as RestoreDatabaseTaskDataObject; + TaskResult taskResult = null; + + if (restoreDataObject != null) + { + // Create a task to perform backup + return await Task.Factory.StartNew(() => + { + TaskResult result = new TaskResult(); + try + { + ExecuteRestore(restoreDataObject); + result.TaskStatus = SqlTaskStatus.Succeeded; + } + catch (Exception ex) + { + result.TaskStatus = SqlTaskStatus.Failed; + result.ErrorMessage = ex.Message; + if (ex.InnerException != null) + { + result.ErrorMessage += System.Environment.NewLine + ex.InnerException.Message; + } + } + return result; + }); + } + else + { + taskResult = new TaskResult(); + taskResult.TaskStatus = SqlTaskStatus.Failed; + } + + return taskResult; + } + + + + /// + /// Async task to cancel restore + /// + public async Task CancelTaskAsync(SqlTask sqlTask) + { + RestoreDatabaseTaskDataObject restoreDataObject = sqlTask.TaskMetadata.Data as RestoreDatabaseTaskDataObject; + TaskResult taskResult = null; + + + if (restoreDataObject != null && restoreDataObject.IsValid) + { + // Create a task for backup cancellation request + return await Task.Factory.StartNew(() => + { + + foreach (Restore restore in restoreDataObject.RestorePlan.RestoreOperations) + { + restore.Abort(); + } + + + return new TaskResult + { + TaskStatus = SqlTaskStatus.Canceled + }; + + }); + } + else + { + taskResult = new TaskResult(); + taskResult.TaskStatus = SqlTaskStatus.Failed; + } + + return taskResult; + } + + /// + /// Creates a restore plan, The result includes the information about the backup set, + /// the files and the database to restore to + /// + /// Restore requests + /// Restore plan + public RestorePlanResponse CreateRestorePlanResponse(RestoreDatabaseTaskDataObject restoreDataObject) + { + RestorePlanResponse response = new RestorePlanResponse() + { + DatabaseName = restoreDataObject.RestoreParams.DatabaseName + }; + if (restoreDataObject != null && restoreDataObject.IsValid) + { + UpdateRestorePlan(restoreDataObject); + + if (restoreDataObject != null && restoreDataObject.IsValid) + { + response.DatabaseName = restoreDataObject.RestorePlanner.DatabaseName; + response.DbFiles = restoreDataObject.DbFiles.Select(x => x.PhysicalName); + response.CanRestore = CanRestore(restoreDataObject); + + if (!response.CanRestore) + { + response.ErrorMessage = "Backup not supported."; + } + + response.RelocateFilesNeeded = !restoreDataObject.DbFilesLocationAreValid(); + response.DefaultDataFolder = restoreDataObject.DefaultDataFileFolder; + response.DefaultLogFolder = restoreDataObject.DefaultLogFileFolder; + } + else + { + response.ErrorMessage = "Failed to create restore plan"; + response.CanRestore = false; + } + } + else + { + response.ErrorMessage = "Failed to create restore database plan"; + } + return response; + + } + + /// + /// Returns true if the restoring the restoreDataObject is supported in the service + /// + private static bool CanRestore(RestoreDatabaseTaskDataObject restoreDataObject) + { + if (restoreDataObject != null) + { + var backupTypes = restoreDataObject.GetBackupSetInfo(); + return backupTypes.Any(x => x.BackupType.StartsWith(RestoreConstants.TypeFull)); + } + return false; + } + + /// + /// Creates anew restore task object to do the restore operations + /// + /// Restore request parameters + /// Restore task object + public RestoreDatabaseTaskDataObject CreateRestoreDatabaseTaskDataObject(RestoreParams restoreParams) + { + ConnectionInfo connInfo; + DisasterRecoveryService.ConnectionServiceInstance.TryFindConnection( + restoreParams.OwnerUri, + out connInfo); + + if (connInfo != null) + { + Server server = new Server(new ServerConnection(connInfo.ConnectionDetails.ServerName)); + + RestoreDatabaseTaskDataObject restoreDataObject = new RestoreDatabaseTaskDataObject(server, restoreParams.DatabaseName); + restoreDataObject.RestoreParams = restoreParams; + return restoreDataObject; + } + return null; + } + + /// + /// Create a restore data object that includes the plan to do the restore operation + /// + /// + /// + private void UpdateRestorePlan(RestoreDatabaseTaskDataObject restoreDataObject) + { + // Server server = new Server(new ServerConnection(connInfo.ConnectionDetails.ServerName)); + //RestoreDatabaseTaskDataObject restoreDataObject = new RestoreDatabaseTaskDataObject(server, requestParam.DatabaseName); + if (!string.IsNullOrEmpty(restoreDataObject.RestoreParams.BackupFilePath)) + { + restoreDataObject.AddFile(restoreDataObject.RestoreParams.BackupFilePath); + } + restoreDataObject.RestorePlanner.ReadHeaderFromMedia = !string.IsNullOrEmpty(restoreDataObject.RestoreParams.BackupFilePath); + var dbNames = restoreDataObject.GetSourceDbNames(); + string dbName = dbNames.First(); + restoreDataObject.RestorePlanner.DatabaseName = dbName; + restoreDataObject.UpdateRestorePlan(restoreDataObject.RestoreParams.RelocateDbFiles); + } + + /// + /// Executes the restore operation + /// + /// + public void ExecuteRestore(RestoreDatabaseTaskDataObject restoreDataObject) + { + UpdateRestorePlan(restoreDataObject); + + if (restoreDataObject != null && CanRestore(restoreDataObject)) + { + restoreDataObject.RestorePlan.Execute(); + } + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseTaskDataObject.cs b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseTaskDataObject.cs new file mode 100644 index 00000000..f1ea2067 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseTaskDataObject.cs @@ -0,0 +1,816 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Microsoft.SqlServer.Management.Smo; +using Microsoft.SqlTools.ServiceLayer.DisasterRecovery.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation +{ + /// + /// Includes the plan with all the data required to do a restore operation on server + /// + public class RestoreDatabaseTaskDataObject + { + public RestoreDatabaseTaskDataObject(Server server, String databaseName) + { + this.Server = server; + this.Util = new RestoreUtil(server); + restorePlanner = new DatabaseRestorePlanner(server); + + if (String.IsNullOrEmpty(databaseName)) + { + this.restorePlanner = new DatabaseRestorePlanner(server); + } + else + { + this.restorePlanner = new DatabaseRestorePlanner(server, databaseName); + this.targetDbName = databaseName; + } + + this.restorePlanner.TailLogBackupFile = this.Util.GetDefaultTailLogbackupFile(databaseName); + this.restoreOptions = new RestoreOptions(); + //the server will send events in intervals of 5 percent + this.restoreOptions.PercentCompleteNotification = 5; + } + + public bool IsValid + { + get + { + return this.Server != null && this.RestorePlanner != null; + } + } + + public RestoreParams RestoreParams { get; set; } + + /// + /// Database names includes in the restore plan + /// + /// + public List GetSourceDbNames() + { + return Util.GetSourceDbNames(this.restorePlanner.BackupMediaList, this.CredentialName); + } + + /// + /// Current sqlserver instance + /// + public Server Server; + + /// + /// Recent exception that was thrown + /// Displayed at the top of the dialog + /// + public Exception ActiveException { get; set; } + + public Exception CreateOrUpdateRestorePlanException { get; set; } + + /// + /// Add a backup file to restore plan media list + /// + /// + public void AddFile(string filePath) + { + this.RestorePlanner.BackupMediaList.Add(new BackupDeviceItem + { + DeviceType = DeviceType.File, + Name = filePath + }); + } + + public RestoreUtil Util { get; set; } + + private DatabaseRestorePlanner restorePlanner; + + /// + /// SMO database restore planner used to create a restore plan + /// + public DatabaseRestorePlanner RestorePlanner + { + get { return restorePlanner; } + } + + private string tailLogBackupFile; + private bool planUpdateRequired = false; + + /// + /// File to backup tail log before doing the restore + /// + public string TailLogBackupFile + { + get { return tailLogBackupFile; } + set + { + if (tailLogBackupFile == null || !tailLogBackupFile.Equals(value)) + { + this.RestorePlanner.TailLogBackupFile = value; + this.planUpdateRequired = true; + this.tailLogBackupFile = value; + } + } + } + + private RestoreOptions restoreOptions; + + public RestoreOptions RestoreOptions + { + get { return restoreOptions; } + } + + private string dataFilesFolder = string.Empty; + + /// + /// Folder for all data files when relocate all files option is used + /// + public string DataFilesFolder + { + get { return this.dataFilesFolder; } + set + { + if (this.dataFilesFolder == null || !this.dataFilesFolder.Equals(value)) + { + try + { + Uri pathUri; + bool fUriCreated = Uri.TryCreate(value, UriKind.Absolute, out pathUri); + + if (fUriCreated && pathUri.Scheme == "https") + { + this.dataFilesFolder = value; + } + else + { + this.dataFilesFolder = PathWrapper.GetDirectoryName(value); + } + if (string.IsNullOrEmpty(this.dataFilesFolder)) + { + this.dataFilesFolder = this.Server.DefaultFile; + } + } + catch (Exception ex) + { + this.ActiveException = ex; + } + + this.RelocateDbFiles(); + } + } + } + + private string logFilesFolder = string.Empty; + + /// + /// Folder for all log files when relocate all files option is used + /// + public string LogFilesFolder + { + get { return this.logFilesFolder; } + set + { + if (this.logFilesFolder == null || !this.logFilesFolder.Equals(value)) + { + try + { + Uri pathUri; + bool fUriCreated = Uri.TryCreate(value, UriKind.Absolute, out pathUri); + + if (fUriCreated && pathUri.Scheme == "https") + { + this.logFilesFolder = value; + } + else + { + this.logFilesFolder = PathWrapper.GetDirectoryName(value); + } + if (string.IsNullOrEmpty(this.logFilesFolder)) + { + this.logFilesFolder = Server.DefaultLog; + } + } + catch (Exception ex) + { + this.ActiveException = ex; + } + this.RelocateDbFiles(); + } + } + } + + /// + /// Gets or sets a value indicating whether [prompt before each backup]. + /// + /// + /// true if [prompt before each backup]; otherwise, false. + /// + public bool PromptBeforeEachBackup { get; set; } + + private void RelocateDbFiles() + { + try + { + foreach (DbFile dbFile in this.DbFiles) + { + string fileName = this.GetTargetDbFilePhysicalName(dbFile.PhysicalName); + if (!dbFile.DbFileType.Equals("Log")) + { + if (!string.IsNullOrEmpty(this.dataFilesFolder)) + { + dbFile.PhysicalNameRelocate = PathWrapper.Combine(this.dataFilesFolder, fileName); + } + else + { + dbFile.PhysicalNameRelocate = fileName; + } + } + else + { + if (!string.IsNullOrEmpty(this.logFilesFolder)) + { + dbFile.PhysicalNameRelocate = PathWrapper.Combine(this.logFilesFolder, fileName); + } + else + { + dbFile.PhysicalNameRelocate = fileName; + } + } + } + } + catch (Exception ex) + { + this.ActiveException = ex; + } + } + + private List dbFiles = new List(); + + /// + /// List of files of the source database or in the backup file + /// + public List DbFiles + { + get { return dbFiles; } + } + + internal RestorePlan restorePlan; + + /// + /// Restore plan to do the restore + /// + public RestorePlan RestorePlan + { + get + { + if (this.restorePlan == null) + { + this.UpdateRestorePlan(false); + } + return this.restorePlan; + } + internal set + { + this.restorePlan = value; + } + } + + public bool[] RestoreSelected; + + /// + /// The database being restored + /// + public string targetDbName = string.Empty; + + /// + /// The database used to restore from + /// + public string sourceDbName = string.Empty; + + /// + /// Gets or sets a value indicating whether [close existing connections]. + /// + /// + /// true if [close existing connections]; otherwise, false. + /// + public bool CloseExistingConnections { get; set; } + + /* + private BackupTimeLine.TimeLineDuration timeLineDuration = BackupTimeLine.TimeLineDuration.Day; + + public BackupTimeLine.TimeLineDuration TimeLineDuration + { + get { return this.timeLineDuration; } + set { this.timeLineDuration = value; } + } + */ + + /// + /// Sql server credential name used to restore from Microsoft Azure url + /// + internal string CredentialName = string.Empty; + + /// + /// Azure container SAS policy + /// + internal string ContainerSharedAccessPolicy = string.Empty; + + /// + /// Gets RestorePlan to perform restore and to script + /// + public RestorePlan GetRestorePlanForExecutionAndScript() + { + this.ActiveException = null; //Clear any existing exceptions as the plan is getting recreated. + //Clear any existing exceptions as new plan is getting recreated. + this.CreateOrUpdateRestorePlanException = null; + bool tailLogBackup = this.RestorePlanner.BackupTailLog; + if (this.planUpdateRequired) + { + this.RestorePlan = this.RestorePlanner.CreateRestorePlan(this.RestoreOptions); + this.UpdateRestoreSelected(); + this.Util.AddCredentialNameForUrlBackupSet(this.RestorePlan, this.CredentialName); + } + RestorePlan rp = new RestorePlan(this.Server); + rp.RestoreAction = RestoreActionType.Database; + if (this.RestorePlan != null) + { + if (this.RestorePlan.TailLogBackupOperation != null && tailLogBackup) + { + rp.TailLogBackupOperation = this.RestorePlan.TailLogBackupOperation; + } + int i = 0; + foreach (Restore res in this.RestorePlan.RestoreOperations) + { + if (this.RestoreSelected[i] == true) + { + rp.RestoreOperations.Add(res); + } + i++; + } + } + this.SetRestorePlanProperties(rp); + return rp; + } + + /// + /// Updates the RestoreSelected Array to hold information about updated Restore Plan + /// + private void UpdateRestoreSelected() + { + int operationsCount = this.RestorePlan.RestoreOperations.Count; + // The given condition will return true only if new backup has been added on database during lifetime of restore dialog. + // This will happen when tail log backup is taken successfully and subsequent restores have failed. + if (operationsCount > this.RestoreSelected.Length) + { + bool[] tempRestoreSel = new bool[this.RestorePlan.RestoreOperations.Count]; + for (int i = 0; i < operationsCount; i++) + { + if (i < RestoreSelected.Length) + { + //Retain all the old values. + tempRestoreSel[i] = RestoreSelected[i]; + } + else + { + //Do not add the newly added backupset into Restore plan by default. + tempRestoreSel[i] = false; + } + } + this.RestoreSelected = tempRestoreSel; + } + } + + /// + /// Returns the physical name for the target Db file. + /// It is the sourceDbName replaced with targetDbName in sourceFilename. + /// If either sourceDbName or TargetDbName is empty, the source Db filename is returned. + /// + /// source DbFile physical location + /// + private string GetTargetDbFilePhysicalName(string sourceDbFilePhysicalLocation) + { + string fileName = Path.GetFileName(sourceDbFilePhysicalLocation); + if (!string.IsNullOrEmpty(this.sourceDbName) && !string.IsNullOrEmpty(this.targetDbName)) + { + string sourceFilename = fileName; + fileName = sourceFilename.Replace(this.sourceDbName, this.targetDbName); + } + return fileName; + } + + public IEnumerable GetBackupSetInfo() + { + List result = new List(); + foreach (Restore restore in RestorePlan.RestoreOperations) + { + BackupSet backupSet = restore.BackupSet; + + String bkSetComponent; + String bkSetType; + CommonUtilities.GetBackupSetTypeAndComponent(backupSet.BackupSetType, out bkSetType, out bkSetComponent); + + if (this.Server.Version.Major > 8 && backupSet.IsCopyOnly) + { + bkSetType += " (Copy Only)"; + } + result.Add(new BackupSetInfo { BackupComponent = bkSetComponent, BackupType = bkSetType }); + } + + return result; + } + + /// + /// Gets the files of the database + /// + public List GetDbFiles() + { + Database db = null; + List ret = new List(); + if (!this.RestorePlanner.ReadHeaderFromMedia) + { + db = this.Server.Databases[this.RestorePlanner.DatabaseName]; + } + if (restorePlan != null && restorePlan.RestoreOperations.Count > 0) + { + if (db != null && db.Status == DatabaseStatus.Normal) + { + ret = this.Util.GetDbFiles(db); + } + else + { + ret = this.Util.GetDbFiles(restorePlan.RestoreOperations[0]); + } + } + return ret; + } + + public string DefaultDataFileFolder + { + get + { + return Util.GetDefaultDataFileFolder(); + } + } + + public string DefaultLogFileFolder + { + get + { + return Util.GetDefaultLogFileFolder(); + } + } + + internal RestorePlan CreateRestorePlan(DatabaseRestorePlanner planner, RestoreOptions restoreOptions) + { + this.CreateOrUpdateRestorePlanException = null; + RestorePlan ret = null; + + try + { + ret = planner.CreateRestorePlan(restoreOptions); + if (ret == null || ret.RestoreOperations.Count == 0) + { + this.ActiveException = planner.GetBackupDeviceReadErrors(); + } + } + catch (Exception ex) + { + this.ActiveException = ex; + this.CreateOrUpdateRestorePlanException = this.ActiveException; + } + finally + { + } + + + return ret; + } + + /// + /// Updates restore plan + /// + public void UpdateRestorePlan(bool relocateAllFiles = false) + { + this.ActiveException = null; //Clear any existing exceptions as the plan is getting recreated. + //Clear any existing exceptions as new plan is getting recreated. + this.CreateOrUpdateRestorePlanException = null; + this.DbFiles.Clear(); + this.planUpdateRequired = false; + this.restorePlan = null; + if (String.IsNullOrEmpty(this.RestorePlanner.DatabaseName)) + { + this.RestorePlan = new RestorePlan(this.Server); + // this.LaunchAzureConnectToStorageDialog(); + this.Util.AddCredentialNameForUrlBackupSet(this.RestorePlan, this.CredentialName); + } + else + { + + this.RestorePlan = this.CreateRestorePlan(this.RestorePlanner, this.RestoreOptions); + this.Util.AddCredentialNameForUrlBackupSet(this.restorePlan, this.CredentialName); + if (this.ActiveException == null) + { + this.dbFiles = this.GetDbFiles(); + if(relocateAllFiles) + { + RelocateDbFiles(); + } + this.SetRestorePlanProperties(this.restorePlan); + } + } + if (this.restorePlan != null) + { + this.RestoreSelected = new bool[this.restorePlan.RestoreOperations.Count]; + for (int i = 0; i < this.restorePlan.RestoreOperations.Count; i++) + { + this.RestoreSelected[i] = true; + } + } + else + { + this.RestorePlan = new RestorePlan(this.Server); + this.Util.AddCredentialNameForUrlBackupSet(this.RestorePlan, this.CredentialName); + this.RestoreSelected = new bool[0]; + } + } + + + /// + /// Determine if restore plan of selected database does have Url + /// + private bool IfRestorePlanHasUrl() + { + return (restorePlan.RestoreOperations.Any( + res => res.BackupSet.BackupMediaSet.BackupMediaList.Any(t => t.MediaType == DeviceType.Url))); + } + + + /// + /// Sets restore plan properties + /// + private void SetRestorePlanProperties(RestorePlan rp) + { + if (rp == null || rp.RestoreOperations.Count < 1) + { + return; + } + rp.SetRestoreOptions(this.RestoreOptions); + rp.CloseExistingConnections = this.CloseExistingConnections; + if (this.targetDbName != null && !this.targetDbName.Equals(string.Empty)) + { + rp.DatabaseName = targetDbName; + } + rp.RestoreOperations[0].RelocateFiles.Clear(); + foreach (DbFile dbFile in this.DbFiles) + { + // For XStore path, we don't want to try the getFullPath. + string newPhysicalPath; + Uri pathUri; + bool fUriCreated = Uri.TryCreate(dbFile.PhysicalNameRelocate, UriKind.Absolute, out pathUri); + if (fUriCreated && pathUri.Scheme == "https") + { + newPhysicalPath = dbFile.PhysicalNameRelocate; + } + else + { + newPhysicalPath = Path.GetFullPath(dbFile.PhysicalNameRelocate); + } + if (!dbFile.PhysicalName.Equals(newPhysicalPath)) + { + RelocateFile relocFile = new RelocateFile(dbFile.LogicalName, dbFile.PhysicalNameRelocate); + rp.RestoreOperations[0].RelocateFiles.Add(relocFile); + } + } + } + + /// + /// Bool indicating whether a tail log backup will be taken + /// + public bool BackupTailLog + { + get + { + return this.RestorePlanner.BackupTailLog; + } + set + { + if (this.RestorePlanner.BackupTailLog != value) + { + this.RestorePlanner.BackupTailLog = value; + this.planUpdateRequired = true; + } + } + } + + /// + /// bool indicating whether the database will be left in restoring state + /// + public bool TailLogWithNoRecovery + { + get + { + return this.RestorePlanner.TailLogWithNoRecovery; + } + set + { + if (this.RestorePlanner.TailLogWithNoRecovery != value) + { + this.RestorePlanner.TailLogWithNoRecovery = value; + this.planUpdateRequired = true; + } + } + } + + public DateTime? CurrentRestorePointInTime + { + get + { + if (this.RestorePlan == null || this.RestorePlan.RestoreOperations.Count == 0 + || this.RestoreSelected.Length == 0 || !this.RestoreSelected[0]) + { + return null; + } + for (int i = this.RestorePlan.RestoreOperations.Count - 1; i >= 0; i--) + { + if (this.RestoreSelected[i]) + { + if (this.RestorePlan.RestoreOperations[i].BackupSet == null + || (this.RestorePlan.RestoreOperations[i].BackupSet.BackupSetType == BackupSetType.Log && + this.RestorePlan.RestoreOperations[i].ToPointInTime != null)) + { + return this.RestorePlanner.RestoreToPointInTime; + } + return this.RestorePlan.RestoreOperations[i].BackupSet.BackupStartDate; + } + } + return null; + } + } + + public void ToggleSelectRestore(int index) + { + RestorePlan rp = this.restorePlan; + if (rp == null || rp.RestoreOperations.Count <= index) + { + return; + } + //the last index - this will include tail-Log restore operation if present + if (index == rp.RestoreOperations.Count - 1) + { + if (this.RestoreSelected[index]) + { + this.RestoreSelected[index] = false; + } + else + { + for (int i = 0; i <= index; i++) + { + this.RestoreSelected[i] = true; + } + } + return; + } + if (index == 0) + { + if (!this.RestoreSelected[index]) + { + this.RestoreSelected[index] = true; + } + else + { + for (int i = index; i < rp.RestoreOperations.Count; i++) + { + this.RestoreSelected[i] = false; + } + } + return; + } + + if (index == 1 && rp.RestoreOperations[index].BackupSet.BackupSetType == BackupSetType.Differential) + { + if (!this.RestoreSelected[index]) + { + this.RestoreSelected[0] = true; + this.RestoreSelected[index] = true; + } + else if (rp.RestoreOperations[2].BackupSet == null) + { + this.RestoreSelected[index] = false; + this.RestoreSelected[2] = false; + } + else if (this.Server.Version.Major < 9 || BackupSet.IsBackupSetsInSequence(rp.RestoreOperations[0].BackupSet, rp.RestoreOperations[2].BackupSet)) + { + this.RestoreSelected[index] = false; + } + else + { + for (int i = index; i < rp.RestoreOperations.Count; i++) + { + this.RestoreSelected[i] = false; + } + } + return; + } + if (rp.RestoreOperations[index].BackupSet.BackupSetType == BackupSetType.Log) + { + if (this.RestoreSelected[index]) + { + for (int i = index; i < rp.RestoreOperations.Count; i++) + { + this.RestoreSelected[i] = false; + } + return; + } + else + { + for (int i = 0; i <= index; i++) + { + this.RestoreSelected[i] = true; + } + return; + } + } + } + + /// + /// Verifies the backup files location. + /// + internal void CheckBackupFilesLocation() + { + if (this.RestorePlan != null) + { + foreach (Restore restore in this.RestorePlan.RestoreOperations) + { + if (restore.BackupSet != null) + { + restore.BackupSet.CheckBackupFilesExistence(); + } + } + } + } + + internal bool DbFilesLocationAreValid() + { + foreach (DbFile dbFile in this.DbFiles) + { + string newPhysicalPath = Path.GetFullPath(dbFile.PhysicalNameRelocate); + if (string.Compare(dbFile.PhysicalName, dbFile.PhysicalNameRelocate, true) != 0) + { + bool isValidFolder = false; + bool isValidPath = Util.IsDestinationPathValid(Path.GetDirectoryName(newPhysicalPath), ref isValidFolder); + if (!(isValidFolder && isValidPath)) + { + return false; + } + } + } + return true; + } + } + + public class RestoreDatabaseRecoveryState + { + public RestoreDatabaseRecoveryState(DatabaseRecoveryState recoveryState) + { + this.RecoveryState = recoveryState; + } + + public DatabaseRecoveryState RecoveryState; + private static string RestoreWithRecovery = "RESTORE WITH RECOVERY"; + private static string RestoreWithNoRecovery = "RESTORE WITH NORECOVERY"; + private static string RestoreWithStandby = "RESTORE WITH STANDBY"; + + public override string ToString() + { + switch (this.RecoveryState) + { + case DatabaseRecoveryState.WithRecovery: + return RestoreDatabaseRecoveryState.RestoreWithRecovery; + case DatabaseRecoveryState.WithNoRecovery: + return RestoreDatabaseRecoveryState.RestoreWithNoRecovery; + case DatabaseRecoveryState.WithStandBy: + return RestoreDatabaseRecoveryState.RestoreWithStandby; + } + return RestoreDatabaseRecoveryState.RestoreWithRecovery; + } + + /* + public string Info() + { + switch (this.RecoveryState) + { + case DatabaseRecoveryState.WithRecovery: + return SR.RestoreWithRecoveryInfo; + case DatabaseRecoveryState.WithNoRecovery: + return SR.RestoreWithNoRecoveryInfo; + case DatabaseRecoveryState.WithStandBy: + return SR.RestoreWithStandbyInfo; + } + return SR.RestoreWithRecoveryInfo; + } + */ + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreUtil.cs b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreUtil.cs new file mode 100644 index 00000000..681a7130 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreUtil.cs @@ -0,0 +1,740 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using Microsoft.Data.Tools.DataSets; +using Microsoft.SqlServer.Management.Sdk.Sfc; +using Microsoft.SqlServer.Management.Smo; + +namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation +{ + public class RestoreUtil + { + public RestoreUtil(Server server) + { + this.server = server; + this.excludedDB = new List { "master", "tempdb" }; + } + + /// + /// Current sql server instance + /// + private readonly Server server; + private readonly IList excludedDB; + + public List GetTargetDbNamesForPageRestore() + { + List databaseNames = new List(); + foreach (Database db in this.server.Databases) + { + if (!this.excludedDB.Contains(db.Name) && + (db.Status == DatabaseStatus.Normal || db.Status == DatabaseStatus.Suspect || db.Status == DatabaseStatus.EmergencyMode) && + db.RecoveryModel == RecoveryModel.Full) + { + databaseNames.Add(db.Name); + } + } + return databaseNames; + } + + public List GetTargetDbNames() + { + List databaseNames = new List(); + foreach (Database db in this.server.Databases) + { + if (!this.excludedDB.Contains(db.Name)) + { + databaseNames.Add(db.Name); + } + } + return databaseNames; + } + + internal DateTime GetServerCurrentDateTime() + { + DateTime dt = DateTime.MinValue; + + //TODO: the code is moved from ssms and used for restore differential backups + //Uncomment when restore operation for differential backups is supported + /* + string query = "SELECT GETDATE()"; + DataSet dataset = this.server.ExecutionManager.ExecuteWithResults(query); + if (dataset != null && dataset.Tables.Count > 0 && dataset.Tables[0].Rows.Count > 0) + { + dt = Convert.ToDateTime(dataset.Tables[0].Rows[0][0], SmoApplication.DefaultCulture); + } + */ + return dt; + } + + //TODO: the code is moved from ssms and used for restore differential backups + //Uncomment when restore operation for differential backups is supported + /* + /// + /// Queries msdb for source database names + /// + public List GetSourceDbNames() + { + List databaseNames = new List(); + Request req = new Request(); + req.Urn = "Server/BackupSet"; + req.Fields = new string[1]; + req.Fields[0] = "DatabaseName"; + req.OrderByList = new OrderBy[1]; + req.OrderByList[0] = new OrderBy(); + req.OrderByList[0].Field = "DatabaseName"; + req.OrderByList[0].Dir = OrderBy.Direction.Asc; + DataTable dt = server.ExecutionManager.GetEnumeratorData(req); + string last = ""; + foreach (DataRow row in dt.Rows) + { + string dbName = Convert.ToString(row["DatabaseName"], System.Globalization.CultureInfo.InvariantCulture); + if (!this.excludedDB.Contains(dbName) && !dbName.Equals(last)) + { + bool found = false; + foreach (string str in databaseNames) + { + if (StrEqual(str, dbName)) + { + found = true; + break; + } + } + if (found == false) + { + databaseNames.Add(dbName); + } + } + last = dbName; + } + return databaseNames; + } + */ + + /// + /// Reads backup file header to get source database names + /// If valid credential name is not provided for URL throws exception while executing T-sql statement + /// + /// List of backup device items + /// Optional Sqlserver credential name to read backup header from URL + public List GetSourceDbNames(ICollection bkdevList, string credential = null) + { + List databaseNames = new List(); + foreach (BackupDeviceItem bkdev in bkdevList) + { + // use the Restore public API to do the Restore Headeronly query + Restore res = new Restore(); + res.CredentialName = credential; + res.Devices.Add(bkdev); + + DataTable dt = res.ReadBackupHeader(this.server); + if (dt != null) + { + foreach (DataRow dr in dt.Rows) + { + if (dr != null && !(dr["DatabaseName"] is DBNull)) + { + string dbName = (string)dr["DatabaseName"]; + bool found = false; + foreach (string str in databaseNames) + { + if (StringComparer.OrdinalIgnoreCase.Compare(str, dbName) == 0) + { + found = true; + break; + } + } + if (found == false) + { + databaseNames.Add(dbName); + } + } + } + } + } + return databaseNames; + } + + public string GetNewPhysicalRestoredFileName(string filePathParam, string dbName, bool isNewDatabase, string type, ref int fileIndex) + { + if (string.IsNullOrEmpty(filePathParam)) + { + return string.Empty; + } + + string result = string.Empty; + string filePath = filePathParam; + int idx = filePath.LastIndexOf('\\'); + string folderPath = filePath.Substring(0, idx); + + string fileName = filePath.Substring(idx + 1); + idx = fileName.LastIndexOf('.'); + string fileExtension = fileName.Substring(idx + 1); + + bool isFolder = true; + bool isValidPath = IsDestinationPathValid(folderPath, ref isFolder); + + if (!isValidPath || !isFolder) + { + if (type != RestoreConstants.Log) + { + folderPath = server.Settings.DefaultFile; + if (folderPath.Length == 0) + { + folderPath = server.Information.MasterDBPath; + } + } + else + { + folderPath = server.Settings.DefaultLog; + if (folderPath.Length == 0) + { + folderPath = server.Information.MasterDBLogPath; + } + } + } + else + { + if (!isNewDatabase) + { + return filePathParam; + } + } + + if (!isNewDatabase) + { + result = folderPath + "\\" + dbName + "." + fileExtension; + } + else + { + if (0 != string.Compare(fileExtension, "mdf", StringComparison.OrdinalIgnoreCase)) + { + result = folderPath + "\\" + dbName + "_" + Convert.ToString(fileIndex, System.Globalization.CultureInfo.InvariantCulture) + "." + fileExtension; + fileIndex++; + } + else + { + result = folderPath + "\\" + dbName + "." + fileExtension; + } + } + + return result; + } + + public bool IsDestinationPathValid(string path, ref bool isFolder) + { + Enumerator en = null; + DataTable dt; + Request req = new Request(); + + en = new Enumerator(); + req.Urn = "Server/File[@FullName='" + Urn.EscapeString(path) + "']"; + dt = en.Process(this.server.ConnectionContext.SqlConnectionObject, req); + + if (dt.Rows.Count > 0) + { + isFolder = !(Convert.ToBoolean(dt.Rows[0]["IsFile"], System.Globalization.CultureInfo.InvariantCulture)); + return true; + } + else + { + isFolder = false; + return false; + } + } + + /// + /// Returns a list of database files + /// + /// SMO database + /// a list of database files + public List GetDbFiles(Database db) + { + List ret = new List(); + if (db == null) + { + return ret; + } + char fileType = '\0'; + foreach (FileGroup fg in db.FileGroups) + { + if ((fg.FileGroupType == FileGroupType.FileStreamDataFileGroup) || (fg.FileGroupType == FileGroupType.MemoryOptimizedDataFileGroup)) + { + fileType = DbFile.FileStreamFileType; + } + else + { + fileType = DbFile.RowFileType; + } + foreach (DataFile f in fg.Files) + { + DbFile dbFile = new DbFile(f.Name, fileType, f.FileName); + ret.Add(dbFile); + } + } + foreach (LogFile f in db.LogFiles) + { + DbFile dbFile = new DbFile(f.Name, DbFile.LogFileType, f.FileName); + ret.Add(dbFile); + } + return ret; + } + + //TODO: the code is moved from ssms and used for other typs of restore operation + //Uncomment when restore operation for those types are supported + /* + public List GetDbFiles(BackupSet bkSet) + { + List ret = new List(); + if (bkSet == null || bkSet.BackupMediaSet == null || bkSet.BackupMediaSet.BackupMediaList.Count() < 1) + { + return ret; + } + DataSet dataset = bkSet.FileList; + if (dataset != null && dataset.Tables.Count > 0) + { + string logicalName = null; + string physicalName = null; + char type = '\0'; + foreach (DataRow dr in dataset.Tables[0].Rows) + { + if (!(dr["LogicalName"] is DBNull)) + { + logicalName = (string)dr["LogicalName"]; + } + if (!(dr["PhysicalName"] is DBNull)) + { + physicalName = (string)dr["PhysicalName"]; + } + if (!(dr["Type"] is DBNull)) + { + // The data type of Type in a list obtained from RESTORE FILELISTONLY is char(1). + string temp = (string)dr["Type"]; + if (!String.IsNullOrEmpty(temp)) + { + type = temp[0]; + } + } + if (!String.IsNullOrEmpty(logicalName) && !String.IsNullOrEmpty(physicalName) && (type != '\0')) + { + DbFile dbFile = new DbFile(logicalName, type, physicalName); + ret.Add(dbFile); + } + } + } + return ret; + } + */ + + /// + /// Returns a list of database files in all the backup devices in the Restore object + /// + public List GetDbFiles(Restore restore) + { + List ret = new List(); + if (restore == null || restore.Devices == null || restore.Devices.Count < 1) + { + return ret; + } + // Using the Restore public API to do the Restore FilelistOnly + Restore res = new Restore(); + res.CredentialName = restore.CredentialName; + res.Devices.Add(restore.Devices[0]); + res.FileNumber = restore.FileNumber; + DataTable datatable = res.ReadFileList(this.server); + if (datatable != null && datatable.Rows.Count > 0) + { + string logicalName = null; + string physicalName = null; + char type = '\0'; + foreach (DataRow dr in datatable.Rows) + { + if (!(dr["LogicalName"] is DBNull)) + { + logicalName = (string)dr["LogicalName"]; + } + if (!(dr["PhysicalName"] is DBNull)) + { + physicalName = (string)dr["PhysicalName"]; + } + if (!(dr["Type"] is DBNull)) + { + // The data type of Type in a list obtained from RESTORE FILELISTONLY is char(1). + string temp = (string)dr["Type"]; + if (!String.IsNullOrEmpty(temp)) + { + type = temp[0]; + } + } + if (!String.IsNullOrEmpty(logicalName) && !String.IsNullOrEmpty(physicalName) && (type != '\0')) + { + DbFile dbFile = new DbFile(logicalName, type, physicalName); + ret.Add(dbFile); + } + } + } + return ret; + } + + /// + /// Set credential name in the restore objects which have a backup set in Microsoft Azure + /// From sql16, default credential is SAS credential so no explict credential needed for restore object. + /// + /// Restore plan created for the restore operation + /// Sql server credential name + public void AddCredentialNameForUrlBackupSet(RestorePlan restorePlan, string credentialName) + { + if (string.IsNullOrEmpty(credentialName) || restorePlan == null || restorePlan.RestoreOperations == null) + { + return; + } + if (restorePlan.Server.VersionMajor >= 13) // for sql16, default backup/restore URL will use SAS + { + return; + } + // If any of the backup media in the restore object is in URL, we assign the credential name to the CredentialName property of the Restore object + foreach (Restore res in restorePlan.RestoreOperations) + { + + if (res.BackupSet != null && res.BackupSet.BackupMediaSet != null && res.BackupSet.BackupMediaSet.BackupMediaList != null) + { + foreach (BackupMedia bkMedia in res.BackupSet.BackupMediaSet.BackupMediaList) + { + if (bkMedia != null && bkMedia.MediaType == DeviceType.Url) + { + res.CredentialName = credentialName; + break; + } + } + } + + if (res.Devices != null) + { + foreach (BackupDeviceItem bkDevice in res.Devices) + { + if (bkDevice.DeviceType == DeviceType.Url) + { + res.CredentialName = credentialName; + break; + } + } + } + } + // If the backup file to which the tail log is going to be backed up is a file in Microsoft Azure, + // we assign the credential name to the Credential Name property of the Backup object + if (restorePlan.TailLogBackupOperation != null && restorePlan.TailLogBackupOperation.Devices != null) + { + foreach (BackupDeviceItem bkdevItem in restorePlan.TailLogBackupOperation.Devices) + { + if (bkdevItem != null && bkdevItem.DeviceType == DeviceType.Url) + { + restorePlan.TailLogBackupOperation.CredentialName = credentialName; + break; + } + } + } + } + + internal string GetDefaultDataFileFolder() + { + string ret = this.server.Settings.DefaultFile; + if (string.IsNullOrEmpty(ret)) + { + ret = this.server.Information.MasterDBPath; + } + + ret = ret.TrimEnd(server.PathSeparator[0]); + return ret; + } + + internal string GetDefaultLogFileFolder() + { + string ret = this.server.Settings.DefaultLog; + if (string.IsNullOrEmpty(ret)) + { + ret = this.server.Information.MasterDBLogPath; + } + + ret = ret.TrimEnd(server.PathSeparator[0]); + return ret; + } + + internal string GetDefaultBackupFolder() + { + string ret = this.server.Settings.BackupDirectory; + ret = ret.TrimEnd(server.PathSeparator[0]); + return ret; + } + + internal string GetDefaultTailLogbackupFile(string databaseName) + { + if (string.IsNullOrEmpty(databaseName)) + { + return string.Empty; + } + var folderpath = GetDefaultBackupFolder(); + var filename = SanitizeFileName(databaseName) + "_LogBackup_" + GetServerCurrentDateTime().ToString("yyyy-MM-dd_HH-mm-ss") + ".bak"; + return PathWrapper.Combine(folderpath, filename); + } + + + /// + /// Returns a default location for tail log backup + /// If the first backup media is from Microsoft Azure, a Microsoft Azure url for the Tail log backup file is returned + /// + internal string GetDefaultTailLogbackupFile(string databaseName, RestorePlan restorePlan) + { + if (string.IsNullOrEmpty(databaseName) || restorePlan == null) + { + return string.Empty; + } + if (restorePlan.TailLogBackupOperation != null && restorePlan.TailLogBackupOperation.Devices != null) + { + restorePlan.TailLogBackupOperation.Devices.Clear(); + } + string folderpath = string.Empty; + BackupMedia firstBackupMedia = this.GetFirstBackupMedia(restorePlan); + string filename = this.SanitizeFileName(databaseName) + "_LogBackup_" + this.GetServerCurrentDateTime().ToString("yyyy-MM-dd_HH-mm-ss") + ".bak"; + if (firstBackupMedia != null && firstBackupMedia.MediaType == DeviceType.Url) + { + // the uri will use the same container as the container of the first backup media + Uri uri; + if (Uri.TryCreate(firstBackupMedia.MediaName, UriKind.Absolute, out uri)) + { + UriBuilder uriBuilder = new UriBuilder(); + uriBuilder.Scheme = uri.Scheme; + uriBuilder.Host = uri.Host; + if (uri.AbsolutePath.Length > 0) + { + string[] parts = uri.AbsolutePath.Split('/'); + string newPath = string.Join("/", parts, 0, parts.Length - 1); + if (newPath.EndsWith("/")) + { + newPath = newPath.Substring(0, newPath.Length - 1); + } + uriBuilder.Host = uriBuilder.Host + newPath; + } + uriBuilder.Path = filename; + string urlFilename = uriBuilder.Uri.AbsoluteUri; + if (restorePlan.TailLogBackupOperation != null && restorePlan.TailLogBackupOperation.Devices != null) + { + restorePlan.TailLogBackupOperation.Devices.Add(new BackupDeviceItem(urlFilename, DeviceType.Url)); + } + return urlFilename; + } + } + folderpath = this.GetDefaultBackupFolder(); + if (restorePlan.TailLogBackupOperation != null && restorePlan.TailLogBackupOperation.Devices != null) + { + restorePlan.TailLogBackupOperation.Devices.Add(new BackupDeviceItem(PathWrapper.Combine(folderpath, filename), DeviceType.File)); + } + return PathWrapper.Combine(folderpath, filename); + } + + internal string GetDefaultStandbyFile(string databaseName) + { + if (string.IsNullOrEmpty(databaseName)) + { + return string.Empty; + } + var folderpath = GetDefaultBackupFolder(); + var filename = SanitizeFileName(databaseName) + "_RollbackUndo_" + GetServerCurrentDateTime().ToString("yyyy-MM-dd_HH-mm-ss") + ".bak"; + return PathWrapper.Combine(folderpath, filename); + } + + //TODO: the code is moved from ssms and used for other typs of restore operation + //Uncomment when restore operation for those types are supported + /* + internal DateTime GetLastBackupDate(DatabaseRestorePlanner planner) + { + BackupSetCollection bkSetColl = planner.BackupSets; + if (bkSetColl.backupsetList.Count > 0) + { + return bkSetColl.backupsetList[bkSetColl.backupsetList.Count - 1].BackupStartDate; + } + return DateTime.MinValue; + } + */ + + /// + /// Sanitizes the name of the file. + /// + /// The name. + /// + internal string SanitizeFileName(string name) + { + char[] result = name.ToCharArray(); + string illegalCharacters = "\\/:*?\"<>|"; + + int resultLength = result.GetLength(0); + int illegalLength = illegalCharacters.Length; + + for (int resultIndex = 0; resultIndex < resultLength; resultIndex++) + { + for (int illegalIndex = 0; illegalIndex < illegalLength; illegalIndex++) + { + if (result[resultIndex] == illegalCharacters[illegalIndex]) + { + result[resultIndex] = '_'; + } + } + } + return new string(result); + } + + //TODO: the code is moved from ssms and used for other typs of restore operation + //Uncomment when restore operation for those types are supported + /* + internal void MarkDuplicateSuspectPages(List suspectPageObjList) + { + List newList = new List(suspectPageObjList); + newList.Sort(); + newList[0].IsDuplicate = false; + for (int i = 1; i < newList.Count; i++) + { + if (newList[i].CompareTo(newList[i - 1]) == 0) + { + newList[i].IsDuplicate = true; + newList[i - 1].IsDuplicate = true; + } + else + { + newList[i].IsDuplicate = false; + } + } + } + */ + + /* + internal void VerifyChecksumWorker(RestorePlan plan, IBackgroundOperationContext backgroundContext, EventHandler cancelEventHandler) + { + if (plan == null || plan.RestoreOperations.Count() == 0) + { + return; + } + backgroundContext.IsCancelable = true; + backgroundContext.CancelRequested += cancelEventHandler; + try + { + foreach (Restore res in plan.RestoreOperations) + { + if (!backgroundContext.IsCancelRequested && res.backupSet != null) + { + StringBuilder bkMediaNames = new StringBuilder(); + foreach (BackupDeviceItem item in res.Devices) + { + backgroundContext.Status = SR.Verifying + ":" + item.Name; + try + { + // Use the Restore public API to do the Restore VerifyOnly query + Restore restore = new Restore(); + restore.CredentialName = res.CredentialName; + restore.Devices.Add(item); + if (!res.SqlVerify(this.server)) + { + throw new Exception(SR.BackupDeviceItemVerificationFailed(item.Name)); + } + } + catch (Exception ex) + { + throw new Exception(SR.BackupDeviceItemVerificationFailed(item.Name), ex); + } + } + } + } + } + finally + { + backgroundContext.CancelRequested -= cancelEventHandler; + } + } + */ + + private BackupMedia GetFirstBackupMedia(RestorePlan restorePlan) + { + /* + if (restorePlan == null || restorePlan.RestoreOperations == null || restorePlan.RestoreOperations.Count == 0) + { + return null; + } + Restore res = restorePlan.RestoreOperations[0]; + if (res == null || res.backupSet == null || res.backupSet.backupMediaSet == null || res.backupSet.backupMediaSet.BackupMediaList == null || res.backupSet.backupMediaSet.BackupMediaList.ToList().Count == 0) + { + return null; + } + return res.backupSet.backupMediaSet.BackupMediaList.ToList()[0]; + */ + return null; + } + } + + /// + /// A class representing a database file + /// + public class DbFile + { + public DbFile(string logicalName, char type, string physicalName) + { + this.logicalName = logicalName; + this.physicalName = physicalName; + if (type != '\0') + { + this.dbFileType = type; + } + this.PhysicalNameRelocate = physicalName; + } + + // Database file types + // When restoring backup, the engine returns the following file type values. + public const char RowFileType = 'D'; + public const char LogFileType = 'L'; + public const char FullTextCatalogFileType = 'F'; + public const char FileStreamFileType = 'S'; + + private string logicalName; + public string LogicalName + { + get { return logicalName; } + } + + private string physicalName; + public string PhysicalName + { + get { return physicalName; } + } + + internal char dbFileType; + + /// + /// Returns the database file type string to be displayed in the dialog + /// + public string DbFileType + { + get + { + string value = string.Empty; + switch (dbFileType) + { + case DbFile.RowFileType: + value = "RowData";//TODO SR.RowData; + break; + case DbFile.LogFileType: + value = "Log";// SR.Log; + break; + case DbFile.FileStreamFileType: + value = "FileStream";// SR.FileStream; + break; + case DbFile.FullTextCatalogFileType: + value = "FullTextCatlog";// SR.FullTextCatlog; + break; + } + return value; + } + } + + public string PhysicalNameRelocate; + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.strings b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.strings index 2fdd8f00..9fb6f4f9 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.strings +++ b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.strings @@ -812,3 +812,11 @@ Backup_TaskName = Backup Database Task_InProgress = In progress Task_Completed = Completed + +########################################################################### +# Restore +ConflictWithNoRecovery = Specifying this option when restoring a backup with the NORECOVERY option is not permitted. +InvalidPathForDatabaseFile = Invalid path for database file: '{0}' +Log = Log +RestorePlanFailed = Failed to create restore plan +RestoreNotSupported = Restore database is not supported diff --git a/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj b/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj index a87fd87f..07ea4635 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj +++ b/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj @@ -19,7 +19,7 @@ - + diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTask.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTask.cs index e87b13d5..d84ab501 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTask.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTask.cs @@ -32,20 +32,20 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices public event EventHandler> MessageAdded; public event EventHandler> StatusChanged; - public event EventHandler> TaskCanceled; /// /// Creates new instance of SQL task /// /// Task Metadata /// The function to run to start the task - public SqlTask(TaskMetadata taskMetdata, Func> taskToRun) + public SqlTask(TaskMetadata taskMetdata, Func> taskToRun, Func> taskToCancel) { Validate.IsNotNull(nameof(taskMetdata), taskMetdata); Validate.IsNotNull(nameof(taskToRun), taskToRun); TaskMetadata = taskMetdata; TaskToRun = taskToRun; + TaskToCancel = taskToCancel; StartTime = DateTime.UtcNow; TaskId = Guid.NewGuid(); TokenSource = new CancellationTokenSource(); @@ -70,6 +70,15 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices set; } + /// + /// The function to cancel the operation + /// + private Func> TaskToCancel + { + get; + set; + } + /// /// Task unique id /// @@ -81,21 +90,21 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices public async Task RunAsync() { TaskStatus = SqlTaskStatus.InProgress; - await TaskToRun(this).ContinueWith(task => + await RunAndCancel().ContinueWith(task => { if (task.IsCompleted && !task.IsCanceled && !task.IsFaulted) { TaskResult taskResult = task.Result; TaskStatus = taskResult.TaskStatus; } - else if(task.IsCanceled) + else if (task.IsCanceled) { TaskStatus = SqlTaskStatus.Canceled; } - else if(task.IsFaulted) + else if (task.IsFaulted) { TaskStatus = SqlTaskStatus.Failed; - if(task.Exception != null) + if (task.Exception != null) { AddMessage(task.Exception.Message); } @@ -103,6 +112,95 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices }); } + /// + /// Create a backup task for execution and cancellation + /// + /// + /// + internal async Task RunAndCancel() + { + AddMessage(SR.Task_InProgress, SqlTaskStatus.InProgress, true); + + TaskResult taskResult = new TaskResult(); + Task performTask = TaskToRun(this); + Task completedTask = null; + + try + { + if (TaskToCancel != null) + { + AutoResetEvent backupCompletedEvent = new AutoResetEvent(initialState: false); + Task cancelTask = Task.Run(() => CancelTaskAsync(TokenSource.Token, backupCompletedEvent)); + + completedTask = await Task.WhenAny(performTask, cancelTask); + + // Release the cancelTask + if (completedTask == performTask) + { + backupCompletedEvent.Set(); + } + } + else + { + completedTask = await Task.WhenAny(performTask); + } + + AddMessage(completedTask.Result.TaskStatus == SqlTaskStatus.Failed ? completedTask.Result.ErrorMessage : SR.Task_Completed, + completedTask.Result.TaskStatus); + taskResult = completedTask.Result; + + } + catch (OperationCanceledException) + { + taskResult.TaskStatus = SqlTaskStatus.Canceled; + } + catch (Exception ex) + { + if (ex.InnerException != null && ex.InnerException is OperationCanceledException) + { + taskResult.TaskStatus = SqlTaskStatus.Canceled; + } + else + { + taskResult.TaskStatus = SqlTaskStatus.Failed; + AddMessage(ex.Message); + } + } + return taskResult; + } + + /// + /// Async task to cancel backup + /// + /// + /// + /// + /// + private async Task CancelTaskAsync(CancellationToken token, AutoResetEvent backupCompletedEvent) + { + // Create a task for backup cancellation request + + TaskResult result = new TaskResult(); + WaitHandle[] waitHandles = new WaitHandle[2] + { + backupCompletedEvent, + token.WaitHandle + }; + + WaitHandle.WaitAny(waitHandles); + try + { + await this.TaskToCancel(this); + result.TaskStatus = SqlTaskStatus.Canceled; + } + catch (Exception ex) + { + result.TaskStatus = SqlTaskStatus.Failed; + result.ErrorMessage = ex.Message; + } + + return result; + } //Run Task synchronously public void Run() { @@ -138,7 +236,10 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices if (isCancelRequested != value) { isCancelRequested = value; - OnTaskCancelRequested(); + if (isCancelRequested) + { + TokenSource.Cancel(); + } } } } @@ -379,16 +480,6 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices } } - private void OnTaskCancelRequested() - { - TokenSource.Cancel(); - var handler = TaskCanceled; - if (handler != null) - { - handler(this, new TaskEventArgs(TaskStatus, this)); - } - } - public void Dispose() { //Dispose diff --git a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTaskManager.cs b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTaskManager.cs index 766cd95e..9aae187a 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTaskManager.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/TaskServices/SqlTaskManager.cs @@ -82,12 +82,13 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices /// /// Task Metadata /// The function to run the operation + /// The function to cancel the operation /// - public SqlTask CreateTask(TaskMetadata taskMetadata, Func> taskToRun) + public SqlTask CreateTask(TaskMetadata taskMetadata, Func> taskToRun, Func> taskToCancel) { ValidateNotDisposed(); - var newtask = new SqlTask(taskMetadata, taskToRun ); + var newtask = new SqlTask(taskMetadata, taskToRun, taskToCancel); lock (lockObject) { @@ -97,6 +98,31 @@ namespace Microsoft.SqlTools.ServiceLayer.TaskServices return newtask; } + /// + /// Creates a new task + /// + /// Task Metadata + /// The function to run the operation + /// + public SqlTask CreateTask(TaskMetadata taskMetadata, Func> taskToRun) + { + return CreateTask(taskMetadata, taskToRun, null); + } + + /// + /// Creates a new task and starts the task + /// + /// Task Metadata + /// The function to run the operation + /// The function to cancel the operation + /// + public SqlTask CreateAndRun(TaskMetadata taskMetadata, Func> taskToRun, Func> taskToCancel) + { + var sqlTask = CreateTask(taskMetadata, taskToRun, null); + sqlTask.Run(); + return sqlTask; + } + public void Dispose() { Dispose(true); diff --git a/test/CodeCoverage/codecoverage.bat b/test/CodeCoverage/codecoverage.bat index 669b56d7..531078ec 100644 --- a/test/CodeCoverage/codecoverage.bat +++ b/test/CodeCoverage/codecoverage.bat @@ -40,7 +40,7 @@ dotnet restore %REPOROOT%\test\Microsoft.SqlTools.ServiceLayer.TestDriver.Tests\ dotnet build %REPOROOT%\test\Microsoft.SqlTools.ServiceLayer.TestDriver.Tests\Microsoft.SqlTools.ServiceLayer.TestDriver.Tests.csproj %DOTNETCONFIG% SET TEST_SERVER=localhost -SET SQLTOOLSSERVICE_EXE=%REPOROOT%\src\Microsoft.SqlTools.ServiceLayer\bin\Integration\netcoreapp2.0\win7-x64\Microsoft.SqlTools.ServiceLayer.exe +SET SQLTOOLSSERVICE_EXE=%REPOROOT%\src\Microsoft.SqlTools.ServiceLayer\bin\Debug\netcoreapp2.0\win7-x64\MicrosoftSqlToolsServiceLayer.exe SET SERVICECODECOVERAGE=True SET CODECOVERAGETOOL="%WORKINGDIR%packages\OpenCover.4.6.684\tools\OpenCover.Console.exe" SET CODECOVERAGEOUTPUT=coverage.xml diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/DisasterRecovery/RestoreDatabaseServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/DisasterRecovery/RestoreDatabaseServiceTests.cs new file mode 100644 index 00000000..074ace74 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/DisasterRecovery/RestoreDatabaseServiceTests.cs @@ -0,0 +1,213 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Globalization; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SqlTools.Extensibility; +using Microsoft.SqlTools.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.DisasterRecovery; +using Microsoft.SqlTools.ServiceLayer.DisasterRecovery.Contracts; +using Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation; +using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility; +using Microsoft.SqlTools.ServiceLayer.TaskServices; +using Microsoft.SqlTools.ServiceLayer.Test.Common; +using Microsoft.SqlTools.ServiceLayer.UnitTests; +using Moq; +using Xunit; +using static Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility.LiveConnectionHelper; + +namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery +{ + public class RestoreDatabaseServiceTests : ServiceTestBase + { + private ConnectionService _connectService = TestServiceProvider.Instance.ConnectionService; + private Mock serviceHostMock; + private DisasterRecoveryService service; + + public RestoreDatabaseServiceTests() + { + serviceHostMock = new Mock(); + service = CreateService(); + service.InitializeService(serviceHostMock.Object); + } + + [Fact] + public async void RestorePlanShouldCreatedSuccessfullyForFullBackup() + { + string backupFileName = "FullBackup.bak"; + bool canRestore = true; + await VerifyRestore(backupFileName, canRestore); + } + + [Fact] + public async void RestoreShouldExecuteSuccessfullyForFullBackup() + { + string backupFileName = "FullBackup.bak"; + bool canRestore = true; + var restorePlan = await VerifyRestore(backupFileName, canRestore, true); + } + + [Fact] + public async void RestorePlanShouldFailForDiffBackup() + { + string backupFileName = "DiffBackup.bak"; + bool canRestore = false; + await VerifyRestore(backupFileName, canRestore); + } + + [Fact] + public async void RestorePlanShouldFailForTransactionLogBackup() + { + string backupFileName = "TransactionLogBackup.bak"; + bool canRestore = false; + await VerifyRestore(backupFileName, canRestore); + } + + [Fact] + public async Task RestorePlanRequestShouldReturnResponseWithDbFiles() + { + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + { + TestConnectionResult connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath); + + string filePath = GetBackupFilePath("FullBackup.bak"); + + RestoreParams restoreParams = new RestoreParams + { + BackupFilePath = filePath, + OwnerUri = queryTempFile.FilePath + }; + + await RunAndVerify( + test: (requestContext) => service.HandleRestorePlanRequest(restoreParams, requestContext), + verify: ((result) => + { + Assert.True(result.DbFiles.Any()); + Assert.Equal(result.DatabaseName, "BackupTestDb"); + })); + } + } + + [Fact] + public async Task RestoreDatabaseRequestShouldStartTheRestoreTask() + { + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + { + TestConnectionResult connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath); + + string filePath = GetBackupFilePath("FullBackup.bak"); + + RestoreParams restoreParams = new RestoreParams + { + BackupFilePath = filePath, + OwnerUri = queryTempFile.FilePath + }; + + await RunAndVerify( + test: (requestContext) => service.HandleRestoreRequest(restoreParams, requestContext), + verify: ((result) => + { + string taskId = result.TaskId; + var task = SqlTaskManager.Instance.Tasks.FirstOrDefault(x => x.TaskId.ToString() == taskId); + Assert.NotNull(task); + + })); + } + } + + private async Task DropDatabase(string databaseName) + { + string dropDatabaseQuery = string.Format(CultureInfo.InvariantCulture, + Scripts.DropDatabaseIfExist, databaseName); + + await TestServiceProvider.Instance.RunQueryAsync(TestServerType.OnPrem, "master", dropDatabaseQuery); + } + + private async Task VerifyRestore(string backupFileName, bool canRestore, bool execute = false) + { + string filePath = GetBackupFilePath(backupFileName); + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + { + TestConnectionResult connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath); + + RestoreDatabaseHelper service = new RestoreDatabaseHelper(); + var request = new RestoreParams + { + BackupFilePath = filePath, + DatabaseName = string.Empty, + OwnerUri = queryTempFile.FilePath + }; + + var restoreDataObject = service.CreateRestoreDatabaseTaskDataObject(request); + var response = service.CreateRestorePlanResponse(restoreDataObject); + + Assert.NotNull(response); + Assert.Equal(response.CanRestore, canRestore); + if (canRestore) + { + Assert.True(response.DbFiles.Any()); + Assert.Equal(response.DatabaseName, "BackupTestDb"); + if(execute) + { + await DropDatabase(response.DatabaseName); + Thread.Sleep(2000); + request.RelocateDbFiles = response.RelocateFilesNeeded; + service.ExecuteRestore(restoreDataObject); + Assert.True(restoreDataObject.Server.Databases.Contains(response.DatabaseName)); + await DropDatabase(response.DatabaseName); + } + } + + return response; + } + } + + private static string TestLocationDirectory + { + get + { + return Path.Combine(RunEnvironmentInfo.GetTestDataLocation(), "DisasterRecovery"); + } + } + + public DirectoryInfo BackupFileDirectory + { + get + { + string d = Path.Combine(TestLocationDirectory, "Backups"); + return new DirectoryInfo(d); + } + } + + public FileInfo GetBackupFile(string fileName) + { + return new FileInfo(Path.Combine(BackupFileDirectory.FullName, fileName)); + } + + private string GetBackupFilePath(string fileName) + { + FileInfo inputFile = GetBackupFile(fileName); + return inputFile.FullName; + } + + protected DisasterRecoveryService CreateService() + { + CreateServiceProviderWithMinServices(); + + // Create the service using the service provider, which will initialize dependencies + return ServiceProvider.GetService(); + } + + protected override RegisteredServiceProvider CreateServiceProviderWithMinServices() + { + return CreateProvider() + .RegisterSingleService(new DisasterRecoveryService()); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Microsoft.SqlTools.ServiceLayer.IntegrationTests.csproj b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Microsoft.SqlTools.ServiceLayer.IntegrationTests.csproj index 476c1e0d..90a824b9 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Microsoft.SqlTools.ServiceLayer.IntegrationTests.csproj +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Microsoft.SqlTools.ServiceLayer.IntegrationTests.csproj @@ -24,13 +24,14 @@ + - + diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/Microsoft.SqlTools.ServiceLayer.Test.Common.csproj b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/Microsoft.SqlTools.ServiceLayer.Test.Common.csproj index db44620c..9ebae3ee 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/Microsoft.SqlTools.ServiceLayer.Test.Common.csproj +++ b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/Microsoft.SqlTools.ServiceLayer.Test.Common.csproj @@ -12,7 +12,7 @@ - + diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestData/DisasterRecovery/Backups/DiffBackup.bak b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestData/DisasterRecovery/Backups/DiffBackup.bak new file mode 100644 index 00000000..d0b95795 Binary files /dev/null and b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestData/DisasterRecovery/Backups/DiffBackup.bak differ diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestData/DisasterRecovery/Backups/FullBackup.bak b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestData/DisasterRecovery/Backups/FullBackup.bak new file mode 100644 index 00000000..bcc7074f Binary files /dev/null and b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestData/DisasterRecovery/Backups/FullBackup.bak differ diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestData/DisasterRecovery/Backups/TransactionLogBackup.bak b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestData/DisasterRecovery/Backups/TransactionLogBackup.bak new file mode 100644 index 00000000..083582af Binary files /dev/null and b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestData/DisasterRecovery/Backups/TransactionLogBackup.bak differ diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestServiceProvider.cs b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestServiceProvider.cs index c096ef70..e2a51dff 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestServiceProvider.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestServiceProvider.cs @@ -50,6 +50,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common } } + public ConnectionService ConnectionService + { + get + { + return ConnectionService.Instance; + } + } + public ObjectExplorerService ObjectExplorerService { get diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Microsoft.SqlTools.ServiceLayer.TestDriver.csproj b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Microsoft.SqlTools.ServiceLayer.TestDriver.csproj index c1c7c279..fe9f27d7 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Microsoft.SqlTools.ServiceLayer.TestDriver.csproj +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Microsoft.SqlTools.ServiceLayer.TestDriver.csproj @@ -12,7 +12,7 @@ - + diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Microsoft.SqlTools.ServiceLayer.UnitTests.csproj b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Microsoft.SqlTools.ServiceLayer.UnitTests.csproj index 05655803..692dff6e 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Microsoft.SqlTools.ServiceLayer.UnitTests.csproj +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Microsoft.SqlTools.ServiceLayer.UnitTests.csproj @@ -13,7 +13,7 @@ - + diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/DatabaseOperationStub.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/DatabaseOperationStub.cs index 6acb0926..6fbfc64c 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/DatabaseOperationStub.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/DatabaseOperationStub.cs @@ -31,15 +31,14 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.TaskServices public async Task FunctionToRun(SqlTask sqlTask) { - sqlTask.TaskCanceled += OnTaskCanceled; return await Task.Factory.StartNew(() => { while (!IsStopped) { //Just keep running - if (cancellationTokenSource.Token.IsCancellationRequested) + if (sqlTask.TaskStatus == SqlTaskStatus.Canceled) { - throw new OperationCanceledException(); + break; } if (Failed) { @@ -53,9 +52,15 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.TaskServices }); } - private void OnTaskCanceled(object sender, TaskEventArgs e) + public async Task FunctionToCancel(SqlTask sqlTask) { - cancellationTokenSource.Cancel(); + return await Task.Factory.StartNew(() => + { + return new TaskResult + { + TaskStatus = SqlTaskStatus.Canceled + }; + }); } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/SqlTaskTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/SqlTaskTests.cs index 9f04dfc6..8b0aee6c 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/SqlTaskTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/SqlTaskTests.cs @@ -16,18 +16,20 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.TaskServices [Fact] public void CreateSqlTaskGivenInvalidArgumentShouldThrowException() { - Assert.Throws(() => new SqlTask(null, new DatabaseOperationStub().FunctionToRun)); - Assert.Throws(() => new SqlTask(new TaskMetadata(), null)); + DatabaseOperationStub operation = new DatabaseOperationStub(); + + Assert.Throws(() => new SqlTask(null, operation.FunctionToRun, operation.FunctionToCancel)); + Assert.Throws(() => new SqlTask(new TaskMetadata(), null, null)); } [Fact] public void CreateSqlTaskShouldGenerateANewId() { - SqlTask sqlTask = new SqlTask(new TaskMetadata(), new DatabaseOperationStub().FunctionToRun); + SqlTask sqlTask = new SqlTask(new TaskMetadata(), new DatabaseOperationStub().FunctionToRun, null); Assert.NotNull(sqlTask.TaskId); Assert.True(sqlTask.TaskId != Guid.Empty); - SqlTask sqlTask2 = new SqlTask(new TaskMetadata(), new DatabaseOperationStub().FunctionToRun); + SqlTask sqlTask2 = new SqlTask(new TaskMetadata(), new DatabaseOperationStub().FunctionToRun, null); Assert.False(sqlTask.TaskId.CompareTo(sqlTask2.TaskId) == 0); } @@ -40,7 +42,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.TaskServices { TaskStatus = expectedStatus }; - SqlTask sqlTask = new SqlTask(new TaskMetadata(), operation.FunctionToRun); + SqlTask sqlTask = new SqlTask(new TaskMetadata(), operation.FunctionToRun, null); Assert.Equal(sqlTask.TaskStatus, SqlTaskStatus.NotStarted); Task taskToVerify = sqlTask.RunAsync().ContinueWith(task => { @@ -67,7 +69,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.TaskServices { ServerName = "server name", DatabaseName = "database name" - }, operation.FunctionToRun); + }, operation.FunctionToRun, operation.FunctionToCancel); Task taskToVerify = sqlTask.RunAsync().ContinueWith(task => { @@ -89,7 +91,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.TaskServices { TaskStatus = expectedStatus }; - SqlTask sqlTask = new SqlTask(new TaskMetadata(), operation.FunctionToRun); + SqlTask sqlTask = new SqlTask(new TaskMetadata(), operation.FunctionToRun, operation.FunctionToCancel); Assert.Equal(sqlTask.TaskStatus, SqlTaskStatus.NotStarted); Task taskToVerify = sqlTask.RunAsync().ContinueWith(task => { @@ -111,7 +113,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.TaskServices operation.TaskResult = new TaskResult { }; - SqlTask sqlTask = new SqlTask(new TaskMetadata(), operation.FunctionToRun); + SqlTask sqlTask = new SqlTask(new TaskMetadata(), operation.FunctionToRun, operation.FunctionToCancel); Assert.Equal(sqlTask.TaskStatus, SqlTaskStatus.NotStarted); Task taskToVerify = sqlTask.RunAsync().ContinueWith(task => { @@ -133,7 +135,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.TaskServices operation.TaskResult = new TaskResult { }; - SqlTask sqlTask = new SqlTask(new TaskMetadata(), operation.FunctionToRun); + SqlTask sqlTask = new SqlTask(new TaskMetadata(), operation.FunctionToRun, operation.FunctionToCancel); Assert.Equal(sqlTask.TaskStatus, SqlTaskStatus.NotStarted); Task taskToVerify = sqlTask.RunAsync().ContinueWith(task => { diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/TaskManagerTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/TaskManagerTests.cs index 896cf38e..720252b4 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/TaskManagerTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/TaskManagerTests.cs @@ -71,12 +71,12 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.TaskServices operation.TaskResult = new TaskResult { }; - SqlTask sqlTask = manager.CreateTask(taskMetaData, operation.FunctionToRun); + SqlTask sqlTask = manager.CreateTask(taskMetaData, operation.FunctionToRun, operation.FunctionToCancel); Assert.NotNull(sqlTask); Task taskToVerify = sqlTask.RunAsync().ContinueWith(task => { - Assert.Equal(sqlTask.TaskStatus, expectedStatus); + Assert.Equal(expectedStatus, sqlTask.TaskStatus); Assert.Equal(sqlTask.IsCancelRequested, true); manager.Reset(); diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/TaskServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/TaskServiceTests.cs index df937fb6..6350c455 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/TaskServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/TaskServices/TaskServiceTests.cs @@ -69,11 +69,11 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.TaskServices serviceHostMock.AddEventHandling(TaskCreatedNotification.Type, null); serviceHostMock.AddEventHandling(TaskStatusChangedNotification.Type, null); DatabaseOperationStub operation = new DatabaseOperationStub(); - SqlTask sqlTask = service.TaskManager.CreateTask(taskMetaData, operation.FunctionToRun); + SqlTask sqlTask = service.TaskManager.CreateTask(taskMetaData, operation.FunctionToRun, operation.FunctionToCancel); Task taskToVerify = sqlTask.RunAsync().ContinueWith(task => { serviceHostMock.Verify(x => x.SendEvent(TaskStatusChangedNotification.Type, - It.Is(t => t.Status == SqlTaskStatus.Canceled)), Times.Once()); + It.Is(t => t.Status == SqlTaskStatus.Canceled)), Times.AtLeastOnce()); }); CancelTaskParams cancelParams = new CancelTaskParams {