mirror of
https://github.com/ckaczor/sqltoolsservice.git
synced 2026-01-17 09:35:37 -05:00
Add Scripter class and use it in the ScriptingService (#286)
* Initial scripting commit * Fix the script CREATE integration test * Use language service binding queue for scripting * Update scripting service to have methods for each operation * Add scripter class * Fix build break. * Fix a few bugs * Use scripter class instead of PeekDefinition class. * Fix scripting test break * Fix header incorrectly saying file is generated. * Add localization tests to address a large block of code not covered by automation. * Moving system usings to top of using list
This commit is contained in:
457
src/Microsoft.SqlTools.ServiceLayer/Scripting/ScripterCore.cs
Normal file
457
src/Microsoft.SqlTools.ServiceLayer/Scripting/ScripterCore.cs
Normal file
@@ -0,0 +1,457 @@
|
||||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||
//
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Collections.Specialized;
|
||||
using System.Data.Common;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using Microsoft.SqlServer.Management.Common;
|
||||
using Microsoft.SqlServer.Management.Smo;
|
||||
using Microsoft.SqlServer.Management.SqlParser.Intellisense;
|
||||
using Microsoft.SqlServer.Management.SqlParser.MetadataProvider;
|
||||
using Microsoft.SqlServer.Management.SqlParser.Parser;
|
||||
using Microsoft.SqlTools.Hosting.Protocol;
|
||||
using Microsoft.SqlTools.ServiceLayer.Connection;
|
||||
using Microsoft.SqlTools.ServiceLayer.LanguageServices;
|
||||
using Microsoft.SqlTools.ServiceLayer.QueryExecution;
|
||||
using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts;
|
||||
using Microsoft.SqlTools.Utility;
|
||||
using ConnectionType = Microsoft.SqlTools.ServiceLayer.Connection.ConnectionType;
|
||||
using Location = Microsoft.SqlTools.ServiceLayer.Workspace.Contracts.Location;
|
||||
|
||||
namespace Microsoft.SqlTools.ServiceLayer.Scripting
|
||||
{
|
||||
internal partial class Scripter
|
||||
{
|
||||
private bool error;
|
||||
private string errorMessage;
|
||||
private ServerConnection serverConnection;
|
||||
private ConnectionInfo connectionInfo;
|
||||
private Database database;
|
||||
private string tempPath;
|
||||
|
||||
internal delegate StringCollection ScriptGetter(string objectName, string schemaName, ScriptingOptions scriptingOptions);
|
||||
|
||||
// Dictionary that holds the script getter for each type
|
||||
private Dictionary<DeclarationType, ScriptGetter> sqlScriptGetters =
|
||||
new Dictionary<DeclarationType, ScriptGetter>();
|
||||
|
||||
private Dictionary<string, ScriptGetter> sqlScriptGettersFromQuickInfo =
|
||||
new Dictionary<string, ScriptGetter>();
|
||||
|
||||
// Dictionary that holds the object name (as appears on the TSQL create statement)
|
||||
private Dictionary<DeclarationType, string> sqlObjectTypes = new Dictionary<DeclarationType, string>();
|
||||
|
||||
private Dictionary<string, string> sqlObjectTypesFromQuickInfo = new Dictionary<string, string>();
|
||||
|
||||
/// <summary>
|
||||
/// Initialize a Peek Definition helper object
|
||||
/// </summary>
|
||||
/// <param name="serverConnection">SMO Server connection</param>
|
||||
internal Scripter(ServerConnection serverConnection, ConnectionInfo connInfo)
|
||||
{
|
||||
this.serverConnection = serverConnection;
|
||||
this.connectionInfo = connInfo;
|
||||
this.tempPath = FileUtilities.GetPeekDefinitionTempFolder();
|
||||
Initialize();
|
||||
}
|
||||
|
||||
internal Database Database
|
||||
{
|
||||
get
|
||||
{
|
||||
if (this.database == null)
|
||||
{
|
||||
if (this.serverConnection != null && !string.IsNullOrEmpty(this.serverConnection.DatabaseName))
|
||||
{
|
||||
try
|
||||
{
|
||||
// Reuse existing connection
|
||||
Server server = new Server(this.serverConnection);
|
||||
// The default database name is the database name of the server connection
|
||||
string dbName = this.serverConnection.DatabaseName;
|
||||
if (this.connectionInfo != null)
|
||||
{
|
||||
// If there is a query DbConnection, use that connection to get the database name
|
||||
// This is preferred since it has the most current database name (in case of database switching)
|
||||
DbConnection connection;
|
||||
if (connectionInfo.TryGetConnection(ConnectionType.Query, out connection))
|
||||
{
|
||||
if (!string.IsNullOrEmpty(connection.Database))
|
||||
{
|
||||
dbName = connection.Database;
|
||||
}
|
||||
}
|
||||
}
|
||||
this.database = new Database(server, dbName);
|
||||
this.database.Refresh();
|
||||
}
|
||||
catch (ConnectionFailureException cfe)
|
||||
{
|
||||
Logger.Write(LogLevel.Error, "Exception at PeekDefinition Database.get() : " + cfe.Message);
|
||||
this.error = true;
|
||||
this.errorMessage = (connectionInfo != null && connectionInfo.IsAzure) ? SR.PeekDefinitionAzureError(cfe.Message) : SR.PeekDefinitionError(cfe.Message);
|
||||
return null;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
Logger.Write(LogLevel.Error, "Exception at PeekDefinition Database.get() : " + ex.Message);
|
||||
this.error = true;
|
||||
this.errorMessage = SR.PeekDefinitionError(ex.Message);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
return this.database;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Add the given type, scriptgetter and the typeName string to the respective dictionaries
|
||||
/// </summary>
|
||||
private void AddSupportedType(DeclarationType type, ScriptGetter scriptGetter, string typeName, string quickInfoType)
|
||||
{
|
||||
sqlScriptGetters.Add(type, scriptGetter);
|
||||
sqlObjectTypes.Add(type, typeName);
|
||||
if (!string.IsNullOrEmpty(quickInfoType))
|
||||
{
|
||||
sqlScriptGettersFromQuickInfo.Add(quickInfoType.ToLowerInvariant(), scriptGetter);
|
||||
sqlObjectTypesFromQuickInfo.Add(quickInfoType.ToLowerInvariant(), typeName);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get the script of the selected token based on the type of the token
|
||||
/// </summary>
|
||||
/// <param name="declarationItems"></param>
|
||||
/// <param name="tokenText"></param>
|
||||
/// <param name="schemaName"></param>
|
||||
/// <returns>Location object of the script file</returns>
|
||||
internal DefinitionResult GetScript(ParseResult parseResult, Position position, IMetadataDisplayInfoProvider metadataDisplayInfoProvider, string tokenText, string schemaName)
|
||||
{
|
||||
int parserLine = position.Line + 1;
|
||||
int parserColumn = position.Character + 1;
|
||||
// Get DeclarationItems from The Intellisense Resolver for the selected token. The type of the selected token is extracted from the declarationItem.
|
||||
IEnumerable<Declaration> declarationItems = GetCompletionsForToken(parseResult, parserLine, parserColumn, metadataDisplayInfoProvider);
|
||||
if (declarationItems != null && declarationItems.Count() > 0)
|
||||
{
|
||||
foreach (Declaration declarationItem in declarationItems)
|
||||
{
|
||||
if (declarationItem.Title == null)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if (this.Database == null)
|
||||
{
|
||||
return GetDefinitionErrorResult(SR.PeekDefinitionDatabaseError);
|
||||
}
|
||||
StringComparison caseSensitivity = this.Database.CaseSensitive ? StringComparison.Ordinal : StringComparison.OrdinalIgnoreCase;
|
||||
// if declarationItem matches the selected token, script SMO using that type
|
||||
if (declarationItem.Title.Equals(tokenText, caseSensitivity))
|
||||
{
|
||||
return GetDefinitionUsingDeclarationType(declarationItem.Type, declarationItem.DatabaseQualifiedName, tokenText, schemaName);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// if no declarationItem matched the selected token, we try to find the type of the token using QuickInfo.Text
|
||||
string quickInfoText = GetQuickInfoForToken(parseResult, parserLine, parserColumn, metadataDisplayInfoProvider);
|
||||
return GetDefinitionUsingQuickInfoText(quickInfoText, tokenText, schemaName);
|
||||
}
|
||||
// no definition found
|
||||
return GetDefinitionErrorResult(SR.PeekDefinitionNoResultsError);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Script an object using the type extracted from quickInfo Text
|
||||
/// </summary>
|
||||
/// <param name="quickInfoText">the text from the quickInfo for the selected token</param>
|
||||
/// <param name="tokenText">The text of the selected token</param>
|
||||
/// <param name="schemaName">Schema name</param>
|
||||
/// <returns></returns>
|
||||
internal DefinitionResult GetDefinitionUsingQuickInfoText(string quickInfoText, string tokenText, string schemaName)
|
||||
{
|
||||
if (this.Database == null)
|
||||
{
|
||||
return GetDefinitionErrorResult(SR.PeekDefinitionDatabaseError);
|
||||
}
|
||||
StringComparison caseSensitivity = this.Database.CaseSensitive ? StringComparison.Ordinal : StringComparison.OrdinalIgnoreCase;
|
||||
string tokenType = GetTokenTypeFromQuickInfo(quickInfoText, tokenText, caseSensitivity);
|
||||
if (tokenType != null)
|
||||
{
|
||||
if (sqlScriptGettersFromQuickInfo.ContainsKey(tokenType.ToLowerInvariant()))
|
||||
{
|
||||
// With SqlLogin authentication, the defaultSchema property throws an Exception when accessed.
|
||||
// This workaround ensures that a schema name is present by attempting
|
||||
// to get the schema name from the declaration item.
|
||||
// If all fails, the default schema name is assumed to be "dbo"
|
||||
if ((connectionInfo != null && connectionInfo.ConnectionDetails.AuthenticationType.Equals(Constants.SqlLoginAuthenticationType)) && string.IsNullOrEmpty(schemaName))
|
||||
{
|
||||
string fullObjectName = this.GetFullObjectNameFromQuickInfo(quickInfoText, tokenText, caseSensitivity);
|
||||
schemaName = this.GetSchemaFromDatabaseQualifiedName(fullObjectName, tokenText);
|
||||
}
|
||||
Location[] locations = GetSqlObjectDefinition(
|
||||
sqlScriptGettersFromQuickInfo[tokenType.ToLowerInvariant()],
|
||||
tokenText,
|
||||
schemaName,
|
||||
sqlObjectTypesFromQuickInfo[tokenType.ToLowerInvariant()]
|
||||
);
|
||||
DefinitionResult result = new DefinitionResult
|
||||
{
|
||||
IsErrorResult = this.error,
|
||||
Message = this.errorMessage,
|
||||
Locations = locations
|
||||
};
|
||||
return result;
|
||||
}
|
||||
else
|
||||
{
|
||||
// If a type was found but is not in sqlScriptGettersFromQuickInfo, then the type is not supported
|
||||
return GetDefinitionErrorResult(SR.PeekDefinitionTypeNotSupportedError);
|
||||
}
|
||||
}
|
||||
// no definition found
|
||||
return GetDefinitionErrorResult(SR.PeekDefinitionNoResultsError);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Script a object using the type extracted from declarationItem
|
||||
/// </summary>
|
||||
/// <param name="declarationItem">The Declarartion object that matched with the selected token</param>
|
||||
/// <param name="tokenText">The text of the selected token</param>
|
||||
/// <param name="schemaName">Schema name</param>
|
||||
/// <returns></returns>
|
||||
internal DefinitionResult GetDefinitionUsingDeclarationType(DeclarationType type, string databaseQualifiedName, string tokenText, string schemaName)
|
||||
{
|
||||
if (sqlScriptGetters.ContainsKey(type) && sqlObjectTypes.ContainsKey(type))
|
||||
{
|
||||
// With SqlLogin authentication, the defaultSchema property throws an Exception when accessed.
|
||||
// This workaround ensures that a schema name is present by attempting
|
||||
// to get the schema name from the declaration item.
|
||||
// If all fails, the default schema name is assumed to be "dbo"
|
||||
if ((connectionInfo != null && connectionInfo.ConnectionDetails.AuthenticationType.Equals(Constants.SqlLoginAuthenticationType)) && string.IsNullOrEmpty(schemaName))
|
||||
{
|
||||
string fullObjectName = databaseQualifiedName;
|
||||
schemaName = this.GetSchemaFromDatabaseQualifiedName(fullObjectName, tokenText);
|
||||
}
|
||||
Location[] locations = GetSqlObjectDefinition(
|
||||
sqlScriptGetters[type],
|
||||
tokenText,
|
||||
schemaName,
|
||||
sqlObjectTypes[type]
|
||||
);
|
||||
DefinitionResult result = new DefinitionResult
|
||||
{
|
||||
IsErrorResult = this.error,
|
||||
Message = this.errorMessage,
|
||||
Locations = locations
|
||||
};
|
||||
return result;
|
||||
}
|
||||
// If a type was found but is not in sqlScriptGetters, then the type is not supported
|
||||
return GetDefinitionErrorResult(SR.PeekDefinitionTypeNotSupportedError);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Script a object using SMO and write to a file.
|
||||
/// </summary>
|
||||
/// <param name="sqlScriptGetter">Function that returns the SMO scripts for an object</param>
|
||||
/// <param name="objectName">SQL object name</param>
|
||||
/// <param name="schemaName">Schema name or null</param>
|
||||
/// <param name="objectType">Type of SQL object</param>
|
||||
/// <returns>Location object representing URI and range of the script file</returns>
|
||||
internal Location[] GetSqlObjectDefinition(
|
||||
ScriptGetter sqlScriptGetter,
|
||||
string objectName,
|
||||
string schemaName,
|
||||
string objectType)
|
||||
{
|
||||
StringCollection scripts = sqlScriptGetter(objectName, schemaName, null);
|
||||
string tempFileName = (schemaName != null) ? Path.Combine(this.tempPath, string.Format("{0}.{1}.sql", schemaName, objectName))
|
||||
: Path.Combine(this.tempPath, string.Format("{0}.sql", objectName));
|
||||
|
||||
if (scripts != null)
|
||||
{
|
||||
int lineNumber = 0;
|
||||
using (StreamWriter scriptFile = new StreamWriter(File.Open(tempFileName, FileMode.Create, FileAccess.ReadWrite)))
|
||||
{
|
||||
|
||||
foreach (string script in scripts)
|
||||
{
|
||||
string createSyntax = string.Format("CREATE {0}", objectType);
|
||||
if (script.IndexOf(createSyntax, StringComparison.OrdinalIgnoreCase) >= 0)
|
||||
{
|
||||
scriptFile.WriteLine(script);
|
||||
lineNumber = GetStartOfCreate(script, createSyntax);
|
||||
}
|
||||
}
|
||||
}
|
||||
return GetLocationFromFile(tempFileName, lineNumber);
|
||||
}
|
||||
else
|
||||
{
|
||||
this.error = true;
|
||||
this.errorMessage = SR.PeekDefinitionNoResultsError;
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
#region Helper Methods
|
||||
/// <summary>
|
||||
/// Return schema name from the full name of the database. If schema is missing return dbo as schema name.
|
||||
/// </summary>
|
||||
/// <param name="fullObjectName"> The full database qualified name(database.schema.object)</param>
|
||||
/// <param name="objectName"> Object name</param>
|
||||
/// <returns>Schema name</returns>
|
||||
internal string GetSchemaFromDatabaseQualifiedName(string fullObjectName, string objectName)
|
||||
{
|
||||
if (!string.IsNullOrEmpty(fullObjectName))
|
||||
{
|
||||
string[] tokens = fullObjectName.Split('.');
|
||||
for (int i = tokens.Length - 1; i > 0; i--)
|
||||
{
|
||||
if (tokens[i].Equals(objectName))
|
||||
{
|
||||
return tokens[i - 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
return "dbo";
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Convert a file to a location array containing a location object as expected by the extension
|
||||
/// </summary>
|
||||
internal Location[] GetLocationFromFile(string tempFileName, int lineNumber)
|
||||
{
|
||||
// Get absolute Uri based on uri format. This works around a dotnetcore URI bug for linux paths.
|
||||
if (Path.DirectorySeparatorChar.Equals('/'))
|
||||
{
|
||||
tempFileName = "file:" + tempFileName;
|
||||
}
|
||||
else
|
||||
{
|
||||
tempFileName = new Uri(tempFileName).AbsoluteUri;
|
||||
}
|
||||
// Create a location array containing the tempFile Uri, as expected by VSCode.
|
||||
Location[] locations = new[] {
|
||||
new Location {
|
||||
Uri = tempFileName,
|
||||
Range = new Range {
|
||||
Start = new Position { Line = lineNumber, Character = 1},
|
||||
End = new Position { Line = lineNumber + 1, Character = 1}
|
||||
}
|
||||
}
|
||||
};
|
||||
return locations;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get line number for the create statement
|
||||
/// </summary>
|
||||
private int GetStartOfCreate(string script, string createString)
|
||||
{
|
||||
string[] lines = script.Split(new string[] { Environment.NewLine }, StringSplitOptions.None);
|
||||
for (int lineNumber = 0; lineNumber < lines.Length; lineNumber++)
|
||||
{
|
||||
if (lines[lineNumber].IndexOf(createString, StringComparison.OrdinalIgnoreCase) >= 0)
|
||||
{
|
||||
return lineNumber;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
/// <summary>
|
||||
/// Helper method to create definition error result object
|
||||
/// </summary>
|
||||
/// <param name="errorMessage">Error message</param>
|
||||
/// <returns> DefinitionResult</returns>
|
||||
internal DefinitionResult GetDefinitionErrorResult(string errorMessage)
|
||||
{
|
||||
return new DefinitionResult
|
||||
{
|
||||
IsErrorResult = true,
|
||||
Message = errorMessage,
|
||||
Locations = null
|
||||
};
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Return full object name(database.schema.objectName) from the quickInfo text("type database.schema.objectName")
|
||||
/// </summary>
|
||||
/// <param name="quickInfoText">QuickInfo Text for this token</param>
|
||||
/// <param name="tokenText">Token Text</param>
|
||||
/// <param name="caseSensitivity">StringComparison enum</param>
|
||||
/// <returns></returns>
|
||||
internal string GetFullObjectNameFromQuickInfo(string quickInfoText, string tokenText, StringComparison caseSensitivity)
|
||||
{
|
||||
if (string.IsNullOrEmpty(quickInfoText) || string.IsNullOrEmpty(tokenText))
|
||||
{
|
||||
return null;
|
||||
}
|
||||
// extract full object name from quickInfo text
|
||||
string[] tokens = quickInfoText.Split(' ');
|
||||
List<string> tokenList = tokens.Where(el => el.IndexOf(tokenText, caseSensitivity) >= 0).ToList();
|
||||
return (tokenList?.Count() > 0) ? tokenList[0] : null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Return token type from the quickInfo text("type database.schema.objectName")
|
||||
/// </summary>
|
||||
/// <param name="quickInfoText">QuickInfo Text for this token</param>
|
||||
/// <param name="tokenText"Token Text></param>
|
||||
/// <param name="caseSensitivity">StringComparison enum</param>
|
||||
/// <returns></returns>
|
||||
internal string GetTokenTypeFromQuickInfo(string quickInfoText, string tokenText, StringComparison caseSensitivity)
|
||||
{
|
||||
if (string.IsNullOrEmpty(quickInfoText) || string.IsNullOrEmpty(tokenText))
|
||||
{
|
||||
return null;
|
||||
}
|
||||
// extract string denoting the token type from quickInfo text
|
||||
string[] tokens = quickInfoText.Split(' ');
|
||||
List<int> indexList = tokens.Select((s, i) => new { i, s }).Where(el => (el.s).IndexOf(tokenText, caseSensitivity) >= 0).Select(el => el.i).ToList();
|
||||
return (indexList?.Count() > 0) ? String.Join(" ", tokens.Take(indexList[0])) : null;
|
||||
}
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// Wrapper method that calls Resolver.GetQuickInfo
|
||||
/// </summary>
|
||||
internal string GetQuickInfoForToken(ParseResult parseResult, int parserLine, int parserColumn, IMetadataDisplayInfoProvider metadataDisplayInfoProvider)
|
||||
{
|
||||
if (parseResult == null || metadataDisplayInfoProvider == null)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
Babel.CodeObjectQuickInfo quickInfo = Resolver.GetQuickInfo(
|
||||
parseResult, parserLine, parserColumn, metadataDisplayInfoProvider);
|
||||
return quickInfo?.Text;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Wrapper method that calls Resolver.FindCompletions
|
||||
/// </summary>
|
||||
/// <param name="parseResult"></param>
|
||||
/// <param name="parserLine"></param>
|
||||
/// <param name="parserColumn"></param>
|
||||
/// <param name="metadataDisplayInfoProvider"></param>
|
||||
/// <returns></returns>
|
||||
internal IEnumerable<Declaration> GetCompletionsForToken(ParseResult parseResult, int parserLine, int parserColumn, IMetadataDisplayInfoProvider metadataDisplayInfoProvider)
|
||||
{
|
||||
if (parseResult == null || metadataDisplayInfoProvider == null)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
return Resolver.FindCompletions(
|
||||
parseResult, parserLine, parserColumn, metadataDisplayInfoProvider);
|
||||
}
|
||||
#endregion
|
||||
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user