From 9e8e1df95ec8e04058fd04466fa60af70681c89f Mon Sep 17 00:00:00 2001 From: Aditya Bist Date: Thu, 23 Mar 2017 13:19:59 -0700 Subject: [PATCH] fixed peek definition tokenizing (#281) * fixed peek definition tokenizing * fixed signature help test * added new heuristic for PeekDefinition behaviour * added tests for new heuristic * fixed code according to Kevin's CR * fixed failing test due to shared connection * changed uri for procedure test --- .../LanguageServices/LanguageService.cs | 211 ++- .../LanguageServices/ScriptDocumentInfo.cs | 325 ++-- .../Scripting/ScripterCore.cs | 4 +- .../Workspace/Contracts/TextDocument.cs | 91 +- .../LanguageServer/PeekDefinitionTests.cs | 1579 +++++++++-------- 5 files changed, 1293 insertions(+), 917 deletions(-) diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs index c1d6d85d..05599bf9 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs @@ -774,6 +774,100 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices return completionItem; } + + /// + /// Queue a task to the binding queue + /// + /// + /// + /// + /// + /// + /// Returns the result of the task as a DefinitionResult + private DefinitionResult QueueTask(TextDocumentPosition textDocumentPosition, ScriptParseInfo scriptParseInfo, + ConnectionInfo connInfo, ScriptFile scriptFile, string tokenText) + { + // Queue the task with the binding queue + QueueItem queueItem = this.BindingQueue.QueueBindingOperation( + key: scriptParseInfo.ConnectionKey, + bindingTimeout: LanguageService.PeekDefinitionTimeout, + bindOperation: (bindingContext, cancelToken) => + { + string schemaName = this.GetSchemaName(scriptParseInfo, textDocumentPosition.Position, scriptFile); + // Script object using SMO + Scripter scripter = new Scripter(bindingContext.ServerConnection, connInfo); + return scripter.GetScript( + scriptParseInfo.ParseResult, + textDocumentPosition.Position, + bindingContext.MetadataDisplayInfoProvider, + tokenText, + schemaName); + }, + timeoutOperation: (bindingContext) => + { + // return error result + return new DefinitionResult + { + IsErrorResult = true, + Message = SR.PeekDefinitionTimedoutError, + Locations = null + }; + }); + + // wait for the queue item + queueItem.ItemProcessed.WaitOne(); + var result = queueItem.GetResultAsT(); + return result; + } + + private DefinitionResult GetDefinitionFromTokenList(TextDocumentPosition textDocumentPosition, List tokenList, + ScriptParseInfo scriptParseInfo, ScriptFile scriptFile, ConnectionInfo connInfo) + { + + DefinitionResult lastResult = null; + foreach (var token in tokenList) + { + + // Strip "[" and "]"(if present) from the token text to enable matching with the suggestions. + // The suggestion title does not contain any sql punctuation + string tokenText = TextUtilities.RemoveSquareBracketSyntax(token.Text); + textDocumentPosition.Position.Line = token.StartLocation.LineNumber; + textDocumentPosition.Position.Character = token.StartLocation.ColumnNumber; + if (Monitor.TryEnter(scriptParseInfo.BuildingMetadataLock)) + { + try + { + var result = QueueTask(textDocumentPosition, scriptParseInfo, connInfo, scriptFile, tokenText); + lastResult = result; + if (!result.IsErrorResult) + { + return result; + } + } + catch (Exception ex) + { + // if any exceptions are raised return error result with message + Logger.Write(LogLevel.Error, "Exception in GetDefinition " + ex.ToString()); + return new DefinitionResult + { + IsErrorResult = true, + Message = SR.PeekDefinitionError(ex.Message), + Locations = null + }; + } + finally + { + Monitor.Exit(scriptParseInfo.BuildingMetadataLock); + } + } + else + { + Logger.Write(LogLevel.Error, "Timeout waiting to query metadata from server"); + } + } + return (lastResult != null) ? lastResult : null; + } + /// /// Get definition for a selected sql object using SMO Scripting /// @@ -796,76 +890,63 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices } // Get token from selected text - Token selectedToken = ScriptDocumentInfo.GetToken(scriptParseInfo, textDocumentPosition.Position.Line + 1, textDocumentPosition.Position.Character); + Tuple, Queue> selectedToken = ScriptDocumentInfo.GetPeekDefinitionTokens(scriptParseInfo, + textDocumentPosition.Position.Line + 1, textDocumentPosition.Position.Character + 1); + if (selectedToken == null) { return null; } - // Strip "[" and "]"(if present) from the token text to enable matching with the suggestions. - // The suggestion title does not contain any sql punctuation - string tokenText = TextUtilities.RemoveSquareBracketSyntax(selectedToken.Text); - if (scriptParseInfo.IsConnected && Monitor.TryEnter(scriptParseInfo.BuildingMetadataLock)) - { - try - { - // Queue the task with the binding queue - QueueItem queueItem = this.BindingQueue.QueueBindingOperation( - key: scriptParseInfo.ConnectionKey, - bindingTimeout: LanguageService.PeekDefinitionTimeout, - bindOperation: (bindingContext, cancelToken) => - { - string schemaName = this.GetSchemaName(scriptParseInfo, textDocumentPosition.Position, scriptFile); - // Script object using SMO - Scripter scripter = new Scripter(bindingContext.ServerConnection, connInfo); - return scripter.GetScript( - scriptParseInfo.ParseResult, - textDocumentPosition.Position, - bindingContext.MetadataDisplayInfoProvider, - tokenText, - schemaName); - }, - timeoutOperation: (bindingContext) => - { - // return error result - return new DefinitionResult - { - IsErrorResult = true, - Message = SR.PeekDefinitionTimedoutError, - Locations = null - }; - }); - - // wait for the queue item - queueItem.ItemProcessed.WaitOne(); - return queueItem.GetResultAsT(); - } - catch (Exception ex) - { - // if any exceptions are raised return error result with message - Logger.Write(LogLevel.Error, "Exception in GetDefinition " + ex.ToString()); - return new DefinitionResult - { - IsErrorResult = true, - Message = SR.PeekDefinitionError(ex.Message), - Locations = null - }; - } - finally - { - Monitor.Exit(scriptParseInfo.BuildingMetadataLock); - } + if (scriptParseInfo.IsConnected) + { + //try children tokens first + Stack childrenTokens = selectedToken.Item1; + List tokenList = childrenTokens.ToList(); + DefinitionResult childrenResult = GetDefinitionFromTokenList(textDocumentPosition, tokenList, scriptParseInfo, scriptFile, connInfo); + + // if the children peak definition returned null then + // try the parents + if (childrenResult == null || childrenResult.IsErrorResult) + { + Queue parentTokens = selectedToken.Item2; + tokenList = parentTokens.ToList(); + DefinitionResult parentResult = GetDefinitionFromTokenList(textDocumentPosition, tokenList, scriptParseInfo, scriptFile, connInfo); + return (parentResult == null) ? null : parentResult; + } + else + { + return childrenResult; + } + } + else + { + // User is not connected. + return new DefinitionResult + { + IsErrorResult = true, + Message = SR.PeekDefinitionNotConnectedError, + Locations = null + }; } - else + } + + /// + /// Wrapper around find token method + /// + /// + /// + /// + /// token index + private int FindTokenWithCorrectOffset(ScriptParseInfo scriptParseInfo, int startLine, int startColumn) + { + var tokenIndex = scriptParseInfo.ParseResult.Script.TokenManager.FindToken(startLine, startColumn); + var end = scriptParseInfo.ParseResult.Script.TokenManager.GetToken(tokenIndex).EndLocation; + if (end.LineNumber == startLine && end.ColumnNumber == startColumn) { - // User is not connected. - return new DefinitionResult - { - IsErrorResult = true, - Message = SR.PeekDefinitionNotConnectedError, - Locations = null - }; + return tokenIndex + 1; } + return tokenIndex; } /// @@ -878,13 +959,13 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices private string GetSchemaName(ScriptParseInfo scriptParseInfo, Position position, ScriptFile scriptFile) { // Offset index by 1 for sql parser - int startLine = position.Line + 1; - int startColumn = position.Character + 1; + int startLine = position.Line; + int startColumn = position.Character; // Get schema name if (scriptParseInfo != null && scriptParseInfo.ParseResult != null && scriptParseInfo.ParseResult.Script != null && scriptParseInfo.ParseResult.Script.Tokens != null) { - var tokenIndex = scriptParseInfo.ParseResult.Script.TokenManager.FindToken(startLine, startColumn); + var tokenIndex = FindTokenWithCorrectOffset(scriptParseInfo, startLine, startColumn); var prevTokenIndex = scriptParseInfo.ParseResult.Script.TokenManager.GetPreviousSignificantTokenIndex(tokenIndex); var prevTokenText = scriptParseInfo.ParseResult.Script.TokenManager.GetText(prevTokenIndex); if (prevTokenText != null && prevTokenText.Equals(".")) diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ScriptDocumentInfo.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ScriptDocumentInfo.cs index 9a1ae4a7..f46480e2 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ScriptDocumentInfo.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ScriptDocumentInfo.cs @@ -1,133 +1,192 @@ -// -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. -// - -using Microsoft.SqlServer.Management.SqlParser.Parser; -using Microsoft.SqlTools.Utility; -using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; - -namespace Microsoft.SqlTools.ServiceLayer.LanguageServices.Completion -{ - /// - /// A class to calculate the numbers used by SQL parser using the text positions and content - /// - internal class ScriptDocumentInfo - { - /// - /// Create new instance - /// - public ScriptDocumentInfo(TextDocumentPosition textDocumentPosition, ScriptFile scriptFile, ScriptParseInfo scriptParseInfo) - : this(textDocumentPosition, scriptFile) - { - Validate.IsNotNull(nameof(scriptParseInfo), scriptParseInfo); - - ScriptParseInfo = scriptParseInfo; - // need to adjust line & column for base-1 parser indices - Token = GetToken(scriptParseInfo, ParserLine, ParserColumn); - } - - private ScriptDocumentInfo(TextDocumentPosition textDocumentPosition, ScriptFile scriptFile) - { - StartLine = textDocumentPosition.Position.Line; - ParserLine = textDocumentPosition.Position.Line + 1; - StartColumn = TextUtilities.PositionOfPrevDelimeter( - scriptFile.Contents, - textDocumentPosition.Position.Line, - textDocumentPosition.Position.Character); - EndColumn = TextUtilities.PositionOfNextDelimeter( - scriptFile.Contents, - textDocumentPosition.Position.Line, - textDocumentPosition.Position.Character); - ParserColumn = textDocumentPosition.Position.Character + 1; - Contents = scriptFile.Contents; - } - - /// - /// Creates a new with no backing defined - /// - /// A - /// A to process - /// - public static ScriptDocumentInfo CreateDefaultDocumentInfo(TextDocumentPosition textDocumentPosition, ScriptFile scriptFile) - { - return new ScriptDocumentInfo(textDocumentPosition, scriptFile); - } - - /// - /// Gets a string containing the full contents of the file. - /// - public string Contents { get; private set; } - - /// - /// Script Parse Info Instance - /// - public ScriptParseInfo ScriptParseInfo { get; private set; } - - /// - /// Start Line - /// - public int StartLine { get; private set; } - - /// - /// Parser Line - /// - public int ParserLine { get; private set; } - - /// - /// Start Column - /// - public int StartColumn { get; private set; } - - /// - /// end Column - /// - public int EndColumn { get; private set; } - - /// - /// Parser Column - /// - public int ParserColumn { get; private set; } - - /// - /// The token text in the file content used for completion list - /// - public virtual string TokenText - { - get - { - return Token != null ? Token.Text : null; - } - } - - /// - /// The token in the file content used for completion list - /// - public Token Token { get; private set; } - - /// - /// Returns the token that will be used by SQL parser for creating the completion list - /// - internal static Token GetToken(ScriptParseInfo scriptParseInfo, int startLine, int startColumn) - { - if (scriptParseInfo != null && scriptParseInfo.ParseResult != null && scriptParseInfo.ParseResult.Script != null && scriptParseInfo.ParseResult.Script.Tokens != null) - { - var tokenIndex = scriptParseInfo.ParseResult.Script.TokenManager.FindToken(startLine, startColumn); - if (tokenIndex >= 0) - { - // return the current token - int currentIndex = 0; - foreach (var token in scriptParseInfo.ParseResult.Script.Tokens) - { - if (currentIndex == tokenIndex) - { - return token; - } - ++currentIndex; - } - } - } - return null; - } - } -} +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using Microsoft.SqlServer.Management.SqlParser.Parser; +using Microsoft.SqlTools.Utility; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; +using System.Collections.Generic; + +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices.Completion +{ + /// + /// A class to calculate the numbers used by SQL parser using the text positions and content + /// + internal class ScriptDocumentInfo + { + /// + /// Create new instance + /// + public ScriptDocumentInfo(TextDocumentPosition textDocumentPosition, ScriptFile scriptFile, ScriptParseInfo scriptParseInfo) + : this(textDocumentPosition, scriptFile) + { + Validate.IsNotNull(nameof(scriptParseInfo), scriptParseInfo); + + ScriptParseInfo = scriptParseInfo; + // need to adjust line & column for base-1 parser indices + Token = GetToken(scriptParseInfo, ParserLine, ParserColumn); + } + + private ScriptDocumentInfo(TextDocumentPosition textDocumentPosition, ScriptFile scriptFile) + { + StartLine = textDocumentPosition.Position.Line; + ParserLine = textDocumentPosition.Position.Line + 1; + StartColumn = TextUtilities.PositionOfPrevDelimeter( + scriptFile.Contents, + textDocumentPosition.Position.Line, + textDocumentPosition.Position.Character); + EndColumn = TextUtilities.PositionOfNextDelimeter( + scriptFile.Contents, + textDocumentPosition.Position.Line, + textDocumentPosition.Position.Character); + ParserColumn = textDocumentPosition.Position.Character + 1; + Contents = scriptFile.Contents; + } + + /// + /// Creates a new with no backing defined + /// + /// A + /// A to process + /// + public static ScriptDocumentInfo CreateDefaultDocumentInfo(TextDocumentPosition textDocumentPosition, ScriptFile scriptFile) + { + return new ScriptDocumentInfo(textDocumentPosition, scriptFile); + } + + /// + /// Gets a string containing the full contents of the file. + /// + public string Contents { get; private set; } + + /// + /// Script Parse Info Instance + /// + public ScriptParseInfo ScriptParseInfo { get; private set; } + + /// + /// Start Line + /// + public int StartLine { get; private set; } + + /// + /// Parser Line + /// + public int ParserLine { get; private set; } + + /// + /// Start Column + /// + public int StartColumn { get; private set; } + + /// + /// end Column + /// + public int EndColumn { get; private set; } + + /// + /// Parser Column + /// + public int ParserColumn { get; private set; } + + /// + /// The token text in the file content used for completion list + /// + public virtual string TokenText + { + get + { + return Token != null ? Token.Text : null; + } + } + + /// + /// The token in the file content used for completion list + /// + public Token Token { get; private set; } + + /// + /// Returns the token that will be used by SQL parser for creating the completion list + /// + internal static Token GetToken(ScriptParseInfo scriptParseInfo, int startLine, int startColumn) + { + if (scriptParseInfo != null && scriptParseInfo.ParseResult != null && scriptParseInfo.ParseResult.Script != null && scriptParseInfo.ParseResult.Script.Tokens != null) + { + var tokenIndex = scriptParseInfo.ParseResult.Script.TokenManager.FindToken(startLine, startColumn); + if (tokenIndex >= 0) + { + // return the current token + int currentIndex = 0; + foreach (var token in scriptParseInfo.ParseResult.Script.Tokens) + { + if (currentIndex == tokenIndex) + { + return token; + } + ++currentIndex; + } + } + } + return null; + } + + /// + /// Returns the token that is used for Peek Definition objects + /// + internal static Tuple, Queue> GetPeekDefinitionTokens(ScriptParseInfo scriptParseInfo, int startLine, int startColumn) + { + Stack childrenTokens = new Stack(); + Queue parentTokens = new Queue(); + if (scriptParseInfo != null + && scriptParseInfo.ParseResult != null + && scriptParseInfo.ParseResult.Script != null + && scriptParseInfo.ParseResult.Script.Tokens != null) + { + var tokenIndex = scriptParseInfo.ParseResult.Script.TokenManager.FindToken(startLine, startColumn); + if (tokenIndex >= 0) + { + // return the current token and the ones to its right + // until we hit a whitespace token + int currentIndex = 0; + foreach (var token in scriptParseInfo.ParseResult.Script.Tokens) + { + if (currentIndex == tokenIndex) + { + // push all parent tokens until we hit whitespace + int parentIndex = currentIndex; + while (true) + { + if (scriptParseInfo.ParseResult.Script.TokenManager.GetToken(parentIndex).Type != "LEX_WHITE") + { + parentTokens.Enqueue(scriptParseInfo.ParseResult.Script.TokenManager.GetToken(parentIndex)); + parentIndex--; + } + else + { + break; + } + } + } + else if (currentIndex > tokenIndex) + { + // push all children tokens until we hit whitespace + if (scriptParseInfo.ParseResult.Script.TokenManager.GetToken(currentIndex).Type != "LEX_WHITE") + { + childrenTokens.Push(token); + } + else + { + break; + } + } + ++currentIndex; + } + return Tuple.Create(childrenTokens, parentTokens); + } + } + return null; + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScripterCore.cs b/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScripterCore.cs index 10fa5c87..4faec400 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScripterCore.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScripterCore.cs @@ -133,8 +133,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting /// Location object of the script file internal DefinitionResult GetScript(ParseResult parseResult, Position position, IMetadataDisplayInfoProvider metadataDisplayInfoProvider, string tokenText, string schemaName) { - int parserLine = position.Line + 1; - int parserColumn = position.Character + 1; + int parserLine = position.Line; + int parserColumn = position.Character; // Get DeclarationItems from The Intellisense Resolver for the selected token. The type of the selected token is extracted from the declarationItem. IEnumerable declarationItems = GetCompletionsForToken(parseResult, parserLine, parserColumn, metadataDisplayInfoProvider); if (declarationItems != null && declarationItems.Count() > 0) diff --git a/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/TextDocument.cs b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/TextDocument.cs index 362b044a..875fa8c6 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/TextDocument.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/TextDocument.cs @@ -156,7 +156,36 @@ namespace Microsoft.SqlTools.ServiceLayer.Workspace.Contracts /// /// Gets or sets the zero-based column number. /// - public int Character { get; set; } + public int Character { get; set; } + + /// + /// Overrides the base equality method + /// + /// + /// + public override bool Equals(object obj) + { + if (obj == null || (obj as Position == null)) + { + return false; + } + Position p = (Position) obj; + bool result = (Line == p.Line) && (Character == p.Character); + return result; + } + + + /// + /// Overrides the base GetHashCode method + /// + /// + public override int GetHashCode() + { + int hash = 17; + hash = hash * 23 + Line.GetHashCode(); + hash = hash * 23 + Character.GetHashCode(); + return hash; + } } [DebuggerDisplay("Start = {Start.Line}:{Start.Character}, End = {End.Line}:{End.Character}")] @@ -171,6 +200,37 @@ namespace Microsoft.SqlTools.ServiceLayer.Workspace.Contracts /// Gets or sets the ending position of the range. /// public Position End { get; set; } + + /// + /// Overrides the base equality method + /// + /// + /// + public override bool Equals(object obj) + { + + + if (obj == null || !(obj is Range)) + { + return false; + } + Range range = (Range) obj; + bool sameStart = range.Start.Equals(Start); + bool sameEnd = range.End.Equals(End); + return (sameStart && sameEnd); + } + + /// + /// Overrides the base GetHashCode method + /// + /// + public override int GetHashCode() + { + int hash = 17; + hash = hash * 23 + Start.GetHashCode(); + hash = hash * 23 + End.GetHashCode(); + return hash; + } } [DebuggerDisplay("Range = {Range.Start.Line}:{Range.Start.Character} - {Range.End.Line}:{Range.End.Character}, Uri = {Uri}")] @@ -185,6 +245,35 @@ namespace Microsoft.SqlTools.ServiceLayer.Workspace.Contracts /// Gets or sets the Range indicating the range in which location refers. /// public Range Range { get; set; } + + /// + /// Overrides the base equality method + /// + /// + /// + public override bool Equals(object obj) + { + if (obj == null || (obj as Location == null)) + { + return false; + } + Location loc = (Location)obj; + bool sameUri = string.Equals(loc.Uri, Uri); + bool sameRange = loc.Range.Equals(Range); + return (sameUri && sameRange); + } + + /// + /// Overrides the base GetHashCode method + /// + /// + public override int GetHashCode() + { + int hash = 17; + hash = hash * 23 + Uri.GetHashCode(); + hash = hash * 23 + Range.GetHashCode(); + return hash; + } } public enum FileChangeType diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/LanguageServer/PeekDefinitionTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/LanguageServer/PeekDefinitionTests.cs index 1dfca799..34a01c69 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/LanguageServer/PeekDefinitionTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/LanguageServer/PeekDefinitionTests.cs @@ -1,716 +1,863 @@ -// -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. -// -using Microsoft.SqlServer.Management.Common; -using Microsoft.SqlServer.Management.SqlParser.Intellisense; -using Microsoft.SqlTools.ServiceLayer.Connection; -using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility; -using Microsoft.SqlTools.ServiceLayer.LanguageServices; -using Microsoft.SqlTools.ServiceLayer.Scripting; -using Microsoft.SqlTools.ServiceLayer.Test.Common; -using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; -using Moq; -using System; -using System.Data.Common; -using System.IO; -using System.Threading; -using Xunit; -using ConnectionType = Microsoft.SqlTools.ServiceLayer.Connection.ConnectionType; -using Location = Microsoft.SqlTools.ServiceLayer.Workspace.Contracts.Location; - -namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.LanguageServices -{ - /// - /// Tests for the language service peek definition/ go to definition feature - /// - public class PeekDefinitionTests - { - private const string OwnerUri = "testFile1"; - private const string ReturnTableFunctionName = "pd_returnTable"; - private const string ReturnTableTableFunctionQuery = @" -CREATE FUNCTION [dbo].[" + ReturnTableFunctionName + @"] () -RETURNS TABLE -AS -RETURN -( - select * from master.dbo.spt_monitor -); - -GO"; - - private const string AddTwoFunctionName = "pd_addTwo"; - private const string AddTwoFunctionQuery = @" -CREATE FUNCTION[dbo].[" + AddTwoFunctionName + @"](@number int) -RETURNS int -AS -BEGIN - RETURN @number + 2; - END; - -GO"; - - - private const string SsnTypeName = "pd_ssn"; - private const string SsnTypeQuery = @" -CREATE TYPE [dbo].[" + SsnTypeName + @"] FROM [varchar](11) NOT NULL -GO"; - - private const string LocationTableTypeName = "pd_locationTableType"; - - private const string LocationTableTypeQuery = @" -CREATE TYPE [dbo].[" + LocationTableTypeName + @"] AS TABLE( - [LocationName] [varchar](50) NULL, - [CostRate] [int] NULL -) -GO"; - - private const string TestTableSynonymName = "pd_testTable"; - private const string TestTableSynonymQuery = @" -CREATE SYNONYM [dbo].[pd_testTable] FOR master.dbo.spt_monitor -GO"; - - private const string TableValuedFunctionTypeName = "TableValuedFunction"; - private const string ScalarValuedFunctionTypeName = "ScalarValuedFunction"; - private const string UserDefinedDataTypeTypeName = "UserDefinedDataType"; - private const string UserDefinedTableTypeTypeName = "UserDefinedTableType"; - private const string SynonymTypeName = "Synonym"; - private const string StoredProcedureTypeName = "StoredProcedure"; - private const string ViewTypeName = "View"; - private const string TableTypeName = "Table"; - - /// - /// Test get definition for a table object with active connection - /// - [Fact] - public void GetValidTableDefinitionTest() - { - // Get live connectionInfo and serverConnection - ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); - ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); - - Scripter scripter = new Scripter(serverConnection, connInfo); - string objectName = "spt_monitor"; - - string schemaName = null; - string objectType = "TABLE"; - - // Get locations for valid table object - Location[] locations = scripter.GetSqlObjectDefinition(scripter.GetTableScripts, objectName, schemaName, objectType); - Assert.NotNull(locations); - Cleanup(locations); - } - - /// - /// Test get definition for a invalid table object with active connection - /// - [Fact] - public void GetTableDefinitionInvalidObjectTest() - { - // Get live connectionInfo and serverConnection - ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); - ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); - - Scripter scripter = new Scripter(serverConnection, connInfo); - string objectName = "test_invalid"; - string schemaName = null; - string objectType = "TABLE"; - - // Get locations for invalid table object - Location[] locations = scripter.GetSqlObjectDefinition(scripter.GetTableScripts, objectName, schemaName, objectType); - Assert.Null(locations); - } - - /// - /// Test get definition for a valid table object with schema and active connection - /// - [Fact] - public void GetTableDefinitionWithSchemaTest() - { - // Get live connectionInfo and serverConnection - ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); - ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); - - Scripter scripter = new Scripter(serverConnection, connInfo); - string objectName = "spt_monitor"; - - string schemaName = "dbo"; - string objectType = "TABLE"; - - // Get locations for valid table object with schema name - Location[] locations = scripter.GetSqlObjectDefinition(scripter.GetTableScripts, objectName, schemaName, objectType); - Assert.NotNull(locations); - Cleanup(locations); - } - - /// - /// Test GetDefinition with an unsupported type(schema - dbo). Expect a error result. - /// - [Fact] - public void GetUnsupportedDefinitionErrorTest() - { - ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); - ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); - - Scripter scripter = new Scripter(serverConnection, connInfo); - string objectName = "objects"; - string schemaName = "sys"; - // When I try to get definition for 'Collation' - DefinitionResult result = scripter.GetDefinitionUsingDeclarationType(DeclarationType.Collation, "master.sys.objects", objectName, schemaName); - // Then I expect non null result with error flag set - Assert.NotNull(result); - Assert.True(result.IsErrorResult); - } - - /// - /// Get Definition for a object with no definition. Expect a error result - /// - [Fact] - public void GetDefinitionWithNoResultsFoundError() - { - ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); - ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); - - Scripter scripter = new Scripter(serverConnection, connInfo); - string objectName = "from"; - Position position = new Position() - { - Line = 1, - Character = 14 - }; - ScriptParseInfo scriptParseInfo = new ScriptParseInfo() { IsConnected = true }; - Mock bindingContextMock = new Mock(); - DefinitionResult result = scripter.GetScript(scriptParseInfo.ParseResult, position, bindingContextMock.Object.MetadataDisplayInfoProvider, objectName, null); - - Assert.NotNull(result); - Assert.True(result.IsErrorResult); - Assert.Equal(SR.PeekDefinitionNoResultsError, result.Message); - } - - /// - /// Test GetDefinition with a forced timeout. Expect a error result. - /// - [Fact] - public void GetDefinitionTimeoutTest() - { - // Given a binding queue that will automatically time out - var languageService = new LanguageService(); - Mock queueMock = new Mock(); - languageService.BindingQueue = queueMock.Object; - ManualResetEvent mre = new ManualResetEvent(true); // Do not block - Mock itemMock = new Mock(); - itemMock.Setup(i => i.ItemProcessed).Returns(mre); - - DefinitionResult timeoutResult = null; - - queueMock.Setup(q => q.QueueBindingOperation( - It.IsAny(), - It.IsAny>(), - It.IsAny>(), - It.IsAny(), - It.IsAny())) - .Callback, Func, int?, int?>( - (key, bindOperation, timeoutOperation, t1, t2) => - { - timeoutResult = (DefinitionResult)timeoutOperation((IBindingContext)null); - itemMock.Object.Result = timeoutResult; - }) - .Returns(() => itemMock.Object); - - TextDocumentPosition textDocument = new TextDocumentPosition - { - TextDocument = new TextDocumentIdentifier { Uri = OwnerUri }, - Position = new Position - { - Line = 0, - Character = 20 - } - }; - LiveConnectionHelper.TestConnectionResult connectionResult = LiveConnectionHelper.InitLiveConnectionInfo(); - ScriptFile scriptFile = connectionResult.ScriptFile; - ConnectionInfo connInfo = connectionResult.ConnectionInfo; - scriptFile.Contents = "select * from dbo.func ()"; - - ScriptParseInfo scriptInfo = new ScriptParseInfo { IsConnected = true }; - languageService.ScriptParseInfoMap.Add(OwnerUri, scriptInfo); - - // When I call the language service - var result = languageService.GetDefinition(textDocument, scriptFile, connInfo); - - // Then I expect null locations and an error to be reported - Assert.NotNull(result); - Assert.True(result.IsErrorResult); - // Check timeout message - Assert.Equal(SR.PeekDefinitionTimedoutError, result.Message); - } - - /// - /// Test get definition for a view object with active connection - /// - [Fact] - public void GetValidViewDefinitionTest() - { - ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); - ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); - - Scripter scripter = new Scripter(serverConnection, connInfo); - string objectName = "objects"; - string schemaName = "sys"; - string objectType = "VIEW"; - - Location[] locations = scripter.GetSqlObjectDefinition(scripter.GetViewScripts, objectName, schemaName, objectType); - Assert.NotNull(locations); - Cleanup(locations); - } - - /// - /// Test get definition for an invalid view object with no schema name and with active connection - /// - [Fact] - public void GetViewDefinitionInvalidObjectTest() - { - // Get live connectionInfo and serverConnection - ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); - ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); - - Scripter scripter = new Scripter(serverConnection, connInfo); - string objectName = "objects"; - string schemaName = null; - string objectType = "VIEW"; - - Location[] locations = scripter.GetSqlObjectDefinition(scripter.GetViewScripts, objectName, schemaName, objectType); - Assert.Null(locations); - } - - /// - /// Test get definition for a stored procedure object with active connection - /// - [Fact] - public void GetStoredProcedureDefinitionTest() - { - // Get live connectionInfo and serverConnection - ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); - ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); - - Scripter scripter = new Scripter(serverConnection, connInfo); - string objectName = "sp_MSrepl_startup"; - - string schemaName = "dbo"; - string objectType = "PROCEDURE"; - - Location[] locations = scripter.GetSqlObjectDefinition(scripter.GetStoredProcedureScripts, objectName, schemaName, objectType); - Assert.NotNull(locations); - Cleanup(locations); - } - - /// - /// Test get definition for a stored procedure object that does not exist with active connection - /// - [Fact] - public void GetStoredProcedureDefinitionFailureTest() - { - // Get live connectionInfo and serverConnection - ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); - ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); - - Scripter scripter = new Scripter(serverConnection, connInfo); - string objectName = "SP2"; - string schemaName = "dbo"; - string objectType = "PROCEDURE"; - - Location[] locations = scripter.GetSqlObjectDefinition(scripter.GetStoredProcedureScripts, objectName, schemaName, objectType); - Assert.Null(locations); - } - - /// - /// Test get definition for a stored procedure object with active connection and no schema - /// - [Fact] - public void GetStoredProcedureDefinitionWithoutSchemaTest() - { - // Get live connectionInfo and serverConnection - ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); - ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); - - Scripter scripter = new Scripter(serverConnection, connInfo); - string objectName = "sp_MSrepl_startup"; - string schemaName = null; - string objectType = "PROCEDURE"; - - Location[] locations = scripter.GetSqlObjectDefinition(scripter.GetStoredProcedureScripts, objectName, schemaName, objectType); - Assert.NotNull(locations); - Cleanup(locations); - } - - /// - /// Test get definition for a scalar valued function object with active connection and explicit schema name. Expect non-null locations - /// - [Fact] - public void GetScalarValuedFunctionDefinitionWithSchemaNameSuccessTest() - { - ExecuteAndValidatePeekTest(AddTwoFunctionQuery, AddTwoFunctionName, ScalarValuedFunctionTypeName); - } - - private void ExecuteAndValidatePeekTest(string query, string objectName, string objectType, string schemaName = "dbo") - { - if (!string.IsNullOrEmpty(query)) - { - using (SqlTestDb testDb = SqlTestDb.CreateNew(TestServerType.OnPrem, query)) - { - ValidatePeekTest(testDb.DatabaseName, objectName, objectType, schemaName, true); - } - } - else - { - ValidatePeekTest(null, objectName, objectType, schemaName, false); - } - } - - private void ValidatePeekTest(string databaseName, string objectName, string objectType, string schemaName, bool shouldReturnValidResult) - { - // Get live connectionInfo and serverConnection - ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(databaseName); - ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); - - Scripter scripter = new Scripter(serverConnection, connInfo); - - Scripter.ScriptGetter sqlScriptGetter = null; - switch (objectType) - { - case SynonymTypeName: - sqlScriptGetter = scripter.GetSynonymScripts; - break; - case ScalarValuedFunctionTypeName: - sqlScriptGetter = scripter.GetScalarValuedFunctionScripts; - objectType = "Function"; - break; - case TableValuedFunctionTypeName: - sqlScriptGetter = scripter.GetTableValuedFunctionScripts; - objectType = "Function"; - break; - case TableTypeName: - sqlScriptGetter = scripter.GetTableScripts; - break; - case ViewTypeName: - sqlScriptGetter = scripter.GetViewScripts; - break; - case StoredProcedureTypeName: - sqlScriptGetter = scripter.GetStoredProcedureScripts; - break; - case UserDefinedDataTypeTypeName: - sqlScriptGetter = scripter.GetUserDefinedDataTypeScripts; - objectType = "Type"; - break; - case UserDefinedTableTypeTypeName: - sqlScriptGetter = scripter.GetUserDefinedTableTypeScripts; - objectType = "Type"; - break; - } - - Location[] locations = scripter.GetSqlObjectDefinition(sqlScriptGetter, objectName, schemaName, objectType); - if (shouldReturnValidResult) - { - Assert.NotNull(locations); - Cleanup(locations); - } - else - { - Assert.Null(locations); - } - } - - /// - /// Test get definition for a table valued function object with active connection and explicit schema name. Expect non-null locations - /// - [Fact] - public void GetTableValuedFunctionDefinitionWithSchemaNameSuccessTest() - { - ExecuteAndValidatePeekTest(ReturnTableTableFunctionQuery, ReturnTableFunctionName, TableValuedFunctionTypeName); - } - - /// - /// Test get definition for a scalar valued function object that doesn't exist with active connection. Expect null locations - /// - [Fact] - public void GetScalarValuedFunctionDefinitionWithNonExistentFailureTest() - { - string objectName = "doesNotExist"; - string schemaName = "dbo"; - string objectType = ScalarValuedFunctionTypeName; - - ExecuteAndValidatePeekTest(null, objectName, objectType, schemaName); - } - - /// - /// Test get definition for a table valued function object that doesn't exist with active connection. Expect null locations - /// - [Fact] - public void GetTableValuedFunctionDefinitionWithNonExistentObjectFailureTest() - { - string objectName = "doesNotExist"; - string schemaName = "dbo"; - string objectType = TableValuedFunctionTypeName; - ExecuteAndValidatePeekTest(null, objectName, objectType, schemaName); - } - - /// - /// Test get definition for a scalar valued function object with active connection. Expect non-null locations - /// - [Fact] - public void GetScalarValuedFunctionDefinitionWithoutSchemaNameSuccessTest() - { - ExecuteAndValidatePeekTest(AddTwoFunctionQuery, AddTwoFunctionName, ScalarValuedFunctionTypeName, null); - } - - /// - /// Test get definition for a table valued function object with active connection. Expect non-null locations - /// - [Fact] - public void GetTableValuedFunctionDefinitionWithoutSchemaNameSuccessTest() - { - ExecuteAndValidatePeekTest(ReturnTableTableFunctionQuery, ReturnTableFunctionName, TableValuedFunctionTypeName, null); - } - - - /// - /// Test get definition for a user defined data type object with active connection and explicit schema name. Expect non-null locations - /// - [Fact] - public void GetUserDefinedDataTypeDefinitionWithSchemaNameSuccessTest() - { - ExecuteAndValidatePeekTest(SsnTypeQuery, SsnTypeName, UserDefinedDataTypeTypeName); - } - - /// - /// Test get definition for a user defined data type object with active connection. Expect non-null locations - /// - [Fact] - public void GetUserDefinedDataTypeDefinitionWithoutSchemaNameSuccessTest() - { - ExecuteAndValidatePeekTest(SsnTypeQuery, SsnTypeName, UserDefinedDataTypeTypeName, null); - } - - /// - /// Test get definition for a user defined data type object that doesn't exist with active connection. Expect null locations - /// - [Fact] - public void GetUserDefinedDataTypeDefinitionWithNonExistentFailureTest() - { - string objectName = "doesNotExist"; - string schemaName = "dbo"; - string objectType = UserDefinedDataTypeTypeName; - ExecuteAndValidatePeekTest(null, objectName, objectType, schemaName); - } - - /// - /// Test get definition for a user defined table type object with active connection and explicit schema name. Expect non-null locations - /// - [Fact] - public void GetUserDefinedTableTypeDefinitionWithSchemaNameSuccessTest() - { - ExecuteAndValidatePeekTest(LocationTableTypeQuery, LocationTableTypeName, UserDefinedTableTypeTypeName); - } - - /// - /// Test get definition for a user defined table type object with active connection. Expect non-null locations - /// - [Fact] - public void GetUserDefinedTableTypeDefinitionWithoutSchemaNameSuccessTest() - { - ExecuteAndValidatePeekTest(LocationTableTypeQuery, LocationTableTypeName, UserDefinedTableTypeTypeName, null); - } - - /// - /// Test get definition for a user defined table type object that doesn't exist with active connection. Expect null locations - /// - [Fact] - public void GetUserDefinedTableTypeDefinitionWithNonExistentFailureTest() - { - string objectName = "doesNotExist"; - string schemaName = "dbo"; - string objectType = UserDefinedTableTypeTypeName; - ExecuteAndValidatePeekTest(null, objectName, objectType, schemaName); - - } - - /// - /// Test get definition for a synonym object with active connection and explicit schema name. Expect non-null locations - /// - [Fact] - public void GetSynonymDefinitionWithSchemaNameSuccessTest() - { - ExecuteAndValidatePeekTest(TestTableSynonymQuery, TestTableSynonymName, SynonymTypeName); - } - - - /// - /// Test get definition for a Synonym object with active connection. Expect non-null locations - /// - [Fact] - public void GetSynonymDefinitionWithoutSchemaNameSuccessTest() - { - ExecuteAndValidatePeekTest(TestTableSynonymQuery, TestTableSynonymName, SynonymTypeName, null); - } - - /// - /// Test get definition for a Synonym object that doesn't exist with active connection. Expect null locations - /// - [Fact] - public void GetSynonymDefinitionWithNonExistentFailureTest() - { - string objectName = "doesNotExist"; - string schemaName = "dbo"; - string objectType = "Synonym"; - ExecuteAndValidatePeekTest(null, objectName, objectType, schemaName); - } - - /// - /// Test get definition using declaration type for a view object with active connection - /// Expect a non-null result with location - /// - [Fact] - public void GetDefinitionUsingDeclarationTypeWithValidObjectTest() - { - ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); - ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); - - Scripter scripter = new Scripter(serverConnection, connInfo); - string objectName = "objects"; - string schemaName = "sys"; - - DefinitionResult result = scripter.GetDefinitionUsingDeclarationType(DeclarationType.View, "master.sys.objects", objectName, schemaName); - Assert.NotNull(result); - Assert.NotNull(result.Locations); - Assert.False(result.IsErrorResult); - Cleanup(result.Locations); - - } - - /// - /// Test get definition using declaration type for a non existent view object with active connection - /// Expect a non-null result with location - /// - [Fact] - public void GetDefinitionUsingDeclarationTypeWithNonexistentObjectTest() - { - ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); - ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); - - Scripter scripter = new Scripter(serverConnection, connInfo); - string objectName = "doesNotExist"; - string schemaName = "sys"; - - DefinitionResult result = scripter.GetDefinitionUsingDeclarationType(DeclarationType.View, "master.sys.objects", objectName, schemaName); - Assert.NotNull(result); - Assert.True(result.IsErrorResult); - } - - /// - /// Test get definition using quickInfo text for a view object with active connection - /// Expect a non-null result with location - /// - [Fact] - public void GetDefinitionUsingQuickInfoTextWithValidObjectTest() - { - ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); - ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); - - Scripter scripter = new Scripter(serverConnection, connInfo); - string objectName = "objects"; - string schemaName = "sys"; - string quickInfoText = "view master.sys.objects"; - - DefinitionResult result = scripter.GetDefinitionUsingQuickInfoText(quickInfoText, objectName, schemaName); - Assert.NotNull(result); - Assert.NotNull(result.Locations); - Assert.False(result.IsErrorResult); - Cleanup(result.Locations); - - } - - /// - /// Test get definition using quickInfo text for a view object with active connection - /// Expect a non-null result with location - /// - [Fact] - public void GetDefinitionUsingQuickInfoTextWithNonexistentObjectTest() - { - ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); - ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); - - Scripter scripter = new Scripter(serverConnection, connInfo); - string objectName = "doesNotExist"; - string schemaName = "sys"; - string quickInfoText = "view master.sys.objects"; - - DefinitionResult result = scripter.GetDefinitionUsingQuickInfoText(quickInfoText, objectName, schemaName); - Assert.NotNull(result); - Assert.True(result.IsErrorResult); - } - - /// - /// Test if peek definition default database name is the default server connection database name - /// Given that there is no query connection - /// Expect database name to be "master" - /// - [Fact] - public void GetDatabaseWithNoQueryConnectionTest() - { - ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); - ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); - DbConnection connection; - //Check if query connection is present - Assert.False(connInfo.TryGetConnection(ConnectionType.Query, out connection)); - - Scripter scripter = new Scripter(serverConnection, connInfo); - //Check if database name is the default server connection database name - Assert.Equal(scripter.Database.Name, "master"); - } - - /// - /// Test if the peek definition database name changes to the query connection database name - /// Give that there is a query connection - /// Expect database name to be query connection's database name - /// - [Fact] - public void GetDatabaseWithQueryConnectionTest() - { - ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); - ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); - //Mock a query connection object - var mockQueryConnection = new Mock { CallBase = true }; - mockQueryConnection.SetupGet(x => x.Database).Returns("testdb"); - connInfo.ConnectionTypeToConnectionMap[ConnectionType.Query] = mockQueryConnection.Object; - DbConnection connection; - //Check if query connection is present - Assert.True(connInfo.TryGetConnection(ConnectionType.Query, out connection)); - - Scripter scripter = new Scripter(serverConnection, connInfo); - //Check if database name is the database name in the query connection - Assert.Equal(scripter.Database.Name, "testdb"); - - // remove mock from ConnectionInfo - Assert.True(connInfo.ConnectionTypeToConnectionMap.TryRemove(ConnectionType.Query, out connection)); - } - - - /// - /// Helper method to clean up script files - /// - private void Cleanup(Location[] locations) - { - Uri fileUri = new Uri(locations[0].Uri); - if (File.Exists(fileUri.LocalPath)) - { - try - { - File.Delete(fileUri.LocalPath); - } - catch (Exception) - { - - } - } - } - } -} +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// +using Microsoft.SqlServer.Management.Common; +using Microsoft.SqlServer.Management.SqlParser.Intellisense; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility; +using Microsoft.SqlTools.ServiceLayer.LanguageServices; +using Microsoft.SqlTools.ServiceLayer.Scripting; +using Microsoft.SqlTools.ServiceLayer.Test.Common; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; +using Moq; +using System; +using System.Data.Common; +using System.IO; +using System.Threading; +using Xunit; +using ConnectionType = Microsoft.SqlTools.ServiceLayer.Connection.ConnectionType; +using Location = Microsoft.SqlTools.ServiceLayer.Workspace.Contracts.Location; +using System.Collections.Generic; + +namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.LanguageServices +{ + /// + /// Tests for the language service peek definition/ go to definition feature + /// + public class PeekDefinitionTests + { + private const string OwnerUri = "testFile1"; + private const string TestUri = "testFile2"; + private const string ReturnTableFunctionName = "pd_returnTable"; + private const string ReturnTableTableFunctionQuery = @" +CREATE FUNCTION [dbo].[" + ReturnTableFunctionName + @"] () +RETURNS TABLE +AS +RETURN +( + select * from master.dbo.spt_monitor +); + +GO"; + + private const string AddTwoFunctionName = "pd_addTwo"; + private const string AddTwoFunctionQuery = @" +CREATE FUNCTION[dbo].[" + AddTwoFunctionName + @"](@number int) +RETURNS int +AS +BEGIN + RETURN @number + 2; + END; + +GO"; + + + private const string SsnTypeName = "pd_ssn"; + private const string SsnTypeQuery = @" +CREATE TYPE [dbo].[" + SsnTypeName + @"] FROM [varchar](11) NOT NULL +GO"; + + private const string LocationTableTypeName = "pd_locationTableType"; + + private const string LocationTableTypeQuery = @" +CREATE TYPE [dbo].[" + LocationTableTypeName + @"] AS TABLE( + [LocationName] [varchar](50) NULL, + [CostRate] [int] NULL +) +GO"; + + private const string TestTableSynonymName = "pd_testTable"; + private const string TestTableSynonymQuery = @" +CREATE SYNONYM [dbo].[pd_testTable] FOR master.dbo.spt_monitor +GO"; + + private const string TableValuedFunctionTypeName = "TableValuedFunction"; + private const string ScalarValuedFunctionTypeName = "ScalarValuedFunction"; + private const string UserDefinedDataTypeTypeName = "UserDefinedDataType"; + private const string UserDefinedTableTypeTypeName = "UserDefinedTableType"; + private const string SynonymTypeName = "Synonym"; + private const string StoredProcedureTypeName = "StoredProcedure"; + private const string ViewTypeName = "View"; + private const string TableTypeName = "Table"; + + /// + /// Test get definition for a table object with active connection + /// + [Fact] + public void GetValidTableDefinitionTest() + { + // Get live connectionInfo and serverConnection + ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); + ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); + + Scripter scripter = new Scripter(serverConnection, connInfo); + string objectName = "spt_monitor"; + + string schemaName = null; + string objectType = "TABLE"; + + // Get locations for valid table object + Location[] locations = scripter.GetSqlObjectDefinition(scripter.GetTableScripts, objectName, schemaName, objectType); + Assert.NotNull(locations); + Cleanup(locations); + } + + /// + /// Test get definition for a invalid table object with active connection + /// + [Fact] + public void GetTableDefinitionInvalidObjectTest() + { + // Get live connectionInfo and serverConnection + ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); + ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); + + Scripter scripter = new Scripter(serverConnection, connInfo); + string objectName = "test_invalid"; + string schemaName = null; + string objectType = "TABLE"; + + // Get locations for invalid table object + Location[] locations = scripter.GetSqlObjectDefinition(scripter.GetTableScripts, objectName, schemaName, objectType); + Assert.Null(locations); + } + + /// + /// Test get definition for a valid table object with schema and active connection + /// + [Fact] + public void GetTableDefinitionWithSchemaTest() + { + // Get live connectionInfo and serverConnection + ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); + ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); + + Scripter scripter = new Scripter(serverConnection, connInfo); + string objectName = "spt_monitor"; + + string schemaName = "dbo"; + string objectType = "TABLE"; + + // Get locations for valid table object with schema name + Location[] locations = scripter.GetSqlObjectDefinition(scripter.GetTableScripts, objectName, schemaName, objectType); + Assert.NotNull(locations); + Cleanup(locations); + } + + /// + /// Test GetDefinition with an unsupported type(schema - dbo). Expect a error result. + /// + [Fact] + public void GetUnsupportedDefinitionErrorTest() + { + ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); + ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); + + Scripter scripter = new Scripter(serverConnection, connInfo); + string objectName = "objects"; + string schemaName = "sys"; + // When I try to get definition for 'Collation' + DefinitionResult result = scripter.GetDefinitionUsingDeclarationType(DeclarationType.Collation, "master.sys.objects", objectName, schemaName); + // Then I expect non null result with error flag set + Assert.NotNull(result); + Assert.True(result.IsErrorResult); + } + + /// + /// Get Definition for a object with no definition. Expect a error result + /// + [Fact] + public void GetDefinitionWithNoResultsFoundError() + { + ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); + ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); + + Scripter scripter = new Scripter(serverConnection, connInfo); + string objectName = "from"; + Position position = new Position() + { + Line = 1, + Character = 14 + }; + ScriptParseInfo scriptParseInfo = new ScriptParseInfo() { IsConnected = true }; + Mock bindingContextMock = new Mock(); + DefinitionResult result = scripter.GetScript(scriptParseInfo.ParseResult, position, bindingContextMock.Object.MetadataDisplayInfoProvider, objectName, null); + + Assert.NotNull(result); + Assert.True(result.IsErrorResult); + Assert.Equal(SR.PeekDefinitionNoResultsError, result.Message); + } + + /// + /// Test GetDefinition with a forced timeout. Expect a error result. + /// + [Fact] + public void GetDefinitionTimeoutTest() + { + // Given a binding queue that will automatically time out + var languageService = new LanguageService(); + Mock queueMock = new Mock(); + languageService.BindingQueue = queueMock.Object; + ManualResetEvent mre = new ManualResetEvent(true); // Do not block + Mock itemMock = new Mock(); + itemMock.Setup(i => i.ItemProcessed).Returns(mre); + + DefinitionResult timeoutResult = null; + + queueMock.Setup(q => q.QueueBindingOperation( + It.IsAny(), + It.IsAny>(), + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Callback, Func, int?, int?>( + (key, bindOperation, timeoutOperation, t1, t2) => + { + timeoutResult = (DefinitionResult)timeoutOperation((IBindingContext)null); + itemMock.Object.Result = timeoutResult; + }) + .Returns(() => itemMock.Object); + + TextDocumentPosition textDocument = new TextDocumentPosition + { + TextDocument = new TextDocumentIdentifier { Uri = OwnerUri }, + Position = new Position + { + Line = 0, + Character = 20 + } + }; + LiveConnectionHelper.TestConnectionResult connectionResult = LiveConnectionHelper.InitLiveConnectionInfo(); + ScriptFile scriptFile = connectionResult.ScriptFile; + ConnectionInfo connInfo = connectionResult.ConnectionInfo; + scriptFile.Contents = "select * from dbo.func ()"; + + ScriptParseInfo scriptInfo = new ScriptParseInfo { IsConnected = true }; + languageService.ScriptParseInfoMap.Add(OwnerUri, scriptInfo); + + // When I call the language service + var result = languageService.GetDefinition(textDocument, scriptFile, connInfo); + + // Then I expect null locations and an error to be reported + Assert.NotNull(result); + Assert.True(result.IsErrorResult); + // Check timeout message + Assert.Equal(SR.PeekDefinitionTimedoutError, result.Message); + } + + /// + /// Test get definition for a view object with active connection + /// + [Fact] + public void GetValidViewDefinitionTest() + { + ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); + ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); + + Scripter scripter = new Scripter(serverConnection, connInfo); + string objectName = "objects"; + string schemaName = "sys"; + string objectType = "VIEW"; + + Location[] locations = scripter.GetSqlObjectDefinition(scripter.GetViewScripts, objectName, schemaName, objectType); + Assert.NotNull(locations); + Cleanup(locations); + } + + /// + /// Test get definition for an invalid view object with no schema name and with active connection + /// + [Fact] + public void GetViewDefinitionInvalidObjectTest() + { + // Get live connectionInfo and serverConnection + ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); + ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); + + Scripter scripter = new Scripter(serverConnection, connInfo); + string objectName = "objects"; + string schemaName = null; + string objectType = "VIEW"; + + Location[] locations = scripter.GetSqlObjectDefinition(scripter.GetViewScripts, objectName, schemaName, objectType); + Assert.Null(locations); + } + + /// + /// Test get definition for a stored procedure object with active connection + /// + [Fact] + public void GetStoredProcedureDefinitionTest() + { + // Get live connectionInfo and serverConnection + ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); + ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); + + Scripter scripter = new Scripter(serverConnection, connInfo); + string objectName = "sp_MSrepl_startup"; + + string schemaName = "dbo"; + string objectType = "PROCEDURE"; + + Location[] locations = scripter.GetSqlObjectDefinition(scripter.GetStoredProcedureScripts, objectName, schemaName, objectType); + Assert.NotNull(locations); + Cleanup(locations); + } + + /// + /// Test get definition for a stored procedure object that does not exist with active connection + /// + [Fact] + public void GetStoredProcedureDefinitionFailureTest() + { + // Get live connectionInfo and serverConnection + ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); + ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); + + Scripter scripter = new Scripter(serverConnection, connInfo); + string objectName = "SP2"; + string schemaName = "dbo"; + string objectType = "PROCEDURE"; + + Location[] locations = scripter.GetSqlObjectDefinition(scripter.GetStoredProcedureScripts, objectName, schemaName, objectType); + Assert.Null(locations); + } + + /// + /// Test get definition for a stored procedure object with active connection and no schema + /// + [Fact] + public void GetStoredProcedureDefinitionWithoutSchemaTest() + { + // Get live connectionInfo and serverConnection + ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); + ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); + + Scripter scripter = new Scripter(serverConnection, connInfo); + string objectName = "sp_MSrepl_startup"; + string schemaName = null; + string objectType = "PROCEDURE"; + + Location[] locations = scripter.GetSqlObjectDefinition(scripter.GetStoredProcedureScripts, objectName, schemaName, objectType); + Assert.NotNull(locations); + Cleanup(locations); + } + + /// + /// Test get definition for a scalar valued function object with active connection and explicit schema name. Expect non-null locations + /// + [Fact] + public void GetScalarValuedFunctionDefinitionWithSchemaNameSuccessTest() + { + ExecuteAndValidatePeekTest(AddTwoFunctionQuery, AddTwoFunctionName, ScalarValuedFunctionTypeName); + } + + private void ExecuteAndValidatePeekTest(string query, string objectName, string objectType, string schemaName = "dbo") + { + if (!string.IsNullOrEmpty(query)) + { + using (SqlTestDb testDb = SqlTestDb.CreateNew(TestServerType.OnPrem, query)) + { + ValidatePeekTest(testDb.DatabaseName, objectName, objectType, schemaName, true); + } + } + else + { + ValidatePeekTest(null, objectName, objectType, schemaName, false); + } + } + + private void ValidatePeekTest(string databaseName, string objectName, string objectType, string schemaName, bool shouldReturnValidResult) + { + // Get live connectionInfo and serverConnection + ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(databaseName); + ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); + + Scripter scripter = new Scripter(serverConnection, connInfo); + + Scripter.ScriptGetter sqlScriptGetter = null; + switch (objectType) + { + case SynonymTypeName: + sqlScriptGetter = scripter.GetSynonymScripts; + break; + case ScalarValuedFunctionTypeName: + sqlScriptGetter = scripter.GetScalarValuedFunctionScripts; + objectType = "Function"; + break; + case TableValuedFunctionTypeName: + sqlScriptGetter = scripter.GetTableValuedFunctionScripts; + objectType = "Function"; + break; + case TableTypeName: + sqlScriptGetter = scripter.GetTableScripts; + break; + case ViewTypeName: + sqlScriptGetter = scripter.GetViewScripts; + break; + case StoredProcedureTypeName: + sqlScriptGetter = scripter.GetStoredProcedureScripts; + break; + case UserDefinedDataTypeTypeName: + sqlScriptGetter = scripter.GetUserDefinedDataTypeScripts; + objectType = "Type"; + break; + case UserDefinedTableTypeTypeName: + sqlScriptGetter = scripter.GetUserDefinedTableTypeScripts; + objectType = "Type"; + break; + } + + Location[] locations = scripter.GetSqlObjectDefinition(sqlScriptGetter, objectName, schemaName, objectType); + if (shouldReturnValidResult) + { + Assert.NotNull(locations); + Cleanup(locations); + } + else + { + Assert.Null(locations); + } + } + + /// + /// Test get definition for a table valued function object with active connection and explicit schema name. Expect non-null locations + /// + [Fact] + public void GetTableValuedFunctionDefinitionWithSchemaNameSuccessTest() + { + ExecuteAndValidatePeekTest(ReturnTableTableFunctionQuery, ReturnTableFunctionName, TableValuedFunctionTypeName); + } + + /// + /// Test get definition for a scalar valued function object that doesn't exist with active connection. Expect null locations + /// + [Fact] + public void GetScalarValuedFunctionDefinitionWithNonExistentFailureTest() + { + string objectName = "doesNotExist"; + string schemaName = "dbo"; + string objectType = ScalarValuedFunctionTypeName; + + ExecuteAndValidatePeekTest(null, objectName, objectType, schemaName); + } + + /// + /// Test get definition for a table valued function object that doesn't exist with active connection. Expect null locations + /// + [Fact] + public void GetTableValuedFunctionDefinitionWithNonExistentObjectFailureTest() + { + string objectName = "doesNotExist"; + string schemaName = "dbo"; + string objectType = TableValuedFunctionTypeName; + ExecuteAndValidatePeekTest(null, objectName, objectType, schemaName); + } + + /// + /// Test get definition for a scalar valued function object with active connection. Expect non-null locations + /// + [Fact] + public void GetScalarValuedFunctionDefinitionWithoutSchemaNameSuccessTest() + { + ExecuteAndValidatePeekTest(AddTwoFunctionQuery, AddTwoFunctionName, ScalarValuedFunctionTypeName, null); + } + + /// + /// Test get definition for a table valued function object with active connection. Expect non-null locations + /// + [Fact] + public void GetTableValuedFunctionDefinitionWithoutSchemaNameSuccessTest() + { + ExecuteAndValidatePeekTest(ReturnTableTableFunctionQuery, ReturnTableFunctionName, TableValuedFunctionTypeName, null); + } + + + /// + /// Test get definition for a user defined data type object with active connection and explicit schema name. Expect non-null locations + /// + [Fact] + public void GetUserDefinedDataTypeDefinitionWithSchemaNameSuccessTest() + { + ExecuteAndValidatePeekTest(SsnTypeQuery, SsnTypeName, UserDefinedDataTypeTypeName); + } + + /// + /// Test get definition for a user defined data type object with active connection. Expect non-null locations + /// + [Fact] + public void GetUserDefinedDataTypeDefinitionWithoutSchemaNameSuccessTest() + { + ExecuteAndValidatePeekTest(SsnTypeQuery, SsnTypeName, UserDefinedDataTypeTypeName, null); + } + + /// + /// Test get definition for a user defined data type object that doesn't exist with active connection. Expect null locations + /// + [Fact] + public void GetUserDefinedDataTypeDefinitionWithNonExistentFailureTest() + { + string objectName = "doesNotExist"; + string schemaName = "dbo"; + string objectType = UserDefinedDataTypeTypeName; + ExecuteAndValidatePeekTest(null, objectName, objectType, schemaName); + } + + /// + /// Test get definition for a user defined table type object with active connection and explicit schema name. Expect non-null locations + /// + [Fact] + public void GetUserDefinedTableTypeDefinitionWithSchemaNameSuccessTest() + { + ExecuteAndValidatePeekTest(LocationTableTypeQuery, LocationTableTypeName, UserDefinedTableTypeTypeName); + } + + /// + /// Test get definition for a user defined table type object with active connection. Expect non-null locations + /// + [Fact] + public void GetUserDefinedTableTypeDefinitionWithoutSchemaNameSuccessTest() + { + ExecuteAndValidatePeekTest(LocationTableTypeQuery, LocationTableTypeName, UserDefinedTableTypeTypeName, null); + } + + /// + /// Test get definition for a user defined table type object that doesn't exist with active connection. Expect null locations + /// + [Fact] + public void GetUserDefinedTableTypeDefinitionWithNonExistentFailureTest() + { + string objectName = "doesNotExist"; + string schemaName = "dbo"; + string objectType = UserDefinedTableTypeTypeName; + ExecuteAndValidatePeekTest(null, objectName, objectType, schemaName); + + } + + /// + /// Test get definition for a synonym object with active connection and explicit schema name. Expect non-null locations + /// + [Fact] + public void GetSynonymDefinitionWithSchemaNameSuccessTest() + { + ExecuteAndValidatePeekTest(TestTableSynonymQuery, TestTableSynonymName, SynonymTypeName); + } + + + /// + /// Test get definition for a Synonym object with active connection. Expect non-null locations + /// + [Fact] + public void GetSynonymDefinitionWithoutSchemaNameSuccessTest() + { + ExecuteAndValidatePeekTest(TestTableSynonymQuery, TestTableSynonymName, SynonymTypeName, null); + } + + /// + /// Test get definition for a Synonym object that doesn't exist with active connection. Expect null locations + /// + [Fact] + public void GetSynonymDefinitionWithNonExistentFailureTest() + { + string objectName = "doesNotExist"; + string schemaName = "dbo"; + string objectType = "Synonym"; + ExecuteAndValidatePeekTest(null, objectName, objectType, schemaName); + } + + /// + /// Test get definition using declaration type for a view object with active connection + /// Expect a non-null result with location + /// + [Fact] + public void GetDefinitionUsingDeclarationTypeWithValidObjectTest() + { + ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); + ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); + + Scripter scripter = new Scripter(serverConnection, connInfo); + string objectName = "objects"; + string schemaName = "sys"; + + DefinitionResult result = scripter.GetDefinitionUsingDeclarationType(DeclarationType.View, "master.sys.objects", objectName, schemaName); + Assert.NotNull(result); + Assert.NotNull(result.Locations); + Assert.False(result.IsErrorResult); + Cleanup(result.Locations); + + } + + /// + /// Test get definition using declaration type for a non existent view object with active connection + /// Expect a non-null result with location + /// + [Fact] + public void GetDefinitionUsingDeclarationTypeWithNonexistentObjectTest() + { + ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); + ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); + + Scripter scripter = new Scripter(serverConnection, connInfo); + string objectName = "doesNotExist"; + string schemaName = "sys"; + + DefinitionResult result = scripter.GetDefinitionUsingDeclarationType(DeclarationType.View, "master.sys.objects", objectName, schemaName); + Assert.NotNull(result); + Assert.True(result.IsErrorResult); + } + + /// + /// Test get definition using quickInfo text for a view object with active connection + /// Expect a non-null result with location + /// + [Fact] + public void GetDefinitionUsingQuickInfoTextWithValidObjectTest() + { + ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); + ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); + + Scripter scripter = new Scripter(serverConnection, connInfo); + string objectName = "objects"; + string schemaName = "sys"; + string quickInfoText = "view master.sys.objects"; + + DefinitionResult result = scripter.GetDefinitionUsingQuickInfoText(quickInfoText, objectName, schemaName); + Assert.NotNull(result); + Assert.NotNull(result.Locations); + Assert.False(result.IsErrorResult); + Cleanup(result.Locations); + + } + + /// + /// Test get definition using quickInfo text for a view object with active connection + /// Expect a non-null result with location + /// + [Fact] + public void GetDefinitionUsingQuickInfoTextWithNonexistentObjectTest() + { + ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); + ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); + + Scripter scripter = new Scripter(serverConnection, connInfo); + string objectName = "doesNotExist"; + string schemaName = "sys"; + string quickInfoText = "view master.sys.objects"; + + DefinitionResult result = scripter.GetDefinitionUsingQuickInfoText(quickInfoText, objectName, schemaName); + Assert.NotNull(result); + Assert.True(result.IsErrorResult); + } + + /// + /// Test if peek definition default database name is the default server connection database name + /// Given that there is no query connection + /// Expect database name to be "master" + /// + [Fact] + public void GetDatabaseWithNoQueryConnectionTest() + { + ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); + ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); + DbConnection connection; + //Check if query connection is present + Assert.False(connInfo.TryGetConnection(ConnectionType.Query, out connection)); + + Scripter scripter = new Scripter(serverConnection, connInfo); + //Check if database name is the default server connection database name + Assert.Equal(scripter.Database.Name, "master"); + } + + /// + /// Test if the peek definition database name changes to the query connection database name + /// Give that there is a query connection + /// Expect database name to be query connection's database name + /// + [Fact] + public void GetDatabaseWithQueryConnectionTest() + { + ConnectionInfo connInfo = LiveConnectionHelper.InitLiveConnectionInfoForDefinition(); + ServerConnection serverConnection = LiveConnectionHelper.InitLiveServerConnectionForDefinition(connInfo); + //Mock a query connection object + var mockQueryConnection = new Mock { CallBase = true }; + mockQueryConnection.SetupGet(x => x.Database).Returns("testdb"); + connInfo.ConnectionTypeToConnectionMap[ConnectionType.Query] = mockQueryConnection.Object; + DbConnection connection; + //Check if query connection is present + Assert.True(connInfo.TryGetConnection(ConnectionType.Query, out connection)); + + Scripter scripter = new Scripter(serverConnection, connInfo); + //Check if database name is the database name in the query connection + Assert.Equal(scripter.Database.Name, "testdb"); + + // remove mock from ConnectionInfo + Assert.True(connInfo.ConnectionTypeToConnectionMap.TryRemove(ConnectionType.Query, out connection)); + } + + /// + /// Get Definition for a object with no definition. Expect a error result + /// + [Fact] + public async void GetDefinitionFromChildrenAndParents() + { + string queryString = "select * from master.sys.objects"; + + // place the cursor on every token + + //cursor on objects + TextDocumentPosition objectDocument = CreateTextDocPositionWithCursor(26, OwnerUri); + + //cursor on sys + TextDocumentPosition sysDocument = CreateTextDocPositionWithCursor(22, OwnerUri); + + //cursor on master + TextDocumentPosition masterDocument = CreateTextDocPositionWithCursor(15, OwnerUri); + + LiveConnectionHelper.TestConnectionResult connectionResult = LiveConnectionHelper.InitLiveConnectionInfo(); + ScriptFile scriptFile = connectionResult.ScriptFile; + ConnectionInfo connInfo = connectionResult.ConnectionInfo; + var bindingQueue = new ConnectedBindingQueue(); + bindingQueue.AddConnectionContext(connInfo); + LanguageService.Instance.BindingQueue = bindingQueue; + scriptFile.Contents = queryString; + + var service = LanguageService.Instance; + await service.UpdateLanguageServiceOnConnection(connectionResult.ConnectionInfo); + Thread.Sleep(2000); + + ScriptParseInfo scriptInfo = new ScriptParseInfo { IsConnected = true }; + scriptInfo.ConnectionKey = bindingQueue.AddConnectionContext(connInfo); + LanguageService.Instance.ScriptParseInfoMap.Add(OwnerUri, scriptInfo); + + // When I call the language service + var objectResult = LanguageService.Instance.GetDefinition(objectDocument, scriptFile, connInfo); + var sysResult = LanguageService.Instance.GetDefinition(sysDocument, scriptFile, connInfo); + var masterResult = LanguageService.Instance.GetDefinition(masterDocument, scriptFile, connInfo); + + // Then I expect the results to be non-null + Assert.NotNull(objectResult); + Assert.NotNull(sysResult); + Assert.NotNull(masterResult); + + // And I expect the all results to be the same + Assert.True(CompareLocations(objectResult.Locations, sysResult.Locations)); + Assert.True(CompareLocations(objectResult.Locations, masterResult.Locations)); + + Cleanup(objectResult.Locations); + Cleanup(sysResult.Locations); + Cleanup(masterResult.Locations); + LanguageService.Instance.ScriptParseInfoMap.Remove(OwnerUri); + } + + [Fact] + public async void GetDefinitionFromProcedures() + { + + string queryString = "EXEC master.dbo.sp_MSrepl_startup"; + + // place the cursor on every token + + //cursor on objects + TextDocumentPosition fnDocument = CreateTextDocPositionWithCursor(30, TestUri); + + //cursor on sys + TextDocumentPosition dboDocument = CreateTextDocPositionWithCursor(14, TestUri); + + //cursor on master + TextDocumentPosition masterDocument = CreateTextDocPositionWithCursor(10, TestUri); + + LiveConnectionHelper.TestConnectionResult connectionResult = LiveConnectionHelper.InitLiveConnectionInfo(); + ScriptFile scriptFile = connectionResult.ScriptFile; + ConnectionInfo connInfo = connectionResult.ConnectionInfo; + var bindingQueue = new ConnectedBindingQueue(); + bindingQueue.AddConnectionContext(connInfo); + LanguageService.Instance.BindingQueue = bindingQueue; + scriptFile.Contents = queryString; + + var service = LanguageService.Instance; + await service.UpdateLanguageServiceOnConnection(connectionResult.ConnectionInfo); + Thread.Sleep(2000); + + ScriptParseInfo scriptInfo = new ScriptParseInfo { IsConnected = true }; + scriptInfo.ConnectionKey = bindingQueue.AddConnectionContext(connInfo); + LanguageService.Instance.ScriptParseInfoMap.Add(TestUri, scriptInfo); + + // When I call the language service + var fnResult = LanguageService.Instance.GetDefinition(fnDocument, scriptFile, connInfo); + var sysResult = LanguageService.Instance.GetDefinition(dboDocument, scriptFile, connInfo); + var masterResult = LanguageService.Instance.GetDefinition(masterDocument, scriptFile, connInfo); + + // Then I expect the results to be non-null + Assert.NotNull(fnResult); + Assert.NotNull(sysResult); + Assert.NotNull(masterResult); + + // And I expect the all results to be the same + Assert.True(CompareLocations(fnResult.Locations, sysResult.Locations)); + Assert.True(CompareLocations(fnResult.Locations, masterResult.Locations)); + + Cleanup(fnResult.Locations); + Cleanup(sysResult.Locations); + Cleanup(masterResult.Locations); + LanguageService.Instance.ScriptParseInfoMap.Remove(TestUri); + } + + + /// + /// Helper method to clean up script files + /// + private void Cleanup(Location[] locations) + { + Uri fileUri = new Uri(locations[0].Uri); + if (File.Exists(fileUri.LocalPath)) + { + try + { + File.Delete(fileUri.LocalPath); + } + catch (Exception) + { + + } + } + } + + /// + /// Helper method to compare 2 Locations arrays + /// + /// + /// + /// + private bool CompareLocations(Location[] locationsA, Location[] locationsB) + { + HashSet locationSet = new HashSet(); + foreach (var location in locationsA) + { + locationSet.Add(location); + } + foreach (var location in locationsB) + { + if (!locationSet.Contains(location)) + { + return false; + } + } + return true; + } + + private TextDocumentPosition CreateTextDocPositionWithCursor(int column, string OwnerUri) + { + TextDocumentPosition textDocPos = new TextDocumentPosition + { + TextDocument = new TextDocumentIdentifier { Uri = OwnerUri }, + Position = new Position + { + Line = 0, + Character = column + } + }; + return textDocPos; + } + } +}