diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionType.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionType.cs index aabbb7d9..ffa04f46 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionType.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionType.cs @@ -15,5 +15,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection { public const string Default = "Default"; public const string Query = "Query"; + public const string Edit = "Edit"; } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/EditCreateRowRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/EditCreateRowRequest.cs new file mode 100644 index 00000000..2e4d7619 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/EditCreateRowRequest.cs @@ -0,0 +1,35 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.EditData.Contracts +{ + /// + /// Parameters for the update cell request + /// + public class EditCreateRowParams : SessionOperationParams + { + } + + /// + /// Parameters to return upon successful addition of a row to the edit session + /// + public class EditCreateRowResult + { + /// + /// The internal ID of the newly created row + /// + public long NewRowId { get; set; } + } + + public class EditCreateRowRequest + { + public static readonly + RequestType Type = + RequestType.Create("edit/createRow"); + + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/EditDeleteRowRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/EditDeleteRowRequest.cs new file mode 100644 index 00000000..248adcfb --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/EditDeleteRowRequest.cs @@ -0,0 +1,30 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.EditData.Contracts +{ + /// + /// Parameters for identifying a row to mark for deletion + /// + public class EditDeleteRowParams : RowOperationParams + { + } + + /// + /// Parameters to return upon successfully adding row delete to update cache + /// + public class EditDeleteRowResult + { + } + + public class EditDeleteRowRequest + { + public static readonly + RequestType Type = + RequestType.Create("edit/deleteRow"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/EditDisposeRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/EditDisposeRequest.cs new file mode 100644 index 00000000..2520eb2c --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/EditDisposeRequest.cs @@ -0,0 +1,28 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.EditData.Contracts +{ + /// + /// Parameters of the edit session dispose request + /// + public class EditDisposeParams : SessionOperationParams + { + } + + /// + /// Object to return upon successful disposal of an edit session + /// + public class EditDisposeResult { } + + public class EditDisposeRequest + { + public static readonly + RequestType Type = + RequestType.Create("edit/dispose"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/EditInitializeRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/EditInitializeRequest.cs new file mode 100644 index 00000000..72998823 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/EditInitializeRequest.cs @@ -0,0 +1,42 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.EditData.Contracts +{ + /// + /// Parameters of the edit session initialize request + /// + public class EditInitializeParams : SessionOperationParams + { + /// + /// The object to use for generating an edit script + /// + public string ObjectName { get; set; } + + /// + /// The type of the object to use for generating an edit script + /// + public string ObjectType { get; set; } + } + + /// + /// Object to return upon successful completion of an edit session initialize request + /// + /// + /// Empty for now, since there isn't anything special to return on success + /// + public class EditInitializeResult + { + } + + public class EditInitializeRequest + { + public static readonly + RequestType Type = + RequestType.Create("edit/initialize"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/EditRevertRowRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/EditRevertRowRequest.cs new file mode 100644 index 00000000..53049b67 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/EditRevertRowRequest.cs @@ -0,0 +1,30 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.EditData.Contracts +{ + /// + /// Parameters for the revert row request + /// + public class EditRevertRowParams : RowOperationParams + { + } + + /// + /// Parameters to return upon successful revert of a row + /// + public class EditRevertRowResult + { + } + + public class EditRevertRowRequest + { + public static readonly + RequestType Type = + RequestType.Create("edit/revertRow"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/EditSessionReadyEvent.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/EditSessionReadyEvent.cs new file mode 100644 index 00000000..150b53c7 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/EditSessionReadyEvent.cs @@ -0,0 +1,29 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.EditData.Contracts +{ + public class EditSessionReadyParams + { + /// + /// URI for the editor + /// + public string OwnerUri { get; set; } + + /// + /// Whether or not the session is ready + /// + public bool Success { get; set; } + } + + public class EditSessionReadyEvent + { + public static readonly + EventType Type = + EventType.Create("edit/sessionReady"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/EditUpdateCellRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/EditUpdateCellRequest.cs new file mode 100644 index 00000000..b6114713 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/EditUpdateCellRequest.cs @@ -0,0 +1,62 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.EditData.Contracts +{ + /// + /// Parameters for the update cell request + /// + public class EditUpdateCellParams : RowOperationParams + { + /// + /// Internal ID of the column to update + /// + public int ColumnId { get; set; } + + /// + /// String representation of the value to assign to the cell + /// + public string NewValue { get; set; } + } + + /// + /// Parameters to return upon successful update of the cell + /// + public class EditUpdateCellResult + { + /// + /// Whether or not the cell value was modified from the provided string. + /// If true, the client should replace the display value of the cell with the value + /// in + /// + public bool HasCorrections { get; set; } + + /// + /// Whether or not the cell was reverted with the change. + /// If true, the client should unmark the cell as having an update and replace the + /// display value of the cell with the value in + /// + public bool IsRevert { get; set; } + + /// + /// Whether or not the new value of the cell is null + /// + public bool IsNull { get; set; } + + /// + /// The new string value of the cell + /// + public string NewValue { get; set; } + } + + public class EditUpdateCellRequest + { + public static readonly + RequestType Type = + RequestType.Create("edit/updateCell"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/RowOperationParams.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/RowOperationParams.cs new file mode 100644 index 00000000..050d7676 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/RowOperationParams.cs @@ -0,0 +1,18 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.EditData.Contracts +{ + /// + /// Abstract class for parameters that require an OwnerUri and a RowId + /// + public abstract class RowOperationParams : SessionOperationParams + { + /// + /// Internal ID of the row to revert + /// + public long RowId { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/SessionOperationParams.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/SessionOperationParams.cs new file mode 100644 index 00000000..e46ecf98 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/Contracts/SessionOperationParams.cs @@ -0,0 +1,18 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.EditData.Contracts +{ + /// + /// Abstract class for parameters that require an OwnerUri + /// + public abstract class SessionOperationParams + { + /// + /// Owner URI for the session to add new row to + /// + public string OwnerUri { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/EditColumnWrapper.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/EditColumnWrapper.cs new file mode 100644 index 00000000..28177766 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/EditColumnWrapper.cs @@ -0,0 +1,41 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.EditData +{ + /// + /// Small class that stores information needed by the edit data service to properly process + /// edits into scripts. + /// + public class EditColumnWrapper + { + /// + /// The DB column + /// + public DbColumnWrapper DbColumn { get; set; } + + /// + /// Escaped identifier for the name of the column + /// + public string EscapedName { get; set; } + + /// + /// Whether or not the column is used in a key to uniquely identify a row + /// + public bool IsKey { get; set; } + + /// + /// Whether or not the column can be trusted for uniqueness + /// + public bool IsTrustworthyForUniqueness { get; set; } + + /// + /// The ordinal ID of the column + /// + public int Ordinal { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/EditDataService.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/EditDataService.cs new file mode 100644 index 00000000..3f3c9cf1 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/EditDataService.cs @@ -0,0 +1,286 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Concurrent; +using System.Data.Common; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.EditData.Contracts; +using Microsoft.SqlTools.ServiceLayer.Hosting; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts.ExecuteRequests; +using Microsoft.SqlTools.ServiceLayer.Utility; +using ConnectionType = Microsoft.SqlTools.ServiceLayer.Connection.ConnectionType; + +namespace Microsoft.SqlTools.ServiceLayer.EditData +{ + /// + /// Service that handles edit data scenarios + /// + public class EditDataService + { + #region Singleton Instance Implementation + + private static readonly Lazy LazyInstance = new Lazy(() => new EditDataService()); + + public static EditDataService Instance => LazyInstance.Value; + + private EditDataService() + { + queryExecutionService = QueryExecutionService.Instance; + connectionService = ConnectionService.Instance; + metadataFactory = new SmoEditMetadataFactory(); + } + + internal EditDataService(QueryExecutionService qes, ConnectionService cs, IEditMetadataFactory factory) + { + queryExecutionService = qes; + connectionService = cs; + metadataFactory = factory; + } + + #endregion + + #region Member Variables + + private readonly ConnectionService connectionService; + + private readonly IEditMetadataFactory metadataFactory; + + private readonly QueryExecutionService queryExecutionService; + + private readonly Lazy> editSessions = new Lazy>( + () => new ConcurrentDictionary()); + + #endregion + + #region Properties + + /// + /// Dictionary mapping OwnerURIs to active sessions + /// + internal ConcurrentDictionary ActiveSessions => editSessions.Value; + + #endregion + + /// + /// Initializes the edit data service with the service host + /// + /// The service host to register commands/events with + public void InitializeService(ServiceHost serviceHost) + { + // Register handlers for requests + serviceHost.SetRequestHandler(EditCreateRowRequest.Type, HandleCreateRowRequest); + serviceHost.SetRequestHandler(EditDeleteRowRequest.Type, HandleDeleteRowRequest); + serviceHost.SetRequestHandler(EditDisposeRequest.Type, HandleDisposeRequest); + serviceHost.SetRequestHandler(EditInitializeRequest.Type, HandleInitializeRequest); + serviceHost.SetRequestHandler(EditRevertRowRequest.Type, HandleRevertRowRequest); + serviceHost.SetRequestHandler(EditUpdateCellRequest.Type, HandleUpdateCellRequest); + } + + #region Request Handlers + + internal async Task HandleSessionRequest(SessionOperationParams sessionParams, + RequestContext requestContext, Func sessionOperation) + { + try + { + Session session = GetActiveSessionOrThrow(sessionParams.OwnerUri); + + // Get the result from execution of the session operation + TResult result = sessionOperation(session); + await requestContext.SendResult(result); + } + catch (Exception e) + { + await requestContext.SendError(e.Message); + } + } + + internal Task HandleCreateRowRequest(EditCreateRowParams createParams, + RequestContext requestContext) + { + return HandleSessionRequest(createParams, requestContext, session => + { + // Create the row and get the ID of the new row + long newRowId = session.CreateRow(); + return new EditCreateRowResult + { + NewRowId = newRowId + }; + }); + } + + internal Task HandleDeleteRowRequest(EditDeleteRowParams deleteParams, + RequestContext requestContext) + { + return HandleSessionRequest(deleteParams, requestContext, session => + { + // Add the delete row to the edit cache + session.DeleteRow(deleteParams.RowId); + return new EditDeleteRowResult(); + }); + } + + internal async Task HandleDisposeRequest(EditDisposeParams disposeParams, + RequestContext requestContext) + { + try + { + // Sanity check the owner URI + Validate.IsNotNullOrWhitespaceString(nameof(disposeParams.OwnerUri), disposeParams.OwnerUri); + + // Attempt to remove the session + Session session; + if (!ActiveSessions.TryRemove(disposeParams.OwnerUri, out session)) + { + await requestContext.SendError(SR.EditDataSessionNotFound); + return; + } + + // Everything was successful, return success + await requestContext.SendResult(new EditDisposeResult()); + } + catch (Exception e) + { + await requestContext.SendError(e.Message); + } + } + + internal async Task HandleInitializeRequest(EditInitializeParams initParams, + RequestContext requestContext) + { + try + { + // Make sure we have info to process this request + Validate.IsNotNullOrWhitespaceString(nameof(initParams.OwnerUri), initParams.OwnerUri); + Validate.IsNotNullOrWhitespaceString(nameof(initParams.ObjectName), initParams.ObjectName); + + // Setup a callback for when the query has successfully created + Func> queryCreateSuccessCallback = async query => + { + await requestContext.SendResult(new EditInitializeResult()); + return true; + }; + + // Setup a callback for when the query failed to be created + Func queryCreateFailureCallback = requestContext.SendError; + + // Setup a callback for when the query completes execution successfully + Query.QueryAsyncEventHandler queryCompleteSuccessCallback = + q => QueryCompleteCallback(q, initParams, requestContext); + + // Setup a callback for when the query completes execution with failure + Query.QueryAsyncEventHandler queryCompleteFailureCallback = query => + { + EditSessionReadyParams readyParams = new EditSessionReadyParams + { + OwnerUri = initParams.OwnerUri, + Success = false + }; + return requestContext.SendEvent(EditSessionReadyEvent.Type, readyParams); + }; + + // Put together a query for the results and execute it + ExecuteStringParams executeParams = new ExecuteStringParams + { + Query = $"SELECT * FROM {SqlScriptFormatter.FormatMultipartIdentifier(initParams.ObjectName)}", + OwnerUri = initParams.OwnerUri + }; + await queryExecutionService.InterServiceExecuteQuery(executeParams, requestContext, + queryCreateSuccessCallback, queryCreateFailureCallback, + queryCompleteSuccessCallback, queryCompleteFailureCallback); + } + catch (Exception e) + { + await requestContext.SendError(e.Message); + } + } + + internal Task HandleRevertRowRequest(EditRevertRowParams revertParams, + RequestContext requestContext) + { + return HandleSessionRequest(revertParams, requestContext, session => + { + session.RevertRow(revertParams.RowId); + return new EditRevertRowResult(); + }); + } + + internal Task HandleUpdateCellRequest(EditUpdateCellParams updateParams, + RequestContext requestContext) + { + return HandleSessionRequest(updateParams, requestContext, + session => session.UpdateCell(updateParams.RowId, updateParams.ColumnId, updateParams.NewValue)); + } + + #endregion + + #region Private Helpers + + /// + /// Returns the session with the given owner URI or throws if it can't be found + /// + /// If the edit session doesn't exist + /// Owner URI for the edit session + /// The edit session that corresponds to the owner URI + private Session GetActiveSessionOrThrow(string ownerUri) + { + // Sanity check the owner URI is provided + Validate.IsNotNullOrWhitespaceString(nameof(ownerUri), ownerUri); + + // Attempt to get the session, throw if unable + Session session; + if (!ActiveSessions.TryGetValue(ownerUri, out session)) + { + throw new Exception(SR.EditDataSessionNotFound); + } + + return session; + } + + private async Task QueryCompleteCallback(Query query, EditInitializeParams initParams, + IEventSender requestContext) + { + EditSessionReadyParams readyParams = new EditSessionReadyParams + { + OwnerUri = initParams.OwnerUri + }; + + try + { + // Validate the query for a session + ResultSet resultSet = Session.ValidateQueryForSession(query); + + // Get a connection we'll use for SMO metadata lookup (and committing, later on) + DbConnection conn = await connectionService.GetOrOpenConnection(initParams.OwnerUri, ConnectionType.Edit); + var metadata = metadataFactory.GetObjectMetadata(conn, resultSet.Columns, + initParams.ObjectName, initParams.ObjectType); + + // Create the session and add it to the sessions list + Session session = new Session(resultSet, metadata); + if (!ActiveSessions.TryAdd(initParams.OwnerUri, session)) + { + throw new InvalidOperationException("Failed to create edit session, session already exists."); + } + readyParams.Success = true; + } + catch (Exception) + { + // Request that the query be disposed + await queryExecutionService.InterServiceDisposeQuery(initParams.OwnerUri, null, null); + readyParams.Success = false; + } + + // Send the edit session ready notification + await requestContext.SendEvent(EditSessionReadyEvent.Type, readyParams); + } + + #endregion + + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/EditTableMetadata.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/EditTableMetadata.cs new file mode 100644 index 00000000..21831a13 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/EditTableMetadata.cs @@ -0,0 +1,99 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Collections.Generic; +using System.Linq; +using System.Diagnostics; +using Microsoft.SqlServer.Management.Smo; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Microsoft.SqlTools.ServiceLayer.Utility; + +namespace Microsoft.SqlTools.ServiceLayer.EditData +{ + /// + /// Provides metadata about the table or view being edited + /// + public class EditTableMetadata : IEditTableMetadata + { + private readonly List columns; + private readonly List keyColumns; + + /// + /// Constructor that extracts useful metadata from the provided metadata objects + /// + /// DB columns from the ResultSet + /// SMO metadata object for the table/view being edited + public EditTableMetadata(IList dbColumns, TableViewTableTypeBase smoObject) + { + Validate.IsNotNull(nameof(dbColumns), dbColumns); + Validate.IsNotNull(nameof(smoObject), smoObject); + + // Make sure that we have equal columns on both metadata providers + Debug.Assert(dbColumns.Count == smoObject.Columns.Count); + + // Create the columns for edit usage + columns = new List(); + for (int i = 0; i < dbColumns.Count; i++) + { + Column smoColumn = smoObject.Columns[i]; + DbColumnWrapper dbColumn = dbColumns[i]; + + // A column is trustworthy for uniqueness if it can be updated or it has an identity + // property. If both of these are false (eg, timestamp) we can't trust it to uniquely + // identify a row in the table + bool isTrustworthyForUniqueness = dbColumn.IsUpdatable || smoColumn.Identity; + + EditColumnWrapper column = new EditColumnWrapper + { + DbColumn = dbColumn, + Ordinal = i, + EscapedName = SqlScriptFormatter.FormatIdentifier(dbColumn.ColumnName), + IsTrustworthyForUniqueness = isTrustworthyForUniqueness, + + // A key column is determined by whether it is in the primary key and trustworthy + IsKey = smoColumn.InPrimaryKey && isTrustworthyForUniqueness + }; + columns.Add(column); + } + + // Determine what the key columns are + keyColumns = columns.Where(c => c.IsKey).ToList(); + if (keyColumns.Count == 0) + { + // We didn't find any explicit key columns. Instead, we'll use all columns that are + // trustworthy for uniqueness (usually all the columns) + keyColumns = columns.Where(c => c.IsTrustworthyForUniqueness).ToList(); + } + + // If a table is memory optimized it is Hekaton. If it's a view, then it can't be Hekaton + Table smoTable = smoObject as Table; + IsMemoryOptimized = smoTable != null && smoTable.IsMemoryOptimized; + + // Escape the parts of the name + string[] objectNameParts = {smoObject.Schema, smoObject.Name}; + EscapedMultipartName = SqlScriptFormatter.FormatMultipartIdentifier(objectNameParts); + } + + /// + /// Read-only list of columns in the object being edited + /// + public IEnumerable Columns => columns.AsReadOnly(); + + /// + /// Full escaped multipart identifier for the object being edited + /// + public string EscapedMultipartName { get; } + + /// + /// Whether or not the object being edited is memory optimized + /// + public bool IsMemoryOptimized { get; } + + /// + /// Read-only list of columns that are used to uniquely identify a row + /// + public IEnumerable KeyColumns => keyColumns.AsReadOnly(); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/IEditMetadataFactory.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/IEditMetadataFactory.cs new file mode 100644 index 00000000..633566fa --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/IEditMetadataFactory.cs @@ -0,0 +1,26 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Data.Common; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.EditData +{ + /// + /// Interface for a factory that generates metadata for an object to edit + /// + public interface IEditMetadataFactory + { + /// + /// Generates a edit-ready metadata object + /// + /// Connection to use for getting metadata + /// List of columns from a query against the object + /// Name of the object to return metadata for + /// Type of the object to return metadata for + /// Metadata about the object requested + IEditTableMetadata GetObjectMetadata(DbConnection connection, DbColumnWrapper[] columns, string objectName, string objectType); + } +} \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/IEditTableMetadata.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/IEditTableMetadata.cs new file mode 100644 index 00000000..67181ee8 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/IEditTableMetadata.cs @@ -0,0 +1,36 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Collections.Generic; + +namespace Microsoft.SqlTools.ServiceLayer.EditData +{ + /// + /// An interface used in edit scenarios that defines properties for what columns are primary + /// keys, and other metadata of the table. + /// + public interface IEditTableMetadata + { + /// + /// All columns in the table that's being edited + /// + IEnumerable Columns { get; } + + /// + /// The escaped name of the table that's being edited + /// + string EscapedMultipartName { get; } + + /// + /// Whether or not this table is a memory optimized table + /// + bool IsMemoryOptimized { get; } + + /// + /// Columns that can be used to uniquely identify the a row + /// + IEnumerable KeyColumns { get; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/Session.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/Session.cs new file mode 100644 index 00000000..3099b4d5 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/Session.cs @@ -0,0 +1,215 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Concurrent; +using System.IO; +using System.Linq; +using Microsoft.SqlTools.ServiceLayer.EditData.Contracts; +using Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.Utility; + +namespace Microsoft.SqlTools.ServiceLayer.EditData +{ + /// + /// Represents an edit "session" bound to the results of a query, containing a cache of edits + /// that are pending. Provides logic for performing edit operations. + /// + public class Session + { + + #region Member Variables + + private readonly ResultSet associatedResultSet; + private readonly IEditTableMetadata objectMetadata; + + #endregion + + /// + /// Constructs a new edit session bound to the result set and metadat object provided + /// + /// The result set of the table to be edited + /// Metadata provider for the table to be edited + public Session(ResultSet resultSet, IEditTableMetadata objMetadata) + { + Validate.IsNotNull(nameof(resultSet), resultSet); + Validate.IsNotNull(nameof(objMetadata), objMetadata); + + // Setup the internal state + associatedResultSet = resultSet; + objectMetadata = objMetadata; + NextRowId = associatedResultSet.RowCount; + EditCache = new ConcurrentDictionary(); + } + + #region Properties + + /// + /// The internal ID for the next row in the table. Internal for unit testing purposes only. + /// + internal long NextRowId { get; private set; } + + /// + /// The cache of pending updates. Internal for unit test purposes only + /// + internal ConcurrentDictionary EditCache { get;} + + #endregion + + #region Public Methods + + /// + /// Validates that a query can be used for an edit session. The target result set is returned + /// + /// The query to validate + /// The result set to use + public static ResultSet ValidateQueryForSession(Query query) + { + Validate.IsNotNull(nameof(query), query); + + // Determine if the query is valid for editing + // Criterion 1) Query has finished executing + if (!query.HasExecuted) + { + throw new InvalidOperationException(SR.EditDataQueryNotCompleted); + } + + // Criterion 2) Query only has a single result set + ResultSet[] queryResultSets = query.Batches.SelectMany(b => b.ResultSets).ToArray(); + if (queryResultSets.Length != 1) + { + throw new InvalidOperationException(SR.EditDataQueryImproperResultSets); + } + + return query.Batches[0].ResultSets[0]; + } + + /// + /// Creates a new row update and adds it to the update cache + /// + /// If inserting into cache fails + /// The internal ID of the newly created row + public long CreateRow() + { + // Create a new row ID (atomically, since this could be accesses concurrently) + long newRowId = NextRowId++; + + // Create a new row create update and add to the update cache + RowCreate newRow = new RowCreate(newRowId, associatedResultSet, objectMetadata); + if (!EditCache.TryAdd(newRowId, newRow)) + { + // Revert the next row ID + NextRowId--; + throw new InvalidOperationException(SR.EditDataFailedAddRow); + } + + return newRowId; + } + + /// + /// Creates a delete row update and adds it to the update cache + /// + /// + /// If row requested to delete already has a pending change in the cache + /// + /// The internal ID of the row to delete + public void DeleteRow(long rowId) + { + // Sanity check the row ID + if (rowId >= NextRowId || rowId < 0) + { + throw new ArgumentOutOfRangeException(nameof(rowId), SR.EditDataRowOutOfRange); + } + + // Create a new row delete update and add to cache + RowDelete deleteRow = new RowDelete(rowId, associatedResultSet, objectMetadata); + if (!EditCache.TryAdd(rowId, deleteRow)) + { + throw new InvalidOperationException(SR.EditDataUpdatePending); + } + } + + /// + /// Removes a pending row update from the update cache. + /// + /// + /// If a pending row update with the given row ID does not exist. + /// + /// The internal ID of the row to reset + public void RevertRow(long rowId) + { + // Attempt to remove the row with the given ID + RowEditBase removedEdit; + if (!EditCache.TryRemove(rowId, out removedEdit)) + { + throw new ArgumentOutOfRangeException(nameof(rowId), SR.EditDataUpdateNotPending); + } + } + + public string ScriptEdits(string outputPath) + { + // Validate the output path + // @TODO: Reinstate this code once we have an interface around file generation + //if (outputPath == null) + //{ + // // If output path isn't provided, we'll use a temporary location + // outputPath = Path.GetTempFileName(); + //} + //else + if (outputPath == null || outputPath.Trim() == string.Empty) + { + // If output path is empty, that's an error + throw new ArgumentNullException(nameof(outputPath), SR.EditDataScriptFilePathNull); + } + + // Open a handle to the output file + using (FileStream outputStream = File.OpenWrite(outputPath)) + using (TextWriter outputWriter = new StreamWriter(outputStream)) + { + + // Convert each update in the cache into an insert/update/delete statement + foreach (RowEditBase rowEdit in EditCache.Values) + { + outputWriter.WriteLine(rowEdit.GetScript()); + } + } + + // Return the location of the generated script + return outputPath; + } + + /// + /// Performs an update to a specific cell in a row. If the row has not already been + /// initialized with a record in the update cache, one is created. + /// + /// If adding a new update row fails + /// + /// If the row that is requested to be edited is beyond the rows in the results and the + /// rows that are being added. + /// + /// The internal ID of the row to edit + /// The ordinal of the column to edit in the row + /// The new string value of the cell to update + public EditUpdateCellResult UpdateCell(long rowId, int columnId, string newValue) + { + // Sanity check to make sure that the row ID is in the range of possible values + if (rowId >= NextRowId || rowId < 0) + { + throw new ArgumentOutOfRangeException(nameof(rowId), SR.EditDataRowOutOfRange); + } + + // Attempt to get the row that is being edited, create a new update object if one + // doesn't exist + RowEditBase editRow = EditCache.GetOrAdd(rowId, new RowUpdate(rowId, associatedResultSet, objectMetadata)); + + // Pass the call to the row update + return editRow.SetCell(columnId, newValue); + } + + #endregion + + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/SmoEditMetadataFactory.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/SmoEditMetadataFactory.cs new file mode 100644 index 00000000..696d1ceb --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/SmoEditMetadataFactory.cs @@ -0,0 +1,68 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Data.Common; +using System.Data.SqlClient; +using Microsoft.SqlServer.Management.Common; +using Microsoft.SqlServer.Management.Smo; +using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.EditData +{ + /// + /// Factory that generates metadata using a combination of SMO and SqlClient metadata + /// + public class SmoEditMetadataFactory : IEditMetadataFactory + { + /// + /// Generates a edit-ready metadata object using SMO + /// + /// Connection to use for getting metadata + /// List of columns from a query against the object + /// Name of the object to return metadata for + /// Type of the object to return metadata for + /// Metadata about the object requested + public IEditTableMetadata GetObjectMetadata(DbConnection connection, DbColumnWrapper[] columns, string objectName, string objectType) + { + // Get a connection to the database for SMO purposes + SqlConnection sqlConn = connection as SqlConnection; + if (sqlConn == null) + { + // It's not actually a SqlConnection, so let's try a reliable SQL connection + ReliableSqlConnection reliableConn = connection as ReliableSqlConnection; + if (reliableConn == null) + { + // If we don't have connection we can use with SMO, just give up on using SMO + return null; + } + + // We have a reliable connection, use the underlying connection + sqlConn = reliableConn.GetUnderlyingConnection(); + } + + Server server = new Server(new ServerConnection(sqlConn)); + TableViewTableTypeBase result; + switch (objectType.ToLowerInvariant()) + { + case "table": + result = server.Databases[sqlConn.Database].Tables[objectName]; + break; + case "view": + result = server.Databases[sqlConn.Database].Views[objectName]; + break; + default: + throw new ArgumentOutOfRangeException(nameof(objectType), SR.EditDataUnsupportedObjectType(objectType)); + } + if (result == null) + { + throw new ArgumentOutOfRangeException(nameof(objectName), SR.EditDataObjectMetadataNotFound); + } + + return new EditTableMetadata(columns, result); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/CellUpdate.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/CellUpdate.cs new file mode 100644 index 00000000..d818b6b8 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/CellUpdate.cs @@ -0,0 +1,202 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Data.Common; +using System.Globalization; +using System.Text.RegularExpressions; +using Microsoft.SqlTools.ServiceLayer.Utility; + +namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement +{ + /// + /// Representation of a cell that should have a value inserted or updated + /// + public sealed class CellUpdate + { + private const string NullString = @"NULL"; + private const string TextNullString = @"'NULL'"; + private static readonly Regex HexRegex = new Regex("0x[0-9A-F]+", RegexOptions.Compiled | RegexOptions.IgnoreCase); + + /// + /// Constructs a new cell update based on the the string value provided and the column + /// for the cell. + /// + /// Column the cell will be under + /// The string from the client to convert to an object + public CellUpdate(DbColumn column, string valueAsString) + { + Validate.IsNotNull(nameof(column), column); + Validate.IsNotNull(nameof(valueAsString), valueAsString); + + // Store the state that won't be changed + Column = column; + Type columnType = column.DataType; + + // Check for null + if (valueAsString == NullString) + { + Value = DBNull.Value; + ValueAsString = valueAsString; + } + else if (columnType == typeof(byte[])) + { + // Binary columns need special attention + ProcessBinaryCell(valueAsString); + } + else if (columnType == typeof(string)) + { + // Special case for strings because the string value should stay the same as provided + // If user typed 'NULL' they mean NULL as text + Value = valueAsString == TextNullString ? NullString : valueAsString; + ValueAsString = valueAsString; + } + else if (columnType == typeof(Guid)) + { + Value = Guid.Parse(valueAsString); + ValueAsString = Value.ToString(); + } + else if (columnType == typeof(TimeSpan)) + { + Value = TimeSpan.Parse(valueAsString, CultureInfo.CurrentCulture); + ValueAsString = Value.ToString(); + } + else if (columnType == typeof(DateTimeOffset)) + { + Value = DateTimeOffset.Parse(valueAsString, CultureInfo.CurrentCulture); + ValueAsString = Value.ToString(); + } + else if (columnType == typeof(bool)) + { + ProcessBooleanCell(valueAsString); + } + // @TODO: Microsoft.SqlServer.Types.SqlHierarchyId + else + { + // Attempt to go straight to the destination type, if we know what it is, otherwise + // leave it as a string + Value = columnType != null + ? Convert.ChangeType(valueAsString, columnType, CultureInfo.CurrentCulture) + : valueAsString; + ValueAsString = Value.ToString(); + } + } + + #region Properties + + /// + /// The column that the cell will be placed in + /// + public DbColumn Column { get; } + + /// + /// The object representation of the cell provided by the client + /// + public object Value { get; private set; } + + /// + /// converted to a string + /// + public string ValueAsString { get; private set; } + + #endregion + + #region Private Helpers + + private void ProcessBinaryCell(string valueAsString) + { + string trimmedString = valueAsString.Trim(); + + byte[] byteArray; + uint uintVal; + if (uint.TryParse(trimmedString, NumberStyles.None, CultureInfo.InvariantCulture, out uintVal)) + { + // Get the bytes + byteArray = BitConverter.GetBytes(uintVal); + if (BitConverter.IsLittleEndian) + { + Array.Reverse(byteArray); + } + Value = byteArray; + + // User typed something numeric (may be hex or dec) + if ((uintVal & 0xFFFFFF00) == 0) + { + // Value can fit in a single byte + Value = new[] { byteArray[3] }; + } + else if ((uintVal & 0xFFFF0000) == 0) + { + // Value can fit in two bytes + Value = new[] { byteArray[2], byteArray[3] }; + } + else if ((uintVal & 0xFF000000) == 0) + { + // Value can fit in three bytes + Value = new[] { byteArray[1], byteArray[2], byteArray[3] }; + } + } + else if (HexRegex.IsMatch(valueAsString)) + { + // User typed something that starts with a hex identifier (0x) + // Strip off the 0x, pad with zero if necessary + trimmedString = trimmedString.Substring(2); + if (trimmedString.Length % 2 == 1) + { + trimmedString = "0" + trimmedString; + } + + // Convert to a byte array + byteArray = new byte[trimmedString.Length / 2]; + for (int i = 0; i < trimmedString.Length; i += 2) + { + string bString = $"{trimmedString[i]}{trimmedString[i + 1]}"; + byte bVal = byte.Parse(bString, NumberStyles.AllowHexSpecifier, CultureInfo.InvariantCulture); + byteArray[i / 2] = bVal; + } + Value = byteArray; + } + else + { + // Invalid format + throw new FormatException(SR.EditDataInvalidFormatBinary); + } + + // Generate the hex string as the return value + ValueAsString = "0x" + BitConverter.ToString((byte[])Value).Replace("-", string.Empty); + } + + private void ProcessBooleanCell(string valueAsString) + { + // Allow user to enter 1 or 0 + string trimmedString = valueAsString.Trim(); + int intVal; + if (int.TryParse(trimmedString, out intVal)) + { + switch (intVal) + { + case 1: + Value = true; + break; + case 0: + Value = false; + break; + default: + throw new ArgumentOutOfRangeException(nameof(valueAsString), + SR.EditDataInvalidFormatBoolean); + } + } + else + { + // Allow user to enter true or false + Value = bool.Parse(valueAsString); + } + + ValueAsString = Value.ToString(); + } + + #endregion + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowCreate.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowCreate.cs new file mode 100644 index 00000000..7e6c47a8 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowCreate.cs @@ -0,0 +1,103 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using Microsoft.SqlTools.ServiceLayer.EditData.Contracts; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Microsoft.SqlTools.ServiceLayer.Utility; + +namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement +{ + /// + /// Represents a row that should be added to the result set. Generates an INSERT statement. + /// + public sealed class RowCreate : RowEditBase + { + private const string InsertStatement = "INSERT INTO {0}({1}) VALUES ({2})"; + + private readonly CellUpdate[] newCells; + + /// + /// Creates a new Row Creation edit to the result set + /// + /// Internal ID of the row that is being created + /// The result set for the rows in the table we're editing + /// The metadata for table we're editing + public RowCreate(long rowId, ResultSet associatedResultSet, IEditTableMetadata associatedMetadata) + : base(rowId, associatedResultSet, associatedMetadata) + { + newCells = new CellUpdate[associatedResultSet.Columns.Length]; + } + + /// + /// Generates the INSERT INTO statement that will apply the row creation + /// + /// INSERT INTO statement + public override string GetScript() + { + List columnNames = new List(); + List columnValues = new List(); + + // Build the column list and value list + for (int i = 0; i < AssociatedResultSet.Columns.Length; i++) + { + DbColumnWrapper column = AssociatedResultSet.Columns[i]; + CellUpdate cell = newCells[i]; + + // If the column is not updatable, then skip it + if (!column.IsUpdatable) + { + continue; + } + + // If the cell doesn't have a value, but is updatable, don't try to create the script + if (cell == null) + { + throw new InvalidOperationException(SR.EditDataCreateScriptMissingValue); + } + + // Add the column and the data to their respective lists + columnNames.Add(SqlScriptFormatter.FormatIdentifier(column.ColumnName)); + columnValues.Add(SqlScriptFormatter.FormatValue(cell.Value, column)); + } + + // Put together the components of the statement + string joinedColumnNames = string.Join(", ", columnNames); + string joinedColumnValues = string.Join(", ", columnValues); + return string.Format(InsertStatement, AssociatedObjectMetadata.EscapedMultipartName, joinedColumnNames, + joinedColumnValues); + } + + /// + /// Sets the value of a cell in the row to be added + /// + /// Ordinal of the column to set in the row + /// String representation from the client of the value to add + /// + /// The updated value as a string of the object generated from + /// + public override EditUpdateCellResult SetCell(int columnId, string newValue) + { + // Validate the column and the value and convert to object + ValidateColumnIsUpdatable(columnId); + CellUpdate update = new CellUpdate(AssociatedResultSet.Columns[columnId], newValue); + + // Add the cell update to the + newCells[columnId] = update; + + // Put together a result of the change + EditUpdateCellResult eucr = new EditUpdateCellResult + { + HasCorrections = update.ValueAsString != newValue, + NewValue = update.ValueAsString != newValue ? update.ValueAsString : null, + IsNull = update.Value == DBNull.Value, + IsRevert = false // Editing cells of new rows cannot be reverts + }; + return eucr; + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowDelete.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowDelete.cs new file mode 100644 index 00000000..f585691d --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowDelete.cs @@ -0,0 +1,55 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Globalization; +using Microsoft.SqlTools.ServiceLayer.EditData.Contracts; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; + +namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement +{ + /// + /// Represents a row that should be deleted. This will generate a DELETE statement + /// + public sealed class RowDelete : RowEditBase + { + private const string DeleteStatement = "DELETE FROM {0} {1}"; + private const string DeleteMemoryOptimizedStatement = "DELETE FROM {0} WITH(SNAPSHOT) {1}"; + + /// + /// Constructs a new RowDelete object + /// + /// Internal ID of the row to be deleted + /// Result set that is being edited + /// Improved metadata of the object being edited + public RowDelete(long rowId, ResultSet associatedResultSet, IEditTableMetadata associatedMetadata) + : base(rowId, associatedResultSet, associatedMetadata) + { + } + + /// + /// Generates a DELETE statement to delete this row + /// + /// String of the DELETE statement + public override string GetScript() + { + string formatString = AssociatedObjectMetadata.IsMemoryOptimized ? DeleteMemoryOptimizedStatement : DeleteStatement; + return string.Format(CultureInfo.InvariantCulture, formatString, + AssociatedObjectMetadata.EscapedMultipartName, GetWhereClause(false).CommandText); + } + + /// + /// This method should not be called. A cell cannot be updated on a row that is pending + /// deletion. + /// + /// Always thrown + /// Ordinal of the column to update + /// New value for the cell + public override EditUpdateCellResult SetCell(int columnId, string newValue) + { + throw new InvalidOperationException(SR.EditDataDeleteSetCell); + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowEdit.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowEdit.cs new file mode 100644 index 00000000..e36c31e0 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowEdit.cs @@ -0,0 +1,197 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.Data.SqlClient; +using System.Linq; +using Microsoft.SqlTools.ServiceLayer.EditData.Contracts; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Microsoft.SqlTools.ServiceLayer.Utility; + +namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement +{ + /// + /// Base class for row edit operations. Provides basic information and helper functionality + /// that all RowEdit implementations can use. Defines functionality that must be implemented + /// in all child classes. + /// + public abstract class RowEditBase + { + /// + /// Internal parameterless constructor, required for mocking + /// + protected internal RowEditBase() { } + + /// + /// Base constructor for a row edit. Stores the state that should be available to all row + /// edit implementations. + /// + /// The internal ID of the row that is being edited + /// The result set that will be updated + /// Metadata provider for the object to edit + protected RowEditBase(long rowId, ResultSet associatedResultSet, IEditTableMetadata associatedMetadata) + { + RowId = rowId; + AssociatedResultSet = associatedResultSet; + AssociatedObjectMetadata = associatedMetadata; + } + + #region Properties + + /// + /// The internal ID of the row to which this edit applies, relative to the result set + /// + public long RowId { get; } + + /// + /// The result set that is associated with this row edit + /// + public ResultSet AssociatedResultSet { get; } + + /// + /// The metadata for the table this edit is associated to + /// + public IEditTableMetadata AssociatedObjectMetadata { get; } + + #endregion + + /// + /// Converts the row edit into a SQL statement + /// + /// A SQL statement + public abstract string GetScript(); + + /// + /// Changes the value a cell in the row. + /// + /// Ordinal of the column in the row to update + /// The new value for the cell + /// The value of the cell after applying validation logic + public abstract EditUpdateCellResult SetCell(int columnId, string newValue); + + /// + /// Performs validation of column ID and if column can be updated. + /// + /// + /// If is less than 0 or greater than the number of columns + /// in the row + /// + /// If the column is not updatable + /// Ordinal of the column to update + protected void ValidateColumnIsUpdatable(int columnId) + { + // Sanity check that the column ID is within the range of columns + if (columnId >= AssociatedResultSet.Columns.Length || columnId < 0) + { + throw new ArgumentOutOfRangeException(nameof(columnId), SR.EditDataColumnIdOutOfRange); + } + + DbColumnWrapper column = AssociatedResultSet.Columns[columnId]; + if (!column.IsUpdatable) + { + throw new InvalidOperationException(SR.EditDataColumnCannotBeEdited); + } + } + + /// + /// Generates a WHERE clause that uses the key columns of the table to uniquely identity + /// the row that will be updated. + /// + /// + /// Whether or not to generate a parameterized where clause. If true verbatim values + /// will be replaced with paremeters (like @Param12). The parameters must be added to the + /// SqlCommand used to execute the commit. + /// + /// A object + protected WhereClause GetWhereClause(bool parameterize) + { + WhereClause output = new WhereClause(); + + if (!AssociatedObjectMetadata.KeyColumns.Any()) + { + throw new InvalidOperationException(SR.EditDataColumnNoKeyColumns); + } + + IList row = AssociatedResultSet.GetRow(RowId); + foreach (EditColumnWrapper col in AssociatedObjectMetadata.KeyColumns) + { + // Put together a clause for the value of the cell + DbCellValue cellData = row[col.Ordinal]; + string cellDataClause; + if (cellData.IsNull) + { + cellDataClause = "IS NULL"; + } + else + { + if (cellData.RawObject is byte[] || + col.DbColumn.DataTypeName.Equals("TEXT", StringComparison.OrdinalIgnoreCase) || + col.DbColumn.DataTypeName.Equals("NTEXT", StringComparison.OrdinalIgnoreCase)) + { + // Special cases for byte[] and TEXT/NTEXT types + cellDataClause = "IS NOT NULL"; + } + else + { + // General case is to just use the value from the cell + if (parameterize) + { + // Add a parameter and parameterized clause component + // NOTE: We include the row ID to make sure the parameter is unique if + // we execute multiple row edits at once. + string paramName = $"@Param{RowId}{col.Ordinal}"; + cellDataClause = $"= {paramName}"; + output.Parameters.Add(new SqlParameter(paramName, col.DbColumn.SqlDbType)); + } + else + { + // Add the clause component with the formatted value + cellDataClause = $"= {SqlScriptFormatter.FormatValue(cellData, col.DbColumn)}"; + } + } + } + + string completeComponent = $"({col.EscapedName} {cellDataClause})"; + output.ClauseComponents.Add(completeComponent); + } + + return output; + } + + /// + /// Represents a WHERE clause that can be used for identifying a row in a table. + /// + protected class WhereClause + { + /// + /// Constructs and initializes a new where clause + /// + public WhereClause() + { + Parameters = new List(); + ClauseComponents = new List(); + } + + /// + /// SqlParameters used in a parameterized query. If this object was generated without + /// parameterization, this will be an empty list + /// + public List Parameters { get; } + + /// + /// Strings that make up the WHERE clause, such as "([col1] = 'something')" + /// + public List ClauseComponents { get; } + + /// + /// Total text of the WHERE clause that joins all the components with AND + /// + public string CommandText => $"WHERE {string.Join(" AND ", ClauseComponents)}"; + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowUpdate.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowUpdate.cs new file mode 100644 index 00000000..482e3fa5 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowUpdate.cs @@ -0,0 +1,110 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using Microsoft.SqlTools.ServiceLayer.EditData.Contracts; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Microsoft.SqlTools.ServiceLayer.Utility; + +namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement +{ + /// + /// An update to apply to a row of a result set. This will generate an UPDATE statement. + /// + public sealed class RowUpdate : RowEditBase + { + private const string UpdateStatement = "UPDATE {0} SET {1} {2}"; + private const string UpdateStatementMemoryOptimized = "UPDATE {0} WITH (SNAPSHOT) SET {1} {2}"; + + private readonly Dictionary cellUpdates; + private readonly IList associatedRow; + + /// + /// Constructs a new RowUpdate to be added to the cache. + /// + /// Internal ID of the row that will be updated with this object + /// Result set for the rows of the object to update + /// Metadata provider for the object to update + public RowUpdate(long rowId, ResultSet associatedResultSet, IEditTableMetadata associatedMetadata) + : base(rowId, associatedResultSet, associatedMetadata) + { + cellUpdates = new Dictionary(); + associatedRow = associatedResultSet.GetRow(rowId); + } + + /// + /// Constructs an update statement to change the associated row. + /// + /// An UPDATE statement + public override string GetScript() + { + // Build the "SET" portion of the statement + IEnumerable setComponents = cellUpdates.Values.Select(cellUpdate => + { + string formattedColumnName = SqlScriptFormatter.FormatIdentifier(cellUpdate.Column.ColumnName); + string formattedValue = SqlScriptFormatter.FormatValue(cellUpdate.Value, cellUpdate.Column); + return $"{formattedColumnName} = {formattedValue}"; + }); + string setClause = string.Join(", ", setComponents); + + // Get the where clause + string whereClause = GetWhereClause(false).CommandText; + + // Put it all together + string formatString = AssociatedObjectMetadata.IsMemoryOptimized ? UpdateStatementMemoryOptimized : UpdateStatement; + return string.Format(CultureInfo.InvariantCulture, formatString, + AssociatedObjectMetadata.EscapedMultipartName, setClause, whereClause); + } + + /// + /// Sets the value of the cell in the associated row. If is + /// identical to the original value, this will remove the cell update from the row update. + /// + /// Ordinal of the columns that will be set + /// String representation of the value the user input + /// + /// The string representation of the new value (after conversion to target object) if the + /// a change is made. null is returned if the cell is reverted to it's original value. + /// + public override EditUpdateCellResult SetCell(int columnId, string newValue) + { + // Validate the value and convert to object + ValidateColumnIsUpdatable(columnId); + CellUpdate update = new CellUpdate(AssociatedResultSet.Columns[columnId], newValue); + + // If the value is the same as the old value, we shouldn't make changes + // NOTE: We must use .Equals in order to ignore object to object comparisons + if (update.Value.Equals(associatedRow[columnId].RawObject)) + { + // Remove any pending change and stop processing this + if (cellUpdates.ContainsKey(columnId)) + { + cellUpdates.Remove(columnId); + } + return new EditUpdateCellResult + { + HasCorrections = false, + NewValue = associatedRow[columnId].DisplayValue, + IsRevert = true, + IsNull = associatedRow[columnId].IsNull + }; + } + + // The change is real, so set it + cellUpdates[columnId] = update; + return new EditUpdateCellResult + { + HasCorrections = update.ValueAsString != newValue, + NewValue = update.ValueAsString != newValue ? update.ValueAsString : null, + IsNull = update.Value == DBNull.Value, + IsRevert = false // If we're in this branch, it is not a revert + }; + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/HostLoader.cs b/src/Microsoft.SqlTools.ServiceLayer/HostLoader.cs index 429e3732..cf65824e 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/HostLoader.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/HostLoader.cs @@ -8,6 +8,7 @@ using System.Linq; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Credentials; +using Microsoft.SqlTools.ServiceLayer.EditData; using Microsoft.SqlTools.ServiceLayer.Extensibility; using Microsoft.SqlTools.ServiceLayer.Hosting; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; @@ -75,6 +76,9 @@ namespace Microsoft.SqlTools.ServiceLayer QueryExecutionService.Instance.InitializeService(serviceHost); serviceProvider.RegisterSingleService(QueryExecutionService.Instance); + EditDataService.Instance.InitializeService(serviceHost); + serviceProvider.RegisterSingleService(EditDataService.Instance); + InitializeHostedServices(serviceProvider, serviceHost); serviceHost.InitializeRequestHandlers(); diff --git a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.Designer.cs b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.Designer.cs index d7cf2923..65569efd 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.Designer.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.Designer.cs @@ -230,6 +230,15 @@ namespace Microsoft.SqlTools.ServiceLayer.Localization { } } + /// + /// Looks up a localized string similar to Specified URI '{0}' does not have a default connection. + /// + public static string ConnectionServiceDbErrorDefaultNotConnected { + get { + return ResourceManager.GetString("ConnectionServiceDbErrorDefaultNotConnected", resourceCulture); + } + } + /// /// Looks up a localized string similar to SpecifiedUri '{0}' does not have existing connection. /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.cs b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.cs index 3fe5880f..cd92d993 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.cs @@ -485,6 +485,134 @@ namespace Microsoft.SqlTools.ServiceLayer } } + public static string EditDataSessionNotFound + { + get + { + return Keys.GetString(Keys.EditDataSessionNotFound); + } + } + + public static string EditDataQueryNotCompleted + { + get + { + return Keys.GetString(Keys.EditDataQueryNotCompleted); + } + } + + public static string EditDataQueryImproperResultSets + { + get + { + return Keys.GetString(Keys.EditDataQueryImproperResultSets); + } + } + + public static string EditDataFailedAddRow + { + get + { + return Keys.GetString(Keys.EditDataFailedAddRow); + } + } + + public static string EditDataRowOutOfRange + { + get + { + return Keys.GetString(Keys.EditDataRowOutOfRange); + } + } + + public static string EditDataUpdatePending + { + get + { + return Keys.GetString(Keys.EditDataUpdatePending); + } + } + + public static string EditDataUpdateNotPending + { + get + { + return Keys.GetString(Keys.EditDataUpdateNotPending); + } + } + + public static string EditDataObjectMetadataNotFound + { + get + { + return Keys.GetString(Keys.EditDataObjectMetadataNotFound); + } + } + + public static string EditDataInvalidFormatBinary + { + get + { + return Keys.GetString(Keys.EditDataInvalidFormatBinary); + } + } + + public static string EditDataInvalidFormatBoolean + { + get + { + return Keys.GetString(Keys.EditDataInvalidFormatBoolean); + } + } + + public static string EditDataCreateScriptMissingValue + { + get + { + return Keys.GetString(Keys.EditDataCreateScriptMissingValue); + } + } + + public static string EditDataDeleteSetCell + { + get + { + return Keys.GetString(Keys.EditDataDeleteSetCell); + } + } + + public static string EditDataColumnIdOutOfRange + { + get + { + return Keys.GetString(Keys.EditDataColumnIdOutOfRange); + } + } + + public static string EditDataColumnCannotBeEdited + { + get + { + return Keys.GetString(Keys.EditDataColumnCannotBeEdited); + } + } + + public static string EditDataColumnNoKeyColumns + { + get + { + return Keys.GetString(Keys.EditDataColumnNoKeyColumns); + } + } + + public static string EditDataScriptFilePathNull + { + get + { + return Keys.GetString(Keys.EditDataScriptFilePathNull); + } + } + public static string EE_BatchSqlMessageNoProcedureInfo { get @@ -790,6 +918,11 @@ namespace Microsoft.SqlTools.ServiceLayer return Keys.GetString(Keys.WorkspaceServiceBufferPositionOutOfOrder, sLine, sCol, eLine, eCol); } + public static string EditDataUnsupportedObjectType(string typeName) + { + return Keys.GetString(Keys.EditDataUnsupportedObjectType, typeName); + } + [System.Runtime.CompilerServices.CompilerGeneratedAttribute()] public class Keys { @@ -1008,6 +1141,57 @@ namespace Microsoft.SqlTools.ServiceLayer public const string WorkspaceServiceBufferPositionOutOfOrder = "WorkspaceServiceBufferPositionOutOfOrder"; + public const string EditDataSessionNotFound = "EditDataSessionNotFound"; + + + public const string EditDataUnsupportedObjectType = "EditDataUnsupportedObjectType"; + + + public const string EditDataQueryNotCompleted = "EditDataQueryNotCompleted"; + + + public const string EditDataQueryImproperResultSets = "EditDataQueryImproperResultSets"; + + + public const string EditDataFailedAddRow = "EditDataFailedAddRow"; + + + public const string EditDataRowOutOfRange = "EditDataRowOutOfRange"; + + + public const string EditDataUpdatePending = "EditDataUpdatePending"; + + + public const string EditDataUpdateNotPending = "EditDataUpdateNotPending"; + + + public const string EditDataObjectMetadataNotFound = "EditDataObjectMetadataNotFound"; + + + public const string EditDataInvalidFormatBinary = "EditDataInvalidFormatBinary"; + + + public const string EditDataInvalidFormatBoolean = "EditDataInvalidFormatBoolean"; + + + public const string EditDataCreateScriptMissingValue = "EditDataCreateScriptMissingValue"; + + + public const string EditDataDeleteSetCell = "EditDataDeleteSetCell"; + + + public const string EditDataColumnIdOutOfRange = "EditDataColumnIdOutOfRange"; + + + public const string EditDataColumnCannotBeEdited = "EditDataColumnCannotBeEdited"; + + + public const string EditDataColumnNoKeyColumns = "EditDataColumnNoKeyColumns"; + + + public const string EditDataScriptFilePathNull = "EditDataScriptFilePathNull"; + + public const string EE_BatchSqlMessageNoProcedureInfo = "EE_BatchSqlMessageNoProcedureInfo"; diff --git a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.resx b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.resx index 3cdcb15a..551681e1 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.resx +++ b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.resx @@ -410,6 +410,75 @@ . Parameters: 0 - sLine (int), 1 - sCol (int), 2 - eLine (int), 3 - eCol (int) + + Edit session does not exist. + + + + Database object {0} cannot be used for editing. + . + Parameters: 0 - typeName (string) + + + Query has not completed execution + + + + Query did not generate exactly one result set + + + + Failed to add new row to update cache + + + + Given row ID is outside the range of rows in the edit cache + + + + An update is already pending for this row and must be reverted first + + + + Given row ID does not have pending updated + + + + Table or view metadata could not be found + + + + Invalid format for binary column + + + + Allowed values for boolean columns are 0, 1, "true", or "false" + + + + A required cell value is missing + + + + A delete is pending for this row, a cell update cannot be applied. + + + + Column ID must be in the range of columns for the query + + + + Column cannot be edited + + + + No key columns were found + + + + An output filename must be provided + + Msg {0}, Level {1}, State {2}, Line {3} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.strings b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.strings index be49988a..f9c17744 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.strings +++ b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.strings @@ -201,6 +201,43 @@ WorkspaceServicePositionColumnOutOfRange(int line) = Position is outside of colu WorkspaceServiceBufferPositionOutOfOrder(int sLine, int sCol, int eLine, int eCol) = Start position ({0}, {1}) must come before or be equal to the end position ({2}, {3}) +############################################################################ +# Edit Data Service + +EditDataSessionNotFound = Edit session does not exist. + +EditDataUnsupportedObjectType(string typeName) = Database object {0} cannot be used for editing. + +EditDataQueryNotCompleted = Query has not completed execution + +EditDataQueryImproperResultSets = Query did not generate exactly one result set + +EditDataFailedAddRow = Failed to add new row to update cache + +EditDataRowOutOfRange = Given row ID is outside the range of rows in the edit cache + +EditDataUpdatePending = An update is already pending for this row and must be reverted first + +EditDataUpdateNotPending = Given row ID does not have pending updated + +EditDataObjectMetadataNotFound = Table or view metadata could not be found + +EditDataInvalidFormatBinary = Invalid format for binary column + +EditDataInvalidFormatBoolean = Allowed values for boolean columns are 0, 1, "true", or "false" + +EditDataCreateScriptMissingValue = A required cell value is missing + +EditDataDeleteSetCell = A delete is pending for this row, a cell update cannot be applied. + +EditDataColumnIdOutOfRange = Column ID must be in the range of columns for the query + +EditDataColumnCannotBeEdited = Column cannot be edited + +EditDataColumnNoKeyColumns = No key columns were found + +EditDataScriptFilePathNull = An output filename must be provided + ############################################################################ # DacFx Resources diff --git a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.xlf b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.xlf index fdf394be..ea569eb8 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.xlf +++ b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.xlf @@ -509,6 +509,92 @@ Replacement of an empty string by an empty string. + + Edit session does not exist. + Edit session does not exist. + + + + Query has not completed execution + Query has not completed execution + + + + Query did not generate exactly one result set + Query did not generate exactly one result set + + + + Failed to add new row to update cache + Failed to add new row to update cache + + + + Given row ID is outside the range of rows in the edit cache + Given row ID is outside the range of rows in the edit cache + + + + An update is already pending for this row and must be reverted first + An update is already pending for this row and must be reverted first + + + + Given row ID does not have pending updated + Given row ID does not have pending updated + + + + Table or view metadata could not be found + Table or view metadata could not be found + + + + Invalid format for binary column + Invalid format for binary column + + + + Allowed values for boolean columns are 0, 1, "true", or "false" + Boolean columns must be numeric 1 or 0, or string true or false + + + + A required cell value is missing + A required cell value is missing + + + + A delete is pending for this row, a cell update cannot be applied. + A delete is pending for this row, a cell update cannot be applied. + + + + Column ID must be in the range of columns for the query + Column ID must be in the range of columns for the query + + + + Column cannot be edited + Column cannot be edited + + + + No key columns were found + No key columns were found + + + + An output filename must be provided + An output filename must be provided + + + + Database object {0} cannot be used for editing. + Database object {0} cannot be used for editing. + . + Parameters: 0 - typeName (string) + \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbCellValue.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbCellValue.cs index 6eabb4d3..cfd5d95a 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbCellValue.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbCellValue.cs @@ -15,6 +15,11 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts /// public string DisplayValue { get; set; } + /// + /// Whether or not the cell is NULL + /// + public bool IsNull { get; set; } + /// /// The raw object for the cell, for use internally /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbColumnWrapper.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbColumnWrapper.cs index 4c5bd6e6..9cbf438c 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbColumnWrapper.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbColumnWrapper.cs @@ -4,9 +4,11 @@ using System; using System.Collections.Generic; +using System.Data; using System.Data.Common; using System.Data.SqlTypes; using System.Diagnostics; +using Microsoft.SqlTools.ServiceLayer.Utility; namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts { @@ -16,10 +18,12 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts /// public class DbColumnWrapper : DbColumn { + #region Constants + /// /// All types supported by the server, stored as a hash set to provide O(1) lookup /// - internal static readonly HashSet AllServerDataTypes = new HashSet + private static readonly HashSet AllServerDataTypes = new HashSet { "bigint", "binary", @@ -52,6 +56,12 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts "datetime2" }; + private const string SqlXmlDataTypeName = "xml"; + private const string DbTypeXmlDataTypeName = "DBTYPE_XML"; + private const string UnknownTypeName = "unknown"; + + #endregion + /// /// Constructor for a DbColumnWrapper /// @@ -81,21 +91,49 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts NumericScale = column.NumericScale; UdtAssemblyQualifiedName = column.UdtAssemblyQualifiedName; DataType = column.DataType; - DataTypeName = column.DataTypeName; + DataTypeName = column.DataTypeName.ToLowerInvariant(); + + // Determine the SqlDbType + SqlDbType type; + if (Enum.TryParse(DataTypeName, true, out type)) + { + SqlDbType = type; + } + else + { + switch (DataTypeName) + { + case "numeric": + SqlDbType = SqlDbType.Decimal; + break; + case "sql_variant": + SqlDbType = SqlDbType.Variant; + break; + case "timestamp": + SqlDbType = SqlDbType.VarBinary; + break; + case "sysname": + SqlDbType = SqlDbType.NVarChar; + break; + default: + SqlDbType = DataTypeName.EndsWith(".sys.hierarchyid") ? SqlDbType.NVarChar : SqlDbType.Udt; + break; + } + } // We want the display name for the column to always exist ColumnName = string.IsNullOrEmpty(column.ColumnName) ? SR.QueryServiceColumnNull : column.ColumnName; - switch (column.DataTypeName) + switch (DataTypeName) { case "varchar": case "nvarchar": IsChars = true; - Debug.Assert(column.ColumnSize.HasValue); - if (column.ColumnSize.Value == int.MaxValue) + Debug.Assert(ColumnSize.HasValue); + if (ColumnSize.Value == int.MaxValue) { //For Yukon, special case nvarchar(max) with column name == "Microsoft SQL Server 2005 XML Showplan" - //assume it is an XML showplan. @@ -131,8 +169,8 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts case "rowversion": IsBytes = true; - Debug.Assert(column.ColumnSize.HasValue); - if (column.ColumnSize.Value == int.MaxValue) + Debug.Assert(ColumnSize.HasValue); + if (ColumnSize.Value == int.MaxValue) { IsLong = true; } @@ -141,7 +179,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts IsSqlVariant = true; break; default: - if (!AllServerDataTypes.Contains(column.DataTypeName)) + if (!AllServerDataTypes.Contains(DataTypeName)) { // treat all UDT's as long/bytes data types to prevent the CLR from attempting // to load the UDT assembly into our process to call ToString() on the object. @@ -216,6 +254,43 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts /// public bool IsJson { get; set; } + /// + /// The SqlDbType of the column, for use in a SqlParameter + /// + public SqlDbType SqlDbType { get; private set; } + + /// + /// Whether or not the column is an XML Reader type. + /// + /// + /// Logic taken from SSDT determination of whether a column is a SQL XML type. It may not + /// be possible to have XML readers from .NET Core SqlClient. + /// + public bool IsSqlXmlType => DataTypeName.Equals(SqlXmlDataTypeName, StringComparison.OrdinalIgnoreCase) || + DataTypeName.Equals(DbTypeXmlDataTypeName, StringComparison.OrdinalIgnoreCase) || + DataType == typeof(System.Xml.XmlReader); + + /// + /// Whether or not the column is an unknown type + /// + /// + /// Logic taken from SSDT determination of unknown columns. It may not even be possible to + /// have "unknown" column types with the .NET Core SqlClient. + /// + public bool IsUnknownType => DataType == typeof(object) && + DataTypeName.Equals(UnknownTypeName, StringComparison.OrdinalIgnoreCase); + + /// + /// Whether or not the column can be updated, based on whether it's an auto increment + /// column, is an XML reader column, and if it's read only. + /// + /// + /// Logic taken from SSDT determination of updatable columns + /// + public bool IsUpdatable => !IsAutoIncrement.HasTrue() && + !IsReadOnly.HasTrue() && + !IsSqlXmlType; + #endregion } diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamReader.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamReader.cs index f836e74d..78e43d3b 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamReader.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamReader.cs @@ -199,6 +199,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage { result.RawObject = null; result.DisplayValue = null; + result.IsNull = true; } else { @@ -207,6 +208,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage T resultObject = convertFunc(length.ValueLength); result.RawObject = resultObject; result.DisplayValue = toStringFunc == null ? result.RawObject.ToString() : toStringFunc(resultObject); + result.IsNull = false; } return new FileStreamReadResult(result, length.TotalLength); diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs index 46bbc286..4af4eb20 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs @@ -151,11 +151,14 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution RequestContext requestContext) { // Setup actions to perform upon successful start and on failure to start - Func queryCreationAction = () => requestContext.SendResult(new ExecuteRequestResult()); - Func queryFailAction = requestContext.SendError; + Func> queryCreateSuccessAction = async q => { + await requestContext.SendResult(new ExecuteRequestResult()); + return true; + }; + Func queryCreateFailureAction = requestContext.SendError; // Use the internal handler to launch the query - return InterServiceExecuteQuery(executeParams, requestContext, queryCreationAction, queryFailAction); + return InterServiceExecuteQuery(executeParams, requestContext, queryCreateSuccessAction, queryCreateFailureAction, null, null); } /// @@ -328,26 +331,59 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// Query execution meant to be called from another service. Utilizes callbacks to allow /// custom actions to be taken upon creation of query and failure to create query. /// - /// Params for creating the new query - /// Object that can send events for query execution progress - /// - /// Action to perform when query has been successfully created, right before execution of - /// the query + /// Parameters for execution + /// Event sender that will send progressive events during execution of the query + /// + /// Callback for when query has been created successfully. If result is true, query + /// will be executed asynchronously. If result is false, query will be disposed. May + /// be null /// - /// Action to perform if query was not successfully created - public async Task InterServiceExecuteQuery(ExecuteRequestParamsBase executeParams, IEventSender eventSender, - Func queryCreatedAction, Func failureAction) + /// + /// Callback for when query failed to be created successfully. Error message is provided. + /// May be null. + /// + /// + /// Callback to call when query has completed execution successfully. May be null. + /// + /// + /// Callback to call when query has completed execution with errors. May be null. + /// + public async Task InterServiceExecuteQuery(ExecuteRequestParamsBase executeParams, + IEventSender queryEventSender, + Func> queryCreateSuccessFunc, + Func queryCreateFailFunc, + Query.QueryAsyncEventHandler querySuccessFunc, + Query.QueryAsyncEventHandler queryFailureFunc) { Validate.IsNotNull(nameof(executeParams), executeParams); - Validate.IsNotNull(nameof(eventSender), eventSender); - Validate.IsNotNull(nameof(queryCreatedAction), queryCreatedAction); - Validate.IsNotNull(nameof(failureAction), failureAction); - - // Get a new active query - Query newQuery = await CreateAndActivateNewQuery(executeParams, queryCreatedAction, failureAction); + Validate.IsNotNull(nameof(queryEventSender), queryEventSender); + + Query newQuery; + try + { + // Get a new active query + newQuery = CreateQuery(executeParams); + if (queryCreateSuccessFunc != null && !await queryCreateSuccessFunc(newQuery)) + { + // The callback doesn't want us to continue, for some reason + // It's ok if we leave the query behind in the active query list, the next call + // to execute will replace it. + newQuery.Dispose(); + return; + } + } + catch (Exception e) + { + // Call the failure callback if it was provided + if (queryCreateFailFunc != null) + { + await queryCreateFailFunc(e.Message); + } + return; + } // Execute the query asynchronously - ExecuteAndCompleteQuery(executeParams.OwnerUri, eventSender, newQuery); + ExecuteAndCompleteQuery(executeParams.OwnerUri, newQuery, queryEventSender, querySuccessFunc, queryFailureFunc); } /// @@ -390,63 +426,47 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution #region Private Helpers - private async Task CreateAndActivateNewQuery(ExecuteRequestParamsBase executeParams, Func successAction, Func failureAction) + private Query CreateQuery(ExecuteRequestParamsBase executeParams) { - try + // Attempt to get the connection for the editor + ConnectionInfo connectionInfo; + if (!ConnectionService.TryFindConnection(executeParams.OwnerUri, out connectionInfo)) { - // Attempt to get the connection for the editor - ConnectionInfo connectionInfo; - if (!ConnectionService.TryFindConnection(executeParams.OwnerUri, out connectionInfo)) - { - await failureAction(SR.QueryServiceQueryInvalidOwnerUri); - return null; - } - - // Attempt to clean out any old query on the owner URI - Query oldQuery; - if (ActiveQueries.TryGetValue(executeParams.OwnerUri, out oldQuery) && oldQuery.HasExecuted) - { - oldQuery.Dispose(); - ActiveQueries.TryRemove(executeParams.OwnerUri, out oldQuery); - } - - // Retrieve the current settings for executing the query with - QueryExecutionSettings querySettings = Settings.QueryExecutionSettings; - - // Apply execution parameter settings - querySettings.ExecutionPlanOptions = executeParams.ExecutionPlanOptions; - - // If we can't add the query now, it's assumed the query is in progress - Query newQuery = new Query(GetSqlText(executeParams), connectionInfo, querySettings, BufferFileFactory); - if (!ActiveQueries.TryAdd(executeParams.OwnerUri, newQuery)) - { - await failureAction(SR.QueryServiceQueryInProgress); - newQuery.Dispose(); - return null; - } - - // Successfully created query - await successAction(); - - return newQuery; + throw new ArgumentOutOfRangeException(nameof(executeParams.OwnerUri), SR.QueryServiceQueryInvalidOwnerUri); } - catch (Exception e) + + // Attempt to clean out any old query on the owner URI + Query oldQuery; + if (ActiveQueries.TryGetValue(executeParams.OwnerUri, out oldQuery) && oldQuery.HasExecuted) { - await failureAction(e.Message); - return null; + oldQuery.Dispose(); + ActiveQueries.TryRemove(executeParams.OwnerUri, out oldQuery); } + + // Retrieve the current settings for executing the query with + QueryExecutionSettings settings = Settings.QueryExecutionSettings; + + // Apply execution parameter settings + settings.ExecutionPlanOptions = executeParams.ExecutionPlanOptions; + + // If we can't add the query now, it's assumed the query is in progress + Query newQuery = new Query(GetSqlText(executeParams), connectionInfo, settings, BufferFileFactory); + if (!ActiveQueries.TryAdd(executeParams.OwnerUri, newQuery)) + { + newQuery.Dispose(); + throw new InvalidOperationException(SR.QueryServiceQueryInProgress); + } + + return newQuery; } - private static void ExecuteAndCompleteQuery(string ownerUri, IEventSender eventSender, Query query) + private static void ExecuteAndCompleteQuery(string ownerUri, Query query, + IEventSender eventSender, + Query.QueryAsyncEventHandler querySuccessCallback, + Query.QueryAsyncEventHandler queryFailureCallback) { - // Skip processing if the query is null - if (query == null) - { - return; - } - - // Setup the query completion/failure callbacks - Query.QueryAsyncEventHandler callback = async q => + // Setup the callback to send the complete event + Query.QueryAsyncEventHandler completeCallback = async q => { // Send back the results QueryCompleteParams eventParams = new QueryCompleteParams @@ -457,9 +477,13 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution await eventSender.SendEvent(QueryCompleteEvent.Type, eventParams); }; + query.QueryCompleted += completeCallback; + query.QueryFailed += completeCallback; - query.QueryCompleted += callback; - query.QueryFailed += callback; + // Add the callbacks that were provided by the caller + // If they're null, that's no problem + query.QueryCompleted += querySuccessCallback; + query.QueryFailed += queryFailureCallback; // Setup the batch callbacks Batch.BatchAsyncEventHandler batchStartCallback = async b => diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs index b6adeb33..83055228 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs @@ -183,6 +183,26 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution #region Public Methods + public IList GetRow(long rowId) + { + // Sanity check to make sure that results have been read beforehand + if (!hasBeenRead) + { + throw new InvalidOperationException(SR.QueryServiceResultSetNotRead); + } + + // Sanity check to make sure that the row exists + if (rowId >= RowCount) + { + throw new ArgumentOutOfRangeException(nameof(rowId), SR.QueryServiceResultSetStartRowOutOfRange); + } + + using (IFileStreamReader fileStreamReader = fileStreamFactory.GetReader(outputFileName)) + { + return fileStreamReader.ReadRow(fileOffsets[rowId], Columns); + } + } + /// /// Generates a subset of the rows from the result set /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/Utility/Extensions.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/Extensions.cs index 16da91e5..b19c359f 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Utility/Extensions.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Utility/Extensions.cs @@ -39,4 +39,21 @@ namespace Microsoft.SqlTools.ServiceLayer.Utility return isTrue ? "1" : "0"; } } + + internal static class NullableExtensions + { + /// + /// Extension method to evaluate a bool? and determine if it has the value and is true. + /// This way we avoid throwing if the bool? doesn't have a value. + /// + /// The bool? to process + /// + /// true if has a value and it is true + /// false otherwise. + /// + public static bool HasTrue(this bool? obj) + { + return obj.HasValue && obj.Value; + } + } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Utility/SqlScriptFormatter.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/SqlScriptFormatter.cs new file mode 100644 index 00000000..25722ddc --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Utility/SqlScriptFormatter.cs @@ -0,0 +1,263 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.Globalization; +using System.Linq; +using System.Text; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Utility +{ + /// + /// Provides utility for converting arbitrary objects into strings that are ready to be + /// inserted into SQL strings + /// + public class SqlScriptFormatter + { + #region Constants + + public const string NullString = "NULL"; + + private static readonly Dictionary> FormatFunctions = + new Dictionary> + { // CLR Type -------- + {"bigint", (val, col) => SimpleFormatter(val)}, // long + {"bit", (val, col) => FormatBool(val)}, // bool + {"int", (val, col) => SimpleFormatter(val)}, // int + {"smallint", (val, col) => SimpleFormatter(val)}, // short + {"tinyint", (val, col) => SimpleFormatter(val)}, // byte + {"money", (val, col) => FormatMoney(val, "MONEY")}, // Decimal + {"smallmoney", (val, col) => FormatMoney(val, "SMALLMONEY")}, // Decimal + {"decimal", (val, col) => FormatPreciseNumeric(val, col, "DECIMAL")}, // Decimal + {"numeric", (val, col) => FormatPreciseNumeric(val, col, "NUMERIC")}, // Decimal + {"real", (val, col) => FormatFloat(val)}, // float + {"float", (val, col) => FormatDouble(val)}, // double + {"smalldatetime", (val, col) => FormatDateTime(val, "yyyy-MM-dd HH:mm:ss")}, // DateTime + {"datetime", (val, col) => FormatDateTime(val, "yyyy-MM-dd HH:mm:ss.FFF") }, // DateTime + {"datetime2", (val, col) => FormatDateTime(val, "yyyy-MM-dd HH:mm:ss.FFFFFFF")}, // DateTime + {"date", (val, col) => FormatDateTime(val, "yyyy-MM-dd")}, // DateTime + {"datetimeoffset", (val, col) => FormatDateTimeOffset(val)}, // DateTimeOffset + {"time", (val, col) => FormatTimeSpan(val)}, // TimeSpan + {"char", (val, col) => SimpleStringFormatter(val)}, // string + {"nchar", (val, col) => SimpleStringFormatter(val)}, // string + {"varchar", (val, col) => SimpleStringFormatter(val)}, // string + {"nvarchar", (val, col) => SimpleStringFormatter(val)}, // string + {"text", (val, col) => SimpleStringFormatter(val)}, // string + {"ntext", (val, col) => SimpleStringFormatter(val)}, // string + {"xml", (val, col) => SimpleStringFormatter(val)}, // string + {"binary", (val, col) => FormatBinary(val)}, // byte[] + {"varbinary", (val, col) => FormatBinary(val)}, // byte[] + {"image", (val, col) => FormatBinary(val)}, // byte[] + {"uniqueidentifier", (val, col) => SimpleStringFormatter(val)}, // Guid + // Unsupported types: + // *.sys.hierarchyid - cannot cast byte string to hierarchyid + // geography - cannot cast byte string to geography + // geometry - cannot cast byte string to geometry + // timestamp - cannot insert/update timestamp columns + // sql_variant - casting logic isn't good enough + // sysname - it doesn't appear possible to insert a sysname column + }; + + #endregion + + /// + /// Converts an object into a string for SQL script + /// + /// The object to convert + /// The column metadata for the cell to insert + /// String version of the cell value for use in SQL scripts + public static string FormatValue(object value, DbColumn column) + { + Validate.IsNotNull(nameof(column), column); + + // Handle nulls firstly + if (value == null) + { + return NullString; + } + + // Determine how to format based on the column type + string dataType = column.DataTypeName.ToLowerInvariant(); + if (!FormatFunctions.ContainsKey(dataType)) + { + // Attempt to handle UDTs + + // @TODO: to constants file + throw new ArgumentOutOfRangeException(nameof(column.DataTypeName), "A converter for {column type} is not available"); + } + return FormatFunctions[dataType](value, column); + } + + /// + /// Converts a cell value into a string for SQL script + /// + /// The cell to convert + /// The column metadata for the cell to insert + /// String version of the cell value for use in SQL scripts + public static string FormatValue(DbCellValue value, DbColumn column) + { + Validate.IsNotNull(nameof(value), value); + + return FormatValue(value.RawObject, column); + } + + /// + /// Escapes an identifier such as a table name or column name by wrapping it in square brackets + /// + /// The identifier to format + /// Identifier formatted for use in a SQL script + public static string FormatIdentifier(string identifier) + { + return $"[{EscapeString(identifier, ']')}]"; + } + + /// + /// Escapes a multi-part identifier such as a table name or column name with multiple + /// parts split by '.' + /// + /// The identifier to escape + /// The escaped identifier + public static string FormatMultipartIdentifier(string identifier) + { + // If the object is a multi-part identifier (eg, dbo.tablename) split it, and escape as necessary + return FormatMultipartIdentifier(identifier.Split('.')); + } + + /// + /// Escapes a multipart identifier such as a table name, given an array of the parts of the + /// multipart identifier. + /// + /// The parts of the identifier to escape + /// An escaped version of the multipart identifier + public static string FormatMultipartIdentifier(string[] identifiers) + { + IEnumerable escapedParts = identifiers.Select(FormatIdentifier); + return string.Join(".", escapedParts); + } + + #region Private Helpers + + private static string SimpleFormatter(object value) + { + return value.ToString(); + } + + private static string SimpleStringFormatter(object value) + { + return EscapeQuotedSqlString(value.ToString()); + } + + private static string FormatMoney(object value, string type) + { + // we have to manually format the string by ToStringing the value first, and then converting + // the potential (European formatted) comma to a period. + string numericString = ((decimal)value).ToString(CultureInfo.InvariantCulture); + return $"CAST({numericString} AS {type})"; + } + + private static string FormatFloat(object value) + { + // The "R" formatting means "Round Trip", which preserves fidelity + return ((float)value).ToString("R"); + } + + private static string FormatDouble(object value) + { + // The "R" formatting means "Round Trip", which preserves fidelity + return ((double)value).ToString("R"); + } + + private static string FormatBool(object value) + { + // Attempt to cast to bool + bool boolValue = (bool)value; + return boolValue ? "1" : "0"; + } + + private static string FormatPreciseNumeric(object value, DbColumn column, string type) + { + // Make sure we have numeric precision and numeric scale + if (!column.NumericPrecision.HasValue || !column.NumericScale.HasValue) + { + // @TODO Move to constants + throw new InvalidOperationException("Decimal column is missing numeric precision or numeric scale"); + } + + // Convert the value to a decimal, then convert that to a string + string numericString = ((decimal)value).ToString(CultureInfo.InvariantCulture); + return string.Format(CultureInfo.InvariantCulture, "CAST({0} AS {1}({2}, {3}))", + numericString, type, column.NumericPrecision.Value, column.NumericScale.Value); + } + + private static string FormatTimeSpan(object value) + { + // "c" provides "HH:mm:ss.FFFFFFF", and time column accepts up to 7 precision + string timeSpanString = ((TimeSpan)value).ToString("c", CultureInfo.InvariantCulture); + return EscapeQuotedSqlString(timeSpanString); + } + + private static string FormatDateTime(object value, string format) + { + string dateTimeString = ((DateTime)value).ToString(format, CultureInfo.InvariantCulture); + return EscapeQuotedSqlString(dateTimeString); + } + + private static string FormatDateTimeOffset(object value) + { + string dateTimeString = ((DateTimeOffset)value).ToString(CultureInfo.InvariantCulture); + return EscapeQuotedSqlString(dateTimeString); + } + + private static string FormatBinary(object value) + { + byte[] bytes = value as byte[]; + if (bytes == null) + { + // Bypass processing if we can't turn this into a byte[] + return "NULL"; + } + + return "0x" + BitConverter.ToString(bytes).Replace("-", string.Empty); + } + + /// + /// Returns a valid SQL string packaged in single quotes with single quotes inside escaped + /// + /// String to be formatted + /// Formatted SQL string + private static string EscapeQuotedSqlString(string rawString) + { + return $"N'{EscapeString(rawString, '\'')}'"; + } + + /// + /// Replaces all instances of with a duplicate of + /// . For example "can't" becomes "can''t" + /// + /// The string to escape + /// The character to escape + /// The escaped string + private static string EscapeString(string value, char escapeCharacter) + { + Validate.IsNotNull(nameof(value), value); + + StringBuilder sb = new StringBuilder(); + foreach (char c in value) + { + sb.Append(c); + if (escapeCharacter == c) + { + sb.Append(c); + } + } + return sb.ToString(); + } + + #endregion + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/EditData/CellUpdateTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/EditData/CellUpdateTests.cs new file mode 100644 index 00000000..46ebb2b2 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/EditData/CellUpdateTests.cs @@ -0,0 +1,208 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Data.Common; +using Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.EditData +{ + public class CellUpdateTests + { + [Fact] + public void NullColumnTest() + { + // If: I attempt to create a CellUpdate with a null column + // Then: I should get an exception thrown + Assert.Throws(() => new CellUpdate(null, string.Empty)); + } + + [Fact] + public void NullStringValueTest() + { + // If: I attempt to create a CellUpdate with a null string value + // Then: I should get an exception thrown + Assert.Throws(() => new CellUpdate(new CellUpdateTestDbColumn(null), null)); + } + + [Fact] + public void NullStringTest() + { + // If: I attempt to create a CellUpdate to set it to NULL (with mixed cases) + const string nullString = "NULL"; + DbColumn col = new CellUpdateTestDbColumn(typeof(string)); + CellUpdate cu = new CellUpdate(col, nullString); + + // Then: The value should be a DBNull and the string value should be the same as what + // was given + Assert.IsType(cu.Value); + Assert.Equal(DBNull.Value, cu.Value); + Assert.Equal(nullString, cu.ValueAsString); + Assert.Equal(col, cu.Column); + } + + [Fact] + public void NullTextStringTest() + { + // If: I attempt to create a CellUpdate with the text 'NULL' (with mixed case) + DbColumn col = new CellUpdateTestDbColumn(typeof(string)); + CellUpdate cu = new CellUpdate(col, "'NULL'"); + + // Then: The value should be NULL + Assert.IsType(cu.Value); + Assert.Equal("NULL", cu.Value); + Assert.Equal("'NULL'", cu.ValueAsString); + Assert.Equal(col, cu.Column); + } + + [Theory] + [MemberData(nameof(ByteArrayTestParams))] + public void ByteArrayTest(string strValue, byte[] expectedValue, string expectedString) + { + // If: I attempt to create a CellUpdate for a binary column + DbColumn col = new CellUpdateTestDbColumn(typeof(byte[])); + CellUpdate cu = new CellUpdate(col, strValue); + + // Then: The value should be a binary and should match the expected data + Assert.IsType(cu.Value); + Assert.Equal(expectedValue, cu.Value); + Assert.Equal(expectedString, cu.ValueAsString); + Assert.Equal(col, cu.Column); + } + + public static IEnumerable ByteArrayTestParams + { + get + { + // All zero tests + yield return new object[] {"00000000", new byte[] {0x00}, "0x00"}; // Base10 + yield return new object[] {"0x000000", new byte[] {0x00, 0x00, 0x00}, "0x000000"}; // Base16 + yield return new object[] {"0x000", new byte[] {0x00, 0x00}, "0x0000"}; // Base16, odd + + // Single byte tests + yield return new object[] {"50", new byte[] {0x32}, "0x32"}; // Base10 + yield return new object[] {"050", new byte[] {0x32}, "0x32"}; // Base10, leading zeros + yield return new object[] {"0xF0", new byte[] {0xF0}, "0xF0"}; // Base16 + yield return new object[] {"0x0F", new byte[] {0x0F}, "0x0F"}; // Base16, leading zeros + yield return new object[] {"0xF", new byte[] {0x0F}, "0x0F"}; // Base16, odd + + // Two byte tests + yield return new object[] {"1000", new byte[] {0x03, 0xE8}, "0x03E8"}; // Base10 + yield return new object[] {"01000", new byte[] {0x03, 0xE8}, "0x03E8"}; // Base10, leading zeros + yield return new object[] {"0xF001", new byte[] {0xF0, 0x01}, "0xF001"}; // Base16 + yield return new object[] {"0x0F10", new byte[] {0x0F, 0x10}, "0x0F10"}; // Base16, leading zeros + yield return new object[] {"0xF10", new byte[] {0x0F, 0x10}, "0x0F10"}; // Base16, odd + + // Three byte tests + yield return new object[] {"100000", new byte[] {0x01, 0x86, 0xA0}, "0x0186A0"}; // Base10 + yield return new object[] {"0100000", new byte[] {0x01, 0x86, 0xA0}, "0x0186A0"}; // Base10, leading zeros + yield return new object[] {"0x101010", new byte[] {0x10, 0x10, 0x10}, "0x101010"}; // Base16 + yield return new object[] {"0x010101", new byte[] {0x01, 0x01, 0x01}, "0x010101"}; // Base16, leading zeros + yield return new object[] {"0x10101", new byte[] {0x01, 0x01, 0x01}, "0x010101"}; // Base16, odd + + // Four byte tests + yield return new object[] {"20000000", new byte[] {0x01, 0x31, 0x2D, 0x00}, "0x01312D00"}; // Base10 + yield return new object[] {"020000000", new byte[] {0x01, 0x31, 0x2D, 0x00}, "0x01312D00"}; // Base10, leading zeros + yield return new object[] {"0xF0F00101", new byte[] {0xF0, 0xF0, 0x01, 0x01}, "0xF0F00101"}; // Base16 + yield return new object[] {"0x0F0F1010", new byte[] {0x0F, 0x0F, 0x10, 0x10}, "0x0F0F1010"}; // Base16, leading zeros + yield return new object[] {"0xF0F1010", new byte[] {0x0F, 0x0F, 0x10, 0x10}, "0x0F0F1010"}; // Base16, odd + } + } + + [Fact] + public void ByteArrayInvalidFormatTest() + { + // If: I attempt to create a CellUpdate for a binary column + // Then: It should throw an exception + DbColumn col = new CellUpdateTestDbColumn(typeof(byte[])); + Assert.Throws(() => new CellUpdate(col, "this is totally invalid")); + } + + [Theory] + [MemberData(nameof(BoolTestParams))] + public void BoolTest(string input, bool output, string outputString) + { + // If: I attempt to create a CellUpdate for a boolean column + DbColumn col = new CellUpdateTestDbColumn(typeof(bool)); + CellUpdate cu = new CellUpdate(col, input); + + // Then: The value should match what was expected + Assert.IsType(cu.Value); + Assert.Equal(output, cu.Value); + Assert.Equal(outputString, cu.ValueAsString); + Assert.Equal(col, cu.Column); + } + + public static IEnumerable BoolTestParams + { + get + { + yield return new object[] {"1", true, bool.TrueString}; + yield return new object[] {"0", false, bool.FalseString}; + yield return new object[] {bool.TrueString, true, bool.TrueString}; + yield return new object[] {bool.FalseString, false, bool.FalseString}; + } + } + + [Fact] + public void BoolInvalidFormatTest() + { + // If: I create a CellUpdate for a bool column and provide an invalid numeric value + // Then: It should throw an exception + DbColumn col = new CellUpdateTestDbColumn(typeof(bool)); + Assert.Throws(() => new CellUpdate(col, "12345")); + } + + [Theory] + [MemberData(nameof(RoundTripTestParams))] + public void RoundTripTest(Type dbColType, object obj) + { + // Setup: Figure out the test string + string testString = obj.ToString(); + + // If: I attempt to create a CellUpdate for a GUID column + DbColumn col = new CellUpdateTestDbColumn(dbColType); + CellUpdate cu = new CellUpdate(col, testString); + + // Then: The value and type should match what we put in + Assert.IsType(dbColType, cu.Value); + Assert.Equal(obj, cu.Value); + Assert.Equal(testString, cu.ValueAsString); + Assert.Equal(col, cu.Column); + } + + public static IEnumerable RoundTripTestParams + { + get + { + yield return new object[] {typeof(Guid), Guid.NewGuid()}; + yield return new object[] {typeof(TimeSpan), new TimeSpan(0, 1, 20, 0, 123)}; + yield return new object[] {typeof(DateTime), new DateTime(2016, 04, 25, 9, 45, 0)}; + yield return new object[] + { + typeof(DateTimeOffset), + new DateTimeOffset(2016, 04, 25, 9, 45, 0, TimeSpan.FromHours(8)) + }; + yield return new object[] {typeof(long), 1000L}; + yield return new object[] {typeof(decimal), new decimal(3.14)}; + yield return new object[] {typeof(int), 1000}; + yield return new object[] {typeof(short), (short) 1000}; + yield return new object[] {typeof(byte), (byte) 5}; + yield return new object[] {typeof(double), 3.14d}; + yield return new object[] {typeof(float), 3.14f}; + } + } + + private class CellUpdateTestDbColumn : DbColumn + { + public CellUpdateTestDbColumn(Type dataType) + { + DataType = dataType; + } + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/EditData/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/EditData/Common.cs new file mode 100644 index 00000000..3340e7ae --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/EditData/Common.cs @@ -0,0 +1,93 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Collections.Generic; +using System.Data.Common; +using System.Linq; +using System.Threading; +using Microsoft.SqlTools.ServiceLayer.EditData; +using Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; +using Microsoft.SqlTools.ServiceLayer.Utility; +using Moq; + +namespace Microsoft.SqlTools.ServiceLayer.Test.EditData +{ + public class Common + { + public const string OwnerUri = "testFile"; + + public static IEditTableMetadata GetMetadata(DbColumn[] columns, bool allKeys = true, bool isMemoryOptimized = false) + { + // Create a Column Metadata Provider + var columnMetas = columns.Select((c, i) => + new EditColumnWrapper + { + DbColumn = new DbColumnWrapper(c), + EscapedName = c.ColumnName, + Ordinal = i, + IsKey = c.IsIdentity.HasTrue() + }).ToArray(); + + // Create a table metadata provider + var tableMetaMock = new Mock(); + if (allKeys) + { + // All columns should be returned as "keys" + tableMetaMock.Setup(m => m.KeyColumns).Returns(columnMetas); + } + else + { + // All identity columns should be returned as keys + tableMetaMock.Setup(m => m.KeyColumns).Returns(columnMetas.Where(c => c.DbColumn.IsIdentity.HasTrue())); + } + tableMetaMock.Setup(m => m.Columns).Returns(columnMetas); + tableMetaMock.Setup(m => m.IsMemoryOptimized).Returns(isMemoryOptimized); + tableMetaMock.Setup(m => m.EscapedMultipartName).Returns("tbl"); + + return tableMetaMock.Object; + } + + public static DbColumn[] GetColumns(bool includeIdentity) + { + List columns = new List(); + + if (includeIdentity) + { + columns.Add(new TestDbColumn("id", true)); + } + + for (int i = 0; i < 3; i++) + { + columns.Add(new TestDbColumn($"col{i}")); + } + return columns.ToArray(); + } + + public static ResultSet GetResultSet(DbColumn[] columns, bool includeIdentity) + { + object[][] rows = includeIdentity + ? new[] { new object[] { "id", "1", "2", "3" } } + : new[] { new object[] { "1", "2", "3" } }; + var testResultSet = new TestResultSet(columns, rows); + var reader = new TestDbDataReader(new[] { testResultSet }); + var resultSet = new ResultSet(reader, 0, 0, QueryExecution.Common.GetFileStreamFactory(new Dictionary())); + resultSet.ReadResultToEnd(CancellationToken.None).Wait(); + return resultSet; + } + + public static void AddCells(RowEditBase rc, bool includeIdentity) + { + // Skip the first column since if identity, since identity columns can't be updated + int start = includeIdentity ? 1 : 0; + for (int i = start; i < rc.AssociatedResultSet.Columns.Length; i++) + { + rc.SetCell(i, "123"); + } + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/EditData/RowCreateTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/EditData/RowCreateTests.cs new file mode 100644 index 00000000..611b4787 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/EditData/RowCreateTests.cs @@ -0,0 +1,84 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Data.Common; +using System.Text.RegularExpressions; +using Microsoft.SqlTools.ServiceLayer.EditData; +using Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.EditData +{ + public class RowCreateTests + { + [Fact] + public void RowCreateConstruction() + { + // Setup: Create the values to store + const long rowId = 100; + ResultSet rs = QueryExecution.Common.GetBasicExecutedBatch().ResultSets[0]; + IEditTableMetadata etm = Common.GetMetadata(rs.Columns); + + // If: I create a RowCreate instance + RowCreate rc = new RowCreate(rowId, rs, etm); + + // Then: The values I provided should be available + Assert.Equal(rowId, rc.RowId); + Assert.Equal(rs, rc.AssociatedResultSet); + Assert.Equal(etm, rc.AssociatedObjectMetadata); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void GetScript(bool includeIdentity) + { + // Setup: Generate the parameters for the row create + const long rowId = 100; + DbColumn[] columns = Common.GetColumns(includeIdentity); + ResultSet rs = Common.GetResultSet(columns, includeIdentity); + IEditTableMetadata etm = Common.GetMetadata(columns); + + // If: I ask for a script to be generated without an identity column + RowCreate rc = new RowCreate(rowId, rs, etm); + Common.AddCells(rc, includeIdentity); + string script = rc.GetScript(); + + // Then: + // ... The script should not be null, + Assert.NotNull(script); + + // ... It should be formatted as an insert script + Regex r = new Regex(@"INSERT INTO (.+)\((.*)\) VALUES \((.*)\)"); + var m = r.Match(script); + Assert.True(m.Success); + + // ... It should have 3 columns and 3 values (regardless of the presence of an identity col) + string tbl = m.Groups[1].Value; + string cols = m.Groups[2].Value; + string vals = m.Groups[3].Value; + Assert.Equal(etm.EscapedMultipartName, tbl); + Assert.Equal(3, cols.Split(',').Length); + Assert.Equal(3, vals.Split(',').Length); + } + + [Fact] + public void GetScriptMissingCell() + { + // Setup: Generate the parameters for the row create + const long rowId = 100; + DbColumn[] columns = Common.GetColumns(false); + ResultSet rs = Common.GetResultSet(columns, false); + IEditTableMetadata etm = Common.GetMetadata(columns); + + // If: I ask for a script to be generated without setting any values + // Then: An exception should be thrown for missing cells + RowCreate rc = new RowCreate(rowId, rs, etm); + Assert.Throws(() => rc.GetScript()); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/EditData/RowDeleteTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/EditData/RowDeleteTests.cs new file mode 100644 index 00000000..48de0499 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/EditData/RowDeleteTests.cs @@ -0,0 +1,73 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Data.Common; +using Microsoft.SqlTools.ServiceLayer.EditData; +using Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.EditData +{ + public class RowDeleteTests + { + [Fact] + public void RowDeleteConstruction() + { + // Setup: Create the values to store + const long rowId = 100; + ResultSet rs = QueryExecution.Common.GetBasicExecutedBatch().ResultSets[0]; + IEditTableMetadata etm = Common.GetMetadata(rs.Columns); + + // If: I create a RowCreate instance + RowCreate rc = new RowCreate(rowId, rs, etm); + + // Then: The values I provided should be available + Assert.Equal(rowId, rc.RowId); + Assert.Equal(rs, rc.AssociatedResultSet); + Assert.Equal(etm, rc.AssociatedObjectMetadata); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void GetScriptTest(bool isHekaton) + { + DbColumn[] columns = Common.GetColumns(true); + ResultSet rs = Common.GetResultSet(columns, true); + IEditTableMetadata etm = Common.GetMetadata(columns, false, isHekaton); + + // If: I ask for a script to be generated for delete + RowDelete rd = new RowDelete(0, rs, etm); + string script = rd.GetScript(); + + // Then: + // ... The script should not be null + Assert.NotNull(script); + + // ... It should be formatted as a delete script + string scriptStart = $"DELETE FROM {etm.EscapedMultipartName}"; + if (isHekaton) + { + scriptStart += " WITH(SNAPSHOT)"; + } + Assert.StartsWith(scriptStart, script); + } + + [Fact] + public void SetCell() + { + DbColumn[] columns = Common.GetColumns(true); + ResultSet rs = Common.GetResultSet(columns, true); + IEditTableMetadata etm = Common.GetMetadata(columns, false); + + // If: I set a cell on a delete row edit + // Then: It should throw as invalid operation + RowDelete rd = new RowDelete(0, rs, etm); + Assert.Throws(() => rd.SetCell(0, null)); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/EditData/RowEditBaseTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/EditData/RowEditBaseTests.cs new file mode 100644 index 00000000..6b8db71e --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/EditData/RowEditBaseTests.cs @@ -0,0 +1,183 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.Linq; +using System.Text.RegularExpressions; +using System.Threading; +using Microsoft.SqlTools.ServiceLayer.EditData; +using Microsoft.SqlTools.ServiceLayer.EditData.Contracts; +using Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.EditData +{ + public class RowEditBaseTests + { + [Theory] + [InlineData(-1)] // Negative index + [InlineData(100)] // Index larger than number of columns + public void ValidateUpdatableColumnOutOfRange(int columnId) + { + // Setup: Create a result set + ResultSet rs = GetResultSet( + new DbColumn[] { new TestDbColumn("id", true), new TestDbColumn("col1")}, + new object[] { "id", "1" }); + + // If: I validate a column ID that is out of range + // Then: It should throw + RowEditTester tester = new RowEditTester(rs, null); + Assert.Throws(() => tester.ValidateColumn(columnId)); + } + + [Fact] + public void ValidateUpdatableColumnNotUpdatable() + { + // Setup: Create a result set with an identity column + ResultSet rs = GetResultSet( + new DbColumn[] { new TestDbColumn("id", true), new TestDbColumn("col1") }, + new object[] { "id", "1" }); + + // If: I validate a column ID that is not updatable + // Then: It should throw + RowEditTester tester = new RowEditTester(rs, null); + Assert.Throws(() => tester.ValidateColumn(0)); + } + + [Theory] + [MemberData(nameof(GetWhereClauseIsNotNullData))] + public void GetWhereClauseSimple(DbColumn col, object val, string nullClause) + { + // Setup: Create a result set and metadata provider with a single column + var cols = new[] {col}; + ResultSet rs = GetResultSet(cols, new[] {val}); + IEditTableMetadata etm = Common.GetMetadata(cols); + + RowEditTester rt = new RowEditTester(rs, etm); + rt.ValidateWhereClauseSingleKey(nullClause); + } + + public static IEnumerable GetWhereClauseIsNotNullData + { + get + { + yield return new object[] {new TestDbColumn("col"), DBNull.Value, "IS NULL"}; + yield return new object[] {new TestDbColumn("col", "VARBINARY", typeof(byte[])), new byte[5], "IS NOT NULL"}; + yield return new object[] {new TestDbColumn("col", "TEXT", typeof(string)), "abc", "IS NOT NULL"}; + yield return new object[] {new TestDbColumn("col", "NTEXT", typeof(string)), "abc", "IS NOT NULL"}; + } + } + + [Fact] + public void GetWhereClauseMultipleKeyColumns() + { + // Setup: Create a result set and metadata provider with multiple key columns + DbColumn[] cols = {new TestDbColumn("col1"), new TestDbColumn("col2")}; + ResultSet rs = GetResultSet(cols, new object[] {"abc", "def"}); + IEditTableMetadata etm = Common.GetMetadata(cols); + + RowEditTester rt = new RowEditTester(rs, etm); + rt.ValidateWhereClauseMultipleKeys(); + } + + [Fact] + public void GetWhereClauseNoKeyColumns() + { + // Setup: Create a result set and metadata provider with no key columns + DbColumn[] cols = {new TestDbColumn("col1"), new TestDbColumn("col2")}; + ResultSet rs = GetResultSet(cols, new object[] {"abc", "def"}); + IEditTableMetadata etm = Common.GetMetadata(new DbColumn[] {}); + + RowEditTester rt = new RowEditTester(rs, etm); + rt.ValidateWhereClauseNoKeys(); + } + + private static ResultSet GetResultSet(DbColumn[] columns, object[] row) + { + object[][] rows = {row}; + var testResultSet = new TestResultSet(columns, rows); + var testReader = new TestDbDataReader(new [] {testResultSet}); + var resultSet = new ResultSet(testReader, 0,0, QueryExecution.Common.GetFileStreamFactory(new Dictionary())); + resultSet.ReadResultToEnd(CancellationToken.None).Wait(); + return resultSet; + } + + private class RowEditTester : RowEditBase + { + public RowEditTester(ResultSet rs, IEditTableMetadata meta) : base(0, rs, meta) { } + + public void ValidateColumn(int columnId) + { + ValidateColumnIsUpdatable(columnId); + } + + // ReSharper disable once UnusedParameter.Local + public void ValidateWhereClauseSingleKey(string nullValue) + { + // If: I generate a where clause with one is null column value + WhereClause wc = GetWhereClause(false); + + // Then: + // ... There should only be one component + Assert.Equal(1, wc.ClauseComponents.Count); + + // ... Parameterization should be empty + Assert.Empty(wc.Parameters); + + // ... The component should contain the name of the column and be null + Assert.Equal( + $"({AssociatedObjectMetadata.Columns.First().EscapedName} {nullValue})", + wc.ClauseComponents[0]); + + // ... The complete clause should contain a single WHERE + Assert.Equal($"WHERE {wc.ClauseComponents[0]}", wc.CommandText); + } + + public void ValidateWhereClauseMultipleKeys() + { + // If: I generate a where clause with multiple key columns + WhereClause wc = GetWhereClause(false); + + // Then: + // ... There should two components + var keys = AssociatedObjectMetadata.KeyColumns.ToArray(); + Assert.Equal(keys.Length, wc.ClauseComponents.Count); + + // ... Parameterization should be empty + Assert.Empty(wc.Parameters); + + // ... The components should contain the name of the column and the value + Regex r = new Regex(@"\([0-9a-z]+ = .+\)"); + Assert.All(wc.ClauseComponents, s => Assert.True(r.IsMatch(s))); + + // ... The complete clause should contain multiple cause components joined + // with and + Assert.True(wc.CommandText.StartsWith("WHERE ")); + Assert.True(wc.CommandText.EndsWith(string.Join(" AND ", wc.ClauseComponents))); + } + + public void ValidateWhereClauseNoKeys() + { + // If: I generate a where clause from metadata that doesn't have keys + // Then: An exception should be thrown + Assert.Throws(() => GetWhereClause(false)); + } + + public override string GetScript() + { + throw new NotImplementedException(); + } + + public override EditUpdateCellResult SetCell(int columnId, string newValue) + { + throw new NotImplementedException(); + } + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/EditData/RowUpdateTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/EditData/RowUpdateTests.cs new file mode 100644 index 00000000..821ba512 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/EditData/RowUpdateTests.cs @@ -0,0 +1,105 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Data.Common; +using System.Text.RegularExpressions; +using Microsoft.SqlTools.ServiceLayer.EditData; +using Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.EditData +{ + public class RowUpdateTests + { + [Fact] + public void RowUpdateConstruction() + { + // Setup: Create the values to store + const long rowId = 0; + ResultSet rs = QueryExecution.Common.GetBasicExecutedBatch().ResultSets[0]; + IEditTableMetadata etm = Common.GetMetadata(rs.Columns); + + // If: I create a RowUpdate instance + RowUpdate rc = new RowUpdate(rowId, rs, etm); + + // Then: The values I provided should be available + Assert.Equal(rowId, rc.RowId); + Assert.Equal(rs, rc.AssociatedResultSet); + Assert.Equal(etm, rc.AssociatedObjectMetadata); + } + + [Fact] + public void ImplicitRevertTest() + { + // Setup: Create a fake table to update + DbColumn[] columns = Common.GetColumns(true); + ResultSet rs = Common.GetResultSet(columns, true); + IEditTableMetadata etm = Common.GetMetadata(columns); + + // If: + // ... I add updates to all the cells in the row + RowUpdate ru = new RowUpdate(0, rs, etm); + Common.AddCells(ru, true); + + // ... Then I update a cell back to it's old value + var output = ru.SetCell(1, (string) rs.GetRow(0)[1].RawObject); + + // Then: + // ... The output should indicate a revert + Assert.NotNull(output); + Assert.True(output.IsRevert); + Assert.False(output.HasCorrections); + Assert.False(output.IsNull); + Assert.Equal(rs.GetRow(0)[1].DisplayValue, output.NewValue); + + // ... It should be formatted as an update script + Regex r = new Regex(@"UPDATE .+ SET (.*) WHERE"); + var m = r.Match(ru.GetScript()); + + // ... It should have 2 updates + string updates = m.Groups[1].Value; + string[] updateSplit = updates.Split(','); + Assert.Equal(2, updateSplit.Length); + Assert.All(updateSplit, s => Assert.Equal(2, s.Split('=').Length)); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void GetScriptTest(bool isHekaton) + { + // Setup: Create a fake table to update + DbColumn[] columns = Common.GetColumns(true); + ResultSet rs = Common.GetResultSet(columns, true); + IEditTableMetadata etm = Common.GetMetadata(columns, false, isHekaton); + + // If: I ask for a script to be generated for update + RowUpdate ru = new RowUpdate(0, rs, etm); + Common.AddCells(ru, true); + string script = ru.GetScript(); + + // Then: + // ... The script should not be null + Assert.NotNull(script); + + // ... It should be formatted as an update script + string regexString = isHekaton + ? @"UPDATE (.+) WITH \(SNAPSHOT\) SET (.*) WHERE .+" + : @"UPDATE (.+) SET (.*) WHERE .+"; + Regex r = new Regex(regexString); + var m = r.Match(script); + Assert.True(m.Success); + + // ... It should have 3 updates + string tbl = m.Groups[1].Value; + string updates = m.Groups[2].Value; + string[] updateSplit = updates.Split(','); + Assert.Equal(etm.EscapedMultipartName, tbl); + Assert.Equal(3, updateSplit.Length); + Assert.All(updateSplit, s => Assert.Equal(2, s.Split('=').Length)); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/EditData/ServiceIntegrationTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/EditData/ServiceIntegrationTests.cs new file mode 100644 index 00000000..58e5446f --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/EditData/ServiceIntegrationTests.cs @@ -0,0 +1,239 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.EditData; +using Microsoft.SqlTools.ServiceLayer.EditData.Contracts; +using Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; +using Moq; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.EditData +{ + public class ServiceIntegrationTests + { + #region Session Operation Helper Tests + + [Theory] + [InlineData(null)] + [InlineData("")] + [InlineData(" \t\n\r")] + [InlineData("Does not exist")] + public async Task NullOrMissingSessionId(string sessionId) + { + // Setup: + // ... Create a edit data service + var eds = new EditDataService(null, null, null); + + // ... Create a session params that returns the provided session ID + var mockParams = new EditCreateRowParams {OwnerUri = sessionId}; + + // If: I ask to perform an action that requires a session + // Then: I should get an error from it + var efv = new EventFlowValidator() + .AddStandardErrorValidation() + .Complete(); + await eds.HandleSessionRequest(mockParams, efv.Object, session => null); + efv.Validate(); + } + + [Fact] + public async Task OperationThrows() + { + // Setup: + // ... Create an edit data service with a session + var eds = new EditDataService(null, null, null); + eds.ActiveSessions[Common.OwnerUri] = GetDefaultSession(); + + // ... Create a session param that returns the common owner uri + var mockParams = new EditCreateRowParams { OwnerUri = Common.OwnerUri }; + + // If: I ask to perform an action that requires a session + // Then: I should get an error from it + var efv = new EventFlowValidator() + .AddStandardErrorValidation() + .Complete(); + await eds.HandleSessionRequest(mockParams, efv.Object, s => { throw new Exception(); }); + efv.Validate(); + } + + #endregion + + #region Dispose Tests + + [Theory] + [InlineData(null)] + [InlineData("")] + [InlineData(" \t\n\r")] + [InlineData("Does not exist")] + public async Task DisposeNullOrMissingSessionId(string sessionId) + { + // Setup: Create a edit data service + var eds = new EditDataService(null, null, null); + + // If: I ask to perform an action that requires a session + // Then: I should get an error from it + var efv = new EventFlowValidator() + .AddStandardErrorValidation() + .Complete(); + await eds.HandleDisposeRequest(new EditDisposeParams {OwnerUri = sessionId}, efv.Object); + efv.Validate(); + } + + [Fact] + public async Task DisposeSuccess() + { + // Setup: Create an edit data service with a session + var eds = new EditDataService(null, null, null); + eds.ActiveSessions[Common.OwnerUri] = GetDefaultSession(); + + // If: I ask to dispose of an existing session + var efv = new EventFlowValidator() + .AddResultValidation(Assert.NotNull) + .Complete(); + await eds.HandleDisposeRequest(new EditDisposeParams {OwnerUri = Common.OwnerUri}, efv.Object); + + // Then: + // ... It should have completed successfully + efv.Validate(); + + // ... And the session should have been removed from the active session list + Assert.Empty(eds.ActiveSessions); + } + + #endregion + + [Fact] + public async Task DeleteSuccess() + { + // Setup: Create an edit data service with a session + var eds = new EditDataService(null, null, null); + eds.ActiveSessions[Common.OwnerUri] = GetDefaultSession(); + + // If: I validly ask to delete a row + var efv = new EventFlowValidator() + .AddResultValidation(Assert.NotNull) + .Complete(); + await eds.HandleDeleteRowRequest(new EditDeleteRowParams {OwnerUri = Common.OwnerUri, RowId = 0}, efv.Object); + + // Then: + // ... It should be successful + efv.Validate(); + + // ... There should be a delete in the session + Session s = eds.ActiveSessions[Common.OwnerUri]; + Assert.True(s.EditCache.Any(e => e.Value is RowDelete)); + } + + [Fact] + public async Task CreateSucceeds() + { + // Setup: Create an edit data service with a session + var eds = new EditDataService(null, null, null); + eds.ActiveSessions[Common.OwnerUri] = GetDefaultSession(); + + // If: I ask to create a row from a non existant session + var efv = new EventFlowValidator() + .AddResultValidation(ecrr => { Assert.True(ecrr.NewRowId > 0); }) + .Complete(); + await eds.HandleCreateRowRequest(new EditCreateRowParams { OwnerUri = Common.OwnerUri }, efv.Object); + + // Then: + // ... It should have been successful + efv.Validate(); + + // ... There should be a create in the session + Session s = eds.ActiveSessions[Common.OwnerUri]; + Assert.True(s.EditCache.Any(e => e.Value is RowCreate)); + } + + [Fact] + public async Task RevertSucceeds() + { + // Setup: Create an edit data service with a session that has an pending edit + var eds = new EditDataService(null, null, null); + var session = GetDefaultSession(); + session.EditCache[0] = new Mock().Object; + eds.ActiveSessions[Common.OwnerUri] = session; + + // If: I ask to revert a row that has a pending edit + var efv = new EventFlowValidator() + .AddResultValidation(Assert.NotNull) + .Complete(); + await eds.HandleRevertRowRequest(new EditRevertRowParams { OwnerUri = Common.OwnerUri, RowId = 0}, efv.Object); + + // Then: + // ... It should have succeeded + efv.Validate(); + + // ... The edit cache should be empty again + Session s = eds.ActiveSessions[Common.OwnerUri]; + Assert.Empty(s.EditCache); + } + + [Fact] + public async Task UpdateSuccess() + { + // Setup: Create an edit data service with a session + var eds = new EditDataService(null, null, null); + var session = GetDefaultSession(); + eds.ActiveSessions[Common.OwnerUri] = session; + var edit = new Mock(); + edit.Setup(e => e.SetCell(It.IsAny(), It.IsAny())).Returns(new EditUpdateCellResult + { + NewValue = string.Empty, + HasCorrections = true, + IsRevert = false, + IsNull = false + }); + session.EditCache[0] = edit.Object; + + // If: I validly ask to update a cell + var efv = new EventFlowValidator() + .AddResultValidation(eucr => + { + Assert.NotNull(eucr); + Assert.NotNull(eucr.NewValue); + Assert.False(eucr.IsRevert); + Assert.False(eucr.IsNull); + }) + .Complete(); + await eds.HandleUpdateCellRequest(new EditUpdateCellParams { OwnerUri = Common.OwnerUri, RowId = 0}, efv.Object); + + // Then: + // ... It should be successful + efv.Validate(); + + // ... Set cell should have been called once + edit.Verify(e => e.SetCell(It.IsAny(), It.IsAny()), Times.Once); + } + + private static Session GetDefaultSession() + { + // ... Create a session with a proper query and metadata + Query q = QueryExecution.Common.GetBasicExecutedQuery(); + ResultSet rs = q.Batches[0].ResultSets[0]; + IEditTableMetadata etm = Common.GetMetadata(rs.Columns); + Session s = new Session(rs, etm); + return s; + } + } + + public static class EditServiceEventFlowValidatorExtensions + { + public static EventFlowValidator AddStandardErrorValidation(this EventFlowValidator evf) + { + return evf.AddErrorValidation(p => + { + Assert.NotNull(p); + Assert.NotEmpty(p); + }); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/EditData/SessionTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/EditData/SessionTests.cs new file mode 100644 index 00000000..f6446a98 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/EditData/SessionTests.cs @@ -0,0 +1,381 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.IO; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.EditData; +using Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.SqlContext; +using Microsoft.SqlTools.ServiceLayer.Test.Common; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; +using Moq; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.EditData +{ + public class SessionTests + { + #region Construction Tests + + [Fact] + public void SessionConstructionNullQuery() + { + // If: I create a session object without a null query + // Then: It should throw an exception + Assert.Throws(() => new Session(null, Common.GetMetadata(new DbColumn[] {}))); + } + + [Fact] + public void SessionConstructionNullMetadataProvider() + { + // If: I create a session object without a null metadata provider + // Then: It should throw an exception + Query q = QueryExecution.Common.GetBasicExecutedQuery(); + ResultSet rs = q.Batches[0].ResultSets[0]; + Assert.Throws(() => new Session(rs, null)); + } + + [Fact] + public void SessionConstructionValid() + { + // If: I create a session object with a proper query and metadata + Query q = QueryExecution.Common.GetBasicExecutedQuery(); + ResultSet rs = q.Batches[0].ResultSets[0]; + IEditTableMetadata etm = Common.GetMetadata(rs.Columns); + Session s = new Session(rs, etm); + + // Then: + // ... The edit cache should exist and be empty + Assert.NotNull(s.EditCache); + Assert.Empty(s.EditCache); + + // ... The next row ID should be equivalent to the number of rows in the result set + Assert.Equal(q.Batches[0].ResultSets[0].RowCount, s.NextRowId); + } + + #endregion + + #region Validate Tests + + [Fact] + public void SessionValidateUnfinishedQuery() + { + // If: I create a session object with a query that hasn't finished execution + // Then: It should throw an exception + Query q = QueryExecution.Common.GetBasicExecutedQuery(); + q.HasExecuted = false; + Assert.Throws(() => Session.ValidateQueryForSession(q)); + } + + [Fact] + public void SessionValidateIncorrectResultSet() + { + // Setup: Create a query that yields >1 result sets + TestResultSet[] results = + { + QueryExecution.Common.StandardTestResultSet, + QueryExecution.Common.StandardTestResultSet + }; + + // @TODO: Fix when the connection service is fixed + ConnectionInfo ci = QueryExecution.Common.CreateConnectedConnectionInfo(results, false); + ConnectionService.Instance.OwnerToConnectionMap[ci.OwnerUri] = ci; + + var fsf = QueryExecution.Common.GetFileStreamFactory(new Dictionary()); + Query query = new Query(QueryExecution.Common.StandardQuery, ci, new QueryExecutionSettings(), fsf); + query.Execute(); + query.ExecutionTask.Wait(); + + // If: I create a session object with a query that has !=1 result sets + // Then: It should throw an exception + Assert.Throws(() => Session.ValidateQueryForSession(query)); + } + + [Fact] + public void SessionValidateValidResultSet() + { + // If: I validate a query for a session with a valid query + Query q = QueryExecution.Common.GetBasicExecutedQuery(); + ResultSet rs = Session.ValidateQueryForSession(q); + + // Then: I should get the only result set back + Assert.NotNull(rs); + } + + #endregion + + #region Create Row Tests + + [Fact] + public void CreateRowAddFailure() + { + // NOTE: This scenario should theoretically never occur, but is tested for completeness + // Setup: + // ... Create a session with a proper query and metadata + Query q = QueryExecution.Common.GetBasicExecutedQuery(); + ResultSet rs = q.Batches[0].ResultSets[0]; + IEditTableMetadata etm = Common.GetMetadata(rs.Columns); + Session s = new Session(rs, etm); + + // ... Add a mock edit to the edit cache to cause the .TryAdd to fail + var mockEdit = new Mock().Object; + s.EditCache[rs.RowCount] = mockEdit; + + // If: I create a row in the session + // Then: + // ... An exception should be thrown + Assert.Throws(() => s.CreateRow()); + + // ... The mock edit should still exist + Assert.Equal(mockEdit, s.EditCache[rs.RowCount]); + + // ... The next row ID should not have changes + Assert.Equal(rs.RowCount, s.NextRowId); + } + + [Fact] + public void CreateRowSuccess() + { + // Setup: Create a session with a proper query and metadata + Query q = QueryExecution.Common.GetBasicExecutedQuery(); + ResultSet rs = q.Batches[0].ResultSets[0]; + IEditTableMetadata etm = Common.GetMetadata(rs.Columns); + Session s = new Session(rs, etm); + + // If: I add a row to the session + long newId = s.CreateRow(); + + // Then: + // ... The new ID should be equal to the row count + Assert.Equal(rs.RowCount, newId); + + // ... The next row ID should have been incremented + Assert.Equal(rs.RowCount + 1, s.NextRowId); + + // ... There should be a new row create object in the cache + Assert.Contains(newId, s.EditCache.Keys); + Assert.IsType(s.EditCache[newId]); + } + + #endregion + + [Theory] + [MemberData(nameof(RowIdOutOfRangeData))] + public void RowIdOutOfRange(long rowId, Action testAction) + { + // Setup: Create a session with a proper query and metadata + Query q = QueryExecution.Common.GetBasicExecutedQuery(); + ResultSet rs = q.Batches[0].ResultSets[0]; + IEditTableMetadata etm = Common.GetMetadata(rs.Columns); + Session s = new Session(rs, etm); + + // If: I delete a row that is out of range for the result set + // Then: I should get an exception + Assert.Throws(() => testAction(s, rowId)); + } + + public static IEnumerable RowIdOutOfRangeData + { + get + { + // Delete Row + Action delAction = (s, l) => s.DeleteRow(l); + yield return new object[] { -1L, delAction }; + yield return new object[] { 100L, delAction }; + + // Update Cell + Action upAction = (s, l) => s.UpdateCell(l, 0, null); + yield return new object[] { -1L, upAction }; + yield return new object[] { 100L, upAction }; + } + } + + #region Delete Row Tests + + [Fact] + public void DeleteRowAddFailure() + { + // Setup: + // ... Create a session with a proper query and metadata + Query q = QueryExecution.Common.GetBasicExecutedQuery(); + ResultSet rs = q.Batches[0].ResultSets[0]; + IEditTableMetadata etm = Common.GetMetadata(rs.Columns); + Session s = new Session(rs, etm); + + // ... Add a mock edit to the edit cache to cause the .TryAdd to fail + var mockEdit = new Mock().Object; + s.EditCache[0] = mockEdit; + + // If: I delete a row in the session + // Then: + // ... An exception should be thrown + Assert.Throws(() => s.DeleteRow(0)); + + // ... The mock edit should still exist + Assert.Equal(mockEdit, s.EditCache[0]); + } + + [Fact] + public void DeleteRowSuccess() + { + // Setup: Create a session with a proper query and metadata + Query q = QueryExecution.Common.GetBasicExecutedQuery(); + ResultSet rs = q.Batches[0].ResultSets[0]; + IEditTableMetadata etm = Common.GetMetadata(rs.Columns); + Session s = new Session(rs, etm); + + // If: I add a row to the session + s.DeleteRow(0); + + // Then: There should be a new row delete object in the cache + Assert.Contains(0, s.EditCache.Keys); + Assert.IsType(s.EditCache[0]); + } + + #endregion + + #region Revert Row Tests + + [Fact] + public void RevertRowOutOfRange() + { + // Setup: Create a session with a proper query and metadata + Query q = QueryExecution.Common.GetBasicExecutedQuery(); + ResultSet rs = q.Batches[0].ResultSets[0]; + IEditTableMetadata etm = Common.GetMetadata(rs.Columns); + Session s = new Session(rs, etm); + + // If: I revert a row that doesn't have any pending changes + // Then: I should get an exception + Assert.Throws(() => s.RevertRow(0)); + } + + [Fact] + public void RevertRowSuccess() + { + // Setup: + // ... Create a session with a proper query and metadata + Query q = QueryExecution.Common.GetBasicExecutedQuery(); + ResultSet rs = q.Batches[0].ResultSets[0]; + IEditTableMetadata etm = Common.GetMetadata(rs.Columns); + Session s = new Session(rs, etm); + + // ... Add a mock edit to the edit cache to cause the .TryAdd to fail + var mockEdit = new Mock().Object; + s.EditCache[0] = mockEdit; + + // If: I revert the row that has a pending update + s.RevertRow(0); + + // Then: + // ... The edit cache should not contain a pending edit for the row + Assert.DoesNotContain(0, s.EditCache.Keys); + } + + #endregion + + #region Update Cell Tests + + [Fact] + public void UpdateCellExisting() + { + // Setup: + // ... Create a session with a proper query and metadata + Query q = QueryExecution.Common.GetBasicExecutedQuery(); + ResultSet rs = q.Batches[0].ResultSets[0]; + IEditTableMetadata etm = Common.GetMetadata(rs.Columns); + Session s = new Session(rs, etm); + + // ... Add a mock edit to the edit cache to cause the .TryAdd to fail + var mockEdit = new Mock(); + mockEdit.Setup(e => e.SetCell(It.IsAny(), It.IsAny())); + s.EditCache[0] = mockEdit.Object; + + // If: I update a cell on a row that already has a pending edit + s.UpdateCell(0, 0, null); + + // Then: + // ... The mock update should still be in the cache + // ... And it should have had set cell called on it + Assert.Contains(mockEdit.Object, s.EditCache.Values); + } + + [Fact] + public void UpdateCellNew() + { + // Setup: + // ... Create a session with a proper query and metadata + Query q = QueryExecution.Common.GetBasicExecutedQuery(); + ResultSet rs = q.Batches[0].ResultSets[0]; + IEditTableMetadata etm = Common.GetMetadata(rs.Columns); + Session s = new Session(rs, etm); + + // If: I update a cell on a row that does not have a pending edit + s.UpdateCell(0, 0, ""); + + // Then: + // ... A new update row edit should have been added to the cache + Assert.Contains(0, s.EditCache.Keys); + Assert.IsType(s.EditCache[0]); + } + + #endregion + + #region Script Edits Tests + + [Theory] + [InlineData(null)] + [InlineData("")] + [InlineData(" \t\r\n")] + public void ScriptNullOrEmptyOutput(string outputPath) + { + // Setup: Create a session with a proper query and metadata + Query q = QueryExecution.Common.GetBasicExecutedQuery(); + ResultSet rs = q.Batches[0].ResultSets[0]; + IEditTableMetadata etm = Common.GetMetadata(rs.Columns); + Session s = new Session(rs, etm); + + // If: I try to script the edit cache with a null or whitespace output path + // Then: It should throw an exception + Assert.Throws(() => s.ScriptEdits(outputPath)); + } + + [Fact] + public void ScriptProvidedOutputPath() + { + // Setup: + // ... Create a session with a proper query and metadata + Query q = QueryExecution.Common.GetBasicExecutedQuery(); + ResultSet rs = q.Batches[0].ResultSets[0]; + IEditTableMetadata etm = Common.GetMetadata(rs.Columns); + Session s = new Session(rs, etm); + + // ... Add two mock edits that will generate a script + Mock edit = new Mock(); + edit.Setup(e => e.GetScript()).Returns("test"); + s.EditCache[0] = edit.Object; + s.EditCache[1] = edit.Object; + + using (SelfCleaningTempFile file = new SelfCleaningTempFile()) + { + // If: I script the edit cache to a local output path + string outputPath = s.ScriptEdits(file.FilePath); + + // Then: + // ... The output path used should be the same as the one we provided + Assert.Equal(file.FilePath, outputPath); + + // ... The written file should have two lines, one for each edit + Assert.Equal(2, File.ReadAllLines(outputPath).Length); + } + } + + #endregion + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs index 0cd1b691..fef43fec 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs @@ -92,7 +92,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution public static Query GetBasicExecutedQuery() { - ConnectionInfo ci = CreateTestConnectionInfo(StandardTestDataSet, false); + ConnectionInfo ci = CreateConnectedConnectionInfo(StandardTestDataSet, false); // Query won't be able to request a new query DbConnection unless the ConnectionService has a // ConnectionInfo with the same URI as the query, so we will manually set it @@ -106,7 +106,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution public static Query GetBasicExecutedQuery(QueryExecutionSettings querySettings) { - ConnectionInfo ci = CreateTestConnectionInfo(StandardTestDataSet, false); + ConnectionInfo ci = CreateConnectedConnectionInfo(StandardTestDataSet, false); // Query won't be able to request a new query DbConnection unless the ConnectionService has a // ConnectionInfo with the same URI as the query, so we will manually set it diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/DbColumnWrapperTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/DbColumnWrapperTests.cs index ab165105..d85551df 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/DbColumnWrapperTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/DbColumnWrapperTests.cs @@ -17,7 +17,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution /// /// Test DbColumn derived class /// - class TestColumn : DbColumn + private class TestColumn : DbColumn { public TestColumn( string dataTypeName = null, @@ -53,15 +53,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution } /// - /// Basic data type and properites test + /// Basic data type and properties test /// [Fact] public void DataTypeAndPropertiesTest() - { - // check that data types array contains items - var serverDataTypes = DbColumnWrapper.AllServerDataTypes; - Assert.True(serverDataTypes.Count > 0); - + { // check default constructor doesn't throw Assert.NotNull(new DbColumnWrapper()); diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/ServiceIntegrationTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/ServiceIntegrationTests.cs index 5610513a..1ff5f604 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/ServiceIntegrationTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/ServiceIntegrationTests.cs @@ -32,7 +32,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution var queryService = new QueryExecutionService(null, workspaceService); // If: I attempt to get query text from execute document params (entire document) - var queryParams = new ExecuteDocumentSelectionParams {OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument}; + var queryParams = new ExecuteDocumentSelectionParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument }; var queryText = queryService.GetSqlText(queryParams); // Then: The text should match the constructed query @@ -49,7 +49,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution var queryService = new QueryExecutionService(null, workspaceService); // If: I attempt to get query text from execute document params (partial document) - var queryParams = new ExecuteDocumentSelectionParams {OwnerUri = Common.OwnerUri, QuerySelection = Common.SubsectionDocument}; + var queryParams = new ExecuteDocumentSelectionParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.SubsectionDocument }; var queryText = queryService.GetSqlText(queryParams); // Then: The text should be a subset of the constructed query @@ -65,7 +65,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution var queryService = new QueryExecutionService(null, null); // If: I attempt to get query text from execute string params - var queryParams = new ExecuteStringParams {OwnerUri = Common.OwnerUri, Query = Common.StandardQuery}; + var queryParams = new ExecuteStringParams { OwnerUri = Common.OwnerUri, Query = Common.StandardQuery }; var queryText = queryService.GetSqlText(queryParams); // Then: The text should match the standard query @@ -97,14 +97,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution // Setup: Create a query service var qes = new QueryExecutionService(null, null); var eventSender = new EventFlowValidator().Complete().Object; - Func successFunc = () => Task.FromResult(0); - Func errorFunc = Task.FromResult; - - // If: I call the inter-service API to execute with a null execute params + + // If: I call the inter-service API to execute with a null execute params // Then: It should throw await Assert.ThrowsAsync( - () => qes.InterServiceExecuteQuery(null, eventSender, successFunc, errorFunc)); + () => qes.InterServiceExecuteQuery(null, eventSender, null, null, null, null)); } [Fact] @@ -113,43 +111,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution // Setup: Create a query service, and execute params var qes = new QueryExecutionService(null, null); var executeParams = new ExecuteStringParams(); - Func successFunc = () => Task.FromResult(0); - Func errorFunc = Task.FromResult; // If: I call the inter-service API to execute a query with a a null event sender // Then: It should throw await Assert.ThrowsAsync( - () => qes.InterServiceExecuteQuery(executeParams, null, successFunc, errorFunc)); - } - - [Fact] - public async Task InterServiceExecuteNullSuccessFunc() - { - // Setup: Create a query service, and execute params - var qes = new QueryExecutionService(null, null); - var executeParams = new ExecuteStringParams(); - var eventSender = new EventFlowValidator().Complete().Object; - Func errorFunc = Task.FromResult; - - // If: I call the inter-service API to execute a query with a a null success function - // Then: It should throw - await Assert.ThrowsAsync( - () => qes.InterServiceExecuteQuery(executeParams, eventSender, null, errorFunc)); - } - - [Fact] - public async Task InterServiceExecuteNullFailureFunc() - { - // Setup: Create a query service, and execute params - var qes = new QueryExecutionService(null, null); - var executeParams = new ExecuteStringParams(); - var eventSender = new EventFlowValidator().Complete().Object; - Func successFunc = () => Task.FromResult(0); - - // If: I call the inter-service API to execute a query with a a null failure function - // Then: It should throw - await Assert.ThrowsAsync( - () => qes.InterServiceExecuteQuery(executeParams, eventSender, successFunc, null)); + () => qes.InterServiceExecuteQuery(executeParams, null, null, null, null, null)); } [Fact] @@ -184,6 +150,41 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution // NOTE: In order to limit test duplication, we're running the ExecuteDocumentSelection // version of execute query. The code paths are almost identical. + [Fact] + private async Task QueryExecuteAllBatchesNoOp() + { + // If: + // ... I request to execute a valid query with all batches as no op + var workspaceService = GetDefaultWorkspaceService(string.Format("{0}\r\nGO\r\n{0}", Common.NoOpQuery)); + var queryService = Common.GetPrimedExecutionService(null, true, false, workspaceService); + var queryParams = new ExecuteDocumentSelectionParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; + + var efv = new EventFlowValidator() + .AddStandardQueryResultValidator() + .AddStandardBatchStartValidator() + .AddStandardMessageValidator() + .AddStandardBatchCompleteValidator() + .AddStandardBatchCompleteValidator() + .AddStandardMessageValidator() + .AddStandardBatchCompleteValidator() + .AddEventValidation(QueryCompleteEvent.Type, p => + { + // Validate OwnerURI matches + Assert.Equal(Common.OwnerUri, p.OwnerUri); + Assert.NotNull(p.BatchSummaries); + Assert.Equal(2, p.BatchSummaries.Length); + Assert.All(p.BatchSummaries, bs => Assert.Equal(0, bs.ResultSetSummaries.Length)); + }).Complete(); + await Common.AwaitExecution(queryService, queryParams, efv.Object); + + // Then: + // ... All events should have been called as per their flow validator + efv.Validate(); + + // ... There should be one active query + Assert.Equal(1, queryService.ActiveQueries.Count); + } + [Fact] public async Task QueryExecuteSingleBatchNoResultsTest() { @@ -191,7 +192,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution // ... I request to execute a valid query with no results var workspaceService = GetDefaultWorkspaceService(Common.StandardQuery); var queryService = Common.GetPrimedExecutionService(null, true, false, workspaceService); - var queryParams = new ExecuteDocumentSelectionParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri}; + var queryParams = new ExecuteDocumentSelectionParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; var efv = new EventFlowValidator() .AddStandardQueryResultValidator() @@ -219,7 +220,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution var workspaceService = GetDefaultWorkspaceService(Common.StandardQuery); var queryService = Common.GetPrimedExecutionService(Common.StandardTestDataSet, true, false, workspaceService); - var queryParams = new ExecuteDocumentSelectionParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument}; + var queryParams = new ExecuteDocumentSelectionParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument }; var efv = new EventFlowValidator() .AddStandardQueryResultValidator() @@ -247,7 +248,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution var workspaceService = GetDefaultWorkspaceService(Common.StandardQuery); var dataset = new[] {Common.StandardTestResultSet, Common.StandardTestResultSet}; var queryService = Common.GetPrimedExecutionService(dataset, true, false, workspaceService); - var queryParams = new ExecuteDocumentSelectionParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument}; + var queryParams = new ExecuteDocumentSelectionParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument }; var efv = new EventFlowValidator() .AddStandardQueryResultValidator() @@ -274,7 +275,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution // ... I request a to execute a valid query with multiple batches var workspaceService = GetDefaultWorkspaceService(string.Format("{0}\r\nGO\r\n{0}", Common.StandardQuery)); var queryService = Common.GetPrimedExecutionService(Common.StandardTestDataSet, true, false, workspaceService); - var queryParams = new ExecuteDocumentSelectionParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument}; + var queryParams = new ExecuteDocumentSelectionParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument }; var efv = new EventFlowValidator() .AddStandardQueryResultValidator() @@ -306,7 +307,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution // ... I request to execute a query using a file URI that isn't connected var workspaceService = GetDefaultWorkspaceService(Common.StandardQuery); var queryService = Common.GetPrimedExecutionService(null, false, false, workspaceService); - var queryParams = new ExecuteDocumentSelectionParams { OwnerUri = "notConnected", QuerySelection = Common.WholeDocument}; + var queryParams = new ExecuteDocumentSelectionParams { OwnerUri = "notConnected", QuerySelection = Common.WholeDocument }; var efv = new EventFlowValidator() .AddErrorValidation(Assert.NotEmpty) @@ -328,7 +329,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution // ... I request to execute a query var workspaceService = GetDefaultWorkspaceService(Common.StandardQuery); var queryService = Common.GetPrimedExecutionService(null, true, false, workspaceService); - var queryParams = new ExecuteDocumentSelectionParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument}; + var queryParams = new ExecuteDocumentSelectionParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument }; // Note, we don't care about the results of the first request var firstRequestContext = RequestContextMocks.Create(null); @@ -356,7 +357,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution // ... I request to execute a query var workspaceService = GetDefaultWorkspaceService(Common.StandardQuery); var queryService = Common.GetPrimedExecutionService(null, true, false, workspaceService); - var queryParams = new ExecuteDocumentSelectionParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument}; + var queryParams = new ExecuteDocumentSelectionParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument }; // Note, we don't care about the results of the first request var firstRequestContext = RequestContextMocks.Create(null); @@ -390,7 +391,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution // If: // ... I request to execute a query with a missing query string var queryService = Common.GetPrimedExecutionService(null, true, false, workspaceService); - var queryParams = new ExecuteDocumentSelectionParams { OwnerUri = Common.OwnerUri, QuerySelection = null}; + var queryParams = new ExecuteDocumentSelectionParams { OwnerUri = Common.OwnerUri, QuerySelection = null }; var efv = new EventFlowValidator() .AddErrorValidation(Assert.NotEmpty) @@ -412,7 +413,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution // ... I request to execute a query that is invalid var workspaceService = GetDefaultWorkspaceService(Common.StandardQuery); var queryService = Common.GetPrimedExecutionService(null, true, true, workspaceService); - var queryParams = new ExecuteDocumentSelectionParams {OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument}; + var queryParams = new ExecuteDocumentSelectionParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument }; var efv = new EventFlowValidator() .AddStandardQueryResultValidator() diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/SqlScriptFormatterTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/SqlScriptFormatterTests.cs new file mode 100644 index 00000000..d220724c --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/SqlScriptFormatterTests.cs @@ -0,0 +1,321 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.Text.RegularExpressions; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Microsoft.SqlTools.ServiceLayer.Utility; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Utility +{ + public class SqlScriptFormatterTests + { + #region Format Identifier Tests + + [Fact] + public void FormatIdentifierNull() + { + // If: I attempt to format null as an identifier + // Then: I should get an exception thrown + Assert.Throws(() => SqlScriptFormatter.FormatIdentifier(null)); + } + + [Theory] + [InlineData("test", "[test]")] // No escape characters + [InlineData("]test", "[]]test]")] // Escape character at beginning + [InlineData("te]st", "[te]]st]")] // Escape character in middle + [InlineData("test]", "[test]]]")] // Escape character at end + [InlineData("t]]est", "[t]]]]est]")] // Multiple escape characters + public void FormatIdentifierTest(string value, string expectedOutput) + { + // If: I attempt to format a value as an identifier + string output = SqlScriptFormatter.FormatIdentifier(value); + + // Then: The output should match the expected output + Assert.Equal(expectedOutput, output); + } + + [Theory] + [InlineData("test", "[test]")] // No splits, no escape characters + [InlineData("test.test", "[test].[test]")] // One split, no escape characters + [InlineData("test.te]st", "[test].[te]]st]")] // One split, one escape character + [InlineData("test.test.test", "[test].[test].[test]")] // Two splits, no escape characters + public void FormatMultipartIdentifierTest(string value, string expectedOutput) + { + // If: I attempt to format a value as a multipart identifier + string output = SqlScriptFormatter.FormatMultipartIdentifier(value); + + // Then: The output should match the expected output + Assert.Equal(expectedOutput, output); + } + + [Theory] + [MemberData(nameof(GetMultipartIdentifierArrays))] + public void FormatMultipartIdentifierArrayTest(string expectedOutput, string[] splits) + { + // If: I attempt to format a value as a multipart identifier + string output = SqlScriptFormatter.FormatMultipartIdentifier(splits); + + // Then: The output should match the expected output + Assert.Equal(expectedOutput, output); + } + + public static IEnumerable GetMultipartIdentifierArrays + { + get + { + yield return new object[] {"[test]", new[] {"test"}}; // No splits, no escape characters + yield return new object[] {"[test].[test]", new[] {"test", "test"}}; // One split, no escape characters + yield return new object[] {"[test].[te]]st]", new[] {"test", "te]st"}}; // One split, one escape character + yield return new object[] {"[test].[test].[test]", new[] {"test", "test", "test"}}; // Two splits, no escape characters + } + } + + #endregion + + #region FormatValue Tests + + [Fact] + public void NullDbCellTest() + { + // If: I attempt to format a null db cell + // Then: It should throw + Assert.Throws(() => SqlScriptFormatter.FormatValue(null, new FormatterTestDbColumn(null))); + } + + [Fact] + public void NullDbColumnTest() + { + // If: I attempt to format a null db column + // Then: It should throw + Assert.Throws(() => SqlScriptFormatter.FormatValue(new DbCellValue(), null)); + } + + public void UnsupportedColumnTest() + { + // If: I attempt to format an unsupported datatype + // Then: It should throw + DbColumn column = new FormatterTestDbColumn("unsupported"); + Assert.Throws(() => SqlScriptFormatter.FormatValue(new DbCellValue(), column)); + } + + [Fact] + public void NullTest() + { + // If: I attempt to format a db cell that contains null + // Then: I should get the null string back + string formattedString = SqlScriptFormatter.FormatValue(new DbCellValue(), new FormatterTestDbColumn(null)); + Assert.Equal(SqlScriptFormatter.NullString, formattedString); + } + + + [Theory] + [InlineData("BIGINT")] + [InlineData("INT")] + [InlineData("SMALLINT")] + [InlineData("TINYINT")] + public void IntegerNumericTest(string dataType) + { + // Setup: Build a column and cell for the integer type column + DbColumn column = new FormatterTestDbColumn(dataType); + DbCellValue cell = new DbCellValue { RawObject = (long)123 }; + + // If: I attempt to format an integer type column + string output = SqlScriptFormatter.FormatValue(cell, column); + + // Then: The output string should be able to be converted back into a long + Assert.Equal(cell.RawObject, long.Parse(output)); + } + + [Theory] + [InlineData("MONEY", "MONEY", null, null)] + [InlineData("SMALLMONEY", "SMALLMONEY", null, null)] + [InlineData("NUMERIC", @"NUMERIC\(\d+, \d+\)", 18, 0)] + [InlineData("DECIMAL", @"DECIMAL\(\d+, \d+\)", 18, 0)] + public void DecimalTest(string dataType, string regex, int? precision, int? scale) + { + // Setup: Build a column and cell for the decimal type column + DbColumn column = new FormatterTestDbColumn(dataType, precision, scale); + DbCellValue cell = new DbCellValue { RawObject = 123.45m }; + + // If: I attempt to format a decimal type column + string output = SqlScriptFormatter.FormatValue(cell, column); + + // Then: It should match a something like CAST(123.45 AS MONEY) + Regex castRegex = new Regex($@"CAST\([\d\.]+ AS {regex}", RegexOptions.IgnoreCase); + Assert.True(castRegex.IsMatch(output)); + } + + [Fact] + public void DoubleTest() + { + // Setup: Build a column and cell for the approx numeric type column + DbColumn column = new FormatterTestDbColumn("FLOAT"); + DbCellValue cell = new DbCellValue { RawObject = 3.14159d }; + + // If: I attempt to format a approx numeric type column + string output = SqlScriptFormatter.FormatValue(cell, column); + + // Then: The output string should be able to be converted back into a double + Assert.Equal(cell.RawObject, double.Parse(output)); + } + + [Fact] + public void FloatTest() + { + // Setup: Build a column and cell for the approx numeric type column + DbColumn column = new FormatterTestDbColumn("REAL"); + DbCellValue cell = new DbCellValue { RawObject = (float)3.14159 }; + + // If: I attempt to format a approx numeric type column + string output = SqlScriptFormatter.FormatValue(cell, column); + + // Then: The output string should be able to be converted back into a double + Assert.Equal(cell.RawObject, float.Parse(output)); + } + + [Theory] + [InlineData("SMALLDATETIME")] + [InlineData("DATETIME")] + [InlineData("DATETIME2")] + [InlineData("DATE")] + public void DateTimeTest(string dataType) + { + // Setup: Build a column and cell for the datetime type column + DbColumn column = new FormatterTestDbColumn(dataType); + DbCellValue cell = new DbCellValue { RawObject = DateTime.Now }; + + // If: I attempt to format a datetime type column + string output = SqlScriptFormatter.FormatValue(cell, column); + + // Then: The output string should be able to be converted back into a datetime + Regex dateTimeRegex = new Regex("N'(.*)'"); + DateTime outputDateTime; + Assert.True(DateTime.TryParse(dateTimeRegex.Match(output).Groups[1].Value, out outputDateTime)); + } + + [Fact] + public void DateTimeOffsetTest() + { + // Setup: Build a column and cell for the datetime offset type column + DbColumn column = new FormatterTestDbColumn("DATETIMEOFFSET"); + DbCellValue cell = new DbCellValue { RawObject = DateTimeOffset.Now }; + + // If: I attempt to format a datetime offset type column + string output = SqlScriptFormatter.FormatValue(cell, column); + + // Then: The output string should be able to be converted back into a datetime offset + Regex dateTimeRegex = new Regex("N'(.*)'"); + DateTimeOffset outputDateTime; + Assert.True(DateTimeOffset.TryParse(dateTimeRegex.Match(output).Groups[1].Value, out outputDateTime)); + } + + [Fact] + public void TimeTest() + { + // Setup: Build a column and cell for the time type column + DbColumn column = new FormatterTestDbColumn("TIME"); + DbCellValue cell = new DbCellValue { RawObject = TimeSpan.FromHours(12) }; + + // If: I attempt to format a time type column + string output = SqlScriptFormatter.FormatValue(cell, column); + + // Then: The output string should be able to be converted back into a timespan + Regex dateTimeRegex = new Regex("N'(.*)'"); + TimeSpan outputDateTime; + Assert.True(TimeSpan.TryParse(dateTimeRegex.Match(output).Groups[1].Value, out outputDateTime)); + } + + [Theory] + [InlineData("", "N''")] // Make sure empty string works + [InlineData(" \t\r\n", "N' \t\r\n'")] // Test for whitespace + [InlineData("some text \x9152", "N'some text \x9152'")] // Test unicode (UTF-8 and UTF-16) + [InlineData("'", "N''''")] // Test with escaped character + public void StringFormattingTest(string input, string expectedOutput) + { + // Setup: Build a column and cell for the string type column + // NOTE: We're using VARCHAR because it's very general purpose. + DbColumn column = new FormatterTestDbColumn("VARCHAR"); + DbCellValue cell = new DbCellValue { RawObject = input }; + + // If: I attempt to format a string type column + string output = SqlScriptFormatter.FormatValue(cell, column); + + // Then: The output string should be quoted and escaped properly + Assert.Equal(expectedOutput, output); + } + + [Theory] + [InlineData("CHAR")] + [InlineData("NCHAR")] + [InlineData("VARCHAR")] + [InlineData("TEXT")] + [InlineData("NTEXT")] + [InlineData("XML")] + public void StringTypeTest(string datatype) + { + // Setup: Build a column and cell for the string type column + DbColumn column = new FormatterTestDbColumn(datatype); + DbCellValue cell = new DbCellValue { RawObject = "test string" }; + + // If: I attempt to format a string type column + string output = SqlScriptFormatter.FormatValue(cell, column); + + // Then: The output string should match the output string + Assert.Equal("N'test string'", output); + } + + [Theory] + [InlineData("BINARY")] + [InlineData("VARBINARY")] + [InlineData("IMAGE")] + public void BinaryTest(string datatype) + { + // Setup: Build a column and cell for the string type column + DbColumn column = new FormatterTestDbColumn(datatype); + DbCellValue cell = new DbCellValue + { + RawObject = new byte[] { 0x42, 0x45, 0x4e, 0x49, 0x53, 0x43, 0x4f, 0x4f, 0x4c } + }; + + // If: I attempt to format a string type column + string output = SqlScriptFormatter.FormatValue(cell, column); + + // Then: The output string should match the output string + Regex regex = new Regex("0x[0-9A-F]+", RegexOptions.IgnoreCase); + Assert.True(regex.IsMatch(output)); + } + + [Fact] + public void GuidTest() + { + // Setup: Build a column and cell for the string type column + DbColumn column = new FormatterTestDbColumn("UNIQUEIDENTIFIER"); + DbCellValue cell = new DbCellValue { RawObject = Guid.NewGuid() }; + + // If: I attempt to format a string type column + string output = SqlScriptFormatter.FormatValue(cell, column); + + // Then: The output string should match the output string + Regex regex = new Regex(@"N'[0-9A-F]{8}(-[0-9A-F]{4}){3}-[0-9A-F]{12}'", RegexOptions.IgnoreCase); + Assert.True(regex.IsMatch(output)); + } + + #endregion + + private class FormatterTestDbColumn : DbColumn + { + public FormatterTestDbColumn(string dataType, int? precision = null, int? scale = null) + { + DataTypeName = dataType; + NumericPrecision = precision; + NumericScale = scale; + } + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbColumn.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbColumn.cs index cbe6d189..91a588bb 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbColumn.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbColumn.cs @@ -2,6 +2,7 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System; using System.Data.Common; namespace Microsoft.SqlTools.ServiceLayer.Test.Utility @@ -24,11 +25,27 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Utility base.DataTypeName = columnType; } + public TestDbColumn(string columnName, string columnType, Type columnDataType) + { + base.IsLong = false; + base.ColumnName = columnName; + base.ColumnSize = 128; + base.AllowDBNull = true; + base.DataType = columnDataType; + base.DataTypeName = columnType; + } + public TestDbColumn(string columnName, string columnType, int scale) : this(columnName, columnType) { base.NumericScale = scale; } + public TestDbColumn(string columnName, bool isAutoIncrement) + : this(columnName) + { + base.IsAutoIncrement = isAutoIncrement; + base.IsIdentity = isAutoIncrement; + } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs index fb331a88..7927ceb6 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs @@ -150,7 +150,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Utility public override long GetChars(int ordinal, long dataOffset, char[] buffer, int bufferOffset, int length) { - throw new NotImplementedException(); + char[] allChars = ((string) RowEnumerator.Current[ordinal]).ToCharArray(); + int outLength = allChars.Length; + if (buffer != null) + { + Array.Copy(allChars, (int) dataOffset, buffer, bufferOffset, outLength); + } + return outLength; } public override string GetDataTypeName(int ordinal)