Task/script refactor (#446)

* scripting working with race conditions

* new service works with no race conditions

* use new scripting service and commented out tests

* refactored peek definition to use mssql-scripter

* fixed peek definition tests

* removed auto gen comment

* fixed peek definition highlighting bug

* made scripting async and fixed event handlers

* fixed tests (without cancel and plan notifs)

* removed dead code

* added nuget package

* CR comments + select script service implementation

* minor fixes and added test

* CR comments and script select

* added unit tests

* code review comments and cleanup
This commit is contained in:
Aditya Bist
2017-10-02 12:02:43 -07:00
committed by GitHub
parent 9d898f0d0c
commit e7756b0bf1
24 changed files with 718 additions and 993 deletions

View File

@@ -5,7 +5,6 @@
using System;
using System.Collections.Generic;
using System.Collections.Specialized;
using System.Data.Common;
using System.IO;
using System.Linq;
@@ -19,10 +18,13 @@ using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.LanguageServices;
using Microsoft.SqlTools.ServiceLayer.Utility;
using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts;
using Microsoft.SqlTools.ServiceLayer.Scripting.Contracts;
using Microsoft.SqlTools.Utility;
using ConnectionType = Microsoft.SqlTools.ServiceLayer.Connection.ConnectionType;
using Location = Microsoft.SqlTools.ServiceLayer.Workspace.Contracts.Location;
using Microsoft.SqlServer.Management.Sdk.Sfc;
using System.Text;
using System.Data;
namespace Microsoft.SqlTools.ServiceLayer.Scripting
{
@@ -35,20 +37,19 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
private Database database;
private string tempPath;
internal delegate StringCollection ScriptGetter(string objectName, string schemaName, ScriptingOptions scriptingOptions);
// Dictionary that holds the script getter for each type
private Dictionary<DeclarationType, ScriptGetter> sqlScriptGetters =
new Dictionary<DeclarationType, ScriptGetter>();
private Dictionary<string, ScriptGetter> sqlScriptGettersFromQuickInfo =
new Dictionary<string, ScriptGetter>();
// Dictionary that holds the object name (as appears on the TSQL create statement)
private Dictionary<DeclarationType, string> sqlObjectTypes = new Dictionary<DeclarationType, string>();
private Dictionary<string, string> sqlObjectTypesFromQuickInfo = new Dictionary<string, string>();
private Dictionary<DatabaseEngineEdition, string> targetDatabaseEngineEditionMap = new Dictionary<DatabaseEngineEdition, string>();
private Dictionary<int, string> serverVersionMap = new Dictionary<int, string>();
private Dictionary<string, string> objectScriptMap = new Dictionary<string, string>();
internal Scripter() {}
/// <summary>
/// Initialize a Peek Definition helper object
/// </summary>
@@ -60,7 +61,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
this.tempPath = FileUtilities.GetPeekDefinitionTempFolder();
Initialize();
}
internal Database Database
{
get
@@ -114,13 +115,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
/// <summary>
/// Add the given type, scriptgetter and the typeName string to the respective dictionaries
/// </summary>
private void AddSupportedType(DeclarationType type, ScriptGetter scriptGetter, string typeName, string quickInfoType)
private void AddSupportedType(DeclarationType type, string typeName, string quickInfoType, Type smoObjectType)
{
sqlScriptGetters.Add(type, scriptGetter);
sqlObjectTypes.Add(type, typeName);
if (!string.IsNullOrEmpty(quickInfoType))
{
sqlScriptGettersFromQuickInfo.Add(quickInfoType.ToLowerInvariant(), scriptGetter);
sqlObjectTypesFromQuickInfo.Add(quickInfoType.ToLowerInvariant(), typeName);
}
}
@@ -186,7 +185,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
string tokenType = GetTokenTypeFromQuickInfo(quickInfoText, tokenText, caseSensitivity);
if (tokenType != null)
{
if (sqlScriptGettersFromQuickInfo.ContainsKey(tokenType.ToLowerInvariant()))
if (sqlObjectTypesFromQuickInfo.ContainsKey(tokenType.ToLowerInvariant()))
{
// With SqlLogin authentication, the defaultSchema property throws an Exception when accessed.
// This workaround ensures that a schema name is present by attempting
@@ -198,7 +197,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
schemaName = this.GetSchemaFromDatabaseQualifiedName(fullObjectName, tokenText);
}
Location[] locations = GetSqlObjectDefinition(
sqlScriptGettersFromQuickInfo[tokenType.ToLowerInvariant()],
tokenText,
schemaName,
sqlObjectTypesFromQuickInfo[tokenType.ToLowerInvariant()]
@@ -224,13 +222,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
/// <summary>
/// Script a object using the type extracted from declarationItem
/// </summary>
/// <param name="declarationItem">The Declarartion object that matched with the selected token</param>
/// <param name="declarationItem">The Declaration object that matched with the selected token</param>
/// <param name="tokenText">The text of the selected token</param>
/// <param name="schemaName">Schema name</param>
/// <returns></returns>
internal DefinitionResult GetDefinitionUsingDeclarationType(DeclarationType type, string databaseQualifiedName, string tokenText, string schemaName)
{
if (sqlScriptGetters.ContainsKey(type) && sqlObjectTypes.ContainsKey(type))
if (sqlObjectTypes.ContainsKey(type))
{
// With SqlLogin authentication, the defaultSchema property throws an Exception when accessed.
// This workaround ensures that a schema name is present by attempting
@@ -242,7 +240,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
schemaName = this.GetSchemaFromDatabaseQualifiedName(fullObjectName, tokenText);
}
Location[] locations = GetSqlObjectDefinition(
sqlScriptGetters[type],
tokenText,
schemaName,
sqlObjectTypes[type]
@@ -268,32 +265,43 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
/// <param name="objectType">Type of SQL object</param>
/// <returns>Location object representing URI and range of the script file</returns>
internal Location[] GetSqlObjectDefinition(
ScriptGetter sqlScriptGetter,
string objectName,
string schemaName,
string objectType)
{
StringCollection scripts = sqlScriptGetter(objectName, schemaName, null);
// script file destination
string tempFileName = (schemaName != null) ? Path.Combine(this.tempPath, string.Format("{0}.{1}.sql", schemaName, objectName))
: Path.Combine(this.tempPath, string.Format("{0}.sql", objectName));
if (scripts != null)
{
int lineNumber = 0;
using (StreamWriter scriptFile = new StreamWriter(File.Open(tempFileName, FileMode.Create, FileAccess.ReadWrite)))
{
ScriptingScriptOperation operation = InitScriptOperation(objectName, schemaName, objectType);
operation.Execute();
string script = operation.ScriptText;
foreach (string script in scripts)
bool objectFound = false;
int createStatementLineNumber = 0;
File.WriteAllText(tempFileName, script);
string[] lines = File.ReadAllLines(tempFileName);
int lineCount = 0;
string createSyntax = null;
if (objectScriptMap.ContainsKey(objectType.ToLower()))
{
createSyntax = string.Format("CREATE {0}", objectScriptMap[objectType.ToLower()]);
foreach (string line in lines)
{
if (LineContainsObject(line, objectName, createSyntax))
{
string createSyntax = string.Format("CREATE {0}", objectType);
if (script.IndexOf(createSyntax, StringComparison.OrdinalIgnoreCase) >= 0)
{
scriptFile.WriteLine(script);
lineNumber = GetStartOfCreate(script, createSyntax);
}
createStatementLineNumber = lineCount;
objectFound = true;
break;
}
lineCount++;
}
return GetLocationFromFile(tempFileName, lineNumber);
}
if (objectFound)
{
Location[] locations = GetLocationFromFile(tempFileName, createStatementLineNumber);
return locations;
}
else
{
@@ -341,33 +349,21 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
tempFileName = new Uri(tempFileName).AbsoluteUri;
}
// Create a location array containing the tempFile Uri, as expected by VSCode.
Location[] locations = new[] {
new Location {
Location[] locations = new[]
{
new Location
{
Uri = tempFileName,
Range = new Range {
Start = new Position { Line = lineNumber, Character = 1},
End = new Position { Line = lineNumber + 1, Character = 1}
Range = new Range
{
Start = new Position { Line = lineNumber, Character = 0},
End = new Position { Line = lineNumber + 1, Character = 0}
}
}
};
return locations;
}
/// <summary>
/// Get line number for the create statement
/// </summary>
private int GetStartOfCreate(string script, string createString)
{
string[] lines = script.Split(new string[] { Environment.NewLine }, StringSplitOptions.None);
for (int lineNumber = 0; lineNumber < lines.Length; lineNumber++)
{
if (lines[lineNumber].IndexOf(createString, StringComparison.OrdinalIgnoreCase) >= 0)
{
return lineNumber;
}
}
return 0;
}
/// <summary>
/// Helper method to create definition error result object
/// </summary>
@@ -453,7 +449,363 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
return Resolver.FindCompletions(
parseResult, parserLine, parserColumn, metadataDisplayInfoProvider);
}
#endregion
/// <summary>
/// Wrapper method that calls Resolver.FindCompletions
/// </summary>
/// <param name="objectName"></param>
/// <param name="schemaName"></param>
/// <param name="objectType"></param>
/// <param name="tempFileName"></param>
/// <returns></returns>
internal ScriptingScriptOperation InitScriptOperation(string objectName, string schemaName, string objectType)
{
// object that has to be scripted
ScriptingObject scriptingObject = new ScriptingObject
{
Name = objectName,
Schema = schemaName,
Type = objectType
};
// scripting options
ScriptOptions options = new ScriptOptions
{
ScriptCreateDrop = "ScriptCreate",
TypeOfDataToScript = "SchemaOnly",
ScriptStatistics = "ScriptStatsNone",
TargetDatabaseEngineEdition = GetTargetDatabaseEngineEdition(),
TargetDatabaseEngineType = GetTargetDatabaseEngineType(),
ScriptCompatibilityOption = GetScriptCompatibilityOption(),
IncludeIfNotExists = true
};
List<ScriptingObject> objectList = new List<ScriptingObject>();
objectList.Add(scriptingObject);
// create parameters for the scripting operation
ScriptingParams parameters = new ScriptingParams
{
ConnectionString = ConnectionService.BuildConnectionString(this.connectionInfo.ConnectionDetails),
ScriptingObjects = objectList,
ScriptOptions = options,
ScriptDestination = "ToEditor"
};
return new ScriptingScriptOperation(parameters);
}
internal string GetTargetDatabaseEngineEdition()
{
DatabaseEngineEdition dbEngineEdition = this.serverConnection.DatabaseEngineEdition;
string dbEngineEditionString = targetDatabaseEngineEditionMap[dbEngineEdition];
return (dbEngineEditionString != null) ? dbEngineEditionString : "SqlServerEnterpriseEdition";
}
internal string GetScriptCompatibilityOption()
{
int serverVersion = this.serverConnection.ServerVersion.Major;
string dbEngineTypeString = serverVersionMap[serverVersion];
return (dbEngineTypeString != null) ? dbEngineTypeString : "Script140Compat";
}
internal string GetTargetDatabaseEngineType()
{
return connectionInfo.IsAzure ? "SqlAzure" : "SingleInstance";
}
internal bool LineContainsObject(string line, string objectName, string createSyntax)
{
if (line.IndexOf(createSyntax, StringComparison.OrdinalIgnoreCase) >= 0 &&
line.IndexOf(objectName, StringComparison.OrdinalIgnoreCase) >=0)
{
return true;
}
return false;
}
internal static class ScriptingGlobals
{
/// <summary>
/// Left delimiter for an named object
/// </summary>
public const char LeftDelimiter = '[';
/// <summary>
/// right delimiter for a named object
/// </summary>
public const char RightDelimiter = ']';
}
internal static class ScriptingUtils
{
/// <summary>
/// Quote the name of a given sql object.
/// </summary>
/// <param name="sqlObject">object</param>
/// <returns>quoted object name</returns>
internal static string QuoteObjectName(string sqlObject)
{
return QuoteObjectName(sqlObject, ']');
}
/// <summary>
/// Quotes the name of a given sql object
/// </summary>
/// <param name="sqlObject">object</param>
/// <param name="quote">quote to use</param>
/// <returns></returns>
internal static string QuoteObjectName(string sqlObject, char quote)
{
int len = sqlObject.Length;
StringBuilder result = new StringBuilder(sqlObject.Length);
for (int i = 0; i < len; i++)
{
if (sqlObject[i] == quote)
{
result.Append(quote);
}
result.Append(sqlObject[i]);
}
return result.ToString();
}
/// <summary>
/// Returns the value whether the server supports XTP or not s
internal static bool IsXTPSupportedOnServer(Server server)
{
bool isXTPSupported = false;
if (server.ConnectionContext.ExecuteScalar("SELECT SERVERPROPERTY('IsXTPSupported')") != DBNull.Value)
{
isXTPSupported = server.IsXTPSupported;
}
return isXTPSupported;
}
}
internal static string SelectAllValuesFromTransmissionQueue(Urn urn)
{
string script = string.Empty;
StringBuilder selectQuery = new StringBuilder();
/*
SELECT TOP *, casted_message_body =
CASE MESSAGE_TYPE_NAME WHEN 'X'
THEN CAST(MESSAGE_BODY AS NVARCHAR(MAX))
ELSE MESSAGE_BODY
END
FROM [new].[sys].[transmission_queue]
*/
selectQuery.Append("SELECT TOP (1000) ");
selectQuery.Append("*, casted_message_body = \r\nCASE message_type_name WHEN 'X' \r\n THEN CAST(message_body AS NVARCHAR(MAX)) \r\n ELSE message_body \r\nEND \r\n");
// from clause
selectQuery.Append("FROM ");
Urn dbUrn = urn;
// database
while (dbUrn.Parent != null && dbUrn.Type != "Database")
{
dbUrn = dbUrn.Parent;
}
selectQuery.AppendFormat("{0}{1}{2}",
ScriptingGlobals.LeftDelimiter,
ScriptingUtils.QuoteObjectName(dbUrn.GetAttribute("Name"), ScriptingGlobals.RightDelimiter),
ScriptingGlobals.RightDelimiter);
//SYS
selectQuery.AppendFormat(".{0}sys{1}",
ScriptingGlobals.LeftDelimiter,
ScriptingGlobals.RightDelimiter);
//TRANSMISSION QUEUE
selectQuery.AppendFormat(".{0}transmission_queue{1}",
ScriptingGlobals.LeftDelimiter,
ScriptingGlobals.RightDelimiter);
script = selectQuery.ToString();
return script;
}
internal static string SelectAllValues(Urn urn)
{
string script = string.Empty;
StringBuilder selectQuery = new StringBuilder();
selectQuery.Append("SELECT TOP (1000) ");
selectQuery.Append("*, casted_message_body = \r\nCASE message_type_name WHEN 'X' \r\n THEN CAST(message_body AS NVARCHAR(MAX)) \r\n ELSE message_body \r\nEND \r\n");
// from clause
selectQuery.Append("FROM ");
Urn dbUrn = urn;
// database
while (dbUrn.Parent != null && dbUrn.Type != "Database")
{
dbUrn = dbUrn.Parent;
}
selectQuery.AppendFormat("{0}{1}{2}",
ScriptingGlobals.LeftDelimiter,
ScriptingUtils.QuoteObjectName(dbUrn.GetAttribute("Name"), ScriptingGlobals.RightDelimiter),
ScriptingGlobals.RightDelimiter);
// schema
selectQuery.AppendFormat(".{0}{1}{2}",
ScriptingGlobals.LeftDelimiter,
ScriptingUtils.QuoteObjectName(urn.GetAttribute("Schema"), ScriptingGlobals.RightDelimiter),
ScriptingGlobals.RightDelimiter);
// object
selectQuery.AppendFormat(".{0}{1}{2}",
ScriptingGlobals.LeftDelimiter,
ScriptingUtils.QuoteObjectName(urn.GetAttribute("Name"), ScriptingGlobals.RightDelimiter),
ScriptingGlobals.RightDelimiter);
//Adding no lock in the end.
selectQuery.AppendFormat(" WITH(NOLOCK)");
script = selectQuery.ToString();
return script;
}
internal DataTable GetColumnNames(Server server, Urn urn, bool isDw)
{
List<string> filterExpressions = new List<string>();
if (server.Version.Major >= 10)
{
// We don't have to include sparce columns as all the sparce columns data.
// Can be obtain from column set columns.
filterExpressions.Add("@IsSparse=0");
}
// Check if we're called for EDIT for SQL2016+/Sterling+.
// We need to omit temporal columns if such are present on this table.
if (server.Version.Major >= 13 || (DatabaseEngineType.SqlAzureDatabase == server.DatabaseEngineType && server.Version.Major >= 12))
{
// We're called in order to generate a list of columns for EDIT TOP N rows.
// Don't return auto-generated, auto-populated, read-only temporal columns.
filterExpressions.Add("@GeneratedAlwaysType=0");
}
// Check if we're called for SQL2017/Sterling+.
// We need to omit graph internal columns if such are present on this table.
if (server.Version.Major >= 14 || (DatabaseEngineType.SqlAzureDatabase == server.DatabaseEngineType && !isDw))
{
// from Smo.GraphType:
// 0 = None
// 1 = GraphId
// 2 = GraphIdComputed
// 3 = GraphFromId
// 4 = GraphFromObjId
// 5 = GraphFromIdComputed
// 6 = GraphToId
// 7 = GraphToObjId
// 8 = GraphToIdComputed
//
// We only want to show types 0, 2, 5, and 8:
filterExpressions.Add("(@GraphType=0 or @GraphType=2 or @GraphType=5 or @GraphType=8)");
}
Request request = new Request();
// If we have any filters on the columns, add them.
if (filterExpressions.Count > 0)
{
request.Urn = String.Format("{0}/Column[{1}]", urn.ToString(), string.Join(" and ", filterExpressions.ToArray()));
}
else
{
request.Urn = String.Format("{0}/Column", urn.ToString());
}
request.Fields = new String[] { "Name" };
// get the columns in the order they were created
OrderBy order = new OrderBy();
order.Dir = OrderBy.Direction.Asc;
order.Field = "ID";
request.OrderByList = new OrderBy[] { order };
Enumerator en = new Enumerator();
// perform the query.
DataTable dt = null;
EnumResult result = en.Process(server.ConnectionContext, request);
if (result.Type == ResultType.DataTable)
{
dt = result;
}
else
{
dt = ((DataSet)result).Tables[0];
}
return dt;
}
internal string SelectFromTableOrView(Server server, Urn urn, bool isDw)
{
string script = string.Empty;
DataTable dt = GetColumnNames(server, urn, isDw);
StringBuilder selectQuery = new StringBuilder();
// build the first line
if ((dt != null) && (dt.Rows.Count > 0))
{
selectQuery.Append("SELECT TOP (1000) ");
// first column
selectQuery.AppendFormat("{0}{1}{2}\r\n",
ScriptingGlobals.LeftDelimiter,
ScriptingUtils.QuoteObjectName(dt.Rows[0][0] as string, ScriptingGlobals.RightDelimiter),
ScriptingGlobals.RightDelimiter);
// add all other columns on separate lines. Make the names align.
for (int i = 1; i < dt.Rows.Count; i++)
{
selectQuery.AppendFormat(" ,{0}{1}{2}\r\n",
ScriptingGlobals.LeftDelimiter,
ScriptingUtils.QuoteObjectName(dt.Rows[i][0] as string, ScriptingGlobals.RightDelimiter),
ScriptingGlobals.RightDelimiter);
}
}
else
{
selectQuery.Append("SELECT TOP (1000) * ");
}
// from clause
selectQuery.Append(" FROM ");
if(server.ServerType != DatabaseEngineType.SqlAzureDatabase)
{ //Azure doesn't allow qualifying object names with the DB, so only add it on if we're not in Azure
// database URN
Urn dbUrn = urn.Parent;
selectQuery.AppendFormat("{0}{1}{2}.",
ScriptingGlobals.LeftDelimiter,
ScriptingUtils.QuoteObjectName(dbUrn.GetAttribute("Name"), ScriptingGlobals.RightDelimiter),
ScriptingGlobals.RightDelimiter);
}
// schema
selectQuery.AppendFormat("{0}{1}{2}.",
ScriptingGlobals.LeftDelimiter,
ScriptingUtils.QuoteObjectName(urn.GetAttribute("Schema"), ScriptingGlobals.RightDelimiter),
ScriptingGlobals.RightDelimiter);
// object
selectQuery.AppendFormat("{0}{1}{2}",
ScriptingGlobals.LeftDelimiter,
ScriptingUtils.QuoteObjectName(urn.GetAttribute("Name"), ScriptingGlobals.RightDelimiter),
ScriptingGlobals.RightDelimiter);
// In Hekaton M5, if it's a memory optimized table, we need to provide SNAPSHOT hint for SELECT.
if (urn.Type.Equals("Table") && ScriptingUtils.IsXTPSupportedOnServer(server))
{
Table table = (Table)server.GetSmoObject(urn);
table.Refresh();
if (table.IsMemoryOptimized)
{
selectQuery.Append(" WITH (SNAPSHOT)");
}
}
script = selectQuery.ToString();
return script;
}
#endregion
}
}
}