Fix some issues with Script As Select (#474)

This commit is contained in:
Karl Burtram
2017-10-03 14:39:29 -07:00
committed by GitHub
parent 8e78ecf9a4
commit 9091df8f62
5 changed files with 150 additions and 125 deletions

View File

@@ -16,6 +16,7 @@ using System.Threading.Tasks;
using Microsoft.SqlTools.Hosting.Protocol; using Microsoft.SqlTools.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection; using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
using Microsoft.SqlTools.ServiceLayer.LanguageServices;
using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts; using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts;
using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.SqlContext;
using Microsoft.SqlTools.ServiceLayer.Workspace; using Microsoft.SqlTools.ServiceLayer.Workspace;
@@ -64,6 +65,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
private readonly object cancellationTokenSourceLock = new object(); private readonly object cancellationTokenSourceLock = new object();
private ConnectedBindingQueue connectionQueue = new ConnectedBindingQueue(needsMetadata: false);
/// <summary> /// <summary>
/// Map from script URIs to ConnectionInfo objects /// Map from script URIs to ConnectionInfo objects
/// This is internal for testing access only /// This is internal for testing access only
@@ -86,6 +89,17 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
set; set;
} }
/// <summary>
/// Gets the connection queue
/// </summary>
internal ConnectedBindingQueue ConnectionQueue
{
get
{
return this.connectionQueue;
}
}
/// <summary> /// <summary>
/// Default constructor should be private since it's a singleton class, but we need a constructor /// Default constructor should be private since it's a singleton class, but we need a constructor
/// for use in unit test mocking. /// for use in unit test mocking.

View File

@@ -25,6 +25,12 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
internal const int DefaultMinimumConnectionTimeout = 30; internal const int DefaultMinimumConnectionTimeout = 30;
/// <summary>
/// flag determing if the connection queue requires online metadata objects
/// it's much cheaper to not construct these objects if not needed
/// </summary>
private bool needsMetadata;
/// <summary> /// <summary>
/// Gets the current settings /// Gets the current settings
/// </summary> /// </summary>
@@ -33,6 +39,16 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
get { return WorkspaceService<SqlToolsSettings>.Instance.CurrentSettings; } get { return WorkspaceService<SqlToolsSettings>.Instance.CurrentSettings; }
} }
public ConnectedBindingQueue()
: this(true)
{
}
public ConnectedBindingQueue(bool needsMetadata)
{
this.needsMetadata = needsMetadata;
}
/// <summary> /// <summary>
/// Generate a unique key based on the ConnectionInfo object /// Generate a unique key based on the ConnectionInfo object
/// </summary> /// </summary>
@@ -84,14 +100,18 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connInfo); SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connInfo);
// populate the binding context to work with the SMO metadata provider // populate the binding context to work with the SMO metadata provider
ServerConnection serverConn = new ServerConnection(sqlConn); bindingContext.ServerConnection = new ServerConnection(sqlConn);
bindingContext.SmoMetadataProvider = SmoMetadataProvider.CreateConnectedProvider(serverConn);
bindingContext.MetadataDisplayInfoProvider = new MetadataDisplayInfoProvider(); if (this.needsMetadata)
bindingContext.MetadataDisplayInfoProvider.BuiltInCasing = {
this.CurrentSettings.SqlTools.IntelliSense.LowerCaseSuggestions.Value bindingContext.SmoMetadataProvider = SmoMetadataProvider.CreateConnectedProvider(bindingContext.ServerConnection);
? CasingStyle.Lowercase : CasingStyle.Uppercase; bindingContext.MetadataDisplayInfoProvider = new MetadataDisplayInfoProvider();
bindingContext.Binder = BinderProvider.CreateBinder(bindingContext.SmoMetadataProvider); bindingContext.MetadataDisplayInfoProvider.BuiltInCasing =
bindingContext.ServerConnection = serverConn; this.CurrentSettings.SqlTools.IntelliSense.LowerCaseSuggestions.Value
? CasingStyle.Lowercase : CasingStyle.Uppercase;
bindingContext.Binder = BinderProvider.CreateBinder(bindingContext.SmoMetadataProvider);
}
bindingContext.BindingTimeout = ConnectedBindingQueue.DefaultBindingTimeout; bindingContext.BindingTimeout = ConnectedBindingQueue.DefaultBindingTimeout;
bindingContext.IsConnected = true; bindingContext.IsConnected = true;
} }

