diff --git a/.gitignore b/.gitignore index 97e1cc41..9af4c745 100644 --- a/.gitignore +++ b/.gitignore @@ -284,4 +284,5 @@ Session.vim # Stuff from cake /artifacts/ -/.tools/ \ No newline at end of file +/.tools/ +/.dotnet/ \ No newline at end of file diff --git a/bin/nuget/Microsoft.DataTools.SrGen.1.0.2.nupkg b/bin/nuget/Microsoft.DataTools.SrGen.1.0.2.nupkg new file mode 100644 index 00000000..f410a0af Binary files /dev/null and b/bin/nuget/Microsoft.DataTools.SrGen.1.0.2.nupkg differ diff --git a/bin/nuget/Microsoft.SqlServer.Smo.140.1.11.nupkg b/bin/nuget/Microsoft.SqlServer.Smo.140.1.11.nupkg new file mode 100644 index 00000000..f0280385 Binary files /dev/null and b/bin/nuget/Microsoft.SqlServer.Smo.140.1.11.nupkg differ diff --git a/bin/nuget/System.Data.SqlClient.4.4.0-sqltools-24613-04.nupkg b/bin/nuget/System.Data.SqlClient.4.4.0-sqltools-24613-04.nupkg new file mode 100644 index 00000000..733f5362 Binary files /dev/null and b/bin/nuget/System.Data.SqlClient.4.4.0-sqltools-24613-04.nupkg differ diff --git a/build.cake b/build.cake index 20276ba5..1de9257f 100644 --- a/build.cake +++ b/build.cake @@ -257,6 +257,7 @@ Task("TestCore") /// Task("Test") .IsDependentOn("Setup") + .IsDependentOn("SRGen") .IsDependentOn("BuildTest") .Does(() => { @@ -304,6 +305,7 @@ Task("Test") /// Task("OnlyPublish") .IsDependentOn("Setup") + .IsDependentOn("SRGen") .Does(() => { var project = buildPlan.MainProject; @@ -322,7 +324,13 @@ Task("OnlyPublish") publishArguments = $"{publishArguments} --output \"{outputFolder}\" \"{projectFolder}\""; Run(dotnetcli, publishArguments) .ExceptionOnError($"Failed to publish {project} / {framework}"); - + //Setting the rpath for System.Security.Cryptography.Native.dylib library + //Only required for mac. We're assuming the openssl is installed in /usr/local/opt/openssl + //If that's not the case user has to run the command manually + if (!IsRunningOnWindows() && runtime.Contains("osx")) + { + Run("install_name_tool", "-add_rpath /usr/local/opt/openssl/lib " + outputFolder + "/System.Security.Cryptography.Native.dylib"); + } if (requireArchive) { Package(runtime, framework, outputFolder, packageFolder, buildPlan.MainProject.ToLower()); @@ -359,7 +367,6 @@ Task("RestrictToLocalRuntime") /// Task("LocalPublish") .IsDependentOn("Restore") - .IsDependentOn("SrGen") .IsDependentOn("RestrictToLocalRuntime") .IsDependentOn("OnlyPublish") .Does(() => @@ -529,7 +536,8 @@ Task("SRGen") var dotnetArgs = string.Format("{0} -or \"{1}\" -oc \"{2}\" -ns \"{3}\" -an \"{4}\" -cn SR -l CS -dnx \"{5}\"", srgenPath, outputResx, outputCs, projectName, projectName, projectStrings); Information("{0}", dotnetArgs); - Run(dotnetcli, dotnetArgs); + Run(dotnetcli, dotnetArgs) + .ExceptionOnError("Failed to run SRGen."); } }); diff --git a/build.sh b/build.sh old mode 100644 new mode 100755 diff --git a/nuget.config b/nuget.config index f5d41658..9bf3e13c 100644 --- a/nuget.config +++ b/nuget.config @@ -2,6 +2,7 @@ - + + diff --git a/scripts/packages.config b/scripts/packages.config index 54f9b711..ccbbd827 100644 --- a/scripts/packages.config +++ b/scripts/packages.config @@ -2,5 +2,5 @@ - + diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs index 36e86791..cb931a6e 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs @@ -9,6 +9,7 @@ using System.Collections.Generic; using System.Data; using System.Data.Common; using System.Data.SqlClient; +using System.Linq; using System.Threading; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; @@ -205,7 +206,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection var cancellationTask = Task.Run(() => { source.Token.WaitHandle.WaitOne(); - source.Token.ThrowIfCancellationRequested(); + try + { + source.Token.ThrowIfCancellationRequested(); + } + catch (ObjectDisposedException) + { + // Ignore + } }); var openTask = Task.Run(async () => { @@ -367,7 +375,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection // Success return true; } - + /// /// List all databases on the server specified /// @@ -393,18 +401,28 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection var connection = this.ConnectionFactory.CreateSqlConnection(BuildConnectionString(connectionDetails)); connection.Open(); - DbCommand command = connection.CreateCommand(); - command.CommandText = "SELECT name FROM sys.databases ORDER BY database_id ASC"; - command.CommandTimeout = 15; - command.CommandType = CommandType.Text; - var reader = command.ExecuteReader(); - List results = new List(); - while (reader.Read()) + var systemDatabases = new string[] {"master", "model", "msdb", "tempdb"}; + using (DbCommand command = connection.CreateCommand()) { - results.Add(reader[0].ToString()); + command.CommandText = "SELECT name FROM sys.databases ORDER BY name ASC"; + command.CommandTimeout = 15; + command.CommandType = CommandType.Text; + + using (var reader = command.ExecuteReader()) + { + while (reader.Read()) + { + results.Add(reader[0].ToString()); + } + } } + // Put system databases at the top of the list + results = + results.Where(s => systemDatabases.Any(s.Equals)).Concat( + results.Where(s => systemDatabases.All(x => !s.Equals(x)))).ToList(); + connection.Close(); ListDatabasesResponse response = new ListDatabasesResponse(); diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicyUtils.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicyUtils.cs index 800951d4..c8b86091 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicyUtils.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicyUtils.cs @@ -6,6 +6,7 @@ using System; using System.Collections.Generic; using System.Data.SqlClient; +using System.Runtime.InteropServices; using Microsoft.SqlTools.ServiceLayer.Utility; namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection @@ -215,6 +216,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection public static bool IsRetryableNetworkConnectivityError(int errorNumber) { + // .NET core has a bug on OSX/Linux that makes this error number always zero (issue 12472) + if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + return errorNumber != 0 && _retryableNetworkConnectivityErrors.Contains(errorNumber); + } return _retryableNetworkConnectivityErrors.Contains(errorNumber); } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/ServiceHost.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/ServiceHost.cs index 92b097aa..f217e0d3 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/ServiceHost.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/ServiceHost.cs @@ -150,14 +150,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting Capabilities = new ServerCapabilities { TextDocumentSync = TextDocumentSyncKind.Incremental, - DefinitionProvider = true, - ReferencesProvider = true, - DocumentHighlightProvider = true, - HoverProvider = true, + DefinitionProvider = false, + ReferencesProvider = false, + DocumentHighlightProvider = false, + HoverProvider = true, CompletionProvider = new CompletionOptions { ResolveProvider = true, - TriggerCharacters = new string[] { ".", "-", ":", "\\", ",", " " } + TriggerCharacters = new string[] { ".", "-", ":", "\\" } }, SignatureHelpProvider = new SignatureHelpOptions { diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteHelper.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteHelper.cs index 4eef4915..b92e541b 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteHelper.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteHelper.cs @@ -3,7 +3,12 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // +using System; using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Text.RegularExpressions; +using System.Threading; using Microsoft.SqlServer.Management.SqlParser.Binder; using Microsoft.SqlServer.Management.SqlParser.Intellisense; using Microsoft.SqlServer.Management.SqlParser.Parser; @@ -25,89 +30,44 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices private static WorkspaceService workspaceServiceInstance; + private static Regex ValidSqlNameRegex = new Regex(@"^[\p{L}_@][\p{L}\p{N}@$#_]{0,127}$"); + + private static CompletionItem[] emptyCompletionList = new CompletionItem[0]; + private static readonly string[] DefaultCompletionText = new string[] - { - "absolute", - "accent_sensitivity", - "action", - "activation", - "add", - "address", - "admin", - "after", - "aggregate", - "algorithm", - "allow_page_locks", - "allow_row_locks", - "allow_snapshot_isolation", + { + "all", "alter", - "always", - "ansi_null_default", - "ansi_nulls", - "ansi_padding", - "ansi_warnings", - "application", - "arithabort", + "and", + "apply", "as", "asc", - "assembly", - "asymmetric", "at", - "atomic", - "audit", - "authentication", - "authorization", - "auto", - "auto_close", - "auto_shrink", - "auto_update_statistics", - "auto_update_statistics_async", - "availability", "backup", - "before", "begin", "binary", "bit", - "block", "break", - "browse", - "bucket_count", "bulk", "by", "call", - "caller", - "card", "cascade", "case", - "catalog", "catch", - "change_tracking", - "changes", "char", "character", "check", "checkpoint", "close", "clustered", - "collection", "column", - "column_encryption_key", "columnstore", "commit", - "compatibility_level", - "compress_all_row_groups", - "compression", - "compression_delay", - "compute", - "concat_null_yields_null", - "configuration", "connect", "constraint", - "containstable", "continue", "create", - "cube", - "current", + "cross", "current_date", "cursor", "cursor_close_on_commit", @@ -116,47 +76,34 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices "data_compression", "database", "date", - "date_correlation_optimization", - "datefirst", "datetime", "datetime2", "days", - "db_chaining", "dbcc", - "deallocate", "dec", "decimal", "declare", "default", - "delayed_durability", "delete", "deny", "desc", "description", - "disable_broker", "disabled", "disk", "distinct", - "distributed", "double", "drop", "drop_existing", "dump", - "durability", "dynamic", "else", "enable", "encrypted", - "encryption_type", "end", "end-exec", - "entry", - "errlvl", - "escape", - "event", - "except", "exec", "execute", + "exists", "exit", "external", "fast_forward", @@ -165,20 +112,14 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices "filegroup", "filename", "filestream", - "fillfactor", "filter", "first", "float", "for", "foreign", - "freetext", - "freetexttable", "from", "full", - "fullscan", - "fulltext", "function", - "generated", "geography", "get", "global", @@ -194,30 +135,26 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices "holdlock", "hours", "identity", - "identity_insert", "identitycol", "if", - "ignore_dup_key", "image", "immediate", "include", "index", - "inflectional", - "insensitive", + "inner", "insert", "instead", "int", "integer", - "integrated", "intersect", "into", "isolation", + "join", "json", "key", - "kill", "language", "last", - "legacy_cardinality_estimation", + "left", "level", "lineno", "load", @@ -226,16 +163,12 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices "location", "login", "masked", - "master", "maxdop", - "memory_optimized", "merge", "message", "modify", "move", - "multi_user", "namespace", - "national", "native_compilation", "nchar", "next", @@ -245,9 +178,10 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices "nonclustered", "none", "norecompute", + "not", "now", + "null", "numeric", - "numeric_roundabort", "object", "of", "off", @@ -255,21 +189,16 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices "on", "online", "open", - "opendatasource", - "openquery", "openrowset", "openxml", "option", + "or", "order", "out", + "outer", "output", "over", "owner", - "pad_index", - "page", - "page_verify", - "parameter_sniffing", - "parameterization", "partial", "partition", "password", @@ -280,7 +209,6 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices "persisted", "plan", "policy", - "population", "precision", "predicate", "primary", @@ -289,7 +217,6 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices "proc", "procedure", "public", - "query_optimizer_hotfixes", "query_store", "quoted_identifier", "raiserror", @@ -312,7 +239,6 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices "relative", "remove", "reorganize", - "replication", "required", "restart", "restore", @@ -322,7 +248,6 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices "returns", "revert", "revoke", - "role", "rollback", "rollup", "row", @@ -338,11 +263,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices "scroll", "secondary", "security", - "securityaudit", "select", - "semantickeyphrasetable", - "semanticsimilaritydetailstable", - "semanticsimilaritytable", "send", "sent", "sequence", @@ -351,12 +272,10 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices "set", "sets", "setuser", - "shutdown", "simple", "smallint", "smallmoney", "snapshot", - "sort_in_tempdb", "sql", "standard", "start", @@ -368,20 +287,13 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices "statistics_norecompute", "status", "stopped", - "supported", - "symmetric", "sysname", "system", "system_time", - "system_versioning", "table", - "tablesample", "take", "target", - "textimage_on", - "textsize", "then", - "thesaurus", "throw", "time", "timestamp", @@ -392,14 +304,13 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices "transaction", "trigger", "truncate", - "trustworthy", "try", "tsql", "type", + "uncommitted", "union", "unique", "uniqueidentifier", - "unlimited", "updatetext", "use", "user", @@ -407,24 +318,32 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices "value", "values", "varchar", - "varying", "version", "view", "waitfor", - "weight", "when", "where", "while", "with", "within", - "within group", "without", "writetext", "xact_abort", "xml", - "zone" }; + /// + /// Gets a static instance of an empty completion list to avoid + // unneeded memory allocations + /// + internal static CompletionItem[] EmptyCompletionList + { + get + { + return AutoCompleteHelper.emptyCompletionList; + } + } + /// /// Gets or sets the current workspace service instance /// Setter for internal testing purposes only @@ -443,7 +362,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices { AutoCompleteHelper.workspaceServiceInstance = value; } - } + } /// /// Get the default completion list from hard-coded list @@ -456,17 +375,47 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices int row, int startColumn, int endColumn, - bool useLowerCase) + bool useLowerCase, + string tokenText = null) { - var completionItems = new CompletionItem[DefaultCompletionText.Length]; - for (int i = 0; i < DefaultCompletionText.Length; ++i) + // determine how many default completion items there will be + int listSize = DefaultCompletionText.Length; + if (!string.IsNullOrWhiteSpace(tokenText)) { - completionItems[i] = CreateDefaultCompletionItem( - useLowerCase ? DefaultCompletionText[i].ToLower() : DefaultCompletionText[i].ToUpper(), - row, - startColumn, - endColumn); + listSize = 0; + foreach (var completionText in DefaultCompletionText) + { + if (completionText.StartsWith(tokenText, StringComparison.OrdinalIgnoreCase)) + { + ++listSize; + } + } } + + // special case empty list to avoid unneed array allocations + if (listSize == 0) + { + return emptyCompletionList; + } + + // build the default completion list + var completionItems = new CompletionItem[listSize]; + int completionItemIndex = 0; + foreach (var completionText in DefaultCompletionText) + { + // add item to list if the tokenText is null (meaning return whole list) + // or if the completion item begins with the tokenText + if (string.IsNullOrWhiteSpace(tokenText) || completionText.StartsWith(tokenText, StringComparison.OrdinalIgnoreCase)) + { + completionItems[completionItemIndex] = CreateDefaultCompletionItem( + useLowerCase ? completionText.ToLower() : completionText.ToUpper(), + row, + startColumn, + endColumn); + ++completionItemIndex; + } + } + return completionItems; } @@ -483,14 +432,44 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices int startColumn, int endColumn) { - return new CompletionItem() + return CreateCompletionItem(label, label + " keyword", label, CompletionItemKind.Keyword, row, startColumn, endColumn); + } + + internal static CompletionItem[] AddTokenToItems(CompletionItem[] currentList, Token token, int row, + int startColumn, + int endColumn) + { + if (currentList != null && + token != null && !string.IsNullOrWhiteSpace(token.Text) && + token.Text.All(ch => char.IsLetter(ch)) && + currentList.All(x => string.Compare(x.Label, token.Text, true) != 0 + )) + { + var list = currentList.ToList(); + list.Insert(0, CreateCompletionItem(token.Text, token.Text, token.Text, CompletionItemKind.Text, row, startColumn, endColumn)); + return list.ToArray(); + } + return currentList; + } + + private static CompletionItem CreateCompletionItem( + string label, + string detail, + string insertText, + CompletionItemKind kind, + int row, + int startColumn, + int endColumn) + { + CompletionItem item = new CompletionItem() { Label = label, - Kind = CompletionItemKind.Keyword, - Detail = label + " keyword", + Kind = kind, + Detail = detail, + InsertText = insertText, TextEdit = new TextEdit { - NewText = label, + NewText = insertText, Range = new Range { Start = new Position @@ -506,6 +485,8 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices } } }; + + return item; } /// @@ -521,39 +502,55 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices int row, int startColumn, int endColumn) - { + { List completions = new List(); + foreach (var autoCompleteItem in suggestions) { - // convert the completion item candidates into CompletionItems - completions.Add(new CompletionItem() + string insertText = GetCompletionItemInsertName(autoCompleteItem); + CompletionItemKind kind = CompletionItemKind.Variable; + switch (autoCompleteItem.Type) { - Label = autoCompleteItem.Title, - Kind = CompletionItemKind.Variable, - Detail = autoCompleteItem.Title, - TextEdit = new TextEdit - { - NewText = autoCompleteItem.Title, - Range = new Range - { - Start = new Position - { - Line = row, - Character = startColumn - }, - End = new Position - { - Line = row, - Character = endColumn - } - } - } - }); + case DeclarationType.Schema: + kind = CompletionItemKind.Module; + break; + case DeclarationType.Column: + kind = CompletionItemKind.Field; + break; + case DeclarationType.Table: + case DeclarationType.View: + kind = CompletionItemKind.File; + break; + case DeclarationType.Database: + kind = CompletionItemKind.Method; + break; + case DeclarationType.ScalarValuedFunction: + case DeclarationType.TableValuedFunction: + case DeclarationType.BuiltInFunction: + kind = CompletionItemKind.Value; + break; + default: + kind = CompletionItemKind.Unit; + break; + } + + // convert the completion item candidates into CompletionItems + completions.Add(CreateCompletionItem(autoCompleteItem.Title, autoCompleteItem.Title, insertText, kind, row, startColumn, endColumn)); } return completions.ToArray(); } + private static string GetCompletionItemInsertName(Declaration autoCompleteItem) + { + string insertText = autoCompleteItem.Title; + if (!string.IsNullOrEmpty(autoCompleteItem.Title) && !ValidSqlNameRegex.IsMatch(autoCompleteItem.Title)) + { + insertText = string.Format(CultureInfo.InvariantCulture, "[{0}]", autoCompleteItem.Title); + } + return insertText; + } + /// /// Preinitialize the parser and binder with common metadata. /// This should front load the long binding wait to the time the @@ -572,15 +569,14 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices var scriptFile = AutoCompleteHelper.WorkspaceServiceInstance.Workspace.GetFile(info.OwnerUri); LanguageService.Instance.ParseAndBind(scriptFile, info); - if (scriptInfo.BuildingMetadataEvent.WaitOne(LanguageService.OnConnectionWaitTimeout)) + if (Monitor.TryEnter(scriptInfo.BuildingMetadataLock, LanguageService.OnConnectionWaitTimeout)) { try { - scriptInfo.BuildingMetadataEvent.Reset(); - QueueItem queueItem = bindingQueue.QueueBindingOperation( key: scriptInfo.ConnectionKey, bindingTimeout: AutoCompleteHelper.PrepopulateBindTimeout, + waitForLockTimeout: AutoCompleteHelper.PrepopulateBindTimeout, bindOperation: (bindingContext, cancelToken) => { // parse a simple statement that returns common metadata @@ -631,13 +627,61 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices } finally { - scriptInfo.BuildingMetadataEvent.Set(); + Monitor.Exit(scriptInfo.BuildingMetadataLock); } } } } + /// + /// Converts a SQL Parser QuickInfo object into a VS Code Hover object + /// + /// + /// + /// + /// + internal static Hover ConvertQuickInfoToHover( + Babel.CodeObjectQuickInfo quickInfo, + int row, + int startColumn, + int endColumn) + { + // convert from the parser format to the VS Code wire format + var markedStrings = new MarkedString[1]; + if (quickInfo != null) + { + markedStrings[0] = new MarkedString() + { + Language = "SQL", + Value = quickInfo.Text + }; + + return new Hover() + { + Contents = markedStrings, + Range = new Range + { + Start = new Position + { + Line = row, + Character = startColumn + }, + End = new Position + { + Line = row, + Character = endColumn + } + } + }; + } + else + { + return null; + } + } + + /// /// Converts a SQL Parser QuickInfo object into a VS Code Hover object /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/BindingQueue.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/BindingQueue.cs index 2b165dc8..6058ec42 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/BindingQueue.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/BindingQueue.cs @@ -15,7 +15,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices /// Main class for the Binding Queue /// public class BindingQueue where T : IBindingContext, new() - { + { private CancellationTokenSource processQueueCancelToken = new CancellationTokenSource(); private ManualResetEvent itemQueuedEvent = new ManualResetEvent(initialState: false); @@ -61,7 +61,8 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices string key, Func bindOperation, Func timeoutOperation = null, - int? bindingTimeout = null) + int? bindingTimeout = null, + int? waitForLockTimeout = null) { // don't add null operations to the binding queue if (bindOperation == null) @@ -74,7 +75,8 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices Key = key, BindOperation = bindOperation, TimeoutOperation = timeoutOperation, - BindingTimeout = bindingTimeout + BindingTimeout = bindingTimeout, + WaitForLockTimeout = waitForLockTimeout }; lock (this.bindingQueueLock) @@ -98,7 +100,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices { key = "disconnected_binding_context"; } - + lock (this.bindingContextLock) { if (!this.BindingContextMap.ContainsKey(key)) @@ -107,7 +109,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices } return this.BindingContextMap[key]; - } + } } private bool HasPendingQueueItems @@ -191,19 +193,26 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices continue; } + bool lockTaken = false; try { // prefer the queue item binding item, otherwise use the context default timeout int bindTimeout = queueItem.BindingTimeout ?? bindingContext.BindingTimeout; - // handle the case a previous binding operation is still running - if (!bindingContext.BindingLocked.WaitOne(bindTimeout)) + // handle the case a previous binding operation is still running + if (!bindingContext.BindingLock.WaitOne(queueItem.WaitForLockTimeout ?? 0)) { - queueItem.Result = queueItem.TimeoutOperation(bindingContext); - queueItem.ItemProcessed.Set(); + queueItem.Result = queueItem.TimeoutOperation != null + ? queueItem.TimeoutOperation(bindingContext) + : null; + continue; } + bindingContext.BindingLock.Reset(); + + lockTaken = true; + // execute the binding operation object result = null; CancellationTokenSource cancelToken = new CancellationTokenSource(); @@ -220,13 +229,18 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices queueItem.Result = result; } else - { + { + cancelToken.Cancel(); + // if the task didn't complete then call the timeout callback if (queueItem.TimeoutOperation != null) - { - cancelToken.Cancel(); + { queueItem.Result = queueItem.TimeoutOperation(bindingContext); } + + lockTaken = false; + + bindTask.ContinueWith((a) => bindingContext.BindingLock.Set()); } } catch (Exception ex) @@ -237,7 +251,11 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices } finally { - bindingContext.BindingLocked.Set(); + if (lockTaken) + { + bindingContext.BindingLock.Set(); + } + queueItem.ItemProcessed.Set(); } @@ -250,8 +268,15 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices } finally { - // reset the item queued event since we've processed all the pending items - this.itemQueuedEvent.Reset(); + lock (this.bindingQueueLock) + { + // verify the binding queue is still empty + if (this.bindingQueue.Count == 0) + { + // reset the item queued event since we've processed all the pending items + this.itemQueuedEvent.Reset(); + } + } } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ConnectedBindingContext.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ConnectedBindingContext.cs index 8851abe1..0add6ec4 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ConnectedBindingContext.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ConnectedBindingContext.cs @@ -21,6 +21,8 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices { private ParseOptions parseOptions; + private ManualResetEvent bindingLock; + private ServerConnection serverConnection; /// @@ -28,7 +30,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices /// public ConnectedBindingContext() { - this.BindingLocked = new ManualResetEvent(initialState: true); + this.bindingLock = new ManualResetEvent(initialState: true); this.BindingTimeout = ConnectedBindingQueue.DefaultBindingTimeout; this.MetadataDisplayInfoProvider = new MetadataDisplayInfoProvider(); } @@ -72,9 +74,15 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices public IBinder Binder { get; set; } /// - /// Gets or sets an event to signal if a binding operation is in progress + /// Gets the binding lock object /// - public ManualResetEvent BindingLocked { get; set; } + public ManualResetEvent BindingLock + { + get + { + return this.bindingLock; + } + } /// /// Gets or sets the binding operation timeout in milliseconds diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ConnectedBindingQueue.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ConnectedBindingQueue.cs index c99f0cc6..965d94d2 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ConnectedBindingQueue.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ConnectedBindingQueue.cs @@ -21,7 +21,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices /// public class ConnectedBindingQueue : BindingQueue { - internal const int DefaultBindingTimeout = 60000; + internal const int DefaultBindingTimeout = 500; internal const int DefaultMinimumConnectionTimeout = 30; @@ -63,22 +63,24 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices string connectionKey = GetConnectionContextKey(connInfo); IBindingContext bindingContext = this.GetOrCreateBindingContext(connectionKey); - try + if (bindingContext.BindingLock.WaitOne()) { - // increase the connection timeout to at least 30 seconds and and build connection string - // enable PersistSecurityInfo to handle issues in SMO where the connection context is lost in reconnections - int? originalTimeout = connInfo.ConnectionDetails.ConnectTimeout; - bool? originalPersistSecurityInfo = connInfo.ConnectionDetails.PersistSecurityInfo; - connInfo.ConnectionDetails.ConnectTimeout = Math.Max(DefaultMinimumConnectionTimeout, originalTimeout ?? 0); - connInfo.ConnectionDetails.PersistSecurityInfo = true; - string connectionString = ConnectionService.BuildConnectionString(connInfo.ConnectionDetails); - connInfo.ConnectionDetails.ConnectTimeout = originalTimeout; - connInfo.ConnectionDetails.PersistSecurityInfo = originalPersistSecurityInfo; - - // open a dedicated binding server connection - SqlConnection sqlConn = new SqlConnection(connectionString); - if (sqlConn != null) + try { + bindingContext.BindingLock.Reset(); + + // increase the connection timeout to at least 30 seconds and and build connection string + // enable PersistSecurityInfo to handle issues in SMO where the connection context is lost in reconnections + int? originalTimeout = connInfo.ConnectionDetails.ConnectTimeout; + bool? originalPersistSecurityInfo = connInfo.ConnectionDetails.PersistSecurityInfo; + connInfo.ConnectionDetails.ConnectTimeout = Math.Max(DefaultMinimumConnectionTimeout, originalTimeout ?? 0); + connInfo.ConnectionDetails.PersistSecurityInfo = true; + string connectionString = ConnectionService.BuildConnectionString(connInfo.ConnectionDetails); + connInfo.ConnectionDetails.ConnectTimeout = originalTimeout; + connInfo.ConnectionDetails.PersistSecurityInfo = originalPersistSecurityInfo; + + // open a dedicated binding server connection + SqlConnection sqlConn = new SqlConnection(connectionString); sqlConn.Open(); // populate the binding context to work with the SMO metadata provider @@ -91,16 +93,16 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices bindingContext.Binder = BinderProvider.CreateBinder(bindingContext.SmoMetadataProvider); bindingContext.ServerConnection = serverConn; bindingContext.BindingTimeout = ConnectedBindingQueue.DefaultBindingTimeout; - bindingContext.IsConnected = true; + bindingContext.IsConnected = true; } - } - catch (Exception) - { - bindingContext.IsConnected = false; - } - finally - { - bindingContext.BindingLocked.Set(); + catch (Exception) + { + bindingContext.IsConnected = false; + } + finally + { + bindingContext.BindingLock.Set(); + } } return connectionKey; diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/IBindingContext.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/IBindingContext.cs index c83a28d7..aa4637b3 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/IBindingContext.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/IBindingContext.cs @@ -44,9 +44,9 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices IBinder Binder { get; set; } /// - /// Gets or sets an event to signal if a binding operation is in progress + /// Gets the binding lock object /// - ManualResetEvent BindingLocked { get; set; } + ManualResetEvent BindingLock { get; } /// /// Gets or sets the binding operation timeout in milliseconds diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs index 889ad882..e03e748b 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs @@ -37,11 +37,9 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices internal const int DiagnosticParseDelay = 750; - internal const int HoverTimeout = 3000; + internal const int HoverTimeout = 500; - internal const int BindingTimeout = 3000; - - internal const int FindCompletionStartTimeout = 50; + internal const int BindingTimeout = 500; internal const int OnConnectionWaitTimeout = 300000; @@ -264,10 +262,10 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices var completionItems = Instance.GetCompletionItems( textDocumentPosition, scriptFile, connInfo); - - await requestContext.SendResult(completionItems); + + await requestContext.SendResult(completionItems); + } } - } /// /// Handle the resolve completion request event to provide additional @@ -394,15 +392,15 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices SqlToolsSettings oldSettings, EventContext eventContext) { - bool oldEnableIntelliSense = oldSettings.SqlTools.EnableIntellisense; - bool? oldEnableDiagnostics = oldSettings.SqlTools.IntelliSense.EnableDiagnostics; + bool oldEnableIntelliSense = oldSettings.SqlTools.IntelliSense.EnableIntellisense; + bool? oldEnableDiagnostics = oldSettings.SqlTools.IntelliSense.EnableErrorChecking; // update the current settings to reflect any changes CurrentSettings.Update(newSettings); // if script analysis settings have changed we need to clear the current diagnostic markers - if (oldEnableIntelliSense != newSettings.SqlTools.EnableIntellisense - || oldEnableDiagnostics != newSettings.SqlTools.IntelliSense.EnableDiagnostics) + if (oldEnableIntelliSense != newSettings.SqlTools.IntelliSense.EnableIntellisense + || oldEnableDiagnostics != newSettings.SqlTools.IntelliSense.EnableErrorChecking) { // if the user just turned off diagnostics then send an event to clear the error markers if (!newSettings.IsDiagnositicsEnabled) @@ -452,12 +450,10 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices // get or create the current parse info object ScriptParseInfo parseInfo = GetScriptParseInfo(scriptFile.ClientFilePath, createIfNotExists: true); - if (parseInfo.BuildingMetadataEvent.WaitOne(LanguageService.BindingTimeout)) + if (Monitor.TryEnter(parseInfo.BuildingMetadataLock, LanguageService.BindingTimeout)) { try { - parseInfo.BuildingMetadataEvent.Reset(); - if (connInfo == null || !parseInfo.IsConnected) { // parse current SQL file contents to retrieve a list of errors @@ -518,7 +514,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices } finally { - parseInfo.BuildingMetadataEvent.Set(); + Monitor.Exit(parseInfo.BuildingMetadataLock); } } else @@ -538,11 +534,10 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices await Task.Run(() => { ScriptParseInfo scriptInfo = GetScriptParseInfo(info.OwnerUri, createIfNotExists: true); - if (scriptInfo.BuildingMetadataEvent.WaitOne(LanguageService.OnConnectionWaitTimeout)) + if (Monitor.TryEnter(scriptInfo.BuildingMetadataLock, LanguageService.OnConnectionWaitTimeout)) { try - { - scriptInfo.BuildingMetadataEvent.Reset(); + { scriptInfo.ConnectionKey = this.BindingQueue.AddConnectionContext(info); scriptInfo.IsConnected = true; @@ -556,7 +551,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices { // Set Metadata Build event to Signal state. // (Tell Language Service that I am ready with Metadata Provider Object) - scriptInfo.BuildingMetadataEvent.Set(); + Monitor.Exit(scriptInfo.BuildingMetadataLock); } } @@ -588,28 +583,45 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices /// internal CompletionItem ResolveCompletionItem(CompletionItem completionItem) { - try + var scriptParseInfo = LanguageService.Instance.currentCompletionParseInfo; + if (scriptParseInfo != null && scriptParseInfo.CurrentSuggestions != null) { - var scriptParseInfo = LanguageService.Instance.currentCompletionParseInfo; - if (scriptParseInfo != null && scriptParseInfo.CurrentSuggestions != null) + if (Monitor.TryEnter(scriptParseInfo.BuildingMetadataLock)) { - foreach (var suggestion in scriptParseInfo.CurrentSuggestions) + try { - if (string.Equals(suggestion.Title, completionItem.Label)) - { - completionItem.Detail = suggestion.DatabaseQualifiedName; - completionItem.Documentation = suggestion.Description; - break; - } + QueueItem queueItem = this.BindingQueue.QueueBindingOperation( + key: scriptParseInfo.ConnectionKey, + bindingTimeout: LanguageService.BindingTimeout, + bindOperation: (bindingContext, cancelToken) => + { + foreach (var suggestion in scriptParseInfo.CurrentSuggestions) + { + if (string.Equals(suggestion.Title, completionItem.Label)) + { + completionItem.Detail = suggestion.DatabaseQualifiedName; + completionItem.Documentation = suggestion.Description; + break; + } + } + return completionItem; + }); + + queueItem.ItemProcessed.WaitOne(); } + catch (Exception ex) + { + // if any exceptions are raised looking up extended completion metadata + // then just return the original completion item + Logger.Write(LogLevel.Error, "Exeception in ResolveCompletionItem " + ex.ToString()); + } + finally + { + Monitor.Exit(scriptParseInfo.BuildingMetadataLock); + } } } - catch (Exception ex) - { - // if any exceptions are raised looking up extended completion metadata - // then just return the original completion item - Logger.Write(LogLevel.Error, "Exeception in ResolveCompletionItem " + ex.ToString()); - } + return completionItem; } @@ -631,9 +643,8 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices ScriptParseInfo scriptParseInfo = GetScriptParseInfo(textDocumentPosition.TextDocument.Uri); if (scriptParseInfo != null && scriptParseInfo.ParseResult != null) { - if (scriptParseInfo.BuildingMetadataEvent.WaitOne(LanguageService.FindCompletionStartTimeout)) + if (Monitor.TryEnter(scriptParseInfo.BuildingMetadataLock)) { - scriptParseInfo.BuildingMetadataEvent.Reset(); try { QueueItem queueItem = this.BindingQueue.QueueBindingOperation( @@ -661,8 +672,8 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices } finally { - scriptParseInfo.BuildingMetadataEvent.Set(); - } + Monitor.Exit(scriptParseInfo.BuildingMetadataLock); + } } } @@ -680,23 +691,32 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices ScriptFile scriptFile, ConnectionInfo connInfo) { + // initialize some state to parse and bind the current script file + this.currentCompletionParseInfo = null; + CompletionItem[] resultCompletionItems = null; string filePath = textDocumentPosition.TextDocument.Uri; int startLine = textDocumentPosition.Position.Line; + int parserLine = textDocumentPosition.Position.Line + 1; int startColumn = TextUtilities.PositionOfPrevDelimeter( scriptFile.Contents, textDocumentPosition.Position.Line, textDocumentPosition.Position.Character); - int endColumn = textDocumentPosition.Position.Character; + int endColumn = TextUtilities.PositionOfNextDelimeter( + scriptFile.Contents, + textDocumentPosition.Position.Line, + textDocumentPosition.Position.Character); + int parserColumn = textDocumentPosition.Position.Character + 1; bool useLowerCaseSuggestions = this.CurrentSettings.SqlTools.IntelliSense.LowerCaseSuggestions.Value; - this.currentCompletionParseInfo = null; - - // Take a reference to the list at a point in time in case we update and replace the list - + // get the current script parse info object ScriptParseInfo scriptParseInfo = GetScriptParseInfo(textDocumentPosition.TextDocument.Uri); - if (connInfo == null || scriptParseInfo == null) + if (scriptParseInfo == null) { - return AutoCompleteHelper.GetDefaultCompletionItems(startLine, startColumn, endColumn, useLowerCaseSuggestions); + return AutoCompleteHelper.GetDefaultCompletionItems( + startLine, + startColumn, + endColumn, + useLowerCaseSuggestions); } // reparse and bind the SQL statement if needed @@ -705,62 +725,128 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices ParseAndBind(scriptFile, connInfo); } + // if the parse failed then return the default list if (scriptParseInfo.ParseResult == null) { - return AutoCompleteHelper.GetDefaultCompletionItems(startLine, startColumn, endColumn, useLowerCaseSuggestions); + return AutoCompleteHelper.GetDefaultCompletionItems( + startLine, + startColumn, + endColumn, + useLowerCaseSuggestions); } + + // need to adjust line & column for base-1 parser indices + Token token = GetToken(scriptParseInfo, parserLine, parserColumn); + string tokenText = token != null ? token.Text : null; - if (scriptParseInfo.IsConnected - && scriptParseInfo.BuildingMetadataEvent.WaitOne(LanguageService.FindCompletionStartTimeout)) - { - scriptParseInfo.BuildingMetadataEvent.Reset(); - - QueueItem queueItem = this.BindingQueue.QueueBindingOperation( - key: scriptParseInfo.ConnectionKey, - bindingTimeout: LanguageService.BindingTimeout, - bindOperation: (bindingContext, cancelToken) => - { - CompletionItem[] completions = null; - try + // check if the file is connected and the file lock is available + if (scriptParseInfo.IsConnected && Monitor.TryEnter(scriptParseInfo.BuildingMetadataLock)) + { + try + { + // queue the completion task with the binding queue + QueueItem queueItem = this.BindingQueue.QueueBindingOperation( + key: scriptParseInfo.ConnectionKey, + bindingTimeout: LanguageService.BindingTimeout, + bindOperation: (bindingContext, cancelToken) => { // get the completion list from SQL Parser scriptParseInfo.CurrentSuggestions = Resolver.FindCompletions( scriptParseInfo.ParseResult, - textDocumentPosition.Position.Line + 1, - textDocumentPosition.Position.Character + 1, + parserLine, + parserColumn, bindingContext.MetadataDisplayInfoProvider); // cache the current script parse info object to resolve completions later this.currentCompletionParseInfo = scriptParseInfo; - + // convert the suggestion list to the VS Code format - completions = AutoCompleteHelper.ConvertDeclarationsToCompletionItems( + return AutoCompleteHelper.ConvertDeclarationsToCompletionItems( scriptParseInfo.CurrentSuggestions, startLine, startColumn, endColumn); - } - finally + }, + timeoutOperation: (bindingContext) => { - scriptParseInfo.BuildingMetadataEvent.Set(); - } + // return the default list if the connected bind fails + return AutoCompleteHelper.GetDefaultCompletionItems( + startLine, + startColumn, + endColumn, + useLowerCaseSuggestions, + tokenText); + }); - return completions; - }, - timeoutOperation: (bindingContext) => + // wait for the queue item + queueItem.ItemProcessed.WaitOne(); + + var completionItems = queueItem.GetResultAsT(); + if (completionItems != null && completionItems.Length > 0) { - return AutoCompleteHelper.GetDefaultCompletionItems(startLine, startColumn, endColumn, useLowerCaseSuggestions); - }); - - queueItem.ItemProcessed.WaitOne(); - var completionItems = queueItem.GetResultAsT(); - if (completionItems != null && completionItems.Length > 0) + resultCompletionItems = completionItems; + } + else if (!ShouldShowCompletionList(token)) + { + resultCompletionItems = AutoCompleteHelper.EmptyCompletionList; + } + } + finally { - return completionItems; + Monitor.Exit(scriptParseInfo.BuildingMetadataLock); } } - return AutoCompleteHelper.GetDefaultCompletionItems(startLine, startColumn, endColumn, useLowerCaseSuggestions); + // if there are no completions then provide the default list + if (resultCompletionItems == null) + { + resultCompletionItems = AutoCompleteHelper.GetDefaultCompletionItems( + startLine, + startColumn, + endColumn, + useLowerCaseSuggestions, + tokenText); + } + + return resultCompletionItems; + } + + private static Token GetToken(ScriptParseInfo scriptParseInfo, int startLine, int startColumn) + { + if (scriptParseInfo != null && scriptParseInfo.ParseResult != null && scriptParseInfo.ParseResult.Script != null && scriptParseInfo.ParseResult.Script.Tokens != null) + { + var tokenIndex = scriptParseInfo.ParseResult.Script.TokenManager.FindToken(startLine, startColumn); + if (tokenIndex >= 0) + { + // return the current token + int currentIndex = 0; + foreach (var token in scriptParseInfo.ParseResult.Script.Tokens) + { + if (currentIndex == tokenIndex) + { + return token; + } + ++currentIndex; + } + } + } + return null; + } + + private static bool ShouldShowCompletionList(Token token) + { + bool result = true; + if (token != null) + { + switch (token.Id) + { + case (int)Tokens.LEX_MULTILINE_COMMENT: + case (int)Tokens.LEX_END_OF_LINE_COMMENT: + result = false; + break; + } + } + return result; } #endregion @@ -897,6 +983,11 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices // Get the requested files foreach (ScriptFile scriptFile in filesToAnalyze) { + if (IsPreviewWindow(scriptFile)) + { + continue; + } + Logger.Write(LogLevel.Verbose, "Analyzing script file: " + scriptFile.FilePath); ScriptFileMarker[] semanticMarkers = GetSemanticMarkers(scriptFile); Logger.Write(LogLevel.Verbose, "Analysis complete."); diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/QueueItem.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/QueueItem.cs index adf5fa18..a320f842 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/QueueItem.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/QueueItem.cs @@ -5,7 +5,6 @@ using System; using System.Threading; -using System.Threading.Tasks; namespace Microsoft.SqlTools.ServiceLayer.LanguageServices { @@ -52,6 +51,11 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices /// public int? BindingTimeout { get; set; } + /// + /// Gets or sets the timeout for how long to wait for the binding lock + /// + public int? WaitForLockTimeout { get; set; } + /// /// Converts the result of the execution to type T /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ScriptParseInfo.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ScriptParseInfo.cs index 2c56d497..ec87ae6f 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ScriptParseInfo.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ScriptParseInfo.cs @@ -4,7 +4,6 @@ // using System.Collections.Generic; -using System.Threading; using Microsoft.SqlServer.Management.SqlParser.Intellisense; using Microsoft.SqlServer.Management.SqlParser.Parser; @@ -15,14 +14,14 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices /// internal class ScriptParseInfo { - private ManualResetEvent buildingMetadataEvent = new ManualResetEvent(initialState: true); + private object buildingMetadataLock = new object(); /// /// Event which tells if MetadataProvider is built fully or not /// - public ManualResetEvent BuildingMetadataEvent + public object BuildingMetadataLock { - get { return this.buildingMetadataEvent; } + get { return this.buildingMetadataLock; } } /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/Program.cs b/src/Microsoft.SqlTools.ServiceLayer/Program.cs index 35a659fe..25bc92b1 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Program.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Program.cs @@ -23,9 +23,16 @@ namespace Microsoft.SqlTools.ServiceLayer /// internal static void Main(string[] args) { + // read command-line arguments + CommandOptions commandOptions = new CommandOptions(args); + if (commandOptions.ShouldExit) + { + return; + } + // turn on Verbose logging during early development // we need to switch to Normal when preparing for public preview - Logger.Initialize(minimumLogLevel: LogLevel.Verbose); + Logger.Initialize(minimumLogLevel: LogLevel.Verbose, isEnabled: commandOptions.EnableLogging); Logger.Write(LogLevel.Normal, "Starting SQL Tools Service Host"); // set up the host details and profile paths diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs index 51a35a7d..7b54484e 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs @@ -31,7 +31,17 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution private bool disposed; /// - /// Factory for creating readers/writrs for the output of the batch + /// Local time when the execution and retrieval of files is finished + /// + private DateTime executionEndTime; + + /// + /// Local time when the execution starts, specifically when the object is created + /// + private DateTime executionStartTime; + + /// + /// Factory for creating readers/writers for the output of the batch /// private readonly IFileStreamFactory outputFileFactory; @@ -69,6 +79,30 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// public string BatchText { get; set; } + /// + /// Localized timestamp for when the execution completed. + /// Stored in UTC ISO 8601 format; should be localized before displaying to any user + /// + public string ExecutionEndTimeStamp { get { return executionEndTime.ToString("o"); } } + + /// + /// Localized timestamp for how long it took for the execution to complete + /// + public string ExecutionElapsedTime + { + get + { + TimeSpan elapsedTime = executionEndTime - executionStartTime; + return elapsedTime.ToString(); + } + } + + /// + /// Localized timestamp for when the execution began. + /// Stored in UTC ISO 8601 format; should be localized before displaying to any user + /// + public string ExecutionStartTimeStamp { get { return executionStartTime.ToString("o"); } } + /// /// Whether or not this batch has an error /// @@ -90,7 +124,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// /// The result sets of the batch execution /// - public IEnumerable ResultSets + public IList ResultSets { get { return resultSets; } } @@ -136,14 +170,17 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution try { DbCommand command = null; - - // Register the message listener to *this instance* of the batch - // Note: This is being done to associate messages with batches ReliableSqlConnection sqlConn = conn as ReliableSqlConnection; if (sqlConn != null) { + // Register the message listener to *this instance* of the batch + // Note: This is being done to associate messages with batches sqlConn.GetUnderlyingConnection().InfoMessage += StoreDbMessage; command = sqlConn.GetUnderlyingConnection().CreateCommand(); + + // Add a handler for when the command completes + SqlCommand sqlCommand = (SqlCommand) command; + sqlCommand.StatementCompleted += StatementCompletedHandler; } else { @@ -151,7 +188,8 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution } // Make sure we aren't using a ReliableCommad since we do not want automatic retry - Debug.Assert(!(command is ReliableSqlConnection.ReliableSqlCommand), "ReliableSqlCommand command should not be used to execute queries"); + Debug.Assert(!(command is ReliableSqlConnection.ReliableSqlCommand), + "ReliableSqlCommand command should not be used to execute queries"); // Create a command that we'll use for executing the query using (command) @@ -159,6 +197,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution command.CommandText = BatchText; command.CommandType = CommandType.Text; command.CommandTimeout = 0; + executionStartTime = DateTime.Now; // Execute the command to get back a reader using (DbDataReader reader = await command.ExecuteReaderAsync(cancellationToken)) @@ -168,24 +207,26 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution // Skip this result set if there aren't any rows (ie, UPDATE/DELETE/etc queries) if (!reader.HasRows && reader.FieldCount == 0) { - // Create a message with the number of affected rows -- IF the query affects rows - resultMessages.Add(new ResultMessage(reader.RecordsAffected >= 0 - ? SR.QueryServiceAffectedRows(reader.RecordsAffected) - : SR.QueryServiceCompletedSuccessfully)); continue; } // This resultset has results (ie, SELECT/etc queries) - // Read until we hit the end of the result set ResultSet resultSet = new ResultSet(reader, outputFileFactory); - await resultSet.ReadResultToEnd(cancellationToken); - + // Add the result set to the results of the query resultSets.Add(resultSet); + + // Read until we hit the end of the result set + await resultSet.ReadResultToEnd(cancellationToken).ConfigureAwait(false); - // Add a message for the number of rows the query returned - resultMessages.Add(new ResultMessage(SR.QueryServiceAffectedRows(resultSet.RowCount))); } while (await reader.NextResultAsync(cancellationToken)); + + // If there were no messages, for whatever reason (NO COUNT set, messages + // were emitted, records returned), output a "successful" message + if (resultMessages.Count == 0) + { + resultMessages.Add(new ResultMessage(SR.QueryServiceCompletedSuccessfully)); + } } } } @@ -194,9 +235,15 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution HasError = true; UnwrapDbException(dbe); } - catch (Exception) + catch (TaskCanceledException) + { + resultMessages.Add(new ResultMessage(SR.QueryServiceQueryCancelled)); + throw; + } + catch (Exception e) { HasError = true; + resultMessages.Add(new ResultMessage(SR.QueryServiceQueryFailed(e.Message))); throw; } finally @@ -210,6 +257,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution // Mark that we have executed HasExecuted = true; + executionEndTime = DateTime.Now; } } @@ -236,6 +284,29 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution #region Private Helpers + /// + /// Handler for when the StatementCompleted event is fired for this batch's command. This + /// will be executed ONLY when there is a rowcount to report. If this event is not fired + /// either NOCOUNT has been set or the command doesn't affect records. + /// + /// Sender of the event + /// Arguments for the event + private void StatementCompletedHandler(object sender, StatementCompletedEventArgs args) + { + // Add a message for the number of rows the query returned + string message; + if (args.RecordCount == 1) + { + message = SR.QueryServiceAffectedOneRow; + } + else + { + message = SR.QueryServiceAffectedRows(args.RecordCount); + } + + resultMessages.Add(new ResultMessage(message)); + } + /// /// Delegate handler for storing messages that are returned from the server /// NOTE: Only messages that are below a certain severity will be returned via this @@ -260,15 +331,23 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution SqlException se = dbe as SqlException; if (se != null) { - foreach (var error in se.Errors) + var errors = se.Errors.Cast().ToList(); + // Detect user cancellation errors + if (errors.Any(error => error.Class == 11 && error.Number == 0)) { - SqlError sqlError = error as SqlError; - if (sqlError != null) + // User cancellation error, add the single message + HasError = false; + resultMessages.Add(new ResultMessage(SR.QueryServiceQueryCancelled)); + } + else + { + // Not a user cancellation error, add all + foreach (var error in errors) { - int lineNumber = sqlError.LineNumber + Selection.StartLine; + int lineNumber = error.LineNumber + Selection.StartLine; string message = string.Format("Msg {0}, Level {1}, State {2}, Line {3}{4}{5}", - sqlError.Number, sqlError.Class, sqlError.State, lineNumber, - Environment.NewLine, sqlError.Message); + error.Number, error.Class, error.State, lineNumber, + Environment.NewLine, error.Message); resultMessages.Add(new ResultMessage(message)); } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/BatchSummary.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/BatchSummary.cs index 7e1b2837..884d76f6 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/BatchSummary.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/BatchSummary.cs @@ -10,6 +10,21 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts /// public class BatchSummary { + /// + /// Localized timestamp for how long it took for the execution to complete + /// + public string ExecutionElapsed { get; set; } + + /// + /// Localized timestamp for when the execution completed. + /// + public string ExecutionEnd { get; set; } + + /// + /// Localized timestamp for when the execution started. + /// + public string ExecutionStart { get; set; } + /// /// Whether or not the batch was successful. True indicates errors, false indicates success /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteCompleteNotification.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteCompleteNotification.cs index 90c8c7b3..8375235a 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteCompleteNotification.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteCompleteNotification.cs @@ -21,6 +21,11 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts /// Summaries of the result sets that were returned with the query /// public BatchSummary[] BatchSummaries { get; set; } + + /// + /// Error message, if any + /// + public string Message { get; set; } } public class QueryExecuteCompleteEvent diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/SaveResultsRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/SaveResultsRequest.cs index 1cf2390e..85f87b9d 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/SaveResultsRequest.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/SaveResultsRequest.cs @@ -60,17 +60,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts /// Parameters to save results as CSV /// public class SaveResultsAsCsvRequestParams: SaveResultsRequestParams{ - - /// - /// CSV - Write values in quotes - /// - public Boolean ValueInQuotes { get; set; } - - /// - /// The encoding of the file to save results in - /// - public string FileEncoding { get; set; } - /// /// Include headers of columns in CSV /// @@ -95,6 +84,17 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts public string Messages { get; set; } } + /// + /// Error object for save result + /// + public class SaveResultRequestError + { + /// + /// Error message + /// + public string message { get; set; } + } + /// /// Request type to save results as CSV /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/FileUtils.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/FileUtils.cs new file mode 100644 index 00000000..d795c2c2 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/FileUtils.cs @@ -0,0 +1,46 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// +using System; +using System.IO; +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution +{ + internal static class FileUtils + { + /// + /// Checks if file exists and swallows exceptions, if any + /// + /// path of the file + /// + internal static bool SafeFileExists(string path) + { + try + { + return File.Exists(path); + } + catch (Exception) + { + // Swallow exception + return false; + } + } + + /// + /// Deletes a file and swallows exceptions, if any + /// + /// + internal static void SafeFileDelete(string path) + { + try + { + File.Delete(path); + } + catch (Exception) + { + // Swallow exception, do nothing + } + } + + } +} \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs index cf2df73c..d47fafca 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs @@ -96,6 +96,33 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution #region Properties + /// + /// Delegate type for callback when a query completes or fails + /// + /// The query that completed + public delegate Task QueryAsyncEventHandler(Query q); + + /// + /// Delegate type for callback when a query connection fails + /// + /// The query that completed + public delegate Task QueryAsyncErrorEventHandler(string message); + + /// + /// Callback for when the query has completed successfully + /// + public event QueryAsyncEventHandler QueryCompleted; + + /// + /// Callback for when the query has failed + /// + public event QueryAsyncEventHandler QueryFailed; + + /// + /// Callback for when the query connection has failed + /// + public event QueryAsyncErrorEventHandler QueryConnectionException; + /// /// The batches underneath this query /// @@ -116,6 +143,9 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution return Batches.Select((batch, index) => new BatchSummary { Id = index, + ExecutionStart = batch.ExecutionStartTimeStamp, + ExecutionEnd = batch.ExecutionEndTimeStamp, + ExecutionElapsed = batch.ExecutionElapsedTime, HasError = batch.HasError, Messages = batch.ResultMessages.ToArray(), ResultSetSummaries = batch.ResultSummaries, @@ -124,6 +154,8 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution } } + internal Task ExecutionTask { get; private set; } + /// /// Whether or not the query has completed executed, regardless of success or failure /// @@ -167,10 +199,44 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution cancellationSource.Cancel(); } + public void Execute() + { + ExecutionTask = Task.Run(ExecuteInternal); + } + + /// + /// Retrieves a subset of the result sets + /// + /// The index for selecting the batch item + /// The index for selecting the result set + /// The starting row of the results + /// How many rows to retrieve + /// A subset of results + public Task GetSubset(int batchIndex, int resultSetIndex, int startRow, int rowCount) + { + // Sanity check that the results are available + if (!HasExecuted) + { + throw new InvalidOperationException(SR.QueryServiceSubsetNotCompleted); + } + + // Sanity check to make sure that the batch is within bounds + if (batchIndex < 0 || batchIndex >= Batches.Length) + { + throw new ArgumentOutOfRangeException(nameof(batchIndex), SR.QueryServiceSubsetBatchOutOfRange); + } + + return Batches[batchIndex].GetSubset(resultSetIndex, startRow, rowCount); + } + + #endregion + + #region Private Helpers + /// /// Executes this query asynchronously and collects all result sets /// - public async Task Execute() + private async Task ExecuteInternal() { // Mark that we've internally executed hasExecuteBeenCalled = true; @@ -186,7 +252,19 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution // TODO: Don't create a new connection every time, see TFS #834978 using (DbConnection conn = editorConnection.Factory.CreateSqlConnection(connectionString)) { - await conn.OpenAsync(); + try + { + await conn.OpenAsync(); + } + catch(Exception exception) + { + this.HasExecuted = true; + if (QueryConnectionException != null) + { + await QueryConnectionException(exception.Message); + } + return; + } ReliableSqlConnection sqlConn = conn as ReliableSqlConnection; if (sqlConn != null) @@ -202,6 +280,20 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution { await b.Execute(conn, cancellationSource.Token); } + + // Call the query execution callback + if (QueryCompleted != null) + { + await QueryCompleted(this); + } + } + catch (Exception) + { + // Call the query failure callback + if (QueryFailed != null) + { + await QueryFailed(this); + } } finally { @@ -227,7 +319,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution throw new InvalidOperationException(SR.QueryServiceMessageSenderNotSql); } - foreach(SqlError error in args.Errors) + foreach (SqlError error in args.Errors) { // Did the database context change (error code 5701)? if (error.Number == DatabaseContextChangeErrorNumber) @@ -237,31 +329,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution } } - /// - /// Retrieves a subset of the result sets - /// - /// The index for selecting the batch item - /// The index for selecting the result set - /// The starting row of the results - /// How many rows to retrieve - /// A subset of results - public Task GetSubset(int batchIndex, int resultSetIndex, int startRow, int rowCount) - { - // Sanity check that the results are available - if (!HasExecuted) - { - throw new InvalidOperationException(SR.QueryServiceSubsetNotCompleted); - } - - // Sanity check to make sure that the batch is within bounds - if (batchIndex < 0 || batchIndex >= Batches.Length) - { - throw new ArgumentOutOfRangeException(nameof(batchIndex), SR.QueryServiceSubsetBatchOutOfRange); - } - - return Batches[batchIndex].GetSubset(resultSetIndex, startRow, rowCount); - } - #endregion #region IDisposable Implementation diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs index d0ff8d1a..c7614596 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs @@ -4,7 +4,6 @@ // using System; using System.Collections.Concurrent; -using System.IO; using System.Linq; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Connection; @@ -16,7 +15,6 @@ using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.Utility; using Microsoft.SqlTools.ServiceLayer.Workspace; using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; -using Newtonsoft.Json; namespace Microsoft.SqlTools.ServiceLayer.QueryExecution { @@ -129,19 +127,11 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution public async Task HandleExecuteRequest(QueryExecuteParams executeParams, RequestContext requestContext) { - try - { - // Get a query new active query - Query newQuery = await CreateAndActivateNewQuery(executeParams, requestContext); + // Get a query new active query + Query newQuery = await CreateAndActivateNewQuery(executeParams, requestContext); - // Execute the query - await ExecuteAndCompleteQuery(executeParams, requestContext, newQuery); - } - catch (Exception e) - { - // Dump any unexpected exceptions as errors - await requestContext.SendError(e.Message); - } + // Execute the query -- asynchronously + await ExecuteAndCompleteQuery(executeParams, requestContext, newQuery); } public async Task HandleResultSubsetRequest(QueryExecuteSubsetParams subsetParams, @@ -239,21 +229,8 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution return; } - // Cancel the query + // Cancel the query and send a success message result.Cancel(); - result.Dispose(); - - // Attempt to dispose the query - if (!ActiveQueries.TryRemove(cancelParams.OwnerUri, out result)) - { - // It really shouldn't be possible to get to this scenario, but we'll cover it anyhow - await requestContext.SendResult(new QueryCancelResult - { - Messages = SR.QueryServiceCancelDisposeFailed - }); - return; - } - await requestContext.SendResult(new QueryCancelResult()); } catch (InvalidOperationException e) @@ -273,7 +250,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// /// Process request to save a resultSet to a file in CSV format /// - public async Task HandleSaveResultsAsCsvRequest(SaveResultsAsCsvRequestParams saveParams, + internal async Task HandleSaveResultsAsCsvRequest(SaveResultsAsCsvRequestParams saveParams, RequestContext requestContext) { // retrieve query for OwnerUri @@ -286,67 +263,39 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution }); return; } - try + + + ResultSet selectedResultSet = result.Batches[saveParams.BatchIndex].ResultSets[saveParams.ResultSetIndex]; + if (!selectedResultSet.IsBeingDisposed) { - using (StreamWriter csvFile = new StreamWriter(File.Open(saveParams.FilePath, FileMode.Create))) + // Create SaveResults object and add success and error handlers to respective events + SaveResults saveAsCsv = new SaveResults(); + + SaveResults.AsyncSaveEventHandler successHandler = async message => { - // get the requested resultSet from query - Batch selectedBatch = result.Batches[saveParams.BatchIndex]; - ResultSet selectedResultSet = (selectedBatch.ResultSets.ToList())[saveParams.ResultSetIndex]; - int columnCount = 0; - int rowCount = 0; - int columnStartIndex = 0; - int rowStartIndex = 0; - - // set column, row counts depending on whether save request is for entire result set or a subset - if (SaveResults.isSaveSelection(saveParams)) - { - columnCount = saveParams.ColumnEndIndex.Value - saveParams.ColumnStartIndex.Value + 1; - rowCount = saveParams.RowEndIndex.Value - saveParams.RowStartIndex.Value + 1; - columnStartIndex = saveParams.ColumnStartIndex.Value; - rowStartIndex =saveParams.RowStartIndex.Value; - } - else - { - columnCount = selectedResultSet.Columns.Length; - rowCount = (int)selectedResultSet.RowCount; - } - - // write column names if include headers option is chosen - if (saveParams.IncludeHeaders) - { - await csvFile.WriteLineAsync( string.Join( ",", selectedResultSet.Columns.Skip(columnStartIndex).Take(columnCount).Select( column => - SaveResults.EncodeCsvField(column.ColumnName) ?? string.Empty))); - } - - // retrieve rows and write as csv - ResultSetSubset resultSubset = await result.GetSubset(saveParams.BatchIndex, saveParams.ResultSetIndex, rowStartIndex, rowCount); - foreach (var row in resultSubset.Rows) - { - await csvFile.WriteLineAsync( string.Join( ",", row.Skip(columnStartIndex).Take(columnCount).Select( field => - SaveResults.EncodeCsvField((field != null) ? field.ToString(): "NULL")))); - } - - } - - // Successfully wrote file, send success result - await requestContext.SendResult(new SaveResultRequestResult { Messages = null }); - } - catch(Exception ex) - { - // Delete file when exception occurs - if (File.Exists(saveParams.FilePath)) + selectedResultSet.RemoveSaveTask(saveParams.FilePath); + await requestContext.SendResult(new SaveResultRequestResult { Messages = message }); + }; + saveAsCsv.SaveCompleted += successHandler; + SaveResults.AsyncSaveEventHandler errorHandler = async message => { - File.Delete(saveParams.FilePath); - } - await requestContext.SendError(ex.Message); + selectedResultSet.RemoveSaveTask(saveParams.FilePath); + await requestContext.SendError(new SaveResultRequestError { message = message }); + }; + saveAsCsv.SaveFailed += errorHandler; + + saveAsCsv.SaveResultSetAsCsv(saveParams, requestContext, result); + + // Associate the ResultSet with the save task + selectedResultSet.AddSaveTask(saveParams.FilePath, saveAsCsv.SaveTask); + } } /// /// Process request to save a resultSet to a file in JSON format /// - public async Task HandleSaveResultsAsJsonRequest(SaveResultsAsJsonRequestParams saveParams, + internal async Task HandleSaveResultsAsJsonRequest(SaveResultsAsJsonRequestParams saveParams, RequestContext requestContext) { // retrieve query for OwnerUri @@ -359,73 +308,31 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution }); return; } - try + + ResultSet selectedResultSet = result.Batches[saveParams.BatchIndex].ResultSets[saveParams.ResultSetIndex]; + if (!selectedResultSet.IsBeingDisposed) { - using (StreamWriter jsonFile = new StreamWriter(File.Open(saveParams.FilePath, FileMode.Create))) - using (JsonWriter jsonWriter = new JsonTextWriter(jsonFile) ) + // Create SaveResults object and add success and error handlers to respective events + SaveResults saveAsJson = new SaveResults(); + SaveResults.AsyncSaveEventHandler successHandler = async message => { - jsonWriter.Formatting = Formatting.Indented; - jsonWriter.WriteStartArray(); - - // get the requested resultSet from query - Batch selectedBatch = result.Batches[saveParams.BatchIndex]; - ResultSet selectedResultSet = selectedBatch.ResultSets.ToList()[saveParams.ResultSetIndex]; - int rowCount = 0; - int rowStartIndex = 0; - int columnStartIndex = 0; - int columnEndIndex = 0; - - // set column, row counts depending on whether save request is for entire result set or a subset - if (SaveResults.isSaveSelection(saveParams)) - { - - rowCount = saveParams.RowEndIndex.Value - saveParams.RowStartIndex.Value + 1; - rowStartIndex = saveParams.RowStartIndex.Value; - columnStartIndex = saveParams.ColumnStartIndex.Value; - columnEndIndex = saveParams.ColumnEndIndex.Value + 1 ; // include the last column - } - else - { - rowCount = (int)selectedResultSet.RowCount; - columnEndIndex = selectedResultSet.Columns.Length; - } - - // retrieve rows and write as json - ResultSetSubset resultSubset = await result.GetSubset(saveParams.BatchIndex, saveParams.ResultSetIndex, rowStartIndex, rowCount); - foreach (var row in resultSubset.Rows) - { - jsonWriter.WriteStartObject(); - for (int i = columnStartIndex ; i < columnEndIndex; i++) - { - //get column name - DbColumnWrapper col = selectedResultSet.Columns[i]; - string val = row[i]?.ToString(); - jsonWriter.WritePropertyName(col.ColumnName); - if (val == null) - { - jsonWriter.WriteNull(); - } - else - { - jsonWriter.WriteValue(val); - } - } - jsonWriter.WriteEndObject(); - } - jsonWriter.WriteEndArray(); - } - - await requestContext.SendResult(new SaveResultRequestResult { Messages = null }); - } - catch(Exception ex) - { - // Delete file when exception occurs - if (File.Exists(saveParams.FilePath)) + selectedResultSet.RemoveSaveTask(saveParams.FilePath); + await requestContext.SendResult(new SaveResultRequestResult { Messages = message }); + }; + saveAsJson.SaveCompleted += successHandler; + SaveResults.AsyncSaveEventHandler errorHandler = async message => { - File.Delete(saveParams.FilePath); - } - await requestContext.SendError(ex.Message); + selectedResultSet.RemoveSaveTask(saveParams.FilePath); + await requestContext.SendError(new SaveResultRequestError { message = message }); + }; + saveAsJson.SaveFailed += errorHandler; + + saveAsJson.SaveResultSetAsJson(saveParams, requestContext, result); + + // Associate the ResultSet with the save task + selectedResultSet.AddSaveTask(saveParams.FilePath, saveAsJson.SaveTask); } + } #endregion @@ -440,10 +347,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution ConnectionInfo connectionInfo; if (!ConnectionService.TryFindConnection(executeParams.OwnerUri, out connectionInfo)) { - await requestContext.SendResult(new QueryExecuteResult - { - Messages = SR.QueryServiceQueryInvalidOwnerUri - }); + await requestContext.SendError(SR.QueryServiceQueryInvalidOwnerUri); return null; } @@ -463,49 +367,47 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution string queryText; - if (executeParams.QuerySelection != null) + if (executeParams.QuerySelection != null) { string[] queryTextArray = queryFile.GetLinesInRange( new BufferRange( new BufferPosition( - executeParams.QuerySelection.StartLine + 1, + executeParams.QuerySelection.StartLine + 1, executeParams.QuerySelection.StartColumn + 1 - ), + ), new BufferPosition( - executeParams.QuerySelection.EndLine + 1, + executeParams.QuerySelection.EndLine + 1, executeParams.QuerySelection.EndColumn + 1 ) ) ); queryText = queryTextArray.Aggregate((a, b) => a + '\r' + '\n' + b); - } - else + } + else { queryText = queryFile.Contents; } - + // If we can't add the query now, it's assumed the query is in progress Query newQuery = new Query(queryText, connectionInfo, settings, BufferFileFactory); if (!ActiveQueries.TryAdd(executeParams.OwnerUri, newQuery)) { - await requestContext.SendResult(new QueryExecuteResult - { - Messages = SR.QueryServiceQueryInProgress - }); + await requestContext.SendError(SR.QueryServiceQueryInProgress); + newQuery.Dispose(); return null; } return newQuery; } - catch (ArgumentException ane) + catch (Exception e) { - await requestContext.SendResult(new QueryExecuteResult { Messages = ane.Message }); + await requestContext.SendError(e.Message); return null; } // Any other exceptions will fall through here and be collected at the end } - private async Task ExecuteAndCompleteQuery(QueryExecuteParams executeParams, RequestContext requestContext, Query query) + private static async Task ExecuteAndCompleteQuery(QueryExecuteParams executeParams, RequestContext requestContext, Query query) { // Skip processing if the query is null if (query == null) @@ -513,21 +415,41 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution return; } - // Launch the query and respond with successfully launching it - Task executeTask = query.Execute(); + // Setup the query completion/failure callbacks + Query.QueryAsyncEventHandler callback = async q => + { + // Send back the results + QueryExecuteCompleteParams eventParams = new QueryExecuteCompleteParams + { + OwnerUri = executeParams.OwnerUri, + BatchSummaries = q.BatchSummaries + }; + await requestContext.SendEvent(QueryExecuteCompleteEvent.Type, eventParams); + }; + + Query.QueryAsyncErrorEventHandler errorCallback = async errorMessage => + { + // Send back the error message + QueryExecuteCompleteParams eventParams = new QueryExecuteCompleteParams + { + OwnerUri = executeParams.OwnerUri, + Message = errorMessage + }; + await requestContext.SendEvent(QueryExecuteCompleteEvent.Type, eventParams); + }; + + query.QueryCompleted += callback; + query.QueryFailed += callback; + query.QueryConnectionException += errorCallback; + + // Launch this as an asynchronous task + query.Execute(); + + // Send back a result showing we were successful await requestContext.SendResult(new QueryExecuteResult { Messages = null }); - - // Wait for query execution and then send back the results - await Task.WhenAll(executeTask); - QueryExecuteCompleteParams eventParams = new QueryExecuteCompleteParams - { - OwnerUri = executeParams.OwnerUri, - BatchSummaries = query.BatchSummaries - }; - await requestContext.SendEvent(QueryExecuteCompleteEvent.Type, eventParams); } #endregion diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs index ad392cf7..a96de759 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs @@ -4,9 +4,11 @@ // using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Data.Common; using System.Linq; +using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; @@ -43,22 +45,31 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// private readonly IFileStreamFactory fileStreamFactory; - /// - /// File stream reader that will be reused to make rapid-fire retrieval of result subsets - /// quick and low perf impact. - /// - private IFileStreamReader fileStreamReader; - /// /// Whether or not the result set has been read in from the database /// private bool hasBeenRead; + /// + /// Whether resultSet is a 'for xml' or 'for json' result + /// + private bool isSingleColumnXmlJsonResultSet; + /// /// The name of the temporary file we're using to output these results in /// private readonly string outputFileName; + /// + /// Whether the resultSet is in the process of being disposed + /// + private bool isBeingDisposed; + + /// + /// All save tasks currently saving this ResultSet + /// + private ConcurrentDictionary saveTasks; + #endregion /// @@ -80,10 +91,23 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution // Store the factory fileStreamFactory = factory; hasBeenRead = false; + saveTasks = new ConcurrentDictionary(); } #region Properties + /// + /// Whether the resultSet is in the process of being disposed + /// + /// + internal bool IsBeingDisposed + { + get + { + return isBeingDisposed; + } + } + /// /// The columns for this result set /// @@ -114,18 +138,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// public long RowCount { get; private set; } - /// - /// The rows of this result set - /// - public IEnumerable Rows - { - get - { - return FileOffsets.Select( - offset => fileStreamReader.ReadRow(offset, Columns).Select(cell => cell.DisplayValue).ToArray()); - } - } - #endregion #region Public Methods @@ -139,7 +151,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution public Task GetSubset(int startRow, int rowCount) { // Sanity check to make sure that the results have been read beforehand - if (!hasBeenRead || fileStreamReader == null) + if (!hasBeenRead) { throw new InvalidOperationException(SR.QueryServiceResultSetNotRead); } @@ -156,14 +168,32 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution return Task.Factory.StartNew(() => { - // Figure out which rows we need to read back - IEnumerable rowOffsets = FileOffsets.Skip(startRow).Take(rowCount); - // Iterate over the rows we need and process them into output - string[][] rows = rowOffsets.Select(rowOffset => - fileStreamReader.ReadRow(rowOffset, Columns).Select(cell => cell.DisplayValue).ToArray()) - .ToArray(); + string[][] rows; + using (IFileStreamReader fileStreamReader = fileStreamFactory.GetReader(outputFileName)) + { + // If result set is 'for xml' or 'for json', + // Concatenate all the rows together into one row + if (isSingleColumnXmlJsonResultSet) + { + // Iterate over all the rows and process them into a list of string builders + IEnumerable rowValues = FileOffsets.Select(rowOffset => fileStreamReader.ReadRow(rowOffset, Columns)[0].DisplayValue); + rows = new[] { new[] { string.Join(string.Empty, rowValues) } }; + + } + else + { + // Figure out which rows we need to read back + IEnumerable rowOffsets = FileOffsets.Skip(startRow).Take(rowCount); + + // Iterate over the rows we need and process them into output + rows = rowOffsets.Select(rowOffset => + fileStreamReader.ReadRow(rowOffset, Columns).Select(cell => cell.DisplayValue).ToArray()) + .ToArray(); + + } + } // Retrieve the subset of the results as per the request return new ResultSetSubset { @@ -179,6 +209,9 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// Cancellation token for cancelling the query public async Task ReadResultToEnd(CancellationToken cancellationToken) { + // Mark that result has been read + hasBeenRead = true; + // Open a writer for the file using (IFileStreamWriter fileWriter = fileStreamFactory.GetWriter(outputFileName, MaxCharsToStore, MaxXmlCharsToStore)) { @@ -199,10 +232,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution } // Check if resultset is 'for xml/json'. If it is, set isJson/isXml value in column metadata SingleColumnXmlJsonResultSet(); - - // Mark that result has been read - hasBeenRead = true; - fileStreamReader = fileStreamFactory.GetReader(outputFileName); } #endregion @@ -222,13 +251,31 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution return; } - if (disposing) + isBeingDisposed = true; + // Check if saveTasks are running for this ResultSet + if (!saveTasks.IsEmpty) { - fileStreamReader?.Dispose(); - fileStreamFactory.DisposeFile(outputFileName); + // Wait for tasks to finish before disposing ResultSet + Task.WhenAll(saveTasks.Values.ToArray()).ContinueWith((antecedent) => + { + if (disposing) + { + fileStreamFactory.DisposeFile(outputFileName); + } + disposed = true; + isBeingDisposed = false; + }); + } + else + { + // If saveTasks is empty, continue with dispose + if (disposing) + { + fileStreamFactory.DisposeFile(outputFileName); + } + disposed = true; + isBeingDisposed = false; } - - disposed = true; } #endregion @@ -243,19 +290,43 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// private void SingleColumnXmlJsonResultSet() { - if (Columns?.Length == 1) + if (Columns?.Length == 1 && RowCount != 0) { if (Columns[0].ColumnName.Equals(NameOfForXMLColumn, StringComparison.Ordinal)) { Columns[0].IsXml = true; + isSingleColumnXmlJsonResultSet = true; + RowCount = 1; } else if (Columns[0].ColumnName.Equals(NameOfForJSONColumn, StringComparison.Ordinal)) { Columns[0].IsJson = true; + isSingleColumnXmlJsonResultSet = true; + RowCount = 1; } } } #endregion + + #region Internal Methods to Add and Remove save tasks + internal void AddSaveTask(string key, Task saveTask) + { + saveTasks.TryAdd(key, saveTask); + } + + internal void RemoveSaveTask(string key) + { + Task completedTask; + saveTasks.TryRemove(key, out completedTask); + } + + internal Task GetSaveTask(string key) + { + Task completedTask; + saveTasks.TryRemove(key, out completedTask); + return completedTask; + } + #endregion } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/SaveResults.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/SaveResults.cs index 9188d5b0..713eb764 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/SaveResults.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/SaveResults.cs @@ -3,15 +3,47 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // using System; +using System.IO; +using System.Linq; using System.Text; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Newtonsoft.Json; + namespace Microsoft.SqlTools.ServiceLayer.QueryExecution { - internal class SaveResults{ + internal class SaveResults + { + /// + /// Number of rows being read from the ResultSubset in one read + /// + private const int BatchSize = 1000; + + /// + /// Save Task that asynchronously writes ResultSet to file + /// + internal Task SaveTask { get; set; } + + /// + /// Event Handler for save events + /// + /// Message to be returned to client + /// + internal delegate Task AsyncSaveEventHandler(string message); + + /// + /// A successful save event + /// + internal event AsyncSaveEventHandler SaveCompleted; + + /// + /// A failed save event + /// + internal event AsyncSaveEventHandler SaveFailed; /// Method ported from SSMS - /// /// Encodes a single field for inserting into a CSV record. The following rules are applied: /// @@ -32,7 +64,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution internal static String EncodeCsvField(String field) { StringBuilder sbField = new StringBuilder(field); - + //Whether this field has special characters which require it to be embedded in quotes bool embedInQuotes = false; @@ -67,12 +99,12 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution } } } - + //Replace all quotes in the original field with double quotes sbField.Replace("\"", "\"\""); String ret = sbField.ToString(); - + if (embedInQuotes) { ret = "\"" + ret + "\""; @@ -81,11 +113,208 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution return ret; } - internal static bool isSaveSelection(SaveResultsRequestParams saveParams) + /// + /// Check if request is a subset of result set or whole result set + /// + /// Parameters from the request + /// + internal static bool IsSaveSelection(SaveResultsRequestParams saveParams) { return (saveParams.ColumnStartIndex != null && saveParams.ColumnEndIndex != null && saveParams.RowEndIndex != null && saveParams.RowEndIndex != null); } + + /// + /// Save results as JSON format to the file specified in saveParams + /// + /// Parameters from the request + /// Request context for save results + /// Result query object + /// + internal void SaveResultSetAsJson(SaveResultsAsJsonRequestParams saveParams, RequestContext requestContext, Query result) + { + // Run in a separate thread + SaveTask = Task.Run(async () => + { + try + { + using (StreamWriter jsonFile = new StreamWriter(File.Open(saveParams.FilePath, FileMode.Create, FileAccess.ReadWrite, FileShare.Read))) + using (JsonWriter jsonWriter = new JsonTextWriter(jsonFile)) + { + + int rowCount = 0; + int rowStartIndex = 0; + int columnStartIndex = 0; + int columnEndIndex = 0; + + jsonWriter.Formatting = Formatting.Indented; + jsonWriter.WriteStartArray(); + + // Get the requested resultSet from query + Batch selectedBatch = result.Batches[saveParams.BatchIndex]; + ResultSet selectedResultSet = selectedBatch.ResultSets[saveParams.ResultSetIndex]; + + // Set column, row counts depending on whether save request is for entire result set or a subset + if (IsSaveSelection(saveParams)) + { + + rowCount = saveParams.RowEndIndex.Value - saveParams.RowStartIndex.Value + 1; + rowStartIndex = saveParams.RowStartIndex.Value; + columnStartIndex = saveParams.ColumnStartIndex.Value; + columnEndIndex = saveParams.ColumnEndIndex.Value + 1; // include the last column + } + else + { + rowCount = (int)selectedResultSet.RowCount; + columnEndIndex = selectedResultSet.Columns.Length; + } + + // Split rows into batches + for (int count = 0; count < (rowCount / BatchSize) + 1; count++) + { + int numberOfRows = (count < rowCount / BatchSize) ? BatchSize : (rowCount % BatchSize); + if (numberOfRows == 0) + { + break; + } + + // Retrieve rows and write as json + ResultSetSubset resultSubset = await result.GetSubset(saveParams.BatchIndex, saveParams.ResultSetIndex, rowStartIndex + count * BatchSize, numberOfRows); + foreach (var row in resultSubset.Rows) + { + jsonWriter.WriteStartObject(); + for (int i = columnStartIndex; i < columnEndIndex; i++) + { + // Write columnName, value pair + DbColumnWrapper col = selectedResultSet.Columns[i]; + string val = row[i]?.ToString(); + jsonWriter.WritePropertyName(col.ColumnName); + if (val == null) + { + jsonWriter.WriteNull(); + } + else + { + jsonWriter.WriteValue(val); + } + } + jsonWriter.WriteEndObject(); + } + + } + jsonWriter.WriteEndArray(); + } + + // Successfully wrote file, send success result + if (SaveCompleted != null) + { + await SaveCompleted(null); + } + + + } + catch (Exception ex) + { + // Delete file when exception occurs + if (FileUtils.SafeFileExists(saveParams.FilePath)) + { + FileUtils.SafeFileDelete(saveParams.FilePath); + } + if (SaveFailed != null) + { + await SaveFailed(ex.Message); + } + } + }); + } + + /// + /// Save results as CSV format to the file specified in saveParams + /// + /// Parameters from the request + /// Request context for save results + /// Result query object + /// + internal void SaveResultSetAsCsv(SaveResultsAsCsvRequestParams saveParams, RequestContext requestContext, Query result) + { + // Run in a separate thread + SaveTask = Task.Run(async () => + { + try + { + using (StreamWriter csvFile = new StreamWriter(File.Open(saveParams.FilePath, FileMode.Create, FileAccess.ReadWrite, FileShare.Read))) + { + ResultSetSubset resultSubset; + int columnCount = 0; + int rowCount = 0; + int columnStartIndex = 0; + int rowStartIndex = 0; + + // Get the requested resultSet from query + Batch selectedBatch = result.Batches[saveParams.BatchIndex]; + ResultSet selectedResultSet = (selectedBatch.ResultSets)[saveParams.ResultSetIndex]; + // Set column, row counts depending on whether save request is for entire result set or a subset + if (IsSaveSelection(saveParams)) + { + columnCount = saveParams.ColumnEndIndex.Value - saveParams.ColumnStartIndex.Value + 1; + rowCount = saveParams.RowEndIndex.Value - saveParams.RowStartIndex.Value + 1; + columnStartIndex = saveParams.ColumnStartIndex.Value; + rowStartIndex = saveParams.RowStartIndex.Value; + } + else + { + columnCount = selectedResultSet.Columns.Length; + rowCount = (int)selectedResultSet.RowCount; + } + + // Write column names if include headers option is chosen + if (saveParams.IncludeHeaders) + { + csvFile.WriteLine(string.Join(",", selectedResultSet.Columns.Skip(columnStartIndex).Take(columnCount).Select(column => + EncodeCsvField(column.ColumnName) ?? string.Empty))); + } + + for (int i = 0; i < (rowCount / BatchSize) + 1; i++) + { + int numberOfRows = (i < rowCount / BatchSize) ? BatchSize : (rowCount % BatchSize); + if (numberOfRows == 0) + { + break; + } + // Retrieve rows and write as csv + resultSubset = await result.GetSubset(saveParams.BatchIndex, saveParams.ResultSetIndex, rowStartIndex + i * BatchSize, numberOfRows); + + foreach (var row in resultSubset.Rows) + { + csvFile.WriteLine(string.Join(",", row.Skip(columnStartIndex).Take(columnCount).Select(field => + EncodeCsvField((field != null) ? field.ToString() : "NULL")))); + } + } + } + + // Successfully wrote file, send success result + if (SaveCompleted != null) + { + await SaveCompleted(null); + } + } + catch (Exception ex) + { + // Delete file when exception occurs + if (FileUtils.SafeFileExists(saveParams.FilePath)) + { + FileUtils.SafeFileDelete(saveParams.FilePath); + } + + if (SaveFailed != null) + { + await SaveFailed(ex.Message); + } + } + }); + } + + } } \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/SqlContext/IntelliSenseSettings.cs b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/IntelliSenseSettings.cs index f46d3556..9d6caab3 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/SqlContext/IntelliSenseSettings.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/IntelliSenseSettings.cs @@ -15,12 +15,19 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext /// public IntelliSenseSettings() { + this.EnableIntellisense = true; this.EnableSuggestions = true; this.LowerCaseSuggestions = false; - this.EnableDiagnostics = true; + this.EnableErrorChecking = true; this.EnableQuickInfo = true; } + /// + /// Gets or sets a flag determining if IntelliSense is enabled + /// + /// + public bool EnableIntellisense { get; set; } + /// /// Gets or sets a flag determining if suggestions are enabled /// @@ -35,7 +42,7 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext /// /// Gets or sets a flag determining if diagnostics are enabled /// - public bool? EnableDiagnostics { get; set; } + public bool? EnableErrorChecking { get; set; } /// /// Gets or sets a flag determining if quick info is enabled @@ -52,7 +59,7 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext { this.EnableSuggestions = settings.EnableSuggestions; this.LowerCaseSuggestions = settings.LowerCaseSuggestions; - this.EnableDiagnostics = settings.EnableDiagnostics; + this.EnableErrorChecking = settings.EnableErrorChecking; this.EnableQuickInfo = settings.EnableQuickInfo; } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsSettings.cs b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsSettings.cs index cfa438c0..37c35ebf 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsSettings.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsSettings.cs @@ -3,6 +3,8 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // +using Newtonsoft.Json; + namespace Microsoft.SqlTools.ServiceLayer.SqlContext { /// @@ -15,6 +17,7 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext /// /// Gets or sets the underlying settings value object /// + [JsonProperty("mssql")] public SqlToolsSettingsValues SqlTools { get @@ -47,7 +50,7 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext { if (settings != null) { - this.SqlTools.EnableIntellisense = settings.SqlTools.EnableIntellisense; + this.SqlTools.IntelliSense.EnableIntellisense = settings.SqlTools.IntelliSense.EnableIntellisense; this.SqlTools.IntelliSense.Update(settings.SqlTools.IntelliSense); } } @@ -59,8 +62,8 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext { get { - return this.SqlTools.EnableIntellisense - && this.SqlTools.IntelliSense.EnableDiagnostics.Value; + return this.SqlTools.IntelliSense.EnableIntellisense + && this.SqlTools.IntelliSense.EnableErrorChecking.Value; } } @@ -71,7 +74,7 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext { get { - return this.SqlTools.EnableIntellisense + return this.SqlTools.IntelliSense.EnableIntellisense && this.SqlTools.IntelliSense.EnableSuggestions.Value; } } @@ -83,7 +86,7 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext { get { - return this.SqlTools.EnableIntellisense + return this.SqlTools.IntelliSense.EnableIntellisense && this.SqlTools.IntelliSense.EnableQuickInfo.Value; } } @@ -99,17 +102,11 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext /// public SqlToolsSettingsValues() { - this.EnableIntellisense = true; + this.IntelliSense = new IntelliSenseSettings(); this.QueryExecutionSettings = new QueryExecutionSettings(); } - /// - /// Gets or sets a flag determining if IntelliSense is enabled - /// - /// - public bool EnableIntellisense { get; set; } - /// /// Gets or sets the detailed IntelliSense settings /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/Utility/CommandOptions.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/CommandOptions.cs new file mode 100644 index 00000000..0ac371c5 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Utility/CommandOptions.cs @@ -0,0 +1,91 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; + +namespace Microsoft.SqlTools.ServiceLayer.Utility +{ + /// + /// The command-line options helper class. + /// + internal class CommandOptions + { + /// + /// Construct and parse command line options from the arguments array + /// + public CommandOptions(string[] args) + { + ErrorMessage = string.Empty; + + try + { + for (int i = 0; i < args.Length; ++i) + { + string arg = args[i]; + if (arg.StartsWith("--") || arg.StartsWith("-")) + { + arg = arg.Substring(1).ToLowerInvariant(); + switch (arg) + { + case "-enable-logging": + EnableLogging = true; + break; + case "h": + case "-help": + ShouldExit = true; + return; + default: + ErrorMessage += String.Format("Unknown argument \"{0}\"" + Environment.NewLine, arg); + break; + } + } + } + } + catch (Exception ex) + { + ErrorMessage += ex.ToString(); + return; + } + finally + { + if (!string.IsNullOrEmpty(ErrorMessage) || ShouldExit) + { + Console.WriteLine(Usage); + ShouldExit = true; + } + } + } + + internal string ErrorMessage { get; private set; } + + + /// + /// Whether diagnostic logging is enabled + /// + public bool EnableLogging { get; private set; } + + /// + /// Whether the program should exit immediately. Set to true when the usage is printed. + /// + public bool ShouldExit { get; private set; } + + /// + /// Get the usage string describing command-line arguments for the program + /// + public string Usage + { + get + { + var str = string.Format("{0}" + Environment.NewLine + + "Microsoft.SqlTools.ServiceLayer.exe " + Environment.NewLine + + " Options:" + Environment.NewLine + + " [--enable-logging]" + Environment.NewLine + + " [--help]" + Environment.NewLine, + ErrorMessage); + return str; + } + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Utility/TextUtilities.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/TextUtilities.cs index 0da84f43..29ca9e7f 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Utility/TextUtilities.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Utility/TextUtilities.cs @@ -6,7 +6,34 @@ namespace Microsoft.SqlTools.ServiceLayer.Utility { public static class TextUtilities - { + { + /// + /// Find the position of the cursor in the SQL script content buffer and return previous new line position + /// + /// + /// + /// + /// + public static int PositionOfCursor(string sql, int startRow, int startColumn, out int prevNewLine) + { + prevNewLine = 0; + if (string.IsNullOrWhiteSpace(sql)) + { + return 1; + } + + for (int i = 0; i < startRow; ++i) + { + while (prevNewLine < sql.Length && sql[prevNewLine] != '\n') + { + ++prevNewLine; + } + ++prevNewLine; + } + + return startColumn + prevNewLine; + } + /// /// Find the position of the previous delimeter for autocomplete token replacement. /// SQL Parser may have similar functionality in which case we'll delete this method. @@ -14,49 +41,73 @@ namespace Microsoft.SqlTools.ServiceLayer.Utility /// /// /// - /// + /// public static int PositionOfPrevDelimeter(string sql, int startRow, int startColumn) - { - if (string.IsNullOrWhiteSpace(sql)) - { - return 1; - } + { + int prevNewLine; + int delimeterPos = PositionOfCursor(sql, startRow, startColumn, out prevNewLine); - int prevLineColumns = 0; - for (int i = 0; i < startRow; ++i) + if (delimeterPos - 1 < sql.Length) { - while (sql[prevLineColumns] != '\n' && prevLineColumns < sql.Length) + while (--delimeterPos >= prevNewLine) { - ++prevLineColumns; - } - ++prevLineColumns; - } - - startColumn += prevLineColumns; - - if (startColumn - 1 < sql.Length) - { - while (--startColumn >= prevLineColumns) - { - if (sql[startColumn] == ' ' - || sql[startColumn] == '\t' - || sql[startColumn] == '\n' - || sql[startColumn] == '.' - || sql[startColumn] == '+' - || sql[startColumn] == '-' - || sql[startColumn] == '*' - || sql[startColumn] == '>' - || sql[startColumn] == '<' - || sql[startColumn] == '=' - || sql[startColumn] == '/' - || sql[startColumn] == '%') + if (IsCharacterDelimeter(sql[delimeterPos])) { break; } } + + delimeterPos = delimeterPos + 1 - prevNewLine; } - return startColumn + 1 - prevLineColumns; + return delimeterPos; + } + + /// + /// Find the position of the next delimeter for autocomplete token replacement. + /// + /// + /// + /// + public static int PositionOfNextDelimeter(string sql, int startRow, int startColumn) + { + int prevNewLine; + int delimeterPos = PositionOfCursor(sql, startRow, startColumn, out prevNewLine); + + while (delimeterPos < sql.Length) + { + if (IsCharacterDelimeter(sql[delimeterPos])) + { + break; + } + ++delimeterPos; + } + + return delimeterPos - prevNewLine; + } + + /// + /// Determine if the character is a SQL token delimiter + /// + /// + private static bool IsCharacterDelimeter(char ch) + { + return ch == ' ' + || ch == '\t' + || ch == '\n' + || ch == '.' + || ch == '+' + || ch == '-' + || ch == '*' + || ch == '>' + || ch == '<' + || ch == '=' + || ch == '/' + || ch == '%' + || ch == ',' + || ch == ';' + || ch == '(' + || ch == ')'; } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/project.json b/src/Microsoft.SqlTools.ServiceLayer/project.json index aa33787e..778734f0 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/project.json +++ b/src/Microsoft.SqlTools.ServiceLayer/project.json @@ -8,8 +8,8 @@ "dependencies": { "Newtonsoft.Json": "9.0.1", "System.Data.Common": "4.1.0", - "System.Data.SqlClient": "4.1.0", - "Microsoft.SqlServer.Smo": "140.1.8", + "System.Data.SqlClient": "4.4.0-sqltools-24613-04", + "Microsoft.SqlServer.Smo": "140.1.11", "System.Security.SecureString": "4.0.0", "System.Collections.Specialized": "4.0.1", "System.ComponentModel.TypeConverter": "4.1.0", diff --git a/src/Microsoft.SqlTools.ServiceLayer/sr.Designer.cs b/src/Microsoft.SqlTools.ServiceLayer/sr.Designer.cs deleted file mode 100644 index 50d8c09d..00000000 --- a/src/Microsoft.SqlTools.ServiceLayer/sr.Designer.cs +++ /dev/null @@ -1,467 +0,0 @@ -//------------------------------------------------------------------------------ -// -// This code was generated by a tool. -// Runtime Version:4.0.30319.42000 -// -// Changes to this file may cause incorrect behavior and will be lost if -// the code is regenerated. -// -//------------------------------------------------------------------------------ - -namespace Microsoft.SqlTools.ServiceLayer { - using System; - using System.Reflection; - - - /// - /// A strongly-typed resource class, for looking up localized strings, etc. - /// - // This class was auto-generated by the StronglyTypedResourceBuilder - // class via a tool like ResGen or Visual Studio. - // To add or remove a member, edit your .ResX file then rerun ResGen - // with the /str option, or rebuild your VS project. - [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] - [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] - public class sr { - - private static global::System.Resources.ResourceManager resourceMan; - - private static global::System.Globalization.CultureInfo resourceCulture; - - internal sr() { - } - - /// - /// Returns the cached ResourceManager instance used by this class. - /// - [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] - public static global::System.Resources.ResourceManager ResourceManager { - get { - if (object.ReferenceEquals(resourceMan, null)) { - global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Microsoft.SqlTools.ServiceLayer.sr", typeof(sr).GetTypeInfo().Assembly); - resourceMan = temp; - } - return resourceMan; - } - } - - /// - /// Overrides the current thread's CurrentUICulture property for all - /// resource lookups using this strongly typed resource class. - /// - [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] - public static global::System.Globalization.CultureInfo Culture { - get { - return resourceCulture; - } - set { - resourceCulture = value; - } - } - - /// - /// Looks up a localized string similar to Connection details object cannot be null. - /// - public static string ConnectionParamsValidateNullConnection { - get { - return ResourceManager.GetString("ConnectionParamsValidateNullConnection", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to OwnerUri cannot be null or empty. - /// - public static string ConnectionParamsValidateNullOwnerUri { - get { - return ResourceManager.GetString("ConnectionParamsValidateNullOwnerUri", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to ServerName cannot be null or empty. - /// - public static string ConnectionParamsValidateNullServerName { - get { - return ResourceManager.GetString("ConnectionParamsValidateNullServerName", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to {0} cannot be null or empty when using SqlLogin authentication. - /// - public static string ConnectionParamsValidateNullSqlAuth { - get { - return ResourceManager.GetString("ConnectionParamsValidateNullSqlAuth", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Connection parameters cannot be null. - /// - public static string ConnectionServiceConnectErrorNullParams { - get { - return ResourceManager.GetString("ConnectionServiceConnectErrorNullParams", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Invalid value '{0}' for AuthenticationType. Valid values are 'Integrated' and 'SqlLogin'.. - /// - public static string ConnectionServiceConnStringInvalidAuthType { - get { - return ResourceManager.GetString("ConnectionServiceConnStringInvalidAuthType", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Invalid value '{0}' for ApplicationIntent. Valid values are 'ReadWrite' and 'ReadOnly'.. - /// - public static string ConnectionServiceConnStringInvalidIntent { - get { - return ResourceManager.GetString("ConnectionServiceConnStringInvalidIntent", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to SpecifiedUri '{0}' does not have existing connection. - /// - public static string ConnectionServiceListDbErrorNotConnected { - get { - return ResourceManager.GetString("ConnectionServiceListDbErrorNotConnected", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to OwnerUri cannot be null or empty. - /// - public static string ConnectionServiceListDbErrorNullOwnerUri { - get { - return ResourceManager.GetString("ConnectionServiceListDbErrorNullOwnerUri", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Win32Credential object is already disposed. - /// - public static string CredentialServiceWin32CredentialDisposed { - get { - return ResourceManager.GetString("CredentialServiceWin32CredentialDisposed", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Invalid CriticalHandle!. - /// - public static string CredentialsServiceInvalidCriticalHandle { - get { - return ResourceManager.GetString("CredentialsServiceInvalidCriticalHandle", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to The password has exceeded 512 bytes. - /// - public static string CredentialsServicePasswordLengthExceeded { - get { - return ResourceManager.GetString("CredentialsServicePasswordLengthExceeded", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Target must be specified to delete a credential. - /// - public static string CredentialsServiceTargetForDelete { - get { - return ResourceManager.GetString("CredentialsServiceTargetForDelete", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Target must be specified to check existance of a credential. - /// - public static string CredentialsServiceTargetForLookup { - get { - return ResourceManager.GetString("CredentialsServiceTargetForLookup", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Message header must separate key and value using ':'. - /// - public static string HostingHeaderMissingColon { - get { - return ResourceManager.GetString("HostingHeaderMissingColon", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Fatal error: Content-Length header must be provided. - /// - public static string HostingHeaderMissingContentLengthHeader { - get { - return ResourceManager.GetString("HostingHeaderMissingContentLengthHeader", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Fatal error: Content-Length value is not an integer. - /// - public static string HostingHeaderMissingContentLengthValue { - get { - return ResourceManager.GetString("HostingHeaderMissingContentLengthValue", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to MessageReader's input stream ended unexpectedly, terminating. - /// - public static string HostingUnexpectedEndOfStream { - get { - return ResourceManager.GetString("HostingUnexpectedEndOfStream", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to ({0} row(s) affected). - /// - public static string QueryServiceAffectedRows { - get { - return ResourceManager.GetString("QueryServiceAffectedRows", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to The query has already completed, it cannot be cancelled. - /// - public static string QueryServiceCancelAlreadyCompleted { - get { - return ResourceManager.GetString("QueryServiceCancelAlreadyCompleted", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Query successfully cancelled, failed to dispose query. Owner URI not found.. - /// - public static string QueryServiceCancelDisposeFailed { - get { - return ResourceManager.GetString("QueryServiceCancelDisposeFailed", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to (No column name). - /// - public static string QueryServiceColumnNull { - get { - return ResourceManager.GetString("QueryServiceColumnNull", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Command(s) copleted successfully.. - /// - public static string QueryServiceCompletedSuccessfully { - get { - return ResourceManager.GetString("QueryServiceCompletedSuccessfully", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Maximum number of bytes to return must be greater than zero. - /// - public static string QueryServiceDataReaderByteCountInvalid { - get { - return ResourceManager.GetString("QueryServiceDataReaderByteCountInvalid", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Maximum number of chars to return must be greater than zero. - /// - public static string QueryServiceDataReaderCharCountInvalid { - get { - return ResourceManager.GetString("QueryServiceDataReaderCharCountInvalid", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Maximum number of XML bytes to return must be greater than zero. - /// - public static string QueryServiceDataReaderXmlCountInvalid { - get { - return ResourceManager.GetString("QueryServiceDataReaderXmlCountInvalid", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Msg {0}, Level {1}, State {2}, Line {3}{4}{5}. - /// - public static string QueryServiceErrorFormat { - get { - return ResourceManager.GetString("QueryServiceErrorFormat", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to FileStreamWrapper must be initialized before performing operations. - /// - public static string QueryServiceFileWrapperNotInitialized { - get { - return ResourceManager.GetString("QueryServiceFileWrapperNotInitialized", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to This FileStreamWrapper cannot be used for writing. - /// - public static string QueryServiceFileWrapperReadOnly { - get { - return ResourceManager.GetString("QueryServiceFileWrapperReadOnly", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Access method cannot be write-only. - /// - public static string QueryServiceFileWrapperWriteOnly { - get { - return ResourceManager.GetString("QueryServiceFileWrapperWriteOnly", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Sender for OnInfoMessage event must be a SqlConnection. - /// - public static string QueryServiceMessageSenderNotSql { - get { - return ResourceManager.GetString("QueryServiceMessageSenderNotSql", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to A query is already in progress for this editor session. Please cancel this query or wait for its completion.. - /// - public static string QueryServiceQueryInProgress { - get { - return ResourceManager.GetString("QueryServiceQueryInProgress", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to This editor is not connected to a database. - /// - public static string QueryServiceQueryInvalidOwnerUri { - get { - return ResourceManager.GetString("QueryServiceQueryInvalidOwnerUri", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to The requested query does not exist. - /// - public static string QueryServiceRequestsNoQuery { - get { - return ResourceManager.GetString("QueryServiceRequestsNoQuery", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Could not retrieve column schema for result set. - /// - public static string QueryServiceResultSetNoColumnSchema { - get { - return ResourceManager.GetString("QueryServiceResultSetNoColumnSchema", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Cannot read subset unless the results have been read from the server. - /// - public static string QueryServiceResultSetNotRead { - get { - return ResourceManager.GetString("QueryServiceResultSetNotRead", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Reader cannot be null. - /// - public static string QueryServiceResultSetReaderNull { - get { - return ResourceManager.GetString("QueryServiceResultSetReaderNull", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Row count must be a positive integer. - /// - public static string QueryServiceResultSetRowCountOutOfRange { - get { - return ResourceManager.GetString("QueryServiceResultSetRowCountOutOfRange", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Start row cannot be less than 0 or greater than the number of rows in the result set. - /// - public static string QueryServiceResultSetStartRowOutOfRange { - get { - return ResourceManager.GetString("QueryServiceResultSetStartRowOutOfRange", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Batch index cannot be less than 0 or greater than the number of batches. - /// - public static string QueryServiceSubsetBatchOutOfRange { - get { - return ResourceManager.GetString("QueryServiceSubsetBatchOutOfRange", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to The query has not completed, yet. - /// - public static string QueryServiceSubsetNotCompleted { - get { - return ResourceManager.GetString("QueryServiceSubsetNotCompleted", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Result set index cannot be less than 0 or greater than the number of result sets. - /// - public static string QueryServiceSubsetResultSetOutOfRange { - get { - return ResourceManager.GetString("QueryServiceSubsetResultSetOutOfRange", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Start position ({0}, {1}) must come before or be equal to the end position ({2}, {3}). - /// - public static string WorkspaceServiceBufferPositionOutOfOrder { - get { - return ResourceManager.GetString("WorkspaceServiceBufferPositionOutOfOrder", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Position is outside of column range for line {0}. - /// - public static string WorkspaceServicePositionColumnOutOfRange { - get { - return ResourceManager.GetString("WorkspaceServicePositionColumnOutOfRange", resourceCulture); - } - } - - /// - /// Looks up a localized string similar to Position is outside of file line range. - /// - public static string WorkspaceServicePositionLineOutOfRange { - get { - return ResourceManager.GetString("WorkspaceServicePositionLineOutOfRange", resourceCulture); - } - } - } -} diff --git a/src/Microsoft.SqlTools.ServiceLayer/sr.cs b/src/Microsoft.SqlTools.ServiceLayer/sr.cs index 811ab975..9635b0f7 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/sr.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/sr.cs @@ -165,6 +165,14 @@ namespace Microsoft.SqlTools.ServiceLayer } } + public static string QueryServiceQueryCancelled + { + get + { + return Keys.GetString(Keys.QueryServiceQueryCancelled); + } + } + public static string QueryServiceSubsetNotCompleted { get @@ -237,6 +245,14 @@ namespace Microsoft.SqlTools.ServiceLayer } } + public static string QueryServiceAffectedOneRow + { + get + { + return Keys.GetString(Keys.QueryServiceAffectedOneRow); + } + } + public static string QueryServiceCompletedSuccessfully { get @@ -363,6 +379,11 @@ namespace Microsoft.SqlTools.ServiceLayer return Keys.GetString(Keys.QueryServiceErrorFormat, msg, lvl, state, line, newLine, message); } + public static string QueryServiceQueryFailed(string message) + { + return Keys.GetString(Keys.QueryServiceQueryFailed, message); + } + public static string WorkspaceServicePositionColumnOutOfRange(int line) { return Keys.GetString(Keys.WorkspaceServicePositionColumnOutOfRange, line); @@ -444,6 +465,9 @@ namespace Microsoft.SqlTools.ServiceLayer public const string QueryServiceCancelDisposeFailed = "QueryServiceCancelDisposeFailed"; + public const string QueryServiceQueryCancelled = "QueryServiceQueryCancelled"; + + public const string QueryServiceSubsetNotCompleted = "QueryServiceSubsetNotCompleted"; @@ -471,6 +495,9 @@ namespace Microsoft.SqlTools.ServiceLayer public const string QueryServiceFileWrapperReadOnly = "QueryServiceFileWrapperReadOnly"; + public const string QueryServiceAffectedOneRow = "QueryServiceAffectedOneRow"; + + public const string QueryServiceAffectedRows = "QueryServiceAffectedRows"; @@ -480,6 +507,9 @@ namespace Microsoft.SqlTools.ServiceLayer public const string QueryServiceErrorFormat = "QueryServiceErrorFormat"; + public const string QueryServiceQueryFailed = "QueryServiceQueryFailed"; + + public const string QueryServiceColumnNull = "QueryServiceColumnNull"; diff --git a/src/Microsoft.SqlTools.ServiceLayer/sr.resx b/src/Microsoft.SqlTools.ServiceLayer/sr.resx index 63d7e71b..27f0993d 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/sr.resx +++ b/src/Microsoft.SqlTools.ServiceLayer/sr.resx @@ -205,6 +205,10 @@ Query successfully cancelled, failed to dispose query. Owner URI not found. + + Query was canceled by user + + The query has not completed, yet @@ -241,19 +245,28 @@ This FileStreamWrapper cannot be used for writing + + (1 row affected) + + - ({0} row(s) affected) + ({0} rows affected) . Parameters: 0 - rows (long) - Command(s) copleted successfully. + Commands completed successfully. Msg {0}, Level {1}, State {2}, Line {3}{4}{5} . Parameters: 0 - msg (int), 1 - lvl (int), 2 - state (int), 3 - line (int), 4 - newLine (string), 5 - message (string) + + + Query failed: {0} + . + Parameters: 0 - message (string) (No column name) diff --git a/src/Microsoft.SqlTools.ServiceLayer/sr.strings b/src/Microsoft.SqlTools.ServiceLayer/sr.strings index a74a54d9..9b9a138e 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/sr.strings +++ b/src/Microsoft.SqlTools.ServiceLayer/sr.strings @@ -79,6 +79,8 @@ QueryServiceCancelAlreadyCompleted = The query has already completed, it cannot QueryServiceCancelDisposeFailed = Query successfully cancelled, failed to dispose query. Owner URI not found. +QueryServiceQueryCancelled = Query was canceled by user + ### Subset Request QueryServiceSubsetNotCompleted = The query has not completed, yet @@ -105,12 +107,16 @@ QueryServiceFileWrapperReadOnly = This FileStreamWrapper cannot be used for writ ### Query Request -QueryServiceAffectedRows(long rows) = ({0} row(s) affected) +QueryServiceAffectedOneRow = (1 row affected) -QueryServiceCompletedSuccessfully = Command(s) copleted successfully. +QueryServiceAffectedRows(long rows) = ({0} rows affected) + +QueryServiceCompletedSuccessfully = Commands completed successfully. QueryServiceErrorFormat(int msg, int lvl, int state, int line, string newLine, string message) = Msg {0}, Level {1}, State {2}, Line {3}{4}{5} +QueryServiceQueryFailed(string message) = Query failed: {0} + QueryServiceColumnNull = (No column name) QueryServiceRequestsNoQuery = The requested query does not exist diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs index 8205cf21..46ccbb94 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs @@ -7,6 +7,7 @@ using System; using System.Collections.Generic; using System.Data; using System.Data.Common; +using System.Data.SqlClient; using System.Reflection; using System.Threading; using System.Threading.Tasks; @@ -35,7 +36,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection var commandMockSetup = commandMock.Protected() .Setup("ExecuteDbDataReader", It.IsAny()); - commandMockSetup.Returns(new TestDbDataReader(data)); + commandMockSetup.Returns(() => new TestDbDataReader(data)); return commandMock.Object; } @@ -830,5 +831,35 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection ConnectionInfo info; Assert.True(service.TryFindConnection(connectParams.OwnerUri, out info)); } + + /// + /// Verify that Linux/OSX SqlExceptions thrown do not contain an error code. + /// This is a bug in .NET core (see https://github.com/dotnet/corefx/issues/12472). + /// If this test ever fails, it means that this bug has been fixed. When this is + /// the case, look at RetryPolicyUtils.cs in IsRetryableNetworkConnectivityError(), + /// and remove the code block specific to Linux/OSX. + /// + [Fact] + public void TestThatLinuxAndOSXSqlExceptionHasNoErrorCode() + { + TestUtils.RunIfLinuxOrOSX(() => + { + try + { + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(); + builder.DataSource = "bad-server-name"; + builder.UserID = "sa"; + builder.Password = "bad password"; + + SqlConnection connection = new SqlConnection(builder.ConnectionString); + connection.Open(); // This should fail + } + catch (SqlException ex) + { + // Error code should be 0 due to bug + Assert.Equal(ex.Number, 0); + } + }); + } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ReliableConnectionTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ReliableConnectionTests.cs new file mode 100644 index 00000000..24a2f39c --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ReliableConnectionTests.cs @@ -0,0 +1,343 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +#if LIVE_CONNECTION_TESTS + +using System; +using System.Data; +using System.Data.Common; +using System.Data.SqlClient; +using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Connection +{ + /// + /// Tests for the ReliableConnection module. + /// These tests all assume a live connection to a database on localhost using integrated auth. + /// + public class ReliableConnectionTests + { + /// + /// Environment variable that stores the name of the test server hosting the SQL Server instance. + /// + public static string TestServerEnvironmentVariable + { + get { return "TEST_SERVER"; } + } + + private static Lazy testServerName = new Lazy(() => Environment.GetEnvironmentVariable(TestServerEnvironmentVariable)); + + /// + /// Name of the test server hosting the SQL Server instance. + /// + public static string TestServerName + { + get { return testServerName.Value; } + } + + /// + /// Helper method to create an integrated auth connection builder for testing. + /// + private SqlConnectionStringBuilder CreateTestConnectionStringBuilder() + { + SqlConnectionStringBuilder csb = new SqlConnectionStringBuilder(); + csb.DataSource = TestServerName; + csb.IntegratedSecurity = true; + + return csb; + } + + /// + /// Helper method to create an integrated auth reliable connection for testing. + /// + private DbConnection CreateTestConnection() + { + SqlConnectionStringBuilder csb = CreateTestConnectionStringBuilder(); + + RetryPolicy connectionRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy(); + RetryPolicy commandRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy(); + + ReliableSqlConnection connection = new ReliableSqlConnection(csb.ConnectionString, connectionRetryPolicy, commandRetryPolicy); + return connection; + } + + /// + /// Test ReliableConnectionHelper.GetDefaultDatabaseFilePath() + /// + [Fact] + public void TestGetDefaultDatabaseFilePath() + { + TestUtils.RunIfWindows(() => + { + var connectionBuilder = CreateTestConnectionStringBuilder(); + Assert.NotNull(connectionBuilder); + + string filePath = string.Empty; + string logPath = string.Empty; + + ReliableConnectionHelper.OpenConnection( + connectionBuilder, + usingConnection: (conn) => + { + filePath = ReliableConnectionHelper.GetDefaultDatabaseFilePath(conn); + logPath = ReliableConnectionHelper.GetDefaultDatabaseLogPath(conn); + }, + catchException: null, + useRetry: false); + + Assert.False(string.IsNullOrWhiteSpace(filePath)); + Assert.False(string.IsNullOrWhiteSpace(logPath)); + }); + } + + /// + /// Test ReliableConnectionHelper.GetServerVersion() + /// + [Fact] + public void TestGetServerVersion() + { + TestUtils.RunIfWindows(() => + { + using (var connection = CreateTestConnection()) + { + Assert.NotNull(connection); + connection.Open(); + + ReliableConnectionHelper.ServerInfo serverInfo = ReliableConnectionHelper.GetServerVersion(connection); + ReliableConnectionHelper.ServerInfo serverInfo2; + using (var connection2 = CreateTestConnection()) + { + connection2.Open(); + serverInfo2 = ReliableConnectionHelper.GetServerVersion(connection); + } + + Assert.NotNull(serverInfo); + Assert.NotNull(serverInfo2); + Assert.True(serverInfo.ServerMajorVersion != 0); + Assert.True(serverInfo.ServerMajorVersion == serverInfo2.ServerMajorVersion); + Assert.True(serverInfo.ServerMinorVersion == serverInfo2.ServerMinorVersion); + Assert.True(serverInfo.ServerReleaseVersion == serverInfo2.ServerReleaseVersion); + Assert.True(serverInfo.ServerEdition == serverInfo2.ServerEdition); + Assert.True(serverInfo.IsCloud == serverInfo2.IsCloud); + Assert.True(serverInfo.AzureVersion == serverInfo2.AzureVersion); + } + }); + } + + /// + /// Tests ReliableConnectionHelper.GetCompleteServerName() + /// + [Fact] + public void TestGetCompleteServerName() + { + string name = ReliableConnectionHelper.GetCompleteServerName(@".\SQL2008"); + Assert.True(name.Contains(Environment.MachineName)); + + name = ReliableConnectionHelper.GetCompleteServerName(@"(local)"); + Assert.True(name.Contains(Environment.MachineName)); + } + + /// + /// Tests ReliableConnectionHelper.IsDatabaseReadonly() + /// + [Fact] + public void TestIsDatabaseReadonly() + { + var connectionBuilder = CreateTestConnectionStringBuilder(); + Assert.NotNull(connectionBuilder); + + bool isReadOnly = ReliableConnectionHelper.IsDatabaseReadonly(connectionBuilder); + Assert.False(isReadOnly); + } + + /// + /// Verify ANSI_NULL and QUOTED_IDENTIFIER settings can be set and retrieved for a session + /// + [Fact] + public void VerifyAnsiNullAndQuotedIdentifierSettingsReplayed() + { + TestUtils.RunIfWindows(() => + { + using (ReliableSqlConnection conn = (ReliableSqlConnection)ReliableConnectionHelper.OpenConnection(CreateTestConnectionStringBuilder(), useRetry: true)) + { + VerifySessionSettings(conn, true); + VerifySessionSettings(conn, false); + } + }); + } + + private void VerifySessionSettings(ReliableSqlConnection conn, bool expectedSessionValue) + { + Tuple[] settings = null; + using (IDbCommand cmd = conn.CreateCommand()) + { + if (expectedSessionValue) + { + cmd.CommandText = "SET ANSI_NULLS, QUOTED_IDENTIFIER ON"; + } + else + { + cmd.CommandText = "SET ANSI_NULLS, QUOTED_IDENTIFIER OFF"; + } + + cmd.ExecuteNonQuery(); + + //baseline assertion + AssertSessionValues(cmd, ansiNullsValue: expectedSessionValue, quotedIdentifersValue: expectedSessionValue); + + // verify the initial values are correct + settings = conn.CacheOrReplaySessionSettings(cmd, settings); + + // assert no change is session settings + AssertSessionValues(cmd, ansiNullsValue: expectedSessionValue, quotedIdentifersValue: expectedSessionValue); + + // assert cached settings are correct + Assert.Equal("ANSI_NULLS", settings[0].Item1); + Assert.Equal(expectedSessionValue, settings[0].Item2); + + Assert.Equal("QUOTED_IDENTIFIER", settings[1].Item1); + Assert.Equal(expectedSessionValue, settings[1].Item2); + + // invert session values and assert we reset them + + if (expectedSessionValue) + { + cmd.CommandText = "SET ANSI_NULLS, QUOTED_IDENTIFIER OFF"; + } + else + { + cmd.CommandText = "SET ANSI_NULLS, QUOTED_IDENTIFIER ON"; + } + cmd.ExecuteNonQuery(); + + // baseline assertion + AssertSessionValues(cmd, ansiNullsValue: !expectedSessionValue, quotedIdentifersValue: !expectedSessionValue); + + // replay cached value + settings = conn.CacheOrReplaySessionSettings(cmd, settings); + + // assert session settings correctly set + AssertSessionValues(cmd, ansiNullsValue: expectedSessionValue, quotedIdentifersValue: expectedSessionValue); + } + } + + private void AssertSessionValues(IDbCommand cmd, bool ansiNullsValue, bool quotedIdentifersValue) + { + // assert session was updated + cmd.CommandText = "SELECT SESSIONPROPERTY ('ANSI_NULLS'), SESSIONPROPERTY ('QUOTED_IDENTIFIER')"; + using (IDataReader reader = cmd.ExecuteReader()) + { + Assert.True(reader.Read(), "Missing session settings"); + bool actualAnsiNullsOnValue = ((int)reader[0] == 1); + bool actualQuotedIdentifierOnValue = ((int)reader[1] == 1); + Assert.Equal(ansiNullsValue, actualAnsiNullsOnValue); + Assert.Equal(quotedIdentifersValue, actualQuotedIdentifierOnValue); + } + + } + + /// + /// Test that the retry policy factory constructs all possible types of policies successfully. + /// + [Fact] + public void RetryPolicyFactoryConstructsPoliciesSuccessfully() + { + TestUtils.RunIfWindows(() => + { + Assert.NotNull(RetryPolicyFactory.CreateColumnEncryptionTransferRetryPolicy()); + Assert.NotNull(RetryPolicyFactory.CreateDatabaseCommandRetryPolicy()); + Assert.NotNull(RetryPolicyFactory.CreateDataScriptUpdateRetryPolicy()); + Assert.NotNull(RetryPolicyFactory.CreateDefaultConnectionRetryPolicy()); + Assert.NotNull(RetryPolicyFactory.CreateDefaultDataConnectionRetryPolicy()); + Assert.NotNull(RetryPolicyFactory.CreateDefaultDataSqlCommandRetryPolicy()); + Assert.NotNull(RetryPolicyFactory.CreateDefaultDataTransferRetryPolicy()); + Assert.NotNull(RetryPolicyFactory.CreateDefaultSchemaCommandRetryPolicy(true)); + Assert.NotNull(RetryPolicyFactory.CreateDefaultSchemaConnectionRetryPolicy()); + Assert.NotNull(RetryPolicyFactory.CreateElementCommandRetryPolicy()); + Assert.NotNull(RetryPolicyFactory.CreateFastDataRetryPolicy()); + Assert.NotNull(RetryPolicyFactory.CreateNoRetryPolicy()); + Assert.NotNull(RetryPolicyFactory.CreatePrimaryKeyCommandRetryPolicy()); + Assert.NotNull(RetryPolicyFactory.CreateSchemaCommandRetryPolicy(6)); + Assert.NotNull(RetryPolicyFactory.CreateSchemaConnectionRetryPolicy(6)); + }); + } + + /// + /// ReliableConnectionHelper.IsCloud() should be false for a local server + /// + [Fact] + public void TestIsCloudIsFalseForLocalServer() + { + TestUtils.RunIfWindows(() => + { + using (var connection = CreateTestConnection()) + { + Assert.NotNull(connection); + + connection.Open(); + Assert.False(ReliableConnectionHelper.IsCloud(connection)); + } + }); + } + + /// + /// Tests that ReliableConnectionHelper.OpenConnection() opens a connection if it is closed + /// + [Fact] + public void TestOpenConnectionOpensConnection() + { + TestUtils.RunIfWindows(() => + { + using (var connection = CreateTestConnection()) + { + Assert.NotNull(connection); + + Assert.True(connection.State == ConnectionState.Closed); + ReliableConnectionHelper.OpenConnection(connection); + Assert.True(connection.State == ConnectionState.Open); + } + }); + } + + /// + /// Tests that ReliableConnectionHelper.ExecuteNonQuery() runs successfully + /// + [Fact] + public void TestExecuteNonQuery() + { + TestUtils.RunIfWindows(() => + { + var result = ReliableConnectionHelper.ExecuteNonQuery( + CreateTestConnectionStringBuilder(), + "SET NOCOUNT ON; SET NOCOUNT OFF;", + ReliableConnectionHelper.SetCommandTimeout, + null, + true + ); + Assert.NotNull(result); + }); + } + + /// + /// Test that TryGetServerVersion() gets server information + /// + [Fact] + public void TestTryGetServerVersion() + { + TestUtils.RunIfWindows(() => + { + ReliableConnectionHelper.ServerInfo info = null; + Assert.True(ReliableConnectionHelper.TryGetServerVersion(CreateTestConnectionStringBuilder().ConnectionString, out info)); + + Assert.NotNull(info); + Assert.NotNull(info.ServerVersion); + Assert.NotEmpty(info.ServerVersion); + }); + } + } +} +#endif // LIVE_CONNECTION_TESTS diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/BindingQueueTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/BindingQueueTests.cs index 63181f67..bf60278b 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/BindingQueueTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/BindingQueueTests.cs @@ -4,7 +4,6 @@ // using System.Threading; -using System.Threading.Tasks; using Microsoft.SqlServer.Management.Common; using Microsoft.SqlServer.Management.SmoMetadataProvider; using Microsoft.SqlServer.Management.SqlParser.Binder; @@ -25,7 +24,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices { public TestBindingContext() { - this.BindingLocked = new ManualResetEvent(initialState: true); + this.BindingLock = new ManualResetEvent(true); this.BindingTimeout = 3000; } @@ -39,7 +38,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices public IBinder Binder { get; set; } - public ManualResetEvent BindingLocked { get; set; } + public ManualResetEvent BindingLock { get; set; } public int BindingTimeout { get; set; } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs index d27fe156..3899507c 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs @@ -24,7 +24,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // Set up file for returning the query var fileMock = new Mock(); fileMock.Setup(file => file.GetLinesInRange(It.IsAny())) - .Returns(new string[] { Common.StandardQuery }); + .Returns(new[] { Common.StandardQuery }); // Set up workspace mock var workspaceService = new Mock>(); workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) @@ -36,7 +36,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution var executeParams = new QueryExecuteParams { QuerySelection = Common.GetSubSectionDocument(), OwnerUri = Common.OwnerUri }; var executeRequest = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; queryService.ActiveQueries[Common.OwnerUri].HasExecuted = false; // Fake that it hasn't completed execution // ... And then I request to cancel the query @@ -50,8 +51,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution VerifyQueryCancelCallCount(cancelRequest, Times.Once(), Times.Never()); Assert.Null(result.Messages); - // ... The query should have been disposed as well - Assert.Empty(queryService.ActiveQueries); + // ... The query should not have been disposed + Assert.Equal(1, queryService.ActiveQueries.Count); } [Fact] @@ -71,13 +72,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution var executeParams = new QueryExecuteParams {QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri}; var executeRequest = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; // ... And then I request to cancel the query var cancelParams = new QueryCancelParams {OwnerUri = Common.OwnerUri}; QueryCancelResult result = null; var cancelRequest = GetQueryCancelResultContextMock(qcr => result = qcr, null); - queryService.HandleCancelRequest(cancelParams, cancelRequest.Object).Wait(); + await queryService.HandleCancelRequest(cancelParams, cancelRequest.Object); // Then: // ... I should have seen a result event with an error message diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs index c3f6df75..bd5abe81 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs @@ -14,9 +14,6 @@ using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlServer.Management.Common; -using Microsoft.SqlServer.Management.SmoMetadataProvider; -using Microsoft.SqlServer.Management.SqlParser.Binder; -using Microsoft.SqlServer.Management.SqlParser.MetadataProvider; using Microsoft.SqlTools.ServiceLayer.LanguageServices; using Microsoft.SqlTools.ServiceLayer.QueryExecution; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; @@ -95,7 +92,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution { ConnectionInfo ci = CreateTestConnectionInfo(new[] {StandardTestData}, false); Query query = new Query(StandardQuery, ci, new QueryExecutionSettings(), GetFileStreamFactory()); - query.Execute().Wait(); + query.Execute(); + query.ExecutionTask.Wait(); return query; } @@ -287,6 +285,20 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution return new QueryExecutionService(connectionService, workspaceService) {BufferFileStreamFactory = GetFileStreamFactory()}; } + public static WorkspaceService GetPrimedWorkspaceService() + { + // Set up file for returning the query + var fileMock = new Mock(); + fileMock.SetupGet(file => file.Contents).Returns(StandardQuery); + + // Set up workspace mock + var workspaceService = new Mock>(); + workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) + .Returns(fileMock.Object); + + return workspaceService.Object; + } + #endregion } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs index b3ff5efd..2f103ec0 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs @@ -4,7 +4,6 @@ // using System; -using System.Data.Common; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.QueryExecution; @@ -51,7 +50,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); var executeParams = new QueryExecuteParams {QuerySelection = null, OwnerUri = Common.OwnerUri}; var executeRequest = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; // ... And then I dispose of the query var disposeParams = new QueryDisposeParams {OwnerUri = Common.OwnerUri}; @@ -107,6 +107,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution var queryParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; var requestContext = RequestContextMocks.Create(null); await queryService.HandleExecuteRequest(queryParams, requestContext.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; // ... And it sticks around as an active query Assert.Equal(1, queryService.ActiveQueries.Count); diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs index 2484e233..72137474 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs @@ -19,7 +19,6 @@ using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.Test.Utility; using Microsoft.SqlTools.ServiceLayer.Workspace; using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; -using Microsoft.SqlTools.Test.Utility; using Moq; using Xunit; @@ -74,9 +73,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // ... There should be a message for how many rows were affected Assert.Equal(1, batch.ResultMessages.Count()); - Assert.Contains("1 ", batch.ResultMessages.First().Message); - // NOTE: 1 is expected because this test simulates a 'update' statement where 1 row was affected. - // The 1 in quotes is to make sure the 1 isn't part of a larger number } [Fact] @@ -108,7 +104,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // ... There should be a message for how many rows were affected Assert.Equal(resultSets, batch.ResultMessages.Count()); - Assert.Contains(Common.StandardRows.ToString(), batch.ResultMessages.First().Message); } [Fact] @@ -150,13 +145,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // ... Inside each result summary, there should be 5 column definitions Assert.Equal(Common.StandardColumns, rs.ColumnInfo.Length); } - - // ... There should be a message for how many rows were affected - Assert.Equal(resultSets, batch.ResultMessages.Count()); - foreach (var rsm in batch.ResultMessages) - { - Assert.Contains(Common.StandardRows.ToString(), rsm.Message); - } } [Fact] @@ -295,7 +283,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // If: // ... I then execute the query - query.Execute().Wait(); + query.Execute(); + query.ExecutionTask.Wait(); // Then: // ... The query should have completed successfully with one batch summary returned @@ -321,7 +310,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // If: // ... I Then execute the query - query.Execute().Wait(); + query.Execute(); + query.ExecutionTask.Wait(); // Then: // ... The query should have completed successfully with no batch summaries returned @@ -348,7 +338,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // If: // ... I then execute the query - query.Execute().Wait(); + query.Execute(); + query.ExecutionTask.Wait(); // Then: // ... The query should have completed successfully with two batch summaries returned @@ -376,7 +367,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // If: // .. I then execute the query - query.Execute().Wait(); + query.Execute(); + query.ExecutionTask.Wait(); // ... The query should have completed successfully with one batch summary returned Assert.True(query.HasExecuted); @@ -402,7 +394,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // If: // ... I then execute the query - query.Execute().Wait(); + query.Execute(); + query.ExecutionTask.Wait(); // Then: // ... There should be an error on the batch @@ -444,7 +437,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution expectedEvent: QueryExecuteCompleteEvent.Type, eventCallback: (et, cp) => completeParams = cp, errorCallback: null); - queryService.HandleExecuteRequest(queryParams, requestContext.Object).Wait(); + await AwaitExecution(queryService, queryParams, requestContext.Object); // Then: // ... No Errors should have been sent @@ -485,7 +478,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution expectedEvent: QueryExecuteCompleteEvent.Type, eventCallback: (et, cp) => completeParams = cp, errorCallback: null); - queryService.HandleExecuteRequest(queryParams, requestContext.Object).Wait(); + await AwaitExecution(queryService, queryParams, requestContext.Object); // Then: // ... No errors should have been sent @@ -512,18 +505,19 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), false, workspaceService.Object); var queryParams = new QueryExecuteParams { OwnerUri = "notConnected", QuerySelection = Common.WholeDocument }; - QueryExecuteResult result = null; - var requestContext = RequestContextMocks.SetupRequestContextMock(qer => result = qer, QueryExecuteCompleteEvent.Type, null, null); - queryService.HandleExecuteRequest(queryParams, requestContext.Object).Wait(); + object error = null; + var requestContext = RequestContextMocks.Create(null) + .AddErrorHandling(e => error = e); + await queryService.HandleExecuteRequest(queryParams, requestContext.Object); // Then: - // ... An error message should have been returned via the result + // ... An error should have been returned + // ... No result should have been returned // ... No completion event should have been fired - // ... No error event should have been fired // ... There should be no active queries - VerifyQueryExecuteCallCount(requestContext, Times.Once(), Times.Never(), Times.Never()); - Assert.NotNull(result.Messages); - Assert.NotEmpty(result.Messages); + VerifyQueryExecuteCallCount(requestContext, Times.Never(), Times.Never(), Times.Once()); + Assert.IsType(error); + Assert.NotEmpty((string)error); Assert.Empty(queryService.ActiveQueries); } @@ -545,24 +539,25 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument }; // Note, we don't care about the results of the first request - var firstRequestContext = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); - queryService.HandleExecuteRequest(queryParams, firstRequestContext.Object).Wait(); + var firstRequestContext = RequestContextMocks.Create(null); + await AwaitExecution(queryService, queryParams, firstRequestContext.Object); // ... And then I request another query without waiting for the first to complete queryService.ActiveQueries[Common.OwnerUri].HasExecuted = false; // Simulate query hasn't finished - QueryExecuteResult result = null; - var secondRequestContext = RequestContextMocks.SetupRequestContextMock(qer => result = qer, QueryExecuteCompleteEvent.Type, null, null); - queryService.HandleExecuteRequest(queryParams, secondRequestContext.Object).Wait(); + object error = null; + var secondRequestContext = RequestContextMocks.Create(null) + .AddErrorHandling(e => error = e); + await AwaitExecution(queryService, queryParams, secondRequestContext.Object); // Then: - // ... No errors should have been sent - // ... A result should have been sent with an error message + // ... An error should have been sent + // ... A result should have not have been sent // ... No completion event should have been fired - // ... There should only be one active query - VerifyQueryExecuteCallCount(secondRequestContext, Times.Once(), Times.AtMostOnce(), Times.Never()); - Assert.NotNull(result.Messages); - Assert.NotEmpty(result.Messages); - Assert.Equal(1, queryService.ActiveQueries.Count); + // ... The original query should exist + VerifyQueryExecuteCallCount(secondRequestContext, Times.Never(), Times.Never(), Times.Once()); + Assert.IsType(error); + Assert.NotEmpty((string)error); + Assert.Contains(Common.OwnerUri, queryService.ActiveQueries.Keys); } [Fact] @@ -584,15 +579,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // Note, we don't care about the results of the first request var firstRequestContext = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); - - queryService.HandleExecuteRequest(queryParams, firstRequestContext.Object).Wait(); + await AwaitExecution(queryService, queryParams, firstRequestContext.Object); // ... And then I request another query after waiting for the first to complete QueryExecuteResult result = null; QueryExecuteCompleteParams complete = null; var secondRequestContext = RequestContextMocks.SetupRequestContextMock(qer => result = qer, QueryExecuteCompleteEvent.Type, (et, qecp) => complete = qecp, null); - queryService.HandleExecuteRequest(queryParams, secondRequestContext.Object).Wait(); + await AwaitExecution(queryService, queryParams, secondRequestContext.Object); // Then: // ... No errors should have been sent @@ -606,7 +600,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution [Theory] [InlineData(null)] - public async void QueryExecuteMissingSelectionTest(SelectionData selection) + public async Task QueryExecuteMissingSelectionTest(SelectionData selection) { // Set up file for returning the query @@ -621,18 +615,20 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QuerySelection = selection }; - QueryExecuteResult result = null; - var requestContext = - RequestContextMocks.SetupRequestContextMock(qer => result = qer, QueryExecuteCompleteEvent.Type, null, null); - queryService.HandleExecuteRequest(queryParams, requestContext.Object).Wait(); + object errorResult = null; + var requestContext = RequestContextMocks.Create(null) + .AddErrorHandling(error => errorResult = error); + await queryService.HandleExecuteRequest(queryParams, requestContext.Object); // Then: - // ... No errors should have been sent - // ... A result should have been sent with an error message + // ... Am error should have been sent + // ... No result should have been sent // ... No completion event should have been fired - VerifyQueryExecuteCallCount(requestContext, Times.Once(), Times.Never(), Times.Never()); - Assert.NotNull(result.Messages); - Assert.NotEmpty(result.Messages); + // ... An active query should not have been added + VerifyQueryExecuteCallCount(requestContext, Times.Never(), Times.Never(), Times.Once()); + Assert.NotNull(errorResult); + Assert.IsType(errorResult); + Assert.DoesNotContain(Common.OwnerUri, queryService.ActiveQueries.Keys); // ... There should not be an active query Assert.Empty(queryService.ActiveQueries); @@ -657,7 +653,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution QueryExecuteCompleteParams complete = null; var requestContext = RequestContextMocks.SetupRequestContextMock(qer => result = qer, QueryExecuteCompleteEvent.Type, (et, qecp) => complete = qecp, null); - queryService.HandleExecuteRequest(queryParams, requestContext.Object).Wait(); + await AwaitExecution(queryService, queryParams, requestContext.Object); // Then: // ... No errors should have been sent @@ -700,7 +696,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution #endregion - private void VerifyQueryExecuteCallCount(Mock> mock, Times sendResultCalls, Times sendEventCalls, Times sendErrorCalls) + private static void VerifyQueryExecuteCallCount(Mock> mock, Times sendResultCalls, Times sendEventCalls, Times sendErrorCalls) { mock.Verify(rc => rc.SendResult(It.IsAny()), sendResultCalls); mock.Verify(rc => rc.SendEvent( @@ -709,9 +705,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution mock.Verify(rc => rc.SendError(It.IsAny()), sendErrorCalls); } - private DbConnection GetConnection(ConnectionInfo info) + private static DbConnection GetConnection(ConnectionInfo info) { return info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails)); } + + private static async Task AwaitExecution(QueryExecutionService service, QueryExecuteParams qeParams, + RequestContext requestContext) + { + await service.HandleExecuteRequest(qeParams, requestContext); + await service.ActiveQueries[qeParams.OwnerUri].ExecutionTask; + } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SaveResultsTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SaveResultsTests.cs index e3c38ab5..470476c2 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SaveResultsTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SaveResultsTests.cs @@ -3,15 +3,16 @@ // using System; +using System.Linq; using System.IO; using System.Threading.Tasks; using System.Runtime.InteropServices; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; -using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; using Microsoft.SqlTools.ServiceLayer.SqlContext; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; using Microsoft.SqlTools.ServiceLayer.Workspace; -using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; using Moq; using Xunit; @@ -28,19 +29,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution [Fact] public async void SaveResultsAsCsvSuccessTest() { - - // Set up file for returning the query - var fileMock = new Mock(); - fileMock.SetupGet(file => file.Contents).Returns(Common.StandardQuery); - // Set up workspace mock - var workspaceService = new Mock>(); - workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) - .Returns(fileMock.Object); // Execute a query - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, Common.GetPrimedWorkspaceService()); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; - var executeRequest = GetQueryExecuteResultContextMock(null, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + var executeRequest = RequestContextMocks.Create(null); + await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; // Request to save the results as csv with correct parameters var saveParams = new SaveResultsAsCsvRequestParams @@ -54,7 +48,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution SaveResultRequestResult result = null; var saveRequest = GetSaveResultsContextMock(qcr => result = qcr, null); queryService.ActiveQueries[Common.OwnerUri].Batches[0] = Common.GetBasicExecutedBatch(); - queryService.HandleSaveResultsAsCsvRequest(saveParams, saveRequest.Object).Wait(); + + // Call save results and wait on the save task + await queryService.HandleSaveResultsAsCsvRequest(saveParams, saveRequest.Object); + ResultSet selectedResultSet = queryService.ActiveQueries[saveParams.OwnerUri].Batches[saveParams.BatchIndex].ResultSets[saveParams.ResultSetIndex]; + Task saveTask = selectedResultSet.GetSaveTask(saveParams.FilePath); + await saveTask; // Expect to see a file successfully created in filepath and a success message Assert.Null(result.Messages); @@ -74,20 +73,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution [Fact] public async void SaveResultsAsCsvWithSelectionSuccessTest() { - - // Set up file for returning the query - var fileMock = new Mock(); - fileMock.SetupGet(file => file.Contents).Returns(Common.StandardQuery); - // Set up workspace mock - var workspaceService = new Mock>(); - workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) - .Returns(fileMock.Object); - // Execute a query - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, Common.GetPrimedWorkspaceService()); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument , OwnerUri = Common.OwnerUri }; - var executeRequest = GetQueryExecuteResultContextMock(null, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + var executeRequest = RequestContextMocks.Create(null); + await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; // Request to save the results as csv with correct parameters var saveParams = new SaveResultsAsCsvRequestParams @@ -105,7 +96,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution SaveResultRequestResult result = null; var saveRequest = GetSaveResultsContextMock(qcr => result = qcr, null); queryService.ActiveQueries[Common.OwnerUri].Batches[0] = Common.GetBasicExecutedBatch(); - queryService.HandleSaveResultsAsCsvRequest(saveParams, saveRequest.Object).Wait(); + + // Call save results and wait on the save task + await queryService.HandleSaveResultsAsCsvRequest(saveParams, saveRequest.Object); + ResultSet selectedResultSet = queryService.ActiveQueries[saveParams.OwnerUri].Batches[saveParams.BatchIndex].ResultSets[saveParams.ResultSetIndex]; + Task saveTask = selectedResultSet.GetSaveTask(saveParams.FilePath); + await saveTask; // Expect to see a file successfully created in filepath and a success message Assert.Null(result.Messages); @@ -124,21 +120,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution /// [Fact] public async void SaveResultsAsCsvExceptionTest() - { - - // Set up file for returning the query - var fileMock = new Mock(); - fileMock.SetupGet(file => file.Contents).Returns(Common.StandardQuery); - // Set up workspace mock - var workspaceService = new Mock>(); - workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) - .Returns(fileMock.Object); - + { // Execute a query - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, Common.GetPrimedWorkspaceService()); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; - var executeRequest = GetQueryExecuteResultContextMock(null, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + var executeRequest = RequestContextMocks.Create(null); + await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; // Request to save the results as csv with incorrect filepath var saveParams = new SaveResultsAsCsvRequestParams @@ -148,11 +136,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution BatchIndex = 0, FilePath = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? "G:\\test.csv" : "/test.csv" }; - // SaveResultRequestResult result = null; - string errMessage = null; - var saveRequest = GetSaveResultsContextMock( null, err => errMessage = (string) err); + + SaveResultRequestError errMessage = null; + var saveRequest = GetSaveResultsContextMock( null, err => errMessage = (SaveResultRequestError) err); queryService.ActiveQueries[Common.OwnerUri].Batches[0] = Common.GetBasicExecutedBatch(); - queryService.HandleSaveResultsAsCsvRequest(saveParams, saveRequest.Object).Wait(); + + // Call save results and wait on the save task + await queryService.HandleSaveResultsAsCsvRequest(saveParams, saveRequest.Object); + ResultSet selectedResultSet = queryService.ActiveQueries[saveParams.OwnerUri].Batches[saveParams.BatchIndex].ResultSets[saveParams.ResultSetIndex]; + Task saveTask = selectedResultSet.GetSaveTask(saveParams.FilePath); + await saveTask; // Expect to see error message Assert.NotNull(errMessage); @@ -166,13 +159,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution [Fact] public async void SaveResultsAsCsvQueryNotFoundTest() { - + // Create a query execution service var workspaceService = new Mock>(); - // Execute a query var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); - var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; - var executeRequest = GetQueryExecuteResultContextMock(null, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); // Request to save the results as csv with query that is no longer active var saveParams = new SaveResultsAsCsvRequestParams @@ -198,19 +187,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution [Fact] public async void SaveResultsAsJsonSuccessTest() { - - // Set up file for returning the query - var fileMock = new Mock(); - fileMock.SetupGet(file => file.Contents).Returns(Common.StandardQuery); - // Set up workspace mock - var workspaceService = new Mock>(); - workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) - .Returns(fileMock.Object); // Execute a query - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, Common.GetPrimedWorkspaceService()); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; - var executeRequest = GetQueryExecuteResultContextMock(null, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + var executeRequest = RequestContextMocks.Create(null); + await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; // Request to save the results as json with correct parameters var saveParams = new SaveResultsAsJsonRequestParams @@ -223,7 +205,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution SaveResultRequestResult result = null; var saveRequest = GetSaveResultsContextMock(qcr => result = qcr, null); queryService.ActiveQueries[Common.OwnerUri].Batches[0] = Common.GetBasicExecutedBatch(); - queryService.HandleSaveResultsAsJsonRequest(saveParams, saveRequest.Object).Wait(); + + // Call save results and wait on the save task + await queryService.HandleSaveResultsAsJsonRequest(saveParams, saveRequest.Object); + ResultSet selectedResultSet = queryService.ActiveQueries[saveParams.OwnerUri].Batches[saveParams.BatchIndex].ResultSets[saveParams.ResultSetIndex]; + Task saveTask = selectedResultSet.GetSaveTask(saveParams.FilePath); + await saveTask; + + // Expect to see a file successfully created in filepath and a success message Assert.Null(result.Messages); @@ -243,19 +232,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution [Fact] public async void SaveResultsAsJsonWithSelectionSuccessTest() { - // Set up file for returning the query - var fileMock = new Mock(); - fileMock.SetupGet(file => file.Contents).Returns(Common.StandardQuery); - // Set up workspace mock - var workspaceService = new Mock>(); - workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) - .Returns(fileMock.Object); - // Execute a query - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, Common.GetPrimedWorkspaceService()); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument , OwnerUri = Common.OwnerUri }; - var executeRequest = GetQueryExecuteResultContextMock(null, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + var executeRequest = RequestContextMocks.Create(null); + await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; // Request to save the results as json with correct parameters var saveParams = new SaveResultsAsJsonRequestParams @@ -265,14 +247,19 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution BatchIndex = 0, FilePath = "testwrite_5.json", RowStartIndex = 0, - RowEndIndex = 0, + RowEndIndex = 1, ColumnStartIndex = 0, - ColumnEndIndex = 0 + ColumnEndIndex = 1 }; SaveResultRequestResult result = null; var saveRequest = GetSaveResultsContextMock(qcr => result = qcr, null); queryService.ActiveQueries[Common.OwnerUri].Batches[0] = Common.GetBasicExecutedBatch(); - queryService.HandleSaveResultsAsJsonRequest(saveParams, saveRequest.Object).Wait(); + + // Call save results and wait on the save task + await queryService.HandleSaveResultsAsJsonRequest(saveParams, saveRequest.Object); + ResultSet selectedResultSet = queryService.ActiveQueries[saveParams.OwnerUri].Batches[saveParams.BatchIndex].ResultSets[saveParams.ResultSetIndex]; + Task saveTask = selectedResultSet.GetSaveTask(saveParams.FilePath); + await saveTask; // Expect to see a file successfully created in filepath and a success message Assert.Null(result.Messages); @@ -292,18 +279,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution [Fact] public async void SaveResultsAsJsonExceptionTest() { - // Set up file for returning the query - var fileMock = new Mock(); - fileMock.SetupGet(file => file.Contents).Returns(Common.StandardQuery); - // Set up workspace mock - var workspaceService = new Mock>(); - workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) - .Returns(fileMock.Object); // Execute a query - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, Common.GetPrimedWorkspaceService()); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; - var executeRequest = GetQueryExecuteResultContextMock(null, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + var executeRequest = RequestContextMocks.Create(null); + await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; // Request to save the results as json with incorrect filepath var saveParams = new SaveResultsAsJsonRequestParams @@ -313,11 +294,17 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution BatchIndex = 0, FilePath = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? "G:\\test.json" : "/test.json" }; - // SaveResultRequestResult result = null; - string errMessage = null; - var saveRequest = GetSaveResultsContextMock( null, err => errMessage = (string) err); + + + SaveResultRequestError errMessage = null; + var saveRequest = GetSaveResultsContextMock( null, err => errMessage = (SaveResultRequestError) err); queryService.ActiveQueries[Common.OwnerUri].Batches[0] = Common.GetBasicExecutedBatch(); - queryService.HandleSaveResultsAsJsonRequest(saveParams, saveRequest.Object).Wait(); + + // Call save results and wait on the save task + await queryService.HandleSaveResultsAsJsonRequest(saveParams, saveRequest.Object); + ResultSet selectedResultSet = queryService.ActiveQueries[saveParams.OwnerUri].Batches[saveParams.BatchIndex].ResultSets[saveParams.ResultSetIndex]; + Task saveTask = selectedResultSet.GetSaveTask(saveParams.FilePath); + await saveTask; // Expect to see error message Assert.NotNull(errMessage); @@ -331,12 +318,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution [Fact] public async void SaveResultsAsJsonQueryNotFoundTest() { + + // Create a query service var workspaceService = new Mock>(); - // Execute a query var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); - var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; - var executeRequest = GetQueryExecuteResultContextMock(null, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); // Request to save the results as json with query that is no longer active var saveParams = new SaveResultsAsJsonRequestParams @@ -404,52 +389,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution mock.Verify(rc => rc.SendError(It.IsAny()), sendErrorCalls); } - /// - /// Mock request context for executing a query - /// - /// - /// - /// - /// - /// - public static Mock> GetQueryExecuteResultContextMock( - Action resultCallback, - Action, QueryExecuteCompleteParams> eventCallback, - Action errorCallback) - { - var requestContext = new Mock>(); - - // Setup the mock for SendResult - var sendResultFlow = requestContext - .Setup(rc => rc.SendResult(It.IsAny())) - .Returns(Task.FromResult(0)); - if (resultCallback != null) - { - sendResultFlow.Callback(resultCallback); - } - - // Setup the mock for SendEvent - var sendEventFlow = requestContext.Setup(rc => rc.SendEvent( - It.Is>(m => m == QueryExecuteCompleteEvent.Type), - It.IsAny())) - .Returns(Task.FromResult(0)); - if (eventCallback != null) - { - sendEventFlow.Callback(eventCallback); - } - - // Setup the mock for SendError - var sendErrorFlow = requestContext.Setup(rc => rc.SendError(It.IsAny())) - .Returns(Task.FromResult(0)); - if (errorCallback != null) - { - sendErrorFlow.Callback(errorCallback); - } - - return requestContext; - } - - #endregion } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs index 7b57971b..4036c655 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs @@ -146,8 +146,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution Common.CreateMockFactory(new[] {Common.StandardTestData}, false), true, workspaceService.Object); var executeParams = new QueryExecuteParams {QuerySelection = null, OwnerUri = Common.OwnerUri}; - var executeRequest = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); + var executeRequest = RequestContextMocks.Create(null); await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; // ... And I then ask for a valid set of results from it var subsetParams = new QueryExecuteSubsetParams {OwnerUri = Common.OwnerUri, RowsCount = 1, ResultSetIndex = 0, RowsStartIndex = 0}; @@ -203,8 +204,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution Common.CreateMockFactory(new[] { Common.StandardTestData }, false), true, workspaceService.Object); var executeParams = new QueryExecuteParams { QuerySelection = null, OwnerUri = Common.OwnerUri }; - var executeRequest = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + var executeRequest = RequestContextMocks.Create(null); + await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; queryService.ActiveQueries[Common.OwnerUri].HasExecuted = false; // ... And I then ask for a valid set of results from it @@ -224,17 +226,15 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution [Fact] public async void SubsetServiceOutOfRangeSubsetTest() - { - - var workspaceService = new Mock>(); + { // If: // ... I have a query that doesn't have any result sets var queryService = await Common.GetPrimedExecutionService( - Common.CreateMockFactory(null, false), true, - workspaceService.Object); + Common.CreateMockFactory(null, false), true, Common.GetPrimedWorkspaceService()); var executeParams = new QueryExecuteParams { QuerySelection = null, OwnerUri = Common.OwnerUri }; - var executeRequest = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + var executeRequest = RequestContextMocks.Create(null); + await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; // ... And I then ask for a set of results from it var subsetParams = new QueryExecuteSubsetParams { OwnerUri = Common.OwnerUri, RowsCount = 1, ResultSetIndex = 0, RowsStartIndex = 0 }; @@ -259,27 +259,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution Action resultCallback, Action errorCallback) { - var requestContext = new Mock>(); - - // Setup the mock for SendResult - var sendResultFlow = requestContext - .Setup(rc => rc.SendResult(It.IsAny())) - .Returns(Task.FromResult(0)); - if (resultCallback != null) - { - sendResultFlow.Callback(resultCallback); - } - - // Setup the mock for SendError - var sendErrorFlow = requestContext - .Setup(rc => rc.SendError(It.IsAny())) - .Returns(Task.FromResult(0)); - if (errorCallback != null) - { - sendErrorFlow.Callback(errorCallback); - } - - return requestContext; + return RequestContextMocks.Create(resultCallback) + .AddErrorHandling(errorCallback); } private static void VerifyQuerySubsetCallCount(Mock> mock, Times sendResultCalls, diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/SqlContext/SettingsTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/SqlContext/SettingsTests.cs index fba65a29..e06764de 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/SqlContext/SettingsTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/SqlContext/SettingsTests.cs @@ -22,8 +22,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices var sqlToolsSettings = new SqlToolsSettings(); Assert.True(sqlToolsSettings.IsDiagnositicsEnabled); Assert.True(sqlToolsSettings.IsSuggestionsEnabled); - Assert.True(sqlToolsSettings.SqlTools.EnableIntellisense); - Assert.True(sqlToolsSettings.SqlTools.IntelliSense.EnableDiagnostics); + Assert.True(sqlToolsSettings.SqlTools.IntelliSense.EnableIntellisense); + Assert.True(sqlToolsSettings.SqlTools.IntelliSense.EnableErrorChecking); Assert.True(sqlToolsSettings.SqlTools.IntelliSense.EnableSuggestions); Assert.True(sqlToolsSettings.SqlTools.IntelliSense.EnableQuickInfo); Assert.False(sqlToolsSettings.SqlTools.IntelliSense.LowerCaseSuggestions); @@ -38,17 +38,17 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices var sqlToolsSettings = new SqlToolsSettings(); // diagnostics is enabled if IntelliSense and Diagnostics flags are set - sqlToolsSettings.SqlTools.EnableIntellisense = true; - sqlToolsSettings.SqlTools.IntelliSense.EnableDiagnostics = true; + sqlToolsSettings.SqlTools.IntelliSense.EnableIntellisense = true; + sqlToolsSettings.SqlTools.IntelliSense.EnableErrorChecking = true; Assert.True(sqlToolsSettings.IsDiagnositicsEnabled); // diagnostics is disabled if either IntelliSense and Diagnostics flags is not set - sqlToolsSettings.SqlTools.EnableIntellisense = false; - sqlToolsSettings.SqlTools.IntelliSense.EnableDiagnostics = true; + sqlToolsSettings.SqlTools.IntelliSense.EnableIntellisense = false; + sqlToolsSettings.SqlTools.IntelliSense.EnableErrorChecking = true; Assert.False(sqlToolsSettings.IsDiagnositicsEnabled); - sqlToolsSettings.SqlTools.EnableIntellisense = true; - sqlToolsSettings.SqlTools.IntelliSense.EnableDiagnostics = false; + sqlToolsSettings.SqlTools.IntelliSense.EnableIntellisense = true; + sqlToolsSettings.SqlTools.IntelliSense.EnableErrorChecking = false; Assert.False(sqlToolsSettings.IsDiagnositicsEnabled); } @@ -61,16 +61,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices var sqlToolsSettings = new SqlToolsSettings(); // suggestions is enabled if IntelliSense and Suggestions flags are set - sqlToolsSettings.SqlTools.EnableIntellisense = true; + sqlToolsSettings.SqlTools.IntelliSense.EnableIntellisense = true; sqlToolsSettings.SqlTools.IntelliSense.EnableSuggestions = true; Assert.True(sqlToolsSettings.IsSuggestionsEnabled); // suggestions is disabled if either IntelliSense and Suggestions flags is not set - sqlToolsSettings.SqlTools.EnableIntellisense = false; + sqlToolsSettings.SqlTools.IntelliSense.EnableIntellisense = false; sqlToolsSettings.SqlTools.IntelliSense.EnableSuggestions = true; Assert.False(sqlToolsSettings.IsSuggestionsEnabled); - sqlToolsSettings.SqlTools.EnableIntellisense = true; + sqlToolsSettings.SqlTools.IntelliSense.EnableIntellisense = true; sqlToolsSettings.SqlTools.IntelliSense.EnableSuggestions = false; Assert.False(sqlToolsSettings.IsSuggestionsEnabled); } @@ -84,16 +84,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices var sqlToolsSettings = new SqlToolsSettings(); // quick info is enabled if IntelliSense and quick info flags are set - sqlToolsSettings.SqlTools.EnableIntellisense = true; + sqlToolsSettings.SqlTools.IntelliSense.EnableIntellisense = true; sqlToolsSettings.SqlTools.IntelliSense.EnableQuickInfo = true; Assert.True(sqlToolsSettings.IsQuickInfoEnabled); // quick info is disabled if either IntelliSense and quick info flags is not set - sqlToolsSettings.SqlTools.EnableIntellisense = false; + sqlToolsSettings.SqlTools.IntelliSense.EnableIntellisense = false; sqlToolsSettings.SqlTools.IntelliSense.EnableQuickInfo = true; Assert.False(sqlToolsSettings.IsQuickInfoEnabled); - sqlToolsSettings.SqlTools.EnableIntellisense = true; + sqlToolsSettings.SqlTools.IntelliSense.EnableIntellisense = true; sqlToolsSettings.SqlTools.IntelliSense.EnableQuickInfo = false; Assert.False(sqlToolsSettings.IsQuickInfoEnabled); } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/CommandOptionsTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/CommandOptionsTests.cs new file mode 100644 index 00000000..0e2cdeec --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/CommandOptionsTests.cs @@ -0,0 +1,69 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Utility; +using Xunit; + +namespace Microsoft.SqlTools.Test.Utility +{ + /// + /// Tests for the CommandOptions class + /// + public class CommandOptionsTests + { + [Fact] + public void LoggingEnabledWhenFlagProvided() + { + var args = new string[] {"--enable-logging"}; + CommandOptions options = new CommandOptions(args); + Assert.NotNull(options); + + Assert.True(options.EnableLogging); + Assert.False(options.ShouldExit); + } + + [Fact] + public void LoggingDisabledWhenFlagNotProvided() + { + var args = new string[] {}; + CommandOptions options = new CommandOptions(args); + Assert.NotNull(options); + + Assert.False(options.EnableLogging); + Assert.False(options.ShouldExit); + } + + [Fact] + public void UsageIsShownWhenHelpFlagProvided() + { + var args = new string[] {"--help"}; + CommandOptions options = new CommandOptions(args); + Assert.NotNull(options); + + Assert.True(options.ShouldExit); + } + + [Fact] + public void UsageIsShownWhenBadArgumentsProvided() + { + var args = new string[] {"--unknown-argument", "/bad-argument"}; + CommandOptions options = new CommandOptions(args); + Assert.NotNull(options); + + Assert.True(options.ShouldExit); + } + + [Fact] + public void DefaultValuesAreUsedWhenNoArgumentsAreProvided() + { + var args = new string[] {}; + CommandOptions options = new CommandOptions(args); + Assert.NotNull(options); + + Assert.False(options.EnableLogging); + Assert.False(options.ShouldExit); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestUtils.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestUtils.cs index b2d52180..61887cc0 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestUtils.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestUtils.cs @@ -14,6 +14,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Utility test(); } } + + public static void RunIfLinuxOrOSX(Action test) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux) || RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + test(); + } + } public static void RunIfWindows(Action test) { diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/project.json b/test/Microsoft.SqlTools.ServiceLayer.Test/project.json index ae6e3ea6..f0779583 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/project.json +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/project.json @@ -4,12 +4,21 @@ "buildOptions": { "debugType": "portable" }, + "configurations": { + "Integration": { + "buildOptions": { + "define": [ + "LIVE_CONNECTION_TESTS" + ] + } + } + }, "dependencies": { "Newtonsoft.Json": "9.0.1", "System.Runtime.Serialization.Primitives": "4.1.1", "System.Data.Common": "4.1.0", - "System.Data.SqlClient": "4.1.0", - "Microsoft.SqlServer.Smo": "140.1.8", + "System.Data.SqlClient": "4.4.0-sqltools-24613-04", + "Microsoft.SqlServer.Smo": "140.1.11", "System.Security.SecureString": "4.0.0", "System.Collections.Specialized": "4.0.1", "System.ComponentModel.TypeConverter": "4.1.0", diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Driver/ServiceTestDriver.cs b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Driver/ServiceTestDriver.cs new file mode 100644 index 00000000..ac7e76c0 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Driver/ServiceTestDriver.cs @@ -0,0 +1,39 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +// +// The following is based upon code from PowerShell Editor Services +// License: https://github.com/PowerShell/PowerShellEditorServices/blob/develop/LICENSE +// + +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Channel; + +namespace Microsoft.SqlTools.ServiceLayer.TestDriver.Driver +{ + /// + /// Test driver for the service host + /// + public class ServiceTestDriver : TestDriverBase + { + public ServiceTestDriver(string serviceHostExecutable) + { + var clientChannel = new StdioClientChannel(serviceHostExecutable); + this.protocolClient = new ProtocolEndpoint(clientChannel, MessageProtocolType.LanguageServer); + } + + public async Task Start() + { + await this.protocolClient.Start(); + await Task.Delay(1000); // Wait for the service host to start + } + + public async Task Stop() + { + await this.protocolClient.Stop(); + } + } +} \ No newline at end of file diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Driver/TestDriverBase.cs b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Driver/TestDriverBase.cs new file mode 100644 index 00000000..7f86764a --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Driver/TestDriverBase.cs @@ -0,0 +1,173 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +// +// The following is based upon code from PowerShell Editor Services +// License: https://github.com/PowerShell/PowerShellEditorServices/blob/develop/LICENSE +// + +using System; +using System.Collections.Concurrent; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; +using Microsoft.SqlTools.ServiceLayer.Utility; + +namespace Microsoft.SqlTools.ServiceLayer.TestDriver.Driver +{ + /// + /// Wraps the ProtocolEndpoint class with queues to handle events/requests + /// + public class TestDriverBase + { + protected ProtocolEndpoint protocolClient; + + private ConcurrentDictionary> eventQueuePerType = + new ConcurrentDictionary>(); + + private ConcurrentDictionary> requestQueuePerType = + new ConcurrentDictionary>(); + + public Task SendRequest( + RequestType requestType, + TParams requestParams) + { + return + this.protocolClient.SendRequest( + requestType, + requestParams); + } + + public Task SendEvent(EventType eventType, TParams eventParams) + { + return + this.protocolClient.SendEvent( + eventType, + eventParams); + } + + public void QueueEventsForType(EventType eventType) + { + var eventQueue = + this.eventQueuePerType.AddOrUpdate( + eventType.MethodName, + new AsyncQueue(), + (key, queue) => queue); + + this.protocolClient.SetEventHandler( + eventType, + (p, ctx) => + { + return eventQueue.EnqueueAsync(p); + }); + } + + public async Task WaitForEvent( + EventType eventType, + int timeoutMilliseconds = 5000) + { + Task eventTask = null; + + // Use the event queue if one has been registered + AsyncQueue eventQueue = null; + if (this.eventQueuePerType.TryGetValue(eventType.MethodName, out eventQueue)) + { + eventTask = + eventQueue + .DequeueAsync() + .ContinueWith( + task => (TParams)task.Result); + } + else + { + TaskCompletionSource eventTaskSource = new TaskCompletionSource(); + + this.protocolClient.SetEventHandler( + eventType, + (p, ctx) => + { + if (!eventTaskSource.Task.IsCompleted) + { + eventTaskSource.SetResult(p); + } + + return Task.FromResult(true); + }, + true); // Override any existing handler + + eventTask = eventTaskSource.Task; + } + + await + Task.WhenAny( + eventTask, + Task.Delay(timeoutMilliseconds)); + + if (!eventTask.IsCompleted) + { + throw new TimeoutException( + string.Format( + "Timed out waiting for '{0}' event!", + eventType.MethodName)); + } + + return await eventTask; + } + + public async Task>> WaitForRequest( + RequestType requestType, + int timeoutMilliseconds = 5000) + { + Task>> requestTask = null; + + // Use the request queue if one has been registered + AsyncQueue requestQueue = null; + if (this.requestQueuePerType.TryGetValue(requestType.MethodName, out requestQueue)) + { + requestTask = + requestQueue + .DequeueAsync() + .ContinueWith( + task => (Tuple>)task.Result); + } + else + { + var requestTaskSource = + new TaskCompletionSource>>(); + + this.protocolClient.SetRequestHandler( + requestType, + (p, ctx) => + { + if (!requestTaskSource.Task.IsCompleted) + { + requestTaskSource.SetResult( + new Tuple>(p, ctx)); + } + + return Task.FromResult(true); + }); + + requestTask = requestTaskSource.Task; + } + + await + Task.WhenAny( + requestTask, + Task.Delay(timeoutMilliseconds)); + + if (!requestTask.IsCompleted) + { + throw new TimeoutException( + string.Format( + "Timed out waiting for '{0}' request!", + requestType.MethodName)); + } + + return await requestTask; + } + } +} + diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Program.cs b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Program.cs new file mode 100644 index 00000000..55b0a552 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Program.cs @@ -0,0 +1,66 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Linq; +using System.Reflection; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Driver; + +namespace Microsoft.SqlTools.ServiceLayer.TestDriver +{ + internal class Program + { + internal static void Main(string[] args) + { + if (args.Length < 1) + { + Console.WriteLine( "Microsoft.SqlTools.ServiceLayer.TestDriver.exe [tests]" + Environment.NewLine + + " is the path to the Microsoft.SqlTools.ServiceLayer.exe executable" + Environment.NewLine + + " [tests] is a space-separated list of tests to run." + Environment.NewLine + + " They are qualified within the Microsoft.SqlTools.ServiceLayer.TestDriver.Tests namespace"); + Environment.Exit(0); + } + + Task.Run(async () => + { + var serviceHostExecutable = args[0]; + var tests = args.Skip(1); + + foreach (var test in tests) + { + ServiceTestDriver driver = null; + + try + { + driver = new ServiceTestDriver(serviceHostExecutable); + + var className = test.Substring(0, test.LastIndexOf('.')); + var methodName = test.Substring(test.LastIndexOf('.') + 1); + + var type = Type.GetType("Microsoft.SqlTools.ServiceLayer.TestDriver.Tests." + className); + var typeInstance = Activator.CreateInstance(type); + MethodInfo methodInfo = type.GetMethod(methodName); + + await driver.Start(); + Console.WriteLine("Running test " + test); + await (Task)methodInfo.Invoke(typeInstance, new object[] {driver}); + } + catch (Exception ex) + { + Console.WriteLine(ex.ToString()); + } + finally + { + if (driver != null) + { + await driver.Stop(); + } + } + } + }).Wait(); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/ExampleTests.cs b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/ExampleTests.cs new file mode 100644 index 00000000..21fe7552 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/ExampleTests.cs @@ -0,0 +1,38 @@ +using System; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; +using Microsoft.SqlTools.ServiceLayer.TestDriver.Driver; + +namespace Microsoft.SqlTools.ServiceLayer.TestDriver.Tests +{ + public class ExampleTests + { + /// + /// Example test that performs a connect, then disconnect. + /// All tests must have the same signature of returning an async Task + /// and taking in a ServiceTestDriver as a parameter. + /// + public async Task ConnectDisconnectTest(ServiceTestDriver driver) + { + var connectParams = new ConnectParams(); + connectParams.OwnerUri = "file"; + connectParams.Connection = new ConnectionDetails(); + connectParams.Connection.ServerName = "localhost"; + connectParams.Connection.AuthenticationType = "Integrated"; + + var result = await driver.SendRequest(ConnectionRequest.Type, connectParams); + if (result) + { + await driver.WaitForEvent(ConnectionCompleteNotification.Type); + + var disconnectParams = new DisconnectParams(); + disconnectParams.OwnerUri = "file"; + var result2 = await driver.SendRequest(DisconnectRequest.Type, disconnectParams); + if (result2) + { + Console.WriteLine("success"); + } + } + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/project.json b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/project.json new file mode 100644 index 00000000..e2b39b6c --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/project.json @@ -0,0 +1,29 @@ +{ + "name": "Microsoft.SqlTools.ServiceLayer.TestDriver", + "version": "1.0.0-*", + "buildOptions": { + "debugType": "portable", + "emitEntryPoint": true + }, + "dependencies": { + "Microsoft.SqlTools.ServiceLayer": { + "target": "project" + } + }, + "frameworks": { + "netcoreapp1.0": { + "dependencies": { + "Microsoft.NETCore.App": { + "version": "1.0.0" + } + }, + "imports": [ + "dotnet5.4", + "portable-net451+win8" + ], + } + }, + "runtimes": { + "win7-x64": {} + } +}