diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/ServiceHost.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/ServiceHost.cs index f217e0d3..cdf39e29 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/ServiceHost.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/ServiceHost.cs @@ -150,7 +150,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting Capabilities = new ServerCapabilities { TextDocumentSync = TextDocumentSyncKind.Incremental, - DefinitionProvider = false, + DefinitionProvider = true, ReferencesProvider = false, DocumentHighlightProvider = false, HoverProvider = true, diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs index 8d164861..716d7cd8 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs @@ -33,6 +33,8 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices /// public sealed class LanguageService { + private const int OneSecond = 1000; + internal const string DefaultBatchSeperator = "GO"; internal const int DiagnosticParseDelay = 750; @@ -41,7 +43,9 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices internal const int BindingTimeout = 500; - internal const int OnConnectionWaitTimeout = 300000; + internal const int OnConnectionWaitTimeout = 300 * OneSecond; + + internal const int PeekDefinitionTimeout = 10 * OneSecond; private static ConnectionService connectionService = null; @@ -198,7 +202,6 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices // Register the requests that this service will handle // turn off until needed (10/28/2016) - // serviceHost.SetRequestHandler(DefinitionRequest.Type, HandleDefinitionRequest); // serviceHost.SetRequestHandler(ReferencesRequest.Type, HandleReferencesRequest); // serviceHost.SetRequestHandler(DocumentHighlightRequest.Type, HandleDocumentHighlightRequest); @@ -206,6 +209,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices serviceHost.SetRequestHandler(CompletionResolveRequest.Type, HandleCompletionResolveRequest); serviceHost.SetRequestHandler(HoverRequest.Type, HandleHoverRequest); serviceHost.SetRequestHandler(CompletionRequest.Type, HandleCompletionRequest); + serviceHost.SetRequestHandler(DefinitionRequest.Type, HandleDefinitionRequest); // Register a no-op shutdown task for validation of the shutdown logic serviceHost.RegisterShutdownTask(async (shutdownParams, shutdownRequestContext) => @@ -293,15 +297,25 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices } } -// turn off this code until needed (10/28/2016) -#if false - private static async Task HandleDefinitionRequest( - TextDocumentPosition textDocumentPosition, - RequestContext requestContext) + internal static async Task HandleDefinitionRequest(TextDocumentPosition textDocumentPosition, RequestContext requestContext) { - await Task.FromResult(true); + if (WorkspaceService.Instance.CurrentSettings.IsIntelliSenseEnabled) + { + // Retrieve document and connection + ConnectionInfo connInfo; + var scriptFile = LanguageService.WorkspaceServiceInstance.Workspace.GetFile(textDocumentPosition.TextDocument.Uri); + LanguageService.ConnectionServiceInstance.TryFindConnection(scriptFile.ClientFilePath, out connInfo); + + Location[] locations = LanguageService.Instance.GetDefinition(textDocumentPosition, scriptFile, connInfo); + if (locations != null) + { + await requestContext.SendResult(locations); + } + } } +// turn off this code until needed (10/28/2016) +#if false private static async Task HandleReferencesRequest( ReferencesParams referencesParams, RequestContext requestContext) @@ -654,6 +668,106 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices return completionItem; } + /// + /// Get definition for a selected sql object using SMO Scripting + /// + /// + /// + /// + /// Location with the URI of the script file + internal Location[] GetDefinition(TextDocumentPosition textDocumentPosition, ScriptFile scriptFile, ConnectionInfo connInfo) + { + // Parse sql + ScriptParseInfo scriptParseInfo = GetScriptParseInfo(textDocumentPosition.TextDocument.Uri); + if (scriptParseInfo == null) + { + return null; + } + + if (RequiresReparse(scriptParseInfo, scriptFile)) + { + scriptParseInfo.ParseResult = ParseAndBind(scriptFile, connInfo); + } + + // Get token from selected text + Token selectedToken = GetToken(scriptParseInfo, textDocumentPosition.Position.Line + 1, textDocumentPosition.Position.Character); + if (selectedToken == null) + { + return null; + } + // Strip "[" and "]"(if present) from the token text to enable matching with the suggestions. + // The suggestion title does not contain any sql punctuation + string tokenText = TextUtilities.RemoveSquareBracketSyntax(selectedToken.Text); + + if (scriptParseInfo.IsConnected && Monitor.TryEnter(scriptParseInfo.BuildingMetadataLock)) + { + try + { + // Queue the task with the binding queue + QueueItem queueItem = this.BindingQueue.QueueBindingOperation( + key: scriptParseInfo.ConnectionKey, + bindingTimeout: LanguageService.PeekDefinitionTimeout, + bindOperation: (bindingContext, cancelToken) => + { + // Get suggestions for the token + int parserLine = textDocumentPosition.Position.Line + 1; + int parserColumn = textDocumentPosition.Position.Character + 1; + IEnumerable declarationItems = Resolver.FindCompletions( + scriptParseInfo.ParseResult, + parserLine, parserColumn, + bindingContext.MetadataDisplayInfoProvider); + + // Match token with the suggestions(declaration items) returned + string schemaName = this.GetSchemaName(scriptParseInfo, textDocumentPosition.Position, scriptFile); + PeekDefinition peekDefinition = new PeekDefinition(connInfo); + return peekDefinition.GetScript(declarationItems, tokenText, schemaName); + + + }); + + // wait for the queue item + queueItem.ItemProcessed.WaitOne(); + return queueItem.GetResultAsT(); + } + finally + { + Monitor.Exit(scriptParseInfo.BuildingMetadataLock); + } + } + + return null; + } + + /// + /// Extract schema name for a token, if present + /// + /// + /// + /// + /// schema nama + private string GetSchemaName(ScriptParseInfo scriptParseInfo, Position position, ScriptFile scriptFile) + { + // Offset index by 1 for sql parser + int startLine = position.Line + 1; + int startColumn = position.Character + 1; + + // Get schema name + if (scriptParseInfo != null && scriptParseInfo.ParseResult != null && scriptParseInfo.ParseResult.Script != null && scriptParseInfo.ParseResult.Script.Tokens != null) + { + var tokenIndex = scriptParseInfo.ParseResult.Script.TokenManager.FindToken(startLine, startColumn); + var prevTokenIndex = scriptParseInfo.ParseResult.Script.TokenManager.GetPreviousSignificantTokenIndex(tokenIndex); + var prevTokenText = scriptParseInfo.ParseResult.Script.TokenManager.GetText(prevTokenIndex); + if (prevTokenText != null && prevTokenText.Equals(".")) + { + var schemaTokenIndex = scriptParseInfo.ParseResult.Script.TokenManager.GetPreviousSignificantTokenIndex(prevTokenIndex); + Token schemaToken = scriptParseInfo.ParseResult.Script.TokenManager.GetToken(schemaTokenIndex); + return TextUtilities.RemoveSquareBracketSyntax(schemaToken.Text); + } + } + // if no schema name, returns null + return null; + } + /// /// Get quick info hover tooltips for the current position /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/PeekDefinition.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/PeekDefinition.cs new file mode 100644 index 00000000..3d967eb1 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/PeekDefinition.cs @@ -0,0 +1,227 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// +using System; +using System.IO; +using System.Collections.Generic; +using System.Collections.Specialized; +using Microsoft.SqlServer.Management.Smo; +using Microsoft.SqlServer.Management.SqlParser.Intellisense; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices +{ + /// + /// Peek Definition/ Go to definition implementation + /// Script sql objects and write create scripts to file + /// + internal class PeekDefinition + { + private ConnectionInfo connectionInfo; + private string tempPath; + + internal delegate StringCollection ScriptGetter(string objectName, string schemaName); + + // Dictionary that holds the script getter for each type + private Dictionary sqlScriptGetters = + new Dictionary(); + + // Dictionary that holds the object name (as appears on the TSQL create statement) + private Dictionary sqlObjectTypes = new Dictionary(); + + private Database Database + { + get + { + if (this.connectionInfo.SqlConnection != null) + { + Server server = new Server(this.connectionInfo.SqlConnection.DataSource); + return server.Databases[this.connectionInfo.SqlConnection.Database]; + + } + return null; + } + } + + internal PeekDefinition(ConnectionInfo connInfo) + { + this.connectionInfo = connInfo; + DirectoryInfo tempScriptDirectory = Directory.CreateDirectory(Path.GetTempPath() + "mssql_definition"); + this.tempPath = tempScriptDirectory.FullName; + Initialize(); + } + + /// + /// Add getters for each sql object supported by peek definition + /// + private void Initialize() + { + //Add script getters for each sql object + + //Add tables to supported types + AddSupportedType(DeclarationType.Table, GetTableScripts, "Table"); + + //Add views to supported types + AddSupportedType(DeclarationType.View, GetViewScripts, "view"); + + //Add stored procedures to supported types + AddSupportedType(DeclarationType.StoredProcedure, GetStoredProcedureScripts, "Procedure"); + } + + /// + /// Add the given type, scriptgetter and the typeName string to the respective dictionaries + /// + private void AddSupportedType(DeclarationType type, ScriptGetter scriptGetter, string typeName) + { + sqlScriptGetters.Add(type, scriptGetter); + sqlObjectTypes.Add(type, typeName); + + } + + /// + /// Convert a file to a location array containing a location object as expected by the extension + /// + private Location[] GetLocationFromFile(string tempFileName, int lineNumber) + { + Location[] locations = new[] { + new Location { + Uri = new Uri(tempFileName).AbsoluteUri, + Range = new Range { + Start = new Position { Line = lineNumber, Character = 1}, + End = new Position { Line = lineNumber + 1, Character = 1} + } + } + }; + return locations; + } + + /// + /// Get line number for the create statement + /// + private int GetStartOfCreate(string script, string createString) + { + string[] lines = script.Split(new string[] { Environment.NewLine }, StringSplitOptions.None); + for (int lineNumber = 0; lineNumber < lines.Length; lineNumber++) + { + if (lines[lineNumber].IndexOf( createString, StringComparison.OrdinalIgnoreCase) >= 0) + { + return lineNumber; + } + } + return 0; + } + + /// + /// Get the script of the selected token based on the type of the token + /// + /// + /// + /// + /// Location object of the script file + internal Location[] GetScript(IEnumerable declarationItems, string tokenText, string schemaName) + { + foreach (Declaration declarationItem in declarationItems) + { + if (declarationItem.Title == null) + { + continue; + } + + if (declarationItem.Title.Equals(tokenText)) + { + // Script object using SMO based on type + DeclarationType type = declarationItem.Type; + if (sqlScriptGetters.ContainsKey(type) && sqlObjectTypes.ContainsKey(type)) + { + return GetSqlObjectDefinition( + sqlScriptGetters[type], + tokenText, + schemaName, + sqlObjectTypes[type] + ); + } + return null; + } + } + return null; + } + + /// + /// Script a table using SMO + /// + /// Table name + /// Schema name + /// String collection of scripts + internal StringCollection GetTableScripts(string tableName, string schemaName) + { + return (schemaName != null) ? Database?.Tables[tableName, schemaName]?.Script() + : Database?.Tables[tableName]?.Script(); + } + + /// + /// Script a view using SMO + /// + /// View name + /// Schema name + /// String collection of scripts + internal StringCollection GetViewScripts(string viewName, string schemaName) + { + return (schemaName != null) ? Database?.Views[viewName, schemaName]?.Script() + : Database?.Views[viewName]?.Script(); + } + + /// + /// Script a stored procedure using SMO + /// + /// Stored Procedure name + /// Schema Name + /// String collection of scripts + internal StringCollection GetStoredProcedureScripts(string viewName, string schemaName) + { + return (schemaName != null) ? Database?.StoredProcedures[viewName, schemaName]?.Script() + : Database?.StoredProcedures[viewName]?.Script(); + } + + /// + /// Script a object using SMO and write to a file. + /// + /// Function that returns the SMO scripts for an object + /// SQL object name + /// Schema name or null + /// Type of SQL object + /// Location object representing URI and range of the script file + internal Location[] GetSqlObjectDefinition( + ScriptGetter sqlScriptGetter, + string objectName, + string schemaName, + string objectType) + { + StringCollection scripts = sqlScriptGetter(objectName, schemaName); + string tempFileName = (schemaName != null) ? Path.Combine(this.tempPath, string.Format("{0}.{1}.sql", schemaName, objectName)) + : Path.Combine(this.tempPath, string.Format("{0}.sql", objectName)); + + if (scripts != null) + { + int lineNumber = 0; + using (StreamWriter scriptFile = new StreamWriter(File.Open(tempFileName, FileMode.Create, FileAccess.ReadWrite))) + { + + foreach (string script in scripts) + { + string createSyntax = string.Format("CREATE {0}", objectType); + if (script.IndexOf(createSyntax, StringComparison.OrdinalIgnoreCase) >= 0) + { + scriptFile.WriteLine(script); + lineNumber = GetStartOfCreate(script, createSyntax); + } + } + } + return GetLocationFromFile(tempFileName, lineNumber); + } + + return null; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsSettings.cs b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsSettings.cs index 37c35ebf..0359df16 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsSettings.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsSettings.cs @@ -68,7 +68,7 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext } /// - /// Gets a flag determining if suggestons are enabled + /// Gets a flag determining if suggestions are enabled /// public bool IsSuggestionsEnabled { @@ -90,6 +90,17 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext && this.SqlTools.IntelliSense.EnableQuickInfo.Value; } } + + /// + /// Gets a flag determining if IntelliSense is enabled + /// + public bool IsIntelliSenseEnabled + { + get + { + return this.SqlTools.IntelliSense.EnableIntellisense; + } + } } /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/Utility/TextUtilities.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/TextUtilities.cs index 29ca9e7f..979d697b 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Utility/TextUtilities.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Utility/TextUtilities.cs @@ -109,5 +109,19 @@ namespace Microsoft.SqlTools.ServiceLayer.Utility || ch == '(' || ch == ')'; } + + /// + /// Remove square bracket syntax from a token string + /// + /// + /// string with outer brackets removed + public static string RemoveSquareBracketSyntax(string tokenText) + { + if(tokenText.StartsWith("[") && tokenText.EndsWith("]")) + { + return tokenText.Substring(1, tokenText.Length - 2); + } + return tokenText; + } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/PeekDefinitionTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/PeekDefinitionTests.cs new file mode 100644 index 00000000..4588e704 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/PeekDefinitionTests.cs @@ -0,0 +1,279 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// +using System.Collections.Generic; +using System.Threading.Tasks; +using System.IO; +using System; +using Microsoft.SqlServer.Management.SqlParser.Binder; +using Microsoft.SqlServer.Management.SqlParser.MetadataProvider; +using Microsoft.SqlServer.Management.SqlParser.Parser; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.LanguageServices; +using Microsoft.SqlTools.ServiceLayer.SqlContext; +using Microsoft.SqlTools.ServiceLayer.Test.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.Workspace; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; +using Microsoft.SqlTools.Test.Utility; +using Location = Microsoft.SqlTools.ServiceLayer.Workspace.Contracts.Location; +using Moq; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices +{ + /// + /// Tests for the language service peek definition/ go to definition feature + /// + public class PeekDefinitionTests + { + private const int TaskTimeout = 30000; + + private readonly string testScriptUri = TestObjects.ScriptUri; + + private readonly string testConnectionKey = "testdbcontextkey"; + + private Mock bindingQueue; + + private Mock> workspaceService; + + private Mock> requestContext; + + private Mock binder; + + private TextDocumentPosition textDocument; + + private const string OwnerUri = "testFile1"; + + private const string ViewOwnerUri = "testFile2"; + + private const string TriggerOwnerUri = "testFile3"; + + private void InitializeTestObjects() + { + // initial cursor position in the script file + textDocument = new TextDocumentPosition + { + TextDocument = new TextDocumentIdentifier {Uri = this.testScriptUri}, + Position = new Position + { + Line = 0, + Character = 23 + } + }; + + // default settings are stored in the workspace service + WorkspaceService.Instance.CurrentSettings = new SqlToolsSettings(); + + // set up file for returning the query + var fileMock = new Mock(); + fileMock.SetupGet(file => file.Contents).Returns(Common.StandardQuery); + fileMock.SetupGet(file => file.ClientFilePath).Returns(this.testScriptUri); + + // set up workspace mock + workspaceService = new Mock>(); + workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) + .Returns(fileMock.Object); + + // setup binding queue mock + bindingQueue = new Mock(); + bindingQueue.Setup(q => q.AddConnectionContext(It.IsAny())) + .Returns(this.testConnectionKey); + + // inject mock instances into the Language Service + LanguageService.WorkspaceServiceInstance = workspaceService.Object; + LanguageService.ConnectionServiceInstance = TestObjects.GetTestConnectionService(); + ConnectionInfo connectionInfo = TestObjects.GetTestConnectionInfo(); + LanguageService.ConnectionServiceInstance.OwnerToConnectionMap.Add(this.testScriptUri, connectionInfo); + LanguageService.Instance.BindingQueue = bindingQueue.Object; + + // setup the mock for SendResult + requestContext = new Mock>(); + requestContext.Setup(rc => rc.SendResult(It.IsAny())) + .Returns(Task.FromResult(0)); + + // setup the IBinder mock + binder = new Mock(); + binder.Setup(b => b.Bind( + It.IsAny>(), + It.IsAny(), + It.IsAny())); + + var testScriptParseInfo = new ScriptParseInfo(); + LanguageService.Instance.AddOrUpdateScriptParseInfo(this.testScriptUri, testScriptParseInfo); + testScriptParseInfo.IsConnected = true; + testScriptParseInfo.ConnectionKey = LanguageService.Instance.BindingQueue.AddConnectionContext(connectionInfo); + + // setup the binding context object + ConnectedBindingContext bindingContext = new ConnectedBindingContext(); + bindingContext.Binder = binder.Object; + bindingContext.MetadataDisplayInfoProvider = new MetadataDisplayInfoProvider(); + LanguageService.Instance.BindingQueue.BindingContextMap.Add(testScriptParseInfo.ConnectionKey, bindingContext); + } + + + /// + /// Tests the definition event handler. When called with no active connection, no definition is sent + /// + [Fact] + public void DefinitionsHandlerWithNoConnectionTest() + { + InitializeTestObjects(); + // request the completion list + Task handleCompletion = LanguageService.HandleDefinitionRequest(textDocument, requestContext.Object); + handleCompletion.Wait(TaskTimeout); + + // verify that send result was not called + requestContext.Verify(m => m.SendResult(It.IsAny()), Times.Never()); + + } + +#if LIVE_CONNECTION_TESTS + /// + /// Test get definition for a table object with active connection + /// + [Fact] + public void GetTableDefinitionTest() + { + // Get live connectionInfo + ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfoForDefinition(); + PeekDefinition peekDefinition = new PeekDefinition(connInfo); + string objectName = "test_table"; + string schemaName = null; + string objectType = "TABLE"; + + Location[] locations = peekDefinition.GetSqlObjectDefinition(peekDefinition.GetTableScripts, objectName, schemaName, objectType); + Assert.NotNull(locations); + Cleanup(locations); + } + + [Fact] + public void GetUnsupportedDefinitionForFullScript() + { + + ScriptFile scriptFile; + TextDocumentPosition textDocument = new TextDocumentPosition + { + TextDocument = new TextDocumentIdentifier { Uri = OwnerUri }, + Position = new Position + { + Line = 0, + Character = 20 + } + }; + ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfo(out scriptFile); + scriptFile.Contents = "select * from dbo.func ()"; + + var languageService = LanguageService.Instance; + ScriptParseInfo scriptInfo = new ScriptParseInfo { IsConnected = true }; + languageService.ScriptParseInfoMap.Add(OwnerUri, scriptInfo); + + var locations = languageService.GetDefinition(textDocument, scriptFile, connInfo); + Assert.Null(locations); + } + + /// + /// Test get definition for a view object with active connection + /// + [Fact] + public void GetViewDefinitionTest() + { + ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfoForDefinition(); + PeekDefinition peekDefinition = new PeekDefinition(connInfo); + string objectName = "objects"; + string schemaName = "sys"; + string objectType = "VIEW"; + + Location[] locations = peekDefinition.GetSqlObjectDefinition(peekDefinition.GetViewScripts, objectName, schemaName, objectType); + Assert.NotNull(locations); + Cleanup(locations); + } + + /// + /// Test get definition for an invalid view object with no schema name and with active connection + /// + [Fact] + public void GetViewDefinitionInvalidObjectTest() + { + ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfoForDefinition(); + PeekDefinition peekDefinition = new PeekDefinition(connInfo); + string objectName = "objects"; + string schemaName = null; + string objectType = "VIEW"; + + Location[] locations = peekDefinition.GetSqlObjectDefinition(peekDefinition.GetViewScripts, objectName, schemaName, objectType); + Assert.Null(locations); + } + + /// + /// Test get definition for a stored procedure object with active connection + /// + [Fact] + public void GetStoredProcedureDefinitionTest() + { + ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfoForDefinition(); + PeekDefinition peekDefinition = new PeekDefinition(connInfo); + string objectName = "SP1"; + string schemaName = "dbo"; + string objectType = "PROCEDURE"; + + Location[] locations = peekDefinition.GetSqlObjectDefinition(peekDefinition.GetStoredProcedureScripts, objectName, schemaName, objectType); + Assert.NotNull(locations); + Cleanup(locations); + } + + /// + /// Test get definition for a stored procedure object that does not exist with active connection + /// + [Fact] + public void GetStoredProcedureDefinitionFailureTest() + { + ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfoForDefinition(); + PeekDefinition peekDefinition = new PeekDefinition(connInfo); + string objectName = "SP2"; + string schemaName = "dbo"; + string objectType = "PROCEDURE"; + + Location[] locations = peekDefinition.GetSqlObjectDefinition(peekDefinition.GetStoredProcedureScripts, objectName, schemaName, objectType); + Assert.Null(locations); + } + + /// + /// Test get definition for a stored procedure object with active connection and no schema + /// + [Fact] + public void GetStoredProcedureDefinitionWithoutSchemaTest() + { + ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfoForDefinition(); + PeekDefinition peekDefinition = new PeekDefinition(connInfo); + string objectName = "SP1"; + string schemaName = null; + string objectType = "PROCEDURE"; + + Location[] locations = peekDefinition.GetSqlObjectDefinition(peekDefinition.GetStoredProcedureScripts, objectName, schemaName, objectType); + Assert.NotNull(locations); + Cleanup(locations); + } + + /// + /// Helper method to clean up script files + /// + private void Cleanup(Location[] locations) + { + Uri fileUri = new Uri(locations[0].Uri); + if (File.Exists(fileUri.LocalPath)) + { + try + { + File.Delete(fileUri.LocalPath); + } + catch(Exception) + { + + } + } + } +#endif + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs index 01707778..0f861433 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs @@ -191,6 +191,27 @@ namespace Microsoft.SqlTools.Test.Utility connectionService.TryFindConnection(ownerUri, out connInfo); return connInfo; } + + public static ConnectionInfo InitLiveConnectionInfoForDefinition() + { + TestObjects.InitializeTestServices(); + + string ownerUri = ScriptUri; + var connectionService = TestObjects.GetLiveTestConnectionService(); + var connectionResult = + connectionService + .Connect(new ConnectParams() + { + OwnerUri = ownerUri, + Connection = TestObjects.GetIntegratedTestConnectionDetails() + }); + + connectionResult.Wait(); + + ConnectionInfo connInfo = null; + connectionService.TryFindConnection(ownerUri, out connInfo); + return connInfo; + } } /// diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/LanguageServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/LanguageServiceTests.cs index dd93a8e6..dccee6c8 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/LanguageServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/LanguageServiceTests.cs @@ -211,6 +211,49 @@ namespace Microsoft.SqlTools.ServiceLayer.TestDriver.Tests } } + /// + /// Peek Definition/ Go to definition + /// + /// + [Fact] + public async Task DefinitionTest() + { + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + using (TestHelper testHelper = new TestHelper()) + { + string query = "SELECT * FROM sys.objects"; + int lineNumber = 0; + int position = 23; + + testHelper.WriteToFile(queryTempFile.FilePath, query); + + DidOpenTextDocumentNotification openParams = new DidOpenTextDocumentNotification + { + TextDocument = new TextDocumentItem + { + Uri = queryTempFile.FilePath, + LanguageId = "enu", + Version = 1, + Text = query + } + }; + + await testHelper.RequestOpenDocumentNotification(openParams); + + Thread.Sleep(500); + + bool connected = await testHelper.Connect(queryTempFile.FilePath, ConnectionTestUtils.LocalhostConnection); + Assert.True(connected, "Connection is successful"); + + Thread.Sleep(10000); + // Request definition for "objects" + Location[] locations = await testHelper.RequestDefinition(queryTempFile.FilePath, query, lineNumber, position); + + Assert.True(locations != null, "Location is not null and not empty"); + await testHelper.Disconnect(queryTempFile.FilePath); + } + } + /// /// Validate the configuration change event /// diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/TestHelper.cs b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/TestHelper.cs index 54d2690b..d248ffff 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/TestHelper.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Tests/TestHelper.cs @@ -257,6 +257,29 @@ namespace Microsoft.SqlTools.ServiceLayer.TestDriver.Tests return result; } + /// + /// Request definition( peek definition/go to definition) for a sql object in a sql string + /// + public async Task RequestDefinition(string ownerUri, string text, int line, int character) + { + // Write the text to a backing file + lock (fileLock) + { + System.IO.File.WriteAllText(ownerUri, text); + } + + var definitionParams = new TextDocumentPosition(); + definitionParams.TextDocument = new TextDocumentIdentifier(); + definitionParams.TextDocument.Uri = ownerUri; + definitionParams.Position = new Position(); + definitionParams.Position.Line = line; + definitionParams.Position.Character = character; + + // Send definition request + var result = await Driver.SendRequest(DefinitionRequest.Type, definitionParams); + return result; + } + /// /// Run a query using a given connection bound to a URI ///