View File

@@ -43,7 +43,7 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer
private ConcurrentDictionary<string, ObjectExplorerSession> sessionMap; private ConcurrentDictionary<string, ObjectExplorerSession> sessionMap;
private readonly Lazy<Dictionary<string, HashSet<ChildFactory>>> applicableNodeChildFactories; private readonly Lazy<Dictionary<string, HashSet<ChildFactory>>> applicableNodeChildFactories;
private IMultiServiceProvider serviceProvider; private IMultiServiceProvider serviceProvider;
private ConnectedBindingQueue bindingQueue = new ConnectedBindingQueue(); private ConnectedBindingQueue bindingQueue = new ConnectedBindingQueue(needsMetadata: false);
private const int PrepopulateBindTimeout = 10000; private const int PrepopulateBindTimeout = 10000;

View File

@@ -740,12 +740,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
internal string SelectFromTableOrView(Server server, Urn urn, bool isDw) internal string SelectFromTableOrView(Server server, Urn urn, bool isDw)
{ {
string script = string.Empty;
DataTable dt = GetColumnNames(server, urn, isDw); DataTable dt = GetColumnNames(server, urn, isDw);
StringBuilder selectQuery = new StringBuilder(); StringBuilder selectQuery = new StringBuilder();
// build the first line // build the first line
if ((dt != null) && (dt.Rows.Count > 0)) if (dt != null && dt.Rows.Count > 0)
{ {
selectQuery.Append("SELECT TOP (1000) "); selectQuery.Append("SELECT TOP (1000) ");
@@ -768,18 +767,20 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
{ {
selectQuery.Append("SELECT TOP (1000) * "); selectQuery.Append("SELECT TOP (1000) * ");
} }
// from clause // from clause
selectQuery.Append(" FROM "); selectQuery.Append(" FROM ");
if(server.ServerType != DatabaseEngineType.SqlAzureDatabase) if(server.ServerType != DatabaseEngineType.SqlAzureDatabase)
{ //Azure doesn't allow qualifying object names with the DB, so only add it on if we're not in Azure {
// database URN // Azure doesn't allow qualifying object names with the DB, so only add it on if we're not in Azure database URN
Urn dbUrn = urn.Parent; Urn dbUrn = urn.Parent;
selectQuery.AppendFormat("{0}{1}{2}.", selectQuery.AppendFormat("{0}{1}{2}.",
ScriptingGlobals.LeftDelimiter, ScriptingGlobals.LeftDelimiter,
ScriptingUtils.QuoteObjectName(dbUrn.GetAttribute("Name"), ScriptingGlobals.RightDelimiter), ScriptingUtils.QuoteObjectName(dbUrn.GetAttribute("Name"), ScriptingGlobals.RightDelimiter),
ScriptingGlobals.RightDelimiter); ScriptingGlobals.RightDelimiter);
} }
// schema // schema
selectQuery.AppendFormat("{0}{1}{2}.", selectQuery.AppendFormat("{0}{1}{2}.",
ScriptingGlobals.LeftDelimiter, ScriptingGlobals.LeftDelimiter,
@@ -794,16 +795,23 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
// In Hekaton M5, if it's a memory optimized table, we need to provide SNAPSHOT hint for SELECT. // In Hekaton M5, if it's a memory optimized table, we need to provide SNAPSHOT hint for SELECT.
if (urn.Type.Equals("Table") && ScriptingUtils.IsXTPSupportedOnServer(server)) if (urn.Type.Equals("Table") && ScriptingUtils.IsXTPSupportedOnServer(server))
{ {
Table table = (Table)server.GetSmoObject(urn); try
table.Refresh();
if (table.IsMemoryOptimized)
{ {
selectQuery.Append(" WITH (SNAPSHOT)"); Table table = (Table)server.GetSmoObject(urn);
table.Refresh();
if (table.IsMemoryOptimized)
{
selectQuery.Append(" WITH (SNAPSHOT)");
}
}
catch (Exception ex)
{
// log any exceptions determining if InMemory, but don't treat as fatal exception
Logger.Write(LogLevel.Error, "Could not determine if is InMemory table " + ex.ToString());
} }
} }
script = selectQuery.ToString(); return selectQuery.ToString();
return script;
} }
#endregion #endregion

View File

@@ -16,7 +16,6 @@ using Microsoft.SqlTools.Hosting.Protocol;
using Microsoft.SqlTools.Hosting.Protocol.Contracts; using Microsoft.SqlTools.Hosting.Protocol.Contracts;
using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Hosting; using Microsoft.SqlTools.ServiceLayer.Hosting;
using Microsoft.SqlTools.ServiceLayer.LanguageServices;
using Microsoft.SqlTools.ServiceLayer.Metadata.Contracts; using Microsoft.SqlTools.ServiceLayer.Metadata.Contracts;
using Microsoft.SqlTools.ServiceLayer.Scripting.Contracts; using Microsoft.SqlTools.ServiceLayer.Scripting.Contracts;
using Microsoft.SqlTools.Utility; using Microsoft.SqlTools.Utility;
@@ -39,8 +38,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
private static ConnectionService connectionService = null; private static ConnectionService connectionService = null;
private static LanguageService languageServices = null;
private readonly Lazy<ConcurrentDictionary<string, ScriptingOperation>> operations = private readonly Lazy<ConcurrentDictionary<string, ScriptingOperation>> operations =
new Lazy<ConcurrentDictionary<string, ScriptingOperation>>(() => new ConcurrentDictionary<string, ScriptingOperation>()); new Lazy<ConcurrentDictionary<string, ScriptingOperation>>(() => new ConcurrentDictionary<string, ScriptingOperation>());
@@ -65,24 +62,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
} }
} }
/// <summary>
/// Internal for testing purposes only
/// </summary>
internal static LanguageService LanguageServiceInstance
{
get
{
if (languageServices == null)
{
languageServices = LanguageService.Instance;
}
return languageServices;
}
set
{
languageServices = value;
}
}
/// <summary> /// <summary>
/// The collection of active operations /// The collection of active operations
@@ -108,61 +87,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
}); });
} }
/// <summary>
/// Handles request to get select script for an smo object
/// </summary>
private void HandleScriptSelectRequest(ScriptingParams parameters, RequestContext<ScriptingResult> requestContext)
{
Task.Run(() =>
{
try
{
string script = String.Empty;
ScriptingObject scriptingObject = parameters.ScriptingObjects[0];
// convert owner uri received from parameters to lookup for its
// associated connection and build a connection string out of it
SqlConnection sqlConn = new SqlConnection(parameters.ConnectionString);
ServerConnection serverConn = new ServerConnection(sqlConn);
Server server = new Server(serverConn);
server.DefaultTextMode = true;
SqlConnectionStringBuilder connStringBuilder = new SqlConnectionStringBuilder(parameters.ConnectionString);
string urnString = string.Format(
"Server[@Name='{0}']/Database[@Name='{1}']/{2}[@Name='{3}' {4}]",
server.Name.ToUpper(),
connStringBuilder.InitialCatalog,
scriptingObject.Type,
scriptingObject.Name,
scriptingObject.Schema != null ? string.Format("and @Schema = '{0}'", scriptingObject.Schema) : string.Empty);
Urn urn = new Urn(urnString);
string name = urn.GetNameForType(scriptingObject.Type);
if (string.Compare(name, "ServiceBroker", StringComparison.CurrentCultureIgnoreCase) == 0)
{
script = Scripter.SelectAllValuesFromTransmissionQueue(urn);
}
else
{
if (string.Compare(name, "Queues", StringComparison.CurrentCultureIgnoreCase) == 0 ||
string.Compare(name, "SystemQueues", StringComparison.CurrentCultureIgnoreCase) == 0)
{
script = Scripter.SelectAllValues(urn);
}
else
{
Database db = server.Databases[connStringBuilder.InitialCatalog];
bool isDw = db.IsSqlDw;
script = new Scripter().SelectFromTableOrView(server, urn, isDw);
}
}
requestContext.SendResult(new ScriptingResult { Script = script});
}
catch (Exception e)
{
requestContext.SendError(e);
}
});
}
/// <summary> /// <summary>
/// Handles request to execute start the list objects operation. /// Handles request to execute start the list objects operation.
/// </summary> /// </summary>
@@ -171,7 +95,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
try try
{ {
ScriptingListObjectsOperation operation = new ScriptingListObjectsOperation(parameters); ScriptingListObjectsOperation operation = new ScriptingListObjectsOperation(parameters);
operation.CompleteNotification += (sender, e) => this.SendEvent(requestContext, ScriptingListObjectsCompleteEvent.Type, e); operation.CompleteNotification += (sender, e) => requestContext.SendEvent(ScriptingListObjectsCompleteEvent.Type, e);
RunTask(requestContext, operation); RunTask(requestContext, operation);
@@ -184,36 +108,40 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
} }
/// <summary> /// <summary>
/// Handles request to execute start the script operation. /// Handles request to start the scripting operation
/// </summary> /// </summary>
public async Task HandleScriptExecuteRequest(ScriptingParams parameters, RequestContext<ScriptingResult> requestContext) public async Task HandleScriptExecuteRequest(ScriptingParams parameters, RequestContext<ScriptingResult> requestContext)
{ {
try try
{ {
// convert owner uri received from parameters to lookup for its // if a connection string wasn't provided as a parameter then
// associated connection and build a connection string out of it // use the owner uri property to lookup its associated ConnectionInfo
// if a connection string doesn't already exist // and then build a connection string out of that
ConnectionInfo connInfo;
ScriptingService.ConnectionServiceInstance.TryFindConnection(parameters.OwnerUri, out connInfo);
if (parameters.ConnectionString == null) if (parameters.ConnectionString == null)
{ {
ConnectionInfo connInfo;
ScriptingService.ConnectionServiceInstance.TryFindConnection(
parameters.OwnerUri, out connInfo);
if (connInfo != null) if (connInfo != null)
{ {
parameters.ConnectionString = ConnectionService.BuildConnectionString(connInfo.ConnectionDetails); parameters.ConnectionString = ConnectionService.BuildConnectionString(connInfo.ConnectionDetails);
} }
else
{
throw new Exception("Could not find ConnectionInfo");
}
} }
// if the scripting operation is for select // if the scripting operation is for SELECT then handle that message differently
// for SELECT we'll build the SQL directly whereas other scripting operations depend on SMO
if (parameters.ScriptOptions.ScriptCreateDrop == "ScriptSelect") if (parameters.ScriptOptions.ScriptCreateDrop == "ScriptSelect")
{ {
RunSelectTask(parameters, requestContext); RunSelectTask(connInfo, parameters, requestContext);
} }
else else
{ {
ScriptingScriptOperation operation = new ScriptingScriptOperation(parameters); ScriptingScriptOperation operation = new ScriptingScriptOperation(parameters);
operation.PlanNotification += (sender, e) => this.SendEvent(requestContext, ScriptingPlanNotificationEvent.Type, e); operation.PlanNotification += (sender, e) => requestContext.SendEvent(ScriptingPlanNotificationEvent.Type, e);
operation.ProgressNotification += (sender, e) => this.SendEvent(requestContext, ScriptingProgressNotificationEvent.Type, e); operation.ProgressNotification += (sender, e) => requestContext.SendEvent(ScriptingProgressNotificationEvent.Type, e);
operation.CompleteNotification += (sender, e) => this.SendScriptingCompleteEvent(requestContext, ScriptingCompleteEvent.Type, e, operation, parameters.ScriptDestination); operation.CompleteNotification += (sender, e) => this.SendScriptingCompleteEvent(requestContext, ScriptingCompleteEvent.Type, e, operation, parameters.ScriptDestination);
RunTask(requestContext, operation); RunTask(requestContext, operation);
@@ -271,30 +199,85 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
}); });
} }
/// <summary> private Urn BuildScriptingObjectUrn(
/// Sends a JSON-RPC event. Server server,
/// </summary> SqlConnectionStringBuilder connectionStringBuilder,
private void SendEvent<TParams>(IEventSender requestContext, EventType<TParams> eventType, TParams parameters) ScriptingObject scriptingObject)
{ {
Task.Run(async () => await requestContext.SendEvent(eventType, parameters)); string serverName = server.Name.ToUpper();
// remove the port from server name if specified
int commaPos = serverName.IndexOf(',');
if (commaPos >= 0)
{
serverName = serverName.Substring(0, commaPos);
}
// build the URN
string urnString = string.Format(
"Server[@Name='{0}']/Database[@Name='{1}']/{2}[@Name='{3}' {4}]",
serverName,
connectionStringBuilder.InitialCatalog,
scriptingObject.Type,
scriptingObject.Name,
scriptingObject.Schema != null ? string.Format("and @Schema = '{0}'", scriptingObject.Schema) : string.Empty);
return new Urn(urnString);
} }
/// <summary> /// <summary>
/// Runs the async task that performs the scripting operation. /// Runs the async task that performs the scripting operation.
/// </summary> /// </summary>
private void RunSelectTask(ScriptingParams parameters, RequestContext<ScriptingResult> requestContext) private void RunSelectTask(ConnectionInfo connInfo, ScriptingParams parameters, RequestContext<ScriptingResult> requestContext)
{ {
Task.Run(() => ConnectionServiceInstance.ConnectionQueue.QueueBindingOperation(
{ key: ConnectionServiceInstance.ConnectionQueue.AddConnectionContext(connInfo),
try bindingTimeout: ScriptingOperationTimeout,
bindOperation: (bindingContext, cancelToken) =>
{ {
this.HandleScriptSelectRequest(parameters, requestContext); string script = string.Empty;
} ScriptingObject scriptingObject = parameters.ScriptingObjects[0];
catch (Exception e) try
{ {
requestContext.SendError(e); Server server = new Server(bindingContext.ServerConnection);
} server.DefaultTextMode = true;
});
// build object URN
SqlConnectionStringBuilder connectionStringBuilder = new SqlConnectionStringBuilder(parameters.ConnectionString);
Urn objectUrn = BuildScriptingObjectUrn(server, connectionStringBuilder, scriptingObject);
string typeName = objectUrn.GetNameForType(scriptingObject.Type);
// select from service broker
if (string.Compare(typeName, "ServiceBroker", StringComparison.CurrentCultureIgnoreCase) == 0)
{
script = Scripter.SelectAllValuesFromTransmissionQueue(objectUrn);
}
// select from queues
else if (string.Compare(typeName, "Queues", StringComparison.CurrentCultureIgnoreCase) == 0 ||
string.Compare(typeName, "SystemQueues", StringComparison.CurrentCultureIgnoreCase) == 0)
{
script = Scripter.SelectAllValues(objectUrn);
}
// select from table or view
else
{
Database db = server.Databases[connectionStringBuilder.InitialCatalog];
bool isDw = db.IsSqlDw;
script = new Scripter().SelectFromTableOrView(server, objectUrn, isDw);
}
// send script result to client
requestContext.SendResult(new ScriptingResult { Script = script });
}
catch (Exception e)
{
requestContext.SendError(e);
}
return null;
});
} }
/// <summary> /// <summary>