From ab9794800531bff31a3d34f871b41471fcbc3e14 Mon Sep 17 00:00:00 2001 From: Sharon Ravindran Date: Wed, 7 Dec 2016 16:52:35 -0800 Subject: [PATCH] Fix/peek def mac (#170) * Fix Integrated auth error and Uri for *nix/Mac * Format code * Add Logging and unit tests * Modify tests for Windows: * Workaround missing default schema on *nix and Mac * Add unit tests * Correct comments * Change loop length * Fix Log message --- .../LanguageServices/PeekDefinition.cs | 62 +++++++++- .../LanguageServer/PeekDefinitionTests.cs | 115 ++++++++++++++++-- 2 files changed, 163 insertions(+), 14 deletions(-) diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/PeekDefinition.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/PeekDefinition.cs index 3d967eb1..07293ba6 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/PeekDefinition.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/PeekDefinition.cs @@ -6,8 +6,12 @@ using System; using System.IO; using System.Collections.Generic; using System.Collections.Specialized; +using System.Data.SqlClient; +using System.Runtime.InteropServices; using Microsoft.SqlServer.Management.Smo; +using Microsoft.SqlServer.Management.Common; using Microsoft.SqlServer.Management.SqlParser.Intellisense; +using Microsoft.SqlTools.ServiceLayer.Utility; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; @@ -37,9 +41,21 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices { if (this.connectionInfo.SqlConnection != null) { - Server server = new Server(this.connectionInfo.SqlConnection.DataSource); - return server.Databases[this.connectionInfo.SqlConnection.Database]; - + try + { + // Get server object from connection + string connectionString = ConnectionService.BuildConnectionString(this.connectionInfo.ConnectionDetails); + SqlConnection sqlConn = new SqlConnection(connectionString); + sqlConn.Open(); + ServerConnection serverConn = new ServerConnection(sqlConn); + Server server = new Server(serverConn); + return server.Databases[this.connectionInfo.SqlConnection.Database]; + } + catch(Exception ex) + { + Logger.Write(LogLevel.Error, "Exception at PeekDefinition Database.get() : " + ex.Message); + return null; + } } return null; } @@ -83,11 +99,19 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices /// /// Convert a file to a location array containing a location object as expected by the extension /// - private Location[] GetLocationFromFile(string tempFileName, int lineNumber) + internal Location[] GetLocationFromFile(string tempFileName, int lineNumber) { + if (Path.DirectorySeparatorChar.Equals('/')) + { + tempFileName = "file:" + tempFileName; + } + else + { + tempFileName = new Uri(tempFileName).AbsoluteUri; + } Location[] locations = new[] { new Location { - Uri = new Uri(tempFileName).AbsoluteUri, + Uri = tempFileName, Range = new Range { Start = new Position { Line = lineNumber, Character = 1}, End = new Position { Line = lineNumber + 1, Character = 1} @@ -135,6 +159,15 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices DeclarationType type = declarationItem.Type; if (sqlScriptGetters.ContainsKey(type) && sqlObjectTypes.ContainsKey(type)) { + // On *nix and mac systems, the defaultSchema property throws an Exception when accessed. + // This workaround ensures that a schema name is present by attempting + // to get the schema name from the declaration item + // If all fails, the default schema name is assumed to be "dbo" + if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && string.IsNullOrEmpty(schemaName)) + { + string fullObjectName = declarationItem.DatabaseQualifiedName; + schemaName = this.GetSchemaFromDatabaseQualifiedName(fullObjectName, tokenText); + } return GetSqlObjectDefinition( sqlScriptGetters[type], tokenText, @@ -148,6 +181,25 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices return null; } + /// + /// Return schema name from the full name of the database. If schema is missing return dbo as schema name. + /// + /// The full database qualified name(database.schema.object) + /// Object name + /// Schema name + internal string GetSchemaFromDatabaseQualifiedName(string fullObjectName, string objectName) + { + string[] tokens = fullObjectName.Split('.'); + for (int i = tokens.Length - 1; i > 0; i--) + { + if(tokens[i].Equals(objectName)) + { + return tokens[i-1]; + } + } + return "dbo"; + } + /// /// Script a table using SMO /// diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/PeekDefinitionTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/PeekDefinitionTests.cs index 4588e704..b70c6c54 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/PeekDefinitionTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/PeekDefinitionTests.cs @@ -2,10 +2,11 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. // +using System; +using System.IO; using System.Collections.Generic; using System.Threading.Tasks; -using System.IO; -using System; +using System.Runtime.InteropServices; using Microsoft.SqlServer.Management.SqlParser.Binder; using Microsoft.SqlServer.Management.SqlParser.MetadataProvider; using Microsoft.SqlServer.Management.SqlParser.Parser; @@ -46,10 +47,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices private const string OwnerUri = "testFile1"; - private const string ViewOwnerUri = "testFile2"; - - private const string TriggerOwnerUri = "testFile3"; - private void InitializeTestObjects() { // initial cursor position in the script file @@ -126,15 +123,74 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices // verify that send result was not called requestContext.Verify(m => m.SendResult(It.IsAny()), Times.Never()); - } + /// + /// Tests creating location objects on windows and non-windows systems + /// + [Fact] + public void GetLocationFromFileForValidFilePathTest() + { + String filePath = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? "C:\\test\\script.sql" : "/test/script.sql"; + PeekDefinition peekDefinition = new PeekDefinition(null); + Location[] locations = peekDefinition.GetLocationFromFile(filePath, 0); + + String expectedFilePath = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? "file:///C:/test/script.sql" : "file:/test/script.sql"; + Assert.Equal(locations[0].Uri, expectedFilePath); + } + + /// + /// Test PeekDefinition.GetSchemaFromDatabaseQualifiedName with a valid database name + /// + [Fact] + public void GetSchemaFromDatabaseQualifiedNameWithValidNameTest() + { + PeekDefinition peekDefinition = new PeekDefinition(null); + string validDatabaseQualifiedName = "master.test.test_table"; + string objectName = "test_table"; + string expectedSchemaName = "test"; + + string actualSchemaName = peekDefinition.GetSchemaFromDatabaseQualifiedName(validDatabaseQualifiedName, objectName); + Assert.Equal(actualSchemaName, expectedSchemaName); + } + + /// + /// Test PeekDefinition.GetSchemaFromDatabaseQualifiedName with a valid object name and no schema + /// + + [Fact] + public void GetSchemaFromDatabaseQualifiedNameWithNoSchemaTest() + { + PeekDefinition peekDefinition = new PeekDefinition(null); + string validDatabaseQualifiedName = "test_table"; + string objectName = "test_table"; + string expectedSchemaName = "dbo"; + + string actualSchemaName = peekDefinition.GetSchemaFromDatabaseQualifiedName(validDatabaseQualifiedName, objectName); + Assert.Equal(actualSchemaName, expectedSchemaName); + } + + /// + /// Test PeekDefinition.GetSchemaFromDatabaseQualifiedName with a invalid database name + /// + [Fact] + public void GetSchemaFromDatabaseQualifiedNameWithInvalidNameTest() + { + PeekDefinition peekDefinition = new PeekDefinition(null); + string validDatabaseQualifiedName = "x.y.z"; + string objectName = "test_table"; + string expectedSchemaName = "dbo"; + + string actualSchemaName = peekDefinition.GetSchemaFromDatabaseQualifiedName(validDatabaseQualifiedName, objectName); + Assert.Equal(actualSchemaName, expectedSchemaName); + } + #if LIVE_CONNECTION_TESTS /// /// Test get definition for a table object with active connection /// [Fact] - public void GetTableDefinitionTest() + public void GetValidTableDefinitionTest() { // Get live connectionInfo ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfoForDefinition(); @@ -143,11 +199,52 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices string schemaName = null; string objectType = "TABLE"; + // Get locations for valid table object Location[] locations = peekDefinition.GetSqlObjectDefinition(peekDefinition.GetTableScripts, objectName, schemaName, objectType); Assert.NotNull(locations); Cleanup(locations); } + /// + /// Test get definition for a invalid table object with active connection + /// + [Fact] + public void GetTableDefinitionInvalidObjectTest() + { + // Get live connectionInfo + ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfoForDefinition(); + PeekDefinition peekDefinition = new PeekDefinition(connInfo); + string objectName = "test_invalid"; + string schemaName = null; + string objectType = "TABLE"; + + // Get locations for invalid table object + Location[] locations = peekDefinition.GetSqlObjectDefinition(peekDefinition.GetTableScripts, objectName, schemaName, objectType); + Assert.Null(locations); + } + + /// + /// Test get definition for a valid table object with schema and active connection + /// + [Fact] + public void GetTableDefinitionWithSchemaTest() + { + // Get live connectionInfo + ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfoForDefinition(); + PeekDefinition peekDefinition = new PeekDefinition(connInfo); + string objectName = "test_table"; + string schemaName = "dbo"; + string objectType = "TABLE"; + + // Get locations for valid table object with schema name + Location[] locations = peekDefinition.GetSqlObjectDefinition(peekDefinition.GetTableScripts, objectName, schemaName, objectType); + Assert.NotNull(locations); + Cleanup(locations); + } + + /// + /// Test GetDefinition with an unsupported type(function) + /// [Fact] public void GetUnsupportedDefinitionForFullScript() { @@ -177,7 +274,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices /// Test get definition for a view object with active connection /// [Fact] - public void GetViewDefinitionTest() + public void GetValidViewDefinitionTest() { ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfoForDefinition(); PeekDefinition peekDefinition = new PeekDefinition(connInfo);