SqlCmd Connect/On Error/Include commands support (#898)

* Initial Investigation

* Working code with include, connect, on error and tests

* Adding some loc strings

* Some cleanup and more tests

* Some dummy change to trigger build

* Adding PR comments

* Addressing PR comments
This commit is contained in:
Udeesha Gautam
2020-01-10 17:54:39 -08:00
committed by GitHub
parent d512c101c0
commit fe17962ac9
25 changed files with 925 additions and 137 deletions

View File

@@ -71,7 +71,8 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser
endLine,
startColumn + 1,
endColumn + 1,
batchInfos[0].executionCount
batchInfos[0].executionCount,
batchInfos[0].sqlCmdCommand
);
batchDefinitionList.Add(batchDef);
@@ -100,7 +101,8 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser
endLine,
startColumn + 1,
endColumn + 1,
batchInfos[index].executionCount
batchInfos[index].executionCount,
batchInfos[index].sqlCmdCommand
);
batchDefinitionList.Add(batch);
}
@@ -235,7 +237,8 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser
endLine,
startColumn + 1,
endColumn + 1,
batchInfo.executionCount
batchInfo.executionCount,
batchInfo.sqlCmdCommand
);
}
@@ -381,7 +384,7 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser
}
// Add the script info
batchInfos.Add(new BatchInfo(args.Batch.TextSpan.iStartLine, args.Batch.TextSpan.iStartIndex, batchText, args.Batch.ExpectedExecutionCount));
batchInfos.Add(new BatchInfo(args.Batch.TextSpan.iStartLine, args.Batch.TextSpan.iStartIndex, batchText, args.SqlCmdCommand, args.Batch.ExpectedExecutionCount));
}
}
catch (NotImplementedException)
@@ -474,17 +477,19 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser
private class BatchInfo
{
public BatchInfo(int startLine, int startColumn, string batchText, int repeatCount = 1)
public BatchInfo(int startLine, int startColumn, string batchText, SqlCmdCommand sqlCmdCommand, int repeatCount = 1)
{
this.startLine = startLine;
this.startColumn = startColumn;
this.executionCount = repeatCount;
this.batchText = batchText;
this.sqlCmdCommand = sqlCmdCommand;
}
public int startLine;
public int startColumn;
public int executionCount;
public string batchText;
public SqlCmdCommand sqlCmdCommand;
}
}

View File

@@ -0,0 +1,85 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using Microsoft.SqlServer.Management.Common;
using System;
using System.Data;
using System.Data.Common;
using Microsoft.Data.SqlClient;
namespace Microsoft.SqlTools.ServiceLayer.BatchParser
{
public class ConnectSqlCmdCommand : SqlCmdCommand
{
internal ConnectSqlCmdCommand(string server, string username, string password) : base(LexerTokenType.Connect)
{
Server = server;
UserName = username;
Password = password;
}
public string Server { get; private set; }
public string UserName { get; private set; }
public string Password { get; private set; }
/// <summary>
/// attempts to establish connection with given params
/// </summary>
/// <returns>returns the connection object is successful, throws otherwise</returns>
public DbConnection Connect()
{
//create SqlConnectionInfo object
SqlConnectionInfo connectionInfo = new SqlConnectionInfo();
if (Server != null && Server.Length > 0)
{
connectionInfo.ServerName = Server;
}
if (UserName != null && UserName.Length > 0)
{
connectionInfo.UseIntegratedSecurity = false;
connectionInfo.UserName = UserName;
connectionInfo.Password = Password;
}
else
{
connectionInfo.UseIntegratedSecurity = true;
}
DbConnection dbConnection = AttemptToEstablishCurConnection(connectionInfo);
return dbConnection;
}
/// <summary>
/// called when we need to establish new connection for batch executio as a
/// result of "connect" command processing
/// </summary>
/// <param name="ci"></param>
/// <returns></returns>
private DbConnection AttemptToEstablishCurConnection(SqlConnectionInfo ci)
{
if (ci == null || ci.ServerName == null)
{
return null;
}
IDbConnection conn = null;
try
{
string connString = ci.ConnectionString;
connString += ";Pooling=false"; //turn off connection pooling (this is done in other tools so following the same pattern)
conn = new SqlConnection(connString);
conn.Open();
return conn as DbConnection;
}
catch (Exception ex)
{
throw new Exception($"Failed to Change connection to {ci.ServerName}", ex);
}
}
}
}

View File

@@ -13,7 +13,7 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser.ExecutionEngineCode
/// <summary>
/// Constructor method for a BatchDefinition
/// </summary>
public BatchDefinition(string batchText, int startLine, int endLine, int startColumn, int endColumn, int executionCount)
public BatchDefinition(string batchText, int startLine, int endLine, int startColumn, int endColumn, int executionCount, SqlCmdCommand command)
{
BatchText = batchText;
StartLine = startLine;
@@ -22,6 +22,7 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser.ExecutionEngineCode
EndColumn = endColumn;
// set the batch execution count, with min value of 1
BatchExecutionCount = executionCount > 0 ? executionCount : 1;
SqlCmdCommand = command;
}
/// <summary>
@@ -64,6 +65,11 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser.ExecutionEngineCode
get; private set;
}
public SqlCmdCommand SqlCmdCommand
{
get; private set;
}
/// <summary>
/// Get number of times to execute this batch
/// </summary>

