From 52773bc26d5364fddcb8726e6578e002390b7df5 Mon Sep 17 00:00:00 2001 From: Cory Rivera Date: Mon, 28 Aug 2023 12:24:02 -0700 Subject: [PATCH] Add Attach Database functionality to Object Management Service (#2193) --- .../Admin/AdminService.cs | 73 +++++ .../Contracts/GetAssociatedFilesRequest.cs | 36 +++ .../Admin/Contracts/GetDataFolderRequest.cs | 32 +++ .../DisasterRecovery/CommonUtilities.cs | 261 +++++++++++------- .../Contracts/AttachDatabaseRequest.cs | 29 ++ .../ObjectManagementService.cs | 10 +- .../ObjectTypes/Database/DatabaseHandler.cs | 55 ++++ .../ObjectManagement/DatabaseHandlerTests.cs | 106 ++++++- .../ObjectManagement/UtilsTests.cs | 135 +++++++++ 9 files changed, 638 insertions(+), 99 deletions(-) create mode 100644 src/Microsoft.SqlTools.ServiceLayer/Admin/Contracts/GetAssociatedFilesRequest.cs create mode 100644 src/Microsoft.SqlTools.ServiceLayer/Admin/Contracts/GetDataFolderRequest.cs create mode 100644 src/Microsoft.SqlTools.ServiceLayer/ObjectManagement/Contracts/AttachDatabaseRequest.cs create mode 100644 test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ObjectManagement/UtilsTests.cs diff --git a/src/Microsoft.SqlTools.ServiceLayer/Admin/AdminService.cs b/src/Microsoft.SqlTools.ServiceLayer/Admin/AdminService.cs index 092a467a..2b604c2d 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Admin/AdminService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Admin/AdminService.cs @@ -8,10 +8,13 @@ using System; using System.Collections.Concurrent; using System.Threading.Tasks; +using Microsoft.Data.SqlClient; +using Microsoft.SqlServer.Management.Common; using Microsoft.SqlServer.Management.Smo; using Microsoft.SqlTools.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.Admin.Contracts; using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.DisasterRecovery; using Microsoft.SqlTools.ServiceLayer.Hosting; using Microsoft.SqlTools.ServiceLayer.Management; using Microsoft.SqlTools.ServiceLayer.Utility; @@ -71,6 +74,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Admin serviceHost.SetRequestHandler(CreateLoginRequest.Type, HandleCreateLoginRequest, true); serviceHost.SetRequestHandler(DefaultDatabaseInfoRequest.Type, HandleDefaultDatabaseInfoRequest, true); serviceHost.SetRequestHandler(GetDatabaseInfoRequest.Type, HandleGetDatabaseInfoRequest, true); + serviceHost.SetRequestHandler(GetDataFolderRequest.Type, HandleGetDataFolderRequest, true); + serviceHost.SetRequestHandler(GetAssociatedFilesRequest.Type, HandleGetAssociatedFilesRequest, true); } /// @@ -153,6 +158,74 @@ namespace Microsoft.SqlTools.ServiceLayer.Admin }); } + /// + /// Handle get database info request + /// + internal static async Task HandleGetDataFolderRequest( + GetDataFolderParams databaseParams, + RequestContext requestContext) + { + Func requestHandler = async () => + { + ConnectionInfo connInfo; + AdminService.ConnectionServiceInstance.TryFindConnection( + databaseParams.ConnectionUri, + out connInfo); + using (SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connInfo)) + { + // Connection gets disconnected when backup is done + ServerConnection serverConnection = new ServerConnection(sqlConn); + var dataFolder = CommonUtilities.GetDefaultDataFolder(serverConnection); + await requestContext.SendResult(dataFolder); + } + }; + + Task task = Task.Run(async () => await requestHandler()).ContinueWithOnFaulted(async t => + { + // Get innermost exception to get original error message + Exception ex = t.Exception; + while (ex.InnerException != null) + { + ex = ex.InnerException; + }; + await requestContext.SendError(ex.Message); + }); + } + + /// + /// Handle get associated database files request + /// + internal static async Task HandleGetAssociatedFilesRequest( + GetAssociatedFilesParams databaseParams, + RequestContext requestContext) + { + Func requestHandler = async () => + { + ConnectionInfo connInfo; + AdminService.ConnectionServiceInstance.TryFindConnection( + databaseParams.ConnectionUri, + out connInfo); + using (SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connInfo)) + { + // Connection gets disconnected when backup is done + ServerConnection serverConnection = new ServerConnection(sqlConn); + var files = CommonUtilities.GetAssociatedFilePaths(serverConnection, databaseParams.PrimaryFilePath); + await requestContext.SendResult(files); + } + }; + + Task task = Task.Run(async () => await requestHandler()).ContinueWithOnFaulted(async t => + { + // Get innermost exception to get original error message + Exception ex = t.Exception; + while (ex.InnerException != null) + { + ex = ex.InnerException; + }; + await requestContext.SendError(ex.Message); + }); + } + /// /// Return database info for a specific database /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/Admin/Contracts/GetAssociatedFilesRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/Admin/Contracts/GetAssociatedFilesRequest.cs new file mode 100644 index 00000000..fc8421ca --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Admin/Contracts/GetAssociatedFilesRequest.cs @@ -0,0 +1,36 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +#nullable disable + +using Microsoft.SqlTools.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Admin.Contracts +{ + /// + /// Params for a get database info request + /// + public class GetAssociatedFilesParams + { + /// + /// URI identifier for the connection to the target server. + /// + public string ConnectionUri { get; set; } + /// + /// The file path for the primary file that we want to get the associated files for. + /// + public string PrimaryFilePath { get; set; } + } + + /// + /// Get database info request mapping + /// + public class GetAssociatedFilesRequest + { + public static readonly + RequestType Type = + RequestType.Create("admin/getassociatedfiles"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Admin/Contracts/GetDataFolderRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/Admin/Contracts/GetDataFolderRequest.cs new file mode 100644 index 00000000..a2cb5125 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Admin/Contracts/GetDataFolderRequest.cs @@ -0,0 +1,32 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +#nullable disable + +using Microsoft.SqlTools.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Admin.Contracts +{ + /// + /// Params for a get database info request + /// + public class GetDataFolderParams + { + /// + /// URI identifier for the connection to get the server folder info for + /// + public string ConnectionUri { get; set; } + } + + /// + /// Get database info request mapping + /// + public class GetDataFolderRequest + { + public static readonly + RequestType Type = + RequestType.Create("admin/getdatafolder"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/CommonUtilities.cs b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/CommonUtilities.cs index 63551789..bfafd08c 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/CommonUtilities.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/CommonUtilities.cs @@ -14,6 +14,7 @@ using Microsoft.SqlServer.Management.Sdk.Sfc; using Microsoft.SqlServer.Management.Smo; using SMO = Microsoft.SqlServer.Management.Smo; using Microsoft.SqlTools.ServiceLayer.Management; +using System.IO; namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery { @@ -35,13 +36,13 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery Database, Files } - + /// /// Backup set type /// public enum BackupsetType { - BackupsetDatabase, + BackupsetDatabase, BackupsetLog, BackupsetDifferential, BackupsetFiles @@ -101,7 +102,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery } else { - return (RestoreItemLocation+RestoreItemDeviceType.ToString() + IsLogicalDevice.ToString()).GetHashCode(); + return (RestoreItemLocation + RestoreItemDeviceType.ToString() + IsLogicalDevice.ToString()).GetHashCode(); } } } @@ -136,17 +137,17 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery /// /// public CommonUtilities(CDataContainer dataContainer, ServerConnection sqlConnection) - { + { this.dataContainer = dataContainer; this.sqlConnection = sqlConnection; this.excludedDatabases = new ArrayList(); this.excludedDatabases.Add("master"); this.excludedDatabases.Add("tempdb"); } - + public int GetServerVersion() { - return this.dataContainer.Server.Information.Version.Major; + return this.dataContainer.Server.Information.Version.Major; } /// @@ -162,7 +163,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery else if (String.Compare(stringDeviceType, RestoreConstants.Url, StringComparison.OrdinalIgnoreCase) == 0) { return DeviceType.Url; - } + } else { return DeviceType.LogicalDevice; @@ -188,10 +189,10 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery return DeviceType.LogicalDevice; } } - + public BackupDeviceType GetPhisycalDeviceTypeOfLogicalDevice(string deviceName) { - Enumerator enumerator = new Enumerator(); + Enumerator enumerator = new Enumerator(); Request request = new Request(); DataSet dataset = new DataSet(); dataset.Locale = System.Globalization.CultureInfo.InvariantCulture; @@ -199,16 +200,16 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery dataset = enumerator.Process(this.sqlConnection, request); if (dataset.Tables[0].Rows.Count > 0) - { + { BackupDeviceType controllerType = (BackupDeviceType)(Convert.ToInt16(dataset.Tables[0].Rows[0]["BackupDeviceType"], System.Globalization.CultureInfo.InvariantCulture)); return controllerType; } else { - throw new Exception("Unexpected error"); + throw new Exception("Unexpected error"); } } - + public bool ServerHasTapes() { try @@ -221,12 +222,12 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery ds = en.Process(this.sqlConnection, req); if (ds.Tables[0].Rows.Count > 0) - { + { return true; } return false; } - catch(Exception) + catch (Exception) { return false; } @@ -241,18 +242,18 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery DataSet ds = new DataSet(); ds.Locale = System.Globalization.CultureInfo.InvariantCulture; req.Urn = "Server/BackupDevice"; - ds = en.Process(this.sqlConnection,req); - + ds = en.Process(this.sqlConnection, req); + if (ds.Tables[0].Rows.Count > 0) - { - return true; + { + return true; } return false; } - catch(Exception) + catch (Exception) { return false; - } + } } /// @@ -265,8 +266,8 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery char[] result = name.ToCharArray(); string illegalCharacters = "\\/:*?\"<>|"; - int resultLength = result.GetLength(0); - int illegalLength = illegalCharacters.Length; + int resultLength = result.GetLength(0); + int illegalLength = illegalCharacters.Length; for (int resultIndex = 0; resultIndex < resultLength; resultIndex++) { @@ -278,17 +279,17 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery } } } - + return new string(result); } - + public RecoveryModel GetRecoveryModel(string databaseName) { Enumerator en = null; DataSet ds = new DataSet(); ds.Locale = System.Globalization.CultureInfo.InvariantCulture; Request req = new Request(); - RecoveryModel recoveryModel = RecoveryModel.Simple; + RecoveryModel recoveryModel = RecoveryModel.Simple; en = new Enumerator(); req.Urn = "Server/Database[@Name='" + Urn.EscapeString(databaseName) + "']/Option"; @@ -297,9 +298,9 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery ds = en.Process(this.sqlConnection, req); if (ds.Tables[0].Rows.Count > 0) - { - recoveryModel = (RecoveryModel)(ds.Tables[0].Rows[0]["RecoveryModel"]); - } + { + recoveryModel = (RecoveryModel)(ds.Tables[0].Rows[0]["RecoveryModel"]); + } return recoveryModel; } @@ -319,7 +320,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery { recoveryModelString = BackupConstants.RecoveryModelBulk; } - + return recoveryModelString; } @@ -342,6 +343,69 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery return backupFolder; } + public static string GetDefaultDataFolder(ServerConnection connection) + { + string dataFolder = string.Empty; + + Enumerator en = null; + DataSet ds = new DataSet(); + ds.Locale = System.Globalization.CultureInfo.InvariantCulture; + Request req = new Request(); + en = new Enumerator(); + req.Urn = "Server/Setting"; + ds = en.Process(connection, req); + + if (ds.Tables[0].Rows.Count > 0) + { + dataFolder = Convert.ToString(ds.Tables[0].Rows[0]["DefaultFile"], System.Globalization.CultureInfo.InvariantCulture); + } + return dataFolder; + } + + public static string[] GetAssociatedFilePaths(ServerConnection connection, string primaryFilePath) + { + var databaseFiles = new List(); + + var enumerator = new Enumerator(); + var req = new Request() + { + Urn = $"Server/PrimaryFile[@Name='{Urn.EscapeString(primaryFilePath)}']/File", + Fields = new string[] { "IsFile", "FileName" } + }; + var dataTable = (DataTable)enumerator.Process(connection, req); + + foreach (DataRow currentRow in dataTable.Rows) + { + var primaryFolder = Path.GetDirectoryName(primaryFilePath); + var originalPath = (string)currentRow["FileName"]; + var originalFileName = Path.GetFileName(originalPath); + var filePath = Path.Join(primaryFolder, originalFileName); + + // Check if file exists with the constructed path. + // If it's an XI (XStore Integration) path, then assume it exists, otherwise retrieve info for the file to check if it exists. + var exists = true; + var isXIPath = PathWrapper.IsXIPath(primaryFilePath); + if (!isXIPath) + { + var request = new Request() + { + Urn = string.Format(System.Globalization.CultureInfo.CurrentCulture, "Server/File[@FullName='{0}']", Urn.EscapeString(filePath)), + Fields = new string[] { "IsFile" } + }; + + DataTable data = (new Enumerator()).Process(connection, request); + + // If the enumerator could find the file, then it exists + exists = data?.Rows.Count > 0; + } + if (exists) + { + databaseFiles.Add(filePath); + } + } + return databaseFiles.ToArray(); + } + public int GetMediaRetentionValue() { int afterDays = 0; @@ -353,29 +417,29 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery ds.Locale = System.Globalization.CultureInfo.InvariantCulture; req.Urn = "Server/Configuration"; - ds = en.Process(this.sqlConnection, req); - for (int i = 0 ; i < ds.Tables[0].Rows.Count; i++) + ds = en.Process(this.sqlConnection, req); + for (int i = 0; i < ds.Tables[0].Rows.Count; i++) { if (Convert.ToString(ds.Tables[0].Rows[i]["Name"], System.Globalization.CultureInfo.InvariantCulture) == "media retention") { afterDays = Convert.ToInt32(ds.Tables[0].Rows[i]["RunValue"], System.Globalization.CultureInfo.InvariantCulture); break; } - } + } return afterDays; } catch (Exception) - { + { return afterDays; - } + } } public string GetMediaNameFromBackupSetId(int backupSetId) { - Enumerator en = null; - DataSet ds = new DataSet(); + Enumerator en = null; + DataSet ds = new DataSet(); ds.Locale = System.Globalization.CultureInfo.InvariantCulture; - Request req = new Request(); + Request req = new Request(); int mediaId = -1; string mediaName = string.Empty; @@ -402,7 +466,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery catch (Exception) { } - return mediaName; + return mediaName; } public string GetFileType(string type) @@ -429,8 +493,8 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery return result; } - - + + // TODO: This is implemented as internal property in SMO. public bool IsLocalPrimaryReplica(string databaseName) @@ -450,7 +514,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery } return !string.IsNullOrEmpty(server.Databases[databaseName].AvailabilityGroupName); } - + /// /// Returns whether mirroring is enabled on a database or not /// > @@ -509,19 +573,19 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery return result; } - + public bool IsDatabaseOnServer(string databaseName) { Enumerator en = new Enumerator(); DataSet ds = new DataSet(); ds.Locale = System.Globalization.CultureInfo.InvariantCulture; - Request req = new Request(); - + Request req = new Request(); + req.Urn = "Server/Database[@Name='" + Urn.EscapeString(databaseName) + "']"; req.Fields = new string[1]; req.Fields[0] = "Name"; - ds = en.Process(sqlConnection, req); + ds = en.Process(sqlConnection, req); return (ds.Tables[0].Rows.Count > 0) ? true : false; } @@ -617,7 +681,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery case BackupSetType.Differential: backupType = RestoreConstants.TypeDifferential; backupComponent = RestoreConstants.ComponentDatabase; - break; + break; case BackupSetType.FileOrFileGroup: backupType = RestoreConstants.TypeFilegroup; backupComponent = RestoreConstants.ComponentFile; @@ -636,9 +700,9 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery break; } } - + public void GetBackupSetTypeAndComponent(string strType, ref string backupType, ref string backupComponent) - { + { string type = strType.ToUpperInvariant(); if (type == "D") @@ -682,7 +746,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery } } } - } + } } public void GetFileType(string backupType, string tempFileType, ref string fileType) @@ -694,20 +758,23 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery { switch (type) { - case "D": fileType = RestoreConstants.Data; + case "D": + fileType = RestoreConstants.Data; break; - case "S": fileType = RestoreConstants.FileStream; + case "S": + fileType = RestoreConstants.FileStream; break; - default: fileType = RestoreConstants.NotKnown; + default: + fileType = RestoreConstants.NotKnown; break; } } } - + public BackupsetType GetBackupsetTypeFromBackupTypesOnDevice(int type) - { + { BackupsetType Result = BackupsetType.BackupsetDatabase; - switch(type) + switch (type) { case 1: Result = BackupsetType.BackupsetDatabase; @@ -728,11 +795,11 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery return Result; } - + public BackupsetType GetBackupsetTypeFromBackupTypesOnHistory(string type) { BackupsetType result = BackupsetType.BackupsetDatabase; - switch(type) + switch (type) { case "D": result = BackupsetType.BackupsetDatabase; @@ -752,7 +819,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery } return result; } - + public DataSet GetBackupSetFiles(int backupsetId) { Enumerator en = new Enumerator(); @@ -760,7 +827,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery DataSet backupsetfiles = new DataSet(); backupsetfiles.Locale = System.Globalization.CultureInfo.InvariantCulture; - if(backupsetId > 0) + if (backupsetId > 0) { req.Urn = "Server/BackupSet[@ID='" + Urn.EscapeString(Convert.ToString(backupsetId, System.Globalization.CultureInfo.InvariantCulture)) + "']/File"; } @@ -772,9 +839,9 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery return backupsetfiles; } - + public DataSet GetBackupSetById(int backupsetId) - { + { SqlExecutionModes executionMode = this.sqlConnection.SqlExecutionModes; this.sqlConnection.SqlExecutionModes = SqlExecutionModes.ExecuteSql; Enumerator en = new Enumerator(); @@ -788,47 +855,47 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery this.sqlConnection.SqlExecutionModes = executionMode; return backupset; } - + public ArrayList GetBackupSetPhysicalSources(int backupsetId) - { + { SqlExecutionModes executionMode = this.sqlConnection.SqlExecutionModes; this.sqlConnection.SqlExecutionModes = SqlExecutionModes.ExecuteSql; ArrayList sources = new ArrayList(); DataSet backupSet = GetBackupSetById(backupsetId); - if(backupSet.Tables[0].Rows.Count == 1) + if (backupSet.Tables[0].Rows.Count == 1) { string mediaSetID = Convert.ToString(backupSet.Tables[0].Rows[0]["MediaSetId"], System.Globalization.CultureInfo.InvariantCulture); Enumerator en = new Enumerator(); Request req = new Request(); DataSet mediafamily = new DataSet(); - mediafamily.Locale = System.Globalization.CultureInfo.InvariantCulture; + mediafamily.Locale = System.Globalization.CultureInfo.InvariantCulture; - req.Urn = "Server/BackupMediaSet[@ID='"+Urn.EscapeString(mediaSetID)+"']/MediaFamily"; + req.Urn = "Server/BackupMediaSet[@ID='" + Urn.EscapeString(mediaSetID) + "']/MediaFamily"; mediafamily = en.Process(this.sqlConnection, req); if (mediafamily.Tables[0].Rows.Count > 0) { - for (int j = 0 ; j < mediafamily.Tables[0].Rows.Count; j ++) + for (int j = 0; j < mediafamily.Tables[0].Rows.Count; j++) { RestoreItemSource itemSource = new RestoreItemSource(); itemSource.RestoreItemLocation = Convert.ToString(mediafamily.Tables[0].Rows[j]["PhysicalDeviceName"], System.Globalization.CultureInfo.InvariantCulture); BackupDeviceType backupDeviceType = (BackupDeviceType)Enum.Parse(typeof(BackupDeviceType), mediafamily.Tables[0].Rows[j]["BackupDeviceType"].ToString()); - + if (BackupDeviceType.Disk == backupDeviceType) { - itemSource.RestoreItemDeviceType = DeviceType.File; + itemSource.RestoreItemDeviceType = DeviceType.File; } else if (BackupDeviceType.Url == backupDeviceType) { itemSource.RestoreItemDeviceType = DeviceType.Url; - } + } else { - itemSource.RestoreItemDeviceType = DeviceType.Tape; - } - sources.Add(itemSource); + itemSource.RestoreItemDeviceType = DeviceType.Tape; + } + sources.Add(itemSource); } } } @@ -836,11 +903,11 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery this.sqlConnection.SqlExecutionModes = executionMode; return sources; } - + public RestoreActionType GetRestoreTaskFromBackupSetType(BackupsetType type) { RestoreActionType result = RestoreActionType.Database; - + switch (type) { case BackupsetType.BackupsetDatabase: @@ -861,16 +928,16 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery } return result; } - + public int GetLatestBackup(string databaseName, string backupSetName) { Enumerator en = new Enumerator(); Request req = new Request(); DataSet backupSets = new DataSet(); - backupSets.Locale = System.Globalization.CultureInfo.InvariantCulture; + backupSets.Locale = System.Globalization.CultureInfo.InvariantCulture; OrderBy orderByBackupDate; - req.Urn = "Server/BackupSet[@Name='"+Urn.EscapeString(backupSetName)+"' and @DatabaseName='"+ Urn.EscapeString(databaseName)+"']"; + req.Urn = "Server/BackupSet[@Name='" + Urn.EscapeString(backupSetName) + "' and @DatabaseName='" + Urn.EscapeString(databaseName) + "']"; req.OrderByList = new OrderBy[1]; orderByBackupDate = new OrderBy("BackupFinishDate", OrderBy.Direction.Desc); req.OrderByList[0] = orderByBackupDate; @@ -885,7 +952,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery return -1; } } - + public List GetLatestBackupLocations(string databaseName) { List latestLocations = new List(); @@ -944,20 +1011,20 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery } /// LPU doesn't have rights to enumerate msdb.backupset catch (Exception) - { - } - return latestLocations; - } - + { + } + return latestLocations; + } + public string GetDefaultDatabaseForLogin(string loginName) - { - string defaultDatabase = string.Empty; + { + string defaultDatabase = string.Empty; Enumerator en = new Enumerator(); DataSet ds = new DataSet(); ds.Locale = System.Globalization.CultureInfo.InvariantCulture; - Request req = new Request(); - - req.Urn = "Server/Login[@Name='"+Urn.EscapeString(loginName)+"']"; + Request req = new Request(); + + req.Urn = "Server/Login[@Name='" + Urn.EscapeString(loginName) + "']"; req.Fields = new string[1]; req.Fields[0] = "DefaultDatabase"; ds = en.Process(this.sqlConnection, req); @@ -991,26 +1058,26 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery } public ArrayList IsPhysicalPathInLogicalDevice(string physicalPath) - { + { Enumerator en = new Enumerator(); DataSet ds = new DataSet(); ds.Locale = System.Globalization.CultureInfo.InvariantCulture; Request req = new Request(); ArrayList result = null; - int count = 0; - req.Urn = "Server/BackupDevice[@PhysicalLocation='" +Urn.EscapeString(physicalPath)+ "']"; + int count = 0; + req.Urn = "Server/BackupDevice[@PhysicalLocation='" + Urn.EscapeString(physicalPath) + "']"; - ds = en.Process(this.sqlConnection, req); + ds = en.Process(this.sqlConnection, req); count = ds.Tables[0].Rows.Count; - + if (count > 0) { result = new ArrayList(count); - for(int i = 0; i < count; i++) + for (int i = 0; i < count; i++) { result.Add(Convert.ToString(ds.Tables[0].Rows[0]["Name"], System.Globalization.CultureInfo.InvariantCulture)); } - } + } return result; } @@ -1025,7 +1092,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery return System.Environment.MachineName; } - string machineName = sqlServerName; + string machineName = sqlServerName; if (sqlServerName.Trim().Length != 0) { // [0] = machine, [1] = instance diff --git a/src/Microsoft.SqlTools.ServiceLayer/ObjectManagement/Contracts/AttachDatabaseRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/ObjectManagement/Contracts/AttachDatabaseRequest.cs new file mode 100644 index 00000000..2ed7c120 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/ObjectManagement/Contracts/AttachDatabaseRequest.cs @@ -0,0 +1,29 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +#nullable disable +using Microsoft.SqlTools.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement.Contracts +{ + public class DatabaseFileData + { + public string DatabaseName { get; set; } + public string[] DatabaseFilePaths { get; set; } + public string Owner { get; set; } + } + + public class AttachDatabaseRequestParams + { + public string ConnectionUri { get; set; } + public DatabaseFileData[] Databases { get; set; } + public bool GenerateScript { get; set; } + } + + public class AttachDatabaseRequest + { + public static readonly RequestType Type = RequestType.Create("objectManagement/attachDatabase"); + } +} \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/ObjectManagement/ObjectManagementService.cs b/src/Microsoft.SqlTools.ServiceLayer/ObjectManagement/ObjectManagementService.cs index 13fc075f..9228c7f5 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/ObjectManagement/ObjectManagementService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/ObjectManagement/ObjectManagementService.cs @@ -70,6 +70,7 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement this.serviceHost.SetRequestHandler(DisposeViewRequest.Type, HandleDisposeViewRequest, true); this.serviceHost.SetRequestHandler(SearchRequest.Type, HandleSearchRequest, true); this.serviceHost.SetRequestHandler(DetachDatabaseRequest.Type, HandleDetachDatabaseRequest, true); + this.serviceHost.SetRequestHandler(AttachDatabaseRequest.Type, HandleAttachDatabaseRequest, true); this.serviceHost.SetRequestHandler(DropDatabaseRequest.Type, HandleDropDatabaseRequest, true); } @@ -155,7 +156,7 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement } SearchableObjectTypeDescription desc = SearchableObjectTypeDescription.GetDescription(searchableObjectType); - + if (desc.IsDatabaseObject) { if (!string.IsNullOrEmpty(requestParams.Schema)) @@ -207,6 +208,13 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement await requestContext.SendResult(sqlScript); } + internal async Task HandleAttachDatabaseRequest(AttachDatabaseRequestParams requestParams, RequestContext requestContext) + { + var handler = this.GetObjectTypeHandler(SqlObjectType.Database) as DatabaseHandler; + var sqlScript = handler.Attach(requestParams); + await requestContext.SendResult(sqlScript); + } + internal async Task HandleDropDatabaseRequest(DropDatabaseRequestParams requestParams, RequestContext requestContext) { var handler = this.GetObjectTypeHandler(SqlObjectType.Database) as DatabaseHandler; diff --git a/src/Microsoft.SqlTools.ServiceLayer/ObjectManagement/ObjectTypes/Database/DatabaseHandler.cs b/src/Microsoft.SqlTools.ServiceLayer/ObjectManagement/ObjectTypes/Database/DatabaseHandler.cs index 31b893e0..2cf96c7f 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/ObjectManagement/ObjectTypes/Database/DatabaseHandler.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/ObjectManagement/ObjectTypes/Database/DatabaseHandler.cs @@ -20,6 +20,7 @@ using Microsoft.SqlTools.Utility; using System.Text; using System.IO; using Microsoft.SqlTools.ServiceLayer.Utility.SqlScriptFormatters; +using System.Collections.Specialized; using Microsoft.SqlTools.SqlCore.Utility; namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement @@ -393,6 +394,60 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectManagement return builder.ToString(); } + public string Attach(AttachDatabaseRequestParams attachParams) + { + var sqlScript = string.Empty; + ConnectionInfo connectionInfo = this.GetConnectionInfo(attachParams.ConnectionUri); + using (var dataContainer = CreateDatabaseDataContainer(attachParams.ConnectionUri, null, true, null)) + { + var server = dataContainer.Server!; + var originalExecuteMode = server.ConnectionContext.SqlExecutionModes; + if (attachParams.GenerateScript) + { + server.ConnectionContext.SqlExecutionModes = SqlExecutionModes.CaptureSql; + server.ConnectionContext.CapturedSql.Clear(); + } + try + { + foreach (var database in attachParams.Databases) + { + var fileCollection = new StringCollection(); + fileCollection.AddRange(database.DatabaseFilePaths); + if (database.Owner != SR.general_default) + { + server.AttachDatabase(database.DatabaseName, fileCollection, database.Owner); + } + else + { + server.AttachDatabase(database.DatabaseName, fileCollection); + } + } + if (attachParams.GenerateScript) + { + var builder = new StringBuilder(); + var capturedText = server.ConnectionContext.CapturedSql.Text; + foreach (var entry in capturedText) + { + if (entry != null) + { + builder.AppendLine(entry); + } + } + sqlScript = builder.ToString(); + } + } + finally + { + if (attachParams.GenerateScript) + { + server.ConnectionContext.SqlExecutionModes = originalExecuteMode; + } + dataContainer.ServerConnection.Disconnect(); + } + } + return sqlScript; + } + /// /// Used to drop the specified database /// diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ObjectManagement/DatabaseHandlerTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ObjectManagement/DatabaseHandlerTests.cs index 42c3a9c9..4ba271a4 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ObjectManagement/DatabaseHandlerTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ObjectManagement/DatabaseHandlerTests.cs @@ -3,9 +3,12 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // +#nullable disable + using System.Collections.Generic; using System.Collections.Specialized; using System.Linq; +using System.Text; using System.Threading.Tasks; using Microsoft.Data.SqlClient; using Microsoft.SqlServer.Management.Common; @@ -356,7 +359,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectManagement testDatabase.DatabaseScopedConfigurations[0].ValueForSecondary = "OFF"; } await ObjectManagementTestUtils.SaveObject(parametersForUpdate, testDatabase); - DatabaseViewInfo updatedDatabaseViewInfo = await ObjectManagementTestUtils.GetDatabaseObject(parametersForUpdate, testDatabase); + DatabaseViewInfo updatedDatabaseViewInfo = await ObjectManagementTestUtils.GetDatabaseObject(parametersForUpdate, testDatabase); // verify the modified properties Assert.That(((DatabaseInfo)updatedDatabaseViewInfo.ObjectInfo).DatabaseScopedConfigurations[0].ValueForPrimary, Is.EqualTo(testDatabase.DatabaseScopedConfigurations[0].ValueForPrimary), $"DSC updated primary value should match"); @@ -500,6 +503,107 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectManagement } } + [Test] + [TestCase(true)] + [TestCase(false)] + public async Task AttachDatabaseTest(bool generateScript) + { + using (SqlTestDb testDb = await SqlTestDb.CreateNewAsync(TestServerType.OnPrem, false, null, null, nameof(AttachDatabaseTest))) + { + var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", serverType: TestServerType.OnPrem); + using (SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connectionResult.ConnectionInfo)) + { + var serverConn = new ServerConnection(sqlConn); + var server = new Server(serverConn); + var objUrn = ObjectManagementTestUtils.GetDatabaseURN(testDb.DatabaseName); + var database = server.GetSmoObject(objUrn) as Database; + + var originalOwner = database!.Owner; + var originalFilePaths = new List(); + foreach (FileGroup group in database.FileGroups) + { + foreach (DataFile file in group.Files) + { + originalFilePaths.Add(file.FileName); + } + } + foreach (LogFile file in database.LogFiles) + { + originalFilePaths.Add(file.FileName); + } + + // Detach database so that we can re-attach it with the database handler method. + // Have to set database to single user mode to close active connections before detaching it. + database.DatabaseOptions.UserAccess = SqlServer.Management.Smo.DatabaseUserAccess.Single; + database.Alter(TerminationClause.RollbackTransactionsImmediately); + server.DetachDatabase(testDb.DatabaseName, false); + var dbExists = this.DatabaseExists(testDb.DatabaseName, server); + Assert.That(dbExists, Is.False, "Database was not correctly detached before doing attach test."); + + try + { + var handler = new DatabaseHandler(ConnectionService.Instance); + var attachParams = new AttachDatabaseRequestParams() + { + ConnectionUri = connectionResult.ConnectionInfo.OwnerUri, + Databases = new DatabaseFileData[] + { + new DatabaseFileData() + { + Owner = originalOwner, + DatabaseName = testDb.DatabaseName, + DatabaseFilePaths = originalFilePaths.ToArray() + } + }, + GenerateScript = generateScript + }; + var script = handler.Attach(attachParams); + + if (generateScript) + { + dbExists = this.DatabaseExists(testDb.DatabaseName, server); + Assert.That(dbExists, Is.False, "Should not have attached DB when only generating a script."); + + var queryBuilder = new StringBuilder(); + queryBuilder.AppendLine("USE [master]"); + queryBuilder.AppendLine($"CREATE DATABASE [{testDb.DatabaseName}] ON "); + + for (int i = 0; i < originalFilePaths.Count - 1; i++) + { + var file = originalFilePaths[i]; + queryBuilder.AppendLine($"( FILENAME = N'{file}' ),"); + } + queryBuilder.AppendLine($"( FILENAME = N'{originalFilePaths[originalFilePaths.Count - 1]}' )"); + + queryBuilder.AppendLine(" FOR ATTACH"); + queryBuilder.AppendLine($"if exists (select name from master.sys.databases sd where name = N'{testDb.DatabaseName}' and SUSER_SNAME(sd.owner_sid) = SUSER_SNAME() ) EXEC [{testDb.DatabaseName}].dbo.sp_changedbowner @loginame=N'{originalOwner}', @map=false"); + + Assert.That(script, Is.EqualTo(queryBuilder.ToString()), "Did not get expected attach database script"); + } + else + { + Assert.That(script, Is.Empty, "Should not have generated a script for this Attach operation."); + + server.Databases.Refresh(); + dbExists = this.DatabaseExists(testDb.DatabaseName, server); + Assert.That(dbExists, "Database was not attached successfully"); + } + } + finally + { + dbExists = this.DatabaseExists(testDb.DatabaseName, server); + if (!dbExists) + { + // Reattach database so it can get dropped during cleanup + var fileCollection = new StringCollection(); + originalFilePaths.ForEach(file => fileCollection.Add(file)); + server.AttachDatabase(testDb.DatabaseName, fileCollection); + } + } + } + } + } + [Test] public async Task DeleteDatabaseTest() { diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ObjectManagement/UtilsTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ObjectManagement/UtilsTests.cs new file mode 100644 index 00000000..24eb55a3 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ObjectManagement/UtilsTests.cs @@ -0,0 +1,135 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +#nullable disable + +using System.Collections.Generic; +using System.Collections.Specialized; +using System.IO; +using System.Threading.Tasks; +using Microsoft.Data.SqlClient; +using Microsoft.SqlServer.Management.Common; +using Microsoft.SqlServer.Management.Smo; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.DisasterRecovery; +using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility; +using Microsoft.SqlTools.ServiceLayer.Test.Common; +using NUnit.Framework; + +namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectManagement +{ + public class UtilsTests + { + [Test] + public async Task GetDataFolderTest() + { + using (SqlTestDb testDb = await SqlTestDb.CreateNewAsync(TestServerType.OnPrem, false, null, null, nameof(GetDataFolderTest))) + { + var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync(testDb.DatabaseName, serverType: TestServerType.OnPrem); + using (SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connectionResult.ConnectionInfo)) + { + var serverConn = new ServerConnection(sqlConn); + var server = new Server(serverConn); + var objUrn = ObjectManagementTestUtils.GetDatabaseURN(testDb.DatabaseName); + var database = server.GetSmoObject(objUrn) as Database; + + var dataFilePath = database.FileGroups[0].Files[0].FileName; + var expectedDataFolder = Path.GetDirectoryName(dataFilePath).ToString(); + + var actualDataFolder = CommonUtilities.GetDefaultDataFolder(serverConn); + actualDataFolder = Path.TrimEndingDirectorySeparator(actualDataFolder); + Assert.That(actualDataFolder, Is.EqualTo(expectedDataFolder).IgnoreCase, "Did not get expected data file folder path."); + } + } + } + + [Test] + public async Task GetAssociatedFilesTest() + { + using (SqlTestDb testDb = await SqlTestDb.CreateNewAsync(TestServerType.OnPrem, false, null, null, nameof(GetAssociatedFilesTest))) + { + var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync(testDb.DatabaseName, serverType: TestServerType.OnPrem); + using (SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connectionResult.ConnectionInfo)) + { + var serverConn = new ServerConnection(sqlConn); + var server = new Server(serverConn); + var objUrn = ObjectManagementTestUtils.GetDatabaseURN(testDb.DatabaseName); + var database = server.GetSmoObject(objUrn) as Database; + + var expectedFilePaths = new List(); + DataFile primaryFile = null; + foreach (FileGroup group in database.FileGroups) + { + foreach (DataFile file in group.Files) + { + expectedFilePaths.Add(file.FileName); + if (file.IsPrimaryFile) + { + primaryFile = file; + } + } + } + foreach (LogFile file in database.LogFiles) + { + expectedFilePaths.Add(file.FileName); + } + + // Detach database so that we don't throw an error when trying to access the primary data file + // Have to set database to single user mode to close active connections before detaching it. + database.DatabaseOptions.UserAccess = SqlServer.Management.Smo.DatabaseUserAccess.Single; + database.Alter(TerminationClause.RollbackTransactionsImmediately); + server.DetachDatabase(testDb.DatabaseName, false); + try + { + Assert.That(primaryFile, Is.Not.Null, "Could not find a primary file in the list of database files."); + var actualFilePaths = CommonUtilities.GetAssociatedFilePaths(serverConn, primaryFile.FileName); + Assert.That(actualFilePaths, Is.EqualTo(expectedFilePaths).IgnoreCase, "The list of associated files did not match the actual files for the database."); + } + finally + { + // Reattach database so it can get dropped during cleanup + var fileCollection = new StringCollection(); + expectedFilePaths.ForEach(file => fileCollection.Add(file)); + server.AttachDatabase(testDb.DatabaseName, fileCollection); + } + } + } + } + + [Test] + public async Task ThrowErrorWhenDatabaseExistsTest() + { + using (SqlTestDb testDb = await SqlTestDb.CreateNewAsync(TestServerType.OnPrem, false, null, null, nameof(ThrowErrorWhenDatabaseExistsTest))) + { + var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync(testDb.DatabaseName, serverType: TestServerType.OnPrem); + using (SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connectionResult.ConnectionInfo)) + { + var serverConn = new ServerConnection(sqlConn); + var server = new Server(serverConn); + var objUrn = ObjectManagementTestUtils.GetDatabaseURN(testDb.DatabaseName); + var database = server.GetSmoObject(objUrn) as Database; + + DataFile primaryFile = null; + foreach (FileGroup group in database.FileGroups) + { + foreach (DataFile file in group.Files) + { + if (file.IsPrimaryFile) + { + primaryFile = file; + } + } + } + + Assert.That( + () => CommonUtilities.GetAssociatedFilePaths(serverConn, primaryFile.FileName), + Throws.Exception, + "Should throw an error when trying to open a database file that's already in use." + ); + } + } + } + } +} \ No newline at end of file