// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. // using System; using System.Collections.Generic; using System.Collections.Specialized; using System.Data.Common; using System.IO; using System.Linq; using Microsoft.SqlServer.Management.Common; using Microsoft.SqlServer.Management.Smo; using Microsoft.SqlServer.Management.SqlParser.Intellisense; using Microsoft.SqlServer.Management.SqlParser.MetadataProvider; using Microsoft.SqlServer.Management.SqlParser.Parser; using Microsoft.SqlTools.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.LanguageServices; using Microsoft.SqlTools.ServiceLayer.QueryExecution; using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; using Microsoft.SqlTools.Utility; using ConnectionType = Microsoft.SqlTools.ServiceLayer.Connection.ConnectionType; using Location = Microsoft.SqlTools.ServiceLayer.Workspace.Contracts.Location; namespace Microsoft.SqlTools.ServiceLayer.Scripting { internal partial class Scripter { private bool error; private string errorMessage; private ServerConnection serverConnection; private ConnectionInfo connectionInfo; private Database database; private string tempPath; internal delegate StringCollection ScriptGetter(string objectName, string schemaName, ScriptingOptions scriptingOptions); // Dictionary that holds the script getter for each type private Dictionary sqlScriptGetters = new Dictionary(); private Dictionary sqlScriptGettersFromQuickInfo = new Dictionary(); // Dictionary that holds the object name (as appears on the TSQL create statement) private Dictionary sqlObjectTypes = new Dictionary(); private Dictionary sqlObjectTypesFromQuickInfo = new Dictionary(); /// /// Initialize a Peek Definition helper object /// /// SMO Server connection internal Scripter(ServerConnection serverConnection, ConnectionInfo connInfo) { this.serverConnection = serverConnection; this.connectionInfo = connInfo; this.tempPath = FileUtilities.GetPeekDefinitionTempFolder(); Initialize(); } internal Database Database { get { if (this.database == null) { if (this.serverConnection != null && !string.IsNullOrEmpty(this.serverConnection.DatabaseName)) { try { // Reuse existing connection Server server = new Server(this.serverConnection); // The default database name is the database name of the server connection string dbName = this.serverConnection.DatabaseName; if (this.connectionInfo != null) { // If there is a query DbConnection, use that connection to get the database name // This is preferred since it has the most current database name (in case of database switching) DbConnection connection; if (connectionInfo.TryGetConnection(ConnectionType.Query, out connection)) { if (!string.IsNullOrEmpty(connection.Database)) { dbName = connection.Database; } } } this.database = new Database(server, dbName); this.database.Refresh(); } catch (ConnectionFailureException cfe) { Logger.Write(LogLevel.Error, "Exception at PeekDefinition Database.get() : " + cfe.Message); this.error = true; this.errorMessage = (connectionInfo != null && connectionInfo.IsAzure) ? SR.PeekDefinitionAzureError(cfe.Message) : SR.PeekDefinitionError(cfe.Message); return null; } catch (Exception ex) { Logger.Write(LogLevel.Error, "Exception at PeekDefinition Database.get() : " + ex.Message); this.error = true; this.errorMessage = SR.PeekDefinitionError(ex.Message); return null; } } } return this.database; } } /// /// Add the given type, scriptgetter and the typeName string to the respective dictionaries /// private void AddSupportedType(DeclarationType type, ScriptGetter scriptGetter, string typeName, string quickInfoType) { sqlScriptGetters.Add(type, scriptGetter); sqlObjectTypes.Add(type, typeName); if (!string.IsNullOrEmpty(quickInfoType)) { sqlScriptGettersFromQuickInfo.Add(quickInfoType.ToLowerInvariant(), scriptGetter); sqlObjectTypesFromQuickInfo.Add(quickInfoType.ToLowerInvariant(), typeName); } } /// /// Get the script of the selected token based on the type of the token /// /// /// /// /// Location object of the script file internal DefinitionResult GetScript(ParseResult parseResult, Position position, IMetadataDisplayInfoProvider metadataDisplayInfoProvider, string tokenText, string schemaName) { int parserLine = position.Line + 1; int parserColumn = position.Character + 1; // Get DeclarationItems from The Intellisense Resolver for the selected token. The type of the selected token is extracted from the declarationItem. IEnumerable declarationItems = GetCompletionsForToken(parseResult, parserLine, parserColumn, metadataDisplayInfoProvider); if (declarationItems != null && declarationItems.Count() > 0) { foreach (Declaration declarationItem in declarationItems) { if (declarationItem.Title == null) { continue; } if (this.Database == null) { return GetDefinitionErrorResult(SR.PeekDefinitionDatabaseError); } StringComparison caseSensitivity = this.Database.CaseSensitive ? StringComparison.Ordinal : StringComparison.OrdinalIgnoreCase; // if declarationItem matches the selected token, script SMO using that type if (declarationItem.Title.Equals(tokenText, caseSensitivity)) { return GetDefinitionUsingDeclarationType(declarationItem.Type, declarationItem.DatabaseQualifiedName, tokenText, schemaName); } } } else { // if no declarationItem matched the selected token, we try to find the type of the token using QuickInfo.Text string quickInfoText = GetQuickInfoForToken(parseResult, parserLine, parserColumn, metadataDisplayInfoProvider); return GetDefinitionUsingQuickInfoText(quickInfoText, tokenText, schemaName); } // no definition found return GetDefinitionErrorResult(SR.PeekDefinitionNoResultsError); } /// /// Script an object using the type extracted from quickInfo Text /// /// the text from the quickInfo for the selected token /// The text of the selected token /// Schema name /// internal DefinitionResult GetDefinitionUsingQuickInfoText(string quickInfoText, string tokenText, string schemaName) { if (this.Database == null) { return GetDefinitionErrorResult(SR.PeekDefinitionDatabaseError); } StringComparison caseSensitivity = this.Database.CaseSensitive ? StringComparison.Ordinal : StringComparison.OrdinalIgnoreCase; string tokenType = GetTokenTypeFromQuickInfo(quickInfoText, tokenText, caseSensitivity); if (tokenType != null) { if (sqlScriptGettersFromQuickInfo.ContainsKey(tokenType.ToLowerInvariant())) { // With SqlLogin authentication, the defaultSchema property throws an Exception when accessed. // This workaround ensures that a schema name is present by attempting // to get the schema name from the declaration item. // If all fails, the default schema name is assumed to be "dbo" if ((connectionInfo != null && connectionInfo.ConnectionDetails.AuthenticationType.Equals(Constants.SqlLoginAuthenticationType)) && string.IsNullOrEmpty(schemaName)) { string fullObjectName = this.GetFullObjectNameFromQuickInfo(quickInfoText, tokenText, caseSensitivity); schemaName = this.GetSchemaFromDatabaseQualifiedName(fullObjectName, tokenText); } Location[] locations = GetSqlObjectDefinition( sqlScriptGettersFromQuickInfo[tokenType.ToLowerInvariant()], tokenText, schemaName, sqlObjectTypesFromQuickInfo[tokenType.ToLowerInvariant()] ); DefinitionResult result = new DefinitionResult { IsErrorResult = this.error, Message = this.errorMessage, Locations = locations }; return result; } else { // If a type was found but is not in sqlScriptGettersFromQuickInfo, then the type is not supported return GetDefinitionErrorResult(SR.PeekDefinitionTypeNotSupportedError); } } // no definition found return GetDefinitionErrorResult(SR.PeekDefinitionNoResultsError); } /// /// Script a object using the type extracted from declarationItem /// /// The Declarartion object that matched with the selected token /// The text of the selected token /// Schema name /// internal DefinitionResult GetDefinitionUsingDeclarationType(DeclarationType type, string databaseQualifiedName, string tokenText, string schemaName) { if (sqlScriptGetters.ContainsKey(type) && sqlObjectTypes.ContainsKey(type)) { // With SqlLogin authentication, the defaultSchema property throws an Exception when accessed. // This workaround ensures that a schema name is present by attempting // to get the schema name from the declaration item. // If all fails, the default schema name is assumed to be "dbo" if ((connectionInfo != null && connectionInfo.ConnectionDetails.AuthenticationType.Equals(Constants.SqlLoginAuthenticationType)) && string.IsNullOrEmpty(schemaName)) { string fullObjectName = databaseQualifiedName; schemaName = this.GetSchemaFromDatabaseQualifiedName(fullObjectName, tokenText); } Location[] locations = GetSqlObjectDefinition( sqlScriptGetters[type], tokenText, schemaName, sqlObjectTypes[type] ); DefinitionResult result = new DefinitionResult { IsErrorResult = this.error, Message = this.errorMessage, Locations = locations }; return result; } // If a type was found but is not in sqlScriptGetters, then the type is not supported return GetDefinitionErrorResult(SR.PeekDefinitionTypeNotSupportedError); } /// /// Script a object using SMO and write to a file. /// /// Function that returns the SMO scripts for an object /// SQL object name /// Schema name or null /// Type of SQL object /// Location object representing URI and range of the script file internal Location[] GetSqlObjectDefinition( ScriptGetter sqlScriptGetter, string objectName, string schemaName, string objectType) { StringCollection scripts = sqlScriptGetter(objectName, schemaName, null); string tempFileName = (schemaName != null) ? Path.Combine(this.tempPath, string.Format("{0}.{1}.sql", schemaName, objectName)) : Path.Combine(this.tempPath, string.Format("{0}.sql", objectName)); if (scripts != null) { int lineNumber = 0; using (StreamWriter scriptFile = new StreamWriter(File.Open(tempFileName, FileMode.Create, FileAccess.ReadWrite))) { foreach (string script in scripts) { string createSyntax = string.Format("CREATE {0}", objectType); if (script.IndexOf(createSyntax, StringComparison.OrdinalIgnoreCase) >= 0) { scriptFile.WriteLine(script); lineNumber = GetStartOfCreate(script, createSyntax); } } } return GetLocationFromFile(tempFileName, lineNumber); } else { this.error = true; this.errorMessage = SR.PeekDefinitionNoResultsError; return null; } } #region Helper Methods /// /// Return schema name from the full name of the database. If schema is missing return dbo as schema name. /// /// The full database qualified name(database.schema.object) /// Object name /// Schema name internal string GetSchemaFromDatabaseQualifiedName(string fullObjectName, string objectName) { if (!string.IsNullOrEmpty(fullObjectName)) { string[] tokens = fullObjectName.Split('.'); for (int i = tokens.Length - 1; i > 0; i--) { if (tokens[i].Equals(objectName)) { return tokens[i - 1]; } } } return "dbo"; } /// /// Convert a file to a location array containing a location object as expected by the extension /// internal Location[] GetLocationFromFile(string tempFileName, int lineNumber) { // Get absolute Uri based on uri format. This works around a dotnetcore URI bug for linux paths. if (Path.DirectorySeparatorChar.Equals('/')) { tempFileName = "file:" + tempFileName; } else { tempFileName = new Uri(tempFileName).AbsoluteUri; } // Create a location array containing the tempFile Uri, as expected by VSCode. Location[] locations = new[] { new Location { Uri = tempFileName, Range = new Range { Start = new Position { Line = lineNumber, Character = 1}, End = new Position { Line = lineNumber + 1, Character = 1} } } }; return locations; } /// /// Get line number for the create statement /// private int GetStartOfCreate(string script, string createString) { string[] lines = script.Split(new string[] { Environment.NewLine }, StringSplitOptions.None); for (int lineNumber = 0; lineNumber < lines.Length; lineNumber++) { if (lines[lineNumber].IndexOf(createString, StringComparison.OrdinalIgnoreCase) >= 0) { return lineNumber; } } return 0; } /// /// Helper method to create definition error result object /// /// Error message /// DefinitionResult internal DefinitionResult GetDefinitionErrorResult(string errorMessage) { return new DefinitionResult { IsErrorResult = true, Message = errorMessage, Locations = null }; } /// /// Return full object name(database.schema.objectName) from the quickInfo text("type database.schema.objectName") /// /// QuickInfo Text for this token /// Token Text /// StringComparison enum /// internal string GetFullObjectNameFromQuickInfo(string quickInfoText, string tokenText, StringComparison caseSensitivity) { if (string.IsNullOrEmpty(quickInfoText) || string.IsNullOrEmpty(tokenText)) { return null; } // extract full object name from quickInfo text string[] tokens = quickInfoText.Split(' '); List tokenList = tokens.Where(el => el.IndexOf(tokenText, caseSensitivity) >= 0).ToList(); return (tokenList?.Count() > 0) ? tokenList[0] : null; } /// /// Return token type from the quickInfo text("type database.schema.objectName") /// /// QuickInfo Text for this token /// /// StringComparison enum /// internal string GetTokenTypeFromQuickInfo(string quickInfoText, string tokenText, StringComparison caseSensitivity) { if (string.IsNullOrEmpty(quickInfoText) || string.IsNullOrEmpty(tokenText)) { return null; } // extract string denoting the token type from quickInfo text string[] tokens = quickInfoText.Split(' '); List indexList = tokens.Select((s, i) => new { i, s }).Where(el => (el.s).IndexOf(tokenText, caseSensitivity) >= 0).Select(el => el.i).ToList(); return (indexList?.Count() > 0) ? String.Join(" ", tokens.Take(indexList[0])) : null; } /// /// Wrapper method that calls Resolver.GetQuickInfo /// internal string GetQuickInfoForToken(ParseResult parseResult, int parserLine, int parserColumn, IMetadataDisplayInfoProvider metadataDisplayInfoProvider) { if (parseResult == null || metadataDisplayInfoProvider == null) { return null; } Babel.CodeObjectQuickInfo quickInfo = Resolver.GetQuickInfo( parseResult, parserLine, parserColumn, metadataDisplayInfoProvider); return quickInfo?.Text; } /// /// Wrapper method that calls Resolver.FindCompletions /// /// /// /// /// /// internal IEnumerable GetCompletionsForToken(ParseResult parseResult, int parserLine, int parserColumn, IMetadataDisplayInfoProvider metadataDisplayInfoProvider) { if (parseResult == null || metadataDisplayInfoProvider == null) { return null; } return Resolver.FindCompletions( parseResult, parserLine, parserColumn, metadataDisplayInfoProvider); } #endregion } }