View File

@@ -28,7 +28,7 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser.ExecutionEngineCode
public delegate void HaltParserDelegate();
public delegate void ScriptMessageDelegate(string message);
public delegate void ScriptErrorDelegate(string message, ScriptMessageType messageType);
public delegate bool ExecuteDelegate(string batchScript, int num, int lineNumber);
public delegate bool ExecuteDelegate(string batchScript, int num, int lineNumber, SqlCmdCommand command);
#endregion
#region Constructors / Destructor
@@ -75,7 +75,7 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser.ExecutionEngineCode
/// <summary>
/// Take approptiate action on the parsed batches
/// </summary>
public BatchParserAction Go(TextBlock batch, int repeatCount)
public BatchParserAction Go(TextBlock batch, int repeatCount, SqlCmdCommand command)
{
string str;
LineInfo lineInfo;
@@ -85,7 +85,7 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser.ExecutionEngineCode
bool executeResult = false;
if (executeDelegate != null)
{
executeResult = executeDelegate(str, repeatCount, lineInfo.GetStreamPositionForOffset(0).Line + startingLine - 1);
executeResult = executeDelegate(str, repeatCount, lineInfo.GetStreamPositionForOffset(0).Line + startingLine - 1, command);
}
return executeResult ? BatchParserAction.Continue : BatchParserAction.Abort;
}

View File

@@ -15,6 +15,7 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser.ExecutionEngineCode
private readonly Batch batch = null;
private readonly ScriptExecutionResult result;
private readonly SqlCmdCommand sqlCmdCommand;
private BatchParserExecutionFinishedEventArgs()
{
@@ -23,10 +24,11 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser.ExecutionEngineCode
/// <summary>
/// Constructor method for the class
/// </summary>
public BatchParserExecutionFinishedEventArgs(ScriptExecutionResult batchResult, Batch batch)
public BatchParserExecutionFinishedEventArgs(ScriptExecutionResult batchResult, Batch batch, SqlCmdCommand sqlCmdCommand)
{
this.batch = batch;
result = batchResult;
this.sqlCmdCommand = sqlCmdCommand;
}
public Batch Batch
@@ -44,5 +46,13 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser.ExecutionEngineCode
return result;
}
}
public SqlCmdCommand SqlCmdCommand
{
get
{
return sqlCmdCommand;
}
}
}
}

View File

@@ -113,11 +113,72 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser.ExecutionEngineCode
{
stream = null;
newFilename = null;
LineInfo lineInfo;
RaiseScriptError(string.Format(CultureInfo.CurrentCulture, SR.EE_ExecutionError_CommandNotSupported, "Include"), ScriptMessageType.Error);
if (filename == null)
{
stream = null;
return BatchParserAction.Abort;
}
filename.GetText(resolveVariables: true, text: out newFilename, lineInfo: out lineInfo);
string resolvedFileNameWithFullPath = GetFilePath(newFilename);
if (!File.Exists(resolvedFileNameWithFullPath))
{
stream = null;
return BatchParserAction.Abort;
}
else
{
stream = new StreamReader(resolvedFileNameWithFullPath);
}
return BatchParserAction.Continue;
}
private string GetFilePath(string fileName)
{
//try appending the file name with current working directory path
string fullFileName = null;
try
{
if (Environment.CurrentDirectory != null && !File.Exists(fileName))
{
string currentWorkingDirectory = Environment.CurrentDirectory;
if (currentWorkingDirectory != null)
{
fullFileName = Path.GetFullPath(Path.Combine(currentWorkingDirectory, fileName));
if (!File.Exists(fullFileName))
{
fullFileName = null;
}
}
}
if (fullFileName == null)
{
fullFileName = Path.GetFullPath(fileName);
}
return fullFileName;
}
catch (ArgumentException)
{
//path contains invalid path characters.
throw new SqlCmdException(SR.SqlCmd_PathInvalid);
}
catch (PathTooLongException)
{
//path is too long
throw new SqlCmdException(SR.SqlCmd_PathLong);
}
catch (Exception)
{
//catch all other exceptions and report generic error
throw new SqlCmdException(string.Format(SR.SqlCmd_FailedInclude, fileName));
}
}
/// <summary>
/// Method to deal with errors
/// </summary>

View File

