diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/PeekDefinition.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/PeekDefinition.cs
index 3d967eb1..07293ba6 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/PeekDefinition.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/PeekDefinition.cs
@@ -6,8 +6,12 @@ using System;
using System.IO;
using System.Collections.Generic;
using System.Collections.Specialized;
+using System.Data.SqlClient;
+using System.Runtime.InteropServices;
using Microsoft.SqlServer.Management.Smo;
+using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlServer.Management.SqlParser.Intellisense;
+using Microsoft.SqlTools.ServiceLayer.Utility;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts;
@@ -37,9 +41,21 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
{
if (this.connectionInfo.SqlConnection != null)
{
- Server server = new Server(this.connectionInfo.SqlConnection.DataSource);
- return server.Databases[this.connectionInfo.SqlConnection.Database];
-
+ try
+ {
+ // Get server object from connection
+ string connectionString = ConnectionService.BuildConnectionString(this.connectionInfo.ConnectionDetails);
+ SqlConnection sqlConn = new SqlConnection(connectionString);
+ sqlConn.Open();
+ ServerConnection serverConn = new ServerConnection(sqlConn);
+ Server server = new Server(serverConn);
+ return server.Databases[this.connectionInfo.SqlConnection.Database];
+ }
+ catch(Exception ex)
+ {
+ Logger.Write(LogLevel.Error, "Exception at PeekDefinition Database.get() : " + ex.Message);
+ return null;
+ }
}
return null;
}
@@ -83,11 +99,19 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
///
/// Convert a file to a location array containing a location object as expected by the extension
///
- private Location[] GetLocationFromFile(string tempFileName, int lineNumber)
+ internal Location[] GetLocationFromFile(string tempFileName, int lineNumber)
{
+ if (Path.DirectorySeparatorChar.Equals('/'))
+ {
+ tempFileName = "file:" + tempFileName;
+ }
+ else
+ {
+ tempFileName = new Uri(tempFileName).AbsoluteUri;
+ }
Location[] locations = new[] {
new Location {
- Uri = new Uri(tempFileName).AbsoluteUri,
+ Uri = tempFileName,
Range = new Range {
Start = new Position { Line = lineNumber, Character = 1},
End = new Position { Line = lineNumber + 1, Character = 1}
@@ -135,6 +159,15 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
DeclarationType type = declarationItem.Type;
if (sqlScriptGetters.ContainsKey(type) && sqlObjectTypes.ContainsKey(type))
{
+ // On *nix and mac systems, the defaultSchema property throws an Exception when accessed.
+ // This workaround ensures that a schema name is present by attempting
+ // to get the schema name from the declaration item
+ // If all fails, the default schema name is assumed to be "dbo"
+ if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && string.IsNullOrEmpty(schemaName))
+ {
+ string fullObjectName = declarationItem.DatabaseQualifiedName;
+ schemaName = this.GetSchemaFromDatabaseQualifiedName(fullObjectName, tokenText);
+ }
return GetSqlObjectDefinition(
sqlScriptGetters[type],
tokenText,
@@ -148,6 +181,25 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
return null;
}
+ ///
+ /// Return schema name from the full name of the database. If schema is missing return dbo as schema name.
+ ///
+ /// The full database qualified name(database.schema.object)
+ /// Object name
+ /// Schema name
+ internal string GetSchemaFromDatabaseQualifiedName(string fullObjectName, string objectName)
+ {
+ string[] tokens = fullObjectName.Split('.');
+ for (int i = tokens.Length - 1; i > 0; i--)
+ {
+ if(tokens[i].Equals(objectName))
+ {
+ return tokens[i-1];
+ }
+ }
+ return "dbo";
+ }
+
///
/// Script a table using SMO
///
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/PeekDefinitionTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/PeekDefinitionTests.cs
index 4588e704..b70c6c54 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/PeekDefinitionTests.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/PeekDefinitionTests.cs
@@ -2,10 +2,11 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
+using System;
+using System.IO;
using System.Collections.Generic;
using System.Threading.Tasks;
-using System.IO;
-using System;
+using System.Runtime.InteropServices;
using Microsoft.SqlServer.Management.SqlParser.Binder;
using Microsoft.SqlServer.Management.SqlParser.MetadataProvider;
using Microsoft.SqlServer.Management.SqlParser.Parser;
@@ -46,10 +47,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices
private const string OwnerUri = "testFile1";
- private const string ViewOwnerUri = "testFile2";
-
- private const string TriggerOwnerUri = "testFile3";
-
private void InitializeTestObjects()
{
// initial cursor position in the script file
@@ -126,15 +123,74 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices
// verify that send result was not called
requestContext.Verify(m => m.SendResult(It.IsAny()), Times.Never());
-
}
+ ///
+ /// Tests creating location objects on windows and non-windows systems
+ ///
+ [Fact]
+ public void GetLocationFromFileForValidFilePathTest()
+ {
+ String filePath = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? "C:\\test\\script.sql" : "/test/script.sql";
+ PeekDefinition peekDefinition = new PeekDefinition(null);
+ Location[] locations = peekDefinition.GetLocationFromFile(filePath, 0);
+
+ String expectedFilePath = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? "file:///C:/test/script.sql" : "file:/test/script.sql";
+ Assert.Equal(locations[0].Uri, expectedFilePath);
+ }
+
+ ///
+ /// Test PeekDefinition.GetSchemaFromDatabaseQualifiedName with a valid database name
+ ///
+ [Fact]
+ public void GetSchemaFromDatabaseQualifiedNameWithValidNameTest()
+ {
+ PeekDefinition peekDefinition = new PeekDefinition(null);
+ string validDatabaseQualifiedName = "master.test.test_table";
+ string objectName = "test_table";
+ string expectedSchemaName = "test";
+
+ string actualSchemaName = peekDefinition.GetSchemaFromDatabaseQualifiedName(validDatabaseQualifiedName, objectName);
+ Assert.Equal(actualSchemaName, expectedSchemaName);
+ }
+
+ ///
+ /// Test PeekDefinition.GetSchemaFromDatabaseQualifiedName with a valid object name and no schema
+ ///
+
+ [Fact]
+ public void GetSchemaFromDatabaseQualifiedNameWithNoSchemaTest()
+ {
+ PeekDefinition peekDefinition = new PeekDefinition(null);
+ string validDatabaseQualifiedName = "test_table";
+ string objectName = "test_table";
+ string expectedSchemaName = "dbo";
+
+ string actualSchemaName = peekDefinition.GetSchemaFromDatabaseQualifiedName(validDatabaseQualifiedName, objectName);
+ Assert.Equal(actualSchemaName, expectedSchemaName);
+ }
+
+ ///
+ /// Test PeekDefinition.GetSchemaFromDatabaseQualifiedName with a invalid database name
+ ///
+ [Fact]
+ public void GetSchemaFromDatabaseQualifiedNameWithInvalidNameTest()
+ {
+ PeekDefinition peekDefinition = new PeekDefinition(null);
+ string validDatabaseQualifiedName = "x.y.z";
+ string objectName = "test_table";
+ string expectedSchemaName = "dbo";
+
+ string actualSchemaName = peekDefinition.GetSchemaFromDatabaseQualifiedName(validDatabaseQualifiedName, objectName);
+ Assert.Equal(actualSchemaName, expectedSchemaName);
+ }
+
#if LIVE_CONNECTION_TESTS
///
/// Test get definition for a table object with active connection
///
[Fact]
- public void GetTableDefinitionTest()
+ public void GetValidTableDefinitionTest()
{
// Get live connectionInfo
ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfoForDefinition();
@@ -143,11 +199,52 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices
string schemaName = null;
string objectType = "TABLE";
+ // Get locations for valid table object
Location[] locations = peekDefinition.GetSqlObjectDefinition(peekDefinition.GetTableScripts, objectName, schemaName, objectType);
Assert.NotNull(locations);
Cleanup(locations);
}
+ ///
+ /// Test get definition for a invalid table object with active connection
+ ///
+ [Fact]
+ public void GetTableDefinitionInvalidObjectTest()
+ {
+ // Get live connectionInfo
+ ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfoForDefinition();
+ PeekDefinition peekDefinition = new PeekDefinition(connInfo);
+ string objectName = "test_invalid";
+ string schemaName = null;
+ string objectType = "TABLE";
+
+ // Get locations for invalid table object
+ Location[] locations = peekDefinition.GetSqlObjectDefinition(peekDefinition.GetTableScripts, objectName, schemaName, objectType);
+ Assert.Null(locations);
+ }
+
+ ///
+ /// Test get definition for a valid table object with schema and active connection
+ ///
+ [Fact]
+ public void GetTableDefinitionWithSchemaTest()
+ {
+ // Get live connectionInfo
+ ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfoForDefinition();
+ PeekDefinition peekDefinition = new PeekDefinition(connInfo);
+ string objectName = "test_table";
+ string schemaName = "dbo";
+ string objectType = "TABLE";
+
+ // Get locations for valid table object with schema name
+ Location[] locations = peekDefinition.GetSqlObjectDefinition(peekDefinition.GetTableScripts, objectName, schemaName, objectType);
+ Assert.NotNull(locations);
+ Cleanup(locations);
+ }
+
+ ///
+ /// Test GetDefinition with an unsupported type(function)
+ ///
[Fact]
public void GetUnsupportedDefinitionForFullScript()
{
@@ -177,7 +274,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices
/// Test get definition for a view object with active connection
///
[Fact]
- public void GetViewDefinitionTest()
+ public void GetValidViewDefinitionTest()
{
ConnectionInfo connInfo = TestObjects.InitLiveConnectionInfoForDefinition();
PeekDefinition peekDefinition = new PeekDefinition(connInfo);