diff --git a/.mention-bot b/.mention-bot new file mode 100644 index 00000000..c3bb3f47 --- /dev/null +++ b/.mention-bot @@ -0,0 +1,7 @@ +{ + "maxReviewers": 3, + "requiredOrgs": ["Microsoft"], + "skipAlreadyAssignedPR": true, + "skipAlreadyMentionedPR": true, + "skipCollaboratorPR": false +} \ No newline at end of file diff --git a/README.md b/README.md index 534081b2..8282a7ff 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,6 @@ The SQL Tools Service is an application that provides core functionality for var * Connection management * Language Service support using VS Code protocol * Query execution and resultset management -* Schema discovery # Contribution Guidelines @@ -130,12 +129,8 @@ on this check so that our project will always have good generated documentation. - **The build and unit tests must run green** - When you submit your pull request, our automated build system on AppVeyor will attempt to run a - Release build of your changes and then run all unit tests against the build. If you notice that - any of your unit tests have failed, please fix them by creating a new commit and then pushing it - to your branch. If you see that some unrelated test has failed, try re-running the build for your - pull request. If you continue to see issues, write a comment on the pull request and we will - look into it. + Run all unit tests and code coverage tests to ensure all tests are passing and code coverage numbers + and not negatively impacted by your change. - **Respond to code review feedback** @@ -147,6 +142,5 @@ on this check so that our project will always have good generated documentation. Once your final changes have been accepted, we may ask you to do a final rebase to have your commits so that they follow our commit guidelines. If specific guidance is given, please follow it when - rebasing your commits. Once you do your final push and we see the AppVeyor build pass, we will - merge your changes! + rebasing your commits. Once you do your final push we will merge your changes! diff --git a/RefreshDllsForTestRun.cmd b/RefreshDllsForTestRun.cmd new file mode 100644 index 00000000..284f4aa1 --- /dev/null +++ b/RefreshDllsForTestRun.cmd @@ -0,0 +1,17 @@ +SET WORKINGDIR=%~dp0 +SET _TargetLocation=%1 +SET _BuildConfiguration=%2 +IF [%_BuildConfiguration%] NEQ [] GOTO Start +SET _BuildConfiguration=Debug + +:Start +SET _PerfTestSourceLocation="%WORKINGDIR%\test\Microsoft.SqlTools.ServiceLayer.PerfTests\bin\%_BuildConfiguration%\netcoreapp1.0\win7-x64\publish" +SET _ServiceSourceLocation="%WORKINGDIR%\src\Microsoft.SqlTools.ServiceLayer\bin\%_BuildConfiguration%\netcoreapp1.0\win7-x64\publish" + + + +dotnet publish %WORKINGDIR%\test\Microsoft.SqlTools.ServiceLayer.PerfTests -c %_BuildConfiguration% +dotnet publish %WORKINGDIR%\src\Microsoft.SqlTools.ServiceLayer -c %_BuildConfiguration% + +XCOPY /i /E /y %_PerfTestSourceLocation% "%_TargetLocation%\Tests" +XCOPY /i /E /y %_ServiceSourceLocation% "%_TargetLocation%\Microsoft.SqlTools.ServiceLayer" diff --git a/sqltoolsservice.sln b/sqltoolsservice.sln index 15ece77f..630d13a2 100644 --- a/sqltoolsservice.sln +++ b/sqltoolsservice.sln @@ -19,6 +19,8 @@ Project("{8BB2217D-0F2D-49D1-97BC-3654ED321F3B}") = "Microsoft.SqlTools.ServiceL EndProject Project("{8BB2217D-0F2D-49D1-97BC-3654ED321F3B}") = "Microsoft.SqlTools.ServiceLayer.Test", "test\Microsoft.SqlTools.ServiceLayer.Test\Microsoft.SqlTools.ServiceLayer.Test.xproj", "{2D771D16-9D85-4053-9F79-E2034737DEEF}" EndProject +Project("{8BB2217D-0F2D-49D1-97BC-3654ED321F3B}") = "Microsoft.SqlTools.ServiceLayer.TestDriver", "test\Microsoft.SqlTools.ServiceLayer.TestDriver\Microsoft.SqlTools.ServiceLayer.TestDriver.xproj", "{CC785604-6277-4878-8DA9-360C47158E96}" +EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "scripts", "scripts", "{B7D21727-2926-452B-9610-3ADB0BB6D789}" ProjectSection(SolutionItems) = preProject scripts\archiving.cake = scripts\archiving.cake @@ -38,6 +40,19 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "build", "build", "{F9978D78 build.sh = build.sh EndProjectSection EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "CodeCoverage", "CodeCoverage", "{87D9C7D9-18F4-4AB9-B20D-66C02B6075E2}" + ProjectSection(SolutionItems) = preProject + test\CodeCoverage\codecoverage.bat = test\CodeCoverage\codecoverage.bat + test\CodeCoverage\gulpfile.js = test\CodeCoverage\gulpfile.js + test\CodeCoverage\nuget.config = test\CodeCoverage\nuget.config + test\CodeCoverage\package.json = test\CodeCoverage\package.json + test\CodeCoverage\packages.config = test\CodeCoverage\packages.config + test\CodeCoverage\ReplaceText.vbs = test\CodeCoverage\ReplaceText.vbs + test\CodeCoverage\runintegration.bat = test\CodeCoverage\runintegration.bat + EndProjectSection +EndProject +Project("{8BB2217D-0F2D-49D1-97BC-3654ED321F3B}") = "Microsoft.SqlTools.ServiceLayer.PerfTests", "test\Microsoft.SqlTools.ServiceLayer.PerfTests\Microsoft.SqlTools.ServiceLayer.PerfTests.xproj", "{7E5968AB-83D7-4738-85A2-416A50F13D2F}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -52,6 +67,14 @@ Global {2D771D16-9D85-4053-9F79-E2034737DEEF}.Debug|Any CPU.Build.0 = Debug|Any CPU {2D771D16-9D85-4053-9F79-E2034737DEEF}.Release|Any CPU.ActiveCfg = Release|Any CPU {2D771D16-9D85-4053-9F79-E2034737DEEF}.Release|Any CPU.Build.0 = Release|Any CPU + {CC785604-6277-4878-8DA9-360C47158E96}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {CC785604-6277-4878-8DA9-360C47158E96}.Debug|Any CPU.Build.0 = Debug|Any CPU + {CC785604-6277-4878-8DA9-360C47158E96}.Release|Any CPU.ActiveCfg = Release|Any CPU + {CC785604-6277-4878-8DA9-360C47158E96}.Release|Any CPU.Build.0 = Release|Any CPU + {7E5968AB-83D7-4738-85A2-416A50F13D2F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {7E5968AB-83D7-4738-85A2-416A50F13D2F}.Debug|Any CPU.Build.0 = Debug|Any CPU + {7E5968AB-83D7-4738-85A2-416A50F13D2F}.Release|Any CPU.ActiveCfg = Release|Any CPU + {7E5968AB-83D7-4738-85A2-416A50F13D2F}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -59,6 +82,9 @@ Global GlobalSection(NestedProjects) = preSolution {0D61DC2B-DA66-441D-B9D0-F76C98F780F9} = {2BBD7364-054F-4693-97CD-1C395E3E84A9} {2D771D16-9D85-4053-9F79-E2034737DEEF} = {AB9CA2B8-6F70-431C-8A1D-67479D8A7BE4} + {CC785604-6277-4878-8DA9-360C47158E96} = {AB9CA2B8-6F70-431C-8A1D-67479D8A7BE4} {B7D21727-2926-452B-9610-3ADB0BB6D789} = {F9978D78-78FE-4E92-A7D6-D436B7683EF6} + {87D9C7D9-18F4-4AB9-B20D-66C02B6075E2} = {AB9CA2B8-6F70-431C-8A1D-67479D8A7BE4} + {7E5968AB-83D7-4738-85A2-416A50F13D2F} = {AB9CA2B8-6F70-431C-8A1D-67479D8A7BE4} EndGlobalSection EndGlobal diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs index 31d0026d..506b043e 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs @@ -23,6 +23,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection OwnerUri = ownerUri; ConnectionDetails = details; ConnectionId = Guid.NewGuid(); + IntellisenseMetrics = new InteractionMetrics(new int[] { 50, 100, 200, 500, 1000, 2000 }); } /// @@ -49,5 +50,15 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection /// The connection to the SQL database that commands will be run against. /// public DbConnection SqlConnection { get; set; } + + /// + /// Intellisense Metrics + /// + public InteractionMetrics IntellisenseMetrics { get; private set; } + + /// + /// Returns true is the db connection is to a SQL db + /// + public bool IsAzure { get; set; } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs index cb931a6e..0ace6c92 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs @@ -15,6 +15,7 @@ using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts; using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.Utility; using Microsoft.SqlTools.ServiceLayer.Workspace; @@ -29,7 +30,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection /// /// Singleton service instance /// - private static Lazy instance + private static readonly Lazy instance = new Lazy(() => new ConnectionService()); /// @@ -48,11 +49,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection /// private ISqlConnectionFactory connectionFactory; - private Dictionary ownerToConnectionMap = new Dictionary(); + private readonly Dictionary ownerToConnectionMap = new Dictionary(); - private ConcurrentDictionary ownerToCancellationTokenSourceMap = new ConcurrentDictionary(); + private readonly ConcurrentDictionary ownerToCancellationTokenSourceMap = new ConcurrentDictionary(); - private Object cancellationTokenSourceLock = new Object(); + private readonly object cancellationTokenSourceLock = new object(); /// /// Map from script URIs to ConnectionInfo objects @@ -77,9 +78,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection } /// - /// Default constructor is private since it's a singleton class + /// Default constructor should be private since it's a singleton class, but we need a constructor + /// for use in unit test mocking. /// - private ConnectionService() + public ConnectionService() { } @@ -129,7 +131,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection } // Attempts to link a URI to an actively used connection for this URI - public bool TryFindConnection(string ownerUri, out ConnectionInfo connectionInfo) + public virtual bool TryFindConnection(string ownerUri, out ConnectionInfo connectionInfo) { return this.ownerToConnectionMap.TryGetValue(ownerUri, out connectionInfo); } @@ -172,21 +174,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection connectionInfo = new ConnectionInfo(ConnectionFactory, connectionParams.OwnerUri, connectionParams.Connection); // try to connect - var response = new ConnectionCompleteParams(); - response.OwnerUri = connectionParams.OwnerUri; + var response = new ConnectionCompleteParams {OwnerUri = connectionParams.OwnerUri}; CancellationTokenSource source = null; try { // build the connection string from the input parameters - string connectionString = ConnectionService.BuildConnectionString(connectionInfo.ConnectionDetails); + string connectionString = BuildConnectionString(connectionInfo.ConnectionDetails); // create a sql connection instance connectionInfo.SqlConnection = connectionInfo.Factory.CreateSqlConnection(connectionString); - // turning on MARS to avoid break in LanguageService with multiple editors - // we'll remove this once ConnectionService is refactored to not own the LanguageService connection - connectionInfo.ConnectionDetails.MultipleActiveResultSets = true; - // Add a cancellation token source so that the connection OpenAsync() can be cancelled using (source = new CancellationTokenSource()) { @@ -264,7 +261,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection // Update with the actual database name in connectionInfo and result // Doing this here as we know the connection is open - expect to do this only on connecting connectionInfo.ConnectionDetails.DatabaseName = connectionInfo.SqlConnection.Database; - response.ConnectionSummary = new ConnectionSummary() + response.ConnectionSummary = new ConnectionSummary { ServerName = connectionInfo.ConnectionDetails.ServerName, DatabaseName = connectionInfo.ConnectionDetails.DatabaseName, @@ -272,7 +269,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection }; // invoke callback notifications - invokeOnConnectionActivities(connectionInfo); + InvokeOnConnectionActivities(connectionInfo); // try to get information about the connected SQL Server instance try @@ -281,7 +278,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection DbConnection connection = reliableConnection != null ? reliableConnection.GetUnderlyingConnection() : connectionInfo.SqlConnection; ReliableConnectionHelper.ServerInfo serverInfo = ReliableConnectionHelper.GetServerVersion(connection); - response.ServerInfo = new Contracts.ServerInfo() + response.ServerInfo = new ServerInfo { ServerMajorVersion = serverInfo.ServerMajorVersion, ServerMinorVersion = serverInfo.ServerMinorVersion, @@ -294,6 +291,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection AzureVersion = serverInfo.AzureVersion, OsVersion = serverInfo.OsVersion }; + connectionInfo.IsAzure = serverInfo.IsCloud; } catch(Exception ex) { @@ -360,6 +358,23 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection return false; } + if (ServiceHost != null) + { + // Send a telemetry notification for intellisense performance metrics + ServiceHost.SendEvent(TelemetryNotification.Type, new TelemetryParams() + { + Params = new TelemetryProperties + { + Properties = new Dictionary + { + { "IsAzure", info.IsAzure ? "1" : "0" } + }, + EventName = TelemetryEventNames.IntellisenseQuantile, + Measures = info.IntellisenseMetrics.Quantile + } + }); + } + // Close the connection info.SqlConnection.Close(); @@ -402,7 +417,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection connection.Open(); List results = new List(); - var systemDatabases = new string[] {"master", "model", "msdb", "tempdb"}; + var systemDatabases = new[] {"master", "model", "msdb", "tempdb"}; using (DbCommand command = connection.CreateCommand()) { command.CommandText = "SELECT name FROM sys.databases ORDER BY name ASC"; @@ -476,7 +491,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection try { - RunConnectRequestHandlerTask(connectParams, requestContext); + RunConnectRequestHandlerTask(connectParams); await requestContext.SendResult(true); } catch @@ -485,7 +500,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection } } - private void RunConnectRequestHandlerTask(ConnectParams connectParams, RequestContext requestContext) + private void RunConnectRequestHandlerTask(ConnectParams connectParams) { // create a task to connect asynchronously so that other requests are not blocked in the meantime Task.Run(async () => @@ -493,7 +508,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection try { // open connection based on request details - ConnectionCompleteParams result = await ConnectionService.Instance.Connect(connectParams); + ConnectionCompleteParams result = await Instance.Connect(connectParams); await ServiceHost.SendEvent(ConnectionCompleteNotification.Type, result); } catch (Exception ex) @@ -518,7 +533,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection try { - bool result = ConnectionService.Instance.CancelConnect(cancelParams); + bool result = Instance.CancelConnect(cancelParams); await requestContext.SendResult(result); } catch(Exception ex) @@ -538,7 +553,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection try { - bool result = ConnectionService.Instance.Disconnect(disconnectParams); + bool result = Instance.Disconnect(disconnectParams); await requestContext.SendResult(result); } catch(Exception ex) @@ -559,7 +574,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection try { - ListDatabasesResponse result = ConnectionService.Instance.ListDatabases(listDatabasesParams); + ListDatabasesResponse result = Instance.ListDatabases(listDatabasesParams); await requestContext.SendResult(result); } catch(Exception ex) @@ -582,10 +597,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection /// public static string BuildConnectionString(ConnectionDetails connectionDetails) { - SqlConnectionStringBuilder connectionBuilder = new SqlConnectionStringBuilder(); - connectionBuilder["Data Source"] = connectionDetails.ServerName; - connectionBuilder["User Id"] = connectionDetails.UserName; - connectionBuilder["Password"] = connectionDetails.Password; + SqlConnectionStringBuilder connectionBuilder = new SqlConnectionStringBuilder + { + ["Data Source"] = connectionDetails.ServerName, + ["User Id"] = connectionDetails.UserName, + ["Password"] = connectionDetails.Password + }; // Check for any optional parameters if (!string.IsNullOrEmpty(connectionDetails.DatabaseName)) @@ -725,7 +742,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection // Fire a connection changed event ConnectionChangedParams parameters = new ConnectionChangedParams(); - ConnectionSummary summary = (ConnectionSummary)(info.ConnectionDetails); + ConnectionSummary summary = info.ConnectionDetails; parameters.Connection = summary.Clone(); parameters.OwnerUri = ownerUri; ServiceHost.SendEvent(ConnectionChangedNotification.Type, parameters); @@ -744,7 +761,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection } } - private void invokeOnConnectionActivities(ConnectionInfo connectionInfo) + private void InvokeOnConnectionActivities(ConnectionInfo connectionInfo) { foreach (var activity in this.onConnectionActivities) { diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.DataTransferDetectionErrorStrategy.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.DataTransferDetectionErrorStrategy.cs index a52619d2..74be5faf 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.DataTransferDetectionErrorStrategy.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.DataTransferDetectionErrorStrategy.cs @@ -13,7 +13,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection /// /// Provides the error detection logic for temporary faults that are commonly found during data transfer. /// - internal sealed class DataTransferErrorDetectionStrategy : ErrorDetectionStrategyBase, IErrorDetectionStrategy + internal class DataTransferErrorDetectionStrategy : ErrorDetectionStrategyBase, IErrorDetectionStrategy { private static readonly DataTransferErrorDetectionStrategy instance = new DataTransferErrorDetectionStrategy(); diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.SqlAzureTemporaryAndIgnorableErrorDetectionStrategy.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.SqlAzureTemporaryAndIgnorableErrorDetectionStrategy.cs index 0cf26070..e538f88b 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.SqlAzureTemporaryAndIgnorableErrorDetectionStrategy.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.SqlAzureTemporaryAndIgnorableErrorDetectionStrategy.cs @@ -18,7 +18,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection /// want to consider this as passing since the first execution that has timed out (or failed for some other temporary error) /// might have managed to create the object. /// - internal sealed class SqlAzureTemporaryAndIgnorableErrorDetectionStrategy : ErrorDetectionStrategyBase, IErrorDetectionStrategy + internal class SqlAzureTemporaryAndIgnorableErrorDetectionStrategy : ErrorDetectionStrategyBase, IErrorDetectionStrategy { /// /// Azure error that can be ignored diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.cs index a2eb03fe..23159d3e 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.cs @@ -533,7 +533,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection return new RetryStateEx { TotalRetryTime = TimeSpan.Zero }; } - private sealed class RetryStateEx : RetryState + internal sealed class RetryStateEx : RetryState { public TimeSpan TotalRetryTime { get; set; } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicyFactory.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicyFactory.cs index 4ac025fb..8e42894a 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicyFactory.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicyFactory.cs @@ -390,7 +390,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection return retryPolicy; } - private static void DataConnectionFailureRetry(RetryState retryState) + internal static void DataConnectionFailureRetry(RetryState retryState) { Logger.Write(LogLevel.Normal, string.Format(CultureInfo.InvariantCulture, "Connection retry number {0}. Delaying {1} ms before retry. Exception: {2}", @@ -401,7 +401,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection RetryPolicyUtils.RaiseAmbientRetryMessage(retryState, SqlSchemaModelErrorCodes.ServiceActions.ConnectionRetry); } - private static void CommandFailureRetry(RetryState retryState, string commandKeyword) + internal static void CommandFailureRetry(RetryState retryState, string commandKeyword) { Logger.Write(LogLevel.Normal, string.Format( CultureInfo.InvariantCulture, @@ -414,7 +414,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection RetryPolicyUtils.RaiseAmbientRetryMessage(retryState, SqlSchemaModelErrorCodes.ServiceActions.CommandRetry); } - private static void CommandFailureIgnore(RetryState retryState, string commandKeyword) + internal static void CommandFailureIgnore(RetryState retryState, string commandKeyword) { Logger.Write(LogLevel.Normal, string.Format( CultureInfo.InvariantCulture, @@ -426,32 +426,32 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection RetryPolicyUtils.RaiseAmbientIgnoreMessage(retryState, SqlSchemaModelErrorCodes.ServiceActions.CommandRetry); } - private static void CommandFailureRetry(RetryState retryState) + internal static void CommandFailureRetry(RetryState retryState) { CommandFailureRetry(retryState, "Command"); } - private static void CommandFailureIgnore(RetryState retryState) + internal static void CommandFailureIgnore(RetryState retryState) { CommandFailureIgnore(retryState, "Command"); } - private static void CreateDatabaseCommandFailureRetry(RetryState retryState) + internal static void CreateDatabaseCommandFailureRetry(RetryState retryState) { CommandFailureRetry(retryState, "Database Command"); } - private static void CreateDatabaseCommandFailureIgnore(RetryState retryState) + internal static void CreateDatabaseCommandFailureIgnore(RetryState retryState) { CommandFailureIgnore(retryState, "Database Command"); } - private static void ElementCommandFailureRetry(RetryState retryState) + internal static void ElementCommandFailureRetry(RetryState retryState) { CommandFailureRetry(retryState, "Element Command"); } - private static void ElementCommandFailureIgnore(RetryState retryState) + internal static void ElementCommandFailureIgnore(RetryState retryState) { CommandFailureIgnore(retryState, "Element Command"); } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/SqlSchemaModelErrorCodes.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/SqlSchemaModelErrorCodes.cs index 4d7a65b7..e5f126cc 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/SqlSchemaModelErrorCodes.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/SqlSchemaModelErrorCodes.cs @@ -271,21 +271,21 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection public const int InvalidFileStreamOptions = ValidationBaseCode + 65; public const int StorageShouldNotSetOnDifferentInstance = ValidationBaseCode + 66; public const int TableShouldNotHaveStorage = ValidationBaseCode + 67; - public static int MemoryOptimizedObjectsValidation_NonMemoryOptimizedTableCannotBeAccessed = ValidationBaseCode + 68; - public static int MemoryOptimizedObjectsValidation_SyntaxNotSupportedOnHekatonElement = ValidationBaseCode + 69; - public static int MemoryOptimizedObjectsValidation_ValidatePrimaryKeyForSchemaAndDataTables = ValidationBaseCode + 70; - public static int MemoryOptimizedObjectsValidation_ValidatePrimaryKeyForSchemaOnlyTables = ValidationBaseCode + 71; - public static int MemoryOptimizedObjectsValidation_OnlyNotNullableColumnsOnIndexes = ValidationBaseCode + 72; - public static int MemoryOptimizedObjectsValidation_HashIndexesOnlyOnMemoryOptimizedObjects = ValidationBaseCode + 73; - public static int MemoryOptimizedObjectsValidation_OptionOnlyForHashIndexes = ValidationBaseCode + 74; - public static int IncrementalStatisticsValidation_FilterNotSupported = ValidationBaseCode + 75; - public static int IncrementalStatisticsValidation_ViewNotSupported = ValidationBaseCode + 76; - public static int IncrementalStatisticsValidation_IndexNotPartitionAligned = ValidationBaseCode + 77; - public static int AzureV12SurfaceAreaValidation = ValidationBaseCode + 78; - public static int DuplicatedTargetObjectReferencesInSecurityPolicy = ValidationBaseCode + 79; - public static int MultipleSecurityPoliciesOnTargetObject = ValidationBaseCode + 80; - public static int ExportedRowsMayBeIncomplete = ValidationBaseCode + 81; - public static int ExportedRowsMayContainSomeMaskedData = ValidationBaseCode + 82; + public const int MemoryOptimizedObjectsValidation_NonMemoryOptimizedTableCannotBeAccessed = ValidationBaseCode + 68; + public const int MemoryOptimizedObjectsValidation_SyntaxNotSupportedOnHekatonElement = ValidationBaseCode + 69; + public const int MemoryOptimizedObjectsValidation_ValidatePrimaryKeyForSchemaAndDataTables = ValidationBaseCode + 70; + public const int MemoryOptimizedObjectsValidation_ValidatePrimaryKeyForSchemaOnlyTables = ValidationBaseCode + 71; + public const int MemoryOptimizedObjectsValidation_OnlyNotNullableColumnsOnIndexes = ValidationBaseCode + 72; + public const int MemoryOptimizedObjectsValidation_HashIndexesOnlyOnMemoryOptimizedObjects = ValidationBaseCode + 73; + public const int MemoryOptimizedObjectsValidation_OptionOnlyForHashIndexes = ValidationBaseCode + 74; + public const int IncrementalStatisticsValidation_FilterNotSupported = ValidationBaseCode + 75; + public const int IncrementalStatisticsValidation_ViewNotSupported = ValidationBaseCode + 76; + public const int IncrementalStatisticsValidation_IndexNotPartitionAligned = ValidationBaseCode + 77; + public const int AzureV12SurfaceAreaValidation = ValidationBaseCode + 78; + public const int DuplicatedTargetObjectReferencesInSecurityPolicy = ValidationBaseCode + 79; + public const int MultipleSecurityPoliciesOnTargetObject = ValidationBaseCode + 80; + public const int ExportedRowsMayBeIncomplete = ValidationBaseCode + 81; + public const int ExportedRowsMayContainSomeMaskedData = ValidationBaseCode + 82; public const int EncryptedColumnValidation_EncryptedPrimaryKey = ValidationBaseCode + 83; public const int EncryptedColumnValidation_EncryptedUniqueColumn = ValidationBaseCode + 84; public const int EncryptedColumnValidation_EncryptedCheckConstraint = ValidationBaseCode + 85; @@ -315,8 +315,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection public const int TemporalValidation_SchemaMismatch = ValidationBaseCode + 109; public const int TemporalValidation_ComputedColumns = ValidationBaseCode + 110; public const int TemporalValidation_NoAlwaysEncryptedCols = ValidationBaseCode + 111; - public static int IndexesOnExternalTable = ValidationBaseCode + 112; - public static int TriggersOnExternalTable = ValidationBaseCode + 113; + public const int IndexesOnExternalTable = ValidationBaseCode + 112; + public const int TriggersOnExternalTable = ValidationBaseCode + 113; public const int StretchValidation_ExportBlocked = ValidationBaseCode + 114; public const int StretchValidation_ImportBlocked = ValidationBaseCode + 115; public const int DeploymentBlocked = ValidationBaseCode + 116; diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/CredentialService.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/CredentialService.cs index cf4182fc..6cb3644e 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Credentials/CredentialService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/CredentialService.cs @@ -48,7 +48,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Credentials /// Default constructor is private since it's a singleton class /// private CredentialService() - : this(null, new LinuxCredentialStore.StoreConfig() + : this(null, new StoreConfig() { CredentialFolder = DefaultSecretsFolder, CredentialFile = DefaultSecretsFile, IsRelativeToUserHomeDir = true}) { } @@ -56,7 +56,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Credentials /// /// Internal for testing purposes only /// - internal CredentialService(ICredentialStore store, LinuxCredentialStore.StoreConfig config) + internal CredentialService(ICredentialStore store, StoreConfig config) { this.credStore = store != null ? store : GetStoreForOS(config); } @@ -64,12 +64,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Credentials /// /// Internal for testing purposes only /// - internal static ICredentialStore GetStoreForOS(LinuxCredentialStore.StoreConfig config) + internal static ICredentialStore GetStoreForOS(StoreConfig config) { if(RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { return new Win32CredentialStore(); } +#if !WINDOWS_ONLY_BUILD else if(RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) { return new OSXCredentialStore(); @@ -78,6 +79,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Credentials { return new LinuxCredentialStore(config); } +#endif throw new InvalidOperationException("Platform not currently supported"); } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/FileTokenStorage.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/FileTokenStorage.cs index 756421e4..94bbde62 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/FileTokenStorage.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/FileTokenStorage.cs @@ -13,6 +13,9 @@ using Newtonsoft.Json; namespace Microsoft.SqlTools.ServiceLayer.Credentials.Linux { + +#if !WINDOWS_ONLY_BUILD + public class FileTokenStorage { private const int OwnerAccessMode = 384; // Permission 0600 - owner read/write, nobody else has access @@ -84,4 +87,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Credentials.Linux Interop.Sys.ChMod(filePath, OwnerAccessMode); } } + +#endif + } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/Interop.Errors.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/Interop.Errors.cs index f3b1d5f5..d5668068 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/Interop.Errors.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/Interop.Errors.cs @@ -8,6 +8,9 @@ using System.Runtime.InteropServices; namespace Microsoft.SqlTools.ServiceLayer.Credentials { + +#if !WINDOWS_ONLY_BUILD + internal static partial class Interop { /// Common Unix errno error codes. @@ -218,4 +221,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Credentials } } +#endif + } \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/Interop.Sys.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/Interop.Sys.cs index 8777ab0c..95848568 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/Interop.Sys.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/Interop.Sys.cs @@ -8,6 +8,9 @@ using System.Runtime.InteropServices; namespace Microsoft.SqlTools.ServiceLayer.Credentials { + +#if !WINDOWS_ONLY_BUILD + internal static partial class Interop { internal static partial class Sys @@ -37,6 +40,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Credentials internal const string SystemNative = "System.Native"; } } - } + +#endif + } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/LinuxCredentialStore.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/LinuxCredentialStore.cs index 93350e83..1009b401 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/LinuxCredentialStore.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/LinuxCredentialStore.cs @@ -3,7 +3,6 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // - using System; using System.Collections.Generic; using System.Diagnostics; @@ -15,6 +14,18 @@ using Microsoft.SqlTools.ServiceLayer.Utility; namespace Microsoft.SqlTools.ServiceLayer.Credentials.Linux { + /// + /// Store configuration struct + /// + internal struct StoreConfig + { + public string CredentialFolder { get; set; } + public string CredentialFile { get; set; } + public bool IsRelativeToUserHomeDir { get; set; } + } + +#if !WINDOWS_ONLY_BUILD + /// /// Linux implementation of the credential store. /// @@ -25,13 +36,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Credentials.Linux /// internal class LinuxCredentialStore : ICredentialStore { - internal struct StoreConfig - { - public string CredentialFolder { get; set; } - public string CredentialFile { get; set; } - public bool IsRelativeToUserHomeDir { get; set; } - } - private string credentialFolderPath; private string credentialFileName; private FileTokenStorage storage; @@ -228,4 +232,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Credentials.Linux } } } -} \ No newline at end of file +#endif + +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/OSX/OSXCredentialStore.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/OSX/OSXCredentialStore.cs index 8ea262b3..ff76fd00 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Credentials/OSX/OSXCredentialStore.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/OSX/OSXCredentialStore.cs @@ -10,6 +10,9 @@ using Microsoft.SqlTools.ServiceLayer.Utility; namespace Microsoft.SqlTools.ServiceLayer.Credentials.OSX { + +#if !WINDOWS_ONLY_BUILD + /// /// OSX implementation of the credential store /// @@ -155,4 +158,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Credentials.OSX } } } -} \ No newline at end of file + +#endif + +} + diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs index ef7b842d..3b3d11c8 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs @@ -163,8 +163,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol TParams typedParams = default(TParams); if (eventMessage.Contents != null) { - // TODO: Catch parse errors! - typedParams = eventMessage.Contents.ToObject(); + try + { + typedParams = eventMessage.Contents.ToObject(); + } + catch (Exception ex) + { + Logger.Write(LogLevel.Verbose, ex.ToString()); + } } return eventHandler(typedParams, eventContext); diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/ServiceHost.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/ServiceHost.cs index f217e0d3..32b8301e 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/ServiceHost.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/ServiceHost.cs @@ -150,14 +150,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting Capabilities = new ServerCapabilities { TextDocumentSync = TextDocumentSyncKind.Incremental, - DefinitionProvider = false, + DefinitionProvider = true, ReferencesProvider = false, DocumentHighlightProvider = false, HoverProvider = true, CompletionProvider = new CompletionOptions { ResolveProvider = true, - TriggerCharacters = new string[] { ".", "-", ":", "\\" } + TriggerCharacters = new string[] { ".", "-", ":", "\\", "[" } }, SignatureHelpProvider = new SignatureHelpOptions { diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteHelper.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteHelper.cs index c15206d8..df5ea3c9 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteHelper.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteHelper.cs @@ -5,16 +5,16 @@ using System; using System.Collections.Generic; -using System.Globalization; using System.Linq; -using System.Text.RegularExpressions; using System.Threading; using Microsoft.SqlServer.Management.SqlParser.Binder; using Microsoft.SqlServer.Management.SqlParser.Intellisense; using Microsoft.SqlServer.Management.SqlParser.Parser; using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.LanguageServices.Completion; using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts; using Microsoft.SqlTools.ServiceLayer.SqlContext; +using Microsoft.SqlTools.ServiceLayer.Utility; using Microsoft.SqlTools.ServiceLayer.Workspace; using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; @@ -30,8 +30,6 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices private static WorkspaceService workspaceServiceInstance; - private static Regex ValidSqlNameRegex = new Regex(@"^[\p{L}_@][\p{L}\p{N}@$#_]{0,127}$"); - private static CompletionItem[] emptyCompletionList = new CompletionItem[0]; private static readonly string[] DefaultCompletionText = new string[] @@ -372,12 +370,13 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices /// /// internal static CompletionItem[] GetDefaultCompletionItems( - int row, - int startColumn, - int endColumn, - bool useLowerCase, - string tokenText = null) + ScriptDocumentInfo scriptDocumentInfo, + bool useLowerCase) { + int row = scriptDocumentInfo.StartLine; + int startColumn = scriptDocumentInfo.StartColumn; + int endColumn = scriptDocumentInfo.EndColumn; + string tokenText = scriptDocumentInfo.TokenText; // determine how many default completion items there will be int listSize = DefaultCompletionText.Length; if (!string.IsNullOrWhiteSpace(tokenText)) @@ -408,7 +407,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices if (string.IsNullOrWhiteSpace(tokenText) || completionText.StartsWith(tokenText, StringComparison.OrdinalIgnoreCase)) { completionItems[completionItemIndex] = CreateDefaultCompletionItem( - useLowerCase ? completionText.ToLower() : completionText.ToUpper(), + useLowerCase ? completionText.ToLowerInvariant() : completionText.ToUpperInvariant(), row, startColumn, endColumn); @@ -432,7 +431,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices int startColumn, int endColumn) { - return CreateCompletionItem(label, label + " keyword", label, CompletionItemKind.Keyword, row, startColumn, endColumn); + return SqlCompletionItem.CreateCompletionItem(label, label + " keyword", label, CompletionItemKind.Keyword, row, startColumn, endColumn); } internal static CompletionItem[] AddTokenToItems(CompletionItem[] currentList, Token token, int row, @@ -446,49 +445,12 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices )) { var list = currentList.ToList(); - list.Insert(0, CreateCompletionItem(token.Text, token.Text, token.Text, CompletionItemKind.Text, row, startColumn, endColumn)); + list.Insert(0, SqlCompletionItem.CreateCompletionItem(token.Text, token.Text, token.Text, CompletionItemKind.Text, row, startColumn, endColumn)); return list.ToArray(); } return currentList; } - private static CompletionItem CreateCompletionItem( - string label, - string detail, - string insertText, - CompletionItemKind kind, - int row, - int startColumn, - int endColumn) - { - CompletionItem item = new CompletionItem() - { - Label = label, - Kind = kind, - Detail = detail, - InsertText = insertText, - TextEdit = new TextEdit - { - NewText = insertText, - Range = new Range - { - Start = new Position - { - Line = row, - Character = startColumn - }, - End = new Position - { - Line = row, - Character = endColumn - } - } - } - }; - - return item; - } - /// /// Converts a list of Declaration objects to CompletionItem objects /// since VS Code expects CompletionItems but SQL Parser works with Declarations @@ -501,56 +463,22 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices IEnumerable suggestions, int row, int startColumn, - int endColumn) + int endColumn, + string tokenText = null) { List completions = new List(); foreach (var autoCompleteItem in suggestions) { - string insertText = GetCompletionItemInsertName(autoCompleteItem); - CompletionItemKind kind = CompletionItemKind.Variable; - switch (autoCompleteItem.Type) - { - case DeclarationType.Schema: - kind = CompletionItemKind.Module; - break; - case DeclarationType.Column: - kind = CompletionItemKind.Field; - break; - case DeclarationType.Table: - case DeclarationType.View: - kind = CompletionItemKind.File; - break; - case DeclarationType.Database: - kind = CompletionItemKind.Method; - break; - case DeclarationType.ScalarValuedFunction: - case DeclarationType.TableValuedFunction: - case DeclarationType.BuiltInFunction: - kind = CompletionItemKind.Value; - break; - default: - kind = CompletionItemKind.Unit; - break; - } + SqlCompletionItem sqlCompletionItem = new SqlCompletionItem(autoCompleteItem, tokenText); // convert the completion item candidates into CompletionItems - completions.Add(CreateCompletionItem(autoCompleteItem.Title, autoCompleteItem.Title, insertText, kind, row, startColumn, endColumn)); + completions.Add(sqlCompletionItem.CreateCompletionItem(row, startColumn, endColumn)); } return completions.ToArray(); } - private static string GetCompletionItemInsertName(Declaration autoCompleteItem) - { - string insertText = autoCompleteItem.Title; - if (!string.IsNullOrEmpty(autoCompleteItem.Title) && !ValidSqlNameRegex.IsMatch(autoCompleteItem.Title)) - { - insertText = string.Format(CultureInfo.InvariantCulture, "[{0}]", autoCompleteItem.Title); - } - return insertText; - } - /// /// Preinitialize the parser and binder with common metadata. /// This should front load the long binding wait to the time the @@ -566,7 +494,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices { if (scriptInfo.IsConnected) { - var scriptFile = AutoCompleteHelper.WorkspaceServiceInstance.Workspace.GetFile(info.OwnerUri); + var scriptFile = AutoCompleteHelper.WorkspaceServiceInstance.Workspace.GetFile(info.OwnerUri); LanguageService.Instance.ParseAndBind(scriptFile, info); if (Monitor.TryEnter(scriptInfo.BuildingMetadataLock, LanguageService.OnConnectionWaitTimeout)) @@ -679,5 +607,77 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices return null; } } + + /// + /// Converts a SQL Parser List of MethodHelpText objects into a VS Code SignatureHelp object + /// + internal static SignatureHelp ConvertMethodHelpTextListToSignatureHelp(List methods, Babel.MethodNameAndParamLocations locations, int line, int column) + { + Validate.IsNotNull(nameof(methods), methods); + Validate.IsNotNull(nameof(locations), locations); + Validate.IsGreaterThan(nameof(line), line, 0); + Validate.IsGreaterThan(nameof(column), column, 0); + + SignatureHelp help = new SignatureHelp(); + + help.Signatures = methods.Select(method => + { + return new SignatureInformation() + { + // Signature label format: param1, param2, ..., paramn RETURNS + Label = method.Name + " " + method.Parameters.Select(parameter => parameter.Display).Aggregate((l, r) => l + "," + r) + " " + method.Type, + Documentation = method.Description, + Parameters = method.Parameters.Select(parameter => + { + return new ParameterInformation() + { + Label = parameter.Display, + Documentation = parameter.Description + }; + }).ToArray() + }; + }).Where(method => method.Label.Contains(locations.Name)).ToArray(); + + if (help.Signatures.Length == 0) + { + return null; + } + + // Find the matching method signature at the cursor's location + // For now, take the first match (since we've already filtered by name above) + help.ActiveSignature = 0; + + // Determine the current parameter at the cursor + int currentParameter = -1; // Default case: not on any particular parameter + if (locations.ParamStartLocation != null) + { + // Is the cursor past the function name? + var location = locations.ParamStartLocation.Value; + if (line > location.LineNumber || (line == location.LineNumber && line == location.LineNumber && column >= location.ColumnNumber)) + { + currentParameter = 0; + } + } + foreach (var location in locations.ParamSeperatorLocations) + { + // Is the cursor past a comma ',' and at least on the next parameter? + if (line > location.LineNumber || (line == location.LineNumber && column > location.ColumnNumber)) + { + currentParameter++; + } + } + if (locations.ParamEndLocation != null) + { + // Is the cursor past the end of the parameter list on a different token? + var location = locations.ParamEndLocation.Value; + if (line > location.LineNumber || (line == location.LineNumber && line == location.LineNumber && column > location.ColumnNumber)) + { + currentParameter = -1; + } + } + help.ActiveParameter = currentParameter; + + return help; + } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/BindingQueue.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/BindingQueue.cs index 6058ec42..07623afb 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/BindingQueue.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/BindingQueue.cs @@ -112,6 +112,17 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices } } + /// + /// Checks if a binding context already exists for the provided context key + /// + protected bool BindingContextExists(string key) + { + lock (this.bindingContextLock) + { + return this.BindingContextMap.ContainsKey(key); + } + } + private bool HasPendingQueueItems { get diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Completion/AutoCompletionResult.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Completion/AutoCompletionResult.cs new file mode 100644 index 00000000..b0678781 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Completion/AutoCompletionResult.cs @@ -0,0 +1,54 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Collections.Generic; +using System.Diagnostics; +using Microsoft.SqlServer.Management.SqlParser.Intellisense; +using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices.Completion +{ + /// + /// Includes the objects created by auto completion service + /// + public class AutoCompletionResult + { + /// + /// Creates new instance + /// + public AutoCompletionResult() + { + Stopwatch = new Stopwatch(); + Stopwatch.Start(); + } + + private Stopwatch Stopwatch { get; set; } + + /// + /// Completes the results to calculate the duration + /// + public void CompleteResult(CompletionItem[] completionItems) + { + Stopwatch.Stop(); + CompletionItems = completionItems; + } + + /// + /// The number of milliseconds to process the result + /// + public double Duration + { + get + { + return Stopwatch.ElapsedMilliseconds; + } + } + + /// + /// Completion list + /// + public CompletionItem[] CompletionItems { get; private set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Completion/CompletionService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Completion/CompletionService.cs new file mode 100644 index 00000000..cf7eb6c3 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Completion/CompletionService.cs @@ -0,0 +1,166 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Collections.Generic; +using System.Threading; +using Microsoft.SqlServer.Management.SqlParser.Intellisense; +using Microsoft.SqlServer.Management.SqlParser.MetadataProvider; +using Microsoft.SqlServer.Management.SqlParser.Parser; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices.Completion +{ + /// + /// A service to create auto complete list for given script document + /// + internal class CompletionService + { + private ConnectedBindingQueue BindingQueue { get; set; } + + /// + /// Created new instance given binding queue + /// + public CompletionService(ConnectedBindingQueue bindingQueue) + { + BindingQueue = bindingQueue; + } + + private ISqlParserWrapper sqlParserWrapper; + + /// + /// SQL parser wrapper to create the completion list + /// + public ISqlParserWrapper SqlParserWrapper + { + get + { + if(this.sqlParserWrapper == null) + { + this.sqlParserWrapper = new SqlParserWrapper(); + } + return this.sqlParserWrapper; + } + set + { + this.sqlParserWrapper = value; + } + } + + /// + /// Creates a completion list given connection and document info + /// + public AutoCompletionResult CreateCompletions( + ConnectionInfo connInfo, + ScriptDocumentInfo scriptDocumentInfo, + bool useLowerCaseSuggestions) + { + AutoCompletionResult result = new AutoCompletionResult(); + // check if the file is connected and the file lock is available + if (scriptDocumentInfo.ScriptParseInfo.IsConnected && Monitor.TryEnter(scriptDocumentInfo.ScriptParseInfo.BuildingMetadataLock)) + { + try + { + QueueItem queueItem = AddToQueue(connInfo, scriptDocumentInfo.ScriptParseInfo, scriptDocumentInfo, useLowerCaseSuggestions); + + // wait for the queue item + queueItem.ItemProcessed.WaitOne(); + var completionResult = queueItem.GetResultAsT(); + if (completionResult != null && completionResult.CompletionItems != null && completionResult.CompletionItems.Length > 0) + { + result = completionResult; + } + else if (!ShouldShowCompletionList(scriptDocumentInfo.Token)) + { + result.CompleteResult(AutoCompleteHelper.EmptyCompletionList); + } + } + finally + { + Monitor.Exit(scriptDocumentInfo.ScriptParseInfo.BuildingMetadataLock); + } + } + + return result; + } + + private QueueItem AddToQueue( + ConnectionInfo connInfo, + ScriptParseInfo scriptParseInfo, + ScriptDocumentInfo scriptDocumentInfo, + bool useLowerCaseSuggestions) + { + // queue the completion task with the binding queue + QueueItem queueItem = this.BindingQueue.QueueBindingOperation( + key: scriptParseInfo.ConnectionKey, + bindingTimeout: LanguageService.BindingTimeout, + bindOperation: (bindingContext, cancelToken) => + { + return CreateCompletionsFromSqlParser(connInfo, scriptParseInfo, scriptDocumentInfo, bindingContext.MetadataDisplayInfoProvider); + }, + timeoutOperation: (bindingContext) => + { + // return the default list if the connected bind fails + return CreateDefaultCompletionItems(scriptParseInfo, scriptDocumentInfo, useLowerCaseSuggestions); + }); + return queueItem; + } + + private static bool ShouldShowCompletionList(Token token) + { + bool result = true; + if (token != null) + { + switch (token.Id) + { + case (int)Tokens.LEX_MULTILINE_COMMENT: + case (int)Tokens.LEX_END_OF_LINE_COMMENT: + result = false; + break; + } + } + return result; + } + + private AutoCompletionResult CreateDefaultCompletionItems(ScriptParseInfo scriptParseInfo, ScriptDocumentInfo scriptDocumentInfo, bool useLowerCaseSuggestions) + { + AutoCompletionResult result = new AutoCompletionResult(); + CompletionItem[] completionList = AutoCompleteHelper.GetDefaultCompletionItems(scriptDocumentInfo, useLowerCaseSuggestions); + result.CompleteResult(completionList); + return result; + } + + private AutoCompletionResult CreateCompletionsFromSqlParser( + ConnectionInfo connInfo, + ScriptParseInfo scriptParseInfo, + ScriptDocumentInfo scriptDocumentInfo, + MetadataDisplayInfoProvider metadataDisplayInfoProvider) + { + AutoCompletionResult result = new AutoCompletionResult(); + IEnumerable suggestions = SqlParserWrapper.FindCompletions( + scriptParseInfo.ParseResult, + scriptDocumentInfo.ParserLine, + scriptDocumentInfo.ParserColumn, + metadataDisplayInfoProvider); + + // get the completion list from SQL Parser + scriptParseInfo.CurrentSuggestions = suggestions; + + // convert the suggestion list to the VS Code format + CompletionItem[] completionList = AutoCompleteHelper.ConvertDeclarationsToCompletionItems( + scriptParseInfo.CurrentSuggestions, + scriptDocumentInfo.StartLine, + scriptDocumentInfo.StartColumn, + scriptDocumentInfo.EndColumn, + scriptDocumentInfo.TokenText); + + result.CompleteResult(completionList); + + //The bucket for number of milliseconds will take to send back auto complete list + connInfo.IntellisenseMetrics.UpdateMetrics(result.Duration, 1, (k2, v2) => v2 + 1); + return result; + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Completion/SqlCompletionItem.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Completion/SqlCompletionItem.cs new file mode 100644 index 00000000..39247f34 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Completion/SqlCompletionItem.cs @@ -0,0 +1,206 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Globalization; +using System.Text.RegularExpressions; +using Microsoft.SqlServer.Management.SqlParser.Intellisense; +using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts; +using Microsoft.SqlTools.ServiceLayer.Utility; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices.Completion +{ + /// + /// Creates a completion item from SQL parser declaration item + /// + public class SqlCompletionItem + { + private static Regex ValidSqlNameRegex = new Regex(@"^[\p{L}_@][\p{L}\p{N}@$#_]{0,127}$"); + + /// + /// Create new instance given the SQL parser declaration + /// + public SqlCompletionItem(Declaration declaration, string tokenText) : + this(declaration == null ? null : declaration.Title, declaration == null ? DeclarationType.Table : declaration.Type, tokenText) + { + } + + /// + /// Creates new instance given declaration title and type + /// + public SqlCompletionItem(string declarationTitle, DeclarationType declarationType, string tokenText) + { + Validate.IsNotNullOrEmptyString("declarationTitle", declarationTitle); + + DeclarationTitle = declarationTitle; + DeclarationType = declarationType; + TokenText = tokenText; + + Init(); + } + + private void Init() + { + InsertText = GetCompletionItemInsertName(); + Label = DeclarationTitle; + if (StartsWithBracket(TokenText)) + { + Label = WithBracket(Label); + InsertText = WithBracket(InsertText); + } + Detail = Label; + Kind = CreateCompletionItemKind(); + } + + private CompletionItemKind CreateCompletionItemKind() + { + CompletionItemKind kind = CompletionItemKind.Variable; + switch (DeclarationType) + { + case DeclarationType.Schema: + kind = CompletionItemKind.Module; + break; + case DeclarationType.Column: + kind = CompletionItemKind.Field; + break; + case DeclarationType.Table: + case DeclarationType.View: + kind = CompletionItemKind.File; + break; + case DeclarationType.Database: + kind = CompletionItemKind.Method; + break; + case DeclarationType.ScalarValuedFunction: + case DeclarationType.TableValuedFunction: + case DeclarationType.BuiltInFunction: + kind = CompletionItemKind.Value; + break; + default: + kind = CompletionItemKind.Unit; + break; + } + + return kind; + } + + /// + /// Declaration Title + /// + public string DeclarationTitle { get; private set; } + + /// + /// Token text from the editor + /// + public string TokenText { get; private set; } + + /// + /// SQL declaration type + /// + public DeclarationType DeclarationType { get; private set; } + + /// + /// Completion item label + /// + public string Label { get; private set; } + + /// + /// Completion item kind + /// + public CompletionItemKind Kind { get; private set; } + + /// + /// Completion insert text + /// + public string InsertText { get; private set; } + + /// + /// Completion item detail + /// + public string Detail { get; private set; } + + /// + /// Creates a completion item given the editor info + /// + public CompletionItem CreateCompletionItem( + int row, + int startColumn, + int endColumn) + { + return CreateCompletionItem(Label, Detail, InsertText, Kind, row, startColumn, endColumn); + } + + /// + /// Creates a completion item + /// + public static CompletionItem CreateCompletionItem( + string label, + string detail, + string insertText, + CompletionItemKind kind, + int row, + int startColumn, + int endColumn) + { + CompletionItem item = new CompletionItem() + { + Label = label, + Kind = kind, + Detail = detail, + InsertText = insertText, + TextEdit = new TextEdit + { + NewText = insertText, + Range = new Range + { + Start = new Position + { + Line = row, + Character = startColumn + }, + End = new Position + { + Line = row, + Character = endColumn + } + } + } + }; + + return item; + } + + private string GetCompletionItemInsertName() + { + string insertText = DeclarationTitle; + if (!string.IsNullOrEmpty(DeclarationTitle) && !ValidSqlNameRegex.IsMatch(DeclarationTitle)) + { + insertText = WithBracket(DeclarationTitle); + } + return insertText; + } + + private bool HasBrackets(string text) + { + return text != null && text.StartsWith("[") && text.EndsWith("]"); + } + + private bool StartsWithBracket(string text) + { + return text != null && text.StartsWith("["); + } + + private string WithBracket(string text) + { + if (!HasBrackets(text)) + { + return string.Format(CultureInfo.InvariantCulture, "[{0}]", text); + } + else + { + return text; + } + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Completion/SqlParserWrapper.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Completion/SqlParserWrapper.cs new file mode 100644 index 00000000..7e949aa9 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Completion/SqlParserWrapper.cs @@ -0,0 +1,34 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Collections.Generic; +using Microsoft.SqlServer.Management.SqlParser.Intellisense; +using Microsoft.SqlServer.Management.SqlParser.MetadataProvider; +using Microsoft.SqlServer.Management.SqlParser.Parser; + +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices.Completion +{ + /// + /// SqlParserWrapper interface + /// + public interface ISqlParserWrapper + { + IEnumerable FindCompletions(ParseResult parseResult, int line, int col, IMetadataDisplayInfoProvider displayInfoProvider); + } + + /// + /// A wrapper class around SQL parser methods to make the operations testable + /// + public class SqlParserWrapper : ISqlParserWrapper + { + /// + /// Creates completion list given SQL script info + /// + public IEnumerable FindCompletions(ParseResult parseResult, int line, int col, IMetadataDisplayInfoProvider displayInfoProvider) + { + return Resolver.FindCompletions(parseResult, line, col, displayInfoProvider); + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ConnectedBindingQueue.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ConnectedBindingQueue.cs index 965d94d2..bd767a7a 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ConnectedBindingQueue.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ConnectedBindingQueue.cs @@ -61,6 +61,11 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices // lookup the current binding context string connectionKey = GetConnectionContextKey(connInfo); + if (BindingContextExists(connectionKey)) + { + // no need to populate the context again since the context already exists + return connectionKey; + } IBindingContext bindingContext = this.GetOrCreateBindingContext(connectionKey); if (bindingContext.BindingLock.WaitOne()) diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/IntelliSenseReady.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/IntelliSenseReady.cs new file mode 100644 index 00000000..4b61e4b7 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/IntelliSenseReady.cs @@ -0,0 +1,30 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts +{ + /// + /// Parameters sent back with an IntelliSense ready event + /// + public class IntelliSenseReadyParams + { + /// + /// URI identifying the text document + /// + public string OwnerUri { get; set; } + } + + /// + /// Event sent when the language service is finished updating after a connection + /// + public class IntelliSenseReadyNotification + { + public static readonly + EventType Type = + EventType.Create("textDocument/intelliSenseReady"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/TelemetryNotification.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/TelemetryNotification.cs new file mode 100644 index 00000000..00bc8fe0 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/TelemetryNotification.cs @@ -0,0 +1,59 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Collections.Generic; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts +{ + public class TelemetryProperties + { + public string EventName { get; set; } + + /// + /// Telemetry properties + /// + public Dictionary Properties { get; set; } + + /// + /// Telemetry measures + /// + public Dictionary Measures { get; set; } + } + + /// + /// Parameters sent back with an IntelliSense ready event + /// + public class TelemetryParams + { + public TelemetryProperties Params { get; set; } + } + + /// + /// Event sent when the language service needs to add a telemetry event + /// + public class TelemetryNotification + { + public static readonly + EventType Type = + EventType.Create("telemetry/sqlevent"); + } + + /// + /// List of telemetry events + /// + public static class TelemetryEventNames + { + /// + /// telemetry event name for auto complete response time + /// + public const string IntellisenseQuantile = "IntellisenseQuantile"; + + /// + /// telemetry even name for when definition is requested + /// + public const string PeekDefinitionRequested = "PeekDefinitionRequested"; + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/InteractionMetrics.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/InteractionMetrics.cs new file mode 100644 index 00000000..db04157f --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/InteractionMetrics.cs @@ -0,0 +1,98 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Utility; + +namespace Microsoft.SqlTools.ServiceLayer +{ + /// + /// A class to calculate the value for the metrics using the given bucket + /// + public class InteractionMetrics + { + /// + /// Creates new instance given a bucket of metrics + /// + public InteractionMetrics(int[] metrics) + { + Validate.IsNotNull("metrics", metrics); + if(metrics.Length == 0) + { + throw new ArgumentOutOfRangeException("metrics"); + } + + Counters = new ConcurrentDictionary(); + if (!IsSorted(metrics)) + { + Array.Sort(metrics); + } + Metrics = metrics; + } + + private ConcurrentDictionary Counters { get; } + + private object perfCountersLock = new object(); + + /// + /// The metrics bucket + /// + public int[] Metrics { get; private set; } + + /// + /// Returns true if the given list is sorted + /// + private bool IsSorted(int[] metrics) + { + if (metrics.Length > 1) + { + int previous = metrics[0]; + for (int i = 1; i < metrics.Length; i++) + { + if(metrics[i] < previous) + { + return false; + } + previous = metrics[i]; + } + } + return true; + } + + /// + /// Update metric value given new number + /// + public void UpdateMetrics(double duration, T newValue, Func updateValueFactory) + { + int metric = Metrics[Metrics.Length - 1]; + for (int i = 0; i < Metrics.Length; i++) + { + if (duration <= Metrics[i]) + { + metric = Metrics[i]; + break; + } + } + string key = metric.ToString(); + Counters.AddOrUpdate(key, newValue, updateValueFactory); + } + + /// + /// Returns the quantile + /// + public Dictionary Quantile + { + get + { + return Counters.ToDictionary(kvp => kvp.Key, kvp => kvp.Value); + } + } + } +} + diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs index e03e748b..c5f3ae3b 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs @@ -18,6 +18,7 @@ using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlTools.ServiceLayer.Hosting; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.LanguageServices.Completion; using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts; using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.Utility; @@ -33,6 +34,8 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices /// public sealed class LanguageService { + private const int OneSecond = 1000; + internal const string DefaultBatchSeperator = "GO"; internal const int DiagnosticParseDelay = 750; @@ -41,7 +44,9 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices internal const int BindingTimeout = 500; - internal const int OnConnectionWaitTimeout = 300000; + internal const int OnConnectionWaitTimeout = 300 * OneSecond; + + internal const int PeekDefinitionTimeout = 10 * OneSecond; private static ConnectionService connectionService = null; @@ -196,13 +201,16 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices public void InitializeService(ServiceHost serviceHost, SqlToolsContext context) { // Register the requests that this service will handle - serviceHost.SetRequestHandler(DefinitionRequest.Type, HandleDefinitionRequest); - serviceHost.SetRequestHandler(ReferencesRequest.Type, HandleReferencesRequest); - serviceHost.SetRequestHandler(CompletionResolveRequest.Type, HandleCompletionResolveRequest); + + // turn off until needed (10/28/2016) + // serviceHost.SetRequestHandler(ReferencesRequest.Type, HandleReferencesRequest); + // serviceHost.SetRequestHandler(DocumentHighlightRequest.Type, HandleDocumentHighlightRequest); + serviceHost.SetRequestHandler(SignatureHelpRequest.Type, HandleSignatureHelpRequest); - serviceHost.SetRequestHandler(DocumentHighlightRequest.Type, HandleDocumentHighlightRequest); + serviceHost.SetRequestHandler(CompletionResolveRequest.Type, HandleCompletionResolveRequest); serviceHost.SetRequestHandler(HoverRequest.Type, HandleHoverRequest); serviceHost.SetRequestHandler(CompletionRequest.Type, HandleCompletionRequest); + serviceHost.SetRequestHandler(DefinitionRequest.Type, HandleDefinitionRequest); // Register a no-op shutdown task for validation of the shutdown logic serviceHost.RegisterShutdownTask(async (shutdownParams, shutdownRequestContext) => @@ -290,13 +298,34 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices } } - private static async Task HandleDefinitionRequest( - TextDocumentPosition textDocumentPosition, - RequestContext requestContext) + internal static async Task HandleDefinitionRequest(TextDocumentPosition textDocumentPosition, RequestContext requestContext) { - await Task.FromResult(true); + if (WorkspaceService.Instance.CurrentSettings.IsIntelliSenseEnabled) + { + // Retrieve document and connection + ConnectionInfo connInfo; + var scriptFile = LanguageService.WorkspaceServiceInstance.Workspace.GetFile(textDocumentPosition.TextDocument.Uri); + LanguageService.ConnectionServiceInstance.TryFindConnection(scriptFile.ClientFilePath, out connInfo); + + Location[] locations = LanguageService.Instance.GetDefinition(textDocumentPosition, scriptFile, connInfo); + if (locations != null) + { + await requestContext.SendResult(locations); + + // Send a notification to signal that definition is sent + await ServiceHost.Instance.SendEvent(TelemetryNotification.Type, new TelemetryParams() + { + Params = new TelemetryProperties + { + EventName = TelemetryEventNames.PeekDefinitionRequested + } + }); + } + } } +// turn off this code until needed (10/28/2016) +#if false private static async Task HandleReferencesRequest( ReferencesParams referencesParams, RequestContext requestContext) @@ -304,19 +333,39 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices await Task.FromResult(true); } - private static async Task HandleSignatureHelpRequest( - TextDocumentPosition textDocumentPosition, - RequestContext requestContext) - { - await Task.FromResult(true); - } - private static async Task HandleDocumentHighlightRequest( TextDocumentPosition textDocumentPosition, RequestContext requestContext) { await Task.FromResult(true); } +#endif + + private static async Task HandleSignatureHelpRequest( + TextDocumentPosition textDocumentPosition, + RequestContext requestContext) + { + // check if Intellisense suggestions are enabled + if (!WorkspaceService.Instance.CurrentSettings.IsSuggestionsEnabled) + { + await Task.FromResult(true); + } + else + { + ScriptFile scriptFile = WorkspaceService.Instance.Workspace.GetFile( + textDocumentPosition.TextDocument.Uri); + + SignatureHelp help = LanguageService.Instance.GetSignatureHelp(textDocumentPosition, scriptFile); + if (help != null) + { + await requestContext.SendResult(help); + } + else + { + await requestContext.SendResult(new SignatureHelp()); + } + } + } private static async Task HandleHoverRequest( TextDocumentPosition textDocumentPosition, @@ -361,7 +410,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices eventContext); } - await Task.FromResult(true); + await Task.FromResult(true); } /// @@ -555,7 +604,10 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices } } - AutoCompleteHelper.PrepopulateCommonMetadata(info, scriptInfo, this.BindingQueue); + AutoCompleteHelper.PrepopulateCommonMetadata(info, scriptInfo, this.BindingQueue); + + // Send a notification to signal that autocomplete is ready + ServiceHost.Instance.SendEvent(IntelliSenseReadyNotification.Type, new IntelliSenseReadyParams() {OwnerUri = info.OwnerUri}); }); } @@ -626,6 +678,106 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices return completionItem; } + /// + /// Get definition for a selected sql object using SMO Scripting + /// + /// + /// + /// + /// Location with the URI of the script file + internal Location[] GetDefinition(TextDocumentPosition textDocumentPosition, ScriptFile scriptFile, ConnectionInfo connInfo) + { + // Parse sql + ScriptParseInfo scriptParseInfo = GetScriptParseInfo(textDocumentPosition.TextDocument.Uri); + if (scriptParseInfo == null) + { + return null; + } + + if (RequiresReparse(scriptParseInfo, scriptFile)) + { + scriptParseInfo.ParseResult = ParseAndBind(scriptFile, connInfo); + } + + // Get token from selected text + Token selectedToken = ScriptDocumentInfo.GetToken(scriptParseInfo, textDocumentPosition.Position.Line + 1, textDocumentPosition.Position.Character); + if (selectedToken == null) + { + return null; + } + // Strip "[" and "]"(if present) from the token text to enable matching with the suggestions. + // The suggestion title does not contain any sql punctuation + string tokenText = TextUtilities.RemoveSquareBracketSyntax(selectedToken.Text); + + if (scriptParseInfo.IsConnected && Monitor.TryEnter(scriptParseInfo.BuildingMetadataLock)) + { + try + { + // Queue the task with the binding queue + QueueItem queueItem = this.BindingQueue.QueueBindingOperation( + key: scriptParseInfo.ConnectionKey, + bindingTimeout: LanguageService.PeekDefinitionTimeout, + bindOperation: (bindingContext, cancelToken) => + { + // Get suggestions for the token + int parserLine = textDocumentPosition.Position.Line + 1; + int parserColumn = textDocumentPosition.Position.Character + 1; + IEnumerable declarationItems = Resolver.FindCompletions( + scriptParseInfo.ParseResult, + parserLine, parserColumn, + bindingContext.MetadataDisplayInfoProvider); + + // Match token with the suggestions(declaration items) returned + string schemaName = this.GetSchemaName(scriptParseInfo, textDocumentPosition.Position, scriptFile); + PeekDefinition peekDefinition = new PeekDefinition(connInfo); + return peekDefinition.GetScript(declarationItems, tokenText, schemaName); + + + }); + + // wait for the queue item + queueItem.ItemProcessed.WaitOne(); + return queueItem.GetResultAsT(); + } + finally + { + Monitor.Exit(scriptParseInfo.BuildingMetadataLock); + } + } + + return null; + } + + /// + /// Extract schema name for a token, if present + /// + /// + /// + /// + /// schema nama + private string GetSchemaName(ScriptParseInfo scriptParseInfo, Position position, ScriptFile scriptFile) + { + // Offset index by 1 for sql parser + int startLine = position.Line + 1; + int startColumn = position.Character + 1; + + // Get schema name + if (scriptParseInfo != null && scriptParseInfo.ParseResult != null && scriptParseInfo.ParseResult.Script != null && scriptParseInfo.ParseResult.Script.Tokens != null) + { + var tokenIndex = scriptParseInfo.ParseResult.Script.TokenManager.FindToken(startLine, startColumn); + var prevTokenIndex = scriptParseInfo.ParseResult.Script.TokenManager.GetPreviousSignificantTokenIndex(tokenIndex); + var prevTokenText = scriptParseInfo.ParseResult.Script.TokenManager.GetText(prevTokenIndex); + if (prevTokenText != null && prevTokenText.Equals(".")) + { + var schemaTokenIndex = scriptParseInfo.ParseResult.Script.TokenManager.GetPreviousSignificantTokenIndex(prevTokenIndex); + Token schemaToken = scriptParseInfo.ParseResult.Script.TokenManager.GetToken(schemaTokenIndex); + return TextUtilities.RemoveSquareBracketSyntax(schemaToken.Text); + } + } + // if no schema name, returns null + return null; + } + /// /// Get quick info hover tooltips for the current position /// @@ -682,171 +834,137 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices } /// - /// Return the completion item list for the current text position. - /// This method does not await cache builds since it expects to return quickly + /// Get function signature help for the current position /// - /// - public CompletionItem[] GetCompletionItems( - TextDocumentPosition textDocumentPosition, - ScriptFile scriptFile, - ConnectionInfo connInfo) + internal SignatureHelp GetSignatureHelp(TextDocumentPosition textDocumentPosition, ScriptFile scriptFile) { - // initialize some state to parse and bind the current script file - this.currentCompletionParseInfo = null; - CompletionItem[] resultCompletionItems = null; - string filePath = textDocumentPosition.TextDocument.Uri; int startLine = textDocumentPosition.Position.Line; - int parserLine = textDocumentPosition.Position.Line + 1; int startColumn = TextUtilities.PositionOfPrevDelimeter( scriptFile.Contents, textDocumentPosition.Position.Line, textDocumentPosition.Position.Character); - int endColumn = TextUtilities.PositionOfNextDelimeter( - scriptFile.Contents, - textDocumentPosition.Position.Line, - textDocumentPosition.Position.Character); - int parserColumn = textDocumentPosition.Position.Character + 1; - bool useLowerCaseSuggestions = this.CurrentSettings.SqlTools.IntelliSense.LowerCaseSuggestions.Value; + int endColumn = textDocumentPosition.Position.Character; - // get the current script parse info object ScriptParseInfo scriptParseInfo = GetScriptParseInfo(textDocumentPosition.TextDocument.Uri); + if (scriptParseInfo == null) { - return AutoCompleteHelper.GetDefaultCompletionItems( - startLine, - startColumn, - endColumn, - useLowerCaseSuggestions); + // Cache not set up yet - skip and wait until later + return null; } + ConnectionInfo connInfo; + LanguageService.ConnectionServiceInstance.TryFindConnection( + scriptFile.ClientFilePath, + out connInfo); + // reparse and bind the SQL statement if needed if (RequiresReparse(scriptParseInfo, scriptFile)) { ParseAndBind(scriptFile, connInfo); } - // if the parse failed then return the default list - if (scriptParseInfo.ParseResult == null) + if (scriptParseInfo.ParseResult != null) { - return AutoCompleteHelper.GetDefaultCompletionItems( - startLine, - startColumn, - endColumn, - useLowerCaseSuggestions); - } - - // need to adjust line & column for base-1 parser indices - Token token = GetToken(scriptParseInfo, parserLine, parserColumn); - string tokenText = token != null ? token.Text : null; - - // check if the file is connected and the file lock is available - if (scriptParseInfo.IsConnected && Monitor.TryEnter(scriptParseInfo.BuildingMetadataLock)) - { - try - { - // queue the completion task with the binding queue - QueueItem queueItem = this.BindingQueue.QueueBindingOperation( - key: scriptParseInfo.ConnectionKey, - bindingTimeout: LanguageService.BindingTimeout, - bindOperation: (bindingContext, cancelToken) => - { - // get the completion list from SQL Parser - scriptParseInfo.CurrentSuggestions = Resolver.FindCompletions( - scriptParseInfo.ParseResult, - parserLine, - parserColumn, - bindingContext.MetadataDisplayInfoProvider); - - // cache the current script parse info object to resolve completions later - this.currentCompletionParseInfo = scriptParseInfo; - - // convert the suggestion list to the VS Code format - return AutoCompleteHelper.ConvertDeclarationsToCompletionItems( - scriptParseInfo.CurrentSuggestions, - startLine, - startColumn, - endColumn); - }, - timeoutOperation: (bindingContext) => - { - // return the default list if the connected bind fails - return AutoCompleteHelper.GetDefaultCompletionItems( - startLine, - startColumn, - endColumn, - useLowerCaseSuggestions, - tokenText); - }); - - // wait for the queue item - queueItem.ItemProcessed.WaitOne(); - - var completionItems = queueItem.GetResultAsT(); - if (completionItems != null && completionItems.Length > 0) - { - resultCompletionItems = completionItems; - } - else if (!ShouldShowCompletionList(token)) - { - resultCompletionItems = AutoCompleteHelper.EmptyCompletionList; - } - } - finally + if (Monitor.TryEnter(scriptParseInfo.BuildingMetadataLock)) { - Monitor.Exit(scriptParseInfo.BuildingMetadataLock); - } - } - - // if there are no completions then provide the default list - if (resultCompletionItems == null) - { - resultCompletionItems = AutoCompleteHelper.GetDefaultCompletionItems( - startLine, - startColumn, - endColumn, - useLowerCaseSuggestions, - tokenText); - } - - return resultCompletionItems; - } - - private static Token GetToken(ScriptParseInfo scriptParseInfo, int startLine, int startColumn) - { - if (scriptParseInfo != null && scriptParseInfo.ParseResult != null && scriptParseInfo.ParseResult.Script != null && scriptParseInfo.ParseResult.Script.Tokens != null) - { - var tokenIndex = scriptParseInfo.ParseResult.Script.TokenManager.FindToken(startLine, startColumn); - if (tokenIndex >= 0) - { - // return the current token - int currentIndex = 0; - foreach (var token in scriptParseInfo.ParseResult.Script.Tokens) + try { - if (currentIndex == tokenIndex) - { - return token; - } - ++currentIndex; + QueueItem queueItem = this.BindingQueue.QueueBindingOperation( + key: scriptParseInfo.ConnectionKey, + bindingTimeout: LanguageService.BindingTimeout, + bindOperation: (bindingContext, cancelToken) => + { + // get the list of possible current methods for signature help + var methods = Resolver.FindMethods( + scriptParseInfo.ParseResult, + startLine + 1, + endColumn + 1, + bindingContext.MetadataDisplayInfoProvider); + + // get positional information on the current method + var methodLocations = Resolver.GetMethodNameAndParams(scriptParseInfo.ParseResult, + startLine + 1, + endColumn + 1, + bindingContext.MetadataDisplayInfoProvider); + + if (methodLocations != null) + { + // convert from the parser format to the VS Code wire format + return AutoCompleteHelper.ConvertMethodHelpTextListToSignatureHelp(methods, + methodLocations, + startLine + 1, + endColumn + 1); + } + else + { + return null; + } + }); + + queueItem.ItemProcessed.WaitOne(); + return queueItem.GetResultAsT(); } + finally + { + Monitor.Exit(scriptParseInfo.BuildingMetadataLock); + } } } + + // return null if there isn't a tooltip for the current location return null; } - private static bool ShouldShowCompletionList(Token token) + /// + /// Return the completion item list for the current text position. + /// This method does not await cache builds since it expects to return quickly + /// + /// + public CompletionItem[] GetCompletionItems( + TextDocumentPosition textDocumentPosition, + ScriptFile scriptFile, + ConnectionInfo connInfo) { - bool result = true; - if (token != null) + // initialize some state to parse and bind the current script file + this.currentCompletionParseInfo = null; + CompletionItem[] resultCompletionItems = null; + CompletionService completionService = new CompletionService(BindingQueue); + bool useLowerCaseSuggestions = this.CurrentSettings.SqlTools.IntelliSense.LowerCaseSuggestions.Value; + + // get the current script parse info object + ScriptParseInfo scriptParseInfo = GetScriptParseInfo(textDocumentPosition.TextDocument.Uri); + + if (scriptParseInfo == null) { - switch (token.Id) - { - case (int)Tokens.LEX_MULTILINE_COMMENT: - case (int)Tokens.LEX_END_OF_LINE_COMMENT: - result = false; - break; - } + return AutoCompleteHelper.GetDefaultCompletionItems(ScriptDocumentInfo.CreateDefaultDocumentInfo(textDocumentPosition, scriptFile), useLowerCaseSuggestions); } - return result; + + ScriptDocumentInfo scriptDocumentInfo = new ScriptDocumentInfo(textDocumentPosition, scriptFile, scriptParseInfo); + + // reparse and bind the SQL statement if needed + if (RequiresReparse(scriptParseInfo, scriptFile)) + { + ParseAndBind(scriptFile, connInfo); + } + + // if the parse failed then return the default list + if (scriptParseInfo.ParseResult == null) + { + return AutoCompleteHelper.GetDefaultCompletionItems(scriptDocumentInfo, useLowerCaseSuggestions); + } + AutoCompletionResult result = completionService.CreateCompletions(connInfo, scriptDocumentInfo, useLowerCaseSuggestions); + // cache the current script parse info object to resolve completions later + this.currentCompletionParseInfo = scriptParseInfo; + resultCompletionItems = result.CompletionItems; + + // if there are no completions then provide the default list + if (resultCompletionItems == null) + { + resultCompletionItems = AutoCompleteHelper.GetDefaultCompletionItems(scriptDocumentInfo, useLowerCaseSuggestions); + } + + return resultCompletionItems; } #endregion @@ -868,23 +986,26 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices // build a list of SQL script file markers from the errors List markers = new List(); - foreach (var error in parseResult.Errors) + if (parseResult != null && parseResult.Errors != null) { - markers.Add(new ScriptFileMarker() + foreach (var error in parseResult.Errors) { - Message = error.Message, - Level = ScriptFileMarkerLevel.Error, - ScriptRegion = new ScriptRegion() + markers.Add(new ScriptFileMarker() { - File = scriptFile.FilePath, - StartLineNumber = error.Start.LineNumber, - StartColumnNumber = error.Start.ColumnNumber, - StartOffset = 0, - EndLineNumber = error.End.LineNumber, - EndColumnNumber = error.End.ColumnNumber, - EndOffset = 0 - } - }); + Message = error.Message, + Level = ScriptFileMarkerLevel.Error, + ScriptRegion = new ScriptRegion() + { + File = scriptFile.FilePath, + StartLineNumber = error.Start.LineNumber, + StartColumnNumber = error.Start.ColumnNumber, + StartOffset = 0, + EndLineNumber = error.End.LineNumber, + EndColumnNumber = error.End.ColumnNumber, + EndOffset = 0 + } + }); + } } return markers.ToArray(); diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/PeekDefinition.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/PeekDefinition.cs new file mode 100644 index 00000000..07293ba6 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/PeekDefinition.cs @@ -0,0 +1,279 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// +using System; +using System.IO; +using System.Collections.Generic; +using System.Collections.Specialized; +using System.Data.SqlClient; +using System.Runtime.InteropServices; +using Microsoft.SqlServer.Management.Smo; +using Microsoft.SqlServer.Management.Common; +using Microsoft.SqlServer.Management.SqlParser.Intellisense; +using Microsoft.SqlTools.ServiceLayer.Utility; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices +{ + /// + /// Peek Definition/ Go to definition implementation + /// Script sql objects and write create scripts to file + /// + internal class PeekDefinition + { + private ConnectionInfo connectionInfo; + private string tempPath; + + internal delegate StringCollection ScriptGetter(string objectName, string schemaName); + + // Dictionary that holds the script getter for each type + private Dictionary sqlScriptGetters = + new Dictionary(); + + // Dictionary that holds the object name (as appears on the TSQL create statement) + private Dictionary sqlObjectTypes = new Dictionary(); + + private Database Database + { + get + { + if (this.connectionInfo.SqlConnection != null) + { + try + { + // Get server object from connection + string connectionString = ConnectionService.BuildConnectionString(this.connectionInfo.ConnectionDetails); + SqlConnection sqlConn = new SqlConnection(connectionString); + sqlConn.Open(); + ServerConnection serverConn = new ServerConnection(sqlConn); + Server server = new Server(serverConn); + return server.Databases[this.connectionInfo.SqlConnection.Database]; + } + catch(Exception ex) + { + Logger.Write(LogLevel.Error, "Exception at PeekDefinition Database.get() : " + ex.Message); + return null; + } + } + return null; + } + } + + internal PeekDefinition(ConnectionInfo connInfo) + { + this.connectionInfo = connInfo; + DirectoryInfo tempScriptDirectory = Directory.CreateDirectory(Path.GetTempPath() + "mssql_definition"); + this.tempPath = tempScriptDirectory.FullName; + Initialize(); + } + + /// + /// Add getters for each sql object supported by peek definition + /// + private void Initialize() + { + //Add script getters for each sql object + + //Add tables to supported types + AddSupportedType(DeclarationType.Table, GetTableScripts, "Table"); + + //Add views to supported types + AddSupportedType(DeclarationType.View, GetViewScripts, "view"); + + //Add stored procedures to supported types + AddSupportedType(DeclarationType.StoredProcedure, GetStoredProcedureScripts, "Procedure"); + } + + /// + /// Add the given type, scriptgetter and the typeName string to the respective dictionaries + /// + private void AddSupportedType(DeclarationType type, ScriptGetter scriptGetter, string typeName) + { + sqlScriptGetters.Add(type, scriptGetter); + sqlObjectTypes.Add(type, typeName); + + } + + /// + /// Convert a file to a location array containing a location object as expected by the extension + /// + internal Location[] GetLocationFromFile(string tempFileName, int lineNumber) + { + if (Path.DirectorySeparatorChar.Equals('/')) + { + tempFileName = "file:" + tempFileName; + } + else + { + tempFileName = new Uri(tempFileName).AbsoluteUri; + } + Location[] locations = new[] { + new Location { + Uri = tempFileName, + Range = new Range { + Start = new Position { Line = lineNumber, Character = 1}, + End = new Position { Line = lineNumber + 1, Character = 1} + } + } + }; + return locations; + } + + /// + /// Get line number for the create statement + /// + private int GetStartOfCreate(string script, string createString) + { + string[] lines = script.Split(new string[] { Environment.NewLine }, StringSplitOptions.None); + for (int lineNumber = 0; lineNumber < lines.Length; lineNumber++) + { + if (lines[lineNumber].IndexOf( createString, StringComparison.OrdinalIgnoreCase) >= 0) + { + return lineNumber; + } + } + return 0; + } + + /// + /// Get the script of the selected token based on the type of the token + /// + /// + /// + /// + /// Location object of the script file + internal Location[] GetScript(IEnumerable declarationItems, string tokenText, string schemaName) + { + foreach (Declaration declarationItem in declarationItems) + { + if (declarationItem.Title == null) + { + continue; + } + + if (declarationItem.Title.Equals(tokenText)) + { + // Script object using SMO based on type + DeclarationType type = declarationItem.Type; + if (sqlScriptGetters.ContainsKey(type) && sqlObjectTypes.ContainsKey(type)) + { + // On *nix and mac systems, the defaultSchema property throws an Exception when accessed. + // This workaround ensures that a schema name is present by attempting + // to get the schema name from the declaration item + // If all fails, the default schema name is assumed to be "dbo" + if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && string.IsNullOrEmpty(schemaName)) + { + string fullObjectName = declarationItem.DatabaseQualifiedName; + schemaName = this.GetSchemaFromDatabaseQualifiedName(fullObjectName, tokenText); + } + return GetSqlObjectDefinition( + sqlScriptGetters[type], + tokenText, + schemaName, + sqlObjectTypes[type] + ); + } + return null; + } + } + return null; + } + + /// + /// Return schema name from the full name of the database. If schema is missing return dbo as schema name. + /// + /// The full database qualified name(database.schema.object) + /// Object name + /// Schema name + internal string GetSchemaFromDatabaseQualifiedName(string fullObjectName, string objectName) + { + string[] tokens = fullObjectName.Split('.'); + for (int i = tokens.Length - 1; i > 0; i--) + { + if(tokens[i].Equals(objectName)) + { + return tokens[i-1]; + } + } + return "dbo"; + } + + /// + /// Script a table using SMO + /// + /// Table name + /// Schema name + /// String collection of scripts + internal StringCollection GetTableScripts(string tableName, string schemaName) + { + return (schemaName != null) ? Database?.Tables[tableName, schemaName]?.Script() + : Database?.Tables[tableName]?.Script(); + } + + /// + /// Script a view using SMO + /// + /// View name + /// Schema name + /// String collection of scripts + internal StringCollection GetViewScripts(string viewName, string schemaName) + { + return (schemaName != null) ? Database?.Views[viewName, schemaName]?.Script() + : Database?.Views[viewName]?.Script(); + } + + /// + /// Script a stored procedure using SMO + /// + /// Stored Procedure name + /// Schema Name + /// String collection of scripts + internal StringCollection GetStoredProcedureScripts(string viewName, string schemaName) + { + return (schemaName != null) ? Database?.StoredProcedures[viewName, schemaName]?.Script() + : Database?.StoredProcedures[viewName]?.Script(); + } + + /// + /// Script a object using SMO and write to a file. + /// + /// Function that returns the SMO scripts for an object + /// SQL object name + /// Schema name or null + /// Type of SQL object + /// Location object representing URI and range of the script file + internal Location[] GetSqlObjectDefinition( + ScriptGetter sqlScriptGetter, + string objectName, + string schemaName, + string objectType) + { + StringCollection scripts = sqlScriptGetter(objectName, schemaName); + string tempFileName = (schemaName != null) ? Path.Combine(this.tempPath, string.Format("{0}.{1}.sql", schemaName, objectName)) + : Path.Combine(this.tempPath, string.Format("{0}.sql", objectName)); + + if (scripts != null) + { + int lineNumber = 0; + using (StreamWriter scriptFile = new StreamWriter(File.Open(tempFileName, FileMode.Create, FileAccess.ReadWrite))) + { + + foreach (string script in scripts) + { + string createSyntax = string.Format("CREATE {0}", objectType); + if (script.IndexOf(createSyntax, StringComparison.OrdinalIgnoreCase) >= 0) + { + scriptFile.WriteLine(script); + lineNumber = GetStartOfCreate(script, createSyntax); + } + } + } + return GetLocationFromFile(tempFileName, lineNumber); + } + + return null; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ScriptDocumentInfo.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ScriptDocumentInfo.cs new file mode 100644 index 00000000..fd15ad82 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ScriptDocumentInfo.cs @@ -0,0 +1,133 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlServer.Management.SqlParser.Parser; +using Microsoft.SqlTools.ServiceLayer.Utility; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices.Completion +{ + /// + /// A class to calculate the numbers used by SQL parser using the text positions and content + /// + internal class ScriptDocumentInfo + { + /// + /// Create new instance + /// + public ScriptDocumentInfo(TextDocumentPosition textDocumentPosition, ScriptFile scriptFile, ScriptParseInfo scriptParseInfo) + : this(textDocumentPosition, scriptFile) + { + Validate.IsNotNull(nameof(scriptParseInfo), scriptParseInfo); + + ScriptParseInfo = scriptParseInfo; + // need to adjust line & column for base-1 parser indices + Token = GetToken(scriptParseInfo, ParserLine, ParserColumn); + } + + private ScriptDocumentInfo(TextDocumentPosition textDocumentPosition, ScriptFile scriptFile) + { + StartLine = textDocumentPosition.Position.Line; + ParserLine = textDocumentPosition.Position.Line + 1; + StartColumn = TextUtilities.PositionOfPrevDelimeter( + scriptFile.Contents, + textDocumentPosition.Position.Line, + textDocumentPosition.Position.Character); + EndColumn = TextUtilities.PositionOfNextDelimeter( + scriptFile.Contents, + textDocumentPosition.Position.Line, + textDocumentPosition.Position.Character); + ParserColumn = textDocumentPosition.Position.Character + 1; + Contents = scriptFile.Contents; + } + + /// + /// Creates a new with no backing defined + /// + /// A + /// A to process + /// + public static ScriptDocumentInfo CreateDefaultDocumentInfo(TextDocumentPosition textDocumentPosition, ScriptFile scriptFile) + { + return new ScriptDocumentInfo(textDocumentPosition, scriptFile); + } + + /// + /// Gets a string containing the full contents of the file. + /// + public string Contents { get; private set; } + + /// + /// Script Parse Info Instance + /// + public ScriptParseInfo ScriptParseInfo { get; private set; } + + /// + /// Start Line + /// + public int StartLine { get; private set; } + + /// + /// Parser Line + /// + public int ParserLine { get; private set; } + + /// + /// Start Column + /// + public int StartColumn { get; private set; } + + /// + /// end Column + /// + public int EndColumn { get; private set; } + + /// + /// Parser Column + /// + public int ParserColumn { get; private set; } + + /// + /// The token text in the file content used for completion list + /// + public string TokenText + { + get + { + return Token != null ? Token.Text : null; + } + } + + /// + /// The token in the file content used for completion list + /// + public Token Token { get; private set; } + + /// + /// Returns the token that will be used by SQL parser for creating the completion list + /// + internal static Token GetToken(ScriptParseInfo scriptParseInfo, int startLine, int startColumn) + { + if (scriptParseInfo != null && scriptParseInfo.ParseResult != null && scriptParseInfo.ParseResult.Script != null && scriptParseInfo.ParseResult.Script.Tokens != null) + { + var tokenIndex = scriptParseInfo.ParseResult.Script.TokenManager.FindToken(startLine, startColumn); + if (tokenIndex >= 0) + { + // return the current token + int currentIndex = 0; + foreach (var token in scriptParseInfo.ParseResult.Script.Tokens) + { + if (currentIndex == tokenIndex) + { + return token; + } + ++currentIndex; + } + } + } + return null; + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs index 7b54484e..7336cec2 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs @@ -57,21 +57,50 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution #endregion - internal Batch(string batchText, int startLine, int startColumn, int endLine, int endColumn, IFileStreamFactory outputFileFactory) + internal Batch(string batchText, SelectionData selection, int ordinalId, IFileStreamFactory outputFileFactory) { // Sanity check for input Validate.IsNotNullOrEmptyString(nameof(batchText), batchText); Validate.IsNotNull(nameof(outputFileFactory), outputFileFactory); + Validate.IsGreaterThan(nameof(ordinalId), ordinalId, 0); // Initialize the internal state BatchText = batchText; - Selection = new SelectionData(startLine, startColumn, endLine, endColumn); + Selection = selection; + executionStartTime = DateTime.Now; HasExecuted = false; + Id = ordinalId; resultSets = new List(); resultMessages = new List(); this.outputFileFactory = outputFileFactory; } + #region Events + + /// + /// Asynchronous handler for when batches are completed + /// + /// The batch that completed + public delegate Task BatchAsyncEventHandler(Batch batch); + + /// + /// Event that will be called when the batch has completed execution + /// + public event BatchAsyncEventHandler BatchCompletion; + + /// + /// Event to call when the batch has started execution + /// + public event BatchAsyncEventHandler BatchStart; + + /// + /// Event that will be called when the resultset has completed execution. It will not be + /// called from the Batch but from the ResultSet instance + /// + public event ResultSet.ResultSetAsyncEventHandler ResultSetCompletion; + + #endregion + #region Properties /// @@ -113,6 +142,11 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// public bool HasExecuted { get; set; } + /// + /// Ordinal of the batch in the query + /// + public int Id { get; private set; } + /// /// Messages that have come back from the server /// @@ -136,12 +170,39 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution { get { - return ResultSets.Select((set, index) => new ResultSetSummary() + lock (resultSets) { - ColumnInfo = set.Columns, - Id = index, - RowCount = set.RowCount - }).ToArray(); + return resultSets.Select(set => set.Summary).ToArray(); + } + } + } + + /// + /// Creates a based on the batch instance + /// + public BatchSummary Summary + { + get + { + // Batch summary with information available at start + BatchSummary summary = new BatchSummary + { + HasError = HasError, + Id = Id, + Selection = Selection, + ExecutionStart = ExecutionStartTimeStamp + }; + + // Add on extra details if we finished executing it + if (HasExecuted) + { + summary.ResultSetSummaries = ResultSummaries; + summary.Messages = ResultMessages.ToArray(); + summary.ExecutionEnd = ExecutionEndTimeStamp; + summary.ExecutionElapsed = ExecutionElapsedTime; + } + + return summary; } } @@ -167,10 +228,18 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution throw new InvalidOperationException("Batch has already executed."); } + // Notify that we've started execution + if (BatchStart != null) + { + await BatchStart(this); + } + try { - DbCommand command = null; + // Register the message listener to *this instance* of the batch + // Note: This is being done to associate messages with batches ReliableSqlConnection sqlConn = conn as ReliableSqlConnection; + DbCommand command; if (sqlConn != null) { // Register the message listener to *this instance* of the batch @@ -179,7 +248,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution command = sqlConn.GetUnderlyingConnection().CreateCommand(); // Add a handler for when the command completes - SqlCommand sqlCommand = (SqlCommand) command; + SqlCommand sqlCommand = (SqlCommand)command; sqlCommand.StatementCompleted += StatementCompletedHandler; } else @@ -202,6 +271,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution // Execute the command to get back a reader using (DbDataReader reader = await command.ExecuteReaderAsync(cancellationToken)) { + int resultSetOrdinal = 0; do { // Skip this result set if there aren't any rows (ie, UPDATE/DELETE/etc queries) @@ -211,11 +281,16 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution } // This resultset has results (ie, SELECT/etc queries) - ResultSet resultSet = new ResultSet(reader, outputFileFactory); - + ResultSet resultSet = new ResultSet(reader, resultSetOrdinal, Id, outputFileFactory); + resultSet.ResultCompletion += ResultSetCompletion; + // Add the result set to the results of the query - resultSets.Add(resultSet); - + lock (resultSets) + { + resultSets.Add(resultSet); + resultSetOrdinal++; + } + // Read until we hit the end of the result set await resultSet.ReadResultToEnd(cancellationToken).ConfigureAwait(false); @@ -258,6 +333,12 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution // Mark that we have executed HasExecuted = true; executionEndTime = DateTime.Now; + + // Fire an event to signify that the batch has completed + if (BatchCompletion != null) + { + await BatchCompletion(this); + } } } @@ -270,14 +351,21 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// A subset of results public Task GetSubset(int resultSetIndex, int startRow, int rowCount) { - // Sanity check to make sure we have valid numbers - if (resultSetIndex < 0 || resultSetIndex >= resultSets.Count) + ResultSet targetResultSet; + lock (resultSets) { - throw new ArgumentOutOfRangeException(nameof(resultSetIndex), SR.QueryServiceSubsetResultSetOutOfRange); + // Sanity check to make sure we have valid numbers + if (resultSetIndex < 0 || resultSetIndex >= resultSets.Count) + { + throw new ArgumentOutOfRangeException(nameof(resultSetIndex), + SR.QueryServiceSubsetResultSetOutOfRange); + } + + targetResultSet = resultSets[resultSetIndex]; } // Retrieve the result set - return resultSets[resultSetIndex].GetSubset(startRow, rowCount); + return targetResultSet.GetSubset(startRow, rowCount); } #endregion @@ -377,9 +465,12 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution if (disposing) { - foreach (ResultSet r in ResultSets) + lock (resultSets) { - r.Dispose(); + foreach (ResultSet r in resultSets) + { + r.Dispose(); + } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbColumnWrapper.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbColumnWrapper.cs index 9e387f8c..09beae29 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbColumnWrapper.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbColumnWrapper.cs @@ -152,6 +152,13 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts } } + /// + /// Default constructor, used for deserializing JSON RPC only + /// + public DbColumnWrapper() + { + } + #region Properties /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteBatchNotifications.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteBatchNotifications.cs new file mode 100644 index 00000000..42877b6d --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteBatchNotifications.cs @@ -0,0 +1,39 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts +{ + /// + /// Parameters to be sent back as part of a QueryExecuteBatchCompleteEvent to indicate that a + /// batch of a query completed. + /// + public class QueryExecuteBatchNotificationParams + { + /// + /// Summary of the batch that just completed + /// + public BatchSummary BatchSummary { get; set; } + + /// + /// URI for the editor that owns the query + /// + public string OwnerUri { get; set; } + } + + public class QueryExecuteBatchCompleteEvent + { + public static readonly + EventType Type = + EventType.Create("query/batchComplete"); + } + + public class QueryExecuteBatchStartEvent + { + public static readonly + EventType Type = + EventType.Create("query/batchStart"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteRequest.cs index 6079bf51..7630b712 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteRequest.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteRequest.cs @@ -7,21 +7,6 @@ using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts { - /// - /// Container class for a selection range from file - /// - public class SelectionData { - public int StartLine { get; set; } - public int StartColumn { get; set; } - public int EndLine { get; set; } - public int EndColumn { get; set; } - public SelectionData(int startLine, int startColumn, int endLine, int endColumn) { - StartLine = startLine; - StartColumn = startColumn; - EndLine = endLine; - EndColumn = endColumn; - } - } /// /// Parameters for the query execute request /// @@ -44,7 +29,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts public class QueryExecuteResult { /// - /// Connection error messages. Optional, can be set to null to indicate no errors + /// Informational messages from the query runner. Optional, can be set to null. /// public string Messages { get; set; } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteResultSetCompleteNotification.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteResultSetCompleteNotification.cs new file mode 100644 index 00000000..c8eeb00b --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteResultSetCompleteNotification.cs @@ -0,0 +1,22 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts +{ + public class QueryExecuteResultSetCompleteParams + { + public ResultSetSummary ResultSetSummary { get; set; } + + public string OwnerUri { get; set; } + } + + public class QueryExecuteResultSetCompleteEvent + { + public static readonly + EventType Type = + EventType.Create("query/resultSetComplete"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultMessage.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultMessage.cs index 27e6713b..90f66015 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultMessage.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultMessage.cs @@ -40,5 +40,12 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts Time = DateTime.Now.ToString("o"); Message = message; } + + /// + /// Default constructor, used for deserializing JSON RPC only + /// + public ResultMessage() + { + } } } \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSummary.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSummary.cs index c8705d8b..e6fc8691 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSummary.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSummary.cs @@ -15,6 +15,11 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts /// public int Id { get; set; } + /// + /// The ID of the batch set within the query + /// + public int BatchId { get; set; } + /// /// The number of rows that was returned with the resultset /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/SelectionData.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/SelectionData.cs new file mode 100644 index 00000000..183f2b3c --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/SelectionData.cs @@ -0,0 +1,52 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts +{ + /// + /// Container class for a selection range from file + /// + /// TODO: Remove this in favor of buffer range end-to-end + public class SelectionData + { + public SelectionData() { } + + public SelectionData(int startLine, int startColumn, int endLine, int endColumn) + { + StartLine = startLine; + StartColumn = startColumn; + EndLine = endLine; + EndColumn = endColumn; + } + + #region Properties + + public int EndColumn { get; set; } + + public int EndLine { get; set; } + + public int StartColumn { get; set; } + public int StartLine { get; set; } + + #endregion + + public BufferRange ToBufferRange() + { + return new BufferRange(StartLine, StartColumn, EndLine, EndColumn); + } + + public static SelectionData FromBufferRange(BufferRange range) + { + return new SelectionData + { + StartLine = range.Start.Line, + StartColumn = range.Start.Column, + EndLine = range.End.Line, + EndColumn = range.End.Column + }; + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/FileStreamWrapper.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/FileStreamWrapper.cs deleted file mode 100644 index 74297a42..00000000 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/FileStreamWrapper.cs +++ /dev/null @@ -1,273 +0,0 @@ -// -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. -// - -using System; -using System.Diagnostics; -using System.IO; -using Microsoft.SqlTools.ServiceLayer.Utility; - -namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage -{ - /// - /// Wrapper for a file stream, providing simplified creation, deletion, read, and write - /// functionality. - /// - public class FileStreamWrapper : IFileStreamWrapper - { - #region Member Variables - - private byte[] buffer; - private int bufferDataSize; - private FileStream fileStream; - private long startOffset; - private long currentOffset; - - #endregion - - /// - /// Constructs a new FileStreamWrapper and initializes its state. - /// - public FileStreamWrapper() - { - // Initialize the internal state - bufferDataSize = 0; - startOffset = 0; - currentOffset = 0; - } - - #region IFileStreamWrapper Implementation - - /// - /// Initializes the wrapper by creating the internal buffer and opening the requested file. - /// If the file does not already exist, it will be created. - /// - /// Name of the file to open/create - /// The length of the internal buffer - /// - /// Whether or not the wrapper will be used for reading. If true, any calls to a - /// method that writes will cause an InvalidOperationException - /// - public void Init(string fileName, int bufferLength, FileAccess accessMethod) - { - // Sanity check for valid buffer length, fileName, and accessMethod - Validate.IsGreaterThan(nameof(bufferLength), bufferLength, 0); - Validate.IsNotNullOrWhitespaceString(nameof(fileName), fileName); - if (accessMethod == FileAccess.Write) - { - throw new ArgumentException(SR.QueryServiceFileWrapperWriteOnly, nameof(fileName)); - } - - // Setup the buffer - buffer = new byte[bufferLength]; - - // Open the requested file for reading/writing, creating one if it doesn't exist - fileStream = new FileStream(fileName, FileMode.OpenOrCreate, accessMethod, FileShare.ReadWrite, - bufferLength, false /*don't use asyncio*/); - } - - /// - /// Reads data into a buffer from the current offset into the file - /// - /// The buffer to output the read data to - /// The number of bytes to read into the buffer - /// The number of bytes read - public int ReadData(byte[] buf, int bytes) - { - return ReadData(buf, bytes, currentOffset); - } - - /// - /// Reads data into a buffer from the specified offset into the file - /// - /// The buffer to output the read data to - /// The number of bytes to read into the buffer - /// The offset into the file to start reading bytes from - /// The number of bytes read - public int ReadData(byte[] buf, int bytes, long offset) - { - // Make sure that we're initialized before performing operations - if (buffer == null) - { - throw new InvalidOperationException(SR.QueryServiceFileWrapperNotInitialized); - } - - MoveTo(offset); - - int bytesCopied = 0; - while (bytesCopied < bytes) - { - int bufferOffset, bytesToCopy; - GetByteCounts(bytes, bytesCopied, out bufferOffset, out bytesToCopy); - Buffer.BlockCopy(buffer, bufferOffset, buf, bytesCopied, bytesToCopy); - bytesCopied += bytesToCopy; - - if (bytesCopied < bytes && // did not get all the bytes yet - bufferDataSize == buffer.Length) // since current data buffer is full we should continue reading the file - { - // move forward one full length of the buffer - MoveTo(startOffset + buffer.Length); - } - else - { - // copied all the bytes requested or possible, adjust the current buffer pointer - currentOffset += bytesToCopy; - break; - } - } - return bytesCopied; - } - - /// - /// Writes data to the underlying filestream, with buffering. - /// - /// The buffer of bytes to write to the filestream - /// The number of bytes to write - /// The number of bytes written - public int WriteData(byte[] buf, int bytes) - { - // Make sure that we're initialized before performing operations - if (buffer == null) - { - throw new InvalidOperationException(SR.QueryServiceFileWrapperNotInitialized); - } - if (!fileStream.CanWrite) - { - throw new InvalidOperationException(SR.QueryServiceFileWrapperReadOnly); - } - - int bytesCopied = 0; - while (bytesCopied < bytes) - { - int bufferOffset, bytesToCopy; - GetByteCounts(bytes, bytesCopied, out bufferOffset, out bytesToCopy); - Buffer.BlockCopy(buf, bytesCopied, buffer, bufferOffset, bytesToCopy); - bytesCopied += bytesToCopy; - - // adjust the current buffer pointer - currentOffset += bytesToCopy; - - if (bytesCopied < bytes) // did not get all the bytes yet - { - Debug.Assert((int)(currentOffset - startOffset) == buffer.Length); - // flush buffer - Flush(); - } - } - Debug.Assert(bytesCopied == bytes); - return bytesCopied; - } - - /// - /// Flushes the internal buffer to the filestream - /// - public void Flush() - { - // Make sure that we're initialized before performing operations - if (buffer == null) - { - throw new InvalidOperationException(SR.QueryServiceFileWrapperNotInitialized); - } - if (!fileStream.CanWrite) - { - throw new InvalidOperationException(SR.QueryServiceFileWrapperReadOnly); - } - - // Make sure we are at the right place in the file - Debug.Assert(fileStream.Position == startOffset); - - int bytesToWrite = (int)(currentOffset - startOffset); - fileStream.Write(buffer, 0, bytesToWrite); - startOffset += bytesToWrite; - fileStream.Flush(); - - Debug.Assert(startOffset == currentOffset); - } - - /// - /// Deletes the given file (ideally, created with this wrapper) from the filesystem - /// - /// The path to the file to delete - public static void DeleteFile(string fileName) - { - File.Delete(fileName); - } - - #endregion - - /// - /// Perform calculations to determine how many bytes to copy and what the new buffer offset - /// will be for copying. - /// - /// Number of bytes requested to copy - /// Number of bytes copied so far - /// New offset to start copying from/to - /// Number of bytes to copy in this iteration - private void GetByteCounts(int bytes, int bytesCopied, out int bufferOffset, out int bytesToCopy) - { - bufferOffset = (int) (currentOffset - startOffset); - bytesToCopy = bytes - bytesCopied; - if (bytesToCopy > buffer.Length - bufferOffset) - { - bytesToCopy = buffer.Length - bufferOffset; - } - } - - /// - /// Moves the internal buffer to the specified offset into the file - /// - /// Offset into the file to move to - private void MoveTo(long offset) - { - if (buffer.Length > bufferDataSize || // buffer is not completely filled - offset < startOffset || // before current buffer start - offset >= (startOffset + buffer.Length)) // beyond current buffer end - { - // init the offset - startOffset = offset; - - // position file pointer - fileStream.Seek(startOffset, SeekOrigin.Begin); - - // fill in the buffer - bufferDataSize = fileStream.Read(buffer, 0, buffer.Length); - } - // make sure to record where we are - currentOffset = offset; - } - - #region IDisposable Implementation - - private bool disposed; - - public void Dispose() - { - Dispose(true); - GC.SuppressFinalize(this); - } - - protected virtual void Dispose(bool disposing) - { - if (disposed) - { - return; - } - - if (disposing && fileStream != null) - { - if(fileStream.CanWrite) { Flush(); } - fileStream.Dispose(); - } - - disposed = true; - } - - ~FileStreamWrapper() - { - Dispose(false); - } - - #endregion - } -} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWrapper.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWrapper.cs deleted file mode 100644 index 38c283c5..00000000 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWrapper.cs +++ /dev/null @@ -1,22 +0,0 @@ -// -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. -// - -using System; -using System.IO; - -namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage -{ - /// - /// Interface for a wrapper around a filesystem reader/writer, mainly for unit testing purposes - /// - public interface IFileStreamWrapper : IDisposable - { - void Init(string fileName, int bufferSize, FileAccess fileAccessMode); - int ReadData(byte[] buffer, int bytes); - int ReadData(byte[] buffer, int bytes, long fileOffset); - int WriteData(byte[] buffer, int bytes); - void Flush(); - } -} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWriter.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWriter.cs index 7cfffee8..951bd89c 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWriter.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWriter.cs @@ -5,6 +5,7 @@ using System; using System.Data.SqlTypes; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage { @@ -25,7 +26,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage int WriteDouble(double val); int WriteDecimal(decimal val); int WriteSqlDecimal(SqlDecimal val); - int WriteDateTime(DateTime val); + int WriteDateTime(DbColumnWrapper column, DateTime val); int WriteDateTimeOffset(DateTimeOffset dtoVal); int WriteTimeSpan(TimeSpan val); int WriteString(string val); diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamFactory.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamFactory.cs index c06a13ac..573a62d4 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamFactory.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamFactory.cs @@ -29,7 +29,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// A public IFileStreamReader GetReader(string fileName) { - return new ServiceBufferFileStreamReader(new FileStreamWrapper(), fileName); + return new ServiceBufferFileStreamReader(new FileStream(fileName, FileMode.Open, FileAccess.Read)); } /// @@ -42,7 +42,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// A public IFileStreamWriter GetWriter(string fileName, int maxCharsToStore, int maxXmlCharsToStore) { - return new ServiceBufferFileStreamWriter(new FileStreamWrapper(), fileName, maxCharsToStore, maxXmlCharsToStore); + return new ServiceBufferFileStreamWriter(new FileStream(fileName, FileMode.OpenOrCreate, FileAccess.ReadWrite), maxCharsToStore, maxXmlCharsToStore); } /// @@ -51,14 +51,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// The file to dispose of public void DisposeFile(string fileName) { - try - { - FileStreamWrapper.DeleteFile(fileName); - } - catch - { - // If we have problems deleting the file from a temp location, we don't really care - } + FileUtils.SafeFileDelete(fileName); } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamReader.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamReader.cs index 8547999b..487f827b 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamReader.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamReader.cs @@ -23,22 +23,24 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage private byte[] buffer; - private readonly IFileStreamWrapper fileStream; + private readonly Stream fileStream; - private Dictionary> readMethods; + private readonly Dictionary> readMethods; #endregion /// /// Constructs a new ServiceBufferFileStreamReader and initializes its state /// - /// The filestream wrapper to read from - /// The name of the file to read from - public ServiceBufferFileStreamReader(IFileStreamWrapper fileWrapper, string fileName) + /// The filestream to read from + public ServiceBufferFileStreamReader(Stream stream) { // Open file for reading/writing - fileStream = fileWrapper; - fileStream.Init(fileName, DefaultBufferSize, FileAccess.Read); + if (!stream.CanRead || !stream.CanSeek) + { + throw new InvalidOperationException("Stream must be readable and seekable"); + } + fileStream = stream; // Create internal buffer buffer = new byte[DefaultBufferSize]; @@ -258,11 +260,26 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// A DateTime public FileStreamReadResult ReadDateTime(long offset) { - return ReadCellHelper(offset, length => - { - long ticks = BitConverter.ToInt64(buffer, 0); - return new DateTime(ticks); - }); + int precision = 0; + + return ReadCellHelper(offset, + length => + { + precision = BitConverter.ToInt32(buffer, 0); + long ticks = BitConverter.ToInt64(buffer, 4); + return new DateTime(ticks); + }, null, + time => + { + string format = "yyyy-MM-dd HH:mm:ss"; + if (precision > 0) + { + // Output the number milliseconds equivalent to the precision + // NOTE: string('f', precision) will output ffff for precision=4 + format += "." + new string('f', precision); + } + return time.ToString(format); + }); } /// @@ -372,7 +389,8 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage { // read in length information int lengthValue; - int lengthLength = fileStream.ReadData(buffer, 1, offset); + fileStream.Seek(offset, SeekOrigin.Begin); + int lengthLength = fileStream.Read(buffer, 0, 1); if (buffer[0] != 0xFF) { // one byte is enough @@ -381,7 +399,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage else { // read in next 4 bytes - lengthLength += fileStream.ReadData(buffer, 4); + lengthLength += fileStream.Read(buffer, 0, 4); // reconstruct the length lengthValue = BitConverter.ToInt32(buffer, 0); @@ -433,7 +451,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage else { AssureBufferLength(length.ValueLength); - fileStream.ReadData(buffer, length.ValueLength); + fileStream.Read(buffer, 0, length.ValueLength); T resultObject = convertFunc(length.ValueLength); result.RawObject = resultObject; result.DisplayValue = toStringFunc == null ? result.RawObject.ToString() : toStringFunc(resultObject); diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamWriter.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamWriter.cs index 2e4360d2..75045aff 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamWriter.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamWriter.cs @@ -23,7 +23,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage #region Member Variables - private readonly IFileStreamWrapper fileStream; + private readonly Stream fileStream; private readonly int maxCharsToStore; private readonly int maxXmlCharsToStore; @@ -38,22 +38,24 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// /// Functions to use for writing various types to a file /// - private readonly Dictionary> writeMethods; + private readonly Dictionary> writeMethods; #endregion /// /// Constructs a new writer /// - /// The file wrapper to use as the underlying file stream - /// Name of the file to write to + /// The file wrapper to use as the underlying file stream /// Maximum number of characters to store for long text fields /// Maximum number of characters to store for XML fields - public ServiceBufferFileStreamWriter(IFileStreamWrapper fileWrapper, string fileName, int maxCharsToStore, int maxXmlCharsToStore) + public ServiceBufferFileStreamWriter(Stream stream, int maxCharsToStore, int maxXmlCharsToStore) { // open file for reading/writing - fileStream = fileWrapper; - fileStream.Init(fileName, DefaultBufferLength, FileAccess.ReadWrite); + if (!stream.CanWrite || !stream.CanSeek) + { + throw new InvalidOperationException("Stream must be writable and seekable."); + } + fileStream = stream; // create internal buffer byteBuffer = new byte[DefaultBufferLength]; @@ -72,37 +74,78 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage this.maxXmlCharsToStore = maxXmlCharsToStore; // Define what methods to use to write a type to the file - writeMethods = new Dictionary> + writeMethods = new Dictionary> { - {typeof(string), val => WriteString((string) val)}, - {typeof(short), val => WriteInt16((short) val)}, - {typeof(int), val => WriteInt32((int) val)}, - {typeof(long), val => WriteInt64((long) val)}, - {typeof(byte), val => WriteByte((byte) val)}, - {typeof(char), val => WriteChar((char) val)}, - {typeof(bool), val => WriteBoolean((bool) val)}, - {typeof(double), val => WriteDouble((double) val) }, - {typeof(float), val => WriteSingle((float) val) }, - {typeof(decimal), val => WriteDecimal((decimal) val) }, - {typeof(DateTime), val => WriteDateTime((DateTime) val) }, - {typeof(DateTimeOffset), val => WriteDateTimeOffset((DateTimeOffset) val) }, - {typeof(TimeSpan), val => WriteTimeSpan((TimeSpan) val) }, - {typeof(byte[]), val => WriteBytes((byte[]) val)}, + {typeof(string), (val, col) => WriteString((string) val)}, + {typeof(short), (val, col) => WriteInt16((short) val)}, + {typeof(int), (val, col) => WriteInt32((int) val)}, + {typeof(long), (val, col) => WriteInt64((long) val)}, + {typeof(byte), (val, col) => WriteByte((byte) val)}, + {typeof(char), (val, col) => WriteChar((char) val)}, + {typeof(bool), (val, col) => WriteBoolean((bool) val)}, + {typeof(double), (val, col) => WriteDouble((double) val) }, + {typeof(float), (val, col) => WriteSingle((float) val) }, + {typeof(decimal), (val, col) => WriteDecimal((decimal) val) }, + {typeof(DateTime), (val, col) => WriteDateTime(col, (DateTime) val) }, + {typeof(DateTimeOffset), (val, col) => WriteDateTimeOffset((DateTimeOffset) val) }, + {typeof(TimeSpan), (val, col) => WriteTimeSpan((TimeSpan) val) }, + {typeof(byte[]), (val, col) => WriteBytes((byte[]) val)}, - {typeof(SqlString), val => WriteNullable((SqlString) val, obj => WriteString((string) obj))}, - {typeof(SqlInt16), val => WriteNullable((SqlInt16) val, obj => WriteInt16((short) obj))}, - {typeof(SqlInt32), val => WriteNullable((SqlInt32) val, obj => WriteInt32((int) obj))}, - {typeof(SqlInt64), val => WriteNullable((SqlInt64) val, obj => WriteInt64((long) obj)) }, - {typeof(SqlByte), val => WriteNullable((SqlByte) val, obj => WriteByte((byte) obj)) }, - {typeof(SqlBoolean), val => WriteNullable((SqlBoolean) val, obj => WriteBoolean((bool) obj)) }, - {typeof(SqlDouble), val => WriteNullable((SqlDouble) val, obj => WriteDouble((double) obj)) }, - {typeof(SqlSingle), val => WriteNullable((SqlSingle) val, obj => WriteSingle((float) obj)) }, - {typeof(SqlDecimal), val => WriteNullable((SqlDecimal) val, obj => WriteSqlDecimal((SqlDecimal) obj)) }, - {typeof(SqlDateTime), val => WriteNullable((SqlDateTime) val, obj => WriteDateTime((DateTime) obj)) }, - {typeof(SqlBytes), val => WriteNullable((SqlBytes) val, obj => WriteBytes((byte[]) obj)) }, - {typeof(SqlBinary), val => WriteNullable((SqlBinary) val, obj => WriteBytes((byte[]) obj)) }, - {typeof(SqlGuid), val => WriteNullable((SqlGuid) val, obj => WriteGuid((Guid) obj)) }, - {typeof(SqlMoney), val => WriteNullable((SqlMoney) val, obj => WriteMoney((SqlMoney) obj)) } + { + typeof(SqlString), + (val, col) => WriteNullable((SqlString) val, obj => WriteString((string) obj)) + }, + { + typeof(SqlInt16), + (val, col) => WriteNullable((SqlInt16) val, obj => WriteInt16((short) obj)) + }, + { + typeof(SqlInt32), + (val, col) => WriteNullable((SqlInt32) val, obj => WriteInt32((int) obj)) + }, + { + typeof(SqlInt64), + (val, col) => WriteNullable((SqlInt64) val, obj => WriteInt64((long) obj)) + }, + { + typeof(SqlByte), + (val, col) => WriteNullable((SqlByte) val, obj => WriteByte((byte) obj)) + }, + { + typeof(SqlBoolean), + (val, col) => WriteNullable((SqlBoolean) val, obj => WriteBoolean((bool) obj)) }, + { + typeof(SqlDouble), + (val, col) => WriteNullable((SqlDouble) val, obj => WriteDouble((double) obj)) + }, + { + typeof(SqlSingle), + (val, col) => WriteNullable((SqlSingle) val, obj => WriteSingle((float) obj)) + }, + { + typeof(SqlDecimal), + (val, col) => WriteNullable((SqlDecimal) val, obj => WriteSqlDecimal((SqlDecimal) obj)) + }, + { + typeof(SqlDateTime), + (val, col) => WriteNullable((SqlDateTime) val, obj => WriteDateTime(col, (DateTime) obj)) + }, + { + typeof(SqlBytes), + (val, col) => WriteNullable((SqlBytes) val, obj => WriteBytes((byte[]) obj)) + }, + { + typeof(SqlBinary), + (val, col) => WriteNullable((SqlBinary) val, obj => WriteBytes((byte[]) obj)) + }, + { + typeof(SqlGuid), + (val, col) => WriteNullable((SqlGuid) val, obj => WriteGuid((Guid) obj)) + }, + { + typeof(SqlMoney), + (val, col) => WriteNullable((SqlMoney) val, obj => WriteMoney((SqlMoney) obj)) + } }; } @@ -188,10 +231,10 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage } // Use the appropriate writing method for the type - Func writeMethod; + Func writeMethod; if (writeMethods.TryGetValue(tVal, out writeMethod)) { - rowBytes += writeMethod(values[i]); + rowBytes += writeMethod(values[i], ci); } else { @@ -212,7 +255,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage public int WriteNull() { byteBuffer[0] = 0x00; - return fileStream.WriteData(byteBuffer, 1); + return WriteHelper(byteBuffer, 1); } /// @@ -224,7 +267,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage byteBuffer[0] = 0x02; // length shortBuffer[0] = val; Buffer.BlockCopy(shortBuffer, 0, byteBuffer, 1, 2); - return fileStream.WriteData(byteBuffer, 3); + return WriteHelper(byteBuffer, 3); } /// @@ -236,7 +279,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage byteBuffer[0] = 0x04; // length intBuffer[0] = val; Buffer.BlockCopy(intBuffer, 0, byteBuffer, 1, 4); - return fileStream.WriteData(byteBuffer, 5); + return WriteHelper(byteBuffer, 5); } /// @@ -248,7 +291,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage byteBuffer[0] = 0x08; // length longBuffer[0] = val; Buffer.BlockCopy(longBuffer, 0, byteBuffer, 1, 8); - return fileStream.WriteData(byteBuffer, 9); + return WriteHelper(byteBuffer, 9); } /// @@ -260,7 +303,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage byteBuffer[0] = 0x02; // length charBuffer[0] = val; Buffer.BlockCopy(charBuffer, 0, byteBuffer, 1, 2); - return fileStream.WriteData(byteBuffer, 3); + return WriteHelper(byteBuffer, 3); } /// @@ -271,7 +314,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage { byteBuffer[0] = 0x01; // length byteBuffer[1] = (byte) (val ? 0x01 : 0x00); - return fileStream.WriteData(byteBuffer, 2); + return WriteHelper(byteBuffer, 2); } /// @@ -282,7 +325,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage { byteBuffer[0] = 0x01; // length byteBuffer[1] = val; - return fileStream.WriteData(byteBuffer, 2); + return WriteHelper(byteBuffer, 2); } /// @@ -294,7 +337,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage byteBuffer[0] = 0x04; // length floatBuffer[0] = val; Buffer.BlockCopy(floatBuffer, 0, byteBuffer, 1, 4); - return fileStream.WriteData(byteBuffer, 5); + return WriteHelper(byteBuffer, 5); } /// @@ -306,7 +349,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage byteBuffer[0] = 0x08; // length doubleBuffer[0] = val; Buffer.BlockCopy(doubleBuffer, 0, byteBuffer, 1, 8); - return fileStream.WriteData(byteBuffer, 9); + return WriteHelper(byteBuffer, 9); } /// @@ -330,7 +373,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage // data value Buffer.BlockCopy(arrInt32, 0, byteBuffer, 3, iLen - 3); - iTotalLen += fileStream.WriteData(byteBuffer, iLen); + iTotalLen += WriteHelper(byteBuffer, iLen); return iTotalLen; // len+data } @@ -346,18 +389,31 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage int iTotalLen = WriteLength(iLen); // length Buffer.BlockCopy(arrInt32, 0, byteBuffer, 0, iLen); - iTotalLen += fileStream.WriteData(byteBuffer, iLen); + iTotalLen += WriteHelper(byteBuffer, iLen); return iTotalLen; // len+data } /// - /// Writes a DateTime to the file + /// Writes a DateTime to the file as precision and ticks /// /// Number of bytes used to store the DateTime - public int WriteDateTime(DateTime dtVal) + public int WriteDateTime(DbColumnWrapper col, DateTime dtVal) { - return WriteInt64(dtVal.Ticks); + // Length + var length = WriteLength(12); + + // Precision + intBuffer[0] = col.NumericScale ?? 3; + Buffer.BlockCopy(intBuffer, 0, byteBuffer, 0, 4); + + // Ticks + longBuffer[0] = dtVal.Ticks; + Buffer.BlockCopy(longBuffer, 0, byteBuffer, 4, 8); + + length += WriteHelper(byteBuffer, 12); + + return length; } /// @@ -374,7 +430,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage longBufferOffset[0] = dtoVal.Ticks; longBufferOffset[1] = dtoVal.Offset.Ticks; Buffer.BlockCopy(longBufferOffset, 0, byteBuffer, 1, 16); - return fileStream.WriteData(byteBuffer, 17); + return WriteHelper(byteBuffer, 17); } /// @@ -406,7 +462,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage byteBuffer[3] = 0x00; byteBuffer[4] = 0x00; - iTotalLen = fileStream.WriteData(byteBuffer, 5); + iTotalLen = WriteHelper(byteBuffer, 5); } else { @@ -415,7 +471,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage // convert char array into byte array and write it out iTotalLen = WriteLength(bytes.Length); - iTotalLen += fileStream.WriteData(bytes, bytes.Length); + iTotalLen += WriteHelper(bytes, bytes.Length); } return iTotalLen; // len+data } @@ -438,12 +494,12 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage byteBuffer[3] = 0x00; byteBuffer[4] = 0x00; - iTotalLen = fileStream.WriteData(byteBuffer, 5); + iTotalLen = WriteHelper(byteBuffer, 5); } else { iTotalLen = WriteLength(bytesVal.Length); - iTotalLen += fileStream.WriteData(bytesVal, bytesVal.Length); + iTotalLen += WriteHelper(bytesVal, bytesVal.Length); } return iTotalLen; // len+data } @@ -507,7 +563,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage int iTmp = iLen & 0x000000FF; byteBuffer[0] = Convert.ToByte(iTmp); - return fileStream.WriteData(byteBuffer, 1); + return WriteHelper(byteBuffer, 1); } // The length won't fit in 1 byte, so we need to use 1 byte to signify that the length // is a full 4 bytes. @@ -516,7 +572,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage // convert int32 into array of bytes intBuffer[0] = iLen; Buffer.BlockCopy(intBuffer, 0, byteBuffer, 1, 4); - return fileStream.WriteData(byteBuffer, 5); + return WriteHelper(byteBuffer, 5); } /// @@ -532,6 +588,12 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage return val.IsNull ? WriteNull() : valueWriteFunc(val); } + private int WriteHelper(byte[] buffer, int length) + { + fileStream.Write(buffer, 0, length); + return length; + } + #endregion #region IDisposable Implementation diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/StorageDataReader.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/StorageDataReader.cs index cc5d1443..33fdf6d3 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/StorageDataReader.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/StorageDataReader.cs @@ -304,7 +304,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// This code is take almost verbatim from Microsoft.SqlServer.Management.UI.Grid, SSMS /// DataStorage, StorageDataReader class. /// - private class StringWriterWithMaxCapacity : StringWriter + internal class StringWriterWithMaxCapacity : StringWriter { private bool stopWriting; diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs index d47fafca..3c5bae8c 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs @@ -51,11 +51,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// private bool hasExecuteBeenCalled; - /// - /// The factory to use for outputting the results of this query - /// - private readonly IFileStreamFactory outputFileFactory; - #endregion /// @@ -77,7 +72,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution QueryText = queryText; editorConnection = connection; cancellationSource = new CancellationTokenSource(); - outputFileFactory = outputFactory; // Process the query into batches ParseResult parseResult = Parser.Parse(queryText, new ParseOptions @@ -85,27 +79,35 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution BatchSeparator = settings.BatchSeparator }); // NOTE: We only want to process batches that have statements (ie, ignore comments and empty lines) - Batches = parseResult.Script.Batches.Where(b => b.Statements.Count > 0) - .Select(b => new Batch(b.Sql, - b.StartLocation.LineNumber - 1, - b.StartLocation.ColumnNumber - 1, - b.EndLocation.LineNumber - 1, - b.EndLocation.ColumnNumber - 1, - outputFileFactory)).ToArray(); + var batchSelection = parseResult.Script.Batches + .Where(batch => batch.Statements.Count > 0) + .Select((batch, index) => + new Batch(batch.Sql, + new SelectionData( + batch.StartLocation.LineNumber - 1, + batch.StartLocation.ColumnNumber - 1, + batch.EndLocation.LineNumber - 1, + batch.EndLocation.ColumnNumber - 1), + index, outputFactory)); + Batches = batchSelection.ToArray(); } - #region Properties + #region Events /// - /// Delegate type for callback when a query completes or fails + /// Event to be called when a batch is completed. /// - /// The query that completed - public delegate Task QueryAsyncEventHandler(Query q); + public event Batch.BatchAsyncEventHandler BatchCompleted; + + /// + /// Event to be called when a batch starts execution. + /// + public event Batch.BatchAsyncEventHandler BatchStarted; /// /// Delegate type for callback when a query connection fails /// - /// The query that completed + /// Error message for the failing query public delegate Task QueryAsyncErrorEventHandler(string message); /// @@ -123,6 +125,21 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// public event QueryAsyncErrorEventHandler QueryConnectionException; + /// + /// Event to be called when a resultset has completed. + /// + public event ResultSet.ResultSetAsyncEventHandler ResultSetCompleted; + + #endregion + + #region Properties + + /// + /// Delegate type for callback when a query completes or fails + /// + /// The query that completed + public delegate Task QueryAsyncEventHandler(Query q); + /// /// The batches underneath this query /// @@ -139,21 +156,14 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution { throw new InvalidOperationException("Query has not been executed."); } - - return Batches.Select((batch, index) => new BatchSummary - { - Id = index, - ExecutionStart = batch.ExecutionStartTimeStamp, - ExecutionEnd = batch.ExecutionEndTimeStamp, - ExecutionElapsed = batch.ExecutionElapsedTime, - HasError = batch.HasError, - Messages = batch.ResultMessages.ToArray(), - ResultSetSummaries = batch.ResultSummaries, - Selection = batch.Selection - }).ToArray(); + return Batches.Select(b => b.Summary).ToArray(); } } + /// + /// Storage for the async task for execution. Set as internal in order to await completion + /// in unit tests. + /// internal Task ExecutionTask { get; private set; } /// @@ -214,12 +224,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// A subset of results public Task GetSubset(int batchIndex, int resultSetIndex, int startRow, int rowCount) { - // Sanity check that the results are available - if (!HasExecuted) - { - throw new InvalidOperationException(SR.QueryServiceSubsetNotCompleted); - } - // Sanity check to make sure that the batch is within bounds if (batchIndex < 0 || batchIndex >= Batches.Length) { @@ -256,11 +260,11 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution { await conn.OpenAsync(); } - catch(Exception exception) + catch (Exception exception) { - this.HasExecuted = true; + this.HasExecuted = true; if (QueryConnectionException != null) - { + { await QueryConnectionException(exception.Message); } return; @@ -278,6 +282,9 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution // We need these to execute synchronously, otherwise the user will be very unhappy foreach (Batch b in Batches) { + b.BatchStart += BatchStarted; + b.BatchCompletion += BatchCompleted; + b.ResultSetCompletion += ResultSetCompleted; await b.Execute(conn, cancellationSource.Token); } diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs index c7614596..4bd9f1fa 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs @@ -424,6 +424,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution OwnerUri = executeParams.OwnerUri, BatchSummaries = q.BatchSummaries }; + await requestContext.SendEvent(QueryExecuteCompleteEvent.Type, eventParams); }; @@ -442,13 +443,54 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution query.QueryFailed += callback; query.QueryConnectionException += errorCallback; + // Setup the batch callbacks + Batch.BatchAsyncEventHandler batchStartCallback = async b => + { + QueryExecuteBatchNotificationParams eventParams = new QueryExecuteBatchNotificationParams + { + BatchSummary = b.Summary, + OwnerUri = executeParams.OwnerUri + }; + await requestContext.SendEvent(QueryExecuteBatchStartEvent.Type, eventParams); + }; + query.BatchStarted += batchStartCallback; + + Batch.BatchAsyncEventHandler batchCompleteCallback = async b => + { + QueryExecuteBatchNotificationParams eventParams = new QueryExecuteBatchNotificationParams + { + BatchSummary = b.Summary, + OwnerUri = executeParams.OwnerUri + }; + await requestContext.SendEvent(QueryExecuteBatchCompleteEvent.Type, eventParams); + }; + query.BatchCompleted += batchCompleteCallback; + + // Setup the ResultSet completion callback + ResultSet.ResultSetAsyncEventHandler resultCallback = async r => + { + QueryExecuteResultSetCompleteParams eventParams = new QueryExecuteResultSetCompleteParams + { + ResultSetSummary = r.Summary, + OwnerUri = executeParams.OwnerUri + }; + await requestContext.SendEvent(QueryExecuteResultSetCompleteEvent.Type, eventParams); + }; + query.ResultSetCompleted += resultCallback; + // Launch this as an asynchronous task query.Execute(); // Send back a result showing we were successful + string messages = null; + if (query.Batches.Length == 0) + { + // If there were no batches to execute, send back an informational message that the commands were completed successfully + messages = SR.QueryServiceCompletedSuccessfully; + } await requestContext.SendResult(new QueryExecuteResult { - Messages = null + Messages = messages }); } diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs index a96de759..e3ecb890 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs @@ -8,7 +8,6 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.Data.Common; using System.Linq; -using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; @@ -17,6 +16,10 @@ using Microsoft.SqlTools.ServiceLayer.Utility; namespace Microsoft.SqlTools.ServiceLayer.QueryExecution { + /// + /// Class that represents a resultset the was generated from a query. Contains logic for + /// storing and retrieving results. Is contained by a Batch class. + /// public class ResultSet : IDisposable { #region Constants @@ -35,20 +38,31 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution #region Member Variables + /// + /// The reader to use for this resultset + /// + private readonly StorageDataReader dataReader; + /// /// For IDisposable pattern, whether or not object has been disposed /// private bool disposed; + /// + /// A list of offsets into the buffer file that correspond to where rows start + /// + private readonly LongList fileOffsets; + /// /// The factory to use to get reading/writing handlers /// private readonly IFileStreamFactory fileStreamFactory; /// - /// Whether or not the result set has been read in from the database + /// Whether or not the result set has been read in from the database, + /// set as internal in order to fake value in unit tests /// - private bool hasBeenRead; + internal bool hasBeenRead; /// /// Whether resultSet is a 'for xml' or 'for json' result @@ -60,15 +74,10 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// private readonly string outputFileName; - /// - /// Whether the resultSet is in the process of being disposed - /// - private bool isBeingDisposed; - /// /// All save tasks currently saving this ResultSet /// - private ConcurrentDictionary saveTasks; + private readonly ConcurrentDictionary saveTasks; #endregion @@ -76,17 +85,21 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// Creates a new result set and initializes its state /// /// The reader from executing a query + /// The ID of the resultset, the ordinal of the result within the batch + /// The ID of the batch, the ordinal of the batch within the query /// Factory for creating a reader/writer - public ResultSet(DbDataReader reader, IFileStreamFactory factory) + public ResultSet(DbDataReader reader, int ordinal, int batchOrdinal, IFileStreamFactory factory) { // Sanity check to make sure we got a reader Validate.IsNotNull(nameof(reader), SR.QueryServiceResultSetReaderNull); - DataReader = new StorageDataReader(reader); + dataReader = new StorageDataReader(reader); + Id = ordinal; + BatchId = batchOrdinal; // Initialize the storage outputFileName = factory.CreateFile(); - FileOffsets = new LongList(); + fileOffsets = new LongList(); // Store the factory fileStreamFactory = factory; @@ -96,17 +109,22 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution #region Properties + /// + /// Asynchronous handler for when a resultset has completed + /// + /// The result set that completed + public delegate Task ResultSetAsyncEventHandler(ResultSet resultSet); + + /// + /// Event that will be called when the result set has completed execution + /// + public event ResultSetAsyncEventHandler ResultCompletion; + /// /// Whether the resultSet is in the process of being disposed /// /// - internal bool IsBeingDisposed - { - get - { - return isBeingDisposed; - } - } + internal bool IsBeingDisposed { get; private set; } /// /// The columns for this result set @@ -114,14 +132,14 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution public DbColumnWrapper[] Columns { get; private set; } /// - /// The reader to use for this resultset + /// ID of the result set, relative to the batch /// - private StorageDataReader DataReader { get; set; } + public int Id { get; private set; } /// - /// A list of offsets into the buffer file that correspond to where rows start + /// ID of the batch set, relative to the query /// - private LongList FileOffsets { get; set; } + public int BatchId { get; private set; } /// /// Maximum number of characters to store for a field @@ -138,6 +156,23 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// public long RowCount { get; private set; } + /// + /// Generates a summary of this result set + /// + public ResultSetSummary Summary + { + get + { + return new ResultSetSummary + { + ColumnInfo = Columns, + Id = Id, + BatchId = BatchId, + RowCount = RowCount + }; + } + } + #endregion #region Public Methods @@ -178,18 +213,19 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution if (isSingleColumnXmlJsonResultSet) { // Iterate over all the rows and process them into a list of string builders - IEnumerable rowValues = FileOffsets.Select(rowOffset => fileStreamReader.ReadRow(rowOffset, Columns)[0].DisplayValue); + // ReSharper disable once AccessToDisposedClosure The lambda is used immediately in string.Join call + IEnumerable rowValues = fileOffsets.Select(rowOffset => fileStreamReader.ReadRow(rowOffset, Columns)[0].DisplayValue); rows = new[] { new[] { string.Join(string.Empty, rowValues) } }; - } else { // Figure out which rows we need to read back - IEnumerable rowOffsets = FileOffsets.Skip(startRow).Take(rowCount); + IEnumerable rowOffsets = fileOffsets.Skip(startRow).Take(rowCount); // Iterate over the rows we need and process them into output - rows = rowOffsets.Select(rowOffset => - fileStreamReader.ReadRow(rowOffset, Columns).Select(cell => cell.DisplayValue).ToArray()) + // ReSharper disable once AccessToDisposedClosure The lambda is used immediately in .ToArray call + rows = rowOffsets.Select(rowOffset => fileStreamReader.ReadRow(rowOffset, Columns) + .Select(cell => cell.DisplayValue).ToArray()) .ToArray(); } @@ -209,29 +245,41 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// Cancellation token for cancelling the query public async Task ReadResultToEnd(CancellationToken cancellationToken) { - // Mark that result has been read - hasBeenRead = true; - - // Open a writer for the file - using (IFileStreamWriter fileWriter = fileStreamFactory.GetWriter(outputFileName, MaxCharsToStore, MaxXmlCharsToStore)) + try { - // If we can initialize the columns using the column schema, use that - if (!DataReader.DbDataReader.CanGetColumnSchema()) - { - throw new InvalidOperationException(SR.QueryServiceResultSetNoColumnSchema); - } - Columns = DataReader.Columns; - long currentFileOffset = 0; + // Mark that result has been read + hasBeenRead = true; - while (await DataReader.ReadAsync(cancellationToken)) + // Open a writer for the file + var fileWriter = fileStreamFactory.GetWriter(outputFileName, MaxCharsToStore, MaxCharsToStore); + using (fileWriter) { - RowCount++; - FileOffsets.Add(currentFileOffset); - currentFileOffset += fileWriter.WriteRow(DataReader); + // If we can initialize the columns using the column schema, use that + if (!dataReader.DbDataReader.CanGetColumnSchema()) + { + throw new InvalidOperationException(SR.QueryServiceResultSetNoColumnSchema); + } + Columns = dataReader.Columns; + long currentFileOffset = 0; + + while (await dataReader.ReadAsync(cancellationToken)) + { + RowCount++; + fileOffsets.Add(currentFileOffset); + currentFileOffset += fileWriter.WriteRow(dataReader); + } + } + // Check if resultset is 'for xml/json'. If it is, set isJson/isXml value in column metadata + SingleColumnXmlJsonResultSet(); + } + finally + { + // Fire off a result set completion event if we have one + if (ResultCompletion != null) + { + await ResultCompletion(this); } } - // Check if resultset is 'for xml/json'. If it is, set isJson/isXml value in column metadata - SingleColumnXmlJsonResultSet(); } #endregion @@ -251,7 +299,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution return; } - isBeingDisposed = true; + IsBeingDisposed = true; // Check if saveTasks are running for this ResultSet if (!saveTasks.IsEmpty) { @@ -263,7 +311,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution fileStreamFactory.DisposeFile(outputFileName); } disposed = true; - isBeingDisposed = false; + IsBeingDisposed = false; }); } else @@ -274,7 +322,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution fileStreamFactory.DisposeFile(outputFileName); } disposed = true; - isBeingDisposed = false; + IsBeingDisposed = false; } } @@ -288,10 +336,11 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// If the result set represented by this class corresponds to a single JSON /// column that contains results of "for json" query, set isJson = true /// - private void SingleColumnXmlJsonResultSet() { + private void SingleColumnXmlJsonResultSet() + { if (Columns?.Length == 1 && RowCount != 0) - { + { if (Columns[0].ColumnName.Equals(NameOfForXMLColumn, StringComparison.Ordinal)) { Columns[0].IsXml = true; @@ -303,7 +352,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution Columns[0].IsJson = true; isSingleColumnXmlJsonResultSet = true; RowCount = 1; - } + } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/SaveResults.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/SaveResults.cs index 713eb764..0490f6d4 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/SaveResults.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/SaveResults.cs @@ -61,7 +61,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// /// The field to encode /// The CSV encoded version of the original field - internal static String EncodeCsvField(String field) + internal static string EncodeCsvField(string field) { StringBuilder sbField = new StringBuilder(field); @@ -102,9 +102,8 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution //Replace all quotes in the original field with double quotes sbField.Replace("\"", "\"\""); - - String ret = sbField.ToString(); - + string ret = sbField.ToString(); + if (embedInQuotes) { ret = "\"" + ret + "\""; @@ -121,7 +120,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution internal static bool IsSaveSelection(SaveResultsRequestParams saveParams) { return (saveParams.ColumnStartIndex != null && saveParams.ColumnEndIndex != null - && saveParams.RowEndIndex != null && saveParams.RowEndIndex != null); + && saveParams.RowStartIndex != null && saveParams.RowEndIndex != null); } /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsSettings.cs b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsSettings.cs index 37c35ebf..0359df16 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsSettings.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsSettings.cs @@ -68,7 +68,7 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext } /// - /// Gets a flag determining if suggestons are enabled + /// Gets a flag determining if suggestions are enabled /// public bool IsSuggestionsEnabled { @@ -90,6 +90,17 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext && this.SqlTools.IntelliSense.EnableQuickInfo.Value; } } + + /// + /// Gets a flag determining if IntelliSense is enabled + /// + public bool IsIntelliSenseEnabled + { + get + { + return this.SqlTools.IntelliSense.EnableIntellisense; + } + } } /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/FileUtils.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/FileUtils.cs similarity index 100% rename from src/Microsoft.SqlTools.ServiceLayer/QueryExecution/FileUtils.cs rename to src/Microsoft.SqlTools.ServiceLayer/Utility/FileUtils.cs diff --git a/src/Microsoft.SqlTools.ServiceLayer/Utility/TextUtilities.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/TextUtilities.cs index 29ca9e7f..979d697b 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Utility/TextUtilities.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Utility/TextUtilities.cs @@ -109,5 +109,19 @@ namespace Microsoft.SqlTools.ServiceLayer.Utility || ch == '(' || ch == ')'; } + + /// + /// Remove square bracket syntax from a token string + /// + /// + /// string with outer brackets removed + public static string RemoveSquareBracketSyntax(string tokenText) + { + if(tokenText.StartsWith("[") && tokenText.EndsWith("]")) + { + return tokenText.Substring(1, tokenText.Length - 2); + } + return tokenText; + } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/ScriptFile.cs b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/ScriptFile.cs index 279f9b7f..4cb5948c 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/ScriptFile.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/ScriptFile.cs @@ -94,15 +94,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Workspace.Contracts private set; } - /// - /// Gets the array of filepaths dot sourced in this ScriptFile - /// - public string[] ReferencedFiles - { - get; - private set; - } - #endregion #region Constructors @@ -299,9 +290,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Workspace.Contracts this.FileLines.Insert(currentLineNumber - 1, finalLine); currentLineNumber++; } - - // Parse the script again to be up-to-date - this.ParseFileContents(); } /// @@ -447,97 +435,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Workspace.Contracts .Split('\n') .Select(line => line.TrimEnd('\r')) .ToList(); - - // Parse the contents to get syntax tree and errors - this.ParseFileContents(); } #endregion - - #region Private Methods - - /// - /// Parses the current file contents to get the AST, tokens, - /// and parse errors. - /// - private void ParseFileContents() - { -#if false - ParseError[] parseErrors = null; - - // First, get the updated file range - int lineCount = this.FileLines.Count; - if (lineCount > 0) - { - this.FileRange = - new BufferRange( - new BufferPosition(1, 1), - new BufferPosition( - lineCount + 1, - this.FileLines[lineCount - 1].Length + 1)); - } - else - { - this.FileRange = BufferRange.None; - } - - try - { -#if SqlToolsv5r2 - // This overload appeared with Windows 10 Update 1 - if (this.SqlToolsVersion.Major >= 5 && - this.SqlToolsVersion.Build >= 10586) - { - // Include the file path so that module relative - // paths are evaluated correctly - this.ScriptAst = - Parser.ParseInput( - this.Contents, - this.FilePath, - out this.scriptTokens, - out parseErrors); - } - else - { - this.ScriptAst = - Parser.ParseInput( - this.Contents, - out this.scriptTokens, - out parseErrors); - } -#else - this.ScriptAst = - Parser.ParseInput( - this.Contents, - out this.scriptTokens, - out parseErrors); -#endif - } - catch (RuntimeException ex) - { - var parseError = - new ParseError( - null, - ex.ErrorRecord.FullyQualifiedErrorId, - ex.Message); - - parseErrors = new[] { parseError }; - this.scriptTokens = new Token[0]; - this.ScriptAst = null; - } - - // Translate parse errors into syntax markers - this.SyntaxMarkers = - parseErrors - .Select(ScriptFileMarker.FromParseError) - .ToArray(); - - //Get all dot sourced referenced files and store them - this.ReferencedFiles = - AstOperations.FindDotSourcedIncludes(this.ScriptAst); -#endif - } - -#endregion } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs b/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs index 6161c5a2..dad14f90 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs @@ -202,30 +202,39 @@ namespace Microsoft.SqlTools.ServiceLayer.Workspace DidChangeTextDocumentParams textChangeParams, EventContext eventContext) { - StringBuilder msg = new StringBuilder(); - msg.Append("HandleDidChangeTextDocumentNotification"); - List changedFiles = new List(); - - // A text change notification can batch multiple change requests - foreach (var textChange in textChangeParams.ContentChanges) + try { - string fileUri = textChangeParams.TextDocument.Uri ?? textChangeParams.TextDocument.Uri; - msg.AppendLine(string.Format(" File: {0}", fileUri)); + StringBuilder msg = new StringBuilder(); + msg.Append("HandleDidChangeTextDocumentNotification"); + List changedFiles = new List(); - ScriptFile changedFile = Workspace.GetFile(fileUri); + // A text change notification can batch multiple change requests + foreach (var textChange in textChangeParams.ContentChanges) + { + string fileUri = textChangeParams.TextDocument.Uri ?? textChangeParams.TextDocument.Uri; + msg.AppendLine(string.Format(" File: {0}", fileUri)); - changedFile.ApplyChange( - GetFileChangeDetails( - textChange.Range.Value, - textChange.Text)); + ScriptFile changedFile = Workspace.GetFile(fileUri); - changedFiles.Add(changedFile); + changedFile.ApplyChange( + GetFileChangeDetails( + textChange.Range.Value, + textChange.Text)); + + changedFiles.Add(changedFile); + } + + Logger.Write(LogLevel.Verbose, msg.ToString()); + + var handlers = TextDocChangeCallbacks.Select(t => t(changedFiles.ToArray(), eventContext)); + return Task.WhenAll(handlers); + } + catch + { + // Swallow exceptions here to prevent us from crashing + // TODO: this probably means the ScriptFile model is in a bad state or out of sync with the actual file; we should recover here + return Task.FromResult(true); } - - Logger.Write(LogLevel.Verbose, msg.ToString()); - - var handlers = TextDocChangeCallbacks.Select(t => t(changedFiles.ToArray(), eventContext)); - return Task.WhenAll(handlers); } internal async Task HandleDidOpenTextDocumentNotification( diff --git a/src/Microsoft.SqlTools.ServiceLayer/project.json b/src/Microsoft.SqlTools.ServiceLayer/project.json index 778734f0..c8e55c35 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/project.json +++ b/src/Microsoft.SqlTools.ServiceLayer/project.json @@ -5,6 +5,16 @@ "debugType": "portable", "emitEntryPoint": true }, + "configurations": { + "Integration": { + "buildOptions": { + "define": [ + "WINDOWS_ONLY_BUILD" + ], + "emitEntryPoint": true + } + } + }, "dependencies": { "Newtonsoft.Json": "9.0.1", "System.Data.Common": "4.1.0", diff --git a/src/Microsoft.SqlTools.ServiceLayer/sr.cs b/src/Microsoft.SqlTools.ServiceLayer/sr.cs index 9635b0f7..c8abeed3 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/sr.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/sr.cs @@ -173,11 +173,11 @@ namespace Microsoft.SqlTools.ServiceLayer } } - public static string QueryServiceSubsetNotCompleted + public static string QueryServiceSubsetBatchNotCompleted { get { - return Keys.GetString(Keys.QueryServiceSubsetNotCompleted); + return Keys.GetString(Keys.QueryServiceSubsetBatchNotCompleted); } } @@ -468,7 +468,7 @@ namespace Microsoft.SqlTools.ServiceLayer public const string QueryServiceQueryCancelled = "QueryServiceQueryCancelled"; - public const string QueryServiceSubsetNotCompleted = "QueryServiceSubsetNotCompleted"; + public const string QueryServiceSubsetBatchNotCompleted = "QueryServiceSubsetBatchNotCompleted"; public const string QueryServiceSubsetBatchOutOfRange = "QueryServiceSubsetBatchOutOfRange"; diff --git a/src/Microsoft.SqlTools.ServiceLayer/sr.resx b/src/Microsoft.SqlTools.ServiceLayer/sr.resx index 27f0993d..5494a742 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/sr.resx +++ b/src/Microsoft.SqlTools.ServiceLayer/sr.resx @@ -209,8 +209,8 @@ Query was canceled by user - - The query has not completed, yet + + The batch has not completed, yet diff --git a/src/Microsoft.SqlTools.ServiceLayer/sr.strings b/src/Microsoft.SqlTools.ServiceLayer/sr.strings index 9b9a138e..67534aed 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/sr.strings +++ b/src/Microsoft.SqlTools.ServiceLayer/sr.strings @@ -83,7 +83,7 @@ QueryServiceQueryCancelled = Query was canceled by user ### Subset Request -QueryServiceSubsetNotCompleted = The query has not completed, yet +QueryServiceSubsetBatchNotCompleted = The batch has not completed, yet QueryServiceSubsetBatchOutOfRange = Batch index cannot be less than 0 or greater than the number of batches diff --git a/test/CodeCoverage/codecoverage.bat b/test/CodeCoverage/codecoverage.bat index bebb0c6f..4e0920e2 100644 --- a/test/CodeCoverage/codecoverage.bat +++ b/test/CodeCoverage/codecoverage.bat @@ -13,10 +13,26 @@ REM we should remove this step on OpenCover supports portable PDB cscript /nologo ReplaceText.vbs %WORKINGDIR%..\..\src\Microsoft.SqlTools.ServiceLayer\project.json portable full REM rebuild the SqlToolsService project -dotnet build %WORKINGDIR%..\..\src\Microsoft.SqlTools.ServiceLayer\project.json +dotnet build %WORKINGDIR%..\..\src\Microsoft.SqlTools.ServiceLayer\project.json %DOTNETCONFIG% REM run the tests through OpenCover and generate a report -"%WORKINGDIR%packages\OpenCover.4.6.519\tools\OpenCover.Console.exe" -register:user -target:dotnet.exe -targetargs:"test %WORKINGDIR%..\Microsoft.SqlTools.ServiceLayer.Test\project.json" -oldstyle -filter:"+[Microsoft.SqlTools.*]* -[xunit*]*" -output:coverage.xml -searchdirs:%WORKINGDIR%..\Microsoft.SqlTools.ServiceLayer.Test\bin\Debug\netcoreapp1.0 +dotnet build %WORKINGDIR%..\..\test\Microsoft.SqlTools.ServiceLayer.Test\project.json %DOTNETCONFIG% +dotnet build %WORKINGDIR%..\..\test\Microsoft.SqlTools.ServiceLayer.TestDriver\project.json %DOTNETCONFIG% + +SET TEST_SERVER=localhost +SET SQLTOOLSSERVICE_EXE=%WORKINGDIR%..\..\src\Microsoft.SqlTools.ServiceLayer\bin\Integration\netcoreapp1.0\win7-x64\Microsoft.SqlTools.ServiceLayer.exe +SET SERVICECODECOVERAGE=True +SET CODECOVERAGETOOL="%WORKINGDIR%packages\OpenCover.4.6.519\tools\OpenCover.Console.exe" +SET CODECOVERAGEOUTPUT=coverage.xml + +dotnet.exe test %WORKINGDIR%..\Microsoft.SqlTools.ServiceLayer.TestDriver\project.json %DOTNETCONFIG%" + +SET SERVICECODECOVERAGE=FALSE + +"%WORKINGDIR%packages\OpenCover.4.6.519\tools\OpenCover.Console.exe" -mergeoutput -register:user -target:dotnet.exe -targetargs:"test %WORKINGDIR%..\Microsoft.SqlTools.ServiceLayer.TestDriver\project.json %DOTNETCONFIG%" -oldstyle -filter:"+[Microsoft.SqlTools.*]* -[xunit*]*" -output:coverage.xml -searchdirs:%WORKINGDIR%..\Microsoft.SqlTools.ServiceLayer.TestDriver\bin\Debug\netcoreapp1.0 + +"%WORKINGDIR%packages\OpenCover.4.6.519\tools\OpenCover.Console.exe" -mergeoutput -register:user -target:dotnet.exe -targetargs:"test %WORKINGDIR%..\Microsoft.SqlTools.ServiceLayer.Test\project.json %DOTNETCONFIG%" -oldstyle -filter:"+[Microsoft.SqlTools.*]* -[xunit*]*" -output:coverage.xml -searchdirs:%WORKINGDIR%..\Microsoft.SqlTools.ServiceLayer.Test\bin\Debug\netcoreapp1.0 + "%WORKINGDIR%packages\OpenCoverToCoberturaConverter.0.2.4.0\tools\OpenCoverToCoberturaConverter.exe" -input:coverage.xml -output:outputCobertura.xml -sources:%WORKINGDIR%..\..\src\Microsoft.SqlTools.ServiceLayer "%WORKINGDIR%packages\ReportGenerator.2.4.5.0\tools\ReportGenerator.exe" "-reports:coverage.xml" "-targetdir:%WORKINGDIR%\reports" diff --git a/test/CodeCoverage/runintegration.bat b/test/CodeCoverage/runintegration.bat new file mode 100644 index 00000000..e52a5bb5 --- /dev/null +++ b/test/CodeCoverage/runintegration.bat @@ -0,0 +1,4 @@ +set DOTNETCONFIG=-c Integration + +cmd /c npm install +gulp diff --git a/test/Microsoft.SqlTools.ServiceLayer.PerfTests/CreateTestDbAttribute.cs b/test/Microsoft.SqlTools.ServiceLayer.PerfTests/CreateTestDbAttribute.cs new file mode 100644 index 00000000..2b705918 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.PerfTests/CreateTestDbAttribute.cs @@ -0,0 +1,39 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Reflection; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Utility; +using Xunit.Sdk; + +namespace Microsoft.SqlTools.ServiceLayer.PerfTests +{ + /// + /// The attribute for each test to create the test db before the test starts + /// + public class CreateTestDbAttribute : BeforeAfterTestAttribute + { + public CreateTestDbAttribute(TestServerType serverType) + { + ServerType = serverType; + } + + public CreateTestDbAttribute(int serverType) + { + ServerType = (TestServerType)serverType; + } + + public TestServerType ServerType { get; set; } + public override void Before(MethodInfo methodUnderTest) + { + Task task = Common.CreateTestDatabase(ServerType); + task.Wait(); + } + + public override void After(MethodInfo methodUnderTest) + { + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.PerfTests/Microsoft.SqlTools.ServiceLayer.PerfTests.xproj b/test/Microsoft.SqlTools.ServiceLayer.PerfTests/Microsoft.SqlTools.ServiceLayer.PerfTests.xproj new file mode 100644 index 00000000..bfdf13be --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.PerfTests/Microsoft.SqlTools.ServiceLayer.PerfTests.xproj @@ -0,0 +1,22 @@ + + + + 14.0 + $(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion) + + + + 7e5968ab-83d7-4738-85a2-416a50f13d2f + Microsoft.SqlTools.ServiceLayer.PerfTests + .\obj + .\bin\ + v4.5.2 + + + 2.0 + + + + + + \ No newline at end of file diff --git a/test/Microsoft.SqlTools.ServiceLayer.PerfTests/Program.cs b/test/Microsoft.SqlTools.ServiceLayer.PerfTests/Program.cs new file mode 100644 index 00000000..0d920084 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.PerfTests/Program.cs @@ -0,0 +1,31 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Driver; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Utility; +using Microsoft.SqlTools.ServiceLayer.Utility; + +namespace Microsoft.SqlTools.ServiceLayer.PerfTests +{ + public class Program + { + internal static int Main(string[] args) + { + if (args.Length < 1) + { + Console.WriteLine("Microsoft.SqlTools.ServiceLayer.PerfTests.exe [tests]" + Environment.NewLine + + " [tests] is a space-separated list of tests to run." + Environment.NewLine + + " They are qualified within the Microsoft.SqlTools.ServiceLayer.TestDriver.PerfTests namespace" + Environment.NewLine + + $"Be sure to set the environment variable {ServiceTestDriver.ServiceHostEnvironmentVariable} to the full path of the sqltoolsservice executable."); + return 0; + } + + Logger.Initialize("testdriver", LogLevel.Verbose); + + return TestRunner.RunTests(args, "Microsoft.SqlTools.ServiceLayer.PerfTests.").Result; + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.PerfTests/Properties/AssemblyInfo.cs b/test/Microsoft.SqlTools.ServiceLayer.PerfTests/Properties/AssemblyInfo.cs new file mode 100644 index 00000000..825d07ad --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.PerfTests/Properties/AssemblyInfo.cs @@ -0,0 +1,19 @@ +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("Microsoft.SqlTools.ServiceLayer.PerfTests")] +[assembly: AssemblyTrademark("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("7e5968ab-83d7-4738-85a2-416a50f13d2f")] diff --git a/test/Microsoft.SqlTools.ServiceLayer.PerfTests/Tests/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.PerfTests/Tests/Common.cs new file mode 100644 index 00000000..7b8c8985 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.PerfTests/Tests/Common.cs @@ -0,0 +1,111 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Globalization; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Scripts; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Tests; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Utility; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.PerfTests +{ + public class Common + { + public const string PerfTestDatabaseName = "SQLToolsCrossPlatPerfTestDb"; + public const string MasterDatabaseName = "master"; + + + internal static async Task ExecuteWithTimeout(TestTimer timer, int timeout, Func> repeatedCode, + TimeSpan? delay = null, [CallerMemberName] string testName = "") + { + while (true) + { + if (await repeatedCode()) + { + timer.EndAndPrint(testName); + break; + } + if (timer.TotalMilliSecondsUntilNow >= timeout) + { + Assert.True(false, $"{testName} timed out after {timeout} milliseconds"); + break; + } + if (delay.HasValue) + { + await Task.Delay(delay.Value); + } + } + } + + internal static async Task ConnectAsync(TestHelper testHelper, TestServerType serverType, string query, string ownerUri, string databaseName) + { + testHelper.WriteToFile(ownerUri, query); + + DidOpenTextDocumentNotification openParams = new DidOpenTextDocumentNotification + { + TextDocument = new TextDocumentItem + { + Uri = ownerUri, + LanguageId = "enu", + Version = 1, + Text = query + } + }; + + await testHelper.RequestOpenDocumentNotification(openParams); + + Thread.Sleep(500); + var connectParams = await testHelper.GetDatabaseConnectionAsync(serverType, databaseName); + + bool connected = await testHelper.Connect(ownerUri, connectParams); + Assert.True(connected, "Connection is successful"); + Console.WriteLine($"Connection to {connectParams.Connection.ServerName} is successful"); + + return connected; + } + + internal static async Task CalculateRunTime(Func> testToRun, bool printResult, [CallerMemberName] string testName = "") + { + TestTimer timer = new TestTimer() { PrintResult = printResult }; + T result = await testToRun(); + timer.EndAndPrint(testName); + + return result; + } + + /// + /// Create the test db if not already exists + /// + internal static async Task CreateTestDatabase(TestServerType serverType) + { + using (TestHelper testHelper = new TestHelper()) + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + { + string databaseName = Common.PerfTestDatabaseName; + string createDatabaseQuery = Scripts.CreateDatabaseQuery.Replace("#DatabaseName#", databaseName); + await RunQuery(testHelper, serverType, Common.MasterDatabaseName, createDatabaseQuery); + Console.WriteLine(string.Format(CultureInfo.InvariantCulture, "Verified test database '{0}' is created", databaseName)); + await RunQuery(testHelper, serverType, databaseName, Scripts.CreateDatabaseObjectsQuery); + Console.WriteLine(string.Format(CultureInfo.InvariantCulture, "Verified test database '{0}' SQL types are created", databaseName)); + } + } + + internal static async Task RunQuery(TestHelper testHelper, TestServerType serverType, string databaseName, string query) + { + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + { + await Common.ConnectAsync(testHelper, serverType, query, queryTempFile.FilePath, databaseName); + var queryResult = await Common.CalculateRunTime(() => testHelper.RunQuery(queryTempFile.FilePath, query, 50000), false); + + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.PerfTests/Tests/ConnectionTests.cs b/test/Microsoft.SqlTools.ServiceLayer.PerfTests/Tests/ConnectionTests.cs new file mode 100644 index 00000000..44cc0be7 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.PerfTests/Tests/ConnectionTests.cs @@ -0,0 +1,105 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Scripts; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Tests; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Utility; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.PerfTests +{ + public class ConnectionTests + { + + [Fact] + [CreateTestDb(TestServerType.Azure)] + public async Task ConnectAzureTest() + { + TestServerType serverType = TestServerType.Azure; + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + const string query = Scripts.TestDbSimpleSelectQuery; + testHelper.WriteToFile(queryTempFile.FilePath, query); + + DidOpenTextDocumentNotification openParams = new DidOpenTextDocumentNotification + { + TextDocument = new TextDocumentItem + { + Uri = queryTempFile.FilePath, + LanguageId = "enu", + Version = 1, + Text = query + } + }; + + await testHelper.RequestOpenDocumentNotification(openParams); + + Thread.Sleep(500); + var connected = await Common.CalculateRunTime(async () => + { + var connectParams = await testHelper.GetDatabaseConnectionAsync(serverType, Common.PerfTestDatabaseName); + return await testHelper.Connect(queryTempFile.FilePath, connectParams); + }, true); + Assert.True(connected, "Connection was not successful"); + } + } + + [Fact] + [CreateTestDb(TestServerType.OnPrem)] + public async Task ConnectOnPremTest() + { + TestServerType serverType = TestServerType.OnPrem; + + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + const string query = Scripts.TestDbSimpleSelectQuery; + testHelper.WriteToFile(queryTempFile.FilePath, query); + + DidOpenTextDocumentNotification openParams = new DidOpenTextDocumentNotification + { + TextDocument = new TextDocumentItem + { + Uri = queryTempFile.FilePath, + LanguageId = "enu", + Version = 1, + Text = query + } + }; + + await testHelper.RequestOpenDocumentNotification(openParams); + + Thread.Sleep(500); + var connected = await Common.CalculateRunTime(async () => + { + var connectParams = await testHelper.GetDatabaseConnectionAsync(serverType, Common.PerfTestDatabaseName); + return await testHelper.Connect(queryTempFile.FilePath, connectParams); + }, true); + Assert.True(connected, "Connection was not successful"); + } + } + + [Fact] + [CreateTestDb(TestServerType.OnPrem)] + public async Task DisconnectTest() + { + TestServerType serverType = TestServerType.OnPrem; + + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + await Common.ConnectAsync(testHelper, serverType, Scripts.TestDbSimpleSelectQuery, queryTempFile.FilePath, Common.PerfTestDatabaseName); + Thread.Sleep(1000); + var connected = await Common.CalculateRunTime(() => testHelper.Disconnect(queryTempFile.FilePath), true); + Assert.True(connected); + } + } + + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.PerfTests/Tests/IntellisenseTests.cs b/test/Microsoft.SqlTools.ServiceLayer.PerfTests/Tests/IntellisenseTests.cs new file mode 100644 index 00000000..cffe764c --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.PerfTests/Tests/IntellisenseTests.cs @@ -0,0 +1,277 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Scripts; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Tests; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Utility; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.PerfTests +{ + public class IntellisenseTests + { + [Fact] + [CreateTestDb(TestServerType.OnPrem)] + public async Task HoverTestOnPrem() + { + TestServerType serverType = TestServerType.OnPrem; + using (TestHelper testHelper = new TestHelper()) + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + { + const string query = Scripts.TestDbSimpleSelectQuery; + await Common.ConnectAsync(testHelper, serverType, query, queryTempFile.FilePath, Common.PerfTestDatabaseName); + Hover hover = await Common.CalculateRunTime(() => testHelper.RequestHover(queryTempFile.FilePath, query, 0, Scripts.TestDbComplexSelectQueries.Length + 1), true); + Assert.NotNull(hover); + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + + [Fact] + [CreateTestDb(TestServerType.OnPrem)] + public async Task SuggestionsTest() + { + TestServerType serverType = TestServerType.OnPrem; + using (TestHelper testHelper = new TestHelper()) + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + { + const string query = Scripts.TestDbSimpleSelectQuery; + await Common.ConnectAsync(testHelper, serverType, query, queryTempFile.FilePath, Common.PerfTestDatabaseName); + await ValidateCompletionResponse(testHelper, queryTempFile.FilePath, false, Common.PerfTestDatabaseName, true); + await ValidateCompletionResponse(testHelper, queryTempFile.FilePath, true, Common.PerfTestDatabaseName, false); + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + + [Fact] + [CreateTestDb(TestServerType.OnPrem)] + public async Task DiagnosticsTests() + { + TestServerType serverType = TestServerType.OnPrem; + await Common.CreateTestDatabase(serverType); + + using (TestHelper testHelper = new TestHelper()) + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + { + await Common.ConnectAsync(testHelper, serverType, Scripts.TestDbSimpleSelectQuery, queryTempFile.FilePath, Common.PerfTestDatabaseName); + + Thread.Sleep(500); + var contentChanges = new TextDocumentChangeEvent[1]; + contentChanges[0] = new TextDocumentChangeEvent() + { + Range = new Range + { + Start = new Position + { + Line = 0, + Character = 5 + }, + End = new Position + { + Line = 0, + Character = 6 + } + }, + RangeLength = 1, + Text = "z" + }; + DidChangeTextDocumentParams changeParams = new DidChangeTextDocumentParams + { + ContentChanges = contentChanges, + TextDocument = new VersionedTextDocumentIdentifier + { + Version = 2, + Uri = queryTempFile.FilePath + } + }; + + TestTimer timer = new TestTimer() { PrintResult = true }; + await testHelper.RequestChangeTextDocumentNotification(changeParams); + await Common.ExecuteWithTimeout(timer, 60000, async () => + { + var completeEvent = await testHelper.Driver.WaitForEvent(PublishDiagnosticsNotification.Type, 15000); + return completeEvent?.Diagnostics != null && completeEvent.Diagnostics.Length > 0; + }); + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + + [Fact] + [CreateTestDb(TestServerType.Azure)] + public async Task BindingCacheColdAzureSimpleQuery() + { + TestServerType serverType = TestServerType.Azure; + using (TestHelper testHelper = new TestHelper()) + { + await VerifyBindingLoadScenario(testHelper, serverType, Scripts.TestDbSimpleSelectQuery, false); + } + } + + [Fact] + [CreateTestDb(TestServerType.OnPrem)] + public async Task BindingCacheColdOnPremSimpleQuery() + { + TestServerType serverType = TestServerType.OnPrem; + using (TestHelper testHelper = new TestHelper()) + { + await VerifyBindingLoadScenario(testHelper, TestServerType.OnPrem, Scripts.TestDbSimpleSelectQuery, false); + } + + } + + [Fact] + [CreateTestDb(TestServerType.Azure)] + public async Task BindingCacheWarmAzureSimpleQuery() + { + TestServerType serverType = TestServerType.Azure; + using (TestHelper testHelper = new TestHelper()) + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + { + const string query = Scripts.TestDbSimpleSelectQuery; + await VerifyBindingLoadScenario(testHelper, serverType, query, true); + } + } + + [Fact] + [CreateTestDb(TestServerType.OnPrem)] + public async Task BindingCacheWarmOnPremSimpleQuery() + { + TestServerType serverType = TestServerType.OnPrem; + + using (TestHelper testHelper = new TestHelper()) + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + { + const string query = Scripts.TestDbSimpleSelectQuery; + await VerifyBindingLoadScenario(testHelper, serverType, query, true); + } + } + + [Fact] + [CreateTestDb(TestServerType.Azure)] + public async Task BindingCacheColdAzureComplexQuery() + { + TestServerType serverType = TestServerType.Azure; + + using (TestHelper testHelper = new TestHelper()) + { + await VerifyBindingLoadScenario(testHelper, serverType, Scripts.TestDbComplexSelectQueries,false); + } + } + + [Fact] + [CreateTestDb(TestServerType.Azure)] + public async Task BindingCacheColdOnPremComplexQuery() + { + TestServerType serverType = TestServerType.Azure; + using (TestHelper testHelper = new TestHelper()) + { + await VerifyBindingLoadScenario(testHelper, TestServerType.OnPrem, Scripts.TestDbComplexSelectQueries, false); + } + } + + [Fact] + [CreateTestDb(TestServerType.Azure)] + public async Task BindingCacheWarmAzureComplexQuery() + { + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + string query = Scripts.TestDbComplexSelectQueries; + const TestServerType serverType = TestServerType.Azure; + await VerifyBindingLoadScenario(testHelper, serverType, query, true); + } + } + + [Fact] + [CreateTestDb(TestServerType.OnPrem)] + public async Task BindingCacheWarmOnPremComplexQuery() + { + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + string query = Scripts.TestDbComplexSelectQueries; + const TestServerType serverType = TestServerType.OnPrem; + await VerifyBindingLoadScenario(testHelper, serverType, query, true); + } + } + + #region Private Helper Methods + + private async Task VerifyBindingLoadScenario( + TestHelper testHelper, + TestServerType serverType, + string query, + bool preLoad, + [CallerMemberName] string testName = "") + { + string databaseName = Common.PerfTestDatabaseName; + if (preLoad) + { + await VerifyCompletationLoaded(testHelper, serverType, Scripts.TestDbSimpleSelectQuery, + databaseName, printResult: false, testName: testName); + Console.WriteLine("Intellisense cache loaded."); + } + await VerifyCompletationLoaded(testHelper, serverType, query, databaseName, + printResult: true, testName: testName); + } + + private async Task VerifyCompletationLoaded( + TestHelper testHelper, + TestServerType serverType, + string query, + string databaseName, + bool printResult, + string testName) + { + using (SelfCleaningTempFile testTempFile = new SelfCleaningTempFile()) + { + testHelper.WriteToFile(testTempFile.FilePath, query); + await Common.ConnectAsync(testHelper, serverType, query, testTempFile.FilePath, databaseName); + await ValidateCompletionResponse(testHelper, testTempFile.FilePath, printResult, databaseName, + waitForIntelliSense: true, testName: testName); + await testHelper.Disconnect(testTempFile.FilePath); + } + } + + private static async Task ValidateCompletionResponse( + TestHelper testHelper, + string ownerUri, + bool printResult, + string databaseName, + bool waitForIntelliSense, + [CallerMemberName] string testName = "") + { + TestTimer timer = new TestTimer() { PrintResult = printResult }; + bool isReady = !waitForIntelliSense; + await Common.ExecuteWithTimeout(timer, 150000, async () => + { + if (isReady) + { + string query = Scripts.SelectQuery; + CompletionItem[] completions = await testHelper.RequestCompletion(ownerUri, query, 0, query.Length + 1); + return completions != null && completions.Any(x => x.Label == databaseName); + } + else + { + var completeEvent = await testHelper.Driver.WaitForEvent(IntelliSenseReadyNotification.Type, 100000); + isReady = completeEvent.OwnerUri == ownerUri; + if (isReady) + { + Console.WriteLine("IntelliSense cache is loaded."); + } + return false; + } + }, testName: testName); + } + + #endregion + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.PerfTests/Tests/QueryExecutionTests.cs b/test/Microsoft.SqlTools.ServiceLayer.PerfTests/Tests/QueryExecutionTests.cs new file mode 100644 index 00000000..8d41cd1e --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.PerfTests/Tests/QueryExecutionTests.cs @@ -0,0 +1,100 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Scripts; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Tests; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Utility; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.PerfTests +{ + public class QueryExecutionTests + { + [Fact] + public async Task QueryResultSummaryOnPremTest() + { + TestServerType serverType = TestServerType.OnPrem; + + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + const string query = Scripts.MasterBasicQuery; + + await Common.ConnectAsync(testHelper, serverType, query, queryTempFile.FilePath, Common.MasterDatabaseName); + var queryResult = await Common.CalculateRunTime(() => testHelper.RunQuery(queryTempFile.FilePath, query), true); + + Assert.NotNull(queryResult); + Assert.True(queryResult.BatchSummaries.Any(x => x.ResultSetSummaries.Any(r => r.RowCount > 0))); + + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + + [Fact] + public async Task QueryResultFirstOnPremTest() + { + TestServerType serverType = TestServerType.OnPrem; + + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + const string query = Scripts.MasterBasicQuery; + + await Common.ConnectAsync(testHelper, serverType, query, queryTempFile.FilePath, Common.MasterDatabaseName); + + var queryResult = await Common.CalculateRunTime(async () => + { + await testHelper.RunQuery(queryTempFile.FilePath, query); + return await testHelper.ExecuteSubset(queryTempFile.FilePath, 0, 0, 0, 100); + }, true); + + Assert.NotNull(queryResult); + Assert.NotNull(queryResult.ResultSubset); + Assert.True(queryResult.ResultSubset.Rows.Any()); + + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + + [Fact] + [CreateTestDb(TestServerType.OnPrem)] + public async Task CancelQueryOnPremTest() + { + TestServerType serverType = TestServerType.OnPrem; + + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + await Common.ConnectAsync(testHelper, serverType, Scripts.DelayQuery, queryTempFile.FilePath, Common.PerfTestDatabaseName); + var queryParams = new QueryExecuteParams + { + OwnerUri = queryTempFile.FilePath, + QuerySelection = null + }; + + var result = await testHelper.Driver.SendRequest(QueryExecuteRequest.Type, queryParams); + if (result != null && string.IsNullOrEmpty(result.Messages)) + { + TestTimer timer = new TestTimer() { PrintResult = true }; + await Common.ExecuteWithTimeout(timer, 100000, async () => + { + var cancelQueryResult = await testHelper.CancelQuery(queryTempFile.FilePath); + return true; + }, TimeSpan.FromMilliseconds(10)); + } + else + { + Assert.True(false, "Failed to run the query"); + } + + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.PerfTests/Tests/SaveResultsTests.cs b/test/Microsoft.SqlTools.ServiceLayer.PerfTests/Tests/SaveResultsTests.cs new file mode 100644 index 00000000..298d325e --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.PerfTests/Tests/SaveResultsTests.cs @@ -0,0 +1,55 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Scripts; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Tests; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Utility; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.PerfTests +{ + public class SaveResultsTests + { + [Fact] + public async Task TestSaveResultsToCsvTest() + { + TestServerType serverType = TestServerType.OnPrem; + + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (SelfCleaningTempFile outputTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + const string query = Scripts.MasterBasicQuery; + + // Execute a query + await Common.ConnectAsync(testHelper, serverType, query, queryTempFile.FilePath, Common.MasterDatabaseName); + await testHelper.RunQuery(queryTempFile.FilePath, query); + await Common.CalculateRunTime(() => testHelper.SaveAsCsv(queryTempFile.FilePath, outputTempFile.FilePath, 0, 0), true); + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + + [Fact] + public async Task TestSaveResultsToJsonTest() + { + TestServerType serverType = TestServerType.OnPrem; + + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (SelfCleaningTempFile outputTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + const string query = Scripts.MasterBasicQuery; + + // Execute a query + await Common.ConnectAsync(testHelper, serverType, query, queryTempFile.FilePath, Common.MasterDatabaseName); + await testHelper.RunQuery(queryTempFile.FilePath, query); + await Common.CalculateRunTime(() => testHelper.SaveAsJson(queryTempFile.FilePath, outputTempFile.FilePath, 0, 0), true); + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.PerfTests/project.json b/test/Microsoft.SqlTools.ServiceLayer.PerfTests/project.json new file mode 100644 index 00000000..4eb7c6fd --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.PerfTests/project.json @@ -0,0 +1,35 @@ +{ + "name": "Microsoft.SqlTools.ServiceLayer.PerfTests", + "version": "1.0.0-*", + "buildOptions": { + "debugType": "portable", + "emitEntryPoint": true + }, + "dependencies": { + "xunit": "2.1.0", + "dotnet-test-xunit": "1.0.0-rc2-192208-24", + "Microsoft.SqlTools.ServiceLayer": { + "target": "project" + }, + "Microsoft.SqlTools.ServiceLayer.TestDriver": "1.0.0-*" + }, + + "testRunner": "xunit", + + "frameworks": { + "netcoreapp1.0": { + "dependencies": { + "Microsoft.NETCore.App": { + "version": "1.0.0" + } + }, + "imports": [ + "dotnet5.4", + "portable-net451+win8" + ] + } + }, + "runtimes": { + "win7-x64": {} + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/AssemblyInfo.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/AssemblyInfo.cs new file mode 100644 index 00000000..70421c18 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/AssemblyInfo.cs @@ -0,0 +1,8 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Xunit; + +[assembly: CollectionBehavior(DisableTestParallelization = true)] diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Completion/AutoCompletionResultTest.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Completion/AutoCompletionResultTest.cs new file mode 100644 index 00000000..0fde1329 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Completion/AutoCompletionResultTest.cs @@ -0,0 +1,26 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + + +using System.Threading; +using Microsoft.SqlTools.ServiceLayer.LanguageServices.Completion; +using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Completion +{ + public class AutoCompletionResultTest + { + [Fact] + public void MetricsShouldGetSortedGivenUnSortedArray() + { + AutoCompletionResult result = new AutoCompletionResult(); + int duration = 2000; + Thread.Sleep(duration); + result.CompleteResult(new CompletionItem[] { }); + Assert.True(result.Duration >= duration); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Completion/ScriptDocumentInfoTest.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Completion/ScriptDocumentInfoTest.cs new file mode 100644 index 00000000..cc35edfb --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Completion/ScriptDocumentInfoTest.cs @@ -0,0 +1,45 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.LanguageServices; +using Microsoft.SqlTools.ServiceLayer.LanguageServices.Completion; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Completion +{ + public class ScriptDocumentInfoTest + { + [Fact] + public void MetricsShouldGetSortedGivenUnSortedArray() + { + TextDocumentPosition doc = new TextDocumentPosition() + { + TextDocument = new TextDocumentIdentifier + { + Uri = "script file" + }, + Position = new Position() + { + Line = 1, + Character = 14 + } + }; + ScriptFile scriptFile = new ScriptFile() + { + Contents = "Select * from sys.all_objects" + }; + + ScriptParseInfo scriptParseInfo = new ScriptParseInfo(); + ScriptDocumentInfo docInfo = new ScriptDocumentInfo(doc, scriptFile, scriptParseInfo); + + Assert.Equal(docInfo.StartLine, 1); + Assert.Equal(docInfo.ParserLine, 2); + Assert.Equal(docInfo.StartColumn, 44); + Assert.Equal(docInfo.EndColumn, 14); + Assert.Equal(docInfo.ParserColumn, 15); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs index 46ccbb94..3bc7b14c 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs @@ -809,6 +809,36 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection Assert.True(callbackInvoked); } + /// + /// Test ConnectionSummaryComparer + /// + [Fact] + public void TestConnectionSummaryComparer() + { + var summary1 = new ConnectionSummary() + { + ServerName = "localhost", + DatabaseName = "master", + UserName = "user" + }; + + var summary2 = new ConnectionSummary() + { + ServerName = "localhost", + DatabaseName = "master", + UserName = "user" + }; + + var comparer = new ConnectionSummaryComparer(); + Assert.True(comparer.Equals(summary1, summary2)); + + summary2.DatabaseName = "tempdb"; + Assert.False(comparer.Equals(summary1, summary2)); + Assert.False(comparer.Equals(null, summary2)); + + Assert.False(summary1.GetHashCode() == summary2.GetHashCode()); + } + /// /// Verify when a connection is created that the URI -> Connection mapping is created in the connection service. /// @@ -861,5 +891,74 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection } }); } + + /// + /// Test that the connection complete notification type can be created. + /// + [Fact] + public void TestConnectionCompleteNotificationIsCreated() + { + Assert.NotNull(ConnectionCompleteNotification.Type); + } + + /// + /// Test that the connection summary comparer creates a hash code correctly + /// + [Theory] + [InlineData(true, null, null ,null)] + [InlineData(false, null, null, null)] + [InlineData(false, null, null, "sa")] + [InlineData(false, null, "test", null)] + [InlineData(false, null, "test", "sa")] + [InlineData(false, "server", null, null)] + [InlineData(false, "server", null, "sa")] + [InlineData(false, "server", "test", null)] + [InlineData(false, "server", "test", "sa")] + public void TestConnectionSummaryComparerHashCode(bool objectNull, string serverName, string databaseName, string userName) + { + // Given a connection summary and comparer object + ConnectionSummary summary = null; + if (!objectNull) + { + summary = new ConnectionSummary() + { + ServerName = serverName, + DatabaseName = databaseName, + UserName = userName + }; + } + ConnectionSummaryComparer comparer = new ConnectionSummaryComparer(); + + // If I compute a hash code + int hashCode = comparer.GetHashCode(summary); + if (summary == null || (serverName == null && databaseName == null && userName == null)) + { + // Then I expect it to be 31 for a null summary + Assert.Equal(31, hashCode); + } + else + { + // And not 31 otherwise + Assert.NotEqual(31, hashCode); + } + } + + [Fact] + public void ConnectParamsAreInvalidIfConnectionIsNull() + { + // Given connection parameters where the connection property is null + ConnectParams parameters = new ConnectParams(); + parameters.OwnerUri = "my/sql/file.sql"; + parameters.Connection = null; + + string errorMessage; + + // If I check if the parameters are valid + Assert.False(parameters.IsValid(out errorMessage)); + + // Then I expect an error message + Assert.NotNull(errorMessage); + Assert.NotEmpty(errorMessage); + } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ReliableConnectionTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ReliableConnectionTests.cs index 24a2f39c..95194a55 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ReliableConnectionTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ReliableConnectionTests.cs @@ -6,12 +6,20 @@ #if LIVE_CONNECTION_TESTS using System; +using System.Collections.Generic; using System.Data; using System.Data.Common; using System.Data.SqlClient; +using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection; using Microsoft.SqlTools.ServiceLayer.Test.Utility; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; +using Microsoft.SqlTools.Test.Utility; using Xunit; +using static Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection.ReliableConnectionHelper; +using static Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection.RetryPolicy; +using static Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection.RetryPolicy.TimeBasedRetryPolicy; +using static Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection.SqlSchemaModelErrorCodes; namespace Microsoft.SqlTools.ServiceLayer.Test.Connection { @@ -21,6 +29,134 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection /// public class ReliableConnectionTests { + internal class TestDataTransferErrorDetectionStrategy : DataTransferErrorDetectionStrategy + { + public bool InvokeCanRetrySqlException(SqlException exception) + { + return CanRetrySqlException(exception); + } + } + internal class TestSqlAzureTemporaryAndIgnorableErrorDetectionStrategy : SqlAzureTemporaryAndIgnorableErrorDetectionStrategy + { + public TestSqlAzureTemporaryAndIgnorableErrorDetectionStrategy() + : base (new int[] { 100 }) + { + } + + public bool InvokeCanRetrySqlException(SqlException exception) + { + return CanRetrySqlException(exception); + } + + public bool InvokeShouldIgnoreSqlException(SqlException exception) + { + return ShouldIgnoreSqlException(exception); + } + } + + internal class TestFixedDelayPolicy : FixedDelayPolicy + { + public TestFixedDelayPolicy( + IErrorDetectionStrategy strategy, + int maxRetryCount, + TimeSpan intervalBetweenRetries) + : base(strategy, + maxRetryCount, + intervalBetweenRetries) + { + } + + public bool InvokeShouldRetryImpl(RetryState retryStateObj) + { + return ShouldRetryImpl(retryStateObj); + } + } + + internal class TestProgressiveRetryPolicy : ProgressiveRetryPolicy + { + public TestProgressiveRetryPolicy( + IErrorDetectionStrategy strategy, + int maxRetryCount, + TimeSpan initialInterval, + TimeSpan increment) + : base(strategy, + maxRetryCount, + initialInterval, + increment) + { + } + + public bool InvokeShouldRetryImpl(RetryState retryStateObj) + { + return ShouldRetryImpl(retryStateObj); + } + } + + internal class TestTimeBasedRetryPolicy : TimeBasedRetryPolicy + { + public TestTimeBasedRetryPolicy( + IErrorDetectionStrategy strategy, + TimeSpan minTotalRetryTimeLimit, + TimeSpan maxTotalRetryTimeLimit, + double totalRetryTimeLimitRate, + TimeSpan minInterval, + TimeSpan maxInterval, + double intervalFactor) + : base( + strategy, + minTotalRetryTimeLimit, + maxTotalRetryTimeLimit, + totalRetryTimeLimitRate, + minInterval, + maxInterval, + intervalFactor) + { + } + + public bool InvokeShouldRetryImpl(RetryState retryStateObj) + { + return ShouldRetryImpl(retryStateObj); + } + } + + [Fact] + public void FixedDelayPolicyTest() + { + TestFixedDelayPolicy policy = new TestFixedDelayPolicy( + strategy: new NetworkConnectivityErrorDetectionStrategy(), + maxRetryCount: 3, + intervalBetweenRetries: TimeSpan.FromMilliseconds(100)); + bool shouldRety = policy.InvokeShouldRetryImpl(new RetryStateEx()); + Assert.True(shouldRety); + } + + [Fact] + public void ProgressiveRetryPolicyTest() + { + TestProgressiveRetryPolicy policy = new TestProgressiveRetryPolicy( + strategy: new NetworkConnectivityErrorDetectionStrategy(), + maxRetryCount: 3, + initialInterval: TimeSpan.FromMilliseconds(100), + increment: TimeSpan.FromMilliseconds(100)); + bool shouldRety = policy.InvokeShouldRetryImpl(new RetryStateEx()); + Assert.True(shouldRety); + } + + [Fact] + public void TimeBasedRetryPolicyTest() + { + TestTimeBasedRetryPolicy policy = new TestTimeBasedRetryPolicy( + strategy: new NetworkConnectivityErrorDetectionStrategy(), + minTotalRetryTimeLimit: TimeSpan.FromMilliseconds(100), + maxTotalRetryTimeLimit: TimeSpan.FromMilliseconds(100), + totalRetryTimeLimitRate: 100, + minInterval: TimeSpan.FromMilliseconds(100), + maxInterval: TimeSpan.FromMilliseconds(100), + intervalFactor: 1); + bool shouldRety = policy.InvokeShouldRetryImpl(new RetryStateEx()); + Assert.True(shouldRety); + } + /// /// Environment variable that stores the name of the test server hosting the SQL Server instance. /// @@ -338,6 +474,269 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection Assert.NotEmpty(info.ServerVersion); }); } + + + /// + /// Validate ambient static settings + /// + [Fact] + public void AmbientSettingsStaticPropertiesTest() + { + var defaultSettings = AmbientSettings.DefaultSettings; + Assert.NotNull(defaultSettings); + var masterReferenceFilePath = AmbientSettings.MasterReferenceFilePath; + var maxDataReaderDegreeOfParallelism = AmbientSettings.MaxDataReaderDegreeOfParallelism; + var tableProgressUpdateInterval = AmbientSettings.TableProgressUpdateInterval; + var traceRowCountFailure = AmbientSettings.TraceRowCountFailure; + var useOfflineDataReader = AmbientSettings.UseOfflineDataReader; + var streamBackingStoreForOfflineDataReading = AmbientSettings.StreamBackingStoreForOfflineDataReading; + var disableIndexesForDataPhase = AmbientSettings.DisableIndexesForDataPhase; + var reliableDdlEnabled = AmbientSettings.ReliableDdlEnabled; + var importModelDatabase = AmbientSettings.ImportModelDatabase; + var supportAlwaysEncrypted = AmbientSettings.SupportAlwaysEncrypted; + var alwaysEncryptedWizardMigration = AmbientSettings.AlwaysEncryptedWizardMigration; + var skipObjectTypeBlocking =AmbientSettings.SkipObjectTypeBlocking; + var doNotSerializeQueryStoreSettings = AmbientSettings.DoNotSerializeQueryStoreSettings; + var lockTimeoutMilliSeconds = AmbientSettings.LockTimeoutMilliSeconds; + var queryTimeoutSeconds = AmbientSettings.QueryTimeoutSeconds; + var longRunningQueryTimeoutSeconds = AmbientSettings.LongRunningQueryTimeoutSeconds; + var alwaysRetryOnTransientFailure = AmbientSettings.AlwaysRetryOnTransientFailure; + var connectionRetryMessageHandler = AmbientSettings.ConnectionRetryMessageHandler; + + using (var settingsContext = AmbientSettings.CreateSettingsContext()) + { + var settings = settingsContext.Settings; + Assert.NotNull(settings); + } + } + + /// + /// Validate ambient settings populate + /// + [Fact] + public void AmbientSettingsPopulateTest() + { + var data = new AmbientSettings.AmbientData(); + + var masterReferenceFilePath = data.MasterReferenceFilePath; + data.MasterReferenceFilePath = masterReferenceFilePath; + var lockTimeoutMilliSeconds = data.LockTimeoutMilliSeconds; + data.LockTimeoutMilliSeconds = lockTimeoutMilliSeconds; + var queryTimeoutSeconds = data.QueryTimeoutSeconds; + data.QueryTimeoutSeconds = queryTimeoutSeconds; + var longRunningQueryTimeoutSeconds = data.LongRunningQueryTimeoutSeconds; + data.LongRunningQueryTimeoutSeconds = longRunningQueryTimeoutSeconds; + var alwaysRetryOnTransientFailure = data.AlwaysRetryOnTransientFailure; + data.AlwaysRetryOnTransientFailure = alwaysRetryOnTransientFailure; + var connectionRetryMessageHandler = data.ConnectionRetryMessageHandler; + data.ConnectionRetryMessageHandler = connectionRetryMessageHandler; + var traceRowCountFailure = data.TraceRowCountFailure; + data.TraceRowCountFailure = traceRowCountFailure; + var tableProgressUpdateInterval = data.TableProgressUpdateInterval; + data.TableProgressUpdateInterval = tableProgressUpdateInterval; + var useOfflineDataReader = data.UseOfflineDataReader; + data.UseOfflineDataReader = useOfflineDataReader; + var streamBackingStoreForOfflineDataReading = data.StreamBackingStoreForOfflineDataReading; + data.StreamBackingStoreForOfflineDataReading = streamBackingStoreForOfflineDataReading; + var disableIndexesForDataPhase = data.DisableIndexesForDataPhase; + data.DisableIndexesForDataPhase = disableIndexesForDataPhase; + var reliableDdlEnabled = data.ReliableDdlEnabled; + data.ReliableDdlEnabled = reliableDdlEnabled; + var importModelDatabase = data.ImportModelDatabase; + data.ImportModelDatabase = importModelDatabase; + var supportAlwaysEncrypted = data.SupportAlwaysEncrypted; + data.SupportAlwaysEncrypted = supportAlwaysEncrypted; + var alwaysEncryptedWizardMigration = data.AlwaysEncryptedWizardMigration; + data.AlwaysEncryptedWizardMigration = alwaysEncryptedWizardMigration; + var skipObjectTypeBlocking = data.SkipObjectTypeBlocking; + data.SkipObjectTypeBlocking = skipObjectTypeBlocking; + var doNotSerializeQueryStoreSettings = data.DoNotSerializeQueryStoreSettings; + data.DoNotSerializeQueryStoreSettings = doNotSerializeQueryStoreSettings; + + Dictionary settings = new Dictionary(); + settings.Add("LockTimeoutMilliSeconds", 10000); + data.PopulateSettings(settings); + settings["LockTimeoutMilliSeconds"] = 15000; + data.PopulateSettings(settings); + data.TraceSettings(); + } + + [Fact] + public void RetryPolicyFactoryTest() + { + Assert.NotNull(RetryPolicyFactory.NoRetryPolicy); + Assert.NotNull(RetryPolicyFactory.PrimaryKeyViolationRetryPolicy); + + RetryPolicy noRetyPolicy = RetryPolicyFactory.CreateDefaultSchemaCommandRetryPolicy(useRetry: false); + + var retryState = new RetryStateEx(); + retryState.LastError = new Exception(); + RetryPolicyFactory.DataConnectionFailureRetry(retryState); + RetryPolicyFactory.CommandFailureRetry(retryState, "command"); + RetryPolicyFactory.CommandFailureIgnore(retryState, "command"); + RetryPolicyFactory.ElementCommandFailureIgnore(retryState); + RetryPolicyFactory.ElementCommandFailureRetry(retryState); + RetryPolicyFactory.CreateDatabaseCommandFailureIgnore(retryState); + RetryPolicyFactory.CreateDatabaseCommandFailureRetry(retryState); + RetryPolicyFactory.CommandFailureIgnore(retryState); + RetryPolicyFactory.CommandFailureRetry(retryState); + + var transientPolicy = new RetryPolicyFactory.TransientErrorIgnoreStrategy(); + Assert.False(transientPolicy.CanRetry(new Exception())); + Assert.False(transientPolicy.ShouldIgnoreError(new Exception())); + } + + [Fact] + public void ReliableConnectionHelperTest() + { + ScriptFile scriptFile; + ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfo(out scriptFile); + + Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(connInfo.SqlConnection)); + + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(); + Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(builder)); + ReliableConnectionHelper.TryAddAlwaysOnConnectionProperties(builder, new SqlConnectionStringBuilder()); + + Assert.NotNull(ReliableConnectionHelper.GetServerName(connInfo.SqlConnection)); + Assert.NotNull(ReliableConnectionHelper.ReadServerVersion(connInfo.SqlConnection)); + + Assert.NotNull(ReliableConnectionHelper.GetAsSqlConnection(connInfo.SqlConnection)); + + ServerInfo info = ReliableConnectionHelper.GetServerVersion(connInfo.SqlConnection); + Assert.NotNull(ReliableConnectionHelper.IsVersionGreaterThan2012RTM(info)); + } + + [Fact] + public void DataSchemaErrorTests() + { + var error = new DataSchemaError(); + Assert.NotNull(error); + var isOnDisplay = error.IsOnDisplay; + var isBuildErrorCodeDefined = error.IsBuildErrorCodeDefined; + var buildErrorCode = error.BuildErrorCode; + var isPriorityEditable = error.IsPriorityEditable; + var message = error.Message; + var exception = error.Exception; + var prefix = error.Prefix; + var column = error.Column; + var line =error.Line; + var errorCode =error.ErrorCode; + var severity = error.Severity; + var document = error.Document; + + Assert.NotNull(error.ToString()); + Assert.NotNull(DataSchemaError.FormatErrorCode("ex", 1)); + } + + [Fact] + public void InitReliableSqlConnectionTest() + { + ScriptFile scriptFile; + ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfo(out scriptFile); + + var connection = connInfo.SqlConnection as ReliableSqlConnection; + var command = new ReliableSqlConnection.ReliableSqlCommand(connection); + Assert.NotNull(command.Connection); + + var retryPolicy = connection.CommandRetryPolicy; + connection.CommandRetryPolicy = retryPolicy; + Assert.True(connection.CommandRetryPolicy == retryPolicy); + connection.ChangeDatabase("master"); + Assert.True(connection.ConnectionTimeout > 0); + connection.ClearPool(); + } + + [Fact] + public void ThrottlingReasonTests() + { + var reason = RetryPolicy.ThrottlingReason.Unknown; + Assert.NotNull(reason.ThrottlingMode); + Assert.NotNull(reason.ThrottledResources); + + try + { + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(); + builder.InitialCatalog = "master"; + builder.IntegratedSecurity = false; + builder.DataSource = "localhost"; + builder.UserID = "invalid"; + builder.Password = ".."; + SqlConnection conn = new SqlConnection(builder.ToString()); + conn.Open(); + } + catch (SqlException sqlException) + { + var exceptionReason = RetryPolicy.ThrottlingReason.FromException(sqlException); + Assert.NotNull(exceptionReason); + + var errorReason = RetryPolicy.ThrottlingReason.FromError(sqlException.Errors[0]); + Assert.NotNull(errorReason); + + var detectionStrategy = new TestDataTransferErrorDetectionStrategy(); + Assert.True(detectionStrategy.InvokeCanRetrySqlException(sqlException)); + Assert.True(detectionStrategy.CanRetry(new InvalidOperationException())); + Assert.False(detectionStrategy.ShouldIgnoreError(new InvalidOperationException())); + + var detectionStrategy2 = new TestSqlAzureTemporaryAndIgnorableErrorDetectionStrategy(); + Assert.NotNull(detectionStrategy2.InvokeCanRetrySqlException(sqlException)); + Assert.NotNull(detectionStrategy2.InvokeShouldIgnoreSqlException(sqlException)); + } + + var unknownCodeReason = RetryPolicy.ThrottlingReason.FromReasonCode(-1); + var codeReason = RetryPolicy.ThrottlingReason.FromReasonCode(2601); + Assert.NotNull(codeReason); + + Assert.NotNull(codeReason.IsThrottledOnDataSpace); + Assert.NotNull(codeReason.IsThrottledOnLogSpace); + Assert.NotNull(codeReason.IsThrottledOnLogWrite); + Assert.NotNull(codeReason.IsThrottledOnDataRead); + Assert.NotNull(codeReason.IsThrottledOnCPU); + Assert.NotNull(codeReason.IsThrottledOnDatabaseSize); + Assert.NotNull(codeReason.IsThrottledOnWorkerThreads); + Assert.NotNull(codeReason.IsUnknown); + Assert.NotNull(codeReason.ToString()); + } + + [Fact] + public void RetryErrorsTest() + { + var sqlServerRetryError = new SqlServerRetryError( + "test message", new Exception(), + 1, 200, ErrorSeverity.Warning); + Assert.True(sqlServerRetryError.RetryCount == 1); + Assert.NotNull(SqlServerRetryError.FormatRetryMessage(1, TimeSpan.FromSeconds(15), new Exception())); + Assert.NotNull(SqlServerRetryError.FormatIgnoreMessage(1, new Exception())); + + var sqlServerError1 = new SqlServerError("test message", "document", ErrorSeverity.Warning); + var sqlServerError2 = new SqlServerError("test message", "document", 1, ErrorSeverity.Warning); + var sqlServerError3 = new SqlServerError(new Exception(), "document",1, ErrorSeverity.Warning); + var sqlServerError4 = new SqlServerError("test message", "document", 1, 2, ErrorSeverity.Warning); + var sqlServerError5 = new SqlServerError(new Exception(), "document", 1, 2, 3, ErrorSeverity.Warning); + var sqlServerError6 = new SqlServerError("test message", "document", 1, 2, 3, ErrorSeverity.Warning); + var sqlServerError7 = new SqlServerError("test message", new Exception(), "document", 1, 2, 3, ErrorSeverity.Warning); + + Assert.True(SqlSchemaModelErrorCodes.IsParseErrorCode(46010)); + Assert.True(SqlSchemaModelErrorCodes.IsInterpretationErrorCode(Interpretation.InterpretationBaseCode+ 1)); + Assert.True(SqlSchemaModelErrorCodes.IsStatementFilterError(StatementFilter.StatementFilterBaseCode + 1)); + } + + [Fact] + public void RetryCallbackEventArgsTest() + { + var exception = new Exception(); + var timespan = TimeSpan.FromMinutes(1); + + // Given a RetryCallbackEventArgs object with certain parameters + var args = new RetryCallbackEventArgs(5, exception, timespan); + + // If I check the properties on the object + // Then I expect the values to be the same as the values I passed into the constructor + Assert.Equal(5, args.RetryCount); + Assert.Equal(exception, args.Exception); + Assert.Equal(timespan, args.Delay); + } } } + #endif // LIVE_CONNECTION_TESTS diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Credentials/CredentialServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Credentials/CredentialServiceTests.cs index 7adbdebe..d174c48e 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Credentials/CredentialServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Credentials/CredentialServiceTests.cs @@ -24,7 +24,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection /// public class CredentialServiceTests : IDisposable { - private static readonly LinuxCredentialStore.StoreConfig config = new LinuxCredentialStore.StoreConfig() + private static readonly StoreConfig config = new StoreConfig() { CredentialFolder = ".testsecrets", CredentialFile = "sqltestsecrets.json", @@ -61,6 +61,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection { credStore.DeletePassword(credentialId); credStore.DeletePassword(otherCredId); + +#if !WINDOWS_ONLY_BUILD if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) { string credsFolder = ((LinuxCredentialStore)credStore).CredentialFolderPath; @@ -69,6 +71,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection Directory.Delete(credsFolder, true); } } +#endif } [Fact] diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Credentials/Linux/LinuxInteropTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Credentials/Linux/LinuxInteropTests.cs index 1dcff8e6..6d4a234c 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Credentials/Linux/LinuxInteropTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Credentials/Linux/LinuxInteropTests.cs @@ -15,23 +15,25 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Credentials [Fact] public void GetEUidReturnsInt() { +#if !WINDOWS_ONLY_BUILD TestUtils.RunIfLinux(() => { Assert.NotNull(Interop.Sys.GetEUid()); }); +#endif } [Fact] public void GetHomeDirectoryFromPwFindsHomeDir() { - +#if !WINDOWS_ONLY_BUILD TestUtils.RunIfLinux(() => { string userDir = LinuxCredentialStore.GetHomeDirectoryFromPw(); Assert.StartsWith("/", userDir); }); +#endif } - } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/AutocompleteTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/AutocompleteTests.cs index a44f8994..a9e346cd 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/AutocompleteTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/AutocompleteTests.cs @@ -120,20 +120,5 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices // verify that send result was called with a completion array requestContext.Verify(m => m.SendResult(It.IsAny()), Times.Once()); } - - /// - /// Test the service initialization code path and verify nothing throws - /// - [Fact] - public async void UpdateLanguageServiceOnConnection() - { - InitializeTestObjects(); - - AutoCompleteHelper.WorkspaceServiceInstance = workspaceService.Object; - - ConnectionInfo connInfo = TestObjects.GetTestConnectionInfo(); - - await LanguageService.Instance.UpdateLanguageServiceOnConnection(connInfo); - } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/CompletionServiceTest.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/CompletionServiceTest.cs new file mode 100644 index 00000000..50c2bb06 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/CompletionServiceTest.cs @@ -0,0 +1,93 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using Microsoft.SqlServer.Management.SqlParser.Intellisense; +using Microsoft.SqlServer.Management.SqlParser.MetadataProvider; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.LanguageServices; +using Microsoft.SqlTools.ServiceLayer.LanguageServices.Completion; +using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; +using Moq; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServer +{ + public class CompletionServiceTest + { + [Fact] + public void CompletionItemsShouldCreatedUsingSqlParserIfTheProcessDoesNotTimeout() + { + ConnectedBindingQueue bindingQueue = new ConnectedBindingQueue(); + ScriptDocumentInfo docInfo = CreateScriptDocumentInfo(); + CompletionService completionService = new CompletionService(bindingQueue); + ConnectionInfo connectionInfo = new ConnectionInfo(null, null, null); + bool useLowerCaseSuggestions = true; + CompletionItem[] defaultCompletionList = AutoCompleteHelper.GetDefaultCompletionItems(docInfo, useLowerCaseSuggestions); + + List declarations = new List(); + + var sqlParserWrapper = new Mock(); + sqlParserWrapper.Setup(x => x.FindCompletions(docInfo.ScriptParseInfo.ParseResult, docInfo.ParserLine, docInfo.ParserColumn, + It.IsAny())).Returns(declarations); + completionService.SqlParserWrapper = sqlParserWrapper.Object; + + AutoCompletionResult result = completionService.CreateCompletions(connectionInfo, docInfo, useLowerCaseSuggestions); + Assert.NotNull(result); + Assert.NotEqual(result.CompletionItems == null ? 0 : result.CompletionItems.Count(), defaultCompletionList.Count()); + } + + [Fact] + public void CompletionItemsShouldCreatedUsingDefaultListIfTheSqlParserProcessTimesout() + { + ConnectedBindingQueue bindingQueue = new ConnectedBindingQueue(); + ScriptDocumentInfo docInfo = CreateScriptDocumentInfo(); + CompletionService completionService = new CompletionService(bindingQueue); + ConnectionInfo connectionInfo = new ConnectionInfo(null, null, null); + bool useLowerCaseSuggestions = true; + List declarations = new List(); + CompletionItem[] defaultCompletionList = AutoCompleteHelper.GetDefaultCompletionItems(docInfo, useLowerCaseSuggestions); + + var sqlParserWrapper = new Mock(); + sqlParserWrapper.Setup(x => x.FindCompletions(docInfo.ScriptParseInfo.ParseResult, docInfo.ParserLine, docInfo.ParserColumn, + It.IsAny())).Callback(() => Thread.Sleep(LanguageService.BindingTimeout + 100)).Returns(declarations); + completionService.SqlParserWrapper = sqlParserWrapper.Object; + + AutoCompletionResult result = completionService.CreateCompletions(connectionInfo, docInfo, useLowerCaseSuggestions); + Assert.NotNull(result); + Assert.Equal(result.CompletionItems.Count(), defaultCompletionList.Count()); + Thread.Sleep(3000); + Assert.True(connectionInfo.IntellisenseMetrics.Quantile.Any()); + } + + private ScriptDocumentInfo CreateScriptDocumentInfo() + { + TextDocumentPosition doc = new TextDocumentPosition() + { + TextDocument = new TextDocumentIdentifier + { + Uri = "script file" + }, + Position = new Position() + { + Line = 1, + Character = 14 + } + }; + ScriptFile scriptFile = new ScriptFile() + { + Contents = "Select * from sys.all_objects" + }; + + ScriptParseInfo scriptParseInfo = new ScriptParseInfo() { IsConnected = true }; + ScriptDocumentInfo docInfo = new ScriptDocumentInfo(doc, scriptFile, scriptParseInfo); + + return docInfo; + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/InteractionMetricsTest.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/InteractionMetricsTest.cs new file mode 100644 index 00000000..0db21b14 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/InteractionMetricsTest.cs @@ -0,0 +1,84 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServer +{ + public class InteractionMetricsTest + { + [Fact] + public void MetricsShouldGetSortedGivenUnSortedArray() + { + int[] metrics = new int[] { 4, 8, 1, 11, 3 }; + int[] expected = new int[] { 1, 3, 4, 8, 11 }; + InteractionMetrics interactionMetrics = new InteractionMetrics(metrics); + + Assert.Equal(interactionMetrics.Metrics, expected); + } + + [Fact] + public void MetricsShouldThrowExceptionGivenNullInput() + { + int[] metrics = null; + Assert.Throws(() => new InteractionMetrics(metrics)); + } + + [Fact] + public void MetricsShouldThrowExceptionGivenEmptyInput() + { + int[] metrics = new int[] { }; + Assert.Throws(() => new InteractionMetrics(metrics)); + } + + [Fact] + public void MetricsShouldNotChangeGivenSortedArray() + { + int[] metrics = new int[] { 1, 3, 4, 8, 11 }; + int[] expected = new int[] { 1, 3, 4, 8, 11 }; + InteractionMetrics interactionMetrics = new InteractionMetrics(metrics); + + Assert.Equal(interactionMetrics.Metrics, expected); + } + + [Fact] + public void MetricsShouldNotChangeGivenArrayWithOneItem() + { + int[] metrics = new int[] { 11 }; + int[] expected = new int[] { 11 }; + InteractionMetrics interactionMetrics = new InteractionMetrics(metrics); + + Assert.Equal(interactionMetrics.Metrics, expected); + } + + [Fact] + public void MetricsCalculateQuantileCorrectlyGivenSeveralUpdates() + { + int[] metrics = new int[] { 50, 100, 300, 500, 1000, 2000 }; + Func updateValueFactory = (k, current) => current + 1; + InteractionMetrics interactionMetrics = new InteractionMetrics(metrics); + interactionMetrics.UpdateMetrics(54.4, 1, updateValueFactory); + interactionMetrics.UpdateMetrics(345, 1, updateValueFactory); + interactionMetrics.UpdateMetrics(23, 1, updateValueFactory); + interactionMetrics.UpdateMetrics(51, 1, updateValueFactory); + interactionMetrics.UpdateMetrics(500, 1, updateValueFactory); + interactionMetrics.UpdateMetrics(4005, 1, updateValueFactory); + interactionMetrics.UpdateMetrics(2500, 1, updateValueFactory); + interactionMetrics.UpdateMetrics(123, 1, updateValueFactory); + + Dictionary quantile = interactionMetrics.Quantile; + Assert.NotNull(quantile); + Assert.Equal(quantile.Count, 5); + Assert.Equal(quantile["50"], 1); + Assert.Equal(quantile["100"], 2); + Assert.Equal(quantile["300"], 1); + Assert.Equal(quantile["500"], 2); + Assert.Equal(quantile["2000"], 2); + + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs index 9ce596e5..a445fb49 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs @@ -3,28 +3,14 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using System; -using System.Collections.Generic; -using System.Data; -using System.Data.Common; -using System.IO; -using System.Reflection; using Microsoft.SqlTools.ServiceLayer.Connection; -using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; -using Microsoft.SqlTools.ServiceLayer.Credentials; using Microsoft.SqlTools.ServiceLayer.LanguageServices; -using Microsoft.SqlTools.ServiceLayer.QueryExecution; -using Microsoft.SqlTools.ServiceLayer.SqlContext; -using Microsoft.SqlTools.ServiceLayer.Test.QueryExecution; -using Microsoft.SqlTools.ServiceLayer.Test.Utility; -using Microsoft.SqlTools.ServiceLayer.Workspace; +using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts; using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; using Microsoft.SqlTools.Test.Utility; -using Moq; -using Moq.Protected; using Xunit; -namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices +namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServer { /// /// Tests for the ServiceHost Language Service tests @@ -140,131 +126,107 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices Assert.Equal(10, fileMarkers[1].ScriptRegion.EndColumnNumber); Assert.Equal(3, fileMarkers[1].ScriptRegion.EndLineNumber); } + + [Fact] + public void GetSignatureHelpReturnsNullIfParseInfoNotInitialized() + { + // Given service doesn't have parseinfo intialized for a document + const string docContent = "SELECT * FROM sys.objects"; + LanguageService service = TestObjects.GetTestLanguageService(); + var scriptFile = new ScriptFile(); + scriptFile.SetFileContents(docContent); + + // When requesting SignatureHelp + SignatureHelp signatureHelp = service.GetSignatureHelp(CreateDummyDocPosition(), scriptFile); + + // Then null is returned as no parse info can be used to find the signature + Assert.Null(signatureHelp); + } + + private TextDocumentPosition CreateDummyDocPosition() + { + return new TextDocumentPosition + { + TextDocument = new TextDocumentIdentifier { Uri = TestObjects.ScriptUri }, + Position = new Position + { + Line = 0, + Character = 0 + } + }; + } + #endregion #region "General Language Service tests" +#if LIVE_CONNECTION_TESTS + + private static void GetLiveAutoCompleteTestObjects( + out TextDocumentPosition textDocument, + out ScriptFile scriptFile, + out ConnectionInfo connInfo) + { + textDocument = new TextDocumentPosition + { + TextDocument = new TextDocumentIdentifier {Uri = TestObjects.ScriptUri}, + Position = new Position + { + Line = 0, + Character = 0 + } + }; + + connInfo = TestObjects.InitLiveConnectionInfo(out scriptFile); + } + /// /// Test the service initialization code path and verify nothing throws /// // Test is causing failures in build lab..investigating to reenable - //[Fact] - public void ServiceInitiailzation() + [Fact] + public void ServiceInitialization() { - InitializeTestServices(); + try + { + TestObjects.InitializeTestServices(); + } + catch (System.ArgumentException) + { + } Assert.True(LanguageService.Instance.Context != null); Assert.True(LanguageService.ConnectionServiceInstance != null); Assert.True(LanguageService.Instance.CurrentSettings != null); Assert.True(LanguageService.Instance.CurrentWorkspace != null); - - LanguageService.ConnectionServiceInstance = null; - Assert.True(LanguageService.ConnectionServiceInstance == null); } /// /// Test the service initialization code path and verify nothing throws /// // Test is causing failures in build lab..investigating to reenable - //[Fact] + [Fact] public void PrepopulateCommonMetadata() { - InitializeTestServices(); + ScriptFile scriptFile; + ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfo(out scriptFile); - string sqlFilePath = GetTestSqlFile(); - ScriptFile scriptFile = WorkspaceService.Instance.Workspace.GetFile(sqlFilePath); - - string ownerUri = scriptFile.ClientFilePath; - var connectionService = TestObjects.GetTestConnectionService(); - var connectionResult = - connectionService - .Connect(new ConnectParams() - { - OwnerUri = ownerUri, - Connection = TestObjects.GetTestConnectionDetails() - }); - - ConnectionInfo connInfo = null; - connectionService.TryFindConnection(ownerUri, out connInfo); - - ScriptParseInfo scriptInfo = new ScriptParseInfo(); - scriptInfo.IsConnected = true; + ScriptParseInfo scriptInfo = new ScriptParseInfo {IsConnected = true}; AutoCompleteHelper.PrepopulateCommonMetadata(connInfo, scriptInfo, null); } - private string GetTestSqlFile() - { - string filePath = Path.Combine( - Path.GetDirectoryName(Assembly.GetEntryAssembly().Location), - "sqltest.sql"); - - if (File.Exists(filePath)) - { - File.Delete(filePath); - } - - File.WriteAllText(filePath, "SELECT * FROM sys.objects\n"); - - return filePath; - } - - private void InitializeTestServices() - { - const string hostName = "SQL Tools Service Host"; - const string hostProfileId = "SQLToolsService"; - Version hostVersion = new Version(1,0); - - // set up the host details and profile paths - var hostDetails = new HostDetails(hostName, hostProfileId, hostVersion); - SqlToolsContext sqlToolsContext = new SqlToolsContext(hostDetails); - - // Grab the instance of the service host - Hosting.ServiceHost serviceHost = Hosting.ServiceHost.Instance; - - // Start the service - serviceHost.Start().Wait(); - - // Initialize the services that will be hosted here - WorkspaceService.Instance.InitializeService(serviceHost); - LanguageService.Instance.InitializeService(serviceHost, sqlToolsContext); - ConnectionService.Instance.InitializeService(serviceHost); - CredentialService.Instance.InitializeService(serviceHost); - QueryExecutionService.Instance.InitializeService(serviceHost); - - serviceHost.Initialize(); - } - - private Hosting.ServiceHost GetTestServiceHost() - { - // set up the host details and profile paths - var hostDetails = new HostDetails("Test Service Host", "SQLToolsService", new Version(1,0)); - SqlToolsContext context = new SqlToolsContext(hostDetails); - - // Grab the instance of the service host - Hosting.ServiceHost host = Hosting.ServiceHost.Instance; - - // Start the service - host.Start().Wait(); - - return host; - } - - #endregion - - #region "Autocomplete Tests" - // This test currently requires a live database connection to initialize // SMO connected metadata provider. Since we don't want a live DB dependency // in the CI unit tests this scenario is currently disabled. - //[Fact] + [Fact] public void AutoCompleteFindCompletions() { TextDocumentPosition textDocument; ConnectionInfo connInfo; ScriptFile scriptFile; - Common.GetAutoCompleteTestObjects(out textDocument, out scriptFile, out connInfo); + GetLiveAutoCompleteTestObjects(out textDocument, out scriptFile, out connInfo); textDocument.Position.Character = 7; scriptFile.Contents = "select "; @@ -278,32 +240,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices Assert.True(completions.Length > 0); } - /// - /// Creates a mock db command that returns a predefined result set - /// - public static DbCommand CreateTestCommand(Dictionary[][] data) - { - var commandMock = new Mock { CallBase = true }; - var commandMockSetup = commandMock.Protected() - .Setup("ExecuteDbDataReader", It.IsAny()); - - commandMockSetup.Returns(new TestDbDataReader(data)); - - return commandMock.Object; - } - - /// - /// Creates a mock db connection that returns predefined data when queried for a result set - /// - public DbConnection CreateMockDbConnection(Dictionary[][] data) - { - var connectionMock = new Mock { CallBase = true }; - connectionMock.Protected() - .Setup("CreateDbCommand") - .Returns(CreateTestCommand(data)); - - return connectionMock.Object; - } +#endif #endregion } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/PeekDefinitionTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/PeekDefinitionTests.cs new file mode 100644 index 00000000..b70c6c54 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/PeekDefinitionTests.cs @@ -0,0 +1,376 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// +using System; +using System.IO; +using System.Collections.Generic; +using System.Threading.Tasks; +using System.Runtime.InteropServices; +using Microsoft.SqlServer.Management.SqlParser.Binder; +using Microsoft.SqlServer.Management.SqlParser.MetadataProvider; +using Microsoft.SqlServer.Management.SqlParser.Parser; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.LanguageServices; +using Microsoft.SqlTools.ServiceLayer.SqlContext; +using Microsoft.SqlTools.ServiceLayer.Test.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.Workspace; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; +using Microsoft.SqlTools.Test.Utility; +using Location = Microsoft.SqlTools.ServiceLayer.Workspace.Contracts.Location; +using Moq; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices +{ + /// + /// Tests for the language service peek definition/ go to definition feature + /// + public class PeekDefinitionTests + { + private const int TaskTimeout = 30000; + + private readonly string testScriptUri = TestObjects.ScriptUri; + + private readonly string testConnectionKey = "testdbcontextkey"; + + private Mock bindingQueue; + + private Mock> workspaceService; + + private Mock> requestContext; + + private Mock binder; + + private TextDocumentPosition textDocument; + + private const string OwnerUri = "testFile1"; + + private void InitializeTestObjects() + { + // initial cursor position in the script file + textDocument = new TextDocumentPosition + { + TextDocument = new TextDocumentIdentifier {Uri = this.testScriptUri}, + Position = new Position + { + Line = 0, + Character = 23 + } + }; + + // default settings are stored in the workspace service + WorkspaceService.Instance.CurrentSettings = new SqlToolsSettings(); + + // set up file for returning the query + var fileMock = new Mock(); + fileMock.SetupGet(file => file.Contents).Returns(Common.StandardQuery); + fileMock.SetupGet(file => file.ClientFilePath).Returns(this.testScriptUri); + + // set up workspace mock + workspaceService = new Mock>(); + workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) + .Returns(fileMock.Object); + + // setup binding queue mock + bindingQueue = new Mock(); + bindingQueue.Setup(q => q.AddConnectionContext(It.IsAny())) + .Returns(this.testConnectionKey); + + // inject mock instances into the Language Service + LanguageService.WorkspaceServiceInstance = workspaceService.Object; + LanguageService.ConnectionServiceInstance = TestObjects.GetTestConnectionService(); + ConnectionInfo connectionInfo = TestObjects.GetTestConnectionInfo(); + LanguageService.ConnectionServiceInstance.OwnerToConnectionMap.Add(this.testScriptUri, connectionInfo); + LanguageService.Instance.BindingQueue = bindingQueue.Object; + + // setup the mock for SendResult + requestContext = new Mock>(); + requestContext.Setup(rc => rc.SendResult(It.IsAny())) + .Returns(Task.FromResult(0)); + + // setup the IBinder mock + binder = new Mock(); + binder.Setup(b => b.Bind( + It.IsAny>(), + It.IsAny(), + It.IsAny())); + + var testScriptParseInfo = new ScriptParseInfo(); + LanguageService.Instance.AddOrUpdateScriptParseInfo(this.testScriptUri, testScriptParseInfo); + testScriptParseInfo.IsConnected = true; + testScriptParseInfo.ConnectionKey = LanguageService.Instance.BindingQueue.AddConnectionContext(connectionInfo); + + // setup the binding context object + ConnectedBindingContext bindingContext = new ConnectedBindingContext(); + bindingContext.Binder = binder.Object; + bindingContext.MetadataDisplayInfoProvider = new MetadataDisplayInfoProvider(); + LanguageService.Instance.BindingQueue.BindingContextMap.Add(testScriptParseInfo.ConnectionKey, bindingContext); + } + + + /// + /// Tests the definition event handler. When called with no active connection, no definition is sent + /// + [Fact] + public void DefinitionsHandlerWithNoConnectionTest() + { + InitializeTestObjects(); + // request the completion list + Task handleCompletion = LanguageService.HandleDefinitionRequest(textDocument, requestContext.Object); + handleCompletion.Wait(TaskTimeout); + + // verify that send result was not called + requestContext.Verify(m => m.SendResult(It.IsAny()), Times.Never()); + } + + /// + /// Tests creating location objects on windows and non-windows systems + /// + [Fact] + public void GetLocationFromFileForValidFilePathTest() + { + String filePath = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? "C:\\test\\script.sql" : "/test/script.sql"; + PeekDefinition peekDefinition = new PeekDefinition(null); + Location[] locations = peekDefinition.GetLocationFromFile(filePath, 0); + + String expectedFilePath = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? "file:///C:/test/script.sql" : "file:/test/script.sql"; + Assert.Equal(locations[0].Uri, expectedFilePath); + } + + /// + /// Test PeekDefinition.GetSchemaFromDatabaseQualifiedName with a valid database name + /// + [Fact] + public void GetSchemaFromDatabaseQualifiedNameWithValidNameTest() + { + PeekDefinition peekDefinition = new PeekDefinition(null); + string validDatabaseQualifiedName = "master.test.test_table"; + string objectName = "test_table"; + string expectedSchemaName = "test"; + + string actualSchemaName = peekDefinition.GetSchemaFromDatabaseQualifiedName(validDatabaseQualifiedName, objectName); + Assert.Equal(actualSchemaName, expectedSchemaName); + } + + /// + /// Test PeekDefinition.GetSchemaFromDatabaseQualifiedName with a valid object name and no schema + /// + + [Fact] + public void GetSchemaFromDatabaseQualifiedNameWithNoSchemaTest() + { + PeekDefinition peekDefinition = new PeekDefinition(null); + string validDatabaseQualifiedName = "test_table"; + string objectName = "test_table"; + string expectedSchemaName = "dbo"; + + string actualSchemaName = peekDefinition.GetSchemaFromDatabaseQualifiedName(validDatabaseQualifiedName, objectName); + Assert.Equal(actualSchemaName, expectedSchemaName); + } + + /// + /// Test PeekDefinition.GetSchemaFromDatabaseQualifiedName with a invalid database name + /// + [Fact] + public void GetSchemaFromDatabaseQualifiedNameWithInvalidNameTest() + { + PeekDefinition peekDefinition = new PeekDefinition(null); + string validDatabaseQualifiedName = "x.y.z"; + string objectName = "test_table"; + string expectedSchemaName = "dbo"; + + string actualSchemaName = peekDefinition.GetSchemaFromDatabaseQualifiedName(validDatabaseQualifiedName, objectName); + Assert.Equal(actualSchemaName, expectedSchemaName); + } + +#if LIVE_CONNECTION_TESTS + /// + /// Test get definition for a table object with active connection + /// + [Fact] + public void GetValidTableDefinitionTest() + { + // Get live connectionInfo + ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfoForDefinition(); + PeekDefinition peekDefinition = new PeekDefinition(connInfo); + string objectName = "test_table"; + string schemaName = null; + string objectType = "TABLE"; + + // Get locations for valid table object + Location[] locations = peekDefinition.GetSqlObjectDefinition(peekDefinition.GetTableScripts, objectName, schemaName, objectType); + Assert.NotNull(locations); + Cleanup(locations); + } + + /// + /// Test get definition for a invalid table object with active connection + /// + [Fact] + public void GetTableDefinitionInvalidObjectTest() + { + // Get live connectionInfo + ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfoForDefinition(); + PeekDefinition peekDefinition = new PeekDefinition(connInfo); + string objectName = "test_invalid"; + string schemaName = null; + string objectType = "TABLE"; + + // Get locations for invalid table object + Location[] locations = peekDefinition.GetSqlObjectDefinition(peekDefinition.GetTableScripts, objectName, schemaName, objectType); + Assert.Null(locations); + } + + /// + /// Test get definition for a valid table object with schema and active connection + /// + [Fact] + public void GetTableDefinitionWithSchemaTest() + { + // Get live connectionInfo + ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfoForDefinition(); + PeekDefinition peekDefinition = new PeekDefinition(connInfo); + string objectName = "test_table"; + string schemaName = "dbo"; + string objectType = "TABLE"; + + // Get locations for valid table object with schema name + Location[] locations = peekDefinition.GetSqlObjectDefinition(peekDefinition.GetTableScripts, objectName, schemaName, objectType); + Assert.NotNull(locations); + Cleanup(locations); + } + + /// + /// Test GetDefinition with an unsupported type(function) + /// + [Fact] + public void GetUnsupportedDefinitionForFullScript() + { + + ScriptFile scriptFile; + TextDocumentPosition textDocument = new TextDocumentPosition + { + TextDocument = new TextDocumentIdentifier { Uri = OwnerUri }, + Position = new Position + { + Line = 0, + Character = 20 + } + }; + ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfo(out scriptFile); + scriptFile.Contents = "select * from dbo.func ()"; + + var languageService = LanguageService.Instance; + ScriptParseInfo scriptInfo = new ScriptParseInfo { IsConnected = true }; + languageService.ScriptParseInfoMap.Add(OwnerUri, scriptInfo); + + var locations = languageService.GetDefinition(textDocument, scriptFile, connInfo); + Assert.Null(locations); + } + + /// + /// Test get definition for a view object with active connection + /// + [Fact] + public void GetValidViewDefinitionTest() + { + ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfoForDefinition(); + PeekDefinition peekDefinition = new PeekDefinition(connInfo); + string objectName = "objects"; + string schemaName = "sys"; + string objectType = "VIEW"; + + Location[] locations = peekDefinition.GetSqlObjectDefinition(peekDefinition.GetViewScripts, objectName, schemaName, objectType); + Assert.NotNull(locations); + Cleanup(locations); + } + + /// + /// Test get definition for an invalid view object with no schema name and with active connection + /// + [Fact] + public void GetViewDefinitionInvalidObjectTest() + { + ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfoForDefinition(); + PeekDefinition peekDefinition = new PeekDefinition(connInfo); + string objectName = "objects"; + string schemaName = null; + string objectType = "VIEW"; + + Location[] locations = peekDefinition.GetSqlObjectDefinition(peekDefinition.GetViewScripts, objectName, schemaName, objectType); + Assert.Null(locations); + } + + /// + /// Test get definition for a stored procedure object with active connection + /// + [Fact] + public void GetStoredProcedureDefinitionTest() + { + ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfoForDefinition(); + PeekDefinition peekDefinition = new PeekDefinition(connInfo); + string objectName = "SP1"; + string schemaName = "dbo"; + string objectType = "PROCEDURE"; + + Location[] locations = peekDefinition.GetSqlObjectDefinition(peekDefinition.GetStoredProcedureScripts, objectName, schemaName, objectType); + Assert.NotNull(locations); + Cleanup(locations); + } + + /// + /// Test get definition for a stored procedure object that does not exist with active connection + /// + [Fact] + public void GetStoredProcedureDefinitionFailureTest() + { + ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfoForDefinition(); + PeekDefinition peekDefinition = new PeekDefinition(connInfo); + string objectName = "SP2"; + string schemaName = "dbo"; + string objectType = "PROCEDURE"; + + Location[] locations = peekDefinition.GetSqlObjectDefinition(peekDefinition.GetStoredProcedureScripts, objectName, schemaName, objectType); + Assert.Null(locations); + } + + /// + /// Test get definition for a stored procedure object with active connection and no schema + /// + [Fact] + public void GetStoredProcedureDefinitionWithoutSchemaTest() + { + ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfoForDefinition(); + PeekDefinition peekDefinition = new PeekDefinition(connInfo); + string objectName = "SP1"; + string schemaName = null; + string objectType = "PROCEDURE"; + + Location[] locations = peekDefinition.GetSqlObjectDefinition(peekDefinition.GetStoredProcedureScripts, objectName, schemaName, objectType); + Assert.NotNull(locations); + Cleanup(locations); + } + + /// + /// Helper method to clean up script files + /// + private void Cleanup(Location[] locations) + { + Uri fileUri = new Uri(locations[0].Uri); + if (File.Exists(fileUri.LocalPath)) + { + try + { + File.Delete(fileUri.LocalPath); + } + catch(Exception) + { + + } + } + } +#endif + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/SqlCompletionItemTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/SqlCompletionItemTests.cs new file mode 100644 index 00000000..727e379d --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/SqlCompletionItemTests.cs @@ -0,0 +1,210 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlServer.Management.SqlParser.Intellisense; +using Microsoft.SqlTools.ServiceLayer.LanguageServices; +using Microsoft.SqlTools.ServiceLayer.LanguageServices.Completion; +using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServer +{ + public class SqlCompletionItemTests + { + [Fact] + public void InsertTextShouldIncludeBracketGivenNameWithSpace() + { + string declarationTitle = "name with space"; + string expected = "[" + declarationTitle + "]"; + + DeclarationType declarationType = DeclarationType.Table; + string tokenText = ""; + SqlCompletionItem item = new SqlCompletionItem(declarationTitle, declarationType, tokenText); + CompletionItem completionItem = item.CreateCompletionItem(0, 1, 2); + + Assert.True(completionItem.InsertText.StartsWith("[") && completionItem.InsertText.EndsWith("]")); + } + + [Fact] + public void InsertTextShouldIncludeBracketGivenNameWithSpecialCharacter() + { + string declarationTitle = "name @"; + string expected = "[" + declarationTitle + "]"; + DeclarationType declarationType = DeclarationType.Table; + string tokenText = ""; + SqlCompletionItem item = new SqlCompletionItem(declarationTitle, declarationType, tokenText); + CompletionItem completionItem = item.CreateCompletionItem(0, 1, 2); + + Assert.Equal(completionItem.InsertText, expected); + Assert.Equal(completionItem.Detail, declarationTitle); + Assert.Equal(completionItem.Label, declarationTitle); + } + + [Fact] + public void LabelShouldIncludeBracketGivenTokenWithBracket() + { + string declarationTitle = "name"; + string expected = "[" + declarationTitle + "]"; + DeclarationType declarationType = DeclarationType.Table; + string tokenText = "["; + SqlCompletionItem item = new SqlCompletionItem(declarationTitle, declarationType, tokenText); + CompletionItem completionItem = item.CreateCompletionItem(0, 1, 2); + + Assert.Equal(completionItem.Label, expected); + Assert.Equal(completionItem.InsertText, expected); + Assert.Equal(completionItem.Detail, expected); + } + + [Fact] + public void LabelShouldIncludeBracketGivenTokenWithBrackets() + { + string declarationTitle = "name"; + string expected = "[" + declarationTitle + "]"; + DeclarationType declarationType = DeclarationType.Table; + string tokenText = "[]"; + SqlCompletionItem item = new SqlCompletionItem(declarationTitle, declarationType, tokenText); + CompletionItem completionItem = item.CreateCompletionItem(0, 1, 2); + + Assert.Equal(completionItem.Label, expected); + Assert.Equal(completionItem.InsertText, expected); + Assert.Equal(completionItem.Detail, expected); + } + + [Fact] + public void LabelShouldIncludeBracketGivenSqlObjectNameWithBracket() + { + string declarationTitle = @"Bracket\["; + string expected = "[" + declarationTitle + "]"; + DeclarationType declarationType = DeclarationType.Table; + string tokenText = ""; + SqlCompletionItem item = new SqlCompletionItem(declarationTitle, declarationType, tokenText); + CompletionItem completionItem = item.CreateCompletionItem(0, 1, 2); + + Assert.Equal(completionItem.Label, declarationTitle); + Assert.Equal(completionItem.InsertText, expected); + Assert.Equal(completionItem.Detail, declarationTitle); + } + + [Fact] + public void LabelShouldIncludeBracketGivenSqlObjectNameWithBracketAndTokenWithBracket() + { + string declarationTitle = @"Bracket\["; + string expected = "[" + declarationTitle + "]"; + DeclarationType declarationType = DeclarationType.Table; + string tokenText = "[]"; + SqlCompletionItem item = new SqlCompletionItem(declarationTitle, declarationType, tokenText); + CompletionItem completionItem = item.CreateCompletionItem(0, 1, 2); + + Assert.Equal(completionItem.Label, expected); + Assert.Equal(completionItem.InsertText, expected); + Assert.Equal(completionItem.Detail, expected); + } + + [Fact] + public void LabelShouldNotIncludeBracketGivenNameWithBrackets() + { + string declarationTitle = "[name]"; + string expected = declarationTitle; + DeclarationType declarationType = DeclarationType.Table; + string tokenText = "[]"; + SqlCompletionItem item = new SqlCompletionItem(declarationTitle, declarationType, tokenText); + CompletionItem completionItem = item.CreateCompletionItem(0, 1, 2); + + Assert.Equal(completionItem.Label, expected); + Assert.Equal(completionItem.InsertText, expected); + Assert.Equal(completionItem.Detail, expected); + } + + [Fact] + public void LabelShouldIncludeBracketGivenNameWithOneBracket() + { + string declarationTitle = "[name"; + string expected = "[" + declarationTitle + "]"; + DeclarationType declarationType = DeclarationType.Table; + string tokenText = "[]"; + SqlCompletionItem item = new SqlCompletionItem(declarationTitle, declarationType, tokenText); + CompletionItem completionItem = item.CreateCompletionItem(0, 1, 2); + + Assert.Equal(completionItem.Label, expected); + Assert.Equal(completionItem.InsertText, expected); + Assert.Equal(completionItem.Detail, expected); + } + + [Fact] + public void KindShouldBeModuleGivenSchemaDeclarationType() + { + CompletionItemKind expectedType = CompletionItemKind.Module; + DeclarationType declarationType = DeclarationType.Schema; + ValidateDeclarationType(declarationType, expectedType); + } + + [Fact] + public void KindShouldBeFieldGivenColumnDeclarationType() + { + CompletionItemKind expectedType = CompletionItemKind.Field; + DeclarationType declarationType = DeclarationType.Column; + ValidateDeclarationType(declarationType, expectedType); + } + + [Fact] + public void KindShouldBeFileGivenTableDeclarationType() + { + CompletionItemKind expectedType = CompletionItemKind.File; + DeclarationType declarationType = DeclarationType.Table; + ValidateDeclarationType(declarationType, expectedType); + } + + [Fact] + public void KindShouldBeFileGivenViewDeclarationType() + { + CompletionItemKind expectedType = CompletionItemKind.File; + DeclarationType declarationType = DeclarationType.View; + ValidateDeclarationType(declarationType, expectedType); + } + + [Fact] + public void KindShouldBeMethodGivenDatabaseDeclarationType() + { + CompletionItemKind expectedType = CompletionItemKind.Method; + DeclarationType declarationType = DeclarationType.Database; + ValidateDeclarationType(declarationType, expectedType); + } + + [Fact] + public void KindShouldBeValueGivenScalarValuedFunctionDeclarationType() + { + CompletionItemKind expectedType = CompletionItemKind.Value; + DeclarationType declarationType = DeclarationType.ScalarValuedFunction; + ValidateDeclarationType(declarationType, expectedType); + } + + [Fact] + public void KindShouldBeValueGivenTableValuedFunctionDeclarationType() + { + CompletionItemKind expectedType = CompletionItemKind.Value; + DeclarationType declarationType = DeclarationType.TableValuedFunction; + ValidateDeclarationType(declarationType, expectedType); + } + + [Fact] + public void KindShouldBeUnitGivenUnknownDeclarationType() + { + CompletionItemKind expectedType = CompletionItemKind.Unit; + DeclarationType declarationType = DeclarationType.XmlIndex; + ValidateDeclarationType(declarationType, expectedType); + } + + private void ValidateDeclarationType(DeclarationType declarationType, CompletionItemKind expectedType) + { + string declarationTitle = "name"; + string tokenText = ""; + SqlCompletionItem item = new SqlCompletionItem(declarationTitle, declarationType, tokenText); + CompletionItem completionItem = item.CreateCompletionItem(0, 1, 2); + + + Assert.Equal(completionItem.Kind, expectedType); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs index 3899507c..9bc04a98 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs @@ -32,8 +32,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // If: // ... I request a query (doesn't matter what kind) and execute it - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); - var executeParams = new QueryExecuteParams { QuerySelection = Common.GetSubSectionDocument(), OwnerUri = Common.OwnerUri }; + var queryService = Common.GetPrimedExecutionService(null, true, false, workspaceService.Object); + var executeParams = new QueryExecuteParams { QuerySelection = Common.SubsectionDocument, OwnerUri = Common.OwnerUri }; var executeRequest = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); @@ -44,7 +44,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution var cancelParams = new QueryCancelParams {OwnerUri = Common.OwnerUri}; QueryCancelResult result = null; var cancelRequest = GetQueryCancelResultContextMock(qcr => result = qcr, null); - queryService.HandleCancelRequest(cancelParams, cancelRequest.Object).Wait(); + await queryService.HandleCancelRequest(cancelParams, cancelRequest.Object); // Then: // ... I should have seen a successful event (no messages) @@ -68,7 +68,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution .Returns(fileMock.Object); // If: // ... I request a query (doesn't matter what kind) and wait for execution - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = Common.GetPrimedExecutionService(null, true, false, workspaceService.Object); var executeParams = new QueryExecuteParams {QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri}; var executeRequest = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); @@ -91,17 +91,17 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution } [Fact] - public async void CancelNonExistantTest() + public async Task CancelNonExistantTest() { var workspaceService = new Mock>(); // If: // ... I request to cancel a query that doesn't exist - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), false, workspaceService.Object); + var queryService = Common.GetPrimedExecutionService(null, false, false, workspaceService.Object); var cancelParams = new QueryCancelParams {OwnerUri = "Doesn't Exist"}; QueryCancelResult result = null; var cancelRequest = GetQueryCancelResultContextMock(qcr => result = qcr, null); - queryService.HandleCancelRequest(cancelParams, cancelRequest.Object).Wait(); + await queryService.HandleCancelRequest(cancelParams, cancelRequest.Object); // Then: // ... I should have seen a result event with an error message diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs index bd5abe81..7923a554 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs @@ -1,20 +1,17 @@ -// +// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. -// using System; using System.Collections.Generic; using System.Data; using System.Data.Common; using System.IO; -using System.Data.SqlClient; using System.Threading; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; -using Microsoft.SqlServer.Management.Common; -using Microsoft.SqlTools.ServiceLayer.LanguageServices; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.QueryExecution; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage; @@ -29,37 +26,61 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution { public class Common { - public const SelectionData WholeDocument = null; - - public const string StandardQuery = "SELECT * FROM sys.objects"; + #region Constants public const string InvalidQuery = "SELECT *** FROM sys.objects"; public const string NoOpQuery = "-- No ops here, just us chickens."; - public const string UdtQuery = "SELECT hierarchyid::Parse('/')"; + public const int Ordinal = 100; // We'll pick something other than default(int) public const string OwnerUri = "testFile"; - public const int StandardRows = 5; - public const int StandardColumns = 5; - public static string TestServer { get; set; } + public const string StandardQuery = "SELECT * FROM sys.objects"; - public static string TestDatabase { get; set; } + public const int StandardRows = 5; - static Common() + public const string UdtQuery = "SELECT hierarchyid::Parse('/')"; + + public const SelectionData WholeDocument = null; + + public static readonly ConnectionDetails StandardConnectionDetails = new ConnectionDetails { - TestServer = "sqltools11"; - TestDatabase = "master"; - } + DatabaseName = "123", + Password = "456", + ServerName = "789", + UserName = "012" + }; + + public static readonly SelectionData SubsectionDocument = new SelectionData(0, 0, 2, 2); + + #endregion public static Dictionary[] StandardTestData { get { return GetTestData(StandardRows, StandardColumns); } } + #region Public Methods + + public static Batch GetBasicExecutedBatch() + { + Batch batch = new Batch(StandardQuery, SubsectionDocument, 1, GetFileStreamFactory(new Dictionary())); + batch.Execute(CreateTestConnection(new[] {StandardTestData}, false), CancellationToken.None).Wait(); + return batch; + } + + public static Query GetBasicExecutedQuery() + { + ConnectionInfo ci = CreateTestConnectionInfo(new[] {StandardTestData}, false); + Query query = new Query(StandardQuery, ci, new QueryExecutionSettings(), GetFileStreamFactory(new Dictionary())); + query.Execute(); + query.ExecutionTask.Wait(); + return query; + } + public static Dictionary[] GetTestData(int columns, int rows) { Dictionary[] output = new Dictionary[rows]; @@ -76,92 +97,39 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution return output; } - public static SelectionData GetSubSectionDocument() + public static async Task AwaitExecution(QueryExecutionService service, QueryExecuteParams qeParams, + RequestContext requestContext) { - return new SelectionData(0, 0, 2, 2); + await service.HandleExecuteRequest(qeParams, requestContext); + if (service.ActiveQueries.ContainsKey(qeParams.OwnerUri) && service.ActiveQueries[qeParams.OwnerUri].ExecutionTask != null) + { + await service.ActiveQueries[qeParams.OwnerUri].ExecutionTask; + } } - public static Batch GetBasicExecutedBatch() - { - Batch batch = new Batch(StandardQuery, 0, 0, 2, 2, GetFileStreamFactory()); - batch.Execute(CreateTestConnection(new[] {StandardTestData}, false), CancellationToken.None).Wait(); - return batch; - } - - public static Query GetBasicExecutedQuery() - { - ConnectionInfo ci = CreateTestConnectionInfo(new[] {StandardTestData}, false); - Query query = new Query(StandardQuery, ci, new QueryExecutionSettings(), GetFileStreamFactory()); - query.Execute(); - query.ExecutionTask.Wait(); - return query; - } + #endregion #region FileStreamWriteMocking - public static IFileStreamFactory GetFileStreamFactory() + public static IFileStreamFactory GetFileStreamFactory(Dictionary storage) { Mock mock = new Mock(); + mock.Setup(fsf => fsf.CreateFile()) + .Returns(() => + { + string fileName = Guid.NewGuid().ToString(); + storage.Add(fileName, new byte[8192]); + return fileName; + }); mock.Setup(fsf => fsf.GetReader(It.IsAny())) - .Returns(new ServiceBufferFileStreamReader(new InMemoryWrapper(), It.IsAny())); + .Returns(output => new ServiceBufferFileStreamReader(new MemoryStream(storage[output]))); mock.Setup(fsf => fsf.GetWriter(It.IsAny(), It.IsAny(), It.IsAny())) - .Returns(new ServiceBufferFileStreamWriter(new InMemoryWrapper(), It.IsAny(), 1024, - 1024)); + .Returns((output, chars, xml) => new ServiceBufferFileStreamWriter( + new MemoryStream(storage[output]), chars, xml)); return mock.Object; } - public class InMemoryWrapper : IFileStreamWrapper - { - private readonly byte[] storage = new byte[8192]; - private readonly MemoryStream memoryStream; - private bool readingOnly; - - public InMemoryWrapper() - { - memoryStream = new MemoryStream(storage); - } - - public void Dispose() - { - // We'll dispose this via a special method - } - - public void Init(string fileName, int bufferSize, FileAccess fAccess) - { - readingOnly = fAccess == FileAccess.Read; - } - - public int ReadData(byte[] buffer, int bytes) - { - return ReadData(buffer, bytes, memoryStream.Position); - } - - public int ReadData(byte[] buffer, int bytes, long fileOffset) - { - memoryStream.Seek(fileOffset, SeekOrigin.Begin); - return memoryStream.Read(buffer, 0, bytes); - } - - public int WriteData(byte[] buffer, int bytes) - { - if (readingOnly) { throw new InvalidOperationException(); } - memoryStream.Write(buffer, 0, bytes); - memoryStream.Flush(); - return bytes; - } - - public void Flush() - { - if (readingOnly) { throw new InvalidOperationException(); } - } - - public void Close() - { - memoryStream.Dispose(); - } - } - #endregion #region DbConnection Mocking @@ -193,7 +161,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution var connectionMock = new Mock { CallBase = true }; connectionMock.Protected() .Setup("CreateDbCommand") - .Returns(CreateTestCommand(data, throwOnRead)); + .Returns(() => CreateTestCommand(data, throwOnRead)); connectionMock.Setup(dbc => dbc.Open()) .Callback(() => connectionMock.SetupGet(dbc => dbc.State).Returns(ConnectionState.Open)); connectionMock.Setup(dbc => dbc.Close()) @@ -206,90 +174,45 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution { var mockFactory = new Mock(); mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny())) - .Returns(CreateTestConnection(data, throwOnRead)); + .Returns(() => CreateTestConnection(data, throwOnRead)); return mockFactory.Object; } public static ConnectionInfo CreateTestConnectionInfo(Dictionary[][] data, bool throwOnRead) { - // Create connection info - ConnectionDetails connDetails = new ConnectionDetails - { - UserName = "sa", - Password = "Yukon900", - DatabaseName = Common.TestDatabase, - ServerName = Common.TestServer - }; - - return new ConnectionInfo(CreateMockFactory(data, throwOnRead), OwnerUri, connDetails); + return new ConnectionInfo(CreateMockFactory(data, throwOnRead), OwnerUri, StandardConnectionDetails); } #endregion #region Service Mocking - - public static void GetAutoCompleteTestObjects( - out TextDocumentPosition textDocument, - out ScriptFile scriptFile, - out ConnectionInfo connInfo - ) + + public static QueryExecutionService GetPrimedExecutionService(Dictionary[][] data, bool isConnected, bool throwOnRead, WorkspaceService workspaceService) { - textDocument = new TextDocumentPosition - { - TextDocument = new TextDocumentIdentifier {Uri = OwnerUri}, - Position = new Position - { - Line = 0, - Character = 0 - } - }; + // Create a place for the temp "files" to be written + Dictionary storage = new Dictionary(); - connInfo = Common.CreateTestConnectionInfo(null, false); + // Create the connection factory with the dataset + var factory = CreateTestConnectionInfo(data, throwOnRead).Factory; - LanguageService.Instance.ScriptParseInfoMap.Add(textDocument.TextDocument.Uri, new ScriptParseInfo()); - - scriptFile = new ScriptFile {ClientFilePath = textDocument.TextDocument.Uri}; + // Mock the connection service + var connectionService = new Mock(); + ConnectionInfo ci = new ConnectionInfo(factory, OwnerUri, StandardConnectionDetails); + ConnectionInfo outValMock; + connectionService + .Setup(service => service.TryFindConnection(It.IsAny(), out outValMock)) + .OutCallback((string owner, out ConnectionInfo connInfo) => connInfo = isConnected ? ci : null) + .Returns(isConnected); + return new QueryExecutionService(connectionService.Object, workspaceService) {BufferFileStreamFactory = GetFileStreamFactory(storage)}; } - public static ServerConnection GetServerConnection(ConnectionInfo connection) - { - string connectionString = ConnectionService.BuildConnectionString(connection.ConnectionDetails); - var sqlConnection = new SqlConnection(connectionString); - return new ServerConnection(sqlConnection); - } - - public static ConnectionDetails GetTestConnectionDetails() - { - return new ConnectionDetails - { - DatabaseName = "123", - Password = "456", - ServerName = "789", - UserName = "012" - }; - } - - public static async Task GetPrimedExecutionService(ISqlConnectionFactory factory, bool isConnected, WorkspaceService workspaceService) - { - var connectionService = new ConnectionService(factory); - if (isConnected) - { - await connectionService.Connect(new ConnectParams - { - Connection = GetTestConnectionDetails(), - OwnerUri = OwnerUri - }); - } - return new QueryExecutionService(connectionService, workspaceService) {BufferFileStreamFactory = GetFileStreamFactory()}; - } - - public static WorkspaceService GetPrimedWorkspaceService() + public static WorkspaceService GetPrimedWorkspaceService(string query) { // Set up file for returning the query var fileMock = new Mock(); - fileMock.SetupGet(file => file.Contents).Returns(StandardQuery); + fileMock.SetupGet(file => file.Contents).Returns(query); // Set up workspace mock var workspaceService = new Mock>(); @@ -300,6 +223,5 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution } #endregion - } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/FileStreamWrapperTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/FileStreamWrapperTests.cs deleted file mode 100644 index 1ee471fd..00000000 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/FileStreamWrapperTests.cs +++ /dev/null @@ -1,218 +0,0 @@ -// -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. -// - -using System; -using System.IO; -using System.Linq; -using System.Text; -using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage; -using Xunit; - -namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.DataStorage -{ - public class FileStreamWrapperTests - { - [Theory] - [InlineData(null)] - [InlineData("")] - [InlineData(" ")] - public void InitInvalidFilenameParameter(string fileName) - { - // If: - // ... I have a file stream wrapper that is initialized with invalid fileName - // Then: - // ... It should throw an argument null exception - using (FileStreamWrapper fsw = new FileStreamWrapper()) - { - Assert.Throws(() => fsw.Init(fileName, 8192, FileAccess.Read)); - } - } - - [Theory] - [InlineData(0)] - [InlineData(-1)] - public void InitInvalidBufferLength(int bufferLength) - { - // If: - // ... I have a file stream wrapper that is initialized with an invalid buffer length - // Then: - // ... I should throw an argument out of range exception - using (FileStreamWrapper fsw = new FileStreamWrapper()) - { - Assert.Throws(() => fsw.Init("validFileName", bufferLength, FileAccess.Read)); - } - } - - [Fact] - public void InitInvalidFileAccessMode() - { - // If: - // ... I attempt to open a file stream wrapper that is initialized with an invalid file - // access mode - // Then: - // ... I should get an invalid argument exception - using (FileStreamWrapper fsw = new FileStreamWrapper()) - { - Assert.Throws(() => fsw.Init("validFileName", 8192, FileAccess.Write)); - } - } - - [Fact] - public void InitSuccessful() - { - string fileName = Path.GetTempFileName(); - - try - { - using (FileStreamWrapper fsw = new FileStreamWrapper()) - { - // If: - // ... I have a file stream wrapper that is initialized with valid parameters - fsw.Init(fileName, 8192, FileAccess.ReadWrite); - - // Then: - // ... The file should exist - FileInfo fileInfo = new FileInfo(fileName); - Assert.True(fileInfo.Exists); - } - } - finally - { - // Cleanup: - // ... Delete the file that was created - try { File.Delete(fileName); } catch { /* Don't care */ } - } - } - - [Fact] - public void PerformOpWithoutInit() - { - byte[] buf = new byte[10]; - - using (FileStreamWrapper fsw = new FileStreamWrapper()) - { - // If: - // ... I have a file stream wrapper that hasn't been initialized - // Then: - // ... Attempting to perform any operation will result in an exception - Assert.Throws(() => fsw.ReadData(buf, 1)); - Assert.Throws(() => fsw.ReadData(buf, 1, 0)); - Assert.Throws(() => fsw.WriteData(buf, 1)); - Assert.Throws(() => fsw.Flush()); - } - } - - [Fact] - public void PerformWriteOpOnReadOnlyWrapper() - { - byte[] buf = new byte[10]; - - using (FileStreamWrapper fsw = new FileStreamWrapper()) - { - // If: - // ... I have a readonly file stream wrapper - // Then: - // ... Attempting to perform any write operation should result in an exception - Assert.Throws(() => fsw.WriteData(buf, 1)); - Assert.Throws(() => fsw.Flush()); - } - } - - [Theory] - [InlineData(1024, 20, 10)] // Standard scenario - [InlineData(1024, 100, 100)] // Requested more bytes than there are - [InlineData(5, 20, 10)] // Internal buffer too small, force a move-to operation - public void ReadData(int internalBufferLength, int outBufferLength, int requestedBytes) - { - // Setup: - // ... I have a file that has a handful of bytes in it - string fileName = Path.GetTempFileName(); - const string stringToWrite = "hello"; - CreateTestFile(fileName, stringToWrite); - byte[] targetBytes = Encoding.Unicode.GetBytes(stringToWrite); - - try - { - // If: - // ... I have a file stream wrapper that has been initialized to an existing file - // ... And I read some bytes from it - int bytesRead; - byte[] buf = new byte[outBufferLength]; - using (FileStreamWrapper fsw = new FileStreamWrapper()) - { - fsw.Init(fileName, internalBufferLength, FileAccess.Read); - bytesRead = fsw.ReadData(buf, targetBytes.Length); - } - - // Then: - // ... I should get those bytes back - Assert.Equal(targetBytes.Length, bytesRead); - Assert.True(targetBytes.Take(targetBytes.Length).SequenceEqual(buf.Take(targetBytes.Length))); - - } - finally - { - // Cleanup: - // ... Delete the test file - CleanupTestFile(fileName); - } - } - - [Theory] - [InlineData(1024)] // Standard scenario - [InlineData(10)] // Internal buffer too small, forces a flush - public void WriteData(int internalBufferLength) - { - string fileName = Path.GetTempFileName(); - byte[] bytesToWrite = Encoding.Unicode.GetBytes("hello"); - - try - { - // If: - // ... I have a file stream that has been initialized - // ... And I write some bytes to it - using (FileStreamWrapper fsw = new FileStreamWrapper()) - { - fsw.Init(fileName, internalBufferLength, FileAccess.ReadWrite); - int bytesWritten = fsw.WriteData(bytesToWrite, bytesToWrite.Length); - - Assert.Equal(bytesToWrite.Length, bytesWritten); - } - - // Then: - // ... The file I wrote to should contain only the bytes I wrote out - using (FileStream fs = File.OpenRead(fileName)) - { - byte[] readBackBytes = new byte[1024]; - int bytesRead = fs.Read(readBackBytes, 0, readBackBytes.Length); - - Assert.Equal(bytesToWrite.Length, bytesRead); // If bytes read is not equal, then more or less of the original string was written to the file - Assert.True(bytesToWrite.SequenceEqual(readBackBytes.Take(bytesRead))); - } - } - finally - { - // Cleanup: - // ... Delete the test file - CleanupTestFile(fileName); - } - } - - private static void CreateTestFile(string fileName, string value) - { - using (FileStream fs = new FileStream(fileName, FileMode.OpenOrCreate, FileAccess.ReadWrite)) - { - byte[] bytesToWrite = Encoding.Unicode.GetBytes(value); - fs.Write(bytesToWrite, 0, bytesToWrite.Length); - fs.Flush(); - } - } - - private static void CleanupTestFile(string fileName) - { - try { File.Delete(fileName); } catch { /* Don't Care */} - } - } -} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/ServiceBufferFileStreamReaderWriterTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/ServiceBufferFileStreamReaderWriterTests.cs index 87076204..40496d63 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/ServiceBufferFileStreamReaderWriterTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/ServiceBufferFileStreamReaderWriterTests.cs @@ -6,45 +6,102 @@ using System; using System.Collections.Generic; using System.Data.SqlTypes; +using System.IO; using System.Text; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; +using Moq; using Xunit; namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.DataStorage { public class ReaderWriterPairTest { + [Fact] + public void ReaderInvalidStreamCannotRead() + { + // If: I create a service buffer file stream reader with a stream that cannot be read + // Then: I should get an exception + var invalidStream = new Mock(); + invalidStream.SetupGet(s => s.CanRead).Returns(false); + invalidStream.SetupGet(s => s.CanSeek).Returns(true); + Assert.Throws(() => + { + ServiceBufferFileStreamReader obj = new ServiceBufferFileStreamReader(invalidStream.Object); + obj.Dispose(); + }); + } + + [Fact] + public void ReaderInvalidStreamCannotSeek() + { + // If: I create a service buffer file stream reader with a stream that cannot seek + // Then: I should get an exception + var invalidStream = new Mock(); + invalidStream.SetupGet(s => s.CanRead).Returns(true); + invalidStream.SetupGet(s => s.CanSeek).Returns(false); + Assert.Throws(() => + { + ServiceBufferFileStreamReader obj = new ServiceBufferFileStreamReader(invalidStream.Object); + obj.Dispose(); + }); + } + + [Fact] + public void WriterInvalidStreamCannotWrite() + { + // If: I create a service buffer file stream writer with a stream that cannot be read + // Then: I should get an exception + var invalidStream = new Mock(); + invalidStream.SetupGet(s => s.CanWrite).Returns(false); + invalidStream.SetupGet(s => s.CanSeek).Returns(true); + Assert.Throws(() => + { + ServiceBufferFileStreamWriter obj = new ServiceBufferFileStreamWriter(invalidStream.Object, 1024, 1024); + obj.Dispose(); + }); + } + + [Fact] + public void WriterInvalidStreamCannotSeek() + { + // If: I create a service buffer file stream writer with a stream that cannot seek + // Then: I should get an exception + var invalidStream = new Mock(); + invalidStream.SetupGet(s => s.CanWrite).Returns(true); + invalidStream.SetupGet(s => s.CanSeek).Returns(false); + Assert.Throws(() => + { + ServiceBufferFileStreamWriter obj = new ServiceBufferFileStreamWriter(invalidStream.Object, 1024, 1024); + obj.Dispose(); + }); + } + private static void VerifyReadWrite(int valueLength, T value, Func writeFunc, Func readFunc) { - // Setup: Create a mock file stream wrapper - Common.InMemoryWrapper mockWrapper = new Common.InMemoryWrapper(); - try + // Setup: Create a mock file stream + byte[] storage = new byte[8192]; + + // If: + // ... I write a type T to the writer + using (ServiceBufferFileStreamWriter writer = new ServiceBufferFileStreamWriter(new MemoryStream(storage), 10, 10)) { - // If: - // ... I write a type T to the writer - using (ServiceBufferFileStreamWriter writer = new ServiceBufferFileStreamWriter(mockWrapper, "abc", 10, 10)) - { - int writtenBytes = writeFunc(writer, value); - Assert.Equal(valueLength, writtenBytes); - } - - // ... And read the type T back - FileStreamReadResult outValue; - using (ServiceBufferFileStreamReader reader = new ServiceBufferFileStreamReader(mockWrapper, "abc")) - { - outValue = readFunc(reader); - } - - // Then: - Assert.Equal(value, outValue.Value.RawObject); - Assert.Equal(valueLength, outValue.TotalLength); - Assert.NotNull(outValue.Value); + int writtenBytes = writeFunc(writer, value); + Assert.Equal(valueLength, writtenBytes); } - finally + + // ... And read the type T back + FileStreamReadResult outValue; + using (ServiceBufferFileStreamReader reader = new ServiceBufferFileStreamReader(new MemoryStream(storage))) { - // Cleanup: Close the wrapper - mockWrapper.Close(); + outValue = readFunc(reader); } + + // Then: + Assert.Equal(value, outValue.Value.RawObject); + Assert.Equal(valueLength, outValue.TotalLength); + Assert.NotNull(outValue.Value); } [Theory] @@ -174,18 +231,21 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.DataStorage } } - [Fact] - public void DateTimeTest() + [Theory] + [InlineData(3)] // Scale 3 = DATETIME + [InlineData(7)] // Scale 7 = DATETIME2 + public void DateTimeTest(int scale) { - // Setup: Create some test values + // Setup: Create some test values and a column with scale set // NOTE: We are doing these here instead of InlineData because DateTime values can't be written as constant expressions + DbColumnWrapper col = new DbColumnWrapper(new TestDbColumn("dbcol", scale)); DateTime[] testValues = { DateTime.Now, DateTime.UtcNow, DateTime.MinValue, DateTime.MaxValue }; foreach (DateTime value in testValues) { - VerifyReadWrite(sizeof(long) + 1, value, (writer, val) => writer.WriteDateTime(val), reader => reader.ReadDateTime(0)); + VerifyReadWrite(sizeof(long) + sizeof(int) + 1, value, (writer, val) => writer.WriteDateTime(col, val), reader => reader.ReadDateTime(0)); } } @@ -222,16 +282,17 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.DataStorage [Fact] public void StringNullTest() { - // Setup: Create a mock file stream wrapper - Common.InMemoryWrapper mockWrapper = new Common.InMemoryWrapper(); - - // If: - // ... I write null as a string to the writer - using (ServiceBufferFileStreamWriter writer = new ServiceBufferFileStreamWriter(mockWrapper, "abc", 10, 10)) + // Setup: Create a mock file stream + using (MemoryStream stream = new MemoryStream(new byte[8192])) { - // Then: - // ... I should get an argument null exception - Assert.Throws(() => writer.WriteString(null)); + // If: + // ... I write null as a string to the writer + using (ServiceBufferFileStreamWriter writer = new ServiceBufferFileStreamWriter(stream, 10, 10)) + { + // Then: + // ... I should get an argument null exception + Assert.Throws(() => writer.WriteString(null)); + } } } @@ -259,15 +320,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.DataStorage public void BytesNullTest() { // Setup: Create a mock file stream wrapper - Common.InMemoryWrapper mockWrapper = new Common.InMemoryWrapper(); - - // If: - // ... I write null as a string to the writer - using (ServiceBufferFileStreamWriter writer = new ServiceBufferFileStreamWriter(mockWrapper, "abc", 10, 10)) + using (MemoryStream stream = new MemoryStream(new byte[8192])) { - // Then: - // ... I should get an argument null exception - Assert.Throws(() => writer.WriteBytes(null)); + // If: + // ... I write null as a string to the writer + using (ServiceBufferFileStreamWriter writer = new ServiceBufferFileStreamWriter(stream, 10, 10)) + { + // Then: + // ... I should get an argument null exception + Assert.Throws(() => writer.WriteBytes(null)); + } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/StorageDataReaderTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/StorageDataReaderTests.cs new file mode 100644 index 00000000..e29c8394 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/StorageDataReaderTests.cs @@ -0,0 +1,99 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +#if LIVE_CONNECTION_TESTS + +using System.Data.Common; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; +using Microsoft.SqlTools.Test.Utility; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.DataStorage +{ + public class StorageDataReaderTests + { + private StorageDataReader GetTestStorageDataReader(out DbDataReader reader, string query) + { + ScriptFile scriptFile; + ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfo(out scriptFile); + + var command = connInfo.SqlConnection.CreateCommand(); + command.CommandText = query; + reader = command.ExecuteReader(); + + return new StorageDataReader(reader); + } + + /// + /// Validate GetBytesWithMaxCapacity + /// + [Fact] + public void GetBytesWithMaxCapacityTest() + { + DbDataReader reader; + var storageReader = GetTestStorageDataReader( + out reader, + "SELECT CAST([name] as TEXT) As TextName FROM sys.all_columns"); + + reader.Read(); + Assert.False(storageReader.IsDBNull(0)); + + byte[] bytes = storageReader.GetBytesWithMaxCapacity(0, 100); + Assert.NotNull(bytes); + } + + /// + /// Validate GetCharsWithMaxCapacity + /// + [Fact] + public void GetCharsWithMaxCapacityTest() + { + DbDataReader reader; + var storageReader = GetTestStorageDataReader( + out reader, + "SELECT name FROM sys.all_columns"); + + reader.Read(); + Assert.False(storageReader.IsDBNull(0)); + + string shortName = storageReader.GetCharsWithMaxCapacity(0, 2); + Assert.True(shortName.Length == 2); + } + + /// + /// Validate GetXmlWithMaxCapacity + /// + [Fact] + public void GetXmlWithMaxCapacityTest() + { + DbDataReader reader; + var storageReader = GetTestStorageDataReader( + out reader, + "SELECT CAST('Test XML context' AS XML) As XmlColumn"); + + reader.Read(); + Assert.False(storageReader.IsDBNull(0)); + + string shortXml = storageReader.GetXmlWithMaxCapacity(0, 2); + Assert.True(shortXml.Length == 3); + } + + /// + /// Validate StringWriterWithMaxCapacity Write test + /// + [Fact] + public void StringWriterWithMaxCapacityTest() + { + var writer = new StorageDataReader.StringWriterWithMaxCapacity(null, 100); + string output = "..."; + writer.Write(output); + Assert.True(writer.ToString().Equals(output)); + } + } +} + +#endif \ No newline at end of file diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs index 2f103ec0..61aee8b8 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs @@ -28,7 +28,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution var mockDataReader = Common.CreateTestConnection(null, false).CreateCommand().ExecuteReaderAsync().Result; // If: I setup a single resultset and then dispose it - ResultSet rs = new ResultSet(mockDataReader, mockFileStreamFactory.Object); + ResultSet rs = new ResultSet(mockDataReader, Common.Ordinal, Common.Ordinal, mockFileStreamFactory.Object); rs.Dispose(); // Then: The file that was created should have been deleted @@ -47,7 +47,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution .Returns(fileMock.Object); // If: // ... I request a query (doesn't matter what kind) - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = Common.GetPrimedExecutionService(null, true, false, workspaceService.Object); var executeParams = new QueryExecuteParams {QuerySelection = null, OwnerUri = Common.OwnerUri}; var executeRequest = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); @@ -59,7 +59,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution var disposeRequest = GetQueryDisposeResultContextMock(qdr => { result = qdr; }, null); - queryService.HandleDisposeRequest(disposeParams, disposeRequest.Object).Wait(); + await queryService.HandleDisposeRequest(disposeParams, disposeRequest.Object); // Then: // ... I should have seen a successful result @@ -75,11 +75,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution var workspaceService = new Mock>(); // If: // ... I attempt to dispose a query that doesn't exist - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), false, workspaceService.Object); + var queryService = Common.GetPrimedExecutionService(null, false, false, workspaceService.Object); var disposeParams = new QueryDisposeParams {OwnerUri = Common.OwnerUri}; QueryDisposeResult result = null; var disposeRequest = GetQueryDisposeResultContextMock(qdr => result = qdr, null); - queryService.HandleDisposeRequest(disposeParams, disposeRequest.Object).Wait(); + await queryService.HandleDisposeRequest(disposeParams, disposeRequest.Object); // Then: // ... I should have gotten an error result @@ -99,8 +99,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) .Returns(fileMock.Object); // ... We need a query service - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, - workspaceService.Object); + var queryService = Common.GetPrimedExecutionService(null, true, false, workspaceService.Object); + // If: // ... I execute some bogus query diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs index 72137474..c61237fa 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs @@ -5,666 +5,17 @@ //#define USE_LIVE_CONNECTION -using System; using System.Data.Common; -using System.Linq; -using System.Threading; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; -using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; using Microsoft.SqlTools.ServiceLayer.QueryExecution; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; -using Microsoft.SqlTools.ServiceLayer.SqlContext; -using Microsoft.SqlTools.ServiceLayer.Test.Utility; -using Microsoft.SqlTools.ServiceLayer.Workspace; -using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; -using Moq; -using Xunit; namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution { public class ExecuteTests { - #region Batch Class Tests - - [Fact] - public void BatchCreationTest() - { - // If I create a new batch... - Batch batch = new Batch(Common.StandardQuery, 0, 0, 2, 2, Common.GetFileStreamFactory()); - - // Then: - // ... The text of the batch should be stored - Assert.NotEmpty(batch.BatchText); - - // ... It should not have executed and no error - Assert.False(batch.HasExecuted, "The query should not have executed."); - Assert.False(batch.HasError, "The batch should not have an error"); - - // ... The results should be empty - Assert.Empty(batch.ResultSets); - Assert.Empty(batch.ResultSummaries); - Assert.Empty(batch.ResultMessages); - - // ... The start line of the batch should be 0 - Assert.Equal(0, batch.Selection.StartLine); - } - - [Fact] - public void BatchExecuteNoResultSets() - { - // If I execute a query that should get no result sets - Batch batch = new Batch(Common.StandardQuery, 0, 0, 2, 2, Common.GetFileStreamFactory()); - batch.Execute(GetConnection(Common.CreateTestConnectionInfo(null, false)), CancellationToken.None).Wait(); - - // Then: - // ... It should have executed without error - Assert.True(batch.HasExecuted, "The query should have been marked executed."); - Assert.False(batch.HasError, "The batch should not have an error"); - - // ... The results should be empty - Assert.Empty(batch.ResultSets); - Assert.Empty(batch.ResultSummaries); - - // ... The results should not be null - Assert.NotNull(batch.ResultSets); - Assert.NotNull(batch.ResultSummaries); - - // ... There should be a message for how many rows were affected - Assert.Equal(1, batch.ResultMessages.Count()); - } - - [Fact] - public void BatchExecuteOneResultSet() - { - int resultSets = 1; - ConnectionInfo ci = Common.CreateTestConnectionInfo(new[] { Common.StandardTestData }, false); - - // If I execute a query that should get one result set - Batch batch = new Batch(Common.StandardQuery, 0, 0, 2, 2, Common.GetFileStreamFactory()); - batch.Execute(GetConnection(ci), CancellationToken.None).Wait(); - - // Then: - // ... It should have executed without error - Assert.True(batch.HasExecuted, "The batch should have been marked executed."); - Assert.False(batch.HasError, "The batch should not have an error"); - - // ... There should be exactly one result set - Assert.Equal(resultSets, batch.ResultSets.Count()); - Assert.Equal(resultSets, batch.ResultSummaries.Length); - - // ... Inside the result set should be with 5 rows - Assert.Equal(Common.StandardRows, batch.ResultSets.First().RowCount); - Assert.Equal(Common.StandardRows, batch.ResultSummaries[0].RowCount); - - // ... Inside the result set should have 5 columns - Assert.Equal(Common.StandardColumns, batch.ResultSets.First().Columns.Length); - Assert.Equal(Common.StandardColumns, batch.ResultSummaries[0].ColumnInfo.Length); - - // ... There should be a message for how many rows were affected - Assert.Equal(resultSets, batch.ResultMessages.Count()); - } - - [Fact] - public void BatchExecuteTwoResultSets() - { - var dataset = new[] { Common.StandardTestData, Common.StandardTestData }; - int resultSets = dataset.Length; - ConnectionInfo ci = Common.CreateTestConnectionInfo(dataset, false); - - // If I execute a query that should get two result sets - Batch batch = new Batch(Common.StandardQuery, 0, 0, 1, 1, Common.GetFileStreamFactory()); - batch.Execute(GetConnection(ci), CancellationToken.None).Wait(); - - // Then: - // ... It should have executed without error - Assert.True(batch.HasExecuted, "The batch should have been marked executed."); - Assert.False(batch.HasError, "The batch should not have an error"); - - // ... There should be exactly two result sets - Assert.Equal(resultSets, batch.ResultSets.Count()); - - foreach (ResultSet rs in batch.ResultSets) - { - // ... Each result set should have 5 rows - Assert.Equal(Common.StandardRows, rs.RowCount); - - // ... Inside each result set should be 5 columns - Assert.Equal(Common.StandardColumns, rs.Columns.Length); - } - - // ... There should be exactly two result set summaries - Assert.Equal(resultSets, batch.ResultSummaries.Length); - - foreach (ResultSetSummary rs in batch.ResultSummaries) - { - // ... Inside each result summary, there should be 5 rows - Assert.Equal(Common.StandardRows, rs.RowCount); - - // ... Inside each result summary, there should be 5 column definitions - Assert.Equal(Common.StandardColumns, rs.ColumnInfo.Length); - } - } - - [Fact] - public void BatchExecuteInvalidQuery() - { - ConnectionInfo ci = Common.CreateTestConnectionInfo(null, true); - - // If I execute a batch that is invalid - Batch batch = new Batch(Common.StandardQuery, 0, 0, 2, 2, Common.GetFileStreamFactory()); - batch.Execute(GetConnection(ci), CancellationToken.None).Wait(); - - // Then: - // ... It should have executed with error - Assert.True(batch.HasExecuted); - Assert.True(batch.HasError); - - // ... There should be no result sets - Assert.Empty(batch.ResultSets); - Assert.Empty(batch.ResultSummaries); - - // ... There should be plenty of messages for the error - Assert.NotEmpty(batch.ResultMessages); - } - - [Fact] - public async Task BatchExecuteExecuted() - { - ConnectionInfo ci = Common.CreateTestConnectionInfo(new[] { Common.StandardTestData }, false); - - // If I execute a batch - Batch batch = new Batch(Common.StandardQuery, 0, 0, 2, 2, Common.GetFileStreamFactory()); - batch.Execute(GetConnection(ci), CancellationToken.None).Wait(); - - // Then: - // ... It should have executed without error - Assert.True(batch.HasExecuted, "The batch should have been marked executed."); - Assert.False(batch.HasError, "The batch should not have an error"); - - // If I execute it again - // Then: - // ... It should throw an invalid operation exception - await Assert.ThrowsAsync(() => - batch.Execute(GetConnection(ci), CancellationToken.None)); - - // ... The data should still be available without error - Assert.False(batch.HasError, "The batch should not be in an error condition"); - Assert.True(batch.HasExecuted, "The batch should still be marked executed."); - Assert.NotEmpty(batch.ResultSets); - Assert.NotEmpty(batch.ResultSummaries); - } - - [Theory] - [InlineData("")] - [InlineData(null)] - public void BatchExecuteNoSql(string query) - { - // If: - // ... I create a batch that has an empty query - // Then: - // ... It should throw an exception - Assert.Throws(() => new Batch(query, 0, 0, 2, 2, Common.GetFileStreamFactory())); - } - - [Fact] - public void BatchNoBufferFactory() - { - // If: - // ... I create a batch that has no file stream factory - // Then: - // ... It should throw an exception - Assert.Throws(() => new Batch("stuff", 0, 0, 2, 2, null)); - } - - #endregion - - #region Query Class Tests - - [Fact] - public void QueryExecuteNoQueryText() - { - // If: - // ... I create a query that has a null query text - // Then: - // ... It should throw an exception - Assert.Throws(() => - new Query(null, Common.CreateTestConnectionInfo(null, false), new QueryExecutionSettings(), Common.GetFileStreamFactory())); - } - - [Fact] - public void QueryExecuteNoConnectionInfo() - { - // If: - // ... I create a query that has a null connection info - // Then: - // ... It should throw an exception - Assert.Throws(() => new Query("Some Query", null, new QueryExecutionSettings(), Common.GetFileStreamFactory())); - } - - [Fact] - public void QueryExecuteNoSettings() - { - // If: - // ... I create a query that has a null settings - // Then: - // ... It should throw an exception - Assert.Throws(() => - new Query("Some query", Common.CreateTestConnectionInfo(null, false), null, Common.GetFileStreamFactory())); - } - - [Fact] - public void QueryExecuteNoBufferFactory() - { - // If: - // ... I create a query that has a null file stream factory - // Then: - // ... It should throw an exception - Assert.Throws(() => - new Query("Some query", Common.CreateTestConnectionInfo(null, false), new QueryExecutionSettings(),null)); - } - - [Fact] - public void QueryExecuteSingleBatch() - { - // If: - // ... I create a query from a single batch (without separator) - ConnectionInfo ci = Common.CreateTestConnectionInfo(null, false); - Query query = new Query(Common.StandardQuery, ci, new QueryExecutionSettings(), Common.GetFileStreamFactory()); - - // Then: - // ... I should get a single batch to execute that hasn't been executed - Assert.NotEmpty(query.QueryText); - Assert.NotEmpty(query.Batches); - Assert.Equal(1, query.Batches.Length); - Assert.False(query.HasExecuted); - Assert.Throws(() => query.BatchSummaries); - - // If: - // ... I then execute the query - query.Execute(); - query.ExecutionTask.Wait(); - - // Then: - // ... The query should have completed successfully with one batch summary returned - Assert.True(query.HasExecuted); - Assert.NotEmpty(query.BatchSummaries); - Assert.Equal(1, query.BatchSummaries.Length); - } - - [Fact] - public void QueryExecuteNoOpBatch() - { - // If: - // ... I create a query from a single batch that does nothing - ConnectionInfo ci = Common.CreateTestConnectionInfo(null, false); - Query query = new Query(Common.NoOpQuery, ci, new QueryExecutionSettings(), Common.GetFileStreamFactory()); - - // Then: - // ... I should get no batches back - Assert.NotEmpty(query.QueryText); - Assert.Empty(query.Batches); - Assert.False(query.HasExecuted); - Assert.Throws(() => query.BatchSummaries); - - // If: - // ... I Then execute the query - query.Execute(); - query.ExecutionTask.Wait(); - - // Then: - // ... The query should have completed successfully with no batch summaries returned - Assert.True(query.HasExecuted); - Assert.Empty(query.BatchSummaries); - } - - [Fact] - public void QueryExecuteMultipleBatches() - { - // If: - // ... I create a query from two batches (with separator) - ConnectionInfo ci = Common.CreateTestConnectionInfo(null, false); - string queryText = string.Format("{0}\r\nGO\r\n{0}", Common.StandardQuery); - Query query = new Query(queryText, ci, new QueryExecutionSettings(), Common.GetFileStreamFactory()); - - // Then: - // ... I should get back two batches to execute that haven't been executed - Assert.NotEmpty(query.QueryText); - Assert.NotEmpty(query.Batches); - Assert.Equal(2, query.Batches.Length); - Assert.False(query.HasExecuted); - Assert.Throws(() => query.BatchSummaries); - - // If: - // ... I then execute the query - query.Execute(); - query.ExecutionTask.Wait(); - - // Then: - // ... The query should have completed successfully with two batch summaries returned - Assert.True(query.HasExecuted); - Assert.NotEmpty(query.BatchSummaries); - Assert.Equal(2, query.BatchSummaries.Length); - } - - [Fact] - public void QueryExecuteMultipleBatchesWithNoOp() - { - // If: - // ... I create a query from a two batches (with separator) - ConnectionInfo ci = Common.CreateTestConnectionInfo(null, false); - string queryText = string.Format("{0}\r\nGO\r\n{1}", Common.StandardQuery, Common.NoOpQuery); - Query query = new Query(queryText, ci, new QueryExecutionSettings(), Common.GetFileStreamFactory()); - - // Then: - // ... I should get back one batch to execute that hasn't been executed - Assert.NotEmpty(query.QueryText); - Assert.NotEmpty(query.Batches); - Assert.Equal(1, query.Batches.Length); - Assert.False(query.HasExecuted); - Assert.Throws(() => query.BatchSummaries); - - // If: - // .. I then execute the query - query.Execute(); - query.ExecutionTask.Wait(); - - // ... The query should have completed successfully with one batch summary returned - Assert.True(query.HasExecuted); - Assert.NotEmpty(query.BatchSummaries); - Assert.Equal(1, query.BatchSummaries.Length); - } - - [Fact] - public void QueryExecuteInvalidBatch() - { - // If: - // ... I create a query from an invalid batch - ConnectionInfo ci = Common.CreateTestConnectionInfo(null, true); - Query query = new Query(Common.InvalidQuery, ci, new QueryExecutionSettings(), Common.GetFileStreamFactory()); - - // Then: - // ... I should get back a query with one batch not executed - Assert.NotEmpty(query.QueryText); - Assert.NotEmpty(query.Batches); - Assert.Equal(1, query.Batches.Length); - Assert.False(query.HasExecuted); - Assert.Throws(() => query.BatchSummaries); - - // If: - // ... I then execute the query - query.Execute(); - query.ExecutionTask.Wait(); - - // Then: - // ... There should be an error on the batch - Assert.True(query.HasExecuted); - Assert.NotEmpty(query.BatchSummaries); - Assert.Equal(1, query.BatchSummaries.Length); - Assert.True(query.BatchSummaries[0].HasError); - Assert.NotEmpty(query.BatchSummaries[0].Messages); - } - - #endregion - - #region Service Tests - - [Fact] - public async void QueryExecuteValidNoResultsTest() - { - // Given: - // ... Default settings are stored in the workspace service - WorkspaceService.Instance.CurrentSettings = new SqlToolsSettings(); - - // Set up file for returning the query - var fileMock = new Mock(); - fileMock.SetupGet(file => file.Contents).Returns(Common.StandardQuery); - // Set up workspace mock - var workspaceService = new Mock>(); - workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) - .Returns(fileMock.Object); - // If: - // ... I request to execute a valid query with no results - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); - var queryParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; - - QueryExecuteResult result = null; - QueryExecuteCompleteParams completeParams = null; - var requestContext = - RequestContextMocks.SetupRequestContextMock( - resultCallback: qer => result = qer, - expectedEvent: QueryExecuteCompleteEvent.Type, - eventCallback: (et, cp) => completeParams = cp, - errorCallback: null); - await AwaitExecution(queryService, queryParams, requestContext.Object); - - // Then: - // ... No Errors should have been sent - // ... A successful result should have been sent with messages on the first batch - // ... A completion event should have been fired with empty results - VerifyQueryExecuteCallCount(requestContext, Times.Once(), Times.Once(), Times.Never()); - Assert.Null(result.Messages); - Assert.Equal(1, completeParams.BatchSummaries.Length); - Assert.Empty(completeParams.BatchSummaries[0].ResultSetSummaries); - Assert.NotEmpty(completeParams.BatchSummaries[0].Messages); - - // ... There should be one active query - Assert.Equal(1, queryService.ActiveQueries.Count); - } - - [Fact] - public async void QueryExecuteValidResultsTest() - { - - // Set up file for returning the query - var fileMock = new Mock(); - fileMock.SetupGet(file => file.Contents).Returns(Common.StandardQuery); - // Set up workspace mock - var workspaceService = new Mock>(); - workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) - .Returns(fileMock.Object); - // If: - // ... I request to execute a valid query with results - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(new[] { Common.StandardTestData }, false), true, - workspaceService.Object); - var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument }; - - QueryExecuteResult result = null; - QueryExecuteCompleteParams completeParams = null; - var requestContext = - RequestContextMocks.SetupRequestContextMock( - resultCallback: qer => result = qer, - expectedEvent: QueryExecuteCompleteEvent.Type, - eventCallback: (et, cp) => completeParams = cp, - errorCallback: null); - await AwaitExecution(queryService, queryParams, requestContext.Object); - - // Then: - // ... No errors should have been sent - // ... A successful result should have been sent with messages - // ... A completion event should have been fired with one result - VerifyQueryExecuteCallCount(requestContext, Times.Once(), Times.Once(), Times.Never()); - Assert.Null(result.Messages); - Assert.Equal(1, completeParams.BatchSummaries.Length); - Assert.NotEmpty(completeParams.BatchSummaries[0].ResultSetSummaries); - Assert.NotEmpty(completeParams.BatchSummaries[0].Messages); - Assert.False(completeParams.BatchSummaries[0].HasError); - - // ... There should be one active query - Assert.Equal(1, queryService.ActiveQueries.Count); - } - - [Fact] - public async void QueryExecuteUnconnectedUriTest() - { - - var workspaceService = new Mock>(); - // If: - // ... I request to execute a query using a file URI that isn't connected - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), false, workspaceService.Object); - var queryParams = new QueryExecuteParams { OwnerUri = "notConnected", QuerySelection = Common.WholeDocument }; - - object error = null; - var requestContext = RequestContextMocks.Create(null) - .AddErrorHandling(e => error = e); - await queryService.HandleExecuteRequest(queryParams, requestContext.Object); - - // Then: - // ... An error should have been returned - // ... No result should have been returned - // ... No completion event should have been fired - // ... There should be no active queries - VerifyQueryExecuteCallCount(requestContext, Times.Never(), Times.Never(), Times.Once()); - Assert.IsType(error); - Assert.NotEmpty((string)error); - Assert.Empty(queryService.ActiveQueries); - } - - [Fact] - public async void QueryExecuteInProgressTest() - { - - // Set up file for returning the query - var fileMock = new Mock(); - fileMock.SetupGet(file => file.Contents).Returns(Common.StandardQuery); - // Set up workspace mock - var workspaceService = new Mock>(); - workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) - .Returns(fileMock.Object); - - // If: - // ... I request to execute a query - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); - var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument }; - - // Note, we don't care about the results of the first request - var firstRequestContext = RequestContextMocks.Create(null); - await AwaitExecution(queryService, queryParams, firstRequestContext.Object); - - // ... And then I request another query without waiting for the first to complete - queryService.ActiveQueries[Common.OwnerUri].HasExecuted = false; // Simulate query hasn't finished - object error = null; - var secondRequestContext = RequestContextMocks.Create(null) - .AddErrorHandling(e => error = e); - await AwaitExecution(queryService, queryParams, secondRequestContext.Object); - - // Then: - // ... An error should have been sent - // ... A result should have not have been sent - // ... No completion event should have been fired - // ... The original query should exist - VerifyQueryExecuteCallCount(secondRequestContext, Times.Never(), Times.Never(), Times.Once()); - Assert.IsType(error); - Assert.NotEmpty((string)error); - Assert.Contains(Common.OwnerUri, queryService.ActiveQueries.Keys); - } - - [Fact] - public async void QueryExecuteCompletedTest() - { - - // Set up file for returning the query - var fileMock = new Mock(); - fileMock.SetupGet(file => file.Contents).Returns(Common.StandardQuery); - // Set up workspace mock - var workspaceService = new Mock>(); - workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) - .Returns(fileMock.Object); - - // If: - // ... I request to execute a query - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); - var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument }; - - // Note, we don't care about the results of the first request - var firstRequestContext = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); - await AwaitExecution(queryService, queryParams, firstRequestContext.Object); - - // ... And then I request another query after waiting for the first to complete - QueryExecuteResult result = null; - QueryExecuteCompleteParams complete = null; - var secondRequestContext = - RequestContextMocks.SetupRequestContextMock(qer => result = qer, QueryExecuteCompleteEvent.Type, (et, qecp) => complete = qecp, null); - await AwaitExecution(queryService, queryParams, secondRequestContext.Object); - - // Then: - // ... No errors should have been sent - // ... A result should have been sent with no errors - // ... There should only be one active query - VerifyQueryExecuteCallCount(secondRequestContext, Times.Once(), Times.Once(), Times.Never()); - Assert.Null(result.Messages); - Assert.False(complete.BatchSummaries.Any(b => b.HasError)); - Assert.Equal(1, queryService.ActiveQueries.Count); - } - - [Theory] - [InlineData(null)] - public async Task QueryExecuteMissingSelectionTest(SelectionData selection) - { - - // Set up file for returning the query - var fileMock = new Mock(); - fileMock.SetupGet(file => file.Contents).Returns(""); - // Set up workspace mock - var workspaceService = new Mock>(); - workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) - .Returns(fileMock.Object); - // If: - // ... I request to execute a query with a missing query string - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); - var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QuerySelection = selection }; - - object errorResult = null; - var requestContext = RequestContextMocks.Create(null) - .AddErrorHandling(error => errorResult = error); - await queryService.HandleExecuteRequest(queryParams, requestContext.Object); - - // Then: - // ... Am error should have been sent - // ... No result should have been sent - // ... No completion event should have been fired - // ... An active query should not have been added - VerifyQueryExecuteCallCount(requestContext, Times.Never(), Times.Never(), Times.Once()); - Assert.NotNull(errorResult); - Assert.IsType(errorResult); - Assert.DoesNotContain(Common.OwnerUri, queryService.ActiveQueries.Keys); - - // ... There should not be an active query - Assert.Empty(queryService.ActiveQueries); - } - - [Fact] - public async void QueryExecuteInvalidQueryTest() - { - // Set up file for returning the query - var fileMock = new Mock(); - fileMock.SetupGet(file => file.Contents).Returns(Common.StandardQuery); - // Set up workspace mock - var workspaceService = new Mock>(); - workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) - .Returns(fileMock.Object); - // If: - // ... I request to execute a query that is invalid - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, true), true, workspaceService.Object); - var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument }; - - QueryExecuteResult result = null; - QueryExecuteCompleteParams complete = null; - var requestContext = - RequestContextMocks.SetupRequestContextMock(qer => result = qer, QueryExecuteCompleteEvent.Type, (et, qecp) => complete = qecp, null); - await AwaitExecution(queryService, queryParams, requestContext.Object); - - // Then: - // ... No errors should have been sent - // ... A result should have been sent with success (we successfully started the query) - // ... A completion event should have been sent with error - VerifyQueryExecuteCallCount(requestContext, Times.Once(), Times.Once(), Times.Never()); - Assert.Null(result.Messages); - Assert.Equal(1, complete.BatchSummaries.Length); - Assert.True(complete.BatchSummaries[0].HasError); - Assert.NotEmpty(complete.BatchSummaries[0].Messages); - } #if USE_LIVE_CONNECTION [Fact] @@ -693,28 +44,5 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution Assert.NotEmpty(query.BatchSummaries[0].Messages); } #endif - -#endregion - - private static void VerifyQueryExecuteCallCount(Mock> mock, Times sendResultCalls, Times sendEventCalls, Times sendErrorCalls) - { - mock.Verify(rc => rc.SendResult(It.IsAny()), sendResultCalls); - mock.Verify(rc => rc.SendEvent( - It.Is>(m => m == QueryExecuteCompleteEvent.Type), - It.IsAny()), sendEventCalls); - mock.Verify(rc => rc.SendError(It.IsAny()), sendErrorCalls); - } - - private static DbConnection GetConnection(ConnectionInfo info) - { - return info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails)); - } - - private static async Task AwaitExecution(QueryExecutionService service, QueryExecuteParams qeParams, - RequestContext requestContext) - { - await service.HandleExecuteRequest(qeParams, requestContext); - await service.ActiveQueries[qeParams.OwnerUri].ExecutionTask; - } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/BatchTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/BatchTests.cs new file mode 100644 index 00000000..12b99a3b --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/BatchTests.cs @@ -0,0 +1,399 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution +{ + public class BatchTests + { + [Fact] + public void BatchCreationTest() + { + // If I create a new batch... + Batch batch = new Batch(Common.StandardQuery, Common.SubsectionDocument, Common.Ordinal, Common.GetFileStreamFactory(null)); + + // Then: + // ... The text of the batch should be stored + Assert.NotEmpty(batch.BatchText); + + // ... It should not have executed and no error + Assert.False(batch.HasExecuted, "The query should not have executed."); + Assert.False(batch.HasError, "The batch should not have an error"); + + // ... The results should be empty + Assert.Empty(batch.ResultSets); + Assert.Empty(batch.ResultSummaries); + Assert.Empty(batch.ResultMessages); + + // ... The start line of the batch should be 0 + Assert.Equal(0, batch.Selection.StartLine); + + // ... It's ordinal ID should be what I set it to + Assert.Equal(Common.Ordinal, batch.Id); + + // ... The summary should have the same info + Assert.False(batch.Summary.HasError); + Assert.Equal(Common.Ordinal, batch.Summary.Id); + Assert.Null(batch.Summary.ResultSetSummaries); + Assert.Null(batch.Summary.Messages); + Assert.Equal(0, batch.Summary.Selection.StartLine); + Assert.NotEqual(default(DateTime).ToString("o"), batch.Summary.ExecutionStart); // Should have been set at construction + Assert.Null(batch.Summary.ExecutionEnd); + Assert.Null(batch.Summary.ExecutionElapsed); + } + + /// + /// Note: This test also tests the start notification feature + /// + [Fact] + public void BatchExecuteNoResultSets() + { + // Setup: + // ... Create a callback for batch start + BatchSummary batchSummaryFromStart = null; + Batch.BatchAsyncEventHandler batchStartCallback = b => + { + batchSummaryFromStart = b.Summary; + return Task.FromResult(0); + }; + + // ... Create a callback for batch completion + BatchSummary batchSummaryFromCompletion = null; + Batch.BatchAsyncEventHandler batchCompleteCallback = b => + { + batchSummaryFromCompletion = b.Summary; + return Task.FromResult(0); + }; + + // ... Create a callback for result completion + bool resultCallbackFired = false; + ResultSet.ResultSetAsyncEventHandler resultSetCallback = r => + { + resultCallbackFired = true; + return Task.FromResult(0); + }; + + // If I execute a query that should get no result sets + var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary()); + Batch batch = new Batch(Common.StandardQuery, Common.SubsectionDocument, Common.Ordinal, fileStreamFactory); + batch.BatchStart += batchStartCallback; + batch.BatchCompletion += batchCompleteCallback; + batch.ResultSetCompletion += resultSetCallback; + batch.Execute(GetConnection(Common.CreateTestConnectionInfo(null, false)), CancellationToken.None).Wait(); + + // Then: + // ... It should have executed without error + Assert.True(batch.HasExecuted, "The query should have been marked executed."); + Assert.False(batch.HasError, "The batch should not have an error"); + + // ... The results should be empty + Assert.Empty(batch.ResultSets); + Assert.Empty(batch.ResultSummaries); + + // ... The results should not be null + Assert.NotNull(batch.ResultSets); + Assert.NotNull(batch.ResultSummaries); + + // ... There should be a message for how many rows were affected + Assert.Equal(1, batch.ResultMessages.Count()); + + // ... The callback for batch start should have been called + // ... The info from it should have been basic + Assert.NotNull(batchSummaryFromStart); + Assert.False(batchSummaryFromStart.HasError); + Assert.Equal(Common.Ordinal, batchSummaryFromStart.Id); + Assert.Equal(Common.SubsectionDocument, batchSummaryFromStart.Selection); + Assert.True(DateTime.Parse(batchSummaryFromStart.ExecutionStart) > default(DateTime)); + Assert.Null(batchSummaryFromStart.ResultSetSummaries); + Assert.Null(batchSummaryFromStart.Messages); + Assert.Null(batchSummaryFromStart.ExecutionElapsed); + Assert.Null(batchSummaryFromStart.ExecutionEnd); + + // ... The callback for batch completion should have been fired + // ... The summary should match the expected info + Assert.NotNull(batchSummaryFromCompletion); + Assert.False(batchSummaryFromCompletion.HasError); + Assert.Equal(Common.Ordinal, batchSummaryFromCompletion.Id); + Assert.Equal(0, batchSummaryFromCompletion.ResultSetSummaries.Length); + Assert.Equal(1, batchSummaryFromCompletion.Messages.Length); + Assert.Equal(Common.SubsectionDocument, batchSummaryFromCompletion.Selection); + Assert.True(DateTime.Parse(batchSummaryFromCompletion.ExecutionStart) > default(DateTime)); + Assert.True(DateTime.Parse(batchSummaryFromCompletion.ExecutionEnd) > default(DateTime)); + Assert.NotNull(batchSummaryFromCompletion.ExecutionElapsed); + + // ... The callback for the result set should NOT have been fired + Assert.False(resultCallbackFired); + } + + [Fact] + public void BatchExecuteOneResultSet() + { + const int resultSets = 1; + ConnectionInfo ci = Common.CreateTestConnectionInfo(new[] { Common.StandardTestData }, false); + + // Setup: Create a callback for batch completion + BatchSummary batchSummaryFromCallback = null; + Batch.BatchAsyncEventHandler batchCallback = b => + { + batchSummaryFromCallback = b.Summary; + return Task.FromResult(0); + }; + + // ... Create a callback for result set completion + bool resultCallbackFired = false; + ResultSet.ResultSetAsyncEventHandler resultSetCallback = r => + { + resultCallbackFired = true; + return Task.FromResult(0); + }; + + // If I execute a query that should get one result set + var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary()); + Batch batch = new Batch(Common.StandardQuery, Common.SubsectionDocument, Common.Ordinal, fileStreamFactory); + batch.BatchCompletion += batchCallback; + batch.ResultSetCompletion += resultSetCallback; + batch.Execute(GetConnection(ci), CancellationToken.None).Wait(); + + // Then: + // ... It should have executed without error + Assert.True(batch.HasExecuted, "The batch should have been marked executed."); + Assert.False(batch.HasError, "The batch should not have an error"); + + // ... There should be exactly one result set + Assert.Equal(resultSets, batch.ResultSets.Count); + Assert.Equal(resultSets, batch.ResultSummaries.Length); + + // ... Inside the result set should be with 5 rows + Assert.Equal(Common.StandardRows, batch.ResultSets.First().RowCount); + Assert.Equal(Common.StandardRows, batch.ResultSummaries[0].RowCount); + + // ... Inside the result set should have 5 columns + Assert.Equal(Common.StandardColumns, batch.ResultSets.First().Columns.Length); + Assert.Equal(Common.StandardColumns, batch.ResultSummaries[0].ColumnInfo.Length); + + // ... There should be a message for how many rows were affected + Assert.Equal(resultSets, batch.ResultMessages.Count()); + + // ... The callback for batch completion should have been fired + Assert.NotNull(batchSummaryFromCallback); + + // ... The callback for resultset completion should have been fired + Assert.True(resultCallbackFired); // We only want to validate that it happened, validation of the + // summary is done in result set tests + } + + [Fact] + public void BatchExecuteTwoResultSets() + { + var dataset = new[] { Common.StandardTestData, Common.StandardTestData }; + int resultSets = dataset.Length; + ConnectionInfo ci = Common.CreateTestConnectionInfo(dataset, false); + + // Setup: Create a callback for batch completion + BatchSummary batchSummaryFromCallback = null; + Batch.BatchAsyncEventHandler batchCallback = b => + { + batchSummaryFromCallback = b.Summary; + return Task.FromResult(0); + }; + + // ... Create a callback for resultset completion + int resultSummaryCount = 0; + ResultSet.ResultSetAsyncEventHandler resultSetCallback = r => + { + resultSummaryCount++; + return Task.FromResult(0); + }; + + // If I execute a query that should get two result sets + var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary()); + Batch batch = new Batch(Common.StandardQuery, Common.SubsectionDocument, Common.Ordinal, fileStreamFactory); + batch.BatchCompletion += batchCallback; + batch.ResultSetCompletion += resultSetCallback; + batch.Execute(GetConnection(ci), CancellationToken.None).Wait(); + + // Then: + // ... It should have executed without error + Assert.True(batch.HasExecuted, "The batch should have been marked executed."); + Assert.False(batch.HasError, "The batch should not have an error"); + + // ... There should be exactly two result sets + Assert.Equal(resultSets, batch.ResultSets.Count()); + + foreach (ResultSet rs in batch.ResultSets) + { + // ... Each result set should have 5 rows + Assert.Equal(Common.StandardRows, rs.RowCount); + + // ... Inside each result set should be 5 columns + Assert.Equal(Common.StandardColumns, rs.Columns.Length); + } + + // ... There should be exactly two result set summaries + Assert.Equal(resultSets, batch.ResultSummaries.Length); + + foreach (ResultSetSummary rs in batch.ResultSummaries) + { + // ... Inside each result summary, there should be 5 rows + Assert.Equal(Common.StandardRows, rs.RowCount); + + // ... Inside each result summary, there should be 5 column definitions + Assert.Equal(Common.StandardColumns, rs.ColumnInfo.Length); + } + + // ... The callback for batch completion should have been fired + Assert.NotNull(batchSummaryFromCallback); + + // ... The callback for result set completion should have been fired + Assert.Equal(2, resultSummaryCount); + } + + [Fact] + public void BatchExecuteInvalidQuery() + { + // Setup: + // ... Create a callback for batch start + bool batchStartCalled = false; + Batch.BatchAsyncEventHandler batchStartCallback = b => + { + batchStartCalled = true; + return Task.FromResult(0); + }; + + // ... Create a callback for batch completion + BatchSummary batchSummaryFromCallback = null; + Batch.BatchAsyncEventHandler batchCompleteCallback = b => + { + batchSummaryFromCallback = b.Summary; + return Task.FromResult(0); + }; + + // ... Create a callback that will fail the test if it's called + ResultSet.ResultSetAsyncEventHandler resultSetCallback = r => + { + throw new Exception("ResultSet callback was called when it should not have been."); + }; + + ConnectionInfo ci = Common.CreateTestConnectionInfo(null, true); + + // If I execute a batch that is invalid + var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary()); + Batch batch = new Batch(Common.StandardQuery, Common.SubsectionDocument, Common.Ordinal, fileStreamFactory); + batch.BatchStart += batchStartCallback; + batch.BatchCompletion += batchCompleteCallback; + batch.ResultSetCompletion += resultSetCallback; + batch.Execute(GetConnection(ci), CancellationToken.None).Wait(); + + // Then: + // ... It should have executed with error + Assert.True(batch.HasExecuted); + Assert.True(batch.HasError); + + // ... There should be no result sets + Assert.Empty(batch.ResultSets); + Assert.Empty(batch.ResultSummaries); + + // ... There should be plenty of messages for the error + Assert.NotEmpty(batch.ResultMessages); + + // ... The callback for batch completion should have been fired + Assert.NotNull(batchSummaryFromCallback); + + // ... The callback for batch start should have been fired + Assert.True(batchStartCalled); + } + + [Fact] + public async Task BatchExecuteExecuted() + { + ConnectionInfo ci = Common.CreateTestConnectionInfo(new[] { Common.StandardTestData }, false); + + // If I execute a batch + var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary()); + Batch batch = new Batch(Common.StandardQuery, Common.SubsectionDocument, Common.Ordinal, fileStreamFactory); + batch.Execute(GetConnection(ci), CancellationToken.None).Wait(); + + // Then: + // ... It should have executed without error + Assert.True(batch.HasExecuted, "The batch should have been marked executed."); + Assert.False(batch.HasError, "The batch should not have an error"); + + // Setup for part 2: + // ... Create a callback for batch completion + Batch.BatchAsyncEventHandler completeCallback = b => + { + throw new Exception("Batch completion callback should not have been called"); + }; + + // ... Create a callback for batch start + Batch.BatchAsyncEventHandler startCallback = b => + { + throw new Exception("Batch start callback should not have been called"); + }; + + // If I execute it again + // Then: + // ... It should throw an invalid operation exception + batch.BatchStart += startCallback; + batch.BatchCompletion += completeCallback; + await Assert.ThrowsAsync(() => + batch.Execute(GetConnection(ci), CancellationToken.None)); + + // ... The data should still be available without error + Assert.False(batch.HasError, "The batch should not be in an error condition"); + Assert.True(batch.HasExecuted, "The batch should still be marked executed."); + Assert.NotEmpty(batch.ResultSets); + Assert.NotEmpty(batch.ResultSummaries); + } + + [Theory] + [InlineData("")] + [InlineData(null)] + public void BatchExecuteNoSql(string query) + { + // If: + // ... I create a batch that has an empty query + // Then: + // ... It should throw an exception + Assert.Throws(() => new Batch(query, Common.SubsectionDocument, Common.Ordinal, Common.GetFileStreamFactory(null))); + } + + [Fact] + public void BatchNoBufferFactory() + { + // If: + // ... I create a batch that has no file stream factory + // Then: + // ... It should throw an exception + Assert.Throws(() => new Batch("stuff", Common.SubsectionDocument, Common.Ordinal, null)); + } + + [Fact] + public void BatchInvalidOrdinal() + { + // If: + // ... I create a batch has has an ordinal less than 0 + // Then: + // ... It should throw an exception + Assert.Throws(() => new Batch("stuff", Common.SubsectionDocument, -1, Common.GetFileStreamFactory(null))); + } + + private static DbConnection GetConnection(ConnectionInfo info) + { + return info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails)); + } + + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/QueryTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/QueryTests.cs new file mode 100644 index 00000000..3df09e13 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/QueryTests.cs @@ -0,0 +1,316 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.SqlContext; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution +{ + public class QueryTests + { + + [Fact] + public void QueryExecuteNoQueryText() + { + // If: + // ... I create a query that has a null query text + // Then: + // ... It should throw an exception + Assert.Throws(() => + new Query(null, Common.CreateTestConnectionInfo(null, false), new QueryExecutionSettings(), Common.GetFileStreamFactory(null))); + } + + [Fact] + public void QueryExecuteNoConnectionInfo() + { + // If: + // ... I create a query that has a null connection info + // Then: + // ... It should throw an exception + Assert.Throws(() => new Query("Some Query", null, new QueryExecutionSettings(), Common.GetFileStreamFactory(null))); + } + + [Fact] + public void QueryExecuteNoSettings() + { + // If: + // ... I create a query that has a null settings + // Then: + // ... It should throw an exception + Assert.Throws(() => + new Query("Some query", Common.CreateTestConnectionInfo(null, false), null, Common.GetFileStreamFactory(null))); + } + + [Fact] + public void QueryExecuteNoBufferFactory() + { + // If: + // ... I create a query that has a null file stream factory + // Then: + // ... It should throw an exception + Assert.Throws(() => + new Query("Some query", Common.CreateTestConnectionInfo(null, false), new QueryExecutionSettings(), null)); + } + + [Fact] + public void QueryExecuteSingleBatch() + { + // Setup: + // ... Create a callback for atch start + int batchStartCallbacksReceived = 0; + Batch.BatchAsyncEventHandler batchStartCallback = b => + { + batchStartCallbacksReceived++; + return Task.FromResult(0); + }; + + // ... Create a callback for batch completion + int batchCompleteCallbacksReceived = 0; + Batch.BatchAsyncEventHandler batchCompleteCallback = summary => + { + batchCompleteCallbacksReceived++; + return Task.CompletedTask; + }; + + // If: + // ... I create a query from a single batch (without separator) + ConnectionInfo ci = Common.CreateTestConnectionInfo(null, false); + var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary()); + Query query = new Query(Common.StandardQuery, ci, new QueryExecutionSettings(), fileStreamFactory); + query.BatchStarted += batchStartCallback; + query.BatchCompleted += batchCompleteCallback; + + // Then: + // ... I should get a single batch to execute that hasn't been executed + Assert.NotEmpty(query.QueryText); + Assert.NotEmpty(query.Batches); + Assert.Equal(1, query.Batches.Length); + Assert.False(query.HasExecuted); + Assert.Throws(() => query.BatchSummaries); + + // If: + // ... I then execute the query + query.Execute(); + query.ExecutionTask.Wait(); + + // Then: + // ... The query should have completed successfully with one batch summary returned + Assert.True(query.HasExecuted); + Assert.NotEmpty(query.BatchSummaries); + Assert.Equal(1, query.BatchSummaries.Length); + + // ... The batch callbacks should have been called precisely 1 time + Assert.Equal(1, batchStartCallbacksReceived); + Assert.Equal(1, batchCompleteCallbacksReceived); + } + + [Fact] + public void QueryExecuteNoOpBatch() + { + // Setup: + // ... Create a callback for batch startup + Batch.BatchAsyncEventHandler batchStartCallback = b => + { + throw new Exception("Batch startup callback should not have been called."); + }; + + // ... Create a callback for batch completion + Batch.BatchAsyncEventHandler batchCompletionCallback = summary => + { + throw new Exception("Batch completion callback was called"); + }; + + // If: + // ... I create a query from a single batch that does nothing + ConnectionInfo ci = Common.CreateTestConnectionInfo(null, false); + var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary()); + Query query = new Query(Common.NoOpQuery, ci, new QueryExecutionSettings(), fileStreamFactory); + query.BatchStarted += batchStartCallback; + query.BatchCompleted += batchCompletionCallback; + + // Then: + // ... I should get no batches back + Assert.NotEmpty(query.QueryText); + Assert.Empty(query.Batches); + Assert.False(query.HasExecuted); + Assert.Throws(() => query.BatchSummaries); + + // If: + // ... I Then execute the query + query.Execute(); + query.ExecutionTask.Wait(); + + // Then: + // ... The query should have completed successfully with no batch summaries returned + Assert.True(query.HasExecuted); + Assert.Empty(query.BatchSummaries); + } + + [Fact] + public void QueryExecuteMultipleBatches() + { + // Setup: + // ... Create a callback for batch start + int batchStartCallbacksReceived = 0; + Batch.BatchAsyncEventHandler batchStartCallback = b => + { + batchStartCallbacksReceived++; + return Task.FromResult(0); + }; + + // ... Create a callback for batch completion + int batchCompletedCallbacksReceived = 0; + Batch.BatchAsyncEventHandler batchCompletedCallback = summary => + { + batchCompletedCallbacksReceived++; + return Task.FromResult(0); + }; + + // If: + // ... I create a query from two batches (with separator) + ConnectionInfo ci = Common.CreateTestConnectionInfo(null, false); + string queryText = string.Format("{0}\r\nGO\r\n{0}", Common.StandardQuery); + var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary()); + Query query = new Query(queryText, ci, new QueryExecutionSettings(), fileStreamFactory); + query.BatchStarted += batchStartCallback; + query.BatchCompleted += batchCompletedCallback; + + // Then: + // ... I should get back two batches to execute that haven't been executed + Assert.NotEmpty(query.QueryText); + Assert.NotEmpty(query.Batches); + Assert.Equal(2, query.Batches.Length); + Assert.False(query.HasExecuted); + Assert.Throws(() => query.BatchSummaries); + + // If: + // ... I then execute the query + query.Execute(); + query.ExecutionTask.Wait(); + + // Then: + // ... The query should have completed successfully with two batch summaries returned + Assert.True(query.HasExecuted); + Assert.NotEmpty(query.BatchSummaries); + Assert.Equal(2, query.BatchSummaries.Length); + + // ... The batch start and completion callbacks should have been called precisely 2 times + Assert.Equal(2, batchStartCallbacksReceived); + Assert.Equal(2, batchCompletedCallbacksReceived); + } + + [Fact] + public void QueryExecuteMultipleBatchesWithNoOp() + { + // Setup: + // ... Create a callback for batch start + int batchStartCallbacksReceived = 0; + Batch.BatchAsyncEventHandler batchStartCallback = b => + { + batchStartCallbacksReceived++; + return Task.FromResult(0); + }; + + // ... Create a callback for batch completion + int batchCompletionCallbacksReceived = 0; + Batch.BatchAsyncEventHandler batchCompletionCallback = summary => + { + batchCompletionCallbacksReceived++; + return Task.CompletedTask; + }; + + // If: + // ... I create a query from a two batches (with separator) + ConnectionInfo ci = Common.CreateTestConnectionInfo(null, false); + string queryText = string.Format("{0}\r\nGO\r\n{1}", Common.StandardQuery, Common.NoOpQuery); + var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary()); + Query query = new Query(queryText, ci, new QueryExecutionSettings(), fileStreamFactory); + query.BatchStarted += batchStartCallback; + query.BatchCompleted += batchCompletionCallback; + + // Then: + // ... I should get back one batch to execute that hasn't been executed + Assert.NotEmpty(query.QueryText); + Assert.NotEmpty(query.Batches); + Assert.Equal(1, query.Batches.Length); + Assert.False(query.HasExecuted); + Assert.Throws(() => query.BatchSummaries); + + // If: + // .. I then execute the query + query.Execute(); + query.ExecutionTask.Wait(); + + // ... The query should have completed successfully with one batch summary returned + Assert.True(query.HasExecuted); + Assert.NotEmpty(query.BatchSummaries); + Assert.Equal(1, query.BatchSummaries.Length); + + // ... The batch callbacks should have been called precisely 1 time + Assert.Equal(1, batchStartCallbacksReceived); + Assert.Equal(1, batchCompletionCallbacksReceived); + } + + [Fact] + public void QueryExecuteInvalidBatch() + { + // Setup: + // ... Create a callback for batch start + int batchStartCallbacksReceived = 0; + Batch.BatchAsyncEventHandler batchStartCallback = b => + { + batchStartCallbacksReceived++; + return Task.FromResult(0); + }; + + // ... Create a callback for batch completion + int batchCompletionCallbacksReceived = 0; + Batch.BatchAsyncEventHandler batchCompltionCallback = summary => + { + batchCompletionCallbacksReceived++; + return Task.CompletedTask; + }; + + // If: + // ... I create a query from an invalid batch + ConnectionInfo ci = Common.CreateTestConnectionInfo(null, true); + var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary()); + Query query = new Query(Common.InvalidQuery, ci, new QueryExecutionSettings(), fileStreamFactory); + query.BatchStarted += batchStartCallback; + query.BatchCompleted += batchCompltionCallback; + + // Then: + // ... I should get back a query with one batch not executed + Assert.NotEmpty(query.QueryText); + Assert.NotEmpty(query.Batches); + Assert.Equal(1, query.Batches.Length); + Assert.False(query.HasExecuted); + Assert.Throws(() => query.BatchSummaries); + + // If: + // ... I then execute the query + query.Execute(); + query.ExecutionTask.Wait(); + + // Then: + // ... There should be an error on the batch + Assert.True(query.HasExecuted); + Assert.NotEmpty(query.BatchSummaries); + Assert.Equal(1, query.BatchSummaries.Length); + Assert.True(query.BatchSummaries[0].HasError); + Assert.NotEmpty(query.BatchSummaries[0].Messages); + + // ... The batch callbacks should have been called once + Assert.Equal(1, batchStartCallbacksReceived); + Assert.Equal(1, batchCompletionCallbacksReceived); + } + + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/ResultSetTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/ResultSetTests.cs new file mode 100644 index 00000000..3aff6f13 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/ResultSetTests.cs @@ -0,0 +1,209 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution +{ + public class ResultSetTests + { + [Fact] + public void ResultCreation() + { + // If: + // ... I create a new result set with a valid db data reader + DbDataReader mockReader = GetReader(null, false, string.Empty); + ResultSet resultSet = new ResultSet(mockReader, Common.Ordinal, Common.Ordinal, Common.GetFileStreamFactory(new Dictionary())); + + // Then: + // ... There should not be any data read yet + Assert.Null(resultSet.Columns); + Assert.Equal(0, resultSet.RowCount); + Assert.Equal(Common.Ordinal, resultSet.Id); + + // ... The summary should include the same info + Assert.Null(resultSet.Summary.ColumnInfo); + Assert.Equal(0, resultSet.Summary.RowCount); + Assert.Equal(Common.Ordinal, resultSet.Summary.Id); + Assert.Equal(Common.Ordinal, resultSet.Summary.BatchId); + } + + [Fact] + public void ResultCreationInvalidReader() + { + // If: + // ... I create a new result set without a reader + // Then: + // ... It should throw an exception + Assert.Throws(() => new ResultSet(null, Common.Ordinal, Common.Ordinal, null)); + + } + + [Fact] + public async Task ReadToEndSuccess() + { + // Setup: Create a callback for resultset completion + ResultSetSummary resultSummaryFromCallback = null; + ResultSet.ResultSetAsyncEventHandler callback = r => + { + resultSummaryFromCallback = r.Summary; + return Task.FromResult(0); + }; + + // If: + // ... I create a new resultset with a valid db data reader that has data + // ... and I read it to the end + DbDataReader mockReader = GetReader(new [] {Common.StandardTestData}, false, Common.StandardQuery); + var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary()); + ResultSet resultSet = new ResultSet(mockReader, Common.Ordinal, Common.Ordinal, fileStreamFactory); + resultSet.ResultCompletion += callback; + await resultSet.ReadResultToEnd(CancellationToken.None); + + // Then: + // ... The columns should be set + // ... There should be rows to read back + Assert.NotNull(resultSet.Columns); + Assert.Equal(Common.StandardColumns, resultSet.Columns.Length); + Assert.Equal(Common.StandardRows, resultSet.RowCount); + + // ... The summary should have the same info + Assert.NotNull(resultSet.Summary.ColumnInfo); + Assert.Equal(Common.StandardColumns, resultSet.Summary.ColumnInfo.Length); + Assert.Equal(Common.StandardRows, resultSet.Summary.RowCount); + + // ... The callback for result set completion should have been fired + Assert.NotNull(resultSummaryFromCallback); + } + + [Theory] + [InlineData("JSON")] + [InlineData("XML")] + public async Task ReadToEndForXmlJson(string forType) + { + // Setup: + // ... Build a FOR XML or FOR JSON data set + string columnName = string.Format("{0}_F52E2B61-18A1-11d1-B105-00805F49916B", forType); + List> data = new List>(); + for(int i = 0; i < Common.StandardRows; i++) + { + data.Add(new Dictionary { { columnName, "test data"} }); + } + Dictionary[][] dataSets = {data.ToArray()}; + + // ... Create a callback for resultset completion + ResultSetSummary resultSummary = null; + ResultSet.ResultSetAsyncEventHandler callback = r => + { + resultSummary = r.Summary; + return Task.FromResult(0); + }; + + // If: + // ... I create a new resultset with a valid db data reader that is FOR XML/JSON + // ... and I read it to the end + DbDataReader mockReader = GetReader(dataSets, false, Common.StandardQuery); + var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary()); + ResultSet resultSet = new ResultSet(mockReader, Common.Ordinal, Common.Ordinal, fileStreamFactory); + resultSet.ResultCompletion += callback; + await resultSet.ReadResultToEnd(CancellationToken.None); + + // Then: + // ... There should only be one column + // ... There should only be one row + // ... The result should be marked as complete + Assert.Equal(1, resultSet.Columns.Length); + Assert.Equal(1, resultSet.RowCount); + + // ... The callback should have been called + Assert.NotNull(resultSummary); + + // If: + // ... I attempt to read back the results + // Then: + // ... I should only get one row + var subset = await resultSet.GetSubset(0, 10); + Assert.Equal(1, subset.RowCount); + } + + [Fact] + public async Task GetSubsetWithoutExecution() + { + // If: + // ... I create a new result set with a valid db data reader without executing it + DbDataReader mockReader = GetReader(null, false, string.Empty); + var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary()); + ResultSet resultSet = new ResultSet(mockReader, Common.Ordinal, Common.Ordinal, fileStreamFactory); + + // Then: + // ... Attempting to read a subset should fail miserably + await Assert.ThrowsAsync(() => resultSet.GetSubset(0, 0)); + } + + [Theory] + [InlineData(-1, 0)] // Too small start row + [InlineData(20, 0)] // Too large start row + [InlineData(0, -1)] // Negative row count + public async Task GetSubsetInvalidParameters(int startRow, int rowCount) + { + // If: + // ... I create a new result set with a valid db data reader + // ... And execute the result + DbDataReader mockReader = GetReader(new[] {Common.StandardTestData}, false, Common.StandardQuery); + var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary()); + ResultSet resultSet = new ResultSet(mockReader, Common.Ordinal, Common.Ordinal, fileStreamFactory); + await resultSet.ReadResultToEnd(CancellationToken.None); + + // ... And attempt to get a subset with invalid parameters + // Then: + // ... It should throw an exception for an invalid parameter + await Assert.ThrowsAsync(() => resultSet.GetSubset(startRow, rowCount)); + } + + [Theory] + [InlineData(0, 3)] // Standard scenario, 3 rows should come back + [InlineData(0, 20)] // Asking for too many rows, 5 rows should come back + [InlineData(1, 3)] // Standard scenario from non-zero start + [InlineData(1, 20)] // Asking for too many rows at a non-zero start + public async Task GetSubsetSuccess(int startRow, int rowCount) + { + // If: + // ... I create a new result set with a valid db data reader + // ... And execute the result set + DbDataReader mockReader = GetReader(new[] { Common.StandardTestData }, false, Common.StandardQuery); + var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary()); + ResultSet resultSet = new ResultSet(mockReader, Common.Ordinal, Common.Ordinal, fileStreamFactory); + await resultSet.ReadResultToEnd(CancellationToken.None); + + // ... And attempt to get a subset with valid number of rows + ResultSetSubset subset = await resultSet.GetSubset(startRow, rowCount); + + // Then: + // ... There should be rows in the subset, either the number of rows or the number of + // rows requested or the number of rows in the result set, whichever is lower + long availableRowsFromStart = resultSet.RowCount - startRow; + Assert.Equal(Math.Min(availableRowsFromStart, rowCount), subset.RowCount); + + // ... The rows should have the same number of columns as the resultset + Assert.Equal(resultSet.Columns.Length, subset.Rows[0].Length); + } + + private static DbDataReader GetReader(Dictionary[][] dataSet, bool throwOnRead, string query) + { + var info = Common.CreateTestConnectionInfo(dataSet, throwOnRead); + var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails)); + var command = connection.CreateCommand(); + command.CommandText = query; + return command.ExecuteReader(); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/ServiceIntegrationTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/ServiceIntegrationTests.cs new file mode 100644 index 00000000..c62211b0 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/ServiceIntegrationTests.cs @@ -0,0 +1,481 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Microsoft.SqlTools.ServiceLayer.SqlContext; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; +using Microsoft.SqlTools.ServiceLayer.Workspace; +using Moq; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution +{ + public class ServiceIntegrationTests + { + + [Fact] + public async void QueryExecuteSingleBatchNoResultsTest() + { + // Given: + // ... Default settings are stored in the workspace service + // ... A workspace with a standard query is configured + WorkspaceService.Instance.CurrentSettings = new SqlToolsSettings(); + var workspaceService = Common.GetPrimedWorkspaceService(Common.StandardQuery); + + // If: + // ... I request to execute a valid query with no results + var queryService = Common.GetPrimedExecutionService(null, true, false, workspaceService); + var queryParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; + + QueryExecuteResult result = null; + QueryExecuteCompleteParams completeParams = null; + QueryExecuteBatchNotificationParams batchStartParams = null; + QueryExecuteBatchNotificationParams batchCompleteParams = null; + var requestContext = RequestContextMocks.Create(qer => result = qer) + .AddEventHandling(QueryExecuteCompleteEvent.Type, (et, p) => completeParams = p) + .AddEventHandling(QueryExecuteBatchStartEvent.Type, (et, p) => batchStartParams = p) + .AddEventHandling(QueryExecuteBatchCompleteEvent.Type, (et, p) => batchCompleteParams = p) + .AddEventHandling(QueryExecuteResultSetCompleteEvent.Type, null); + await Common.AwaitExecution(queryService, queryParams, requestContext.Object); + + // Then: + // ... No Errors should have been sent + // ... A successful result should have been sent with messages on the first batch + // ... A completion event should have been fired with empty results + // ... A batch completion event should have been fired with empty results + // ... A result set completion event should not have been fired + VerifyQueryExecuteCallCount(requestContext, Times.Once(), Times.Once(), Times.Once(), Times.Once(), Times.Never(), Times.Never()); + Assert.Null(result.Messages); + + Assert.Equal(1, completeParams.BatchSummaries.Length); + Assert.Empty(completeParams.BatchSummaries[0].ResultSetSummaries); + Assert.NotEmpty(completeParams.BatchSummaries[0].Messages); + + // ... Batch start summary should not contain result sets, messages, but should contain owner URI + Assert.NotNull(batchStartParams); + Assert.NotNull(batchStartParams.BatchSummary); + Assert.Null(batchStartParams.BatchSummary.Messages); + Assert.Null(batchStartParams.BatchSummary.ResultSetSummaries); + Assert.Equal(Common.OwnerUri, batchStartParams.OwnerUri); + + // ... Batch completion summary should contain result sets, messages, and the owner URI + Assert.NotNull(batchCompleteParams); + Assert.NotNull(batchCompleteParams.BatchSummary); + Assert.Empty(batchCompleteParams.BatchSummary.ResultSetSummaries); + Assert.NotEmpty(batchCompleteParams.BatchSummary.Messages); + Assert.Equal(Common.OwnerUri, batchCompleteParams.OwnerUri); + + // ... There should be one active query + Assert.Equal(1, queryService.ActiveQueries.Count); + } + + + [Fact] + public async void QueryExecuteSingleBatchSingleResultTest() + { + // Given: + // ... A workspace with a standard query is configured + var workspaceService = Common.GetPrimedWorkspaceService(Common.StandardQuery); + + // If: + // ... I request to execute a valid query with results + var queryService = Common.GetPrimedExecutionService(new[] { Common.StandardTestData }, true, false, workspaceService); + var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument }; + + QueryExecuteResult result = null; + QueryExecuteCompleteParams completeParams = null; + QueryExecuteBatchNotificationParams batchStartParams = null; + QueryExecuteBatchNotificationParams batchCompleteParams = null; + QueryExecuteResultSetCompleteParams resultCompleteParams = null; + var requestContext = RequestContextMocks.Create(qer => result = qer) + .AddEventHandling(QueryExecuteCompleteEvent.Type, (et, p) => completeParams = p) + .AddEventHandling(QueryExecuteBatchStartEvent.Type, (et, p) => batchStartParams = p) + .AddEventHandling(QueryExecuteBatchCompleteEvent.Type, (et, p) => batchCompleteParams = p) + .AddEventHandling(QueryExecuteResultSetCompleteEvent.Type, (et, p) => resultCompleteParams = p); + await Common.AwaitExecution(queryService, queryParams, requestContext.Object); + + // Then: + // ... No errors should have been sent + // ... A successful result should have been sent without messages + // ... A completion event should have been fired with one result + // ... A batch completion event should have been fired + // ... A resultset completion event should have been fired + VerifyQueryExecuteCallCount(requestContext, Times.Once(), Times.Once(), Times.Once(), Times.Once(), Times.Once(), Times.Never()); + Assert.Null(result.Messages); + + Assert.Equal(1, completeParams.BatchSummaries.Length); + Assert.NotEmpty(completeParams.BatchSummaries[0].ResultSetSummaries); + Assert.NotEmpty(completeParams.BatchSummaries[0].Messages); + Assert.False(completeParams.BatchSummaries[0].HasError); + + // ... Batch start summary should not contain result sets, messages, but should contain owner URI + Assert.NotNull(batchStartParams); + Assert.NotNull(batchStartParams.BatchSummary); + Assert.Null(batchStartParams.BatchSummary.Messages); + Assert.Null(batchStartParams.BatchSummary.ResultSetSummaries); + Assert.Equal(Common.OwnerUri, batchStartParams.OwnerUri); + + Assert.NotNull(batchCompleteParams); + Assert.NotEmpty(batchCompleteParams.BatchSummary.ResultSetSummaries); + Assert.NotEmpty(batchCompleteParams.BatchSummary.Messages); + Assert.Equal(Common.OwnerUri, batchCompleteParams.OwnerUri); + + Assert.NotNull(resultCompleteParams); + Assert.Equal(Common.StandardColumns, resultCompleteParams.ResultSetSummary.ColumnInfo.Length); + Assert.Equal(Common.StandardRows, resultCompleteParams.ResultSetSummary.RowCount); + Assert.Equal(Common.OwnerUri, resultCompleteParams.OwnerUri); + + // ... There should be one active query + Assert.Equal(1, queryService.ActiveQueries.Count); + } + + [Fact] + public async Task QueryExecuteSingleBatchMultipleResultTest() + { + // Given: + // ... A workspace with a standard query is configured + var workspaceService = Common.GetPrimedWorkspaceService(Common.StandardQuery); + + // If: + // ... I request to execute a valid query with one batch and multiple result sets + var dataset = new[] { Common.StandardTestData, Common.StandardTestData }; + var queryService = Common.GetPrimedExecutionService(dataset, true, false, workspaceService); + var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument }; + + QueryExecuteResult result = null; + QueryExecuteCompleteParams completeParams = null; + QueryExecuteBatchNotificationParams batchStartParams = null; + QueryExecuteBatchNotificationParams batchCompleteParams = null; + List resultCompleteParams = new List(); + var requestContext = RequestContextMocks.Create(qer => result = qer) + .AddEventHandling(QueryExecuteCompleteEvent.Type, (et, p) => completeParams = p) + .AddEventHandling(QueryExecuteBatchStartEvent.Type, (et, p) => batchStartParams = p) + .AddEventHandling(QueryExecuteBatchCompleteEvent.Type, (et, p) => batchCompleteParams = p) + .AddEventHandling(QueryExecuteResultSetCompleteEvent.Type, (et, p) => resultCompleteParams.Add(p)); + await Common.AwaitExecution(queryService, queryParams, requestContext.Object); + + // Then: + // ... No errors should have been sent + // ... A successful result should have been sent without messages + // ... A completion event should have been fired with one result + // ... A batch completion event should have been fired + // ... Two resultset completion events should have been fired + VerifyQueryExecuteCallCount(requestContext, Times.Once(), Times.Once(), Times.Once(), Times.Once(), Times.Exactly(2), Times.Never()); + Assert.Null(result.Messages); + + Assert.Equal(1, completeParams.BatchSummaries.Length); + Assert.NotEmpty(completeParams.BatchSummaries[0].ResultSetSummaries); + Assert.NotEmpty(completeParams.BatchSummaries[0].Messages); + Assert.False(completeParams.BatchSummaries[0].HasError); + + // ... Batch start summary should not contain result sets, messages, but should contain owner URI + Assert.NotNull(batchStartParams); + Assert.NotNull(batchStartParams.BatchSummary); + Assert.Null(batchStartParams.BatchSummary.Messages); + Assert.Null(batchStartParams.BatchSummary.ResultSetSummaries); + Assert.Equal(Common.OwnerUri, batchStartParams.OwnerUri); + + Assert.NotNull(batchCompleteParams); + Assert.NotEmpty(batchCompleteParams.BatchSummary.ResultSetSummaries); + Assert.NotEmpty(batchCompleteParams.BatchSummary.Messages); + Assert.Equal(Common.OwnerUri, batchCompleteParams.OwnerUri); + + Assert.Equal(2, resultCompleteParams.Count); + foreach (var resultParam in resultCompleteParams) + { + Assert.NotNull(resultCompleteParams); + Assert.Equal(Common.StandardColumns, resultParam.ResultSetSummary.ColumnInfo.Length); + Assert.Equal(Common.StandardRows, resultParam.ResultSetSummary.RowCount); + Assert.Equal(Common.OwnerUri, resultParam.OwnerUri); + } + } + + [Fact] + public async Task QueryExecuteMultipleBatchSingleResultTest() + { + // Given: + // ... A workspace with a standard query is configured + var workspaceService = Common.GetPrimedWorkspaceService(string.Format("{0}\r\nGO\r\n{0}", Common.StandardQuery)); + + // If: + // ... I request a to execute a valid query with multiple batches + var dataSet = new[] { Common.StandardTestData }; + var queryService = Common.GetPrimedExecutionService(dataSet, true, false, workspaceService); + var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument }; + + QueryExecuteResult result = null; + QueryExecuteCompleteParams completeParams = null; + List batchStartParams = new List(); + List batchCompleteParams = new List(); + List resultCompleteParams = new List(); + var requestContext = RequestContextMocks.Create(qer => result = qer) + .AddEventHandling(QueryExecuteCompleteEvent.Type, (et, p) => completeParams = p) + .AddEventHandling(QueryExecuteBatchStartEvent.Type, (et, p) => batchStartParams.Add(p)) + .AddEventHandling(QueryExecuteBatchCompleteEvent.Type, (et, p) => batchCompleteParams.Add(p)) + .AddEventHandling(QueryExecuteResultSetCompleteEvent.Type, (et, p) => resultCompleteParams.Add(p)); + await Common.AwaitExecution(queryService, queryParams, requestContext.Object); + + // Then: + // ... No errors should have been sent + // ... A successful result should have been sent without messages + + VerifyQueryExecuteCallCount(requestContext, Times.Once(), Times.Once(), Times.Exactly(2), Times.Exactly(2), Times.Exactly(2), Times.Never()); + Assert.Null(result.Messages); + + // ... A completion event should have been fired with one two batch summaries, one result each + Assert.Equal(2, completeParams.BatchSummaries.Length); + Assert.Equal(1, completeParams.BatchSummaries[0].ResultSetSummaries.Length); + Assert.Equal(1, completeParams.BatchSummaries[1].ResultSetSummaries.Length); + Assert.NotEmpty(completeParams.BatchSummaries[0].Messages); + Assert.NotEmpty(completeParams.BatchSummaries[1].Messages); + + // ... Two batch start events should have been fired + Assert.Equal(2, batchStartParams.Count); + foreach (var batch in batchStartParams) + { + Assert.Null(batch.BatchSummary.Messages); + Assert.Null(batch.BatchSummary.ResultSetSummaries); + Assert.Equal(Common.OwnerUri, batch.OwnerUri); + } + + // ... Two batch completion events should have been fired + Assert.Equal(2, batchCompleteParams.Count); + foreach (var batch in batchCompleteParams) + { + Assert.NotEmpty(batch.BatchSummary.ResultSetSummaries); + Assert.NotEmpty(batch.BatchSummary.Messages); + Assert.Equal(Common.OwnerUri, batch.OwnerUri); + } + + // ... Two resultset completion events should have been fired + Assert.Equal(2, resultCompleteParams.Count); + foreach (var resultParam in resultCompleteParams) + { + Assert.NotNull(resultParam.ResultSetSummary); + Assert.Equal(Common.StandardColumns, resultParam.ResultSetSummary.ColumnInfo.Length); + Assert.Equal(Common.StandardRows, resultParam.ResultSetSummary.RowCount); + Assert.Equal(Common.OwnerUri, resultParam.OwnerUri); + } + + // ... There should be one active query + Assert.Equal(1, queryService.ActiveQueries.Count); + } + + [Fact] + public async void QueryExecuteUnconnectedUriTest() + { + // Given: + // ... A workspace with a standard query is configured + var workspaceService = Common.GetPrimedWorkspaceService(Common.StandardQuery); + + // If: + // ... I request to execute a query using a file URI that isn't connected + var queryService = Common.GetPrimedExecutionService(null, false, false, workspaceService); + var queryParams = new QueryExecuteParams { OwnerUri = "notConnected", QuerySelection = Common.WholeDocument }; + + object error = null; + var requestContext = RequestContextMocks.Create(null) + .AddErrorHandling(e => error = e); + await Common.AwaitExecution(queryService, queryParams, requestContext.Object); + + // Then: + // ... An error should have been returned + // ... No result should have been returned + // ... No completion event should have been fired + // ... There should be no active queries + VerifyQueryExecuteCallCount(requestContext, Times.Never(), Times.Never(), Times.Never(), Times.Never(), Times.Never(), Times.Once()); + Assert.IsType(error); + Assert.NotEmpty((string)error); + Assert.Empty(queryService.ActiveQueries); + } + + [Fact] + public async void QueryExecuteInProgressTest() + { + // Given: + // ... A workspace with a standard query is configured + var workspaceService = Common.GetPrimedWorkspaceService(Common.StandardQuery); + + // If: + // ... I request to execute a query + var queryService = Common.GetPrimedExecutionService(null, true, false, workspaceService); + var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument }; + + // Note, we don't care about the results of the first request + var firstRequestContext = RequestContextMocks.Create(null); + await Common.AwaitExecution(queryService, queryParams, firstRequestContext.Object); + + // ... And then I request another query without waiting for the first to complete + queryService.ActiveQueries[Common.OwnerUri].HasExecuted = false; // Simulate query hasn't finished + object error = null; + var secondRequestContext = RequestContextMocks.Create(null) + .AddErrorHandling(e => error = e); + await Common.AwaitExecution(queryService, queryParams, secondRequestContext.Object); + + // Then: + // ... An error should have been sent + // ... A result should have not have been sent + // ... No completion event should have been fired + // ... A batch completion event should have fired, but not a resultset event + // ... There should only be one active query + VerifyQueryExecuteCallCount(secondRequestContext, Times.Never(), Times.AtMostOnce(), Times.AtMostOnce(), Times.AtMostOnce(), Times.Never(), Times.Once()); + Assert.IsType(error); + Assert.NotEmpty((string)error); + Assert.Equal(1, queryService.ActiveQueries.Count); + } + + + [Fact] + public async void QueryExecuteCompletedTest() + { + // Given: + // ... A workspace with a standard query is configured + var workspaceService = Common.GetPrimedWorkspaceService(Common.StandardQuery); + + // If: + // ... I request to execute a query + var queryService = Common.GetPrimedExecutionService(null, true, false, workspaceService); + var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument }; + + // Note, we don't care about the results of the first request + var firstRequestContext = RequestContextMocks.Create(null); + await Common.AwaitExecution(queryService, queryParams, firstRequestContext.Object); + + // ... And then I request another query after waiting for the first to complete + QueryExecuteResult result = null; + QueryExecuteCompleteParams complete = null; + QueryExecuteBatchNotificationParams batchStart = null; + QueryExecuteBatchNotificationParams batchComplete = null; + var secondRequestContext = RequestContextMocks.Create(qer => result = qer) + .AddEventHandling(QueryExecuteCompleteEvent.Type, (et, qecp) => complete = qecp) + .AddEventHandling(QueryExecuteBatchStartEvent.Type, (et, p) => batchStart = p) + .AddEventHandling(QueryExecuteBatchCompleteEvent.Type, (et, p) => batchComplete = p); + await Common.AwaitExecution(queryService, queryParams, secondRequestContext.Object); + + // Then: + // ... No errors should have been sent + // ... A result should have been sent with no errors + // ... There should only be one active query + // ... A batch completion event should have fired, but not a result set completion event + VerifyQueryExecuteCallCount(secondRequestContext, Times.Once(), Times.Once(), Times.Once(), Times.Once(), Times.Never(), Times.Never()); + Assert.Null(result.Messages); + + Assert.False(complete.BatchSummaries.Any(b => b.HasError)); + Assert.Equal(1, queryService.ActiveQueries.Count); + + Assert.NotNull(batchStart); + Assert.NotNull(batchComplete); + Assert.False(batchComplete.BatchSummary.HasError); + Assert.Equal(complete.OwnerUri, batchComplete.OwnerUri); + } + + [Fact] + public async Task QueryExecuteMissingSelectionTest() + { + // Given: + // ... A workspace with a standard query is configured + var workspaceService = Common.GetPrimedWorkspaceService(string.Empty); + + // If: + // ... I request to execute a query with a missing query string + var queryService = Common.GetPrimedExecutionService(null, true, false, workspaceService); + var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QuerySelection = null }; + + object errorResult = null; + var requestContext = RequestContextMocks.Create(null) + .AddErrorHandling(error => errorResult = error); + await queryService.HandleExecuteRequest(queryParams, requestContext.Object); + + + // Then: + // ... Am error should have been sent + // ... No result should have been sent + // ... No completion events should have been fired + // ... An active query should not have been added + VerifyQueryExecuteCallCount(requestContext, Times.Never(), Times.Never(), Times.Never(), Times.Never(), Times.Never(), Times.Once()); + Assert.NotNull(errorResult); + Assert.IsType(errorResult); + Assert.DoesNotContain(Common.OwnerUri, queryService.ActiveQueries.Keys); + + // ... There should not be an active query + Assert.Empty(queryService.ActiveQueries); + } + + [Fact] + public async void QueryExecuteInvalidQueryTest() + { + // Given: + // ... A workspace with a standard query is configured + var workspaceService = Common.GetPrimedWorkspaceService(Common.StandardQuery); + + // If: + // ... I request to execute a query that is invalid + var queryService = Common.GetPrimedExecutionService(null, true, true, workspaceService); + var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument }; + + QueryExecuteResult result = null; + QueryExecuteCompleteParams complete = null; + QueryExecuteBatchNotificationParams batchStart = null; + QueryExecuteBatchNotificationParams batchComplete = null; + var requestContext = RequestContextMocks.Create(qer => result = qer) + .AddEventHandling(QueryExecuteCompleteEvent.Type, (et, qecp) => complete = qecp) + .AddEventHandling(QueryExecuteBatchStartEvent.Type, (et, p) => batchStart = p) + .AddEventHandling(QueryExecuteBatchCompleteEvent.Type, (et, p) => batchComplete = p); + await Common.AwaitExecution(queryService, queryParams, requestContext.Object); + + // Then: + // ... No errors should have been sent + // ... A result should have been sent with success (we successfully started the query) + // ... A completion event (query, batch, not resultset) should have been sent with error + VerifyQueryExecuteCallCount(requestContext, Times.Once(), Times.Once(), Times.Once(), Times.Once(), Times.Never(), Times.Never()); + Assert.Null(result.Messages); + + Assert.Equal(1, complete.BatchSummaries.Length); + Assert.True(complete.BatchSummaries[0].HasError); + Assert.NotEmpty(complete.BatchSummaries[0].Messages); + + Assert.NotNull(batchStart); + Assert.False(batchStart.BatchSummary.HasError); + Assert.Null(batchStart.BatchSummary.Messages); + Assert.Null(batchStart.BatchSummary.ResultSetSummaries); + Assert.Equal(Common.OwnerUri, batchStart.OwnerUri); + + Assert.NotNull(batchComplete); + Assert.True(batchComplete.BatchSummary.HasError); + Assert.NotEmpty(batchComplete.BatchSummary.Messages); + Assert.Equal(Common.OwnerUri, batchComplete.OwnerUri); + } + + private static void VerifyQueryExecuteCallCount(Mock> mock, + Times sendResultCalls, + Times sendCompletionEventCalls, + Times sendBatchStartEvent, + Times sendBatchCompletionEvent, + Times sendResultCompleteEvent, + Times sendErrorCalls) + { + mock.Verify(rc => rc.SendResult(It.IsAny()), sendResultCalls); + mock.Verify(rc => rc.SendEvent( + It.Is>(m => m == QueryExecuteCompleteEvent.Type), + It.IsAny()), sendCompletionEventCalls); + mock.Verify(rc => rc.SendEvent( + It.Is>(m => m == QueryExecuteBatchCompleteEvent.Type), + It.IsAny()), sendBatchCompletionEvent); + mock.Verify(rc => rc.SendEvent( + It.Is>(m => m== QueryExecuteBatchStartEvent.Type), + It.IsAny()), sendBatchStartEvent); + mock.Verify(rc => rc.SendEvent( + It.Is>(m => m == QueryExecuteResultSetCompleteEvent.Type), + It.IsAny()), sendResultCompleteEvent); + + mock.Verify(rc => rc.SendError(It.IsAny()), sendErrorCalls); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SaveResultsTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SaveResultsTests.cs index 470476c2..5317c5d4 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SaveResultsTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SaveResultsTests.cs @@ -3,16 +3,13 @@ // using System; -using System.Linq; using System.IO; using System.Threading.Tasks; using System.Runtime.InteropServices; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.QueryExecution; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; -using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.Test.Utility; -using Microsoft.SqlTools.ServiceLayer.Workspace; using Moq; using Xunit; @@ -30,11 +27,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution public async void SaveResultsAsCsvSuccessTest() { // Execute a query - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, Common.GetPrimedWorkspaceService()); + var workplaceService = Common.GetPrimedWorkspaceService(Common.StandardQuery); + var queryService = Common.GetPrimedExecutionService(new [] {Common.StandardTestData}, true, false, workplaceService); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; var executeRequest = RequestContextMocks.Create(null); - await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); - await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; + await Common.AwaitExecution(queryService, executeParams, executeRequest.Object); // Request to save the results as csv with correct parameters var saveParams = new SaveResultsAsCsvRequestParams @@ -47,18 +44,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution }; SaveResultRequestResult result = null; var saveRequest = GetSaveResultsContextMock(qcr => result = qcr, null); - queryService.ActiveQueries[Common.OwnerUri].Batches[0] = Common.GetBasicExecutedBatch(); // Call save results and wait on the save task await queryService.HandleSaveResultsAsCsvRequest(saveParams, saveRequest.Object); ResultSet selectedResultSet = queryService.ActiveQueries[saveParams.OwnerUri].Batches[saveParams.BatchIndex].ResultSets[saveParams.ResultSetIndex]; - Task saveTask = selectedResultSet.GetSaveTask(saveParams.FilePath); - await saveTask; + await selectedResultSet.GetSaveTask(saveParams.FilePath); // Expect to see a file successfully created in filepath and a success message + VerifySaveResultsCallCount(saveRequest, Times.Once(), Times.Never()); Assert.Null(result.Messages); Assert.True(File.Exists(saveParams.FilePath)); - VerifySaveResultsCallCount(saveRequest, Times.Once(), Times.Never()); // Delete temp file after test if (File.Exists(saveParams.FilePath)) @@ -74,11 +69,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution public async void SaveResultsAsCsvWithSelectionSuccessTest() { // Execute a query - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, Common.GetPrimedWorkspaceService()); + var workspaceService = Common.GetPrimedWorkspaceService(Common.StandardQuery); + var queryService = Common.GetPrimedExecutionService(new []{Common.StandardTestData}, true, false, workspaceService); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument , OwnerUri = Common.OwnerUri }; var executeRequest = RequestContextMocks.Create(null); - await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); - await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; + await Common.AwaitExecution(queryService, executeParams, executeRequest.Object); // Request to save the results as csv with correct parameters var saveParams = new SaveResultsAsCsvRequestParams @@ -95,7 +90,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution }; SaveResultRequestResult result = null; var saveRequest = GetSaveResultsContextMock(qcr => result = qcr, null); - queryService.ActiveQueries[Common.OwnerUri].Batches[0] = Common.GetBasicExecutedBatch(); // Call save results and wait on the save task await queryService.HandleSaveResultsAsCsvRequest(saveParams, saveRequest.Object); @@ -104,9 +98,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution await saveTask; // Expect to see a file successfully created in filepath and a success message + VerifySaveResultsCallCount(saveRequest, Times.Once(), Times.Never()); Assert.Null(result.Messages); Assert.True(File.Exists(saveParams.FilePath)); - VerifySaveResultsCallCount(saveRequest, Times.Once(), Times.Never()); // Delete temp file after test if (File.Exists(saveParams.FilePath)) @@ -120,13 +114,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution /// [Fact] public async void SaveResultsAsCsvExceptionTest() - { + { // Execute a query - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, Common.GetPrimedWorkspaceService()); + var workspaceService = Common.GetPrimedWorkspaceService(Common.StandardQuery); + var queryService = Common.GetPrimedExecutionService(new[] {Common.StandardTestData}, true, false, workspaceService); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; var executeRequest = RequestContextMocks.Create(null); - await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); - await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; + await Common.AwaitExecution(queryService, executeParams, executeRequest.Object); // Request to save the results as csv with incorrect filepath var saveParams = new SaveResultsAsCsvRequestParams @@ -139,17 +133,15 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution SaveResultRequestError errMessage = null; var saveRequest = GetSaveResultsContextMock( null, err => errMessage = (SaveResultRequestError) err); - queryService.ActiveQueries[Common.OwnerUri].Batches[0] = Common.GetBasicExecutedBatch(); // Call save results and wait on the save task await queryService.HandleSaveResultsAsCsvRequest(saveParams, saveRequest.Object); ResultSet selectedResultSet = queryService.ActiveQueries[saveParams.OwnerUri].Batches[saveParams.BatchIndex].ResultSets[saveParams.ResultSetIndex]; - Task saveTask = selectedResultSet.GetSaveTask(saveParams.FilePath); - await saveTask; + await selectedResultSet.GetSaveTask(saveParams.FilePath); // Expect to see error message - Assert.NotNull(errMessage); VerifySaveResultsCallCount(saveRequest, Times.Never(), Times.Once()); + Assert.NotNull(errMessage); Assert.False(File.Exists(saveParams.FilePath)); } @@ -157,11 +149,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution /// Test saving results to CSV file when the requested result set is no longer active /// [Fact] - public async void SaveResultsAsCsvQueryNotFoundTest() + public async Task SaveResultsAsCsvQueryNotFoundTest() { // Create a query execution service - var workspaceService = new Mock>(); - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var workspaceService = Common.GetPrimedWorkspaceService(Common.StandardQuery); + var queryService = Common.GetPrimedExecutionService(null, true, false, workspaceService); // Request to save the results as csv with query that is no longer active var saveParams = new SaveResultsAsCsvRequestParams @@ -173,12 +165,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution }; SaveResultRequestResult result = null; var saveRequest = GetSaveResultsContextMock(qcr => result = qcr, null); - queryService.HandleSaveResultsAsCsvRequest(saveParams, saveRequest.Object).Wait(); + await queryService.HandleSaveResultsAsCsvRequest(saveParams, saveRequest.Object); // Expect message that save failed + VerifySaveResultsCallCount(saveRequest, Times.Once(), Times.Never()); Assert.NotNull(result.Messages); Assert.False(File.Exists(saveParams.FilePath)); - VerifySaveResultsCallCount(saveRequest, Times.Once(), Times.Never()); } /// @@ -188,11 +180,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution public async void SaveResultsAsJsonSuccessTest() { // Execute a query - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, Common.GetPrimedWorkspaceService()); + var workspaceService = Common.GetPrimedWorkspaceService(Common.StandardQuery); + var queryService = Common.GetPrimedExecutionService(new[] {Common.StandardTestData}, true, false, workspaceService); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; var executeRequest = RequestContextMocks.Create(null); - await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); - await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; + await Common.AwaitExecution(queryService, executeParams, executeRequest.Object); // Request to save the results as json with correct parameters var saveParams = new SaveResultsAsJsonRequestParams @@ -200,19 +192,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution OwnerUri = Common.OwnerUri, ResultSetIndex = 0, BatchIndex = 0, - FilePath = "testwrite_4.json" + FilePath = "testwrite_4.json" }; SaveResultRequestResult result = null; var saveRequest = GetSaveResultsContextMock(qcr => result = qcr, null); - queryService.ActiveQueries[Common.OwnerUri].Batches[0] = Common.GetBasicExecutedBatch(); // Call save results and wait on the save task await queryService.HandleSaveResultsAsJsonRequest(saveParams, saveRequest.Object); ResultSet selectedResultSet = queryService.ActiveQueries[saveParams.OwnerUri].Batches[saveParams.BatchIndex].ResultSets[saveParams.ResultSetIndex]; Task saveTask = selectedResultSet.GetSaveTask(saveParams.FilePath); await saveTask; - - // Expect to see a file successfully created in filepath and a success message Assert.Null(result.Messages); @@ -233,11 +222,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution public async void SaveResultsAsJsonWithSelectionSuccessTest() { // Execute a query - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, Common.GetPrimedWorkspaceService()); + var workspaceService = Common.GetPrimedWorkspaceService(Common.StandardQuery); + var queryService = Common.GetPrimedExecutionService(new[] { Common.StandardTestData }, true, false, workspaceService); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument , OwnerUri = Common.OwnerUri }; var executeRequest = RequestContextMocks.Create(null); - await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); - await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; + await Common.AwaitExecution(queryService, executeParams, executeRequest.Object); // Request to save the results as json with correct parameters var saveParams = new SaveResultsAsJsonRequestParams @@ -253,18 +242,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution }; SaveResultRequestResult result = null; var saveRequest = GetSaveResultsContextMock(qcr => result = qcr, null); - queryService.ActiveQueries[Common.OwnerUri].Batches[0] = Common.GetBasicExecutedBatch(); // Call save results and wait on the save task await queryService.HandleSaveResultsAsJsonRequest(saveParams, saveRequest.Object); ResultSet selectedResultSet = queryService.ActiveQueries[saveParams.OwnerUri].Batches[saveParams.BatchIndex].ResultSets[saveParams.ResultSetIndex]; - Task saveTask = selectedResultSet.GetSaveTask(saveParams.FilePath); - await saveTask; - + await selectedResultSet.GetSaveTask(saveParams.FilePath); + // Expect to see a file successfully created in filepath and a success message + VerifySaveResultsCallCount(saveRequest, Times.Once(), Times.Never()); Assert.Null(result.Messages); Assert.True(File.Exists(saveParams.FilePath)); - VerifySaveResultsCallCount(saveRequest, Times.Once(), Times.Never()); // Delete temp file after test if (File.Exists(saveParams.FilePath)) @@ -280,11 +267,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution public async void SaveResultsAsJsonExceptionTest() { // Execute a query - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, Common.GetPrimedWorkspaceService()); + var workspaceService = Common.GetPrimedWorkspaceService(Common.StandardQuery); + var queryService = Common.GetPrimedExecutionService(new [] {Common.StandardTestData}, true, false, workspaceService); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; var executeRequest = RequestContextMocks.Create(null); - await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); - await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; + await Common.AwaitExecution(queryService, executeParams, executeRequest.Object); // Request to save the results as json with incorrect filepath var saveParams = new SaveResultsAsJsonRequestParams @@ -303,8 +290,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // Call save results and wait on the save task await queryService.HandleSaveResultsAsJsonRequest(saveParams, saveRequest.Object); ResultSet selectedResultSet = queryService.ActiveQueries[saveParams.OwnerUri].Batches[saveParams.BatchIndex].ResultSets[saveParams.ResultSetIndex]; - Task saveTask = selectedResultSet.GetSaveTask(saveParams.FilePath); - await saveTask; + await selectedResultSet.GetSaveTask(saveParams.FilePath); // Expect to see error message Assert.NotNull(errMessage); @@ -316,12 +302,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution /// Test saving results to JSON file when the requested result set is no longer active /// [Fact] - public async void SaveResultsAsJsonQueryNotFoundTest() + public async Task SaveResultsAsJsonQueryNotFoundTest() { - // Create a query service - var workspaceService = new Mock>(); - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var workspaceService = Common.GetPrimedWorkspaceService(Common.StandardQuery); + var queryService = Common.GetPrimedExecutionService(null, true, false, workspaceService); // Request to save the results as json with query that is no longer active var saveParams = new SaveResultsAsJsonRequestParams @@ -333,7 +318,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution }; SaveResultRequestResult result = null; var saveRequest = GetSaveResultsContextMock(qcr => result = qcr, null); - queryService.HandleSaveResultsAsJsonRequest(saveParams, saveRequest.Object).Wait(); + await queryService.HandleSaveResultsAsJsonRequest(saveParams, saveRequest.Object); // Expect message that save failed Assert.Equal("Failed to save results, ID not found.", result.Messages); @@ -353,25 +338,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution Action resultCallback, Action errorCallback) { - var requestContext = new Mock>(); - - // Setup the mock for SendResult - var sendResultFlow = requestContext - .Setup(rc => rc.SendResult(It.IsAny ())) - .Returns(Task.FromResult(0)); - if (resultCallback != null) - { - sendResultFlow.Callback(resultCallback); - } - - // Setup the mock for SendError - var sendErrorFlow = requestContext - .Setup(rc => rc.SendError(It.IsAny())) - .Returns(Task.FromResult(0)); - if (errorCallback != null) - { - sendErrorFlow.Callback(errorCallback); - } + var requestContext = RequestContextMocks.Create(resultCallback) + .AddErrorHandling(errorCallback); return requestContext; } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs index 4036c655..d7a5916a 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs @@ -4,15 +4,13 @@ // using System; +using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.QueryExecution; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; -using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.Test.Utility; -using Microsoft.SqlTools.ServiceLayer.Workspace; -using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; using Moq; using Xunit; @@ -23,9 +21,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution #region ResultSet Class Tests [Theory] - [InlineData(0,2)] - [InlineData(0,20)] - [InlineData(1,2)] + [InlineData(0, 2)] + [InlineData(0, 20)] + [InlineData(1, 2)] public void ResultSetValidTest(int startRow, int rowCount) { // Setup: @@ -60,6 +58,17 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution Assert.ThrowsAsync(() => rs.GetSubset(rowStartIndex, rowCount)).Wait(); } + [Fact] + public async Task ResultSetNotReadTest() + { + // If: + // ... I have a resultset that hasn't been executed and I request a valid result set from it + // Then: + // ... It should throw an exception for having not been read + ResultSet rs = new ResultSet(new TestDbDataReader(null), Common.Ordinal, Common.Ordinal, Common.GetFileStreamFactory(new Dictionary())); + await Assert.ThrowsAsync(() => rs.GetSubset(0, 1)); + } + #endregion #region Batch Class Tests @@ -99,18 +108,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution #region Query Class Tests - [Fact] - public void SubsetUnexecutedQueryTest() - { - // If I have a query that has *not* been executed - Query q = new Query(Common.StandardQuery, Common.CreateTestConnectionInfo(null, false), new QueryExecutionSettings(), Common.GetFileStreamFactory()); - - // ... And I ask for a subset with valid arguments - // Then: - // ... It should throw an exception - Assert.ThrowsAsync(() => q.GetSubset(0, 0, 0, 2)).Wait(); - } - [Theory] [InlineData(-1)] // Invalid batch, too low [InlineData(2)] // Invalid batch, too high @@ -132,26 +129,17 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution [Fact] public async Task SubsetServiceValidTest() { - - // Set up file for returning the query - var fileMock = new Mock(); - fileMock.SetupGet(file => file.Contents).Returns(Common.StandardQuery); - // Set up workspace mock - var workspaceService = new Mock>(); - workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) - .Returns(fileMock.Object); // If: // ... I have a query that has results (doesn't matter what) - var queryService = await Common.GetPrimedExecutionService( - Common.CreateMockFactory(new[] {Common.StandardTestData}, false), true, - workspaceService.Object); + var workspaceService = Common.GetPrimedWorkspaceService(Common.StandardQuery); + var queryService = Common.GetPrimedExecutionService(new[] {Common.StandardTestData}, true, false, workspaceService); var executeParams = new QueryExecuteParams {QuerySelection = null, OwnerUri = Common.OwnerUri}; var executeRequest = RequestContextMocks.Create(null); await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; // ... And I then ask for a valid set of results from it - var subsetParams = new QueryExecuteSubsetParams {OwnerUri = Common.OwnerUri, RowsCount = 1, ResultSetIndex = 0, RowsStartIndex = 0}; + var subsetParams = new QueryExecuteSubsetParams { OwnerUri = Common.OwnerUri, RowsCount = 1, ResultSetIndex = 0, RowsStartIndex = 0 }; QueryExecuteSubsetResult result = null; var subsetRequest = GetQuerySubsetResultContextMock(qesr => result = qesr, null); await queryService.HandleResultSubsetRequest(subsetParams, subsetRequest.Object); @@ -166,17 +154,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution } [Fact] - public async void SubsetServiceMissingQueryTest() + public async Task SubsetServiceMissingQueryTest() { - - var workspaceService = new Mock>(); // If: // ... I ask for a set of results for a file that hasn't executed a query - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var workspaceService = Common.GetPrimedWorkspaceService(Common.StandardQuery); + var queryService = Common.GetPrimedExecutionService(null, true, false, workspaceService); var subsetParams = new QueryExecuteSubsetParams { OwnerUri = Common.OwnerUri, RowsCount = 1, ResultSetIndex = 0, RowsStartIndex = 0 }; QueryExecuteSubsetResult result = null; var subsetRequest = GetQuerySubsetResultContextMock(qesr => result = qesr, null); - queryService.HandleResultSubsetRequest(subsetParams, subsetRequest.Object).Wait(); + await queryService.HandleResultSubsetRequest(subsetParams, subsetRequest.Object); // Then: // ... I should have an error result @@ -188,32 +175,23 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution } [Fact] - public async void SubsetServiceUnexecutedQueryTest() + public async Task SubsetServiceUnexecutedQueryTest() { - - // Set up file for returning the query - var fileMock = new Mock(); - fileMock.SetupGet(file => file.Contents).Returns(Common.StandardQuery); - // Set up workspace mock - var workspaceService = new Mock>(); - workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) - .Returns(fileMock.Object); // If: // ... I have a query that hasn't finished executing (doesn't matter what) - var queryService = await Common.GetPrimedExecutionService( - Common.CreateMockFactory(new[] { Common.StandardTestData }, false), true, - workspaceService.Object); + var workspaceService = Common.GetPrimedWorkspaceService(Common.StandardQuery); + var queryService = Common.GetPrimedExecutionService(new[] { Common.StandardTestData }, true, false, workspaceService); var executeParams = new QueryExecuteParams { QuerySelection = null, OwnerUri = Common.OwnerUri }; var executeRequest = RequestContextMocks.Create(null); await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; - queryService.ActiveQueries[Common.OwnerUri].HasExecuted = false; + queryService.ActiveQueries[Common.OwnerUri].Batches[0].ResultSets[0].hasBeenRead = false; // ... And I then ask for a valid set of results from it var subsetParams = new QueryExecuteSubsetParams { OwnerUri = Common.OwnerUri, RowsCount = 1, ResultSetIndex = 0, RowsStartIndex = 0 }; QueryExecuteSubsetResult result = null; var subsetRequest = GetQuerySubsetResultContextMock(qesr => result = qesr, null); - queryService.HandleResultSubsetRequest(subsetParams, subsetRequest.Object).Wait(); + await queryService.HandleResultSubsetRequest(subsetParams, subsetRequest.Object); // Then: // ... I should get an error result @@ -226,11 +204,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution [Fact] public async void SubsetServiceOutOfRangeSubsetTest() - { + { // If: // ... I have a query that doesn't have any result sets - var queryService = await Common.GetPrimedExecutionService( - Common.CreateMockFactory(null, false), true, Common.GetPrimedWorkspaceService()); + var workspaceService = Common.GetPrimedWorkspaceService(Common.StandardQuery); + var queryService = Common.GetPrimedExecutionService(null, true, false, workspaceService); var executeParams = new QueryExecuteParams { QuerySelection = null, OwnerUri = Common.OwnerUri }; var executeRequest = RequestContextMocks.Create(null); await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/ServiceHost/AsyncLockTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/ServiceHost/AsyncLockTests.cs new file mode 100644 index 00000000..fdf966c9 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/ServiceHost/AsyncLockTests.cs @@ -0,0 +1,49 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Utility; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.ServiceHost +{ + public class AsyncLockTests + { + [Fact] + public async Task AsyncLockSynchronizesAccess() + { + AsyncLock asyncLock = new AsyncLock(); + + Task lockOne = asyncLock.LockAsync(); + Task lockTwo = asyncLock.LockAsync(); + + Assert.Equal(TaskStatus.RanToCompletion, lockOne.Status); + Assert.Equal(TaskStatus.WaitingForActivation, lockTwo.Status); + lockOne.Result.Dispose(); + + await lockTwo; + Assert.Equal(TaskStatus.RanToCompletion, lockTwo.Status); + } + + [Fact] + public void AsyncLockCancelsWhenRequested() + { + CancellationTokenSource cts = new CancellationTokenSource(); + AsyncLock asyncLock = new AsyncLock(); + + Task lockOne = asyncLock.LockAsync(); + Task lockTwo = asyncLock.LockAsync(cts.Token); + + // Cancel the second lock before the first is released + cts.Cancel(); + lockOne.Result.Dispose(); + + Assert.Equal(TaskStatus.RanToCompletion, lockOne.Status); + Assert.Equal(TaskStatus.Canceled, lockTwo.Status); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/ServiceHost/AsyncQueueTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/ServiceHost/AsyncQueueTests.cs new file mode 100644 index 00000000..df268417 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/ServiceHost/AsyncQueueTests.cs @@ -0,0 +1,91 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Utility; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.ServiceHost +{ + public class AsyncQueueTests + { + [Fact] + public async Task AsyncQueueSynchronizesAccess() + { + ConcurrentBag outputItems = new ConcurrentBag(); + AsyncQueue inputQueue = new AsyncQueue(Enumerable.Range(0, 100)); + CancellationTokenSource cancellationTokenSource = new CancellationTokenSource(); + + try + { + // Start 5 consumers + await Task.WhenAll( + Task.Run(() => ConsumeItems(inputQueue, outputItems, cancellationTokenSource.Token)), + Task.Run(() => ConsumeItems(inputQueue, outputItems, cancellationTokenSource.Token)), + Task.Run(() => ConsumeItems(inputQueue, outputItems, cancellationTokenSource.Token)), + Task.Run(() => ConsumeItems(inputQueue, outputItems, cancellationTokenSource.Token)), + Task.Run(() => ConsumeItems(inputQueue, outputItems, cancellationTokenSource.Token)), + Task.Run( + async () => + { + // Wait for a bit and then add more items to the queue + await Task.Delay(250); + + foreach (var i in Enumerable.Range(100, 200)) + { + await inputQueue.EnqueueAsync(i); + } + + // Cancel the waiters + cancellationTokenSource.Cancel(); + })); + } + catch (TaskCanceledException) + { + // Do nothing, this is expected. + } + + // At this point, numbers 0 through 299 should be in the outputItems + IEnumerable expectedItems = Enumerable.Range(0, 300); + Assert.Equal(0, expectedItems.Except(outputItems).Count()); + } + + [Fact] + public async Task AsyncQueueSkipsCancelledTasks() + { + AsyncQueue inputQueue = new AsyncQueue(); + + // Queue up a couple of tasks to wait for input + CancellationTokenSource cancellationSource = new CancellationTokenSource(); + Task taskOne = inputQueue.DequeueAsync(cancellationSource.Token); + Task taskTwo = inputQueue.DequeueAsync(); + + // Cancel the first task and then enqueue a number + cancellationSource.Cancel(); + await inputQueue.EnqueueAsync(1); + + // Did the second task get the number? + Assert.Equal(TaskStatus.Canceled, taskOne.Status); + Assert.Equal(TaskStatus.RanToCompletion, taskTwo.Status); + Assert.Equal(1, taskTwo.Result); + } + + private async Task ConsumeItems( + AsyncQueue inputQueue, + ConcurrentBag outputItems, + CancellationToken cancellationToken) + { + while (!cancellationToken.IsCancellationRequested) + { + int consumedItem = await inputQueue.DequeueAsync(cancellationToken); + outputItems.Add(consumedItem); + } + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/ServiceHost/ScriptFileTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/ServiceHost/ScriptFileTests.cs new file mode 100644 index 00000000..817f3328 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/ServiceHost/ScriptFileTests.cs @@ -0,0 +1,503 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.IO; +using System.Linq; +using Microsoft.SqlTools.ServiceLayer.Utility; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.ServiceHost +{ + /// + /// ScriptFile test case + /// + public class ScriptFileTests + { + internal static object fileLock = new object(); + + private static readonly string query = + "SELECT * FROM sys.objects as o1" + Environment.NewLine + + "SELECT * FROM sys.objects as o2" + Environment.NewLine + + "SELECT * FROM sys.objects as o3" + Environment.NewLine; + + internal static ScriptFile GetTestScriptFile(string initialText = null) + { + if (initialText == null) + { + initialText = ScriptFileTests.query; + } + + string ownerUri = System.IO.Path.GetTempFileName(); + + // Write the query text to a backing file + lock (fileLock) + { + System.IO.File.WriteAllText(ownerUri, initialText); + } + + return new ScriptFile(ownerUri, ownerUri, initialText); + } + + /// + /// Validate GetLinesInRange with invalid range + /// + [Fact] + public void GetLinesInRangeWithInvalidRangeTest() + { + ScriptFile scriptFile = GetTestScriptFile(); + + bool exceptionRaised = false; + try + { + scriptFile.GetLinesInRange( + new BufferRange( + new BufferPosition(1, 0), + new BufferPosition(2, 0))); + } + catch (ArgumentOutOfRangeException) + { + exceptionRaised = true; + } + + Assert.True(exceptionRaised, "ArgumentOutOfRangeException raised for invalid index"); + + } + + /// + /// Validate GetLinesInRange + /// + [Fact] + public void GetLinesInRangeTest() + { + ScriptFile scriptFile = GetTestScriptFile(); + + string id = scriptFile.Id; + Assert.True(!string.IsNullOrWhiteSpace(id)); + + BufferRange range =scriptFile.FileRange; + Assert.Null(range); + + string[] lines = scriptFile.GetLinesInRange( + new BufferRange( + new BufferPosition(2, 1), + new BufferPosition(2, 7))); + + Assert.True(lines.Length == 1, "One line in range"); + Assert.True(lines[0].Equals("SELECT"), "Range text is correct"); + + string[] queryLines = query.Split('\n'); + + string line = scriptFile.GetLine(2); + Assert.True(queryLines[1].StartsWith(line), "GetLine text is correct"); + } + + [Fact] + public void GetOffsetAtPositionTest() + { + ScriptFile scriptFile = GetTestScriptFile(); + int offset = scriptFile.GetOffsetAtPosition(2, 5); + int expected = 35 + Environment.NewLine.Length; + Assert.True(offset == expected, "Offset is at expected location"); + + BufferPosition position = scriptFile.GetPositionAtOffset(offset); + Assert.True(position.Line == 2 && position.Column == 5, "Position is at expected location"); + } + + [Fact] + public void GetRangeBetweenOffsetsTest() + { + ScriptFile scriptFile = GetTestScriptFile(); + BufferRange range = scriptFile.GetRangeBetweenOffsets( + scriptFile.GetOffsetAtPosition(2, 1), + scriptFile.GetOffsetAtPosition(2, 7)); + Assert.NotNull(range); + } + + [Fact] + public void CanApplySingleLineInsert() + { + this.AssertFileChange( + "This is a test.", + "This is a working test.", + new FileChange + { + Line = 1, + EndLine = 1, + Offset = 10, + EndOffset = 10, + InsertString = " working" + }); + } + + [Fact] + public void CanApplySingleLineReplace() + { + this.AssertFileChange( + "This is a potentially broken test.", + "This is a working test.", + new FileChange + { + Line = 1, + EndLine = 1, + Offset = 11, + EndOffset = 29, + InsertString = "working" + }); + } + + [Fact] + public void CanApplySingleLineDelete() + { + this.AssertFileChange( + "This is a test of the emergency broadcasting system.", + "This is a test.", + new FileChange + { + Line = 1, + EndLine = 1, + Offset = 15, + EndOffset = 52, + InsertString = "" + }); + } + + [Fact] + public void CanApplyMultiLineInsert() + { + this.AssertFileChange( + "first\r\nsecond\r\nfifth", + "first\r\nsecond\r\nthird\r\nfourth\r\nfifth", + new FileChange + { + Line = 3, + EndLine = 3, + Offset = 1, + EndOffset = 1, + InsertString = "third\r\nfourth\r\n" + }); + } + + [Fact] + public void CanApplyMultiLineReplace() + { + this.AssertFileChange( + "first\r\nsecoXX\r\nXXfth", + "first\r\nsecond\r\nthird\r\nfourth\r\nfifth", + new FileChange + { + Line = 2, + EndLine = 3, + Offset = 5, + EndOffset = 3, + InsertString = "nd\r\nthird\r\nfourth\r\nfi" + }); + } + + [Fact] + public void CanApplyMultiLineReplaceWithRemovedLines() + { + this.AssertFileChange( + "first\r\nsecoXX\r\nREMOVE\r\nTHESE\r\nLINES\r\nXXfth", + "first\r\nsecond\r\nthird\r\nfourth\r\nfifth", + new FileChange + { + Line = 2, + EndLine = 6, + Offset = 5, + EndOffset = 3, + InsertString = "nd\r\nthird\r\nfourth\r\nfi" + }); + } + + [Fact] + public void CanApplyMultiLineDelete() + { + this.AssertFileChange( + "first\r\nsecond\r\nREMOVE\r\nTHESE\r\nLINES\r\nthird", + "first\r\nsecond\r\nthird", + new FileChange + { + Line = 3, + EndLine = 6, + Offset = 1, + EndOffset = 1, + InsertString = "" + }); + } + + [Fact] + public void ThrowsExceptionWithEditOutsideOfRange() + { + Assert.Throws( + typeof(ArgumentOutOfRangeException), + () => + { + this.AssertFileChange( + "first\r\nsecond\r\nREMOVE\r\nTHESE\r\nLINES\r\nthird", + "first\r\nsecond\r\nthird", + new FileChange + { + Line = 3, + EndLine = 7, + Offset = 1, + EndOffset = 1, + InsertString = "" + }); + }); + } + + private void AssertFileChange( + string initialString, + string expectedString, + FileChange fileChange) + { + // Create an in-memory file from the StringReader + ScriptFile fileToChange = GetTestScriptFile(initialString); + + // Apply the FileChange and assert the resulting contents + fileToChange.ApplyChange(fileChange); + Assert.Equal(expectedString, fileToChange.Contents); + } + } + + public class ScriptFileGetLinesTests + { + private ScriptFile scriptFile; + + private const string TestString = "Line One\r\nLine Two\r\nLine Three\r\nLine Four\r\nLine Five"; + private readonly string[] TestStringLines = + TestString.Split( + new string[] { "\r\n" }, + StringSplitOptions.None); + + public ScriptFileGetLinesTests() + { + this.scriptFile = + ScriptFileTests.GetTestScriptFile( + "Line One\r\nLine Two\r\nLine Three\r\nLine Four\r\nLine Five\r\n"); + } + + [Fact] + public void CanGetWholeLine() + { + string[] lines = + this.scriptFile.GetLinesInRange( + new BufferRange(5, 1, 5, 10)); + + Assert.Equal(1, lines.Length); + Assert.Equal("Line Five", lines[0]); + } + + [Fact] + public void CanGetMultipleWholeLines() + { + string[] lines = + this.scriptFile.GetLinesInRange( + new BufferRange(2, 1, 4, 10)); + + Assert.Equal(TestStringLines.Skip(1).Take(3), lines); + } + + [Fact] + public void CanGetSubstringInSingleLine() + { + string[] lines = + this.scriptFile.GetLinesInRange( + new BufferRange(4, 3, 4, 8)); + + Assert.Equal(1, lines.Length); + Assert.Equal("ne Fo", lines[0]); + } + + [Fact] + public void CanGetEmptySubstringRange() + { + string[] lines = + this.scriptFile.GetLinesInRange( + new BufferRange(4, 3, 4, 3)); + + Assert.Equal(1, lines.Length); + Assert.Equal("", lines[0]); + } + + [Fact] + public void CanGetSubstringInMultipleLines() + { + string[] expectedLines = new string[] + { + "Two", + "Line Three", + "Line Fou" + }; + + string[] lines = + this.scriptFile.GetLinesInRange( + new BufferRange(2, 6, 4, 9)); + + Assert.Equal(expectedLines, lines); + } + + [Fact] + public void CanGetRangeAtLineBoundaries() + { + string[] expectedLines = new string[] + { + "", + "Line Three", + "" + }; + + string[] lines = + this.scriptFile.GetLinesInRange( + new BufferRange(2, 9, 4, 1)); + + Assert.Equal(expectedLines, lines); + } + } + + public class ScriptFilePositionTests + { + private ScriptFile scriptFile; + + public ScriptFilePositionTests() + { + this.scriptFile = + ScriptFileTests.GetTestScriptFile(@" +First line + Second line is longer + Third line +"); + } + + [Fact] + public void CanOffsetByLine() + { + this.AssertNewPosition( + 1, 1, + 2, 0, + 3, 1); + + this.AssertNewPosition( + 3, 1, + -2, 0, + 1, 1); + } + + [Fact] + public void CanOffsetByColumn() + { + this.AssertNewPosition( + 2, 1, + 0, 2, + 2, 3); + + this.AssertNewPosition( + 2, 5, + 0, -3, + 2, 2); + } + + [Fact] + public void ThrowsWhenPositionOutOfRange() + { + // Less than line range + Assert.Throws( + typeof(ArgumentOutOfRangeException), + () => + { + scriptFile.CalculatePosition( + new BufferPosition(1, 1), + -10, 0); + }); + + // Greater than line range + Assert.Throws( + typeof(ArgumentOutOfRangeException), + () => + { + scriptFile.CalculatePosition( + new BufferPosition(1, 1), + 10, 0); + }); + + // Less than column range + Assert.Throws( + typeof(ArgumentOutOfRangeException), + () => + { + scriptFile.CalculatePosition( + new BufferPosition(1, 1), + 0, -10); + }); + + // Greater than column range + Assert.Throws( + typeof(ArgumentOutOfRangeException), + () => + { + scriptFile.CalculatePosition( + new BufferPosition(1, 1), + 0, 10); + }); + } + + [Fact] + public void CanFindBeginningOfLine() + { + this.AssertNewPosition( + 4, 12, + pos => pos.GetLineStart(), + 4, 5); + } + + [Fact] + public void CanFindEndOfLine() + { + this.AssertNewPosition( + 4, 12, + pos => pos.GetLineEnd(), + 4, 15); + } + + [Fact] + public void CanComposePositionOperations() + { + this.AssertNewPosition( + 4, 12, + pos => pos.AddOffset(-1, 1).GetLineStart(), + 3, 3); + } + + private void AssertNewPosition( + int originalLine, int originalColumn, + int lineOffset, int columnOffset, + int expectedLine, int expectedColumn) + { + this.AssertNewPosition( + originalLine, originalColumn, + pos => pos.AddOffset(lineOffset, columnOffset), + expectedLine, expectedColumn); + } + + private void AssertNewPosition( + int originalLine, int originalColumn, + Func positionOperation, + int expectedLine, int expectedColumn) + { + var newPosition = + positionOperation( + new FilePosition( + this.scriptFile, + originalLine, + originalColumn)); + + Assert.Equal(expectedLine, newPosition.Line); + Assert.Equal(expectedColumn, newPosition.Column); + } + + + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/ServiceHost/SrTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/ServiceHost/SrTests.cs new file mode 100644 index 00000000..38f46228 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/ServiceHost/SrTests.cs @@ -0,0 +1,51 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.ServiceHost +{ + /// + /// ScriptFile test case + /// + public class SrTests + { + /// + /// Simple "test" to access string resources + /// The purpose of this test is for code coverage. It's probably better to just + /// exclude string resources in the code coverage report than maintain this test. + /// + [Fact] + public void SrStringsTest() + { + var culture = SR.Culture; + SR.Culture = culture; + Assert.True(SR.Culture == culture); + + var connectionServiceListDbErrorNullOwnerUri = SR.ConnectionServiceListDbErrorNullOwnerUri; + var connectionParamsValidateNullConnection = SR.ConnectionParamsValidateNullConnection; + var credentialsServiceInvalidCriticalHandle = SR.CredentialsServiceInvalidCriticalHandle; + var credentialsServicePasswordLengthExceeded = SR.CredentialsServicePasswordLengthExceeded; + var credentialsServiceTargetForDelete = SR.CredentialsServiceTargetForDelete; + var credentialsServiceTargetForLookup = SR.CredentialsServiceTargetForLookup; + var queryServiceCancelDisposeFailed = SR.QueryServiceCancelDisposeFailed; + var queryServiceQueryCancelled = SR.QueryServiceQueryCancelled; + var queryServiceDataReaderByteCountInvalid = SR.QueryServiceDataReaderByteCountInvalid; + var queryServiceDataReaderCharCountInvalid = SR.QueryServiceDataReaderCharCountInvalid; + var queryServiceDataReaderXmlCountInvalid = SR.QueryServiceDataReaderXmlCountInvalid; + var queryServiceFileWrapperReadOnly = SR.QueryServiceFileWrapperReadOnly; + var queryServiceAffectedOneRow = SR.QueryServiceAffectedOneRow; + var queryServiceMessageSenderNotSql = SR.QueryServiceMessageSenderNotSql; + var queryServiceResultSetNotRead = SR.QueryServiceResultSetNotRead; + var queryServiceResultSetNoColumnSchema = SR.QueryServiceResultSetNoColumnSchema; + var connectionServiceListDbErrorNotConnected = SR.ConnectionServiceListDbErrorNotConnected(".."); + var connectionServiceConnStringInvalidAuthType = SR.ConnectionServiceConnStringInvalidAuthType(".."); + var connectionServiceConnStringInvalidIntent = SR.ConnectionServiceConnStringInvalidIntent(".."); + var queryServiceAffectedRows = SR.QueryServiceAffectedRows(10); + var queryServiceErrorFormat = SR.QueryServiceErrorFormat(1, 1, 1, 1, "\n", ".."); + var queryServiceQueryFailed = SR.QueryServiceQueryFailed(".."); + var workspaceServiceBufferPositionOutOfOrder = SR.WorkspaceServiceBufferPositionOutOfOrder(1, 2, 3, 4); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/LongListTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/LongListTests.cs new file mode 100644 index 00000000..a28714bc --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/LongListTests.cs @@ -0,0 +1,29 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Utility; +using Xunit; + +namespace Microsoft.SqlTools.Test.Utility +{ + /// + /// Tests for the LongList class + /// + public class LongListTests + { + /// + /// Add and remove and item in a LongList + /// + [Fact] + public void LongListTest() + { + var longList = new LongList(); + longList.Add('.'); + Assert.True(longList.Count == 1); + longList.RemoveAt(0); + Assert.True(longList.Count == 0); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/MoqExtensions.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/MoqExtensions.cs new file mode 100644 index 00000000..aaace1b7 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/MoqExtensions.cs @@ -0,0 +1,42 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Reflection; +using Moq.Language; +using Moq.Language.Flow; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Utility +{ + public static class MoqExtensions + { + public delegate void OutAction(out TOut outVal); + + public delegate void OutAction(T1 arg1, out TOut outVal); + + public static IReturnsThrows OutCallback( + this ICallback mock, OutAction action) where TMock : class + { + return OutCallbackInternal(mock, action); + } + + public static IReturnsThrows OutCallback( + this ICallback mock, OutAction action) where TMock : class + { + return OutCallbackInternal(mock, action); + } + + private static IReturnsThrows OutCallbackInternal( + ICallback mock, object action) where TMock : class + { + typeof(ICallback).GetTypeInfo() + .Assembly.GetType("Moq.MethodCall") + .GetMethod("SetCallbackWithArguments", + BindingFlags.InvokeMethod | BindingFlags.NonPublic | BindingFlags.Instance) + .Invoke(mock, new[] { action }); + return mock as IReturnsThrows; + + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/RequestContextMocks.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/RequestContextMocks.cs index 91e05a76..798e3bcd 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/RequestContextMocks.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/RequestContextMocks.cs @@ -50,11 +50,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Utility this Mock> mock, Action errorCallback) { - // Setup the mock for SendError var sendErrorFlow = mock.Setup(rc => rc.SendError(It.IsAny())) .Returns(Task.FromResult(0)); - if (mock != null && errorCallback != null) + if (errorCallback != null) { sendErrorFlow.Callback(errorCallback); } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbColumn.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbColumn.cs index c2765783..00e88637 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbColumn.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbColumn.cs @@ -8,14 +8,20 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Utility { public class TestDbColumn : DbColumn { - public TestDbColumn() + public TestDbColumn(string columnName) { base.IsLong = false; - base.ColumnName = "Test Column"; + base.ColumnName = columnName; base.ColumnSize = 128; base.AllowDBNull = true; base.DataType = typeof(string); base.DataTypeName = "nvarchar"; } + + public TestDbColumn(string columnName, int numericScale) + : this(columnName) + { + base.NumericScale = numericScale; + } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs index 0330cda0..a9233d0b 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs @@ -8,7 +8,6 @@ using System.Collections.Generic; using System.Collections.ObjectModel; using System.Data.Common; using System.Linq; -using Moq; namespace Microsoft.SqlTools.ServiceLayer.Test.Utility { @@ -93,7 +92,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Utility List columns = new List(); for (int i = 0; i < ResultSet.Current[0].Count; i++) { - columns.Add(new TestDbColumn()); + columns.Add(new TestDbColumn(ResultSet.Current[0].Keys.ToArray()[i])); } return new ReadOnlyCollection(columns); } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs index 67e48786..0f861433 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs @@ -3,16 +3,22 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -//#define USE_LIVE_CONNECTION - using System; using System.Collections.Generic; using System.Data; using System.Data.Common; +using System.IO; +using System.Reflection; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; +using Microsoft.SqlTools.ServiceLayer.Credentials; +using Microsoft.SqlTools.ServiceLayer.Hosting; using Microsoft.SqlTools.ServiceLayer.LanguageServices; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.Test.Utility; +using Microsoft.SqlTools.ServiceLayer.Workspace; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; namespace Microsoft.SqlTools.Test.Utility { @@ -21,6 +27,7 @@ namespace Microsoft.SqlTools.Test.Utility /// public class TestObjects { + private static bool hasInitServices = false; public const string ScriptUri = "file://some/file.sql"; /// @@ -28,13 +35,14 @@ namespace Microsoft.SqlTools.Test.Utility /// public static ConnectionService GetTestConnectionService() { -#if !USE_LIVE_CONNECTION // use mock database connection return new ConnectionService(new TestSqlConnectionFactory()); -#else + } + + public static ConnectionService GetLiveTestConnectionService() + { // connect to a real server instance return ConnectionService.Instance; -#endif } /// @@ -65,9 +73,22 @@ namespace Microsoft.SqlTools.Test.Utility return new ConnectionDetails() { UserName = "sa", - Password = "Yukon900", - DatabaseName = "AdventureWorks2016CTP3_2", - ServerName = "sqltools11" + Password = "...", + DatabaseName = "master", + ServerName = "localhost" + }; + } + + /// + /// Gets a ConnectionDetails for connecting to localhost with integrated auth + /// + public static ConnectionDetails GetIntegratedTestConnectionDetails() + { + return new ConnectionDetails() + { + DatabaseName = "master", + ServerName = "localhost", + AuthenticationType = "Integrated" }; } @@ -85,14 +106,111 @@ namespace Microsoft.SqlTools.Test.Utility /// public static ISqlConnectionFactory GetTestSqlConnectionFactory() { -#if !USE_LIVE_CONNECTION // use mock database connection return new TestSqlConnectionFactory(); -#else + } + + /// + /// Creates a test sql connection factory instance + /// + public static ISqlConnectionFactory GetLiveTestSqlConnectionFactory() + { // connect to a real server instance - return ConnectionService.Instance.ConnectionFactory; -#endif + return ConnectionService.Instance.ConnectionFactory; + } + + public static void InitializeTestServices() + { + if (TestObjects.hasInitServices) + { + return; + } + + TestObjects.hasInitServices = true; + + const string hostName = "SQ Tools Test Service Host"; + const string hostProfileId = "SQLToolsTestService"; + Version hostVersion = new Version(1,0); + + // set up the host details and profile paths + var hostDetails = new HostDetails(hostName, hostProfileId, hostVersion); + SqlToolsContext sqlToolsContext = new SqlToolsContext(hostDetails); + + // Grab the instance of the service host + ServiceHost serviceHost = ServiceHost.Instance; + + // Start the service + serviceHost.Start().Wait(); + + // Initialize the services that will be hosted here + WorkspaceService.Instance.InitializeService(serviceHost); + LanguageService.Instance.InitializeService(serviceHost, sqlToolsContext); + ConnectionService.Instance.InitializeService(serviceHost); + CredentialService.Instance.InitializeService(serviceHost); + QueryExecutionService.Instance.InitializeService(serviceHost); + + serviceHost.Initialize(); + } + + public static string GetTestSqlFile() + { + string filePath = Path.Combine( + Path.GetDirectoryName(Assembly.GetEntryAssembly().Location), + "sqltest.sql"); + if (File.Exists(filePath)) + { + File.Delete(filePath); + } + + File.WriteAllText(filePath, "SELECT * FROM sys.objects\n"); + + return filePath; + } + + public static ConnectionInfo InitLiveConnectionInfo(out ScriptFile scriptFile) + { + TestObjects.InitializeTestServices(); + + string sqlFilePath = GetTestSqlFile(); + scriptFile = WorkspaceService.Instance.Workspace.GetFile(sqlFilePath); + + string ownerUri = scriptFile.ClientFilePath; + var connectionService = TestObjects.GetLiveTestConnectionService(); + var connectionResult = + connectionService + .Connect(new ConnectParams() + { + OwnerUri = ownerUri, + Connection = TestObjects.GetIntegratedTestConnectionDetails() + }); + + connectionResult.Wait(); + + ConnectionInfo connInfo = null; + connectionService.TryFindConnection(ownerUri, out connInfo); + return connInfo; + } + + public static ConnectionInfo InitLiveConnectionInfoForDefinition() + { + TestObjects.InitializeTestServices(); + + string ownerUri = ScriptUri; + var connectionService = TestObjects.GetLiveTestConnectionService(); + var connectionResult = + connectionService + .Connect(new ConnectParams() + { + OwnerUri = ownerUri, + Connection = TestObjects.GetIntegratedTestConnectionDetails() + }); + + connectionResult.Wait(); + + ConnectionInfo connInfo = null; + connectionService.TryFindConnection(ownerUri, out connInfo); + return connInfo; } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/project.json b/test/Microsoft.SqlTools.ServiceLayer.Test/project.json index f0779583..977de2ab 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/project.json +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/project.json @@ -8,7 +8,8 @@ "Integration": { "buildOptions": { "define": [ - "LIVE_CONNECTION_TESTS" + "LIVE_CONNECTION_TESTS", + "WINDOWS_ONLY_BUILD" ] } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Driver/ServiceTestDriver.cs b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Driver/ServiceTestDriver.cs index ac7e76c0..be5d2e39 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Driver/ServiceTestDriver.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Driver/ServiceTestDriver.cs @@ -8,9 +8,17 @@ // License: https://github.com/PowerShell/PowerShellEditorServices/blob/develop/LICENSE // +using System; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Threading; using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Channel; +using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; namespace Microsoft.SqlTools.ServiceLayer.TestDriver.Driver { @@ -19,21 +27,141 @@ namespace Microsoft.SqlTools.ServiceLayer.TestDriver.Driver /// public class ServiceTestDriver : TestDriverBase { - public ServiceTestDriver(string serviceHostExecutable) + + public const string ServiceCodeCoverageEnvironmentVariable = "SERVICECODECOVERAGE"; + + public const string CodeCoverageToolEnvironmentVariable = "CODECOVERAGETOOL"; + + public const string CodeCoverageOutputEnvironmentVariable = "CODECOVERAGEOUTPUT"; + + public const string ServiceHostEnvironmentVariable = "SQLTOOLSSERVICE_EXE"; + + public bool IsCoverageRun { get; set; } + + private Process[] serviceProcesses; + + private DateTime startTime; + + public ServiceTestDriver() { - var clientChannel = new StdioClientChannel(serviceHostExecutable); + string serviceHostExecutable = Environment.GetEnvironmentVariable(ServiceHostEnvironmentVariable); + string serviceHostArguments = "--enable-logging"; + if (string.IsNullOrWhiteSpace(serviceHostExecutable)) + { + // Include a fallback value to for running tests within visual studio + serviceHostExecutable = + @"..\..\src\Microsoft.SqlTools.ServiceLayer\bin\Debug\netcoreapp1.0\win7-x64\Microsoft.SqlTools.ServiceLayer.exe"; + } + + // Make sure it exists before continuing + if (!File.Exists(serviceHostExecutable)) + { + throw new FileNotFoundException($"Failed to find Microsoft.SqlTools.ServiceLayer.exe at provided location '{serviceHostExecutable}'. " + + "Please set SQLTOOLSERVICE_EXE environment variable to location of exe"); + } + + //setup the service host for code coverage if the envvar is enabled + if (Environment.GetEnvironmentVariable(ServiceCodeCoverageEnvironmentVariable) == "True") + { + string coverageToolPath = Environment.GetEnvironmentVariable(CodeCoverageToolEnvironmentVariable); + if (!string.IsNullOrWhiteSpace(coverageToolPath)) + { + string serviceHostDirectory = Path.GetDirectoryName(serviceHostExecutable); + if (string.IsNullOrWhiteSpace(serviceHostDirectory)) + { + serviceHostDirectory = "."; + } + + string coverageOutput = Environment.GetEnvironmentVariable(CodeCoverageOutputEnvironmentVariable); + if (string.IsNullOrWhiteSpace(coverageOutput)) + { + coverageOutput = "coverage.xml"; + } + + serviceHostArguments = $"-mergeoutput -target:{serviceHostExecutable} -targetargs:{serviceHostArguments} " + + $"-register:user -oldstyle -filter:\"+[Microsoft.SqlTools.*]* -[xunit*]*\" -output:{coverageOutput} " + + $"-searchdirs:{serviceHostDirectory};"; + serviceHostExecutable = coverageToolPath; + + this.IsCoverageRun = true; + } + } + + this.clientChannel = new StdioClientChannel(serviceHostExecutable, serviceHostArguments); this.protocolClient = new ProtocolEndpoint(clientChannel, MessageProtocolType.LanguageServer); } + /// + /// Start the test driver, and launch the sqltoolsservice executable + /// public async Task Start() { + // Store the time we started + startTime = DateTime.Now; + + // Launch the process await this.protocolClient.Start(); await Task.Delay(1000); // Wait for the service host to start + + // If this is a code coverage run, we need access to the service layer separate from open cover + if (IsCoverageRun) + { + CancellationTokenSource cancelSource = new CancellationTokenSource(); + Task getServiceProcess = GetServiceProcess(cancelSource.Token); + Task timeoutTask = Task.Delay(TimeSpan.FromSeconds(15), cancelSource.Token); + if (await Task.WhenAny(getServiceProcess, timeoutTask) == timeoutTask) + { + cancelSource.Cancel(); + throw new Exception("Failed to capture service process"); + } + } + + Console.WriteLine("Successfully launched service"); + + // Setup events to queue for testing + this.QueueEventsForType(ConnectionCompleteNotification.Type); + this.QueueEventsForType(IntelliSenseReadyNotification.Type); + this.QueueEventsForType(QueryExecuteCompleteEvent.Type); + this.QueueEventsForType(PublishDiagnosticsNotification.Type); } + /// + /// Stop the test driver, and shutdown the sqltoolsservice executable + /// public async Task Stop() { - await this.protocolClient.Stop(); + if (IsCoverageRun) + { + // Kill all the processes in the list + foreach (Process p in serviceProcesses.Where(p => !p.HasExited)) + { + p.Kill(); + } + ServiceProcess?.WaitForExit(); + } + else + { + await this.protocolClient.Stop(); + } + } + + private async Task GetServiceProcess(CancellationToken token) + { + while (serviceProcesses == null && !token.IsCancellationRequested) + { + var processes = Process.GetProcessesByName("Microsoft.SqlTools.ServiceLayer") + .Where(p => p.StartTime >= startTime).ToArray(); + + // Wait a second if we can't find the process + if (processes.Any()) + { + serviceProcesses = processes; + } + else + { + await Task.Delay(TimeSpan.FromSeconds(1), token); + } + } } } } \ No newline at end of file diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Driver/TestDriverBase.cs b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Driver/TestDriverBase.cs index 7f86764a..dcc352aa 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Driver/TestDriverBase.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Driver/TestDriverBase.cs @@ -10,8 +10,10 @@ using System; using System.Collections.Concurrent; +using System.Diagnostics; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Channel; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; using Microsoft.SqlTools.ServiceLayer.Utility; @@ -24,12 +26,29 @@ namespace Microsoft.SqlTools.ServiceLayer.TestDriver.Driver { protected ProtocolEndpoint protocolClient; + protected StdioClientChannel clientChannel; + private ConcurrentDictionary> eventQueuePerType = new ConcurrentDictionary>(); private ConcurrentDictionary> requestQueuePerType = new ConcurrentDictionary>(); + public Process ServiceProcess + { + get + { + try + { + return Process.GetProcessById(clientChannel.ProcessId); + } + catch + { + return null; + } + } + } + public Task SendRequest( RequestType requestType, TParams requestParams) diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Microsoft.SqlTools.ServiceLayer.TestDriver.xproj b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Microsoft.SqlTools.ServiceLayer.TestDriver.xproj new file mode 100644 index 00000000..6fc6a045 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Microsoft.SqlTools.ServiceLayer.TestDriver.xproj @@ -0,0 +1,22 @@ + + + + 14.0 + $(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion) + + + + CC785604-6277-4878-8DA9-360C47158E96 + Microsoft.SqlTools.ServiceLayer.TestDriver + .\obj + .\bin\ + v4.5.2 + + + 2.0 + + + + + + \ No newline at end of file diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Program.cs b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Program.cs index 55b0a552..823b2efe 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Program.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Program.cs @@ -8,59 +8,89 @@ using System.Linq; using System.Reflection; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.TestDriver.Driver; +using Microsoft.SqlTools.ServiceLayer.Utility; +using Xunit; + +[assembly: CollectionBehavior(DisableTestParallelization = true)] namespace Microsoft.SqlTools.ServiceLayer.TestDriver { internal class Program { - internal static void Main(string[] args) + internal static int Main(string[] args) { if (args.Length < 1) { - Console.WriteLine( "Microsoft.SqlTools.ServiceLayer.TestDriver.exe [tests]" + Environment.NewLine + - " is the path to the Microsoft.SqlTools.ServiceLayer.exe executable" + Environment.NewLine + + Console.WriteLine( "Microsoft.SqlTools.ServiceLayer.TestDriver.exe [tests]" + Environment.NewLine + " [tests] is a space-separated list of tests to run." + Environment.NewLine + - " They are qualified within the Microsoft.SqlTools.ServiceLayer.TestDriver.Tests namespace"); - Environment.Exit(0); + " They are qualified within the Microsoft.SqlTools.ServiceLayer.TestDriver.Tests namespace" + Environment.NewLine + + "Be sure to set the environment variable " + ServiceTestDriver.ServiceHostEnvironmentVariable + " to the full path of the sqltoolsservice executable."); + return 0; } + Logger.Initialize("testdriver", LogLevel.Verbose); + + int returnCode = 0; Task.Run(async () => { - var serviceHostExecutable = args[0]; - var tests = args.Skip(1); - - foreach (var test in tests) + string testNamespace = "Microsoft.SqlTools.ServiceLayer.TestDriver.Tests."; + foreach (var test in args) { - ServiceTestDriver driver = null; - try { - driver = new ServiceTestDriver(serviceHostExecutable); - - var className = test.Substring(0, test.LastIndexOf('.')); - var methodName = test.Substring(test.LastIndexOf('.') + 1); + var testName = test.Contains(testNamespace) ? test.Replace(testNamespace, "") : test; + bool containsTestName = testName.Contains("."); + var className = containsTestName ? testName.Substring(0, testName.LastIndexOf('.')) : testName; + var methodName = containsTestName ? testName.Substring(testName.LastIndexOf('.') + 1) : null; - var type = Type.GetType("Microsoft.SqlTools.ServiceLayer.TestDriver.Tests." + className); - var typeInstance = Activator.CreateInstance(type); - MethodInfo methodInfo = type.GetMethod(methodName); - - await driver.Start(); - Console.WriteLine("Running test " + test); - await (Task)methodInfo.Invoke(typeInstance, new object[] {driver}); + var type = Type.GetType(testNamespace + className); + if (type == null) + { + Console.WriteLine("Invalid class name"); + } + else + { + if (string.IsNullOrEmpty(methodName)) + { + var methods = type.GetMethods().Where(x => x.CustomAttributes.Any(a => a.AttributeType == typeof(FactAttribute))); + foreach (var method in methods) + { + await RunTest(type, method, method.Name); + } + } + else + { + MethodInfo methodInfo = type.GetMethod(methodName); + await RunTest(type, methodInfo, test); + } + } } catch (Exception ex) { Console.WriteLine(ex.ToString()); - } - finally - { - if (driver != null) - { - await driver.Stop(); - } + returnCode = -1; } } }).Wait(); + + return returnCode; + } + + private static async Task RunTest(Type type, MethodInfo methodInfo, string testName) + { + if (methodInfo == null) + { + Console.WriteLine("Invalid method name"); + } + else + { + using (var typeInstance = (IDisposable)Activator.CreateInstance(type)) + { + Console.WriteLine("Running test " + testName); + await (Task)methodInfo.Invoke(typeInstance, null); + Console.WriteLine("Test ran successfully: " + testName); + } + } } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Scripts/CreateTestDatabase.sql b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Scripts/CreateTestDatabase.sql new file mode 100644 index 00000000..58e4cd19 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Scripts/CreateTestDatabase.sql @@ -0,0 +1,13 @@ +USE master +GO + +DECLARE @dbname nvarchar(128) +SET @dbname = N'#DatabaseName#' + +IF NOT(EXISTS (SELECT name +FROM master.dbo.sysdatabases +WHERE ('[' + name + ']' = @dbname +OR name = @dbname))) +BEGIN + CREATE DATABASE #DatabaseName# +END \ No newline at end of file diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Scripts/CreateTestDatabaseObjects.sql b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Scripts/CreateTestDatabaseObjects.sql new file mode 100644 index 00000000..f815c4f7 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Scripts/CreateTestDatabaseObjects.sql @@ -0,0 +1,231 @@ + +IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'Demo') +BEGIN + EXEC('CREATE SCHEMA [Demo]') +END + +GO +IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'HumanResources') +BEGIN + EXEC('CREATE SCHEMA [HumanResources]') +END +GO +IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'Person') +BEGIN + EXEC('CREATE SCHEMA [Person]') +END +GO +IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'Production') +BEGIN + EXEC('CREATE SCHEMA [Production]') +END +GO +IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'Purchasing') +BEGIN + EXEC('CREATE SCHEMA [Purchasing]') +END +GO +IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'Sales') +BEGIN + EXEC('CREATE SCHEMA [Sales]') +END +GO +IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'Security') +BEGIN + EXEC('CREATE SCHEMA [Security]') +END +GO +IF NOT EXISTS (SELECT * FROM sys.types WHERE name = 'AccountNumber') +BEGIN + CREATE TYPE [dbo].[AccountNumber] FROM [nvarchar](15) NULL +END +GO +IF NOT EXISTS (SELECT * FROM sys.types WHERE name = 'Flag') +BEGIN + CREATE TYPE [dbo].[Flag] FROM [bit] NOT NULL +END +GO + +IF NOT EXISTS (SELECT * FROM sys.types WHERE name = 'Name') +BEGIN + CREATE TYPE [dbo].[Name] FROM [nvarchar](50) NULL +END +GO +IF NOT EXISTS (SELECT * FROM sys.types WHERE name = 'NameStyle') +BEGIN + CREATE TYPE [dbo].[NameStyle] FROM [bit] NOT NULL +END +GO +IF NOT EXISTS (SELECT * FROM sys.types WHERE name = 'OrderNumber') +BEGIN + CREATE TYPE [dbo].[OrderNumber] FROM [nvarchar](25) NULL +END +GO + +GO +IF NOT EXISTS (SELECT * FROM sys.types WHERE name = 'Phone') +BEGIN + CREATE TYPE [dbo].[Phone] FROM [nvarchar](25) NULL +END +GO + +SET ANSI_NULLS ON +GO +SET QUOTED_IDENTIFIER ON +GO +IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = 'Address') +BEGIN +CREATE TABLE [Person].[Address]( + [AddressID] [int] IDENTITY(1,1) NOT FOR REPLICATION NOT NULL, + [AddressLine1] [nvarchar](60) NOT NULL, + [AddressLine2] [nvarchar](60) NULL, + [City] [nvarchar](30) NOT NULL, + [StateProvinceID] [int] NOT NULL, + [PostalCode] [nvarchar](15) NOT NULL, + [SpatialLocation] [geography] NULL, + [rowguid] [uniqueidentifier] ROWGUIDCOL NOT NULL, + [ModifiedDate] [datetime] NOT NULL, + CONSTRAINT [PK_Address_AddressID] PRIMARY KEY CLUSTERED +( + [AddressID] ASC +)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY] +) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY] +END +GO + +IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = 'AddressType') +BEGIN +CREATE TABLE [Person].[AddressType]( + [AddressTypeID] [int] IDENTITY(1,1) NOT NULL, + [Name] [dbo].[Name] NOT NULL, + [rowguid] [uniqueidentifier] ROWGUIDCOL NOT NULL, + [ModifiedDate] [datetime] NOT NULL, + CONSTRAINT [PK_AddressType_AddressTypeID] PRIMARY KEY CLUSTERED +( + [AddressTypeID] ASC +)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY] +) ON [PRIMARY] +END +GO +/****** Object: Table [Person].[ContactType] Script Date: 11/22/2016 9:25:52 AM ******/ +IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = 'ContactType') +BEGIN +CREATE TABLE [Person].[ContactType]( + [ContactTypeID] [int] IDENTITY(1,1) NOT NULL, + [Name] [dbo].[Name] NOT NULL, + [ModifiedDate] [datetime] NOT NULL, + CONSTRAINT [PK_ContactType_ContactTypeID] PRIMARY KEY CLUSTERED +( + [ContactTypeID] ASC +)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY] +) ON [PRIMARY] +END +GO + +IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = 'EmailAddress') +BEGIN +CREATE TABLE [Person].[EmailAddress]( + [BusinessEntityID] [int] NOT NULL, + [EmailAddressID] [int] IDENTITY(1,1) NOT NULL, + [EmailAddress] [nvarchar](50) NULL, + [rowguid] [uniqueidentifier] ROWGUIDCOL NOT NULL, + [ModifiedDate] [datetime] NOT NULL, + CONSTRAINT [PK_EmailAddress_BusinessEntityID_EmailAddressID] PRIMARY KEY CLUSTERED +( + [BusinessEntityID] ASC, + [EmailAddressID] ASC +)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY] +) ON [PRIMARY] +END +GO + +IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = 'Person') +BEGIN +CREATE TABLE [Person].[Person]( + [BusinessEntityID] [int] NOT NULL, + [PersonType] [nchar](2) NOT NULL, + [NameStyle] [dbo].[NameStyle] NOT NULL, + [Title] [nvarchar](8) NULL, + [FirstName] [dbo].[Name] NOT NULL, + [MiddleName] [dbo].[Name] NULL, + [LastName] [dbo].[Name] NOT NULL, + [Suffix] [nvarchar](10) NULL, + [EmailPromotion] [int] NOT NULL, + [rowguid] [uniqueidentifier] ROWGUIDCOL NOT NULL, + [ModifiedDate] [datetime] NOT NULL, + CONSTRAINT [PK_Person_BusinessEntityID] PRIMARY KEY CLUSTERED +( + [BusinessEntityID] ASC +)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY] +) ON [PRIMARY] +END +GO + +IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = 'PersonPhone') +BEGIN +CREATE TABLE [Person].[PersonPhone]( + [BusinessEntityID] [int] NOT NULL, + [PhoneNumber] [dbo].[Phone] NOT NULL, + [PhoneNumberTypeID] [int] NOT NULL, + [ModifiedDate] [datetime] NOT NULL, + CONSTRAINT [PK_PersonPhone_BusinessEntityID_PhoneNumber_PhoneNumberTypeID] PRIMARY KEY CLUSTERED +( + [BusinessEntityID] ASC, + [PhoneNumber] ASC, + [PhoneNumberTypeID] ASC +)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY] +) ON [PRIMARY] +END +GO + +IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = 'Location') +BEGIN +CREATE TABLE [Production].[Location]( + [LocationID] [smallint] IDENTITY(1,1) NOT NULL, + [Name] [dbo].[Name] NOT NULL, + [CostRate] [smallmoney] NOT NULL, + [Availability] [decimal](8, 2) NOT NULL, + [ModifiedDate] [datetime] NOT NULL, + CONSTRAINT [PK_Location_LocationID] PRIMARY KEY CLUSTERED +( + [LocationID] ASC +)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY] +) ON [PRIMARY] +END +GO + +IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = 'Product') +BEGIN +CREATE TABLE [Production].[Product]( + [ProductID] [int] IDENTITY(1,1) NOT NULL, + [Name] [dbo].[Name] NOT NULL, + [ProductNumber] [nvarchar](25) NOT NULL, + [MakeFlag] [dbo].[Flag] NOT NULL, + [FinishedGoodsFlag] [dbo].[Flag] NOT NULL, + [Color] [nvarchar](15) NULL, + [SafetyStockLevel] [smallint] NOT NULL, + [ReorderPoint] [smallint] NOT NULL, + [StandardCost] [money] NOT NULL, + [ListPrice] [money] NOT NULL, + [Size] [nvarchar](5) NULL, + [SizeUnitMeasureCode] [nchar](3) NULL, + [WeightUnitMeasureCode] [nchar](3) NULL, + [Weight] [decimal](8, 2) NULL, + [DaysToManufacture] [int] NOT NULL, + [ProductLine] [nchar](2) NULL, + [Class] [nchar](2) NULL, + [Style] [nchar](2) NULL, + [ProductSubcategoryID] [int] NULL, + [ProductModelID] [int] NULL, + [SellStartDate] [datetime] NOT NULL, + [SellEndDate] [datetime] NULL, + [DiscontinuedDate] [datetime] NULL, + [rowguid] [uniqueidentifier] ROWGUIDCOL NOT NULL, + [ModifiedDate] [datetime] NOT NULL, + CONSTRAINT [PK_Product_ProductID] PRIMARY KEY CLUSTERED +( + [ProductID] ASC +)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY] +) ON [PRIMARY] +END +GO diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Scripts/Scripts.cs b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Scripts/Scripts.cs new file mode 100644 index 00000000..02aa695e --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Scripts/Scripts.cs @@ -0,0 +1,65 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.IO; +using System.Reflection; + +namespace Microsoft.SqlTools.ServiceLayer.TestDriver.Scripts +{ + public class Scripts + { + public const string MasterBasicQuery = "SELECT * FROM sys.all_columns"; //basic queries should return at least 10000 rows + + public const string DelayQuery = "WAITFOR DELAY '00:01:00'"; + + public const string TestDbSimpleSelectQuery = "SELECT * FROM [Person].[Address]"; + + public const string SelectQuery = "SELECT * FROM "; + + public static string CreateDatabaseObjectsQuery { get { return CreateDatabaseObjectsQueryInstance.Value; } } + + public static string CreateDatabaseQuery { get { return CreateDatabaseQueryInstance.Value; } } + + public static string TestDbComplexSelectQueries { get { return TestDbSelectQueriesInstance.Value; } } + + private static readonly Lazy CreateDatabaseObjectsQueryInstance = new Lazy(() => + { + return GetScriptFileContent("Microsoft.SqlTools.ServiceLayer.TestDriver.Scripts.CreateTestDatabaseObjects.sql"); + }); + + private static readonly Lazy CreateDatabaseQueryInstance = new Lazy(() => + { + return GetScriptFileContent("Microsoft.SqlTools.ServiceLayer.TestDriver.Scripts.CreateTestDatabase.sql"); + }); + + private static readonly Lazy TestDbSelectQueriesInstance = new Lazy(() => + { + return GetScriptFileContent("Microsoft.SqlTools.ServiceLayer.TestDriver.Scripts.TestDbTableQueries.sql"); + }); + + private static string GetScriptFileContent(string fileName) + { + string fileContent = string.Empty; + try + { + using (Stream stream = typeof(Scripts).GetTypeInfo().Assembly.GetManifestResourceStream(fileName)) + { + using (StreamReader reader = new StreamReader(stream)) + { + fileContent = reader.ReadToEnd(); + } + } + } + catch (Exception ex) + { + Console.WriteLine($"Failed to load the sql script. error: {ex.Message}"); + } + return fileContent; + } + + + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Scripts/TestDbTableQueries.sql b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Scripts/TestDbTableQueries.sql new file mode 100644 index 00000000..2ce30a44 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Scripts/TestDbTableQueries.sql @@ -0,0 +1,84 @@ +SELECT * FROM [Person].[Address] + +SELECT [AddressID] + ,[AddressLine1] + ,[AddressLine2] + ,[City] + ,[StateProvinceID] + ,[PostalCode] + ,[SpatialLocation] + ,[rowguid] + ,[ModifiedDate] + FROM [Person].[Address] +GO +SELECT [AddressTypeID] + ,[Name] + ,[rowguid] + ,[ModifiedDate] + FROM [Person].[AddressType] +GO +SELECT [ContactTypeID] + ,[Name] + ,[ModifiedDate] + FROM [Person].[ContactType] +GO +SELECT [BusinessEntityID] + ,[EmailAddressID] + ,[EmailAddress] + ,[rowguid] + ,[ModifiedDate] + FROM [Person].[EmailAddress] +GO +SELECT [BusinessEntityID] + ,[PersonType] + ,[NameStyle] + ,[Title] + ,[FirstName] + ,[MiddleName] + ,[LastName] + ,[Suffix] + ,[EmailPromotion] + ,[rowguid] + ,[ModifiedDate] + FROM [Person].[Person] +GO +SELECT [BusinessEntityID] + ,[PhoneNumber] + ,[PhoneNumberTypeID] + ,[ModifiedDate] + FROM [Person].[PersonPhone] +GO +SELECT [LocationID] + ,[Name] + ,[CostRate] + ,[Availability] + ,[ModifiedDate] + FROM [Production].[Location] +GO +SELECT [ProductID] + ,[Name] + ,[ProductNumber] + ,[MakeFlag] + ,[FinishedGoodsFlag] + ,[Color] + ,[SafetyStockLevel] + ,[ReorderPoint] + ,[StandardCost] + ,[ListPrice] + ,[Size] + ,[SizeUnitMeasureCode] + ,[WeightUnitMeasureCode] + ,[Weight] + ,[DaysToManufacture] + ,[ProductLine] + ,[Class] + ,[Style] + ,[ProductSubcategoryID] + ,[ProductModelID] + ,[SellStartDate] + ,[SellEndDate] + ,[DiscontinuedDate] + ,[rowguid] + ,[ModifiedDate] + FROM [Production].[Product] +GO \ No newline at end of file diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/ConnectionTests.cs b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/ConnectionTests.cs new file mode 100644 index 00000000..8aefb01b --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/ConnectionTests.cs @@ -0,0 +1,59 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Utility; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.TestDriver.Tests +{ + /// + /// Language Service end-to-end integration tests + /// + public class ConnectionTest + { + /// + /// Try to connect with invalid credentials + /// + [Fact] + public async Task InvalidConnection() + { + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + bool connected = await testHelper.Connect(queryTempFile.FilePath, ConnectionTestUtils.InvalidConnection, 300000); + Assert.False(connected, "Invalid connection is failed to connect"); + + await testHelper.Connect(queryTempFile.FilePath, ConnectionTestUtils.InvalidConnection, 300000); + + Thread.Sleep(1000); + + await testHelper.CancelConnect(queryTempFile.FilePath); + + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + + /// + /// Validate list databases request + /// + [Fact] + public async Task ListDatabasesTest() + { + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + bool connected = await testHelper.Connect(queryTempFile.FilePath, ConnectionTestUtils.LocalhostConnection); + Assert.True(connected, "Connection successful"); + + var listDatabaseResult = await testHelper.ListDatabases(queryTempFile.FilePath); + Assert.True(listDatabaseResult.DatabaseNames.Length > 0); + + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/ExampleTests.cs b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/ExampleTests.cs deleted file mode 100644 index 21fe7552..00000000 --- a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/ExampleTests.cs +++ /dev/null @@ -1,38 +0,0 @@ -using System; -using System.Threading.Tasks; -using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; -using Microsoft.SqlTools.ServiceLayer.TestDriver.Driver; - -namespace Microsoft.SqlTools.ServiceLayer.TestDriver.Tests -{ - public class ExampleTests - { - /// - /// Example test that performs a connect, then disconnect. - /// All tests must have the same signature of returning an async Task - /// and taking in a ServiceTestDriver as a parameter. - /// - public async Task ConnectDisconnectTest(ServiceTestDriver driver) - { - var connectParams = new ConnectParams(); - connectParams.OwnerUri = "file"; - connectParams.Connection = new ConnectionDetails(); - connectParams.Connection.ServerName = "localhost"; - connectParams.Connection.AuthenticationType = "Integrated"; - - var result = await driver.SendRequest(ConnectionRequest.Type, connectParams); - if (result) - { - await driver.WaitForEvent(ConnectionCompleteNotification.Type); - - var disconnectParams = new DisconnectParams(); - disconnectParams.OwnerUri = "file"; - var result2 = await driver.SendRequest(DisconnectRequest.Type, disconnectParams); - if (result2) - { - Console.WriteLine("success"); - } - } - } - } -} diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/LanguageServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/LanguageServiceTests.cs new file mode 100644 index 00000000..dccee6c8 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/LanguageServiceTests.cs @@ -0,0 +1,471 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts; +using Microsoft.SqlTools.ServiceLayer.SqlContext; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Utility; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.TestDriver.Tests +{ + /// + /// Language Service end-to-end integration tests + /// + public class LanguageServiceTests + { + + /// + /// Validate hover tooltip scenarios + /// + [Fact] + public async Task HoverTest() + { + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + string query = "SELECT * FROM sys.objects"; + + testHelper.WriteToFile(queryTempFile.FilePath, query); + + DidOpenTextDocumentNotification openParams = new DidOpenTextDocumentNotification + { + TextDocument = new TextDocumentItem + { + Uri = queryTempFile.FilePath, + LanguageId = "enu", + Version = 1, + Text = query + } + }; + + await testHelper.RequestOpenDocumentNotification(openParams); + + Thread.Sleep(500); + + bool connected = await testHelper.Connect(queryTempFile.FilePath, ConnectionTestUtils.LocalhostConnection); + Assert.True(connected, "Connection was not successful"); + + Thread.Sleep(10000); + + Hover hover = await testHelper.RequestHover(queryTempFile.FilePath, query, 0, 15); + + Assert.True(hover != null, "Hover tooltop is null"); + + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + + /// + /// Validation autocompletion suggestions scenarios + /// + [Fact] + public async Task CompletionTest() + { + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + string query = "SELECT * FROM sys.objects"; + + testHelper.WriteToFile(queryTempFile.FilePath, query); + + DidOpenTextDocumentNotification openParams = new DidOpenTextDocumentNotification + { + TextDocument = new TextDocumentItem + { + Uri = queryTempFile.FilePath, + LanguageId = "enu", + Version = 1, + Text = query + } + }; + + await testHelper.RequestOpenDocumentNotification(openParams); + + Thread.Sleep(500); + + bool connected = await testHelper.Connect(queryTempFile.FilePath, ConnectionTestUtils.LocalhostConnection); + Assert.True(connected, "Connection is successful"); + + Thread.Sleep(10000); + + CompletionItem[] completions = await testHelper.RequestCompletion(queryTempFile.FilePath, query, 0, 15); + + Assert.True(completions != null && completions.Length > 0, "Completion items list is null or empty"); + + Thread.Sleep(50); + + await testHelper.RequestResolveCompletion(completions[0]); + + Assert.True(completions != null && completions.Length > 0, "Completion items list is null or empty"); + + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + + /// + /// Validate diagnostic scenarios + /// + [Fact] + public async Task DiagnosticsTests() + { + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + bool connected = await testHelper.Connect(queryTempFile.FilePath, ConnectionTestUtils.LocalhostConnection); + Assert.True(connected, "Connection was not successful"); + + Thread.Sleep(500); + + string query = "SELECT *** FROM sys.objects"; + + DidOpenTextDocumentNotification openParams = new DidOpenTextDocumentNotification + { + TextDocument = new TextDocumentItem + { + Uri = queryTempFile.FilePath, + LanguageId = "enu", + Version = 1, + Text = query + } + }; + + await testHelper.RequestOpenDocumentNotification(openParams); + + Thread.Sleep(100); + + var contentChanges = new TextDocumentChangeEvent[1]; + contentChanges[0] = new TextDocumentChangeEvent + { + Range = new Range + { + Start = new Position + { + Line = 0, + Character = 5 + }, + End = new Position + { + Line = 0, + Character = 6 + } + }, + RangeLength = 1, + Text = "z" + }; + + DidChangeTextDocumentParams changeParams = new DidChangeTextDocumentParams() + { + ContentChanges = contentChanges, + TextDocument = new VersionedTextDocumentIdentifier() + { + Version = 2, + Uri = queryTempFile.FilePath + } + }; + + await testHelper.RequestChangeTextDocumentNotification(changeParams); + + Thread.Sleep(100); + + contentChanges[0] = new TextDocumentChangeEvent + { + Range = new Range + { + Start = new Position + { + Line = 0, + Character = 5 + }, + End = new Position + { + Line = 0, + Character = 6 + } + }, + RangeLength = 1, + Text = "t" + }; + + changeParams = new DidChangeTextDocumentParams + { + ContentChanges = contentChanges, + TextDocument = new VersionedTextDocumentIdentifier + { + Version = 3, + Uri = queryTempFile.FilePath + } + }; + + await testHelper.RequestChangeTextDocumentNotification(changeParams); + + Thread.Sleep(2500); + + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + + /// + /// Peek Definition/ Go to definition + /// + /// + [Fact] + public async Task DefinitionTest() + { + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + string query = "SELECT * FROM sys.objects"; + int lineNumber = 0; + int position = 23; + + testHelper.WriteToFile(queryTempFile.FilePath, query); + + DidOpenTextDocumentNotification openParams = new DidOpenTextDocumentNotification + { + TextDocument = new TextDocumentItem + { + Uri = queryTempFile.FilePath, + LanguageId = "enu", + Version = 1, + Text = query + } + }; + + await testHelper.RequestOpenDocumentNotification(openParams); + + Thread.Sleep(500); + + bool connected = await testHelper.Connect(queryTempFile.FilePath, ConnectionTestUtils.LocalhostConnection); + Assert.True(connected, "Connection is successful"); + + Thread.Sleep(10000); + // Request definition for "objects" + Location[] locations = await testHelper.RequestDefinition(queryTempFile.FilePath, query, lineNumber, position); + + Assert.True(locations != null, "Location is not null and not empty"); + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + + /// + /// Validate the configuration change event + /// + [Fact] + public async Task ChangeConfigurationTest() + { + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + bool connected = await testHelper.Connect(queryTempFile.FilePath, ConnectionTestUtils.LocalhostConnection); + Assert.True(connected, "Connection was not successful"); + + Thread.Sleep(500); + + var settings = new SqlToolsSettings(); + settings.SqlTools.IntelliSense.EnableIntellisense = false; + DidChangeConfigurationParams configParams = new DidChangeConfigurationParams() + { + Settings = settings + }; + + await testHelper.RequestChangeConfigurationNotification(configParams); + + Thread.Sleep(2000); + + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + + [Fact] + public async Task NotificationIsSentAfterOnConnectionAutoCompleteUpdate() + { + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + // Connect + await testHelper.Connect(queryTempFile.FilePath, ConnectionTestUtils.LocalhostConnection); + + // An event signalling that IntelliSense is ready should be sent shortly thereafter + var readyParams = await testHelper.Driver.WaitForEvent(IntelliSenseReadyNotification.Type, 30000); + Assert.NotNull(readyParams); + Assert.Equal(queryTempFile.FilePath, readyParams.OwnerUri); + + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + + [Fact] + public async Task FunctionSignatureCompletionReturnsEmptySignatureHelpObjectWhenThereAreNoMatches() + { + string sqlText = "EXEC sys.fn_not_a_real_function "; + + using (SelfCleaningTempFile tempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + string ownerUri = tempFile.FilePath; + File.WriteAllText(ownerUri, sqlText); + + // Connect + await testHelper.Connect(ownerUri, ConnectionTestUtils.LocalhostConnection); + + // Wait for intellisense to be ready + var readyParams = await testHelper.Driver.WaitForEvent(IntelliSenseReadyNotification.Type, 30000); + Assert.NotNull(readyParams); + Assert.Equal(ownerUri, readyParams.OwnerUri); + + // Send a function signature help Request + var position = new TextDocumentPosition() + { + TextDocument = new TextDocumentIdentifier() + { + Uri = ownerUri + }, + Position = new Position() + { + Line = 0, + Character = sqlText.Length + } + }; + var signatureHelp = await testHelper.Driver.SendRequest(SignatureHelpRequest.Type, position); + + Assert.NotNull(signatureHelp); + Assert.False(signatureHelp.ActiveSignature.HasValue); + Assert.Null(signatureHelp.Signatures); + + await testHelper.Disconnect(ownerUri); + } + } + + [Fact] + public async Task FunctionSignatureCompletionReturnsCorrectFunction() + { + string sqlText = "EXEC sys.fn_isrolemember "; + + using (SelfCleaningTempFile tempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + string ownerUri = tempFile.FilePath; + File.WriteAllText(ownerUri, sqlText); + + // Connect + await testHelper.Connect(ownerUri, ConnectionTestUtils.LocalhostConnection); + + // Wait for intellisense to be ready + var readyParams = await testHelper.Driver.WaitForEvent(IntelliSenseReadyNotification.Type, 30000); + Assert.NotNull(readyParams); + Assert.Equal(ownerUri, readyParams.OwnerUri); + + // Send a function signature help Request + var position = new TextDocumentPosition() + { + TextDocument = new TextDocumentIdentifier() + { + Uri = ownerUri + }, + Position = new Position() + { + Line = 0, + Character = sqlText.Length + } + }; + var signatureHelp = await testHelper.Driver.SendRequest(SignatureHelpRequest.Type, position); + + Assert.NotNull(signatureHelp); + Assert.True(signatureHelp.ActiveSignature.HasValue); + Assert.NotEmpty(signatureHelp.Signatures); + + var label = signatureHelp.Signatures[signatureHelp.ActiveSignature.Value].Label; + Assert.NotNull(label); + Assert.NotEmpty(label); + Assert.True(label.Contains("fn_isrolemember")); + + await testHelper.Disconnect(ownerUri); + } + } + + [Fact] + public async Task FunctionSignatureCompletionReturnsCorrectParametersAtEachPosition() + { + string sqlText = "EXEC sys.fn_isrolemember 1, 'testing', 2"; + + using (SelfCleaningTempFile tempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + string ownerUri = tempFile.FilePath; + File.WriteAllText(ownerUri, sqlText); + + // Connect + await testHelper.Connect(ownerUri, ConnectionTestUtils.LocalhostConnection); + + // Wait for intellisense to be ready + var readyParams = await testHelper.Driver.WaitForEvent(IntelliSenseReadyNotification.Type, 30000); + Assert.NotNull(readyParams); + Assert.Equal(ownerUri, readyParams.OwnerUri); + + // Verify all parameters when the cursor is inside of parameters and at separator boundaries (,) + await VerifyFunctionSignatureHelpParameter(testHelper, ownerUri, 25, "fn_isrolemember", 0, "@mode int"); + await VerifyFunctionSignatureHelpParameter(testHelper, ownerUri, 26, "fn_isrolemember", 0, "@mode int"); + await VerifyFunctionSignatureHelpParameter(testHelper, ownerUri, 27, "fn_isrolemember", 1, "@login sysname"); + await VerifyFunctionSignatureHelpParameter(testHelper, ownerUri, 30, "fn_isrolemember", 1, "@login sysname"); + await VerifyFunctionSignatureHelpParameter(testHelper, ownerUri, 37, "fn_isrolemember", 1, "@login sysname"); + await VerifyFunctionSignatureHelpParameter(testHelper, ownerUri, 38, "fn_isrolemember", 2, "@tranpubid int"); + await VerifyFunctionSignatureHelpParameter(testHelper, ownerUri, 39, "fn_isrolemember", 2, "@tranpubid int"); + + await testHelper.Disconnect(ownerUri); + } + } + + public async Task VerifyFunctionSignatureHelpParameter( + TestHelper testHelper, + string ownerUri, + int character, + string expectedFunctionName, + int expectedParameterIndex, + string expectedParameterName) + { + var position = new TextDocumentPosition() + { + TextDocument = new TextDocumentIdentifier() + { + Uri = ownerUri + }, + Position = new Position() + { + Line = 0, + Character = character + } + }; + var signatureHelp = await testHelper.Driver.SendRequest(SignatureHelpRequest.Type, position); + + Assert.NotNull(signatureHelp); + Assert.NotNull(signatureHelp.ActiveSignature); + Assert.True(signatureHelp.ActiveSignature.HasValue); + Assert.NotEmpty(signatureHelp.Signatures); + + var activeSignature = signatureHelp.Signatures[signatureHelp.ActiveSignature.Value]; + Assert.NotNull(activeSignature); + + var label = activeSignature.Label; + Assert.NotNull(label); + Assert.NotEmpty(label); + Assert.True(label.Contains(expectedFunctionName)); + + Assert.NotNull(signatureHelp.ActiveParameter); + Assert.True(signatureHelp.ActiveParameter.HasValue); + Assert.Equal(expectedParameterIndex, signatureHelp.ActiveParameter.Value); + + var parameter = activeSignature.Parameters[signatureHelp.ActiveParameter.Value]; + Assert.NotNull(parameter); + Assert.NotNull(parameter.Label); + Assert.NotEmpty(parameter.Label); + Assert.Equal(expectedParameterName, parameter.Label); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/QueryExecutionTests.cs b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/QueryExecutionTests.cs new file mode 100644 index 00000000..75df010c --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/QueryExecutionTests.cs @@ -0,0 +1,361 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Utility; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.TestDriver.Tests +{ + public class QueryExecutionTests + { + /* Commenting out these tests until they are fixed (12/1/16) + [Fact] + public async Task TestQueryCancelReliability() + { + const string query = "SELECT * FROM sys.objects a CROSS JOIN sys.objects b CROSS JOIN sys.objects c"; + + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + await testHelper.Connect(queryTempFile.FilePath, ConnectionTestUtils.AzureTestServerConnection); + + // Run and cancel 100 queries + for (int i = 0; i < 100; i++) + { + var queryTask = testHelper.RunQuery(queryTempFile.FilePath, query); + + var cancelResult = await testHelper.CancelQuery(queryTempFile.FilePath); + Assert.NotNull(cancelResult); + Assert.True(string.IsNullOrEmpty(cancelResult.Messages)); + + await queryTask; + } + + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + + [Fact] + public async Task TestQueryDoesNotBlockOtherRequests() + { + const string query = "SELECT * FROM sys.objects a CROSS JOIN sys.objects b CROSS JOIN sys.objects c"; + + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + await testHelper.Connect(queryTempFile.FilePath, ConnectionTestUtils.AzureTestServerConnection); + + // Start a long-running query + var queryTask = testHelper.RunQuery(queryTempFile.FilePath, query, 60000); + + // Interact with the service. None of these requests should time out while waiting for the query to finish + for (int i = 0; i < 10; i++) + { + using (SelfCleaningTempFile queryFile2 = new SelfCleaningTempFile()) + { + await testHelper.Connect(queryFile2.FilePath, ConnectionTestUtils.AzureTestServerConnection); + Assert.NotNull(await testHelper.RequestCompletion(queryFile2.FilePath, "SELECT * FROM sys.objects", 0, 10)); + await testHelper.Disconnect(queryFile2.FilePath); + } + } + + await testHelper.CancelQuery(queryTempFile.FilePath); + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + + [Fact] + public async Task TestParallelQueryExecution() + { + const int queryCount = 10; + const string query = "SELECT * FROM sys.objects"; + + using (TestHelper testHelper = new TestHelper()) + { + // Create n connections + SelfCleaningTempFile[] ownerUris = new SelfCleaningTempFile[queryCount]; + for (int i = 0; i < queryCount; i++) + { + ownerUris[i] = new SelfCleaningTempFile(); + Assert.NotNull(await testHelper.Connect(ownerUris[i].FilePath, ConnectionTestUtils.AzureTestServerConnection)); + } + + // Run n queries at once + var queryTasks = new Task[queryCount]; + for (int i = 0; i < queryCount; i++) + { + queryTasks[i] = testHelper.RunQuery(ownerUris[i].FilePath, query); + } + await Task.WhenAll(queryTasks); + + // Verify that they all completed with results and Disconnect + for (int i = 0; i < queryCount; i++) + { + Assert.NotNull(queryTasks[i].Result); + Assert.NotNull(queryTasks[i].Result.BatchSummaries); + await testHelper.Disconnect(ownerUris[i].FilePath); + ownerUris[i].Dispose(); + } + } + } + + [Fact] + public async Task TestSaveResultsDoesNotBlockOtherRequests() + { + const string query = "SELECT * FROM sys.objects"; + + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + await testHelper.Connect(queryTempFile.FilePath, ConnectionTestUtils.AzureTestServerConnection); + + // Execute a query + await testHelper.RunQuery(queryTempFile.FilePath, query); + + // Spawn several tasks to save results + var saveTasks = new Task[100]; + for (int i = 0; i < 100; i++) + { + if (i % 2 == 0) + { + saveTasks[i] = testHelper.SaveAsCsv(queryTempFile.FilePath, System.IO.Path.GetTempFileName(), 0, 0); + } + else + { + saveTasks[i] = testHelper.SaveAsJson(queryTempFile.FilePath, System.IO.Path.GetTempFileName(), 0, 0); + } + } + + // Interact with the service. None of these requests should time out while waiting for the save results tasks to finish + for (int i = 0; i < 10; i++) + { + using(SelfCleaningTempFile queryFile2 = new SelfCleaningTempFile()) + { + await testHelper.Connect(queryFile2.FilePath, ConnectionTestUtils.AzureTestServerConnection); + Assert.NotNull(await testHelper.RequestCompletion(queryFile2.FilePath, "SELECT * FROM sys.objects", 0, 10)); + await testHelper.Disconnect(queryFile2.FilePath); + } + } + + await Task.WhenAll(saveTasks); + + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + + [Fact] + public async Task TestQueryingSubsetDoesNotBlockOtherRequests() + { + const string query = "SELECT * FROM sys.objects"; + + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + await testHelper.Connect(queryTempFile.FilePath, ConnectionTestUtils.AzureTestServerConnection); + + // Execute a query + await testHelper.RunQuery(queryTempFile.FilePath, query); + + // Spawn several tasks for subset requests + var subsetTasks = new Task[100]; + for (int i = 0; i < 100; i++) + { + subsetTasks[i] = testHelper.ExecuteSubset(queryTempFile.FilePath, 0, 0, 0, 100); + } + + // Interact with the service. None of these requests should time out while waiting for the subset tasks to finish + for (int i = 0; i < 10; i++) + { + using (SelfCleaningTempFile queryFile2 = new SelfCleaningTempFile()) + { + await testHelper.Connect(queryFile2.FilePath, ConnectionTestUtils.AzureTestServerConnection); + Assert.NotNull(await testHelper.RequestCompletion(queryFile2.FilePath, "SELECT * FROM sys.objects", 0, 10)); + await testHelper.Disconnect(queryFile2.FilePath); + } + } + + await Task.WhenAll(subsetTasks); + + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + + [Fact] + public async Task TestCancelQueryWhileOtherOperationsAreInProgress() + { + const string query = "SELECT * FROM sys.objects a CROSS JOIN sys.objects b"; + + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + List tasks = new List(); + + await testHelper.Connect(queryTempFile.FilePath, ConnectionTestUtils.AzureTestServerConnection); + + // Execute a long-running query + var queryTask = testHelper.RunQuery(queryTempFile.FilePath, query, 60000); + + // Queue up some tasks that interact with the service + for (int i = 0; i < 10; i++) + { + using (SelfCleaningTempFile queryFile2 = new SelfCleaningTempFile()) + { + tasks.Add(Task.Run(async () => + { + await testHelper.Connect(queryFile2.FilePath, ConnectionTestUtils.AzureTestServerConnection); + await testHelper.RequestCompletion(queryFile2.FilePath, "SELECT * FROM sys.objects", 0, 10); + await testHelper.RunQuery(queryFile2.FilePath, "SELECT * FROM sys.objects"); + await testHelper.Disconnect(queryFile2.FilePath); + })); + } + } + + // Cancel the long-running query + await testHelper.CancelQuery(queryTempFile.FilePath); + + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + + [Fact] + public async Task ExecuteBasicQueryTest() + { + const string query = "SELECT * FROM sys.all_columns c"; + + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + bool connected = await testHelper.Connect(queryTempFile.FilePath, ConnectionTestUtils.LocalhostConnection); + Assert.True(connected, "Connection is successful"); + + Thread.Sleep(500); + + DidOpenTextDocumentNotification openParams = new DidOpenTextDocumentNotification() + { + TextDocument = new TextDocumentItem() + { + Uri = queryTempFile.FilePath, + LanguageId = "enu", + Version = 1, + Text = query + } + }; + + await testHelper.RequestOpenDocumentNotification(openParams); + + var queryResult = await testHelper.RunQuery(queryTempFile.FilePath, query, 10000); + + Assert.NotNull(queryResult); + Assert.NotNull(queryResult.BatchSummaries); + + foreach (var batchSummary in queryResult.BatchSummaries) + { + foreach (var resultSetSummary in batchSummary.ResultSetSummaries) + { + Assert.True(resultSetSummary.RowCount > 0); + } + } + + var subsetRequest = new QueryExecuteSubsetParams() + { + OwnerUri = queryTempFile.FilePath, + BatchIndex = 0, + ResultSetIndex = 0, + RowsStartIndex = 0, + RowsCount = 100, + }; + + var querySubset = await testHelper.RequestQueryExecuteSubset(subsetRequest); + Assert.NotNull(querySubset); + Assert.True(querySubset.ResultSubset.RowCount == 100); + + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + + [Fact] + public async Task TestQueryingAfterCompletionRequests() + { + const string query = "SELECT * FROM sys.objects"; + + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + List tasks = new List(); + + await testHelper.Connect(queryTempFile.FilePath, ConnectionTestUtils.AzureTestServerConnection); + + Enumerable.Range(0, 10).ToList().ForEach(arg => tasks.Add(testHelper.RequestCompletion(queryTempFile.FilePath, query, 0, 10))); + var queryTask = testHelper.RunQuery(queryTempFile.FilePath, query); + tasks.Add(queryTask); + await Task.WhenAll(tasks); + + Assert.NotNull(queryTask.Result); + Assert.NotNull(queryTask.Result.BatchSummaries); + + await testHelper.Connect(queryTempFile.FilePath, ConnectionTestUtils.DataToolsTelemetryAzureConnection); + tasks.Clear(); + Enumerable.Range(0, 10).ToList().ForEach(arg => tasks.Add(testHelper.RequestCompletion(queryTempFile.FilePath, query, 0, 10))); + queryTask = testHelper.RunQuery(queryTempFile.FilePath, query); + tasks.Add(queryTask); + await Task.WhenAll(tasks); + + Assert.NotNull(queryTask.Result); + Assert.NotNull(queryTask.Result.BatchSummaries); + + await testHelper.Connect(queryTempFile.FilePath, ConnectionTestUtils.SqlDataToolsAzureConnection); + tasks.Clear(); + Enumerable.Range(0, 10).ToList().ForEach(arg => tasks.Add(testHelper.RequestCompletion(queryTempFile.FilePath, query, 0, 10))); + queryTask = testHelper.RunQuery(queryTempFile.FilePath, query); + tasks.Add(queryTask); + await Task.WhenAll(tasks); + + Assert.NotNull(queryTask.Result); + Assert.NotNull(queryTask.Result.BatchSummaries); + + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + */ + + [Fact] + public async Task NoOpQueryReturnsMessage() + { + // Given queries that do nothing (no-ops)... + var queries = new string[] + { + "-- no-op", + "GO", + "GO -- no-op" + }; + + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + foreach (var query in queries) + { + Assert.True(await testHelper.Connect(queryTempFile.FilePath, ConnectionTestUtils.LocalhostConnection)); + + // If the queries are executed... + var queryResult = await testHelper.RunQueryAsync(queryTempFile.FilePath, query); + + // Then I expect messages that the commands were completed successfully to be in the result + Assert.NotNull(queryResult); + Assert.NotNull(queryResult.Messages); + Assert.Equal("Commands completed successfully.", queryResult.Messages); + + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/StressTests.cs b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/StressTests.cs new file mode 100644 index 00000000..bd39ee4a --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/StressTests.cs @@ -0,0 +1,217 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Utility; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.TestDriver.Tests +{ + public class StressTests + { + /// + /// Simulate typing by a user to stress test the language service + /// + //[Fact] + public async Task TestLanguageService() + { + const string textToType = "SELECT * FROM sys.objects GO " + + "CREATE TABLE MyTable(" + + "FirstName CHAR," + + "LastName CHAR," + + "DateOfBirth DATETIME," + + "CONSTRAINT MyTableConstraint UNIQUE (FirstName, LastName, DateOfBirth)) GO " + + "INSERT INTO MyTable (FirstName, LastName, DateOfBirth) VALUES ('John', 'Doe', '19800101') GO " + + "SELECT * FROM MyTable GO " + + "ALTER TABLE MyTable DROP CONSTRAINT MyTableConstraint GO " + + "DROP TABLE MyTable GO "; + + + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + // Connect + bool connected = await testHelper.Connect(queryTempFile.FilePath, ConnectionTestUtils.LocalhostConnection); + Assert.True(connected, "Connection was not successful"); + + Thread.Sleep(10000); // Wait for intellisense to warm up + + // Simulate typing + Stopwatch stopwatch = new Stopwatch(); + stopwatch.Start(); + int version = 1; + while (stopwatch.Elapsed < TimeSpan.FromMinutes(60)) + { + for (int i = 0; i < textToType.Length; i++) + { + System.IO.File.WriteAllText(queryTempFile.FilePath, textToType.Substring(0, i + 1)); + + var contentChanges = new TextDocumentChangeEvent[1]; + contentChanges[0] = new TextDocumentChangeEvent() + { + Range = new Range() + { + Start = new Position() + { + Line = 0, + Character = i + }, + End = new Position() + { + Line = 0, + Character = i + } + }, + RangeLength = 1, + Text = textToType.Substring(i, 1) + }; + + DidChangeTextDocumentParams changeParams = new DidChangeTextDocumentParams() + { + ContentChanges = contentChanges, + TextDocument = new VersionedTextDocumentIdentifier() + { + Version = ++version, + Uri = queryTempFile.FilePath + } + }; + + await testHelper.RequestChangeTextDocumentNotification(changeParams); + + Thread.Sleep(50); + + // If we just typed a space, request/resolve completion + if (textToType[i] == ' ') + { + var completions = await testHelper.RequestCompletion(queryTempFile.FilePath, textToType.Substring(0, i + 1), 0, i + 1); + Assert.True(completions != null && completions.Length > 0, "Completion items list was null or empty"); + + Thread.Sleep(50); + + var item = await testHelper.RequestResolveCompletion(completions[0]); + + Assert.NotNull(item); + } + } + + // Clear the text document + System.IO.File.WriteAllText(queryTempFile.FilePath, ""); + + var contentChanges2 = new TextDocumentChangeEvent[1]; + contentChanges2[0] = new TextDocumentChangeEvent() + { + Range = new Range() + { + Start = new Position() + { + Line = 0, + Character = 0 + }, + End = new Position() + { + Line = 0, + Character = textToType.Length - 1 + } + }, + RangeLength = textToType.Length, + Text = "" + }; + + DidChangeTextDocumentParams changeParams2 = new DidChangeTextDocumentParams() + { + ContentChanges = contentChanges2, + TextDocument = new VersionedTextDocumentIdentifier() + { + Version = ++version, + Uri = queryTempFile.FilePath + } + }; + + await testHelper.RequestChangeTextDocumentNotification(changeParams2); + } + + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + + /// + /// Repeatedly execute queries to stress test the query execution service. + /// + //[Fact] + public async Task TestQueryExecutionService() + { + const string queryToRun = "SELECT * FROM sys.all_objects GO " + + "SELECT * FROM sys.objects GO " + + "SELECT * FROM sys.tables GO " + + "SELECT COUNT(*) FROM sys.objects"; + + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + // Connect + bool connected = await testHelper.Connect(queryTempFile.FilePath, ConnectionTestUtils.LocalhostConnection); + Assert.True(connected, "Connection is successful"); + + // Run queries repeatedly + Stopwatch stopwatch = new Stopwatch(); + stopwatch.Start(); + while (stopwatch.Elapsed < TimeSpan.FromMinutes(60)) + { + var queryResult = await testHelper.RunQuery(queryTempFile.FilePath, queryToRun, 10000); + + Assert.NotNull(queryResult); + Assert.NotNull(queryResult.BatchSummaries); + Assert.NotEmpty(queryResult.BatchSummaries); + Assert.NotNull(queryResult.BatchSummaries[0].ResultSetSummaries); + Assert.NotNull(queryResult.BatchSummaries[1].ResultSetSummaries); + Assert.NotNull(queryResult.BatchSummaries[2].ResultSetSummaries); + Assert.NotNull(queryResult.BatchSummaries[3].ResultSetSummaries); + + Assert.NotNull(await testHelper.ExecuteSubset(queryTempFile.FilePath, 0, 0, 0, 7)); + Assert.NotNull(await testHelper.ExecuteSubset(queryTempFile.FilePath, 1, 0, 0, 7)); + Assert.NotNull(await testHelper.ExecuteSubset(queryTempFile.FilePath, 2, 0, 0, 7)); + Assert.NotNull(await testHelper.ExecuteSubset(queryTempFile.FilePath, 3, 0, 0, 1)); + + Thread.Sleep(500); + } + + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + + /// + /// Repeatedly connect and disconnect to stress test the connection service. + /// + //[Fact] + public async Task TestConnectionService() + { + string ownerUri = "file:///my/test/file.sql"; + + var connection = ConnectionTestUtils.LocalhostConnection; + connection.Connection.Pooling = false; + + using (TestHelper testHelper = new TestHelper()) + { + // Connect/disconnect repeatedly + Stopwatch stopwatch = new Stopwatch(); + stopwatch.Start(); + while (stopwatch.Elapsed < TimeSpan.FromMinutes(60)) + { + // Connect + bool connected = await testHelper.Connect(ownerUri, connection); + Assert.True(connected, "Connection is successful"); + + // Disconnect + bool disconnected = await testHelper.Disconnect(ownerUri); + Assert.True(disconnected, "Disconnect is successful"); + } + } + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/TestHelper.cs b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/TestHelper.cs new file mode 100644 index 00000000..d248ffff --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/TestHelper.cs @@ -0,0 +1,395 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; +using Microsoft.SqlTools.ServiceLayer.Credentials.Contracts; +using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Microsoft.SqlTools.ServiceLayer.SqlContext; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Driver; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Utility; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.TestDriver.Tests +{ + /// + /// Base class for all test suites run by the test driver + /// + public sealed class TestHelper : IDisposable + { + private bool isRunning = false; + + public TestHelper() + { + Driver = new ServiceTestDriver(); + Driver.Start().Wait(); + this.isRunning = true; + } + + public void Dispose() + { + if (this.isRunning) + { + WaitForExit(); + } + } + + public void WaitForExit() + { + try + { + this.isRunning = false; + Driver.Stop().Wait(); + Console.WriteLine("Successfully killed process."); + } + catch(Exception e) + { + Console.WriteLine($"Exception while waiting for service exit: {e.Message}"); + } + } + + /// + /// The driver object used to read/write data to the service + /// + public ServiceTestDriver Driver + { + get; + private set; + } + + private object fileLock = new Object(); + + /// + /// Request a new connection to be created + /// + /// True if the connection completed successfully + public async Task Connect(string ownerUri, ConnectParams connectParams, int timeout = 15000) + { + connectParams.OwnerUri = ownerUri; + var connectResult = await Driver.SendRequest(ConnectionRequest.Type, connectParams); + if (connectResult) + { + var completeEvent = await Driver.WaitForEvent(ConnectionCompleteNotification.Type, timeout); + return !string.IsNullOrEmpty(completeEvent.ConnectionId); + } + else + { + return false; + } + } + + /// + /// Request a disconnect + /// + public async Task Disconnect(string ownerUri) + { + var disconnectParams = new DisconnectParams(); + disconnectParams.OwnerUri = ownerUri; + + var disconnectResult = await Driver.SendRequest(DisconnectRequest.Type, disconnectParams); + return disconnectResult; + } + + /// + /// Request a cancel connect + /// + public async Task CancelConnect(string ownerUri) + { + var cancelParams = new CancelConnectParams(); + cancelParams.OwnerUri = ownerUri; + + return await Driver.SendRequest(CancelConnectRequest.Type, cancelParams); + } + + /// + /// Request a cancel connect + /// + public async Task ListDatabases(string ownerUri) + { + var listParams = new ListDatabasesParams(); + listParams.OwnerUri = ownerUri; + + return await Driver.SendRequest(ListDatabasesRequest.Type, listParams); + } + + /// + /// Request the active SQL script is parsed for errors + /// + public async Task RequestQueryExecuteSubset(QueryExecuteSubsetParams subsetParams) + { + return await Driver.SendRequest(QueryExecuteSubsetRequest.Type, subsetParams); + } + + /// + /// Request the active SQL script is parsed for errors + /// + public async Task RequestOpenDocumentNotification(DidOpenTextDocumentNotification openParams) + { + await Driver.SendEvent(DidOpenTextDocumentNotification.Type, openParams); + } + + /// + /// Request a configuration change notification + /// + public async Task RequestChangeConfigurationNotification(DidChangeConfigurationParams configParams) + { + await Driver.SendEvent(DidChangeConfigurationNotification.Type, configParams); + } + + /// + /// /// Request the active SQL script is parsed for errors + /// + public async Task RequestChangeTextDocumentNotification(DidChangeTextDocumentParams changeParams) + { + await Driver.SendEvent(DidChangeTextDocumentNotification.Type, changeParams); + } + + /// + /// Request completion item resolve to look-up additional info + /// + public async Task RequestResolveCompletion(CompletionItem item) + { + var result = await Driver.SendRequest(CompletionResolveRequest.Type, item); + return result; + } + + /// + /// Request a Read Credential for given credential id + /// + public async Task ReadCredential(string credentialId) + { + var credentialParams = new Credential(); + credentialParams.CredentialId = credentialId; + + return await Driver.SendRequest(ReadCredentialRequest.Type, credentialParams); + } + + /// + /// Returns database connection parameters for given server type + /// + public async Task GetDatabaseConnectionAsync(TestServerType serverType, string databaseName) + { + ConnectionProfile connectionProfile = null; + TestServerIdentity serverIdentiry = ConnectionTestUtils.TestServers.FirstOrDefault(x => x.ServerType == serverType); + if (serverIdentiry == null) + { + connectionProfile = ConnectionTestUtils.Setting.Connections.FirstOrDefault(x => x.ServerType == serverType); + } + else + { + connectionProfile = ConnectionTestUtils.Setting.GetConnentProfile(serverIdentiry.ProfileName, serverIdentiry.ServerName); + } + + if (connectionProfile != null) + { + string password = connectionProfile.Password; + if (string.IsNullOrEmpty(password)) + { + Credential credential = await ReadCredential(connectionProfile.formatCredentialId()); + password = credential.Password; + } + ConnectParams conenctParam = ConnectionTestUtils.CreateConnectParams(connectionProfile.ServerName, connectionProfile.Database, + connectionProfile.User, password); + if (!string.IsNullOrEmpty(databaseName)) + { + conenctParam.Connection.DatabaseName = databaseName; + } + if (serverType == TestServerType.Azure) + { + conenctParam.Connection.ConnectTimeout = 30; + conenctParam.Connection.Encrypt = true; + conenctParam.Connection.TrustServerCertificate = false; + } + return conenctParam; + } + return null; + } + + /// + /// Request a list of completion items for a position in a block of text + /// + public async Task RequestCompletion(string ownerUri, string text, int line, int character) + { + // Write the text to a backing file + lock (fileLock) + { + System.IO.File.WriteAllText(ownerUri, text); + } + + var completionParams = new TextDocumentPosition(); + completionParams.TextDocument = new TextDocumentIdentifier(); + completionParams.TextDocument.Uri = ownerUri; + completionParams.Position = new Position(); + completionParams.Position.Line = line; + completionParams.Position.Character = character; + + var result = await Driver.SendRequest(CompletionRequest.Type, completionParams); + return result; + } + + /// + /// Request a a hover tooltop + /// + public async Task RequestHover(string ownerUri, string text, int line, int character) + { + // Write the text to a backing file + lock (fileLock) + { + System.IO.File.WriteAllText(ownerUri, text); + } + + var completionParams = new TextDocumentPosition + { + TextDocument = new TextDocumentIdentifier {Uri = ownerUri}, + Position = new Position + { + Line = line, + Character = character + } + }; + + var result = await Driver.SendRequest(HoverRequest.Type, completionParams); + return result; + } + + /// + /// Request definition( peek definition/go to definition) for a sql object in a sql string + /// + public async Task RequestDefinition(string ownerUri, string text, int line, int character) + { + // Write the text to a backing file + lock (fileLock) + { + System.IO.File.WriteAllText(ownerUri, text); + } + + var definitionParams = new TextDocumentPosition(); + definitionParams.TextDocument = new TextDocumentIdentifier(); + definitionParams.TextDocument.Uri = ownerUri; + definitionParams.Position = new Position(); + definitionParams.Position.Line = line; + definitionParams.Position.Character = character; + + // Send definition request + var result = await Driver.SendRequest(DefinitionRequest.Type, definitionParams); + return result; + } + + /// + /// Run a query using a given connection bound to a URI + /// + public async Task RunQuery(string ownerUri, string query, int timeoutMilliseconds = 5000) + { + // Write the query text to a backing file + WriteToFile(ownerUri, query); + + var queryParams = new QueryExecuteParams + { + OwnerUri = ownerUri, + QuerySelection = null + }; + + var result = await Driver.SendRequest(QueryExecuteRequest.Type, queryParams); + if (result != null && string.IsNullOrEmpty(result.Messages)) + { + var eventResult = await Driver.WaitForEvent(QueryExecuteCompleteEvent.Type, timeoutMilliseconds); + return eventResult; + } + else + { + return null; + } + } + + /// + /// Run a query using a given connection bound to a URI. This method only waits for the initial response from query + /// execution (QueryExecuteResult). It is up to the caller to wait for the QueryExecuteCompleteEvent if they are interested. + /// + public async Task RunQueryAsync(string ownerUri, string query, int timeoutMilliseconds = 5000) + { + WriteToFile(ownerUri, query); + + var queryParams = new QueryExecuteParams + { + OwnerUri = ownerUri, + QuerySelection = null + }; + + return await Driver.SendRequest(QueryExecuteRequest.Type, queryParams); + } + + /// + /// Request to cancel an executing query + /// + public async Task CancelQuery(string ownerUri) + { + var cancelParams = new QueryCancelParams {OwnerUri = ownerUri}; + + var result = await Driver.SendRequest(QueryCancelRequest.Type, cancelParams); + return result; + } + + /// + /// Request to save query results as CSV + /// + public async Task SaveAsCsv(string ownerUri, string filename, int batchIndex, int resultSetIndex) + { + var saveParams = new SaveResultsAsCsvRequestParams + { + OwnerUri = ownerUri, + BatchIndex = batchIndex, + ResultSetIndex = resultSetIndex, + FilePath = filename + }; + + var result = await Driver.SendRequest(SaveResultsAsCsvRequest.Type, saveParams); + return result; + } + + /// + /// Request to save query results as JSON + /// + public async Task SaveAsJson(string ownerUri, string filename, int batchIndex, int resultSetIndex) + { + var saveParams = new SaveResultsAsJsonRequestParams + { + OwnerUri = ownerUri, + BatchIndex = batchIndex, + ResultSetIndex = resultSetIndex, + FilePath = filename + }; + + var result = await Driver.SendRequest(SaveResultsAsJsonRequest.Type, saveParams); + return result; + } + + /// + /// Request a subset of results from a query + /// + public async Task ExecuteSubset(string ownerUri, int batchIndex, int resultSetIndex, int rowStartIndex, int rowCount) + { + var subsetParams = new QueryExecuteSubsetParams(); + subsetParams.OwnerUri = ownerUri; + subsetParams.BatchIndex = batchIndex; + subsetParams.ResultSetIndex = resultSetIndex; + subsetParams.RowsStartIndex = rowStartIndex; + subsetParams.RowsCount = rowCount; + + var result = await Driver.SendRequest(QueryExecuteSubsetRequest.Type, subsetParams); + return result; + } + + public void WriteToFile(string ownerUri, string query) + { + lock (fileLock) + { + System.IO.File.WriteAllText(ownerUri, query); + } + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/WorkspaceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/WorkspaceTests.cs new file mode 100644 index 00000000..59814534 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/WorkspaceTests.cs @@ -0,0 +1,37 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.IO; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Contracts; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.TestDriver.Tests +{ + /// + /// Language Service end-to-end integration tests + /// + public class WorkspaceTests + { + /// + /// Validate workspace lifecycle events + /// + [Fact] + public async Task InitializeRequestTest() + { + using (TestHelper testHelper = new TestHelper()) + { + InitializeRequest initializeRequest = new InitializeRequest() + { + RootPath = Path.GetTempPath(), + Capabilities = new ClientCapabilities() + }; + + InitializeResult result = await testHelper.Driver.SendRequest(InitializeRequest.Type, initializeRequest); + Assert.NotNull(result); + } + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Utility/ConnectionTestUtils.cs b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Utility/ConnectionTestUtils.cs new file mode 100644 index 00000000..273f8f22 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Utility/ConnectionTestUtils.cs @@ -0,0 +1,184 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Runtime.InteropServices; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; +using Newtonsoft.Json.Linq; + +namespace Microsoft.SqlTools.ServiceLayer.TestDriver.Utility +{ + /// + /// Contains useful utility methods for testing connections + /// + public class ConnectionTestUtils + { + public static IEnumerable TestServers = InitTestServerNames(); + public static Setting Setting = InitSetting(); + + private static readonly Lazy azureTestServerConnection = + new Lazy(() => GetConnectionFromVsCodeSettings("***REMOVED***")); + + private static IEnumerable InitTestServerNames() + { + try + { + string testServerNamesFilePath = Environment.GetEnvironmentVariable("TestServerNamesFile"); + if (!string.IsNullOrEmpty(testServerNamesFilePath)) + { + string jsonFileContent = File.ReadAllText(testServerNamesFilePath); + return Newtonsoft.Json.JsonConvert.DeserializeObject>(jsonFileContent); + } + else + { + return Enumerable.Empty(); + } + } + catch (Exception ex) + { + Console.WriteLine("Failed to load the database connection server name settings. error: " + ex.Message); + return null; + } + } + + private static Setting InitSetting() + { + try + { + string settingsFileContents = GetSettingFileContent(); + Setting setting = Newtonsoft.Json.JsonConvert.DeserializeObject(settingsFileContents); + + return setting; + } + catch (Exception ex) + { + Console.WriteLine("Failed to load the connection settings. error: " + ex.Message); + return null; + } + } + + public static ConnectParams AzureTestServerConnection + { + get { return azureTestServerConnection.Value; } + } + + public static ConnectParams LocalhostConnection + { + get + { + return new ConnectParams() + { + Connection = new ConnectionDetails() + { + DatabaseName = "master", + ServerName = "localhost", + AuthenticationType = "Integrated" + } + }; + } + } + + public static ConnectParams InvalidConnection + { + get + { + return new ConnectParams() + { + Connection = new ConnectionDetails() + { + DatabaseName = "master", + ServerName = "localhost", + AuthenticationType = "SqlLogin", + UserName = "invalid", + Password = ".." + } + }; + } + } + + private static readonly Lazy sqlDataToolsAzureConnection = + new Lazy(() => GetConnectionFromVsCodeSettings("***REMOVED***")); + + public static ConnectParams SqlDataToolsAzureConnection + { + get { return sqlDataToolsAzureConnection.Value; } + } + + private static readonly Lazy dataToolsTelemetryAzureConnection = + new Lazy(() => GetConnectionFromVsCodeSettings("***REMOVED***")); + + private static string GetSettingFileContent() + { + string settingsFilename; + settingsFilename = Environment.GetEnvironmentVariable("SettingsFileName"); + if (string.IsNullOrEmpty(settingsFilename)) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + settingsFilename = Environment.GetEnvironmentVariable("APPDATA") + @"\Code\User\settings.json"; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + settingsFilename = Environment.GetEnvironmentVariable("HOME") + @"/Library/Application Support/Code/User/settings.json"; + } + else + { + settingsFilename = Environment.GetEnvironmentVariable("HOME") + @"/.config/Code/User/settings.json"; + } + } + string settingsFileContents = File.ReadAllText(settingsFilename); + + return settingsFileContents; + } + + public static ConnectParams DataToolsTelemetryAzureConnection + { + get { return dataToolsTelemetryAzureConnection.Value; } + } + + /// + /// Create a connection parameters object + /// + public static ConnectParams CreateConnectParams(string server, string database, string username, string password) + { + ConnectParams connectParams = new ConnectParams(); + connectParams.Connection = new ConnectionDetails(); + connectParams.Connection.ServerName = server; + connectParams.Connection.DatabaseName = database; + connectParams.Connection.UserName = username; + connectParams.Connection.Password = password; + connectParams.Connection.AuthenticationType = "SqlLogin"; + return connectParams; + } + + /// + /// Retrieve connection parameters from the vscode settings file + /// + public static ConnectParams GetConnectionFromVsCodeSettings(string serverName) + { + try + { + string settingsFileContents = GetSettingFileContent(); + + JObject root = JObject.Parse(settingsFileContents); + JArray connections = (JArray)root["mssql.connections"]; + + var connectionObject = connections.Where(x => x["server"].ToString() == serverName).First(); + + return CreateConnectParams( connectionObject["server"].ToString(), + connectionObject["database"].ToString(), + connectionObject["user"].ToString(), + connectionObject["password"].ToString()); + } + catch (Exception ex) + { + throw new Exception("Unable to load connection " + serverName + " from the vscode settings.json. Ensure the file is formatted correctly.", ex); + } + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Utility/SelfCleaningTempFile.cs b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Utility/SelfCleaningTempFile.cs new file mode 100644 index 00000000..c991e3ff --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Utility/SelfCleaningTempFile.cs @@ -0,0 +1,53 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.IO; + +namespace Microsoft.SqlTools.ServiceLayer.TestDriver.Utility +{ + public class SelfCleaningTempFile : IDisposable + { + private bool disposed; + + public SelfCleaningTempFile() + { + FilePath = Path.GetTempFileName(); + } + + public string FilePath { get; private set; } + + #region IDisposable Implementation + + public void Dispose() + { + if (!disposed) + { + Dispose(true); + GC.SuppressFinalize(this); + } + } + + public void Dispose(bool disposing) + { + if (!disposed && disposing) + { + try + { + File.Delete(FilePath); + } + catch + { + Console.WriteLine($"Failed to cleanup {FilePath}"); + } + } + + disposed = true; + } + + #endregion + + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Utility/Setting.cs b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Utility/Setting.cs new file mode 100644 index 00000000..7da7eecd --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Utility/Setting.cs @@ -0,0 +1,82 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Collections.Generic; +using System.Linq; +using System.Globalization; +using Newtonsoft.Json; + +namespace Microsoft.SqlTools.ServiceLayer.TestDriver.Utility +{ + /// + /// The model for deserializing settings.json + /// + public class Setting + { + [JsonProperty("mssql.connections")] + public List Connections { get; set; } + + public ConnectionProfile GetConnentProfile(string profilerName, string serverName) + { + if (!string.IsNullOrEmpty(profilerName)) + { + var byPrfileName = Connections.FirstOrDefault(x => x.ProfileName == profilerName); + if (byPrfileName != null) + { + return byPrfileName; + } + } + return Connections.FirstOrDefault(x => x.ServerName == serverName); + } + } + + /// + /// The model to deserializing the connections inside settings.json + /// + public class ConnectionProfile + { + public const string CRED_PREFIX = "Microsoft.SqlTools"; + public const string CRED_SEPARATOR = "|"; + public const string CRED_SERVER_PREFIX = "server:"; + public const string CRED_DB_PREFIX = "db:"; + public const string CRED_USER_PREFIX = "user:"; + public const string CRED_ITEMTYPE_PREFIX = "itemtype:"; + + [JsonProperty("server")] + public string ServerName { get; set; } + public string Database { get; set; } + + public string User { get; set; } + + public string Password { get; set; } + + public string ProfileName { get; set; } + + public TestServerType ServerType { get; set; } + + + public string formatCredentialId(string itemType = "Profile") + { + if (!string.IsNullOrEmpty(ServerName)) + { + List cred = new List(); + cred.Add(CRED_PREFIX); + AddToList(itemType, CRED_ITEMTYPE_PREFIX, cred); + AddToList(ServerName, CRED_SERVER_PREFIX, cred); + AddToList(Database, CRED_DB_PREFIX, cred); + AddToList(User, CRED_USER_PREFIX, cred); + return string.Join(CRED_SEPARATOR, cred.ToArray()); + } + return null; + } + private void AddToList(string item, string prefix, List list) + { + if (!string.IsNullOrEmpty(item)) + { + list.Add(string.Format(CultureInfo.InvariantCulture, "{0}{1}", prefix, item)); + } + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Utility/TestResult.cs b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Utility/TestResult.cs new file mode 100644 index 00000000..ac81c5d5 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Utility/TestResult.cs @@ -0,0 +1,12 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.TestDriver.Utility +{ + public class TestResult + { + public double ElapsedTime { get; set; } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Utility/TestRunner.cs b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Utility/TestRunner.cs new file mode 100644 index 00000000..e0f443b1 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Utility/TestRunner.cs @@ -0,0 +1,106 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Globalization; +using System.Linq; +using System.Reflection; +using System.Threading.Tasks; +using Xunit; +using Xunit.Sdk; + +namespace Microsoft.SqlTools.ServiceLayer.TestDriver.Utility +{ + public class TestRunner + { + public static async Task RunTests(string[] tests, string testNamespace) + { + foreach (var test in tests) + { + try + { + var testName = test.Contains(testNamespace) ? test.Replace(testNamespace, "") : test; + bool containsTestName = testName.Contains("."); + var className = containsTestName ? testName.Substring(0, testName.LastIndexOf('.')) : testName; + var methodName = containsTestName ? testName.Substring(testName.LastIndexOf('.') + 1) : null; + Assembly assembly = Assembly.GetEntryAssembly(); + Type type = assembly.GetType(testNamespace + className); + if (type == null) + { + Console.WriteLine("Invalid class name"); + } + else + { + var typeInstance = Activator.CreateInstance(type); + if (string.IsNullOrEmpty(methodName)) + { + var methods = type.GetMethods().Where(x => x.CustomAttributes.Any(a => a.AttributeType == typeof(FactAttribute))); + foreach (var method in methods) + { + await RunTest(typeInstance, method, method.Name); + } + } + else + { + MethodInfo methodInfo = type.GetMethod(methodName); + await RunTest(typeInstance, methodInfo, test); + } + + IDisposable disposable = typeInstance as IDisposable; + if (disposable != null) + { + disposable.Dispose(); + } + } + } + catch (Exception ex) + { + Console.WriteLine(ex.ToString()); + return -1; + } + } + return 0; + } + + private static async Task RunTest(object typeInstance, MethodInfo methodInfo, string testName) + { + try + { + if (methodInfo == null) + { + Console.WriteLine("Invalid method name"); + } + else + { + var testAttributes = methodInfo.CustomAttributes; + BeforeAfterTestAttribute beforeAfterTestAttribute = null; + foreach (var attribute in testAttributes) + { + var args = attribute.ConstructorArguments.Select(x => x.Value as object).ToArray(); + var objAttribute = Activator.CreateInstance(attribute.AttributeType, args); + + beforeAfterTestAttribute = objAttribute as BeforeAfterTestAttribute; + if (beforeAfterTestAttribute != null) + { + beforeAfterTestAttribute.Before(methodInfo); + } + } + Console.WriteLine("Running test " + testName); + await (Task)methodInfo.Invoke(typeInstance, null); + if (beforeAfterTestAttribute != null) + { + beforeAfterTestAttribute.After(methodInfo); + } + Console.WriteLine("Test ran successfully: " + testName); + } + } + catch(Exception ex) + { + Console.WriteLine(string.Format(CultureInfo.InvariantCulture, "Test Failed: {0} error: {1}", testName, ex.Message)); + + } + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Utility/TestServerIdentity.cs b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Utility/TestServerIdentity.cs new file mode 100644 index 00000000..ce9b55fe --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Utility/TestServerIdentity.cs @@ -0,0 +1,25 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.TestDriver.Utility +{ + /// + /// The model to deserialize the server names json + /// + public class TestServerIdentity + { + public string ServerName { get; set; } + public string ProfileName { get; set; } + + public TestServerType ServerType { get; set; } + } + + public enum TestServerType + { + None, + Azure, + OnPrem + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Utility/TestTimer.cs b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Utility/TestTimer.cs new file mode 100644 index 00000000..05e4c00f --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Utility/TestTimer.cs @@ -0,0 +1,84 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Globalization; +using System.IO; +using System.Runtime.CompilerServices; + +namespace Microsoft.SqlTools.ServiceLayer.TestDriver.Utility +{ + /// + /// Timer to calculate the test run time + /// + public class TestTimer + { + private static string resultFolder = InitResultFolder(); + + private static string InitResultFolder() + { + string resultFodler = Environment.GetEnvironmentVariable("ResultFolder"); + if (string.IsNullOrEmpty(resultFodler)) + { + string assemblyLocation = System.Reflection.Assembly.GetEntryAssembly().Location; + resultFodler = Path.GetDirectoryName(assemblyLocation); + } + return resultFodler; + } + + public TestTimer() + { + Start(); + } + + public bool PrintResult { get; set; } + + public void Start() + { + StartDateTime = DateTime.UtcNow; + } + + public void End() + { + EndDateTime = DateTime.UtcNow; + } + + public void EndAndPrint([CallerMemberName] string testName = "") + { + End(); + if (PrintResult) + { + var currentColor = Console.ForegroundColor; + Console.ForegroundColor = ConsoleColor.Green; + Console.WriteLine(string.Format(CultureInfo.InvariantCulture, "Test Name: {0} Run time in milliSeconds: {1}", testName, TotalMilliSeconds)); + Console.ForegroundColor = currentColor; + string resultContent = Newtonsoft.Json.JsonConvert.SerializeObject(new TestResult { ElapsedTime = TotalMilliSeconds }); + string fileName = testName + ".json"; + string resultFilePath = string.IsNullOrEmpty(resultFolder) ? fileName : Path.Combine(resultFolder, fileName); + File.WriteAllText(resultFilePath, resultContent); + Console.WriteLine("Result file: " + resultFilePath); + } + } + + public double TotalMilliSeconds + { + get + { + return (EndDateTime - StartDateTime).TotalMilliseconds; + } + } + + public double TotalMilliSecondsUntilNow + { + get + { + return (DateTime.UtcNow - StartDateTime).TotalMilliseconds; + } + } + + public DateTime StartDateTime { get; private set; } + public DateTime EndDateTime { get; private set; } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/project.json b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/project.json index e2b39b6c..71f0ce7a 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/project.json +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/project.json @@ -3,24 +3,40 @@ "version": "1.0.0-*", "buildOptions": { "debugType": "portable", - "emitEntryPoint": true + "emitEntryPoint": true, + "embed": { + "includeFiles": [ + "Scripts/CreateTestDatabaseObjects.sql", + "Scripts/CreateTestDatabase.sql", + "Scripts/TestDbTableQueries.sql" + ] + }, + "publishOptions": { + "include": [ + "Scripts/AdventureWorks.sql" + ] + } }, "dependencies": { + "xunit": "2.1.0", + "dotnet-test-xunit": "1.0.0-rc2-192208-24", "Microsoft.SqlTools.ServiceLayer": { "target": "project" } }, + "testRunner": "xunit", "frameworks": { "netcoreapp1.0": { "dependencies": { "Microsoft.NETCore.App": { + "type": "platform", "version": "1.0.0" } }, "imports": [ "dotnet5.4", "portable-net451+win8" - ], + ] } }, "runtimes": {