@@ -281,14 +281,14 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser.ExecutionEngineCode
/// </summary>
/// <param name="batch"></param>
/// <param name="batchResult"></param>
private void RaiseBatchParserExecutionFinished(Batch batch, ScriptExecutionResult batchResult)
private void RaiseBatchParserExecutionFinished(Batch batch, ScriptExecutionResult batchResult, SqlCmdCommand sqlCmdCommand)
{
Debug.Assert(batch != null);
EventHandler<BatchParserExecutionFinishedEventArgs> cache = BatchParserExecutionFinished;
if (cache != null)
{
BatchParserExecutionFinishedEventArgs args = new BatchParserExecutionFinishedEventArgs(batchResult, batch);
BatchParserExecutionFinishedEventArgs args = new BatchParserExecutionFinishedEventArgs(batchResult, batch, sqlCmdCommand);
cache(this, args);
}
}
@@ -336,7 +336,8 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser.ExecutionEngineCode
private bool ExecuteBatchInternal(
string batchScript,
int num,
int lineNumber)
int lineNumber,
SqlCmdCommand sqlCmdCommand)
{
if (lineNumber == -1)
{
@@ -353,7 +354,7 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser.ExecutionEngineCode
{
bool continueProcessing = true;
numBatchExecutionTimes = num;
ExecuteBatchTextSpanInternal(batchScript, localTextSpan, out continueProcessing);
ExecuteBatchTextSpanInternal(batchScript, localTextSpan, out continueProcessing, sqlCmdCommand);
return continueProcessing;
}
else
@@ -368,7 +369,7 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser.ExecutionEngineCode
/// <param name="batchScript"></param>
/// <param name="textSpan"></param>
/// <param name="continueProcessing"></param>
private void ExecuteBatchTextSpanInternal(string batchScript, TextSpan textSpan, out bool continueProcessing)
private void ExecuteBatchTextSpanInternal(string batchScript, TextSpan textSpan, out bool continueProcessing, SqlCmdCommand sqlCmdCommand)
{
Debug.Assert(!String.IsNullOrEmpty(batchScript));
continueProcessing = true;
@@ -443,7 +444,7 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser.ExecutionEngineCode
if (!isExecutionDiscarded)
{
RaiseBatchParserExecutionFinished(currentBatch, batchResult);
RaiseBatchParserExecutionFinished(currentBatch, batchResult, sqlCmdCommand);
}
}
else
@@ -501,7 +502,7 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser.ExecutionEngineCode
}
else
{
ExecuteBatchInternal(script, /* num */ 1, /* lineNumber */ 0);
ExecuteBatchInternal(script, num: 1, lineNumber: 0, /* sqlcmdCommand required for parsing only*/ sqlCmdCommand: null);
}
}

View File

@@ -9,7 +9,7 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser
{
public interface ICommandHandler
{
BatchParserAction Go(TextBlock batch, int repeatCount);
BatchParserAction Go(TextBlock batch, int repeatCount, SqlCmdCommand tokenType);
BatchParserAction OnError(Token token, OnErrorAction action);
BatchParserAction Include(TextBlock filename, out TextReader stream, out string newFilename);
}

View File

@@ -0,0 +1,17 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
namespace Microsoft.SqlTools.ServiceLayer.BatchParser
{
public class OnErrorSqlCmdCommand : SqlCmdCommand
{
internal OnErrorSqlCmdCommand(OnErrorAction action) : base(LexerTokenType.OnError)
{
Action = action;
}
public OnErrorAction Action { get; private set; }
}
}

View File

@@ -23,6 +23,7 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser
private Lexer lexer;
private List<Token> tokenBuffer;
private readonly IVariableResolver variableResolver;
private SqlCmdCommand sqlCmdCommand;
/// <summary>
/// Constructor for the Parser class
@@ -200,7 +201,7 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser
private void ExecuteBatch(int repeatCount)
{
BatchParserAction action;
action = commandHandler.Go(new TextBlock(this, tokenBuffer), repeatCount);
action = commandHandler.Go(new TextBlock(this, tokenBuffer), repeatCount, this.sqlCmdCommand);
if (action == BatchParserAction.Abort)
{
@@ -406,6 +407,11 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser
ParseSetvar(setvarToken);
break;
case LexerTokenType.Connect:
Token connectToken = LookaheadToken;
RemoveLastWhitespaceToken();
Accept();
ParseConnect(connectToken);
break;
case LexerTokenType.Ed:
case LexerTokenType.ErrorCommand:
case LexerTokenType.Execute:
@@ -472,6 +478,8 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser
parserAction = commandHandler.OnError(onErrorToken, onErrorAction);
this.sqlCmdCommand = new OnErrorSqlCmdCommand(onErrorAction);
if (parserAction == BatchParserAction.Abort)
{
RaiseError(ErrorCode.Aborted);
@@ -522,6 +530,83 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser
variableResolver.SetVariable(setvarToken.Begin, variableName, variableValue);
}
private void ParseConnect(Token connectToken)
{
string serverName = null;
string userName = null;
string password = null;
Accept(LexerTokenType.Whitespace);
Expect(LexerTokenType.Text);
serverName = ResolveVariables(LookaheadToken, 0, null);
if (serverName == null)
{
//found some text but couldn't parse for servername
RaiseError(ErrorCode.UnrecognizedToken);
}
Accept();
Accept(LexerTokenType.Whitespace);
switch (LookaheadTokenType)
{
case LexerTokenType.Text:
userName = ParseUserName();
password = ParsePassword();
if(userName == null || password == null)
{
//found some text but couldn't parse for user/password
RaiseError(ErrorCode.UnrecognizedToken);
}
break;
case LexerTokenType.NewLine:
case LexerTokenType.Eof:
Accept();
break;
default:
RaiseError(ErrorCode.UnrecognizedToken);
break;
}
this.sqlCmdCommand = new ConnectSqlCmdCommand(serverName, userName, password);
}
private string ParseUserName()
{
string username = null;
if (LookaheadToken.Text == "-U")
{
Accept();
Accept(LexerTokenType.Whitespace);
if (LookaheadTokenType == LexerTokenType.Text)
{
username = ResolveVariables(LookaheadToken, 0, null);
Accept();
Accept(LexerTokenType.Whitespace);
}
}
return username;
}
private string ParsePassword()
{
string password = null;
if (LookaheadToken.Text == "-P")
{
Accept();
Accept(LexerTokenType.Whitespace);
if (LookaheadTokenType == LexerTokenType.Text)
{
password = ResolveVariables(LookaheadToken, 0, null);
Accept();
Accept(LexerTokenType.Whitespace);
}
}
return password;
}
internal void RaiseError(ErrorCode errorCode, string message = null)
{
RaiseError(errorCode, LookaheadToken, message);

View File

@@ -0,0 +1,20 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
namespace Microsoft.SqlTools.ServiceLayer.BatchParser
{
/// <summary>
/// Class to pass back SqlCmd specific properties from Parser to Query Execution
/// </summary>
public class SqlCmdCommand
{
internal SqlCmdCommand(LexerTokenType tokenType)
{
this.LexerTokenType = tokenType;
}
public LexerTokenType LexerTokenType { get; private set; }
}
}

View File

@@ -0,0 +1,18 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System;
namespace Microsoft.SqlTools.ServiceLayer.BatchParser
{
/// <summary>
/// Specific exception type for SQLCMD related issues
/// </summary>
public class SqlCmdException : Exception
{
public SqlCmdException(string message) : base(message)
{
}
}
}

View File

@@ -261,6 +261,30 @@ namespace Microsoft.SqlTools.ManagedBatchParser
}
}
public static string SqlCmd_PathInvalid
{
get
{
return Keys.GetString(Keys.SqlCmd_PathInvalid);
}
}
public static string SqlCmd_PathLong
{
get
{
return Keys.GetString(Keys.SqlCmd_PathLong);
}
}
public static string SqlCmd_FailedInclude
{
get
{
return Keys.GetString(Keys.SqlCmd_FailedInclude);
}
}
[System.Runtime.CompilerServices.CompilerGeneratedAttribute()]
public class Keys
{
@@ -356,6 +380,14 @@ namespace Microsoft.SqlTools.ManagedBatchParser
public const string BatchParser_VariableNotDefined = "BatchParser_VariableNotDefined";
public const string SqlCmd_PathInvalid = "SqlCmd_PathInvalid";
public const string SqlCmd_PathLong = "SqlCmd_PathLong";
public const string SqlCmd_FailedInclude = "SqlCmd_FailedInclude";
private Keys()
{ }

View File

@@ -233,4 +233,16 @@
<value>Variable {0} is not defined.</value>
<comment></comment>
</data>
<data name="SqlCmd_PathInvalid" xml:space="preserve">
<value>Path contains invalid characters.</value>
<comment></comment>
</data>
<data name="SqlCmd_PathLong" xml:space="preserve">
<value>Path too long.</value>
<comment></comment>
</data>
<data name="SqlCmd_FailedInclude" xml:space="preserve">
<value>Could not find included file {0}.</value>
<comment></comment>
</data>
</root>

View File

@@ -82,3 +82,9 @@ BatchParser_IncorrectSyntax = Incorrect syntax was encountered while parsing '{0
BatchParser_VariableNotDefined = Variable {0} is not defined.
SqlCmd_PathInvalid = Path contains invalid characters.
SqlCmd_PathLong = Path too long.
SqlCmd_FailedInclude = Could not find included file {0}.

View File

@@ -147,6 +147,21 @@
<target state="new">Variable {0} is not defined.</target>
<note></note>
</trans-unit>
<trans-unit id="SqlCmd_PathInvalid">
<source>Path contains invalid characters.</source>
<target state="new">Path contains invalid characters.</target>
<note></note>
</trans-unit>
<trans-unit id="SqlCmd_PathLong">
<source>Path too long.</source>
<target state="new">Path too long.</target>
<note></note>
</trans-unit>
<trans-unit id="SqlCmd_FailedInclude">
<source>Could not find included file {0}.</source>
<target state="new">Could not find included file {0}.</target>
<note></note>
</trans-unit>
</body>
</file>
</xliff>

View File

@@ -357,6 +357,21 @@ namespace Microsoft.SqlTools.ServiceLayer
}
}
public static string SqlCmdExitOnError
{
get
{
return Keys.GetString(Keys.SqlCmdExitOnError);
}
}
public static string SqlCmdUnsupportedToken
{
get
{
return Keys.GetString(Keys.SqlCmdUnsupportedToken);
}
}
public static string PeekDefinitionNoResultsError
{
get
@@ -3280,6 +3295,12 @@ namespace Microsoft.SqlTools.ServiceLayer
public const string QueryServiceExecutionPlanNotFound = "QueryServiceExecutionPlanNotFound";
public const string SqlCmdExitOnError = "SqlCmdExitOnError";
public const string SqlCmdUnsupportedToken = "SqlCmdUnsupportedToken";
public const string SerializationServiceUnsupportedFormat = "SerializationServiceUnsupportedFormat";

View File

@@ -322,6 +322,14 @@
<value>Could not retrieve an execution plan from the result set </value>
<comment></comment>
</data>
<data name="SqlCmdExitOnError" xml:space="preserve">
<value>An error was encountered during execution of batch. Exiting.</value>
<comment></comment>
</data>
<data name="SqlCmdUnsupportedToken" xml:space="preserve">
<value>Encountered unsupported token {0}</value>
<comment></comment>
</data>
<data name="SerializationServiceUnsupportedFormat" xml:space="preserve">
<value>Unsupported Save Format: {0}</value>
<comment>.

View File

@@ -140,6 +140,10 @@ QueryServiceResultSetNoColumnSchema = Could not retrieve column schema for resul
QueryServiceExecutionPlanNotFound = Could not retrieve an execution plan from the result set
SqlCmdExitOnError = An error was encountered during execution of batch. Exiting.
SqlCmdUnsupportedToken = Encountered unsupported token {0}
############################################################################
# Serialization Service

View File

@@ -18,6 +18,7 @@ using Microsoft.SqlTools.Utility;
using System.Globalization;
using System.Collections.ObjectModel;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.BatchParser;
namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
{
@@ -70,6 +71,13 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
#endregion
internal Batch(string batchText, SelectionData selection, int ordinalId,
IFileStreamFactory outputFileFactory, SqlCmdCommand sqlCmdCommand, int executionCount = 1, bool getFullColumnSchema = false) : this(batchText, selection, ordinalId,
outputFileFactory, executionCount, getFullColumnSchema)
{
this.SqlCmdCommand = sqlCmdCommand;
}
internal Batch(string batchText, SelectionData selection, int ordinalId,
IFileStreamFactory outputFileFactory, int executionCount = 1, bool getFullColumnSchema = false)
{
@@ -138,6 +146,12 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
/// called from the Batch but from the ResultSet instance.
/// </summary>
public event ResultSet.ResultSetAsyncEventHandler ResultSetUpdated;
/// <summary>
/// Event that will be called when additional rows in the result set are available (rowCount available has increased). It will not be
/// called from the Batch but from the ResultSet instance.
/// </summary>
public event EventHandler<bool> HandleOnErrorAction;
#endregion
#region Properties
@@ -147,6 +161,8 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
/// </summary>
public string BatchText { get; set; }
public SqlCmdCommand SqlCmdCommand { get; set; }
public int BatchExecutionCount { get; private set; }
/// <summary>
/// Localized timestamp for when the execution completed.
@@ -250,6 +266,17 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
/// <param name="conn">The connection to use to execute the batch</param>
/// <param name="cancellationToken">Token for cancelling the execution</param>
public async Task Execute(DbConnection conn, CancellationToken cancellationToken)
{
await Execute(conn, cancellationToken, OnErrorAction.Ignore);
}
/// <summary>
/// Executes this batch and captures any server messages that are returned.
/// </summary>
/// <param name="conn">The connection to use to execute the batch</param>
/// <param name="cancellationToken">Token for cancelling the execution</param>
/// <param name="onErrorAction">Continue (Ignore) or Exit on Error. This comes only in SQLCMD mode</param>
public async Task Execute(DbConnection conn, CancellationToken cancellationToken, OnErrorAction onErrorAction = OnErrorAction.Ignore)
{
// Sanity check to make sure we haven't already run this batch
if (HasExecuted)
@@ -273,7 +300,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
try
{
await DoExecute(conn, cancellationToken);
await DoExecute(conn, cancellationToken, onErrorAction);
}
catch (TaskCanceledException)
{
@@ -308,7 +335,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
}
private async Task DoExecute(DbConnection conn, CancellationToken cancellationToken)
private async Task DoExecute(DbConnection conn, CancellationToken cancellationToken, OnErrorAction onErrorAction = OnErrorAction.Ignore)
{
bool canContinue = true;
int timesLoop = this.BatchExecutionCount;
@@ -326,6 +353,10 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
catch (DbException dbe)
{
HasError = true;
if (onErrorAction == OnErrorAction.Exit)
{
throw new SqlCmdException(dbe.Message);
}
canContinue = await UnwrapDbException(dbe);
if (canContinue)
{
@@ -679,6 +710,11 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
{
this.HasError = true;
}
if (this.HandleOnErrorAction != null)
{
HandleOnErrorAction(this, isError);
}
}
/// <summary>
@@ -693,6 +729,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
{
bool canIgnore = true;
SqlException se = dbe as SqlException;
if (se != null)
{
var errors = se.Errors.Cast<SqlError>().ToList();

View File

@@ -86,6 +86,22 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
/// </summary>
private string newDatabaseName;
/// <summary>
/// On Error Action for the query in SQLCMD Mode- Ignore or Exit
/// </summary>
private OnErrorAction onErrorAction;
/// <summary>
/// Connection that is used for query to run
/// This is always initialized from editor connection but might be different in case of SQLCMD mode
/// </summary>
private DbConnection queryConnection;
/// <summary>
/// Cancelled but not user but by SQLCMD settings
/// </summary>
private bool CancelledBySqlCmd;
#endregion
/// <summary>
@@ -132,6 +148,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
batchDefinition.EndLine-1,
batchDefinition.EndColumn-1),
index, outputFactory,
batchDefinition.SqlCmdCommand,
batchDefinition.BatchExecutionCount,
getFullColumnSchema));
@@ -401,7 +418,8 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
}
// Locate and setup the connection
DbConnection queryConnection = await ConnectionService.Instance.GetOrOpenConnection(editorConnection.OwnerUri, ConnectionType.Query);
queryConnection = await ConnectionService.Instance.GetOrOpenConnection(editorConnection.OwnerUri, ConnectionType.Query);
onErrorAction = OnErrorAction.Ignore;
sqlConn = queryConnection as ReliableSqlConnection;
if (sqlConn != null)
{
@@ -427,7 +445,8 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
b.ResultSetCompletion += ResultSetCompleted;
b.ResultSetAvailable += ResultSetAvailable;
b.ResultSetUpdated += ResultSetUpdated;
await b.Execute(queryConnection, cancellationSource.Token);
await ExecuteBatch(b);
}
// Execute afterBatches synchronously, after the user defined batches
@@ -445,7 +464,11 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
catch (Exception e)
{
HasErrored = true;
if (e is OperationCanceledException)
if (e is SqlCmdException || CancelledBySqlCmd)
{
await BatchMessageSent(new ResultMessage(SR.SqlCmdExitOnError, false, null));
}
else if (e is OperationCanceledException)
{
await BatchMessageSent(new ResultMessage(SR.QueryServiceQueryCancelled, false, null));
}
@@ -480,6 +503,63 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
}
}
}
private async Task ExecuteBatch(Batch b)
{
if (b.SqlCmdCommand != null)
{
b.HandleOnErrorAction += HandleOnErrorAction;
await PerformInplaceSqlCmdAction(b);
}
await b.Execute(queryConnection, cancellationSource.Token, this.onErrorAction);
if (CancelledBySqlCmd)
{
throw new SqlCmdException(SR.SqlCmdExitOnError);
}
}
public async Task PerformInplaceSqlCmdAction(Batch b)
{
try
{
switch (b.SqlCmdCommand.LexerTokenType)
{
case LexerTokenType.Connect:
var qc = (b.SqlCmdCommand as ConnectSqlCmdCommand)?.Connect();
queryConnection = qc ?? queryConnection;
break;
case LexerTokenType.OnError:
onErrorAction = (b.SqlCmdCommand as OnErrorSqlCmdCommand).Action;
break;
default:
throw new SqlCmdException(string.Format(SR.SqlCmdUnsupportedToken, b.SqlCmdCommand.LexerTokenType));
}
}
catch (Exception ex)
{
b.HasError = true;
await BatchMessageSent(new ResultMessage(ex.Message, true, null));
if (this.onErrorAction == OnErrorAction.Exit)
{
HasCancelled = true;
CancelledBySqlCmd = true;
cancellationSource.Cancel();
throw new SqlCmdException(SR.SqlCmdExitOnError);
}
}
}
private void HandleOnErrorAction(object sender, bool iserror)
{
if (iserror && this.onErrorAction == OnErrorAction.Exit)
{
HasCancelled = true;
CancelledBySqlCmd = true;
cancellationSource.Cancel();
}
}
/// <summary>
/// Handler for database messages during query execution

View File

@@ -362,6 +362,129 @@ GO";
}
}
/// <summary>
/// Verify whether the batchParser parsed :connect command successfully
/// </summary>
[Fact]
public void VerifyConnectSqlCmd()
{
using (ExecutionEngine executionEngine = new ExecutionEngine())
{
var liveConnection = LiveConnectionHelper.InitLiveConnectionInfo("master");
string serverName = liveConnection.ConnectionInfo.ConnectionDetails.ServerName;
string userName = liveConnection.ConnectionInfo.ConnectionDetails.UserName;
string password = liveConnection.ConnectionInfo.ConnectionDetails.Password;
string sqlCmdQuery = $@"
:Connect {serverName} -U {userName} -P {password}
GO
select * from sys.databases where name = 'master'
GO";
string sqlCmdQueryIncorrect = $@"
:Connect {serverName} -u {userName} -p {password}
GO
select * from sys.databases where name = 'master'
GO";
var condition = new ExecutionEngineConditions() { IsSqlCmd = true };
using (SqlConnection sqlConn = ConnectionService.OpenSqlConnection(liveConnection.ConnectionInfo))
using (TestExecutor testExecutor = new TestExecutor(sqlCmdQuery, sqlConn, condition))
{
testExecutor.Run();
Assert.True(testExecutor.ParserExecutionError == false, "Parse Execution error should be false");
Assert.True(testExecutor.ResultCountQueue.Count == 1, $"Unexpected number of ResultCount items - expected 1 but got {testExecutor.ResultCountQueue.Count}");
Assert.True(testExecutor.ErrorMessageQueue.Count == 0, $"Unexpected error messages from test executor : {string.Join(Environment.NewLine, testExecutor.ErrorMessageQueue)}");
}
using (SqlConnection sqlConn = ConnectionService.OpenSqlConnection(liveConnection.ConnectionInfo))
using (TestExecutor testExecutor = new TestExecutor(sqlCmdQueryIncorrect, sqlConn, condition))
{
testExecutor.Run();
Assert.True(testExecutor.ParserExecutionError == true, "Parse Execution error should be true");
}
}
}
/// <summary>
/// Verify whether the batchParser parsed :on error successfully
/// </summary>
[Fact]
public void VerifyOnErrorSqlCmd()
{
using (ExecutionEngine executionEngine = new ExecutionEngine())
{
var liveConnection = LiveConnectionHelper.InitLiveConnectionInfo("master");
string serverName = liveConnection.ConnectionInfo.ConnectionDetails.ServerName;
string sqlCmdQuery = $@"
:on error ignore
GO
select * from sys.databases_wrong where name = 'master'
GO
select* from sys.databases where name = 'master'
GO
:on error exit
GO
select* from sys.databases_wrong where name = 'master'
GO
select* from sys.databases where name = 'master'
GO";
var condition = new ExecutionEngineConditions() { IsSqlCmd = true };
using (SqlConnection sqlConn = ConnectionService.OpenSqlConnection(liveConnection.ConnectionInfo))
using (TestExecutor testExecutor = new TestExecutor(sqlCmdQuery, sqlConn, condition))
{
testExecutor.Run();
Assert.True(testExecutor.ResultCountQueue.Count == 1, $"Unexpected number of ResultCount items - expected only 1 since the later should not be executed but got {testExecutor.ResultCountQueue.Count}");
Assert.True(testExecutor.ErrorMessageQueue.Count == 2, $"Unexpected number error messages from test executor expected 2, actual : {string.Join(Environment.NewLine, testExecutor.ErrorMessageQueue)}");
}
}
}
/// <summary>
/// Verify whether the batchParser parses Include command i.e. :r successfully
/// </summary>
[Fact]
public void VerifyIncludeSqlCmd()
{
string file = "VerifyIncludeSqlCmd_test.sql";
try
{
using (ExecutionEngine executionEngine = new ExecutionEngine())
{
var liveConnection = LiveConnectionHelper.InitLiveConnectionInfo("master");
string serverName = liveConnection.ConnectionInfo.ConnectionDetails.ServerName;
string sqlCmdFile = $@"
select * from sys.databases where name = 'master'
GO";
File.WriteAllText("VerifyIncludeSqlCmd_test.sql", sqlCmdFile);
string sqlCmdQuery = $@"
:r {file}
GO
select * from sys.databases where name = 'master'
GO";
var condition = new ExecutionEngineConditions() { IsSqlCmd = true };
using (SqlConnection sqlConn = ConnectionService.OpenSqlConnection(liveConnection.ConnectionInfo))
using (TestExecutor testExecutor = new TestExecutor(sqlCmdQuery, sqlConn, condition))
{
testExecutor.Run();
Assert.True(testExecutor.ResultCountQueue.Count == 2, $"Unexpected number of ResultCount items - expected 1 but got {testExecutor.ResultCountQueue.Count}");
Assert.True(testExecutor.ErrorMessageQueue.Count == 0, $"Unexpected error messages from test executor : {string.Join(Environment.NewLine, testExecutor.ErrorMessageQueue)}");
}
File.Delete(file);
}
}
catch
{
if (File.Exists(file))
{
File.Delete(file);
}
}
}
// Verify whether the executionEngine execute Batch
[Fact]
public void VerifyExecuteBatch()

View File

@@ -26,7 +26,7 @@ namespace Microsoft.SqlTools.ManagedBatchParser.UnitTests.BatchParser
this.parser = parser;
}
public BatchParserAction Go(TextBlock batch, int repeatCount)
public BatchParserAction Go(TextBlock batch, int repeatCount, SqlCmdCommand command)
{
string textWithVariablesResolved;
string textWithVariablesUnresolved;

View File

@@ -115,9 +115,10 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.QueryExecution
Assert.True(multipleElapsedTime > elapsedTime);
}
public Query CreateAndExecuteQuery(string queryText, ConnectionInfo connectionInfo, IFileStreamFactory fileStreamFactory)
public static Query CreateAndExecuteQuery(string queryText, ConnectionInfo connectionInfo, IFileStreamFactory fileStreamFactory, bool IsSqlCmd = false)
{
Query query = new Query(queryText, connectionInfo, new QueryExecutionSettings(), fileStreamFactory);
var settings = new QueryExecutionSettings() { IsSqlCmdMode = IsSqlCmd };
Query query = new Query(queryText, connectionInfo, settings, fileStreamFactory);
query.Execute();
query.ExecutionTask.Wait();
return query;

View File

@@ -0,0 +1,141 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility;
using Microsoft.SqlTools.ServiceLayer.QueryExecution;
using Microsoft.SqlTools.ServiceLayer.Test.Common;
using System;
using System.IO;
using Xunit;
namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.QueryExecution
{
public class SqlCmdExecutionTest
{
[Fact]
public void TestConnectSqlCmdCommand()
{
var fileStreamFactory = MemoryFileSystem.GetFileStreamFactory();
var liveConnection = LiveConnectionHelper.InitLiveConnectionInfo("master");
ConnectionInfo connInfo = liveConnection.ConnectionInfo;
string serverName = liveConnection.ConnectionInfo.ConnectionDetails.ServerName;
string sqlCmdQuerySuccess = $@"
:Connect {serverName}
select * from sys.databases where name = 'master'
GO";
Query query = ExecuteTests.CreateAndExecuteQuery(sqlCmdQuerySuccess, connInfo, fileStreamFactory, IsSqlCmd: true);
Assert.True(query.Batches.Length == 1, $"Expected: 1 parsed batch, actual : {query.Batches.Length}");
Assert.True(query.Batches[0].HasExecuted && !query.Batches[0].HasError && query.Batches[0].ResultSets.Count == 1, "Query should be executed and have one result set");
string sqlCmdQueryFilaure = $@"
:Connect SomeWrongName
select * from sys.databases where name = 'master'
GO";
query = ExecuteTests.CreateAndExecuteQuery(sqlCmdQueryFilaure, connInfo, fileStreamFactory, IsSqlCmd: true);
Assert.True(query.Batches.Length == 1, $"Expected: 1 parsed batch, actual : {query.Batches.Length}");
Assert.True(query.Batches[0].HasError, "Query should have error");
}
[Fact]
public void TestOnErrorSqlCmdCommand()
{
var fileStreamFactory = MemoryFileSystem.GetFileStreamFactory();
var liveConnection = LiveConnectionHelper.InitLiveConnectionInfo("master");
ConnectionInfo connInfo = liveConnection.ConnectionInfo;
string sqlCmdQuerySuccess = $@"
:on error ignore
GO
select * from sys.databases_wrong where name = 'master'
GO
select * from sys.databases where name = 'master'
GO";
Query query = ExecuteTests.CreateAndExecuteQuery(sqlCmdQuerySuccess, connInfo, fileStreamFactory, IsSqlCmd: true);
Assert.True(query.Batches[0].HasExecuted && query.Batches[0].HasError, "first batch should be executed and have error");
Assert.True(query.Batches[1].HasExecuted, "last batch should be executed");
string sqlCmdQueryFilaure = $@"
:on error exit
GO
select * from sys.databases_wrong where name = 'master'
GO
select * from sys.databases where name = 'master'
GO";
query = ExecuteTests.CreateAndExecuteQuery(sqlCmdQueryFilaure, connInfo, fileStreamFactory, IsSqlCmd: true);
Assert.True(query.Batches[0].HasExecuted && query.Batches[0].HasError, "first batch should be executed and have error");
Assert.False(query.Batches[1].HasExecuted, "last batch should NOT be executed");
}
[Fact]
public void TestIncludeSqlCmdCommand()
{
var fileStreamFactory = MemoryFileSystem.GetFileStreamFactory();
var liveConnection = LiveConnectionHelper.InitLiveConnectionInfo("master");
ConnectionInfo connInfo = liveConnection.ConnectionInfo;
string path = Path.Combine(Environment.CurrentDirectory, "mysqlfile.sql");
string sqlPath = "\"" + path + "\"";
// correct sql file text
string correctfileText = $@"
select * from sys.databases where name = 'msdb' or name = 'master'
GO";
// incorrect sql file text
string incorrectfileText = $@"
select * from sys.databases_wrong where name = 'msdb' or name = 'master'
GO";
File.WriteAllText(path, correctfileText);
string sqlCmdQuerySuccess = $@"
:on error exit
:setvar mypath {sqlPath}
GO
:r $(mypath)
GO
select * from sys.databases where name = 'master'
GO";
Query query = ExecuteTests.CreateAndExecuteQuery(sqlCmdQuerySuccess, connInfo, fileStreamFactory, IsSqlCmd: true);
Assert.True(query.Batches.Length == 2, $"Batches should be parsed and should be 2, actual number {query.Batches.Length}");
Assert.True(query.Batches[0].HasExecuted && !query.Batches[0].HasError && query.Batches[0].ResultSets.Count == 1 && query.Batches[0].ResultSets[0].RowCount == 2, "first batch should be executed and have 2 results");
Assert.True(query.Batches[1].HasExecuted && !query.Batches[1].HasError && query.Batches[1].ResultSets.Count == 1 && query.Batches[1].ResultSets[0].RowCount == 1, "second batch should be executed and have 1 result");
string sqlCmdQueryFilaure1 = $@"
:on error exit
:setvar mypath somewrongpath
GO
:r $(mypath)
GO
select * from sys.databases where name = 'master'
GO";
query = ExecuteTests.CreateAndExecuteQuery(sqlCmdQueryFilaure1, connInfo, fileStreamFactory, IsSqlCmd: true);
Assert.True(query.Batches.Length == 0, $"Batches should be 0 since parsing was aborted, actual number {query.Batches.Length}");
File.WriteAllText(path, incorrectfileText);
string sqlCmdQueryFilaure2 = $@"
:on error exit
:setvar mypath {sqlPath}
GO
:r $(mypath)
GO
select * from sys.databases where name = 'master'
GO";
query = ExecuteTests.CreateAndExecuteQuery(sqlCmdQueryFilaure2, connInfo, fileStreamFactory, IsSqlCmd: true);
Assert.True(query.Batches.Length == 2, $"Batches should be parsed and should be 2, actual number {query.Batches.Length}");
Assert.True(query.Batches[0].HasExecuted && query.Batches[0].HasError, "first batch should be executed and have error");
Assert.True(!query.Batches[1].HasExecuted, "second batch should not get to be executed because of the first error");
}
}
}