diff --git a/.gitignore b/.gitignore
index de0fdc5b..a52a0fe0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -36,6 +36,7 @@ node_modules
packages
reports
opencovertests.xml
+outputCobertura.xml
sqltools.xml
# Cross building rootfs
@@ -275,3 +276,7 @@ Session.vim
# Visual Studio Code
.vscode/
+
+# Stuff from cake
+/artifacts/
+/.tools/
\ No newline at end of file
diff --git a/build.cake b/build.cake
index 61047b05..1b7c3b80 100644
--- a/build.cake
+++ b/build.cake
@@ -81,7 +81,8 @@ Task("Cleanup")
/// Pre-build setup tasks.
///
Task("Setup")
- .IsDependentOn("BuildEnvironment")
+ .IsDependentOn("InstallDotnet")
+ .IsDependentOn("InstallXUnit")
.IsDependentOn("PopulateRuntimes")
.Does(() =>
{
@@ -92,7 +93,6 @@ Task("Setup")
/// Use default RID (+ win7-x86 on Windows) for now.
///
Task("PopulateRuntimes")
- .IsDependentOn("BuildEnvironment")
.Does(() =>
{
buildPlan.Rids = new string[]
@@ -112,43 +112,65 @@ Task("PopulateRuntimes")
});
///
-/// Install/update build environment.
+/// Install dotnet if it isn't already installed
///
-Task("BuildEnvironment")
+Task("InstallDotnet")
.Does(() =>
{
- var installScript = $"dotnet-install.{shellExtension}";
- System.IO.Directory.CreateDirectory(dotnetFolder);
- var scriptPath = System.IO.Path.Combine(dotnetFolder, installScript);
- using (WebClient client = new WebClient())
- {
- client.DownloadFile($"{buildPlan.DotNetInstallScriptURL}/{installScript}", scriptPath);
- }
- if (!IsRunningOnWindows())
- {
- Run("chmod", $"+x '{scriptPath}'");
- }
- var installArgs = $"-Channel {buildPlan.DotNetChannel}";
- if (!String.IsNullOrEmpty(buildPlan.DotNetVersion))
- {
- installArgs = $"{installArgs} -Version {buildPlan.DotNetVersion}";
- }
- if (!buildPlan.UseSystemDotNetPath)
- {
- installArgs = $"{installArgs} -InstallDir {dotnetFolder}";
- }
- Run(shell, $"{shellArgument} {scriptPath} {installArgs}");
- try
- {
- Run(dotnetcli, "--info");
- }
- catch (Win32Exception)
- {
- throw new Exception(".NET CLI binary cannot be found.");
- }
+ // Determine if `dotnet` is installed
+ var dotnetInstalled = true;
+ try
+ {
+ Run(dotnetcli, "--info");
+ Information("dotnet is already installed, will skip download/install");
+ }
+ catch(Win32Exception)
+ {
+ // If we get this exception, dotnet isn't installed
+ dotnetInstalled = false;
+ }
- System.IO.Directory.CreateDirectory(toolsFolder);
+ // Install dotnet if it isn't already installed
+ if (!dotnetInstalled)
+ {
+ var installScript = $"dotnet-install.{shellExtension}";
+ System.IO.Directory.CreateDirectory(dotnetFolder);
+ var scriptPath = System.IO.Path.Combine(dotnetFolder, installScript);
+ using (WebClient client = new WebClient())
+ {
+ client.DownloadFile($"{buildPlan.DotNetInstallScriptURL}/{installScript}", scriptPath);
+ }
+ if (!IsRunningOnWindows())
+ {
+ Run("chmod", $"+x '{scriptPath}'");
+ }
+ var installArgs = $"-Channel {buildPlan.DotNetChannel}";
+ if (!String.IsNullOrEmpty(buildPlan.DotNetVersion))
+ {
+ installArgs = $"{installArgs} -Version {buildPlan.DotNetVersion}";
+ }
+ if (!buildPlan.UseSystemDotNetPath)
+ {
+ installArgs = $"{installArgs} -InstallDir {dotnetFolder}";
+ }
+ Run(shell, $"{shellArgument} {scriptPath} {installArgs}");
+ try
+ {
+ Run(dotnetcli, "--info");
+ }
+ catch (Win32Exception)
+ {
+ throw new Exception(".NET CLI failed to be installed");
+ }
+ }
+});
+///
+/// Installs XUnit nuget package
+Task("InstallXUnit")
+ .Does(() =>
+{
+ // Install the tools
var nugetPath = Environment.GetEnvironmentVariable("NUGET_EXE");
var arguments = $"install xunit.runner.console -ExcludeVersion -NoCache -Prerelease -OutputDirectory \"{toolsFolder}\"";
if (IsRunningOnWindows())
@@ -208,14 +230,6 @@ Task("TestAll")
.IsDependentOn("TestCore")
.Does(() =>{});
-///
-/// Run all tests for Travis CI .NET Desktop and .NET Core
-///
-Task("TravisTestAll")
- .IsDependentOn("Cleanup")
- .IsDependentOn("TestAll")
- .Does(() =>{});
-
///
/// Run tests for .NET Core (using .NET CLI).
///
@@ -345,6 +359,7 @@ Task("RestrictToLocalRuntime")
///
Task("LocalPublish")
.IsDependentOn("Restore")
+ .IsDependentOn("SrGen")
.IsDependentOn("RestrictToLocalRuntime")
.IsDependentOn("OnlyPublish")
.Does(() =>
@@ -451,20 +466,6 @@ Task("Local")
{
});
-///
-/// Build centered around producing the final artifacts for Travis
-///
-/// The tests are run as a different task "TestAll"
-///
-Task("Travis")
- .IsDependentOn("Cleanup")
- .IsDependentOn("Restore")
- .IsDependentOn("AllPublish")
- // .IsDependentOn("TestPublished")
- .Does(() =>
-{
-});
-
///
/// Update the package versions within project.json files.
/// Uses depversion.json file as input.
@@ -492,6 +493,46 @@ Task("SetPackageVersions")
}
});
+///
+/// Executes SRGen to create a resx file and associated designer C# file
+///
+Task("SRGen")
+ .Does(() =>
+{
+ var projects = System.IO.Directory.GetFiles(sourceFolder, "project.json", SearchOption.AllDirectories).ToList();
+ foreach(var project in projects) {
+ var projectDir = System.IO.Path.GetDirectoryName(project);
+ var projectName = (new System.IO.DirectoryInfo(projectDir)).Name;
+ var projectStrings = System.IO.Path.Combine(projectDir, "sr.strings");
+
+ if (!System.IO.File.Exists(projectStrings))
+ {
+ Information("Project {0} doesn't contain 'sr.strings' file", projectName);
+ continue;
+ }
+
+ var srgenPath = System.IO.Path.Combine(toolsFolder, "Microsoft.DataTools.SrGen", "lib", "netcoreapp1.0", "srgen.dll");
+ var outputResx = System.IO.Path.Combine(projectDir, "sr.resx");
+ var outputCs = System.IO.Path.Combine(projectDir, "sr.cs");
+
+ // Delete preexisting resx and designer files
+ if (System.IO.File.Exists(outputResx))
+ {
+ System.IO.File.Delete(outputResx);
+ }
+ if (System.IO.File.Exists(outputCs))
+ {
+ System.IO.File.Delete(outputCs);
+ }
+
+ // Run SRGen
+ var dotnetArgs = string.Format("{0} -or \"{1}\" -oc \"{2}\" -ns \"{3}\" -an \"{4}\" -cn SR -l CS \"{5}\"",
+ srgenPath, outputResx, outputCs, projectName, projectName, projectStrings);
+ Information("{0}", dotnetArgs);
+ Run(dotnetcli, dotnetArgs);
+ }
+});
+
///
/// Default Task aliases to Local.
///
diff --git a/build.cmd b/build.cmd
new file mode 100644
index 00000000..9a7e5104
--- /dev/null
+++ b/build.cmd
@@ -0,0 +1 @@
+powershell -File build.ps1 %*
\ No newline at end of file
diff --git a/scripts/cake-bootstrap.ps1 b/scripts/cake-bootstrap.ps1
index a87c1478..7f0bc382 100644
--- a/scripts/cake-bootstrap.ps1
+++ b/scripts/cake-bootstrap.ps1
@@ -106,5 +106,7 @@ if (!(Test-Path $CAKE_EXE)) {
# Start Cake
Write-Host "Running build script..."
-Invoke-Expression "& `"$CAKE_EXE`" `"$Script`" -verbosity=`"$Verbosity`" $UseMono $UseDryRun $ScriptArgs"
+$v = "& `"$CAKE_EXE`" `"$Script`" -verbosity=`"$Verbosity`" $UseMono $UseDryRun $ScriptArgs"
+Write-Host $v
+Invoke-Expression $v
exit $LASTEXITCODE
diff --git a/scripts/packages.config b/scripts/packages.config
index c4feb50f..296eeb59 100644
--- a/scripts/packages.config
+++ b/scripts/packages.config
@@ -2,4 +2,5 @@
+
diff --git a/sqltoolsservice.sln b/sqltoolsservice.sln
index cd55b538..15ece77f 100644
--- a/sqltoolsservice.sln
+++ b/sqltoolsservice.sln
@@ -9,6 +9,7 @@ EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{32DC973E-9EEA-4694-B1C2-B031167AB945}"
ProjectSection(SolutionItems) = preProject
.gitignore = .gitignore
+ BUILD.md = BUILD.md
global.json = global.json
nuget.config = nuget.config
README.md = README.md
@@ -18,6 +19,25 @@ Project("{8BB2217D-0F2D-49D1-97BC-3654ED321F3B}") = "Microsoft.SqlTools.ServiceL
EndProject
Project("{8BB2217D-0F2D-49D1-97BC-3654ED321F3B}") = "Microsoft.SqlTools.ServiceLayer.Test", "test\Microsoft.SqlTools.ServiceLayer.Test\Microsoft.SqlTools.ServiceLayer.Test.xproj", "{2D771D16-9D85-4053-9F79-E2034737DEEF}"
EndProject
+Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "scripts", "scripts", "{B7D21727-2926-452B-9610-3ADB0BB6D789}"
+ ProjectSection(SolutionItems) = preProject
+ scripts\archiving.cake = scripts\archiving.cake
+ scripts\artifacts.cake = scripts\artifacts.cake
+ scripts\cake-bootstrap.ps1 = scripts\cake-bootstrap.ps1
+ scripts\cake-bootstrap.sh = scripts\cake-bootstrap.sh
+ scripts\packages.config = scripts\packages.config
+ scripts\runhelpers.cake = scripts\runhelpers.cake
+ EndProjectSection
+EndProject
+Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "build", "build", "{F9978D78-78FE-4E92-A7D6-D436B7683EF6}"
+ ProjectSection(SolutionItems) = preProject
+ build.cake = build.cake
+ build.cmd = build.cmd
+ build.json = build.json
+ build.ps1 = build.ps1
+ build.sh = build.sh
+ EndProjectSection
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -39,5 +59,6 @@ Global
GlobalSection(NestedProjects) = preSolution
{0D61DC2B-DA66-441D-B9D0-F76C98F780F9} = {2BBD7364-054F-4693-97CD-1C395E3E84A9}
{2D771D16-9D85-4053-9F79-E2034737DEEF} = {AB9CA2B8-6F70-431C-8A1D-67479D8A7BE4}
+ {B7D21727-2926-452B-9610-3ADB0BB6D789} = {F9978D78-78FE-4E92-A7D6-D436B7683EF6}
EndGlobalSection
EndGlobal
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
index 57a7ba6e..c694ca93 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
@@ -11,6 +11,7 @@ using System.Data.SqlClient;
using System.Threading.Tasks;
using Microsoft.SqlTools.EditorServices.Utility;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
+using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.SqlContext;
using Microsoft.SqlTools.ServiceLayer.Workspace;
@@ -169,12 +170,45 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
ownerToConnectionMap[connectionParams.OwnerUri] = connectionInfo;
+ // Update with the actual database name in connectionInfo and result
+ // Doing this here as we know the connection is open - expect to do this only on connecting
+ connectionInfo.ConnectionDetails.DatabaseName = connectionInfo.SqlConnection.Database;
+ response.ConnectionSummary = new ConnectionSummary()
+ {
+ ServerName = connectionInfo.ConnectionDetails.ServerName,
+ DatabaseName = connectionInfo.ConnectionDetails.DatabaseName,
+ UserName = connectionInfo.ConnectionDetails.UserName,
+ };
+
// invoke callback notifications
foreach (var activity in this.onConnectionActivities)
{
activity(connectionInfo);
}
+ // try to get information about the connected SQL Server instance
+ try
+ {
+ ReliableConnectionHelper.ServerInfo serverInfo = ReliableConnectionHelper.GetServerVersion(connectionInfo.SqlConnection);
+ response.ServerInfo = new Contracts.ServerInfo()
+ {
+ ServerMajorVersion = serverInfo.ServerMajorVersion,
+ ServerMinorVersion = serverInfo.ServerMinorVersion,
+ ServerReleaseVersion = serverInfo.ServerReleaseVersion,
+ EngineEditionId = serverInfo.EngineEditionId,
+ ServerVersion = serverInfo.ServerVersion,
+ ServerLevel = serverInfo.ServerLevel,
+ ServerEdition = serverInfo.ServerEdition,
+ IsCloud = serverInfo.IsCloud,
+ AzureVersion = serverInfo.AzureVersion,
+ OsVersion = serverInfo.OsVersion
+ };
+ }
+ catch(Exception ex)
+ {
+ response.Messages = ex.ToString();
+ }
+
// return the connection result
response.ConnectionId = connectionInfo.ConnectionId.ToString();
return response;
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParamsExtensions.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParamsExtensions.cs
index 9f2c7356..c345532d 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParamsExtensions.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParamsExtensions.cs
@@ -3,8 +3,6 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
-using System;
-
namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
{
///
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectResponse.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectResponse.cs
index c325c64f..9066efa8 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectResponse.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectResponse.cs
@@ -19,5 +19,15 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
/// Gets or sets any connection error messages
///
public string Messages { get; set; }
+
+ ///
+ /// Information about the connected server.
+ ///
+ public ServerInfo ServerInfo { get; set; }
+
+ ///
+ /// Gets or sets the actual Connection established, including Database Name
+ ///
+ public ConnectionSummary ConnectionSummary { get; set; }
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ServerInfo.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ServerInfo.cs
new file mode 100644
index 00000000..3bc9e73d
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ServerInfo.cs
@@ -0,0 +1,63 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
+{
+ ///
+ /// Contract for information on the connected SQL Server instance.
+ ///
+ public class ServerInfo
+ {
+ ///
+ /// The major version of the SQL Server instance.
+ ///
+ public int ServerMajorVersion { get; set; }
+
+ ///
+ /// The minor version of the SQL Server instance.
+ ///
+ public int ServerMinorVersion { get; set; }
+
+ ///
+ /// The build of the SQL Server instance.
+ ///
+ public int ServerReleaseVersion { get; set; }
+
+ ///
+ /// The ID of the engine edition of the SQL Server instance.
+ ///
+ public int EngineEditionId { get; set; }
+
+ ///
+ /// String containing the full server version text.
+ ///
+ public string ServerVersion { get; set; }
+
+ ///
+ /// String describing the product level of the server.
+ ///
+ public string ServerLevel { get; set; }
+
+ ///
+ /// The edition of the SQL Server instance.
+ ///
+ public string ServerEdition { get; set; }
+
+ ///
+ /// Whether the SQL Server instance is running in the cloud (Azure) or not.
+ ///
+ public bool IsCloud { get; set; }
+
+ ///
+ /// The version of Azure that the SQL Server instance is running on, if applicable.
+ ///
+ public int AzureVersion { get; set; }
+
+ ///
+ /// The Operating System version string of the machine running the SQL Server instance.
+ ///
+ public string OsVersion { get; set; }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/AmbientSettings.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/AmbientSettings.cs
new file mode 100644
index 00000000..e56769dc
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/AmbientSettings.cs
@@ -0,0 +1,452 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.Collections.Generic;
+using System.Reflection;
+using Microsoft.SqlTools.EditorServices.Utility;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
+{
+ ///
+ /// This class represents connection (and other) settings specified by called of the DacFx API. DacFx
+ /// cannot rely on the registry to supply override values therefore setting overrides must be made
+ /// by the top-of-the-stack
+ ///
+ internal sealed class AmbientSettings
+ {
+ private const string LogicalContextName = "__LocalContextConfigurationName";
+
+ internal enum StreamBackingStore
+ {
+ // MemoryStream
+ Memory = 0,
+
+ // FileStream
+ File = 1
+ }
+
+ // Internal for test purposes
+ internal const string MasterReferenceFilePathIndex = "MasterReferenceFilePath";
+ internal const string DatabaseLockTimeoutIndex = "DatabaseLockTimeout";
+ internal const string QueryTimeoutIndex = "QueryTimeout";
+ internal const string LongRunningQueryTimeoutIndex = "LongRunningQueryTimeout";
+ internal const string AlwaysRetryOnTransientFailureIndex = "AlwaysRetryOnTransientFailure";
+ internal const string MaxDataReaderDegreeOfParallelismIndex = "MaxDataReaderDegreeOfParallelism";
+ internal const string ConnectionRetryHandlerIndex = "ConnectionRetryHandler";
+ internal const string TraceRowCountFailureIndex = "TraceRowCountFailure";
+ internal const string TableProgressUpdateIntervalIndex = "TableProgressUpdateInterval";
+ internal const string UseOfflineDataReaderIndex = "UseOfflineDataReader";
+ internal const string StreamBackingStoreForOfflineDataReadingIndex = "StreamBackingStoreForOfflineDataReading";
+ internal const string DisableIndexesForDataPhaseIndex = "DisableIndexesForDataPhase";
+ internal const string ReliableDdlEnabledIndex = "ReliableDdlEnabled";
+ internal const string ImportModelDatabaseIndex = "ImportModelDatabase";
+ internal const string SupportAlwaysEncryptedIndex = "SupportAlwaysEncrypted";
+ internal const string SkipObjectTypeBlockingIndex = "SkipObjectTypeBlocking";
+ internal const string DoNotSerializeQueryStoreSettingsIndex = "DoNotSerializeQueryStoreSettings";
+ internal const string AlwaysEncryptedWizardMigrationIndex = "AlwaysEncryptedWizardMigration";
+
+ private static readonly AmbientData _defaultSettings;
+
+ static AmbientSettings()
+ {
+ _defaultSettings = new AmbientData();
+ }
+
+ ///
+ /// Access to the default ambient settings. Access to these settings is made available
+ /// for SSDT scenarios where settings are read from the registry and not set explicitly through
+ /// the API
+ ///
+ public static AmbientData DefaultSettings
+ {
+ get { return _defaultSettings; }
+ }
+
+ public static string MasterReferenceFilePath
+ {
+ get { return GetValue(MasterReferenceFilePathIndex); }
+ }
+
+ public static int LockTimeoutMilliSeconds
+ {
+ get { return GetValue(DatabaseLockTimeoutIndex); }
+ }
+
+ public static int QueryTimeoutSeconds
+ {
+ get { return GetValue(QueryTimeoutIndex); }
+ }
+
+ public static int LongRunningQueryTimeoutSeconds
+ {
+ get { return GetValue(LongRunningQueryTimeoutIndex); }
+ }
+
+ public static Action ConnectionRetryMessageHandler
+ {
+ get { return GetValue>(ConnectionRetryHandlerIndex); }
+ }
+
+ public static bool AlwaysRetryOnTransientFailure
+ {
+ get { return GetValue(AlwaysRetryOnTransientFailureIndex); }
+ }
+
+ public static int MaxDataReaderDegreeOfParallelism
+ {
+ get { return GetValue(MaxDataReaderDegreeOfParallelismIndex); }
+ }
+
+ public static int TableProgressUpdateInterval
+ {
+ // value of zero means do not fire 'heartbeat' progress events. Non-zero values will
+ // fire a heartbeat progress event every n seconds.
+ get { return GetValue(TableProgressUpdateIntervalIndex); }
+ }
+
+ public static bool TraceRowCountFailure
+ {
+ get { return GetValue(TraceRowCountFailureIndex); }
+ }
+
+ public static bool UseOfflineDataReader
+ {
+ get { return GetValue(UseOfflineDataReaderIndex); }
+ }
+
+ public static StreamBackingStore StreamBackingStoreForOfflineDataReading
+ {
+ get { return GetValue(StreamBackingStoreForOfflineDataReadingIndex); }
+ }
+
+ public static bool DisableIndexesForDataPhase
+ {
+ get { return GetValue(DisableIndexesForDataPhaseIndex); }
+ }
+
+ public static bool ReliableDdlEnabled
+ {
+ get { return GetValue(ReliableDdlEnabledIndex); }
+ }
+
+ public static bool ImportModelDatabase
+ {
+ get { return GetValue(ImportModelDatabaseIndex); }
+ }
+
+ ///
+ /// Setting that shows whether Always Encrypted is supported.
+ /// If false, then reverse engineering and script interpretation of a database with any Always Encrypted object will fail
+ ///
+ public static bool SupportAlwaysEncrypted
+ {
+ get { return GetValue(SupportAlwaysEncryptedIndex); }
+ }
+
+ public static bool AlwaysEncryptedWizardMigration
+ {
+ get { return GetValue(AlwaysEncryptedWizardMigrationIndex); }
+ }
+
+ ///
+ /// Setting that determines whether checks for unsupported object types are performed.
+ /// If false, unsupported object types will prevent extract from being performed.
+ /// Default value is false.
+ ///
+ public static bool SkipObjectTypeBlocking
+ {
+ get { return GetValue(SkipObjectTypeBlockingIndex); }
+ }
+
+ ///
+ /// Setting that determines whether the Database Options that store Query Store settings will be left out during package serialization.
+ /// Default value is false.
+ ///
+ public static bool DoNotSerializeQueryStoreSettings
+ {
+ get { return GetValue(DoNotSerializeQueryStoreSettingsIndex); }
+ }
+
+ ///
+ /// Called by top-of-stack API to setup/configure settings that should be used
+ /// throughout the API (lower in the stack). The settings are reverted once the returned context
+ /// has been disposed.
+ ///
+ public static IStackSettingsContext CreateSettingsContext()
+ {
+ return new StackConfiguration();
+ }
+
+ private static T1 GetValue(string configIndex)
+ {
+ IAmbientDataDirectAccess config = _defaultSettings;
+
+ return (T1)config.Data[configIndex].Value;
+ }
+
+ ///
+ /// Data-transfer object that represents a specific configuration
+ ///
+ public class AmbientData : IAmbientDataDirectAccess
+ {
+ private readonly Dictionary _configuration;
+
+ public AmbientData()
+ {
+ _configuration = new Dictionary(StringComparer.OrdinalIgnoreCase);
+ _configuration[DatabaseLockTimeoutIndex] = new AmbientValue(5000);
+ _configuration[QueryTimeoutIndex] = new AmbientValue(60);
+ _configuration[LongRunningQueryTimeoutIndex] = new AmbientValue(0);
+ _configuration[AlwaysRetryOnTransientFailureIndex] = new AmbientValue(false);
+ _configuration[ConnectionRetryHandlerIndex] = new AmbientValue(typeof(Action), null);
+ _configuration[MaxDataReaderDegreeOfParallelismIndex] = new AmbientValue(8);
+ _configuration[TraceRowCountFailureIndex] = new AmbientValue(false); // default: throw DacException on rowcount mismatch during import/export data validation
+ _configuration[TableProgressUpdateIntervalIndex] = new AmbientValue(300); // default: fire heartbeat progress update events every 5 minutes
+ _configuration[UseOfflineDataReaderIndex] = new AmbientValue(false);
+ _configuration[StreamBackingStoreForOfflineDataReadingIndex] = new AmbientValue(StreamBackingStore.File); //applicable only when UseOfflineDataReader is set to true
+ _configuration[MasterReferenceFilePathIndex] = new AmbientValue(typeof(string), null);
+ // Defect 1210884: Enable an option to allow secondary index, check and fk constraints to stay enabled during data upload with import in DACFX for IES
+ _configuration[DisableIndexesForDataPhaseIndex] = new AmbientValue(true);
+ _configuration[ReliableDdlEnabledIndex] = new AmbientValue(false);
+ _configuration[ImportModelDatabaseIndex] = new AmbientValue(true);
+ _configuration[SupportAlwaysEncryptedIndex] = new AmbientValue(false);
+ _configuration[AlwaysEncryptedWizardMigrationIndex] = new AmbientValue(false);
+ _configuration[SkipObjectTypeBlockingIndex] = new AmbientValue(false);
+ _configuration[DoNotSerializeQueryStoreSettingsIndex] = new AmbientValue(false);
+ }
+
+ public string MasterReferenceFilePath
+ {
+ get { return (string)_configuration[MasterReferenceFilePathIndex].Value; }
+ set { _configuration[MasterReferenceFilePathIndex].Value = value; }
+ }
+
+ public int LockTimeoutMilliSeconds
+ {
+ get { return (int)_configuration[DatabaseLockTimeoutIndex].Value; }
+ set { _configuration[DatabaseLockTimeoutIndex].Value = value; }
+ }
+ public int QueryTimeoutSeconds
+ {
+ get { return (int)_configuration[QueryTimeoutIndex].Value; }
+ set { _configuration[QueryTimeoutIndex].Value = value; }
+ }
+ public int LongRunningQueryTimeoutSeconds
+ {
+ get { return (int)_configuration[LongRunningQueryTimeoutIndex].Value; }
+ set { _configuration[LongRunningQueryTimeoutIndex].Value = value; }
+ }
+ public bool AlwaysRetryOnTransientFailure
+ {
+ get { return (bool)_configuration[AlwaysRetryOnTransientFailureIndex].Value; }
+ set { _configuration[AlwaysRetryOnTransientFailureIndex].Value = value; }
+ }
+ public Action ConnectionRetryMessageHandler
+ {
+ get { return (Action)_configuration[ConnectionRetryHandlerIndex].Value; }
+ set { _configuration[ConnectionRetryHandlerIndex].Value = value; }
+ }
+ public bool TraceRowCountFailure
+ {
+ get { return (bool)_configuration[TraceRowCountFailureIndex].Value; }
+ set { _configuration[TraceRowCountFailureIndex].Value = value; }
+ }
+ public int TableProgressUpdateInterval
+ {
+ get { return (int)_configuration[TableProgressUpdateIntervalIndex].Value; }
+ set { _configuration[TableProgressUpdateIntervalIndex].Value = value; }
+ }
+
+ public bool UseOfflineDataReader
+ {
+ get { return (bool)_configuration[UseOfflineDataReaderIndex].Value; }
+ set { _configuration[UseOfflineDataReaderIndex].Value = value; }
+ }
+
+ public StreamBackingStore StreamBackingStoreForOfflineDataReading
+ {
+ get { return (StreamBackingStore)_configuration[StreamBackingStoreForOfflineDataReadingIndex].Value; }
+ set { _configuration[StreamBackingStoreForOfflineDataReadingIndex].Value = value; }
+ }
+
+ public bool DisableIndexesForDataPhase
+ {
+ get { return (bool)_configuration[DisableIndexesForDataPhaseIndex].Value; }
+ set { _configuration[DisableIndexesForDataPhaseIndex].Value = value; }
+ }
+
+ public bool ReliableDdlEnabled
+ {
+ get { return (bool)_configuration[ReliableDdlEnabledIndex].Value; }
+ set { _configuration[ReliableDdlEnabledIndex].Value = value; }
+ }
+
+ public bool ImportModelDatabase
+ {
+ get { return (bool)_configuration[ImportModelDatabaseIndex].Value; }
+ set { _configuration[ImportModelDatabaseIndex].Value = value; }
+ }
+
+ internal bool SupportAlwaysEncrypted
+ {
+ get { return (bool)_configuration[SupportAlwaysEncryptedIndex].Value; }
+ set { _configuration[SupportAlwaysEncryptedIndex].Value = value; }
+ }
+
+ internal bool AlwaysEncryptedWizardMigration
+ {
+ get { return (bool)_configuration[AlwaysEncryptedWizardMigrationIndex].Value; }
+ set { _configuration[AlwaysEncryptedWizardMigrationIndex].Value = value; }
+ }
+
+ internal bool SkipObjectTypeBlocking
+ {
+ get { return (bool)_configuration[SkipObjectTypeBlockingIndex].Value; }
+ set { _configuration[SkipObjectTypeBlockingIndex].Value = value; }
+ }
+
+ internal bool DoNotSerializeQueryStoreSettings
+ {
+ get { return (bool)_configuration[DoNotSerializeQueryStoreSettingsIndex].Value; }
+ set { _configuration[DoNotSerializeQueryStoreSettingsIndex].Value = value; }
+ }
+
+ ///
+ /// Provides a way to bulk populate settings from a dictionary
+ ///
+ public void PopulateSettings(IDictionary settingsCollection)
+ {
+ if (settingsCollection != null)
+ {
+ Dictionary newSettings = new Dictionary();
+
+ // We know all the values are set on the current configuration
+ foreach (KeyValuePair potentialPair in settingsCollection)
+ {
+ AmbientValue currentValue;
+ if (_configuration.TryGetValue(potentialPair.Key, out currentValue))
+ {
+ object newValue = potentialPair.Value;
+ newSettings[potentialPair.Key] = newValue;
+ }
+ }
+
+ if (newSettings.Count > 0)
+ {
+ foreach (KeyValuePair newSetting in newSettings)
+ {
+ _configuration[newSetting.Key].Value = newSetting.Value;
+ }
+ }
+ }
+ }
+
+ ///
+ /// Logs the Ambient Settings
+ ///
+ public void TraceSettings()
+ {
+ // NOTE: logging as warning so we can get this data in the IEService DacFx logs
+ Logger.Write(LogLevel.Warning, Resources.LoggingAmbientSettings);
+
+ foreach (KeyValuePair setting in _configuration)
+ {
+ // Log Ambient Settings
+ Logger.Write(
+ LogLevel.Warning,
+ string.Format(
+ Resources.AmbientSettingFormat,
+ setting.Key,
+ setting.Value == null ? setting.Value : setting.Value.Value));
+ }
+ }
+
+ Dictionary IAmbientDataDirectAccess.Data
+ {
+ get { return _configuration; }
+ }
+ }
+
+ ///
+ /// This class is used as value in the dictionary to ensure that the type of value is correct.
+ ///
+ private class AmbientValue
+ {
+ private readonly Type _type;
+ private readonly bool _isTypeNullable;
+ private object _value;
+
+ public AmbientValue(object value)
+ : this(value == null ? null : value.GetType(), value)
+ {
+ }
+
+ public AmbientValue(Type type, object value)
+ {
+ if (type == null)
+ {
+ throw new ArgumentNullException("type");
+ }
+ _type = type;
+ _isTypeNullable = !type.GetTypeInfo().IsValueType || Nullable.GetUnderlyingType(type) != null;
+ Value = value;
+ }
+
+ public object Value
+ {
+ get { return _value; }
+ set
+ {
+ if ((_isTypeNullable && value == null) || _type.GetTypeInfo().IsInstanceOfType(value))
+ {
+ _value = value;
+ }
+ else
+ {
+ Logger.Write(LogLevel.Error, string.Format(Resources.UnableToAssignValue, value.GetType().FullName, _type.FullName));
+ }
+ }
+ }
+ }
+
+ ///
+ /// This private interface allows pass-through access directly to member data
+ ///
+ private interface IAmbientDataDirectAccess
+ {
+ Dictionary Data { get; }
+ }
+
+ ///
+ /// This class encapsulated the concept of configuration that is set on the stack and
+ /// flows across multiple threads as part of the logical call context
+ ///
+ private sealed class StackConfiguration : IStackSettingsContext
+ {
+ private readonly AmbientData _data;
+
+ public StackConfiguration()
+ {
+ _data = new AmbientData();
+ //CallContext.LogicalSetData(LogicalContextName, _data);
+ }
+
+ public AmbientData Settings
+ {
+ get { return _data; }
+ }
+
+ public void Dispose()
+ {
+ Dispose(true);
+ }
+ private void Dispose(bool disposing)
+ {
+ //CallContext.LogicalSetData(LogicalContextName, null);
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/CachedServerInfo.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/CachedServerInfo.cs
new file mode 100644
index 00000000..2a495a45
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/CachedServerInfo.cs
@@ -0,0 +1,137 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.Collections.Concurrent;
+using System.Data;
+using System.Data.SqlClient;
+using System.Linq;
+using Microsoft.SqlTools.EditorServices.Utility;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
+{
+ ///
+ /// This class caches server information for subsequent use
+ ///
+ internal static class CachedServerInfo
+ {
+ private struct CachedInfo
+ {
+ public bool IsAzure;
+ public DateTime LastUpdate;
+ }
+
+ private static ConcurrentDictionary _cache;
+ private static object _cacheLock;
+ private const int _maxCacheSize = 1024;
+ private const int _deleteBatchSize = 512;
+
+ private const int MinimalQueryTimeoutSecondsForAzure = 300;
+
+ static CachedServerInfo()
+ {
+ _cache = new ConcurrentDictionary(StringComparer.OrdinalIgnoreCase);
+ _cacheLock = new object();
+ }
+
+ public static int GetQueryTimeoutSeconds(IDbConnection connection)
+ {
+ string dataSource = SafeGetDataSourceFromConnection(connection);
+ return GetQueryTimeoutSeconds(dataSource);
+ }
+
+ public static int GetQueryTimeoutSeconds(string dataSource)
+ {
+ //keep existing behavior and return the default ambient settings
+ //if the provided data source is null or whitespace, or the original
+ //setting is already 0 which means no limit.
+ int originalValue = AmbientSettings.QueryTimeoutSeconds;
+ if (string.IsNullOrWhiteSpace(dataSource)
+ || (originalValue == 0))
+ {
+ return originalValue;
+ }
+
+ CachedInfo info;
+ bool hasFound = _cache.TryGetValue(dataSource, out info);
+
+ if (hasFound && info.IsAzure
+ && originalValue < MinimalQueryTimeoutSecondsForAzure)
+ {
+ return MinimalQueryTimeoutSecondsForAzure;
+ }
+ else
+ {
+ return originalValue;
+ }
+ }
+
+ public static void AddOrUpdateIsAzure(IDbConnection connection, bool isAzure)
+ {
+ Validate.IsNotNull(nameof(connection), connection);
+
+ SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(connection.ConnectionString);
+ AddOrUpdateIsAzure(builder.DataSource, isAzure);
+ }
+
+ public static void AddOrUpdateIsAzure(string dataSource, bool isAzure)
+ {
+ Validate.IsNotNullOrWhitespaceString(nameof(dataSource), dataSource);
+ CachedInfo info;
+ bool hasFound = _cache.TryGetValue(dataSource, out info);
+
+ if (hasFound && info.IsAzure == isAzure)
+ {
+ return;
+ }
+ else
+ {
+ lock (_cacheLock)
+ {
+ if (! _cache.ContainsKey(dataSource))
+ {
+ //delete a batch of old elements when we try to add a new one and
+ //the capacity limitation is hit
+ if (_cache.Keys.Count > _maxCacheSize - 1)
+ {
+ var keysToDelete = _cache
+ .OrderBy(x => x.Value.LastUpdate)
+ .Take(_deleteBatchSize)
+ .Select(pair => pair.Key);
+
+ foreach (string key in keysToDelete)
+ {
+ _cache.TryRemove(key, out info);
+ }
+ }
+ }
+
+ info.IsAzure = isAzure;
+ info.LastUpdate = DateTime.UtcNow;
+ _cache.AddOrUpdate(dataSource, info, (key, oldValue) => info);
+ }
+ }
+ }
+
+ private static string SafeGetDataSourceFromConnection(IDbConnection connection)
+ {
+ if (connection == null)
+ {
+ return null;
+ }
+
+ try
+ {
+ SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(connection.ConnectionString);
+ return builder.DataSource;
+ }
+ catch
+ {
+ Logger.Write(LogLevel.Error, String.Format(Resources.FailedToParseConnectionString, connection.ConnectionString));
+ return null;
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/Constants.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/Constants.cs
new file mode 100644
index 00000000..53c0f96b
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/Constants.cs
@@ -0,0 +1,17 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
+{
+ ///
+ /// Contains common constants used throughout ReliableConnection code.
+ ///
+ internal static class Constants
+ {
+ internal const int UndefinedErrorCode = 0;
+
+ internal const string Local = "(local)";
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/DataSchemaError.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/DataSchemaError.cs
new file mode 100644
index 00000000..143b29ec
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/DataSchemaError.cs
@@ -0,0 +1,214 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.Globalization;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
+{
+ ///
+ /// This class is used to encapsulate all the information needed by the DataSchemaErrorTaskService to create a corresponding entry in the Visual Studio Error List.
+ /// A component should add this Error Object to the for such purpose.
+ /// Errors and their children are expected to be thread-safe. Ideally, this means that
+ /// the objects are just data-transfer-objects initialized during construction.
+ ///
+ [Serializable]
+ internal class DataSchemaError
+ {
+ internal const string DefaultPrefix = "SQL";
+ private const int MaxErrorCode = 99999;
+ protected const int UndefinedErrorCode = 0;
+
+ public DataSchemaError() : this(string.Empty, ErrorSeverity.Unknown)
+ {
+ }
+
+ public DataSchemaError(string message, ErrorSeverity severity)
+ : this(message, string.Empty, severity)
+ {
+ }
+
+ public DataSchemaError(string message, Exception innerException, ErrorSeverity severity)
+ : this(message, innerException, string.Empty, 0, severity)
+ {
+ }
+
+ public DataSchemaError(string message, string document, ErrorSeverity severity)
+ : this(message, document, 0, 0, DefaultPrefix, UndefinedErrorCode, severity)
+ {
+ }
+
+ public DataSchemaError(string message, string document, int errorCode, ErrorSeverity severity)
+ : this(message, document, 0, 0, DefaultPrefix, errorCode, severity)
+ {
+ }
+
+ public DataSchemaError(string message, string document, int line, int column, ErrorSeverity severity)
+ : this(message, document,line, column, DefaultPrefix, UndefinedErrorCode, severity)
+ {
+ }
+
+ public DataSchemaError(DataSchemaError source, ErrorSeverity severity)
+ : this(source.Message, source.Document, source.Line, source.Column, source.Prefix, source.ErrorCode, severity)
+ {
+ }
+
+ public DataSchemaError(
+ Exception exception,
+ string prefix,
+ int errorCode,
+ ErrorSeverity severity)
+ : this(exception, string.Empty, 0, 0, prefix, errorCode, severity)
+ {
+ }
+
+ public DataSchemaError(
+ string message,
+ Exception exception,
+ string prefix,
+ int errorCode,
+ ErrorSeverity severity)
+ : this(message, exception, string.Empty, 0, 0, prefix, errorCode, severity)
+ {
+ }
+
+ public DataSchemaError(
+ Exception exception,
+ string document,
+ int line,
+ int column,
+ string prefix,
+ int errorCode,
+ ErrorSeverity severity)
+ : this(exception.Message, exception, document, line, column, prefix, errorCode, severity)
+ {
+ }
+
+ public DataSchemaError(
+ string message,
+ string document,
+ int line,
+ int column,
+ string prefix,
+ int errorCode,
+ ErrorSeverity severity)
+ : this(message, null, document, line, column, prefix, errorCode, severity)
+ {
+ }
+
+ public DataSchemaError(
+ string message,
+ Exception exception,
+ string document,
+ int line,
+ int column,
+ string prefix,
+ int errorCode,
+ ErrorSeverity severity)
+ {
+ if (errorCode > MaxErrorCode || errorCode < 0)
+ {
+ throw new ArgumentOutOfRangeException("errorCode");
+ }
+
+ Document = document;
+ Severity = severity;
+ Line = line;
+ Column = column;
+ Message = message;
+ Exception = exception;
+
+ ErrorCode = errorCode;
+ Prefix = prefix;
+ IsPriorityEditable = true;
+ }
+
+ ///
+ /// The filename of the error. It corresponds to the File column on the Visual Studio Error List window.
+ ///
+ public string Document { get; set; }
+
+ ///
+ /// The severity of the error
+ ///
+ public ErrorSeverity Severity { get; private set; }
+
+ public int ErrorCode { get; private set; }
+
+ ///
+ /// Line Number of the error
+ ///
+ public int Line { get; set; }
+
+ ///
+ /// Column Number of the error
+ ///
+ public int Column { get; set; }
+
+ ///
+ /// Prefix of the error
+ ///
+ public string Prefix { get; private set; }
+
+ ///
+ /// If the error has any special help topic, this property may hold the ID to the same.
+ ///
+ public string HelpKeyword { get; set; }
+
+ ///
+ /// Exception associated with the error, or null
+ ///
+ public Exception Exception { get; set; }
+
+ ///
+ /// Message
+ ///
+ public string Message { get; set; }
+
+ ///
+ /// Should this message honor the "treat warnings as error" flag?
+ ///
+ public Boolean IsPriorityEditable { get; set; }
+
+ ///
+ /// Represents the error code used in MSBuild output. This is the prefix and the
+ /// error code
+ ///
+ ///
+ public string BuildErrorCode
+ {
+ get { return FormatErrorCode(Prefix, ErrorCode); }
+ }
+
+ internal Boolean IsBuildErrorCodeDefined
+ {
+ get { return (ErrorCode != UndefinedErrorCode); }
+ }
+
+ ///
+ /// true if this error is being displayed in ErrorList. More of an Accounting Mechanism to be used internally.
+ ///
+ internal bool IsOnDisplay { get; set; }
+
+ internal static string FormatErrorCode(string prefix, int code)
+ {
+ return string.Format(
+ CultureInfo.InvariantCulture,
+ "{0}{1:d5}",
+ prefix,
+ code);
+ }
+
+ ///
+ /// String form of this error.
+ /// NB: This is for debugging only.
+ ///
+ /// String form of the error.
+ public override string ToString()
+ {
+ return string.Format(CultureInfo.CurrentCulture, "{0} - {1}({2},{3}): {4}", FormatErrorCode(Prefix, ErrorCode), Document, Line, Column, Message);
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/DbCommandWrapper.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/DbCommandWrapper.cs
new file mode 100644
index 00000000..8620f461
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/DbCommandWrapper.cs
@@ -0,0 +1,71 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.Data;
+using System.Data.SqlClient;
+using Microsoft.SqlTools.EditorServices.Utility;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
+{
+
+ ///
+ /// Wraps objects that could be a or
+ /// a , providing common methods across both.
+ ///
+ internal sealed class DbCommandWrapper
+ {
+ private readonly IDbCommand _command;
+ private readonly bool _isReliableCommand;
+
+ public DbCommandWrapper(IDbCommand command)
+ {
+ Validate.IsNotNull(nameof(command), command);
+ if (command is ReliableSqlConnection.ReliableSqlCommand)
+ {
+ _isReliableCommand = true;
+ }
+ else if (!(command is SqlCommand))
+ {
+ throw new InvalidOperationException(Resources.InvalidCommandType);
+ }
+ _command = command;
+ }
+
+ public static bool IsSupportedCommand(IDbCommand command)
+ {
+ return command is ReliableSqlConnection.ReliableSqlCommand
+ || command is SqlCommand;
+ }
+
+
+ public event StatementCompletedEventHandler StatementCompleted
+ {
+ add
+ {
+ SqlCommand sqlCommand = GetAsSqlCommand();
+ sqlCommand.StatementCompleted += value;
+ }
+ remove
+ {
+ SqlCommand sqlCommand = GetAsSqlCommand();
+ sqlCommand.StatementCompleted -= value;
+ }
+ }
+
+ ///
+ /// Gets this as a SqlCommand by casting (if we know it is actually a SqlCommand)
+ /// or by getting the underlying command (if it's a ReliableSqlCommand)
+ ///
+ private SqlCommand GetAsSqlCommand()
+ {
+ if (_isReliableCommand)
+ {
+ return ((ReliableSqlConnection.ReliableSqlCommand) _command).GetUnderlyingCommand();
+ }
+ return (SqlCommand) _command;
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/DbConnectionWrapper.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/DbConnectionWrapper.cs
new file mode 100644
index 00000000..c1f1d437
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/DbConnectionWrapper.cs
@@ -0,0 +1,113 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.Data;
+using System.Data.SqlClient;
+using Microsoft.SqlTools.EditorServices.Utility;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
+{
+ ///
+ /// Wraps objects that could be a or
+ /// a , providing common methods across both.
+ ///
+ internal sealed class DbConnectionWrapper
+ {
+ private readonly IDbConnection _connection;
+ private readonly bool _isReliableConnection;
+
+ public DbConnectionWrapper(IDbConnection connection)
+ {
+ Validate.IsNotNull(nameof(connection), connection);
+ if (connection is ReliableSqlConnection)
+ {
+ _isReliableConnection = true;
+ }
+ else if (!(connection is SqlConnection))
+ {
+ throw new InvalidOperationException(Resources.InvalidConnectionType);
+ }
+
+ _connection = connection;
+ }
+
+ public static bool IsSupportedConnection(IDbConnection connection)
+ {
+ return connection is ReliableSqlConnection
+ || connection is SqlConnection;
+ }
+
+ public event SqlInfoMessageEventHandler InfoMessage
+ {
+ add
+ {
+ SqlConnection conn = GetAsSqlConnection();
+ conn.InfoMessage += value;
+ }
+ remove
+ {
+ SqlConnection conn = GetAsSqlConnection();
+ conn.InfoMessage -= value;
+ }
+ }
+
+ public string DataSource
+ {
+ get
+ {
+ if (_isReliableConnection)
+ {
+ return ((ReliableSqlConnection) _connection).DataSource;
+ }
+ return ((SqlConnection)_connection).DataSource;
+ }
+ }
+
+ public string ServerVersion
+ {
+ get
+ {
+ if (_isReliableConnection)
+ {
+ return ((ReliableSqlConnection)_connection).ServerVersion;
+ }
+ return ((SqlConnection)_connection).ServerVersion;
+ }
+ }
+
+ ///
+ /// Gets this as a SqlConnection by casting (if we know it is actually a SqlConnection)
+ /// or by getting the underlying connection (if it's a ReliableSqlConnection)
+ ///
+ public SqlConnection GetAsSqlConnection()
+ {
+ if (_isReliableConnection)
+ {
+ return ((ReliableSqlConnection) _connection).GetUnderlyingConnection();
+ }
+ return (SqlConnection) _connection;
+ }
+
+ /*
+ TODO - IClonable does not exist in .NET Core.
+ ///
+ /// Clones the connection and ensures it's opened.
+ /// If it's a SqlConnection it will clone it,
+ /// and for ReliableSqlConnection it will clone the underling connection.
+ /// The reason the entire ReliableSqlConnection is not cloned is that it includes
+ /// several callbacks and we don't want to try and handle deciding how to clone these
+ /// yet.
+ ///
+ public SqlConnection CloneAndOpenConnection()
+ {
+ SqlConnection conn = GetAsSqlConnection();
+ SqlConnection clonedConn = ((ICloneable) conn).Clone() as SqlConnection;
+ clonedConn.Open();
+ return clonedConn;
+ }
+ */
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ErrorSeverity.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ErrorSeverity.cs
new file mode 100644
index 00000000..5cb01c6d
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ErrorSeverity.cs
@@ -0,0 +1,15 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
+{
+ internal enum ErrorSeverity
+ {
+ Unknown = 0,
+ Error,
+ Warning,
+ Message
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/IStackSettingsContext.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/IStackSettingsContext.cs
new file mode 100644
index 00000000..a121ab73
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/IStackSettingsContext.cs
@@ -0,0 +1,19 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
+{
+ ///
+ /// This interface controls the lifetime of settings created as part of the
+ /// top-of-stack API. Changes made to this context's AmbientData instance will
+ /// flow to lower in the stack while this object is not disposed.
+ ///
+ internal interface IStackSettingsContext : IDisposable
+ {
+ AmbientSettings.AmbientData Settings { get; }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableConnectionHelper.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableConnectionHelper.cs
new file mode 100644
index 00000000..5069fdc9
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableConnectionHelper.cs
@@ -0,0 +1,1267 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.Collections.Generic;
+using System.Data;
+using System.Data.SqlClient;
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+using System.Globalization;
+using System.Security;
+using Microsoft.SqlTools.EditorServices.Utility;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
+{
+ internal static class ReliableConnectionHelper
+ {
+ private const int PCU1BuildNumber = 2816;
+ public readonly static SqlConnectionStringBuilder BuilderWithDefaultApplicationName = new SqlConnectionStringBuilder("server=(local);");
+
+ private const string ServerNameLocalhost = "localhost";
+ private const string SqlProviderName = "System.Data.SqlClient";
+
+ private const string ApplicationIntent = "ApplicationIntent";
+ private const string MultiSubnetFailover = "MultiSubnetFailover";
+ private const string DacFxApplicationName = "DacFx";
+
+ // See MSDN documentation for "SERVERPROPERTY (SQL Azure Database)" for "EngineEdition" property:
+ // http://msdn.microsoft.com/en-us/library/ee336261.aspx
+ private const int SqlAzureEngineEditionId = 5;
+
+ ///
+ /// Opens the connection and sets the lock/command timeout and pooling=false.
+ ///
+ /// The opened connection
+ public static IDbConnection OpenConnection(SqlConnectionStringBuilder csb, bool useRetry)
+ {
+ csb.Pooling = false;
+ return OpenConnection(csb.ToString(), useRetry);
+ }
+
+ ///
+ /// Opens the connection and sets the lock/command timeout. This routine
+ /// will assert if pooling!=false.
+ ///
+ /// The opened connection
+ public static IDbConnection OpenConnection(string connectionString, bool useRetry)
+ {
+#if DEBUG
+ try
+ {
+ SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(connectionString);
+ Debug.Assert(builder.Pooling == false, "Pooling should be false");
+ }
+ catch (Exception ex)
+ {
+ Debug.Assert(false, "Invalid connectionstring: " + ex.Message);
+ }
+#endif
+
+ if (AmbientSettings.AlwaysRetryOnTransientFailure == true)
+ {
+ useRetry = true;
+ }
+
+ RetryPolicy commandRetryPolicy, connectionRetryPolicy;
+ if (useRetry)
+ {
+ commandRetryPolicy = RetryPolicyFactory.CreateDefaultSchemaCommandRetryPolicy(useRetry: true);
+ connectionRetryPolicy = RetryPolicyFactory.CreateDefaultSchemaConnectionRetryPolicy();
+ }
+ else
+ {
+ commandRetryPolicy = RetryPolicyFactory.CreateNoRetryPolicy();
+ connectionRetryPolicy = RetryPolicyFactory.CreateNoRetryPolicy();
+ }
+
+ ReliableSqlConnection connection = new ReliableSqlConnection(connectionString, connectionRetryPolicy, commandRetryPolicy);
+
+ try
+ {
+ connection.Open();
+ }
+ catch (Exception ex)
+ {
+
+ string debugMessage = String.Format(CultureInfo.CurrentCulture,
+ "Opening connection using connection string '{0}' failed with exception: {1}", connectionString, ex.Message);
+#if DEBUG
+ Debug.WriteLine(debugMessage);
+#endif
+ connection.Dispose();
+ throw;
+ }
+
+ return connection;
+ }
+
+ ///
+ /// Opens the connection (if it is not already) and sets
+ /// the lock/command timeout.
+ ///
+ /// The connection to open
+ public static void OpenConnection(IDbConnection conn)
+ {
+ if (conn.State == ConnectionState.Closed)
+ {
+ conn.Open();
+ }
+ }
+
+ ///
+ /// Opens a connection using 'csb' as the connection string. Provide
+ /// 'usingConnection' to execute T-SQL against the open connection and
+ /// 'catchException' to handle errors.
+ ///
+ /// The connection string used when opening the IDbConnection
+ /// delegate called when the IDbConnection has been successfully opened
+ /// delegate called when an exception has occurred. Pass back 'true' to handle the
+ /// exception, 'false' to throw. If Null is passed in then all exceptions are thrown.
+ /// Should retry logic be used when opening the connection
+ public static void OpenConnection(
+ SqlConnectionStringBuilder csb,
+ Action usingConnection,
+ Predicate catchException,
+ bool useRetry)
+ {
+ Validate.IsNotNull(nameof(csb), csb);
+ Validate.IsNotNull(nameof(usingConnection), usingConnection);
+
+ try
+ {
+ // Always disable pooling
+ csb.Pooling = false;
+ using (IDbConnection conn = OpenConnection(csb.ConnectionString, useRetry))
+ {
+ usingConnection(conn);
+ }
+ }
+ catch (Exception ex)
+ {
+ if (catchException == null || !catchException(ex))
+ {
+ throw;
+ }
+ }
+ }
+
+ /*
+ TODO - re-enable if we port ConnectionStringSecurer
+ ///
+ /// This method provides the provides a connection string configured with the specified database name.
+ /// This is also an opportunity to decrypt the connection string based on the encryption/decryption strategy.
+ /// InvalidConnectionStringException could be thrown since this routine attempts to restore the connection
+ /// string if 'restoreConnectionString' is true.
+ ///
+ /// Will only set DatabaseName/ApplicationName if the value is not null.
+ ///
+ ///
+ public static SqlConnectionStringBuilder ConfigureConnectionString(
+ string connectionString,
+ string databaseName,
+ string applicationName,
+ bool restoreConnectionString = true)
+ {
+ if (restoreConnectionString)
+ {
+ // Read the connection string through the persistence layer
+ connectionString = ConnectionStringSecurer.RestoreConnectionString(
+ connectionString,
+ SqlProviderName);
+ }
+
+ SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(connectionString);
+
+ builder.Pooling = false;
+ builder.MultipleActiveResultSets = false;
+
+ // Cannot set the applicationName/initialCatalog to null but empty string is valid
+ if (databaseName != null)
+ {
+ builder.InitialCatalog = databaseName;
+ }
+
+ if (applicationName != null)
+ {
+ builder.ApplicationName = applicationName;
+ }
+
+ return builder;
+ }
+ */
+
+ ///
+ /// Optional 'initializeConnection' routine. This sets the lock and command timeout for the connection.
+ ///
+ public static void SetLockAndCommandTimeout(IDbConnection conn)
+ {
+ ReliableSqlConnection.SetLockAndCommandTimeout(conn);
+ }
+
+ ///
+ /// Opens a IDbConnection, creates a IDbCommand and calls ExecuteNonQuery against the connection.
+ ///
+ /// The connection string.
+ /// The scalar T-SQL command.
+ /// Optional delegate to initialize the IDbCommand before execution.
+ /// Default is SqlConnectionHelper.SetCommandTimeout
+ /// delegate called when an exception has occurred. Pass back 'true' to handle the
+ /// exception, 'false' to throw. If Null is passed in then all exceptions are thrown.
+ /// Should a retry policy be used when calling ExecuteNonQuery
+ /// The number of rows affected
+ public static object ExecuteNonQuery(
+ SqlConnectionStringBuilder csb,
+ string commandText,
+ Action initializeCommand,
+ Predicate catchException,
+ bool useRetry)
+ {
+ object retObject = null;
+ OpenConnection(
+ csb,
+ (connection) =>
+ {
+ retObject = ExecuteNonQuery(connection, commandText, initializeCommand, catchException);
+ },
+ catchException,
+ useRetry);
+
+ return retObject;
+ }
+
+ ///
+ /// Creates a IDbCommand and calls ExecuteNonQuery against the connection.
+ ///
+ /// The connection. This must be opened.
+ /// The scalar T-SQL command.
+ /// Optional delegate to initialize the IDbCommand before execution.
+ /// Default is SqlConnectionHelper.SetCommandTimeout
+ /// Optional exception handling. Pass back 'true' to handle the
+ /// exception, 'false' to throw. If Null is passed in then all exceptions are thrown.
+ /// The number of rows affected
+ [SuppressMessage("Microsoft.Security", "CA2100:Review SQL queries for security vulnerabilities")]
+ public static object ExecuteNonQuery(
+ IDbConnection conn,
+ string commandText,
+ Action initializeCommand,
+ Predicate catchException)
+ {
+ Validate.IsNotNull(nameof(conn), conn);
+ Validate.IsNotNullOrEmptyString(nameof(commandText), commandText);
+
+ IDbCommand cmd = null;
+ try
+ {
+
+ Debug.Assert(conn.State == ConnectionState.Open, "connection passed to ExecuteNonQuery should be open.");
+
+ cmd = conn.CreateCommand();
+ if (initializeCommand == null)
+ {
+ initializeCommand = SetCommandTimeout;
+ }
+ initializeCommand(cmd);
+
+ cmd.CommandText = commandText;
+ cmd.CommandType = CommandType.Text;
+
+ return cmd.ExecuteNonQuery();
+ }
+ catch (Exception ex)
+ {
+ if (catchException == null || !catchException(ex))
+ {
+ throw;
+ }
+ }
+ finally
+ {
+ if (cmd != null)
+ {
+ cmd.Dispose();
+ }
+ }
+ return null;
+ }
+
+ ///
+ /// Creates a IDbCommand and calls ExecuteScalar against the connection.
+ ///
+ /// The connection. This must be opened.
+ /// The scalar T-SQL command.
+ /// Optional delegate to initialize the IDbCommand before execution.
+ /// Default is SqlConnectionHelper.SetCommandTimeout
+ /// Optional exception handling. Pass back 'true' to handle the
+ /// exception, 'false' to throw. If Null is passed in then all exceptions are thrown.
+ /// The scalar result
+ [SuppressMessage("Microsoft.Security", "CA2100:Review SQL queries for security vulnerabilities")]
+ public static object ExecuteScalar(
+ IDbConnection conn,
+ string commandText,
+ Action initializeCommand = null,
+ Predicate catchException = null)
+ {
+ Validate.IsNotNull(nameof(conn), conn);
+ Validate.IsNotNullOrEmptyString(nameof(commandText), commandText);
+
+ IDbCommand cmd = null;
+
+ try
+ {
+ Debug.Assert(conn.State == ConnectionState.Open, "connection passed to ExecuteScalar should be open.");
+
+ cmd = conn.CreateCommand();
+ if (initializeCommand == null)
+ {
+ initializeCommand = SetCommandTimeout;
+ }
+ initializeCommand(cmd);
+
+ cmd.CommandText = commandText;
+ cmd.CommandType = CommandType.Text;
+ return cmd.ExecuteScalar();
+ }
+ catch (Exception ex)
+ {
+ if (catchException == null || !catchException(ex))
+ {
+ throw;
+ }
+ }
+ finally
+ {
+ if (cmd != null)
+ {
+ cmd.Dispose();
+ }
+ }
+ return null;
+ }
+
+ ///
+ /// Creates a IDbCommand and calls ExecuteReader against the connection.
+ ///
+ /// The connection to execute the reader on. This must be opened.
+ /// The command text to execute
+ /// A delegate used to read from the reader
+ /// Optional delegate to initialize the IDbCommand object
+ /// Optional exception handling. Pass back 'true' to handle the
+ /// exception, 'false' to throw. If Null is passed in then all exceptions are thrown.
+ [SuppressMessage("Microsoft.Security", "CA2100:Review SQL queries for security vulnerabilities")]
+ public static void ExecuteReader(
+ IDbConnection conn,
+ string commandText,
+ Action readResult,
+ Action initializeCommand = null,
+ Predicate catchException = null)
+ {
+ Validate.IsNotNull(nameof(conn), conn);
+ Validate.IsNotNullOrEmptyString(nameof(commandText), commandText);
+ Validate.IsNotNull(nameof(readResult), readResult);
+
+ IDbCommand cmd = null;
+ try
+ {
+ Debug.Assert(conn.State == ConnectionState.Open, "connection passed to ExecuteReader should be open.");
+
+ cmd = conn.CreateCommand();
+
+ if (initializeCommand == null)
+ {
+ initializeCommand = SetCommandTimeout;
+ }
+
+ initializeCommand(cmd);
+
+ cmd.CommandText = commandText;
+ cmd.CommandType = CommandType.Text;
+ using (IDataReader reader = cmd.ExecuteReader())
+ {
+ readResult(reader);
+ }
+ }
+ catch (Exception ex)
+ {
+ if (catchException == null || !catchException(ex))
+ {
+ throw;
+ }
+ }
+ finally
+ {
+ if (cmd != null)
+ {
+ cmd.Dispose();
+ }
+ }
+ }
+
+ ///
+ /// optional 'initializeCommand' routine. This initializes the IDbCommand
+ ///
+ ///
+ public static void SetCommandTimeout(IDbCommand cmd)
+ {
+ Validate.IsNotNull(nameof(cmd), cmd);
+ cmd.CommandTimeout = CachedServerInfo.GetQueryTimeoutSeconds(cmd.Connection);
+ }
+
+
+ ///
+ /// Return true if the database is an Azure database
+ ///
+ ///
+ ///
+ public static bool IsCloud(IDbConnection connection)
+ {
+ Validate.IsNotNull(nameof(connection), connection);
+ if (!(connection.State == ConnectionState.Open))
+ {
+ Logger.Write(LogLevel.Warning, Resources.ConnectionPassedToIsCloudShouldBeOpen);
+ }
+
+ Func executeCommand = commandText =>
+ {
+ bool result = false;
+ ExecuteReader(connection,
+ commandText,
+ readResult: (reader) =>
+ {
+ reader.Read();
+ int engineEditionId = int.Parse(reader[0].ToString(), CultureInfo.InvariantCulture);
+
+ result = IsCloudEngineId(engineEditionId);
+ }
+ );
+ return result;
+ };
+
+ bool isSqlCloud = false;
+ try
+ {
+ isSqlCloud = executeCommand(SqlConnectionHelperScripts.EngineEdition);
+ }
+ catch (SqlException)
+ {
+ // The default query contains a WITH (NOLOCK). This doesn't work for Azure DW, so when things don't work out,
+ // we'll fall back to a version without NOLOCK and try again.
+ isSqlCloud = executeCommand(SqlConnectionHelperScripts.EngineEditionWithLock);
+ }
+
+ return isSqlCloud;
+ }
+
+ private static bool IsCloudEngineId(int engineEditionId)
+ {
+ return engineEditionId == SqlAzureEngineEditionId;
+ }
+
+ ///
+ /// Handles the exceptions typically thrown when a SQLConnection is being opened
+ ///
+ /// True if the exception was handled
+ public static bool StandardExceptionHandler(Exception ex)
+ {
+ Validate.IsNotNull(nameof(ex), ex);
+
+ if (ex is SqlException ||
+ ex is RetryLimitExceededException)
+ {
+ return true;
+ }
+ if (ex is InvalidCastException ||
+ ex is ArgumentException || // Thrown when a particular connection string property is invalid (i.e. failover parner = "yes")
+ ex is InvalidOperationException || // thrown when the connection pool is empty and SQL is down
+ ex is TimeoutException ||
+ ex is SecurityException)
+ {
+ return true;
+ }
+
+ Logger.Write(LogLevel.Error, ex.ToString());
+ return false;
+ }
+
+ ///
+ /// Returns the default database path.
+ ///
+ /// The connection
+ /// The delegate used to initialize the command
+ /// The exception handler delegate. If Null is passed in then all exceptions are thrown
+ public static string GetDefaultDatabaseFilePath(
+ IDbConnection conn,
+ Action initializeCommand = null,
+ Predicate catchException = null)
+ {
+ Validate.IsNotNull(nameof(conn), conn);
+
+ string filePath = null;
+ ServerInfo info = GetServerVersion(conn);
+
+ if (!info.IsCloud)
+ {
+ filePath = GetDefaultDatabasePath(conn, SqlConnectionHelperScripts.GetDatabaseFilePathAndName, initializeCommand, catchException);
+ }
+
+ return filePath;
+ }
+
+ ///
+ /// Returns the log path or null
+ ///
+ /// The connection
+ /// The delegate used to initialize the command
+ /// The exception handler delegate. If Null is passed in then all exceptions are thrown
+ public static string GetDefaultDatabaseLogPath(
+ IDbConnection conn,
+ Action initializeCommand = null,
+ Predicate catchException = null)
+ {
+ Validate.IsNotNull(nameof(conn), conn);
+
+ string logPath = null;
+ ServerInfo info = GetServerVersion(conn);
+
+ if (!info.IsCloud)
+ {
+ logPath = GetDefaultDatabasePath(conn, SqlConnectionHelperScripts.GetDatabaseLogPathAndName, initializeCommand, catchException);
+ }
+
+ return logPath;
+ }
+
+ ///
+ /// Returns the database path or null
+ ///
+ /// The connection
+ /// The command to issue
+ /// The delegate used to initialize the command
+ /// The exception handler delegate. If Null is passed in then all exceptions are thrown
+ private static string GetDefaultDatabasePath(
+ IDbConnection conn,
+ string commandText,
+ Action initializeCommand = null,
+ Predicate catchException = null)
+ {
+ Validate.IsNotNull(nameof(conn), conn);
+ Validate.IsNotNullOrEmptyString(nameof(commandText), commandText);
+
+ string filePath = ExecuteScalar(conn, commandText, initializeCommand, catchException) as string;
+ if (!String.IsNullOrWhiteSpace(filePath))
+ {
+ // Remove filename from the filePath
+ Uri pathUri;
+ if (Uri.TryCreate(filePath, UriKind.Absolute, out pathUri) == false)
+ {
+ // Invalid Uri
+ return null;
+ }
+
+ // Get a current directory path relative to the pathUri
+ // This will remove filename from the uri.
+ Uri filePathUri = new Uri(pathUri, ".");
+ // For file uri we need to get LocalPath instead of file:// url
+ filePath = filePathUri.IsFile ? filePathUri.LocalPath : filePathUri.OriginalString;
+ }
+ return filePath;
+ }
+
+ ///
+ /// Returns true if the database is readonly. This routine will swallow the exceptions you might expect from SQL using StandardExceptionHandler.
+ ///
+ public static bool IsDatabaseReadonly(SqlConnectionStringBuilder builder)
+ {
+ Validate.IsNotNull(nameof(builder), builder);
+
+ if (builder == null)
+ {
+ throw new ArgumentNullException("builder");
+ }
+ bool isDatabaseReadOnly = false;
+
+ OpenConnection(
+ builder,
+ (connection) =>
+ {
+ string commandText = String.Format(CultureInfo.InvariantCulture, SqlConnectionHelperScripts.CheckDatabaseReadonly, builder.InitialCatalog);
+ ExecuteReader(connection,
+ commandText,
+ readResult: (reader) =>
+ {
+ if (reader.Read())
+ {
+ string currentSetting = reader.GetString(1);
+ if (String.Compare(currentSetting, "ON", StringComparison.OrdinalIgnoreCase) == 0)
+ {
+ isDatabaseReadOnly = true;
+ }
+ }
+ });
+ },
+ (ex) =>
+ {
+ Logger.Write(LogLevel.Error, ex.ToString());
+ return StandardExceptionHandler(ex); // handled
+ },
+ useRetry: true);
+
+ return isDatabaseReadOnly;
+ }
+
+ public class ServerInfo
+ {
+ public int ServerMajorVersion;
+ public int ServerMinorVersion;
+ public int ServerReleaseVersion;
+ public int EngineEditionId;
+ public string ServerVersion;
+ public string ServerLevel;
+ public string ServerEdition;
+ public bool IsCloud;
+ public int AzureVersion;
+
+ // In SQL 2012 SP1 Selective XML indexes were added. There is bug where upgraded databases from previous versions
+ // of SQL Server do not have their metadata upgraded to include the xml_index_type column in the sys.xml_indexes view. Because
+ // of this, we must detect the presence of the column to determine if we can query for Selective Xml Indexes
+ public bool IsSelectiveXmlIndexMetadataPresent;
+
+ public bool IsAzureV1
+ {
+ get
+ {
+ return IsCloud && AzureVersion == 1;
+ }
+ }
+
+ public string OsVersion;
+ }
+
+ public static bool TryGetServerVersion(string connectionString, out ServerInfo serverInfo)
+ {
+ serverInfo = null;
+ SqlConnectionStringBuilder builder;
+ if (!TryGetConnectionStringBuilder(connectionString, out builder))
+ {
+ return false;
+ }
+
+ serverInfo = GetServerVersion(builder);
+ return true;
+ }
+
+ ///
+ /// Returns the version of the server. This routine will throw if an exception is encountered.
+ ///
+ public static ServerInfo GetServerVersion(SqlConnectionStringBuilder csb)
+ {
+ Validate.IsNotNull(nameof(csb), csb);
+ ServerInfo serverInfo = null;
+
+ OpenConnection(
+ csb,
+ (connection) =>
+ {
+ serverInfo = GetServerVersion(connection);
+ },
+ catchException: null, // Always throw
+ useRetry: true);
+
+ return serverInfo;
+ }
+
+ ///
+ /// Returns the version of the server. This routine will throw if an exception is encountered.
+ ///
+ public static ServerInfo GetServerVersion(IDbConnection connection)
+ {
+ Validate.IsNotNull(nameof(connection), connection);
+ if (!(connection.State == ConnectionState.Open))
+ {
+ Logger.Write(LogLevel.Error, "connection passed to GetServerVersion should be open.");
+ }
+
+ Func getServerInfo = commandText =>
+ {
+ ServerInfo serverInfo = new ServerInfo();
+ ExecuteReader(
+ connection,
+ commandText,
+ delegate (IDataReader reader)
+ {
+ reader.Read();
+ int engineEditionId = Int32.Parse(reader[0].ToString(), CultureInfo.InvariantCulture);
+
+ serverInfo.EngineEditionId = engineEditionId;
+ serverInfo.IsCloud = IsCloudEngineId(engineEditionId);
+
+ serverInfo.ServerVersion = reader[1].ToString();
+ serverInfo.ServerLevel = reader[2].ToString();
+ serverInfo.ServerEdition = reader[3].ToString();
+
+ if (reader.FieldCount > 4)
+ {
+ // Detect the presence of SXI
+ serverInfo.IsSelectiveXmlIndexMetadataPresent = reader.GetInt32(4) == 1;
+ }
+
+ // The 'ProductVersion' server property is of the form ##.#[#].####.#,
+ Version serverVersion = new Version(serverInfo.ServerVersion);
+
+ // The server version is of the form ##.##.####,
+ serverInfo.ServerMajorVersion = serverVersion.Major;
+ serverInfo.ServerMinorVersion = serverVersion.Minor;
+ serverInfo.ServerReleaseVersion = serverVersion.Build;
+
+ if (serverInfo.IsCloud)
+ {
+ serverInfo.AzureVersion = serverVersion.Major > 11 ? 2 : 1;
+ }
+
+ try
+ {
+ CachedServerInfo.AddOrUpdateIsAzure(connection, serverInfo.IsCloud);
+ }
+ catch (Exception ex)
+ {
+ //we don't want to fail the normal flow if any unexpected thing happens
+ //during caching although it's unlikely. So we just log the exception and ignore it
+ Logger.Write(LogLevel.Error, Resources.FailedToCacheIsCloud);
+ Logger.Write(LogLevel.Error, ex.ToString());
+ }
+ });
+
+ // Also get the OS Version
+ ExecuteReader(
+ connection,
+ SqlConnectionHelperScripts.GetOsVersion,
+ delegate (IDataReader reader)
+ {
+ reader.Read();
+ serverInfo.OsVersion = reader[0].ToString();
+ });
+
+ return serverInfo;
+ };
+
+ ServerInfo result = null;
+ try
+ {
+ result = getServerInfo(SqlConnectionHelperScripts.EngineEdition);
+ }
+ catch (SqlException)
+ {
+ // The default query contains a WITH (NOLOCK). This doesn't work for Azure DW, so when things don't work out,
+ // we'll fall back to a version without NOLOCK and try again.
+ result = getServerInfo(SqlConnectionHelperScripts.EngineEditionWithLock);
+ }
+
+ return result;
+ }
+
+ public static string GetServerName(IDbConnection connection)
+ {
+ return new DbConnectionWrapper(connection).DataSource;
+ }
+
+ public static string ReadServerVersion(IDbConnection connection)
+ {
+ return new DbConnectionWrapper(connection).ServerVersion;
+ }
+
+ ///
+ /// Converts to a SqlConnection by casting (if we know it is actually a SqlConnection)
+ /// or by getting the underlying connection (if it's a ReliableSqlConnection)
+ ///
+ public static SqlConnection GetAsSqlConnection(IDbConnection connection)
+ {
+ return new DbConnectionWrapper(connection).GetAsSqlConnection();
+ }
+
+ /* TODO - CloneAndOpenConnection() requires IClonable, which doesn't exist in .NET Core
+ ///
+ /// Clones a connection and ensures it's opened.
+ /// If it's a SqlConnection it will clone it,
+ /// and for ReliableSqlConnection it will clone the underling connection.
+ /// The reason the entire ReliableSqlConnection is not cloned is that it includes
+ /// several callbacks and we don't want to try and handle deciding how to clone these
+ /// yet.
+ ///
+ public static SqlConnection CloneAndOpenConnection(IDbConnection connection)
+ {
+ return new DbConnectionWrapper(connection).CloneAndOpenConnection();
+ }
+ */
+
+ public class ServerAndDatabaseInfo : ServerInfo
+ {
+ public int DbCompatibilityLevel;
+ public string DatabaseName;
+ }
+
+ private static bool TryGetConnectionStringBuilder(string connectionString, out SqlConnectionStringBuilder builder)
+ {
+ builder = null;
+
+ if (String.IsNullOrEmpty(connectionString))
+ {
+ // Connection string is not valid
+ return false;
+ }
+
+ // Attempt to initialize the builder
+ Exception handledEx = null;
+ try
+ {
+ builder = new SqlConnectionStringBuilder(connectionString);
+ }
+ catch (KeyNotFoundException ex)
+ {
+ handledEx = ex;
+ }
+ catch (FormatException ex)
+ {
+ handledEx = ex;
+ }
+ catch (ArgumentException ex)
+ {
+ handledEx = ex;
+ }
+
+ if (handledEx != null)
+ {
+ Logger.Write(LogLevel.Error, String.Format(Resources.ErrorParsingConnectionString, handledEx));
+ return false;
+ }
+
+ return true;
+ }
+
+ /*
+ ///
+ /// Get the version of the server and database using
+ /// the connection string provided. This routine will
+ /// throw if an exception is encountered.
+ ///
+ /// The connection string used to connect to the database.
+ /// Basic information about the server
+ public static bool GetServerAndDatabaseVersion(string connectionString, out ServerAndDatabaseInfo info)
+ {
+ bool foundVersion = false;
+ info = new ServerAndDatabaseInfo { IsCloud = false, ServerMajorVersion = -1, DbCompatibilityLevel = -1, DatabaseName = String.Empty };
+
+ SqlConnectionStringBuilder builder;
+ if (!TryGetConnectionStringBuilder(connectionString, out builder))
+ {
+ return false;
+ }
+
+ // The database name is either the InitialCatalog or the AttachDBFilename. The
+ // AttachDBFilename is used if an mdf file is specified in the connections dialog.
+ if (String.IsNullOrEmpty(builder.InitialCatalog) ||
+ String.IsNullOrEmpty(builder.AttachDBFilename))
+ {
+ builder.Pooling = false;
+
+ string tempDatabaseName = String.Empty;
+ int tempDbCompatibilityLevel = 0;
+ ServerInfo serverInfo = null;
+
+ OpenConnection(
+ builder,
+ (connection) =>
+ {
+ // Set the lock timeout to 3 seconds
+ SetLockAndCommandTimeout(connection);
+
+ serverInfo = GetServerVersion(connection);
+
+ tempDatabaseName = (String.IsNullOrEmpty(builder.InitialCatalog) == false) ?
+ builder.InitialCatalog : builder.AttachDBFilename;
+
+ // If at this point the dbName remained an empty string then
+ // we should get the database name from the open IDbConnection
+ if (String.IsNullOrEmpty(tempDatabaseName))
+ {
+ tempDatabaseName = connection.Database;
+ }
+
+ // SQL Azure does not support custom DBCompat values.
+ if (!serverInfo.IsAzureV1)
+ {
+ SqlParameter databaseNameParameter = new SqlParameter(
+ "@dbname",
+ SqlDbType.NChar,
+ 128,
+ ParameterDirection.Input,
+ false,
+ 0,
+ 0,
+ null,
+ DataRowVersion.Default,
+ tempDatabaseName);
+
+ object compatibilityLevel;
+
+ using (IDbCommand versionCommand = connection.CreateCommand())
+ {
+ versionCommand.CommandText = "SELECT compatibility_level FROM sys.databases WITH (NOLOCK) WHERE name = @dbname";
+ versionCommand.CommandType = CommandType.Text;
+ versionCommand.Parameters.Add(databaseNameParameter);
+ compatibilityLevel = versionCommand.ExecuteScalar();
+ }
+
+ // value is null if db is not online
+ foundVersion = compatibilityLevel != null && !(compatibilityLevel is DBNull);
+ if(foundVersion)
+ {
+ tempDbCompatibilityLevel = (byte)compatibilityLevel;
+ }
+ else
+ {
+ string conString = connection.ConnectionString == null ? "null" : connection.ConnectionString;
+ string dbName = tempDatabaseName == null ? "null" : tempDatabaseName;
+ string message = string.Format(CultureInfo.CurrentCulture,
+ "Querying database compatibility level failed. Connection string: '{0}'. dbname: '{1}'.",
+ conString, dbName);
+ Tracer.TraceEvent(TraceEventType.Error, TraceId.CoreServices, message);
+ }
+ }
+ else
+ {
+ foundVersion = true;
+ }
+ },
+ catchException: null, // Always throw
+ useRetry: true);
+
+ info.IsCloud = serverInfo.IsCloud;
+ info.ServerMajorVersion = serverInfo.ServerMajorVersion;
+ info.ServerMinorVersion = serverInfo.ServerMinorVersion;
+ info.ServerReleaseVersion = serverInfo.ServerReleaseVersion;
+ info.ServerVersion = serverInfo.ServerVersion;
+ info.ServerLevel = serverInfo.ServerLevel;
+ info.ServerEdition = serverInfo.ServerEdition;
+ info.AzureVersion = serverInfo.AzureVersion;
+ info.DatabaseName = tempDatabaseName;
+ info.DbCompatibilityLevel = tempDbCompatibilityLevel;
+ }
+
+ return foundVersion;
+ }
+ */
+
+ ///
+ /// Returns true if the authenticating database is master, otherwise false. An example of
+ /// false is when the user is a contained user connecting to a contained database.
+ ///
+ public static bool IsAuthenticatingDatabaseMaster(IDbConnection connection)
+ {
+ try
+ {
+ const string sqlCommand =
+ @"use [{0}];
+ if (db_id() = 1)
+ begin
+ -- contained auth is 0 when connected to master
+ select 0
+ end
+ else
+ begin
+ -- need dynamic sql so that we compile this query only when we know resource db is available
+ exec('select case when authenticating_database_id = 1 then 0 else 1 end from sys.dm_exec_sessions where session_id = @@SPID')
+ end";
+
+ string finalCmd = null;
+ if (!String.IsNullOrWhiteSpace(connection.Database))
+ {
+ finalCmd = String.Format(CultureInfo.InvariantCulture, sqlCommand, connection.Database);
+ }
+ else
+ {
+ finalCmd = String.Format(CultureInfo.InvariantCulture, sqlCommand, "master");
+ }
+
+ object retValue = ExecuteScalar(connection, finalCmd);
+ if (retValue != null && retValue.ToString() == "1")
+ {
+ // contained auth is 0 when connected to non-master
+ return false;
+ }
+ return true;
+ }
+ catch (Exception ex)
+ {
+ if (StandardExceptionHandler(ex))
+ {
+ return true;
+ }
+ throw;
+ }
+ }
+
+ ///
+ /// Returns true if the authenticating database is master, otherwise false. An example of
+ /// false is when the user is a contained user connecting to a contained database.
+ ///
+ public static bool IsAuthenticatingDatabaseMaster(SqlConnectionStringBuilder builder)
+ {
+ bool authIsMaster = true;
+ OpenConnection(
+ builder,
+ usingConnection: (connection) =>
+ {
+ authIsMaster = IsAuthenticatingDatabaseMaster(connection);
+ },
+ catchException: StandardExceptionHandler, // Don't throw unless it's an unexpected exception
+ useRetry: true);
+ return authIsMaster;
+ }
+
+ ///
+ /// Returns the form of the server as a it's name - replaces . and (localhost)
+ ///
+ public static string GetCompleteServerName(string server)
+ {
+ if (String.IsNullOrEmpty(server))
+ {
+ return server;
+ }
+
+ int nlen = 0;
+ if (server[0] == '.')
+ {
+ nlen = 1;
+ }
+ else if (String.Compare(server, Constants.Local, StringComparison.OrdinalIgnoreCase) == 0)
+ {
+ nlen = Constants.Local.Length;
+ }
+ else if (String.Compare(server, 0, ServerNameLocalhost, 0, ServerNameLocalhost.Length, StringComparison.OrdinalIgnoreCase) == 0)
+ {
+ nlen = ServerNameLocalhost.Length;
+ }
+
+ if (nlen > 0)
+ {
+ string strMachine = Environment.MachineName;
+ if (server.Length == nlen)
+ return strMachine;
+ if (server.Length > (nlen + 1) && server[nlen] == '\\') // instance
+ {
+ string strRet = strMachine + server.Substring(nlen);
+ return strRet;
+ }
+ }
+
+ return server;
+ }
+
+ /*
+ ///
+ /// Processes a user-supplied connection string and provides a trimmed connection string
+ /// that eliminates everything except for DataSource, InitialCatalog, UserId, Password,
+ /// ConnectTimeout, Encrypt, TrustServerCertificate and IntegratedSecurity.
+ ///
+ /// When connection string is invalid
+ public static string TrimConnectionString(string connectionString)
+ {
+ Exception handledException;
+
+ try
+ {
+ SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connectionString);
+ return TrimConnectionStringBuilder(scsb).ConnectionString;
+ }
+ catch (ArgumentException exception)
+ {
+ handledException = exception;
+ }
+ catch (KeyNotFoundException exception)
+ {
+ handledException = exception;
+ }
+ catch (FormatException exception)
+ {
+ handledException = exception;
+ }
+
+ throw new InvalidConnectionStringException(handledException);
+ }
+ */
+
+ ///
+ /// Sql 2012 PCU1 introduces breaking changes to metadata queries and adds new Selective XML Index support.
+ /// This method allows components to detect if the represents a build of SQL 2012 after RTM.
+ ///
+ public static bool IsVersionGreaterThan2012RTM(ServerInfo _serverInfo)
+ {
+ return _serverInfo.ServerMajorVersion > 11 ||
+ // Use the presence of SXI metadata rather than build number as upgrade bugs leave out the SXI metadata for some upgraded databases.
+ _serverInfo.ServerMajorVersion == 11 && _serverInfo.IsSelectiveXmlIndexMetadataPresent;
+ }
+
+
+ // SQL Server: Defect 1122301: ReliableConnectionHelper does not maintain ApplicationIntent
+ // The ApplicationIntent and MultiSubnetFailover property is not introduced to .NET until .NET 4.0 update 2
+ // However, DacFx is officially depends on .NET 4.0 RTM
+ // So here we want to support both senarios, on machine with 4.0 RTM installed, it will ignore these 2 properties
+ // On machine with higher .NET version which included those properties, it will pick them up.
+ public static void TryAddAlwaysOnConnectionProperties(SqlConnectionStringBuilder userBuilder, SqlConnectionStringBuilder trimBuilder)
+ {
+ if (userBuilder.ContainsKey(ApplicationIntent))
+ {
+ trimBuilder[ApplicationIntent] = userBuilder[ApplicationIntent];
+ }
+
+ if (userBuilder.ContainsKey(MultiSubnetFailover))
+ {
+ trimBuilder[MultiSubnetFailover] = userBuilder[MultiSubnetFailover];
+ }
+ }
+
+ /* TODO - this relies on porting SqlAuthenticationMethodUtils
+ ///
+ /// Processes a user-supplied connection string and provides a trimmed connection string
+ /// that eliminates everything except for DataSource, InitialCatalog, UserId, Password,
+ /// ConnectTimeout, Encrypt, TrustServerCertificate, IntegratedSecurity and Pooling.
+ ///
+ ///
+ /// Pooling is always set to false to avoid connections remaining open.
+ ///
+ /// When connection string is invalid
+ public static SqlConnectionStringBuilder TrimConnectionStringBuilder(SqlConnectionStringBuilder userBuilder, Action throwException = null)
+ {
+
+ Exception handledException;
+
+ if (throwException == null)
+ {
+ throwException = (propertyName) =>
+ {
+ throw new InvalidConnectionStringException(String.Format(CultureInfo.CurrentCulture, Resources.UnsupportedConnectionStringArgument, propertyName));
+ };
+ }
+ if (!String.IsNullOrEmpty(userBuilder.AttachDBFilename))
+ {
+ throwException("AttachDBFilename");
+ }
+ if (userBuilder.UserInstance)
+ {
+ throwException("User Instance");
+ }
+
+ try
+ {
+ SqlConnectionStringBuilder trimBuilder = new SqlConnectionStringBuilder();
+
+ if (String.IsNullOrWhiteSpace(userBuilder.DataSource))
+ {
+ throw new InvalidConnectionStringException();
+ }
+
+ trimBuilder.ConnectTimeout = userBuilder.ConnectTimeout;
+ trimBuilder.DataSource = userBuilder.DataSource;
+
+ if (false == String.IsNullOrWhiteSpace(userBuilder.InitialCatalog))
+ {
+ trimBuilder.InitialCatalog = userBuilder.InitialCatalog;
+ }
+
+ trimBuilder.IntegratedSecurity = userBuilder.IntegratedSecurity;
+
+ if (!String.IsNullOrWhiteSpace(userBuilder.UserID))
+ {
+ trimBuilder.UserID = userBuilder.UserID;
+ }
+
+ if (!String.IsNullOrWhiteSpace(userBuilder.Password))
+ {
+ trimBuilder.Password = userBuilder.Password;
+ }
+
+ trimBuilder.TrustServerCertificate = userBuilder.TrustServerCertificate;
+ trimBuilder.Encrypt = userBuilder.Encrypt;
+
+ if (String.IsNullOrWhiteSpace(userBuilder.ApplicationName) ||
+ String.Equals(BuilderWithDefaultApplicationName.ApplicationName, userBuilder.ApplicationName, StringComparison.Ordinal))
+ {
+ trimBuilder.ApplicationName = DacFxApplicationName;
+ }
+ else
+ {
+ trimBuilder.ApplicationName = userBuilder.ApplicationName;
+ }
+
+ TryAddAlwaysOnConnectionProperties(userBuilder, trimBuilder);
+
+ if (SqlAuthenticationMethodUtils.IsAuthenticationSupported())
+ {
+ SqlAuthenticationMethodUtils.SetAuthentication(userBuilder, trimBuilder);
+ }
+
+ if (SqlAuthenticationMethodUtils.IsCertificateSupported())
+ {
+ SqlAuthenticationMethodUtils.SetCertificate(userBuilder, trimBuilder);
+ }
+
+ trimBuilder.Pooling = false;
+ return trimBuilder;
+ }
+ catch (ArgumentException exception)
+ {
+ handledException = exception;
+ }
+ catch (KeyNotFoundException exception)
+ {
+ handledException = exception;
+ }
+ catch (FormatException exception)
+ {
+ handledException = exception;
+ }
+
+ throw new InvalidConnectionStringException(handledException);
+ }
+
+ public static bool TryCreateConnectionStringBuilder(string connectionString, out SqlConnectionStringBuilder builder, out Exception handledException)
+ {
+ bool success = false;
+ builder = null;
+ handledException = null;
+ try
+ {
+ builder = TrimConnectionStringBuilder(new SqlConnectionStringBuilder(connectionString));
+
+ success = true;
+ }
+ catch (InvalidConnectionStringException e)
+ {
+ handledException = e;
+ }
+ catch (ArgumentException exception)
+ {
+ handledException = exception;
+ }
+ catch (KeyNotFoundException exception)
+ {
+ handledException = exception;
+ }
+ catch (FormatException exception)
+ {
+ handledException = exception;
+ }
+ finally
+ {
+ if (handledException != null)
+ {
+ success = false;
+ }
+ }
+ return success;
+ }
+ */
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableSqlCommand.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableSqlCommand.cs
new file mode 100644
index 00000000..4f051688
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableSqlCommand.cs
@@ -0,0 +1,247 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+// This code is copied from the source described in the comment below.
+
+// =======================================================================================
+// Microsoft Windows Server AppFabric Customer Advisory Team (CAT) Best Practices Series
+//
+// This sample is supplemental to the technical guidance published on the community
+// blog at http://blogs.msdn.com/appfabriccat/ and copied from
+// sqlmain ./sql/manageability/mfx/common/
+//
+// =======================================================================================
+// Copyright © 2012 Microsoft Corporation. All rights reserved.
+//
+// THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER
+// EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF
+// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. YOU BEAR THE RISK OF USING IT.
+// =======================================================================================
+
+// namespace Microsoft.AppFabricCAT.Samples.Azure.TransientFaultHandling.SqlAzure
+// namespace Microsoft.SqlServer.Management.Common
+
+using System;
+using System.Data;
+using System.Data.Common;
+using System.Data.SqlClient;
+using System.Diagnostics.Contracts;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
+{
+ ///
+ /// Provides a reliable way of opening connections to and executing commands
+ /// taking into account potential network unreliability and a requirement for connection retry.
+ ///
+ internal sealed partial class ReliableSqlConnection
+ {
+ internal class ReliableSqlCommand : DbCommand
+ {
+ private const int Dummy = 0;
+ private readonly SqlCommand _command;
+
+ // connection is settable
+ private ReliableSqlConnection _connection;
+
+ public ReliableSqlCommand()
+ : this(null, Dummy)
+ {
+ }
+
+ public ReliableSqlCommand(ReliableSqlConnection connection)
+ : this(connection, Dummy)
+ {
+ Contract.Requires(connection != null);
+ }
+
+ private ReliableSqlCommand(ReliableSqlConnection connection, int dummy)
+ {
+ if (connection != null)
+ {
+ _connection = connection;
+ _command = connection.CreateSqlCommand();
+ }
+ else
+ {
+ _command = new SqlCommand();
+ }
+ }
+
+ protected override void Dispose(bool disposing)
+ {
+ if (disposing)
+ {
+ _command.Dispose();
+ }
+ }
+
+ ///
+ /// Gets or sets the text command to run against the data source.
+ ///
+ [System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Security", "CA2100:Review SQL queries for security vulnerabilities")]
+ public override string CommandText
+ {
+ get { return _command.CommandText; }
+ set { _command.CommandText = value; }
+ }
+
+ ///
+ /// Gets or sets the wait time before terminating the attempt to execute a command and generating an error.
+ ///
+ public override int CommandTimeout
+ {
+ get { return _command.CommandTimeout; }
+ set { _command.CommandTimeout = value; }
+ }
+
+ ///
+ /// Gets or sets a value that specifies how the property is interpreted.
+ ///
+ public override CommandType CommandType
+ {
+ get { return _command.CommandType; }
+ set { _command.CommandType = value; }
+ }
+
+ ///
+ /// Gets or sets the used by this .
+ ///
+ protected override DbConnection DbConnection
+ {
+ get
+ {
+ return _connection;
+ }
+
+ set
+ {
+ if (value == null)
+ {
+ throw new ArgumentNullException("value");
+ }
+
+ ReliableSqlConnection newConnection = value as ReliableSqlConnection;
+
+ if (newConnection == null)
+ {
+ throw new InvalidOperationException(Resources.OnlyReliableConnectionSupported);
+ }
+
+ _connection = newConnection;
+ _command.Connection = _connection._underlyingConnection;
+ }
+ }
+
+ ///
+ /// Gets the .
+ ///
+ protected override DbParameterCollection DbParameterCollection
+ {
+ get { return _command.Parameters; }
+ }
+
+ ///
+ /// Gets or sets the transaction within which the Command object of a .NET Framework data provider executes.
+ ///
+ protected override DbTransaction DbTransaction
+ {
+ get { return _command.Transaction; }
+ set { _command.Transaction = value as SqlTransaction; }
+ }
+
+ ///
+ /// Gets or sets a value indicating whether the command object should be visible in a customized interface control.
+ ///
+ public override bool DesignTimeVisible
+ {
+ get { return _command.DesignTimeVisible; }
+ set { _command.DesignTimeVisible = value; }
+ }
+
+ ///
+ /// Gets or sets how command results are applied to the System.Data.DataRow when
+ /// used by the System.Data.IDataAdapter.Update(System.Data.DataSet) method of
+ /// a .
+ ///
+ public override UpdateRowSource UpdatedRowSource
+ {
+ get { return _command.UpdatedRowSource; }
+ set { _command.UpdatedRowSource = value; }
+ }
+
+ ///
+ /// Attempts to cancels the execution of an .
+ ///
+ public override void Cancel()
+ {
+ _command.Cancel();
+ }
+
+ ///
+ /// Creates a new instance of an object.
+ ///
+ /// An object.
+ protected override DbParameter CreateDbParameter()
+ {
+ return _command.CreateParameter();
+ }
+
+ ///
+ /// Executes an SQL statement against the Connection object of a .NET Framework
+ /// data provider, and returns the number of rows affected.
+ ///
+ /// The number of rows affected.
+ public override int ExecuteNonQuery()
+ {
+ ValidateConnectionIsSet();
+ return _connection.ExecuteNonQuery(_command);
+ }
+
+ ///
+ /// Executes the against the
+ /// and builds an using one of the values.
+ ///
+ /// One of the values.
+ /// An object.
+ protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior)
+ {
+ ValidateConnectionIsSet();
+ return (DbDataReader)_connection.ExecuteReader(_command, behavior);
+ }
+
+ ///
+ /// Executes the query, and returns the first column of the first row in the
+ /// resultset returned by the query. Extra columns or rows are ignored.
+ ///
+ /// The first column of the first row in the resultset.
+ public override object ExecuteScalar()
+ {
+ ValidateConnectionIsSet();
+ return _connection.ExecuteScalar(_command);
+ }
+
+ ///
+ /// Creates a prepared (or compiled) version of the command on the data source.
+ ///
+ public override void Prepare()
+ {
+ _command.Prepare();
+ }
+
+ internal SqlCommand GetUnderlyingCommand()
+ {
+ return _command;
+ }
+
+ private void ValidateConnectionIsSet()
+ {
+ if (_connection == null)
+ {
+ throw new InvalidOperationException(Resources.ConnectionPropertyNotSet);
+ }
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableSqlConnection.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableSqlConnection.cs
new file mode 100644
index 00000000..b949cd54
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableSqlConnection.cs
@@ -0,0 +1,548 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+// This code is copied from the source described in the comment below.
+
+// =======================================================================================
+// Microsoft Windows Server AppFabric Customer Advisory Team (CAT) Best Practices Series
+//
+// This sample is supplemental to the technical guidance published on the community
+// blog at http://blogs.msdn.com/appfabriccat/ and copied from
+// sqlmain ./sql/manageability/mfx/common/
+//
+// =======================================================================================
+// Copyright © 2012 Microsoft Corporation. All rights reserved.
+//
+// THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER
+// EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF
+// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. YOU BEAR THE RISK OF USING IT.
+// =======================================================================================
+
+// namespace Microsoft.AppFabricCAT.Samples.Azure.TransientFaultHandling.SqlAzure
+// namespace Microsoft.SqlServer.Management.Common
+
+using System;
+using System.Collections.Generic;
+using System.Data;
+using System.Data.Common;
+using System.Data.SqlClient;
+using System.Diagnostics;
+using System.Globalization;
+using System.Text;
+using Microsoft.SqlTools.EditorServices.Utility;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
+{
+ ///
+ /// Provides a reliable way of opening connections to and executing commands
+ /// taking into account potential network unreliability and a requirement for connection retry.
+ ///
+ internal sealed partial class ReliableSqlConnection : DbConnection, IDisposable
+ {
+ private const string QueryAzureSessionId = "SELECT CONVERT(NVARCHAR(36), CONTEXT_INFO())";
+
+ private readonly SqlConnection _underlyingConnection;
+ private readonly RetryPolicy _connectionRetryPolicy;
+ private RetryPolicy _commandRetryPolicy;
+ private Guid _azureSessionId;
+
+ ///
+ /// Initializes a new instance of the ReliableSqlConnection class with a given connection string
+ /// and a policy defining whether to retry a request if the connection fails to be opened or a command
+ /// fails to be successfully executed.
+ ///
+ /// The connection string used to open the SQL Azure database.
+ /// The retry policy defining whether to retry a request if a connection fails to be established.
+ /// The retry policy defining whether to retry a request if a command fails to be executed.
+ public ReliableSqlConnection(string connectionString, RetryPolicy connectionRetryPolicy, RetryPolicy commandRetryPolicy)
+ {
+ _underlyingConnection = new SqlConnection(connectionString);
+ _connectionRetryPolicy = connectionRetryPolicy ?? RetryPolicyFactory.CreateNoRetryPolicy();
+ _commandRetryPolicy = commandRetryPolicy ?? RetryPolicyFactory.CreateNoRetryPolicy();
+
+ _underlyingConnection.StateChange += OnConnectionStateChange;
+ _connectionRetryPolicy.RetryOccurred += RetryConnectionCallback;
+ _commandRetryPolicy.RetryOccurred += RetryCommandCallback;
+ }
+
+ ///
+ /// Performs application-defined tasks associated with freeing, releasing, or
+ /// resetting managed and unmanaged resources.
+ ///
+ /// A flag indicating that managed resources must be released.
+ protected override void Dispose(bool disposing)
+ {
+ if (disposing)
+ {
+ if (_connectionRetryPolicy != null)
+ {
+ _connectionRetryPolicy.RetryOccurred -= RetryConnectionCallback;
+ }
+
+ if (_commandRetryPolicy != null)
+ {
+ _commandRetryPolicy.RetryOccurred -= RetryCommandCallback;
+ }
+
+ _underlyingConnection.StateChange -= OnConnectionStateChange;
+ if (_underlyingConnection.State == ConnectionState.Open)
+ {
+ _underlyingConnection.Close();
+ }
+
+ _underlyingConnection.Dispose();
+ }
+ }
+
+ [System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Security", "CA2100:Review SQL queries for security vulnerabilities")]
+ internal static void SetLockAndCommandTimeout(IDbConnection conn)
+ {
+ Validate.IsNotNull(nameof(conn), conn);
+
+ // Make sure we use the underlying connection as ReliableConnection.Open also calls
+ // this method
+ ReliableSqlConnection reliableConn = conn as ReliableSqlConnection;
+ if (reliableConn != null)
+ {
+ conn = reliableConn._underlyingConnection;
+ }
+
+ const string setLockTimeout = @"set LOCK_TIMEOUT {0}";
+
+ using (IDbCommand cmd = conn.CreateCommand())
+ {
+ cmd.CommandText = string.Format(CultureInfo.InvariantCulture, setLockTimeout, AmbientSettings.LockTimeoutMilliSeconds);
+ cmd.CommandType = CommandType.Text;
+ cmd.CommandTimeout = CachedServerInfo.GetQueryTimeoutSeconds(conn);
+ cmd.ExecuteNonQuery();
+ }
+ }
+
+ internal static void SetDefaultAnsiSettings(IDbConnection conn)
+ {
+ Validate.IsNotNull(nameof(conn), conn);
+
+ // Make sure we use the underlying connection as ReliableConnection.Open also calls
+ // this method
+ ReliableSqlConnection reliableConn = conn as ReliableSqlConnection;
+ if (reliableConn != null)
+ {
+ conn = reliableConn._underlyingConnection;
+ }
+
+ // Configure the connection with proper ANSI settings and lock timeout
+ using (IDbCommand cmd = conn.CreateCommand())
+ {
+ cmd.CommandTimeout = CachedServerInfo.GetQueryTimeoutSeconds(conn);
+ cmd.CommandText = @"SET ANSI_NULLS, ANSI_PADDING, ANSI_WARNINGS, ARITHABORT, CONCAT_NULL_YIELDS_NULL, QUOTED_IDENTIFIER ON;
+SET NUMERIC_ROUNDABORT OFF;";
+ cmd.ExecuteNonQuery();
+ }
+ }
+
+ ///
+ /// Gets or sets the connection string for opening a connection to the SQL Azure database.
+ ///
+ public override string ConnectionString
+ {
+ get { return _underlyingConnection.ConnectionString; }
+ set { _underlyingConnection.ConnectionString = value; }
+ }
+
+ ///
+ /// Gets the policy which decides whether to retry a connection request, based on how many
+ /// times the request has been made and the reason for the last failure.
+ ///
+ public RetryPolicy ConnectionRetryPolicy
+ {
+ get { return _connectionRetryPolicy; }
+ }
+
+ ///
+ /// Gets the policy which decides whether to retry a command, based on how many
+ /// times the request has been made and the reason for the last failure.
+ ///
+ public RetryPolicy CommandRetryPolicy
+ {
+ get { return _commandRetryPolicy; }
+ set
+ {
+ Validate.IsNotNull(nameof(value), value);
+
+ if (_commandRetryPolicy != null)
+ {
+ _commandRetryPolicy.RetryOccurred -= RetryCommandCallback;
+ }
+
+ _commandRetryPolicy = value;
+ _commandRetryPolicy.RetryOccurred += RetryCommandCallback;
+ }
+ }
+
+ ///
+ /// Gets the server name from the underlying connection.
+ ///
+ public override string DataSource
+ {
+ get { return _underlyingConnection.DataSource; }
+ }
+
+ ///
+ /// Gets the server version from the underlying connection.
+ ///
+ public override string ServerVersion
+ {
+ get { return _underlyingConnection.ServerVersion; }
+ }
+
+ ///
+ /// If the underlying SqlConnection absolutely has to be accessed, for instance
+ /// to pass to external APIs that require this type of connection, then this
+ /// can be used.
+ ///
+ ///
+ public SqlConnection GetUnderlyingConnection()
+ {
+ return _underlyingConnection;
+ }
+
+ ///
+ /// Begins a database transaction with the specified System.Data.IsolationLevel value.
+ ///
+ /// One of the System.Data.IsolationLevel values.
+ /// An object representing the new transaction.
+ protected override DbTransaction BeginDbTransaction(IsolationLevel level)
+ {
+ return _underlyingConnection.BeginTransaction(level);
+ }
+
+ ///
+ /// Changes the current database for an open Connection object.
+ ///
+ /// The name of the database to use in place of the current database.
+ public override void ChangeDatabase(string databaseName)
+ {
+ _underlyingConnection.ChangeDatabase(databaseName);
+ }
+
+ ///
+ /// Opens a database connection with the settings specified by the ConnectionString
+ /// property of the provider-specific Connection object.
+ ///
+ public override void Open()
+ {
+ OpenConnection();
+ }
+
+ ///
+ /// Closes the connection to the database.
+ ///
+ public override void Close()
+ {
+ _underlyingConnection.Close();
+ }
+
+ ///
+ /// Gets the time to wait while trying to establish a connection before terminating
+ /// the attempt and generating an error.
+ ///
+ public override int ConnectionTimeout
+ {
+ get { return _underlyingConnection.ConnectionTimeout; }
+ }
+
+ ///
+ /// Creates and returns an object implementing the IDbCommand interface which is associated
+ /// with the underlying SqlConnection.
+ ///
+ /// A object.
+ protected override DbCommand CreateDbCommand()
+ {
+ return CreateReliableCommand();
+ }
+
+ ///
+ /// Creates and returns an object implementing the IDbCommand interface which is associated
+ /// with the underlying SqlConnection.
+ ///
+ /// A object.
+ public SqlCommand CreateSqlCommand()
+ {
+ return _underlyingConnection.CreateCommand();
+ }
+
+ ///
+ /// Gets the name of the current database or the database to be used after a
+ /// connection is opened.
+ ///
+ public override string Database
+ {
+ get { return _underlyingConnection.Database; }
+ }
+
+ ///
+ /// Gets the current state of the connection.
+ ///
+ public override ConnectionState State
+ {
+ get { return _underlyingConnection.State; }
+ }
+
+ ///
+ /// Adds an info message event listener.
+ ///
+ /// An info message event listener.
+ public void AddInfoMessageHandler(SqlInfoMessageEventHandler handler)
+ {
+ _underlyingConnection.InfoMessage += handler;
+ }
+
+ ///
+ /// Removes an info message event listener.
+ ///
+ /// An info message event listener.
+ public void RemoveInfoMessageHandler(SqlInfoMessageEventHandler handler)
+ {
+ _underlyingConnection.InfoMessage -= handler;
+ }
+
+ ///
+ /// Clears underlying connection pool.
+ ///
+ public void ClearPool()
+ {
+ if (_underlyingConnection != null)
+ {
+ SqlConnection.ClearPool(_underlyingConnection);
+ }
+ }
+
+ private void RetryCommandCallback(RetryState retryState)
+ {
+ RetryPolicyUtils.RaiseSchemaAmbientRetryMessage(retryState, SqlSchemaModelErrorCodes.ServiceActions.CommandRetry, _azureSessionId);
+ }
+
+ private void RetryConnectionCallback(RetryState retryState)
+ {
+ RetryPolicyUtils.RaiseSchemaAmbientRetryMessage(retryState, SqlSchemaModelErrorCodes.ServiceActions.ConnectionRetry, _azureSessionId);
+ }
+
+ ///
+ /// Opens a database connection with the settings specified by the ConnectionString and ConnectionRetryPolicy properties.
+ ///
+ /// An object representing the open connection.
+ private SqlConnection OpenConnection()
+ {
+ // Check if retry policy was specified, if not, disable retries by executing the Open method using RetryPolicy.NoRetry.
+ _connectionRetryPolicy.ExecuteAction(() =>
+ {
+ if (_underlyingConnection.State != ConnectionState.Open)
+ {
+ _underlyingConnection.Open();
+ }
+ SetLockAndCommandTimeout(_underlyingConnection);
+ SetDefaultAnsiSettings(_underlyingConnection);
+ });
+
+ return _underlyingConnection;
+ }
+
+ public void OnConnectionStateChange(object sender, StateChangeEventArgs e)
+ {
+ SqlConnection conn = (SqlConnection)sender;
+ switch (e.CurrentState)
+ {
+ case ConnectionState.Open:
+ RetreiveSessionId();
+ break;
+ case ConnectionState.Broken:
+ case ConnectionState.Closed:
+ _azureSessionId = Guid.Empty;
+ break;
+ case ConnectionState.Connecting:
+ case ConnectionState.Executing:
+ case ConnectionState.Fetching:
+ default:
+ break;
+ }
+ }
+
+ private void RetreiveSessionId()
+ {
+ try
+ {
+ using (IDbCommand command = CreateReliableCommand())
+ {
+ command.CommandText = QueryAzureSessionId;
+ object result = command.ExecuteScalar();
+
+ // Only returns a session id for SQL Azure
+ if (DBNull.Value != result)
+ {
+ string sessionId = (string)command.ExecuteScalar();
+ _azureSessionId = new Guid(sessionId);
+ }
+ }
+ }
+ catch (SqlException exception)
+ {
+ Logger.Write(LogLevel.Error, Resources.UnableToRetrieveAzureSessionId + exception.ToString());
+ }
+ }
+
+ ///
+ /// Creates and returns a ReliableSqlCommand object associated
+ /// with the underlying SqlConnection.
+ ///
+ /// A object.
+ private ReliableSqlCommand CreateReliableCommand()
+ {
+ return new ReliableSqlCommand(this);
+ }
+
+ private void VerifyConnectionOpen(IDbCommand command)
+ {
+ // Verify whether or not the connection is valid and is open. This code may be retried therefore
+ // it is important to ensure that a connection is re-established should it have previously failed.
+ if (command.Connection == null)
+ {
+ command.Connection = this;
+ }
+
+ if (command.Connection.State != ConnectionState.Open)
+ {
+ SqlConnection.ClearPool(_underlyingConnection);
+
+ command.Connection.Open();
+ }
+ }
+
+ private IDataReader ExecuteReader(IDbCommand command, CommandBehavior behavior)
+ {
+ Tuple[] sessionSettings = null;
+ return _commandRetryPolicy.ExecuteAction(() =>
+ {
+ VerifyConnectionOpen(command);
+ sessionSettings = CacheOrReplaySessionSettings(command, sessionSettings);
+
+ return command.ExecuteReader(behavior);
+ });
+ }
+
+ // Because retry loses session settings, cache session settings or reply if the settings are already cached.
+ internal Tuple[] CacheOrReplaySessionSettings(IDbCommand originalCommand, Tuple[] sessionSettings)
+ {
+ if (sessionSettings == null)
+ {
+ sessionSettings = QuerySessionSettings(originalCommand);
+ }
+ else
+ {
+ SetSessionSettings(originalCommand.Connection, sessionSettings);
+ }
+
+ return sessionSettings;
+ }
+
+ private object ExecuteScalar(IDbCommand command)
+ {
+ Tuple[] sessionSettings = null;
+ return _commandRetryPolicy.ExecuteAction(() =>
+ {
+ VerifyConnectionOpen(command);
+ sessionSettings = CacheOrReplaySessionSettings(command, sessionSettings);
+
+ return command.ExecuteScalar();
+ });
+ }
+
+ private Tuple[] QuerySessionSettings(IDbCommand originalCommand)
+ {
+ Tuple[] sessionSettings = new Tuple[2];
+
+ IDbConnection connection = originalCommand.Connection;
+ using (IDbCommand localCommand = connection.CreateCommand())
+ {
+ // Executing a reader requires preservation of any pending transaction created by the calling command
+ localCommand.Transaction = originalCommand.Transaction;
+ localCommand.CommandText = "SELECT ISNULL(SESSIONPROPERTY ('ANSI_NULLS'), 0), ISNULL(SESSIONPROPERTY ('QUOTED_IDENTIFIER'), 1)";
+ using (IDataReader reader = localCommand.ExecuteReader())
+ {
+ if (reader.Read())
+ {
+ sessionSettings[0] = Tuple.Create("ANSI_NULLS", ((int)reader[0] == 1));
+ sessionSettings[1] = Tuple.Create("QUOTED_IDENTIFIER", ((int)reader[1] ==1));
+ }
+ else
+ {
+ Debug.Assert(false, "Reader cannot be empty");
+ }
+ }
+ return sessionSettings;
+ }
+ }
+
+ private void SetSessionSettings(IDbConnection connection, params Tuple[] settings)
+ {
+ List setONOptions = new List();
+ List setOFFOptions = new List();
+ if(settings != null)
+ {
+ foreach (Tuple setting in settings)
+ {
+ if (setting.Item2)
+ {
+ setONOptions.Add(setting.Item1);
+ }
+ else
+ {
+ setOFFOptions.Add(setting.Item1);
+ }
+ }
+ }
+
+ SetSessionSettings(connection, setONOptions, "ON");
+ SetSessionSettings(connection, setOFFOptions, "OFF");
+
+ }
+
+ [System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Security", "CA2100:Review SQL queries for security vulnerabilities")]
+ private static void SetSessionSettings(IDbConnection connection, List sessionOptions, string onOff)
+ {
+ if (sessionOptions.Count > 0)
+ {
+ using (IDbCommand localCommand = connection.CreateCommand())
+ {
+ StringBuilder builder = new StringBuilder("SET ");
+ for (int i = 0; i < sessionOptions.Count; i++)
+ {
+ if (i > 0)
+ {
+ builder.Append(',');
+ }
+ builder.Append(sessionOptions[i]);
+ }
+ builder.Append(" ");
+ builder.Append(onOff);
+ localCommand.CommandText = builder.ToString();
+ localCommand.ExecuteNonQuery();
+ }
+ }
+ }
+
+ private int ExecuteNonQuery(IDbCommand command)
+ {
+ Tuple[] sessionSettings = null;
+ return _commandRetryPolicy.ExecuteAction(() =>
+ {
+ VerifyConnectionOpen(command);
+ sessionSettings = CacheOrReplaySessionSettings(command, sessionSettings);
+
+ return command.ExecuteNonQuery();
+ });
+ }
+ }
+}
+
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/Resources.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/Resources.cs
new file mode 100644
index 00000000..3fcbe225
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/Resources.cs
@@ -0,0 +1,149 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
+{
+ ///
+ /// Contains string resources used throughout ReliableConnection code.
+ ///
+ internal static class Resources
+ {
+ internal static string AmbientSettingFormat
+ {
+ get
+ {
+ return "{0}: {1}";
+ }
+ }
+
+ internal static string ConnectionPassedToIsCloudShouldBeOpen
+ {
+ get
+ {
+ return "connection passed to IsCloud should be open.";
+ }
+ }
+
+ internal static string ConnectionPropertyNotSet
+ {
+ get
+ {
+ return "Connection property has not been initialized.";
+ }
+ }
+
+ internal static string ExceptionCannotBeRetried
+ {
+ get
+ {
+ return "Exception cannot be retried because of err #{0}:{1}";
+ }
+ }
+
+ internal static string ErrorParsingConnectionString
+ {
+ get
+ {
+ return "Error parsing connection string {0}";
+ }
+ }
+
+ internal static string FailedToCacheIsCloud
+ {
+ get
+ {
+ return "failed to cache the server property of IsAzure";
+ }
+ }
+
+ internal static string FailedToParseConnectionString
+ {
+ get
+ {
+ return "failed to parse the provided connection string: {0}";
+ }
+ }
+
+ internal static string IgnoreOnException
+ {
+ get
+ {
+ return "Retry number {0}. Ignoring Exception: {1}";
+ }
+ }
+
+ internal static string InvalidCommandType
+ {
+ get
+ {
+ return "Unsupported command object. Use SqlCommand or ReliableSqlCommand.";
+ }
+ }
+
+ internal static string InvalidConnectionType
+ {
+ get
+ {
+ return "Unsupported connection object. Use SqlConnection or ReliableSqlConnection.";
+ }
+ }
+
+ internal static string LoggingAmbientSettings
+ {
+ get
+ {
+ return "Logging Ambient Settings...";
+ }
+ }
+
+ internal static string Mode
+ {
+ get
+ {
+ return "Mode";
+ }
+ }
+
+ internal static string OnlyReliableConnectionSupported
+ {
+ get
+ {
+ return "Connection property can only be set to a value of type ReliableSqlConnection.";
+ }
+ }
+
+ internal static string RetryOnException
+ {
+ get
+ {
+ return "Retry number {0}. Delaying {1} ms before next retry. Exception: {2}";
+ }
+ }
+
+ internal static string ThrottlingTypeInfo
+ {
+ get
+ {
+ return "ThrottlingTypeInfo";
+ }
+ }
+
+ internal static string UnableToAssignValue
+ {
+ get
+ {
+ return "Unable to assign the value of type {0} to {1}";
+ }
+ }
+
+ internal static string UnableToRetrieveAzureSessionId
+ {
+ get
+ {
+ return "Unable to retrieve Azure session-id.";
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryCallbackEventArgs.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryCallbackEventArgs.cs
new file mode 100644
index 00000000..1fa26cee
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryCallbackEventArgs.cs
@@ -0,0 +1,61 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+// This code is copied from the source described in the comment below.
+
+// =======================================================================================
+// Microsoft Windows Server AppFabric Customer Advisory Team (CAT) Best Practices Series
+//
+// This sample is supplemental to the technical guidance published on the community
+// blog at http://blogs.msdn.com/appfabriccat/ and copied from
+// sqlmain ./sql/manageability/mfx/common/
+//
+// =======================================================================================
+// Copyright © 2012 Microsoft Corporation. All rights reserved.
+//
+// THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER
+// EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF
+// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. YOU BEAR THE RISK OF USING IT.
+// =======================================================================================
+
+// namespace Microsoft.SQL.CAT.BestPractices.SqlAzure.Framework
+// namespace Microsoft.SqlServer.Management.Common
+
+using System;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
+{
+ ///
+ /// Defines a arguments for event handler which will be invoked whenever a retry condition is encountered.
+ ///
+ internal sealed class RetryCallbackEventArgs : EventArgs
+ {
+ private readonly int _retryCount;
+ private readonly Exception _exception;
+ private readonly TimeSpan _delay;
+
+ public RetryCallbackEventArgs(int retryCount, Exception exception, TimeSpan delay)
+ {
+ _retryCount = retryCount;
+ _exception = exception;
+ _delay = delay;
+ }
+
+ public TimeSpan Delay
+ {
+ get { return _delay; }
+ }
+
+ public Exception Exception
+ {
+ get { return _exception; }
+ }
+
+ public int RetryCount
+ {
+ get { return _retryCount; }
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryLimitExceededException.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryLimitExceededException.cs
new file mode 100644
index 00000000..0ae8b4e7
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryLimitExceededException.cs
@@ -0,0 +1,38 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+// This code is copied from the source described in the comment below.
+
+// =======================================================================================
+// Microsoft Windows Server AppFabric Customer Advisory Team (CAT) Best Practices Series
+//
+// This sample is supplemental to the technical guidance published on the community
+// blog at http://blogs.msdn.com/appfabriccat/ and copied from
+// sqlmain ./sql/manageability/mfx/common/
+//
+// =======================================================================================
+// Copyright © 2012 Microsoft Corporation. All rights reserved.
+//
+// THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER
+// EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF
+// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. YOU BEAR THE RISK OF USING IT.
+// =======================================================================================
+
+// namespace Microsoft.SQL.CAT.BestPractices.SqlAzure.Framework
+// namespace Microsoft.SqlServer.Management.Common
+
+using System;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
+{
+ ///
+ /// The special type of exception that provides managed exit from a retry loop. The user code can use this
+ /// exception to notify the retry policy that no further retry attempts are required.
+ ///
+ [Serializable]
+ internal sealed class RetryLimitExceededException : Exception
+ {
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.DataTransferDetectionErrorStrategy.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.DataTransferDetectionErrorStrategy.cs
new file mode 100644
index 00000000..8876fca7
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.DataTransferDetectionErrorStrategy.cs
@@ -0,0 +1,43 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System.Data.SqlClient;
+using Microsoft.SqlTools.EditorServices.Utility;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
+{
+ internal abstract partial class RetryPolicy
+ {
+ ///
+ /// Provides the error detection logic for temporary faults that are commonly found during data transfer.
+ ///
+ internal sealed class DataTransferErrorDetectionStrategy : ErrorDetectionStrategyBase, IErrorDetectionStrategy
+ {
+ private static readonly DataTransferErrorDetectionStrategy instance = new DataTransferErrorDetectionStrategy();
+
+ public static DataTransferErrorDetectionStrategy Instance
+ {
+ get { return instance; }
+ }
+
+ protected override bool CanRetrySqlException(SqlException sqlException)
+ {
+ // Enumerate through all errors found in the exception.
+ foreach (SqlError err in sqlException.Errors)
+ {
+ RetryPolicyUtils.AppendThrottlingDataIfIsThrottlingError(sqlException, err);
+ if (RetryPolicyUtils.IsNonRetryableDataTransferError(err.Number))
+ {
+ Logger.Write(LogLevel.Error, string.Format(Resources.ExceptionCannotBeRetried, err.Number, err.Message));
+ return false;
+ }
+ }
+
+ // Default is to treat all SqlException as retriable.
+ return true;
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.IErrorDetectionStrategy.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.IErrorDetectionStrategy.cs
new file mode 100644
index 00000000..bc591616
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.IErrorDetectionStrategy.cs
@@ -0,0 +1,97 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.Data.SqlClient;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
+{
+ internal abstract partial class RetryPolicy
+ {
+ public interface IErrorDetectionStrategy
+ {
+ ///
+ /// Determines whether the specified exception represents a temporary failure that can be compensated by a retry.
+ ///
+ /// The exception object to be verified.
+ /// True if the specified exception is considered as temporary, otherwise false.
+ bool CanRetry(Exception ex);
+
+ ///
+ /// Determines whether the specified exception can be ignored.
+ ///
+ /// The exception object to be verified.
+ /// True if the specified exception is considered as non-harmful.
+ bool ShouldIgnoreError(Exception ex);
+ }
+
+ ///
+ /// Base class with common retry logic. The core behavior for retrying non SqlExceptions is the same
+ /// across retry policies
+ ///
+ internal abstract class ErrorDetectionStrategyBase : IErrorDetectionStrategy
+ {
+ public bool CanRetry(Exception ex)
+ {
+ if (ex != null)
+ {
+ SqlException sqlException;
+ if ((sqlException = ex as SqlException) != null)
+ {
+ return CanRetrySqlException(sqlException);
+ }
+ if (ex is InvalidOperationException)
+ {
+ // Operations can throw this exception if the connection is killed before the write starts to the server
+ // However if there's an inner SqlException it may be a CLR load failure or other non-transient error
+ if (ex.InnerException != null
+ && ex.InnerException is SqlException)
+ {
+ return CanRetry(ex.InnerException);
+ }
+ return true;
+ }
+ if (ex is TimeoutException)
+ {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ public bool ShouldIgnoreError(Exception ex)
+ {
+ if (ex != null)
+ {
+ SqlException sqlException;
+ if ((sqlException = ex as SqlException) != null)
+ {
+ return ShouldIgnoreSqlException(sqlException);
+ }
+ if (ex is InvalidOperationException)
+ {
+ // Operations can throw this exception if the connection is killed before the write starts to the server
+ // However if there's an inner SqlException it may be a CLR load failure or other non-transient error
+ if (ex.InnerException != null
+ && ex.InnerException is SqlException)
+ {
+ return ShouldIgnoreError(ex.InnerException);
+ }
+ }
+ }
+
+ return false;
+ }
+
+ protected virtual bool ShouldIgnoreSqlException(SqlException sqlException)
+ {
+ return false;
+ }
+
+ protected abstract bool CanRetrySqlException(SqlException sqlException);
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.NetworkConnectivityErrorStrategy.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.NetworkConnectivityErrorStrategy.cs
new file mode 100644
index 00000000..456426d1
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.NetworkConnectivityErrorStrategy.cs
@@ -0,0 +1,43 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System.Data.SqlClient;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
+{
+ internal abstract partial class RetryPolicy
+ {
+ ///
+ /// Provides the error detection logic for temporary faults that are commonly found in SQL Azure.
+ /// The same errors CAN occur on premise also, but they are not seen as often.
+ ///
+ internal sealed class NetworkConnectivityErrorDetectionStrategy : ErrorDetectionStrategyBase, IErrorDetectionStrategy
+ {
+ private static NetworkConnectivityErrorDetectionStrategy instance = new NetworkConnectivityErrorDetectionStrategy();
+
+ public static NetworkConnectivityErrorDetectionStrategy Instance
+ {
+ get { return instance; }
+ }
+
+ protected override bool CanRetrySqlException(SqlException sqlException)
+ {
+ // Enumerate through all errors found in the exception.
+ bool foundRetryableError = false;
+ foreach (SqlError err in sqlException.Errors)
+ {
+ RetryPolicyUtils.AppendThrottlingDataIfIsThrottlingError(sqlException, err);
+ if (!RetryPolicyUtils.IsRetryableNetworkConnectivityError(err.Number))
+ {
+ // If any error is not retryable then cannot retry
+ return false;
+ }
+ foundRetryableError = true;
+ }
+ return foundRetryableError;
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.SqlAzureTemporaryAndIgnorableErrorDetectionStrategy.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.SqlAzureTemporaryAndIgnorableErrorDetectionStrategy.cs
new file mode 100644
index 00000000..0cf26070
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.SqlAzureTemporaryAndIgnorableErrorDetectionStrategy.cs
@@ -0,0 +1,63 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System.Collections.Generic;
+using System.Data.SqlClient;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
+{
+ internal abstract partial class RetryPolicy
+ {
+ ///
+ /// Provides the error detection logic for temporary faults that are commonly found in SQL Azure.
+ /// This strategy is similar to SqlAzureTemporaryErrorDetectionStrategy, but it exposes ways
+ /// to accept a certain exception and treat it as passing.
+ /// For example, if we are retrying, and we get a failure that an object already exists, we might
+ /// want to consider this as passing since the first execution that has timed out (or failed for some other temporary error)
+ /// might have managed to create the object.
+ ///
+ internal sealed class SqlAzureTemporaryAndIgnorableErrorDetectionStrategy : ErrorDetectionStrategyBase, IErrorDetectionStrategy
+ {
+ ///
+ /// Azure error that can be ignored
+ ///
+ private readonly IList ignorableAzureErrors = null;
+
+ public SqlAzureTemporaryAndIgnorableErrorDetectionStrategy(params int[] ignorableErrors)
+ {
+ this.ignorableAzureErrors = ignorableErrors;
+ }
+
+ protected override bool CanRetrySqlException(SqlException sqlException)
+ {
+ // Enumerate through all errors found in the exception.
+ bool foundRetryableError = false;
+ foreach (SqlError err in sqlException.Errors)
+ {
+ RetryPolicyUtils.AppendThrottlingDataIfIsThrottlingError(sqlException, err);
+ if (!RetryPolicyUtils.IsRetryableAzureError(err.Number))
+ {
+ return false;
+ }
+
+ foundRetryableError = true;
+ }
+ return foundRetryableError;
+ }
+
+ protected override bool ShouldIgnoreSqlException(SqlException sqlException)
+ {
+ int errorNumber = sqlException.Number;
+
+ if (ignorableAzureErrors == null)
+ {
+ return false;
+ }
+
+ return ignorableAzureErrors.Contains(errorNumber);
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.SqlAzureTemporaryErrorDetectionStrategy.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.SqlAzureTemporaryErrorDetectionStrategy.cs
new file mode 100644
index 00000000..43233e85
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.SqlAzureTemporaryErrorDetectionStrategy.cs
@@ -0,0 +1,43 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System.Data.SqlClient;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
+{
+ internal abstract partial class RetryPolicy
+ {
+ ///
+ /// Provides the error detection logic for temporary faults that are commonly found in SQL Azure.
+ /// The same errors CAN occur on premise also, but they are not seen as often.
+ ///
+ internal sealed class SqlAzureTemporaryErrorDetectionStrategy : ErrorDetectionStrategyBase, IErrorDetectionStrategy
+ {
+ private static SqlAzureTemporaryErrorDetectionStrategy instance = new SqlAzureTemporaryErrorDetectionStrategy();
+
+ public static SqlAzureTemporaryErrorDetectionStrategy Instance
+ {
+ get { return instance; }
+ }
+
+ protected override bool CanRetrySqlException(SqlException sqlException)
+ {
+ // Enumerate through all errors found in the exception.
+ bool foundRetryableError = false;
+ foreach (SqlError err in sqlException.Errors)
+ {
+ RetryPolicyUtils.AppendThrottlingDataIfIsThrottlingError(sqlException, err);
+ if (!RetryPolicyUtils.IsRetryableAzureError(err.Number))
+ {
+ // If any error is not retryable then cannot retry
+ return false;
+ }
+ foundRetryableError = true;
+ }
+ return foundRetryableError;
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.ThrottleReason.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.ThrottleReason.cs
new file mode 100644
index 00000000..64eb2102
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.ThrottleReason.cs
@@ -0,0 +1,357 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.Collections.Generic;
+using System.Data.SqlClient;
+using System.Globalization;
+using System.Linq;
+using System.Text;
+using System.Text.RegularExpressions;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
+{
+ internal abstract partial class RetryPolicy
+ {
+ ///
+ /// Implements an object holding the decoded reason code returned from SQL Azure when encountering throttling conditions.
+ ///
+ [Serializable]
+ public class ThrottlingReason
+ {
+ ///
+ /// Returns the error number that corresponds to throttling conditions reported by SQL Azure.
+ ///
+ public const int ThrottlingErrorNumber = 40501;
+
+ ///
+ /// Gets an unknown throttling condition in the event the actual throttling condition cannot be determined.
+ ///
+ public static ThrottlingReason Unknown
+ {
+ get
+ {
+ var unknownCondition = new ThrottlingReason() { ThrottlingMode = ThrottlingMode.Unknown };
+ unknownCondition.throttledResources.Add(Tuple.Create(ThrottledResourceType.Unknown, ThrottlingType.Unknown));
+
+ return unknownCondition;
+ }
+ }
+
+ ///
+ /// Maintains a collection of key-value pairs where a key is resource type and a value is the type of throttling applied to the given resource type.
+ ///
+ private readonly IList> throttledResources = new List>(9);
+
+ ///
+ /// Provides a compiled regular expression used for extracting the reason code from the error message.
+ ///
+ private static readonly Regex sqlErrorCodeRegEx = new Regex(@"Code:\s*(\d+)", RegexOptions.IgnoreCase | RegexOptions.Compiled);
+
+ ///
+ /// Gets the value that reflects the throttling mode in SQL Azure.
+ ///
+ public ThrottlingMode ThrottlingMode
+ {
+ get;
+ private set;
+ }
+
+ ///
+ /// Gets the list of resources in SQL Azure that were subject to throttling conditions.
+ ///
+ public IEnumerable> ThrottledResources
+ {
+ get
+ {
+ return this.throttledResources;
+ }
+ }
+
+ ///
+ /// Determines throttling conditions from the specified SQL exception.
+ ///
+ /// The object containing information relevant to an error returned by SQL Server when encountering throttling conditions.
+ /// An instance of the object holding the decoded reason codes returned from SQL Azure upon encountering throttling conditions.
+ public static ThrottlingReason FromException(SqlException ex)
+ {
+ if (ex != null)
+ {
+ foreach (SqlError error in ex.Errors)
+ {
+ if (error.Number == ThrottlingErrorNumber)
+ {
+ return FromError(error);
+ }
+ }
+ }
+
+ return Unknown;
+ }
+
+ ///
+ /// Determines the throttling conditions from the specified SQL error.
+ ///
+ /// The object containing information relevant to a warning or error returned by SQL Server.
+ /// An instance of the object holding the decoded reason codes returned from SQL Azure when encountering throttling conditions.
+ public static ThrottlingReason FromError(SqlError error)
+ {
+ if (error != null)
+ {
+ var match = sqlErrorCodeRegEx.Match(error.Message);
+ int reasonCode = 0;
+
+ if (match.Success && Int32.TryParse(match.Groups[1].Value, out reasonCode))
+ {
+ return FromReasonCode(reasonCode);
+ }
+ }
+
+ return Unknown;
+ }
+
+ ///
+ /// Determines the throttling conditions from the specified reason code.
+ ///
+ /// The reason code returned by SQL Azure which contains the throttling mode and the exceeded resource types.
+ /// An instance of the object holding the decoded reason codes returned from SQL Azure when encountering throttling conditions.
+ public static ThrottlingReason FromReasonCode(int reasonCode)
+ {
+ if (reasonCode > 0)
+ {
+ // Decode throttling mode from the last 2 bits.
+ ThrottlingMode throttlingMode = (ThrottlingMode)(reasonCode & 3);
+
+ var condition = new ThrottlingReason() { ThrottlingMode = throttlingMode };
+
+ // Shift 8 bits to truncate throttling mode.
+ int groupCode = reasonCode >> 8;
+
+ // Determine throttling type for all well-known resources that may be subject to throttling conditions.
+ condition.throttledResources.Add(Tuple.Create(ThrottledResourceType.PhysicalDatabaseSpace, (ThrottlingType)(groupCode & 3)));
+ condition.throttledResources.Add(Tuple.Create(ThrottledResourceType.PhysicalLogSpace, (ThrottlingType)((groupCode = groupCode >> 2) & 3)));
+ condition.throttledResources.Add(Tuple.Create(ThrottledResourceType.LogWriteIODelay, (ThrottlingType)((groupCode = groupCode >> 2) & 3)));
+ condition.throttledResources.Add(Tuple.Create(ThrottledResourceType.DataReadIODelay, (ThrottlingType)((groupCode = groupCode >> 2) & 3)));
+ condition.throttledResources.Add(Tuple.Create(ThrottledResourceType.CPU, (ThrottlingType)((groupCode = groupCode >> 2) & 3)));
+ condition.throttledResources.Add(Tuple.Create(ThrottledResourceType.DatabaseSize, (ThrottlingType)((groupCode = groupCode >> 2) & 3)));
+ condition.throttledResources.Add(Tuple.Create(ThrottledResourceType.Internal, (ThrottlingType)((groupCode = groupCode >> 2) & 3)));
+ condition.throttledResources.Add(Tuple.Create(ThrottledResourceType.WorkerThreads, (ThrottlingType)((groupCode = groupCode >> 2) & 3)));
+ condition.throttledResources.Add(Tuple.Create(ThrottledResourceType.Internal, (ThrottlingType)((groupCode = groupCode >> 2) & 3)));
+
+ return condition;
+ }
+ else
+ {
+ return Unknown;
+ }
+ }
+
+ ///
+ /// Gets a value indicating whether physical data file space throttling was reported by SQL Azure.
+ ///
+ public bool IsThrottledOnDataSpace
+ {
+ get
+ {
+ return this.throttledResources.Where(x => x.Item1 == ThrottledResourceType.PhysicalDatabaseSpace).Count() > 0;
+ }
+ }
+
+ ///
+ /// Gets a value indicating whether physical log space throttling was reported by SQL Azure.
+ ///
+ public bool IsThrottledOnLogSpace
+ {
+ get
+ {
+ return this.throttledResources.Where(x => x.Item1 == ThrottledResourceType.PhysicalLogSpace).Count() > 0;
+ }
+ }
+
+ ///
+ /// Gets a value indicating whether transaction activity throttling was reported by SQL Azure.
+ ///
+ public bool IsThrottledOnLogWrite
+ {
+ get { return this.throttledResources.Where(x => x.Item1 == ThrottledResourceType.LogWriteIODelay).Count() > 0; }
+ }
+
+ ///
+ /// Gets a value indicating whether data read activity throttling was reported by SQL Azure.
+ ///
+ public bool IsThrottledOnDataRead
+ {
+ get { return this.throttledResources.Where(x => x.Item1 == ThrottledResourceType.DataReadIODelay).Count() > 0; }
+ }
+
+ ///
+ /// Gets a value indicating whether CPU throttling was reported by SQL Azure.
+ ///
+ public bool IsThrottledOnCPU
+ {
+ get { return this.throttledResources.Where(x => x.Item1 == ThrottledResourceType.CPU).Count() > 0; }
+ }
+
+ ///
+ /// Gets a value indicating whether database size throttling was reported by SQL Azure.
+ ///
+ public bool IsThrottledOnDatabaseSize
+ {
+ get { return this.throttledResources.Where(x => x.Item1 == ThrottledResourceType.DatabaseSize).Count() > 0; }
+ }
+
+ ///
+ /// Gets a value indicating whether concurrent requests throttling was reported by SQL Azure.
+ ///
+ public bool IsThrottledOnWorkerThreads
+ {
+ get { return this.throttledResources.Where(x => x.Item1 == ThrottledResourceType.WorkerThreads).Count() > 0; }
+ }
+
+ ///
+ /// Gets a value indicating whether throttling conditions were not determined with certainty.
+ ///
+ public bool IsUnknown
+ {
+ get { return ThrottlingMode == ThrottlingMode.Unknown; }
+ }
+
+ ///
+ /// Returns a textual representation the current ThrottlingReason object including the information held with respect to throttled resources.
+ ///
+ /// A string that represents the current ThrottlingReason object.
+ public override string ToString()
+ {
+ StringBuilder result = new StringBuilder();
+
+ result.AppendFormat(Resources.Mode, ThrottlingMode);
+
+ var resources = this.throttledResources.Where(x => x.Item1 != ThrottledResourceType.Internal).
+ Select, string>(x => String.Format(CultureInfo.CurrentCulture, Resources.ThrottlingTypeInfo, x.Item1, x.Item2)).
+ OrderBy(x => x).ToArray();
+
+ result.Append(String.Join(", ", resources));
+
+ return result.ToString();
+ }
+ }
+
+ #region ThrottlingMode enumeration
+ ///
+ /// Defines the possible throttling modes in SQL Azure.
+ ///
+ public enum ThrottlingMode
+ {
+ ///
+ /// Corresponds to "No Throttling" throttling mode whereby all SQL statements can be processed.
+ ///
+ NoThrottling = 0,
+
+ ///
+ /// Corresponds to "Reject Update / Insert" throttling mode whereby SQL statements such as INSERT, UPDATE, CREATE TABLE and CREATE INDEX are rejected.
+ ///
+ RejectUpdateInsert = 1,
+
+ ///
+ /// Corresponds to "Reject All Writes" throttling mode whereby SQL statements such as INSERT, UPDATE, DELETE, CREATE, DROP are rejected.
+ ///
+ RejectAllWrites = 2,
+
+ ///
+ /// Corresponds to "Reject All" throttling mode whereby all SQL statements are rejected.
+ ///
+ RejectAll = 3,
+
+ ///
+ /// Corresponds to an unknown throttling mode whereby throttling mode cannot be determined with certainty.
+ ///
+ Unknown = -1
+ }
+ #endregion
+
+ #region ThrottlingType enumeration
+ ///
+ /// Defines the possible throttling types in SQL Azure.
+ ///
+ public enum ThrottlingType
+ {
+ ///
+ /// Indicates that no throttling was applied to a given resource.
+ ///
+ None = 0,
+
+ ///
+ /// Corresponds to a Soft throttling type. Soft throttling is applied when machine resources such as, CPU, IO, storage, and worker threads exceed
+ /// predefined safety thresholds despite the load balancer’s best efforts.
+ ///
+ Soft = 1,
+
+ ///
+ /// Corresponds to a Hard throttling type. Hard throttling is applied when the machine is out of resources, for example storage space.
+ /// With hard throttling, no new connections are allowed to the databases hosted on the machine until resources are freed up.
+ ///
+ Hard = 2,
+
+ ///
+ /// Corresponds to an unknown throttling type in the event when the throttling type cannot be determined with certainty.
+ ///
+ Unknown = 3
+ }
+ #endregion
+
+ #region ThrottledResourceType enumeration
+ ///
+ /// Defines the types of resources in SQL Azure which may be subject to throttling conditions.
+ ///
+ public enum ThrottledResourceType
+ {
+ ///
+ /// Corresponds to "Physical Database Space" resource which may be subject to throttling.
+ ///
+ PhysicalDatabaseSpace = 0,
+
+ ///
+ /// Corresponds to "Physical Log File Space" resource which may be subject to throttling.
+ ///
+ PhysicalLogSpace = 1,
+
+ ///
+ /// Corresponds to "Transaction Log Write IO Delay" resource which may be subject to throttling.
+ ///
+ LogWriteIODelay = 2,
+
+ ///
+ /// Corresponds to "Database Read IO Delay" resource which may be subject to throttling.
+ ///
+ DataReadIODelay = 3,
+
+ ///
+ /// Corresponds to "CPU" resource which may be subject to throttling.
+ ///
+ CPU = 4,
+
+ ///
+ /// Corresponds to "Database Size" resource which may be subject to throttling.
+ ///
+ DatabaseSize = 5,
+
+ ///
+ /// Corresponds to "SQL Worker Thread Pool" resource which may be subject to throttling.
+ ///
+ WorkerThreads = 7,
+
+ ///
+ /// Corresponds to an internal resource which may be subject to throttling.
+ ///
+ Internal = 6,
+
+ ///
+ /// Corresponds to an unknown resource type in the event when the actual resource cannot be determined with certainty.
+ ///
+ Unknown = -1
+ }
+ #endregion
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.cs
new file mode 100644
index 00000000..9b554841
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/RetryPolicy.cs
@@ -0,0 +1,542 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+// This code is copied from the source described in the comment below.
+
+// =======================================================================================
+// Microsoft Windows Server AppFabric Customer Advisory Team (CAT) Best Practices Series
+//
+// This sample is supplemental to the technical guidance published on the community
+// blog at http://blogs.msdn.com/appfabriccat/ and copied from
+// sqlmain ./sql/manageability/mfx/common/
+//
+// =======================================================================================
+// Copyright © 2012 Microsoft Corporation. All rights reserved.
+//
+// THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER
+// EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF
+// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. YOU BEAR THE RISK OF USING IT.
+// =======================================================================================
+
+// namespace Microsoft.SQL.CAT.BestPractices.SqlAzure.Framework
+// namespace Microsoft.SqlServer.Management.Common
+
+using System;
+using System.Data.SqlClient;
+using System.Diagnostics;
+using System.Diagnostics.Contracts;
+using System.Globalization;
+using System.Threading;
+using Microsoft.SqlTools.EditorServices.Utility;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
+{
+ ///
+ /// Implements a policy defining and implementing the retry mechanism for unreliable actions.
+ ///
+ internal abstract partial class RetryPolicy
+ {
+ ///
+ /// Defines a callback delegate which will be invoked whenever a retry condition is encountered.
+ ///
+ /// The state of current retry attempt.
+ internal delegate void RetryCallbackDelegate(RetryState retryState);
+
+ ///
+ /// Defines a callback delegate which will be invoked whenever an error is ignored on retry.
+ ///
+ /// The state of current retry attempt.
+ internal delegate void IgnoreErrorCallbackDelegate(RetryState retryState);
+
+ private readonly IErrorDetectionStrategy _errorDetectionStrategy;
+
+ protected RetryPolicy(IErrorDetectionStrategy strategy)
+ {
+ Contract.Assert(strategy != null);
+
+ _errorDetectionStrategy = strategy;
+ this.FastFirstRetry = true;
+
+ //TODO Defect 1078447 Validate whether CommandTimeout needs to be used differently in schema/data scenarios
+ this.CommandTimeoutInSeconds = AmbientSettings.LongRunningQueryTimeoutSeconds;
+ }
+
+ ///
+ /// An instance of a callback delegate which will be invoked whenever a retry condition is encountered.
+ ///
+ public event RetryCallbackDelegate RetryOccurred;
+
+ ///
+ /// An instance of a callback delegate which will be invoked whenever an error is ignored on retry.
+ ///
+ public event IgnoreErrorCallbackDelegate IgnoreErrorOccurred;
+
+ ///
+ /// Gets or sets a value indicating whether or not the very first retry attempt will be made immediately
+ /// whereas the subsequent retries will remain subject to retry interval.
+ ///
+ public bool FastFirstRetry { get; set; }
+
+ ///
+ /// Gets or sets the timeout in seconds of sql commands
+ ///
+ public int CommandTimeoutInSeconds
+ {
+ get;
+ set;
+ }
+
+ ///
+ /// Gets the error detection strategy of this retry policy
+ ///
+ internal IErrorDetectionStrategy ErrorDetectionStrategy
+ {
+ get
+ {
+ return _errorDetectionStrategy;
+ }
+ }
+
+ ///
+ /// We should only ignore errors if they happen after the first retry.
+ /// This flag is used to allow the ignore even on first try, for testing purposes.
+ ///
+ ///
+ /// This flag is currently being used for TESTING PURPOSES ONLY.
+ ///
+ internal bool ShouldIgnoreOnFirstTry
+ {
+ get;
+ set;
+ }
+
+ protected static bool IsLessThanMaxRetryCount(int currentRetryCount, int maxRetryCount)
+ {
+ return currentRetryCount <= maxRetryCount;
+ }
+
+ ///
+ /// Repetitively executes the specified action while it satisfies the current retry policy.
+ ///
+ /// A delegate representing the executable action which doesn't return any results.
+ /// Cancellation token to cancel action between retries.
+ public void ExecuteAction(Action action, CancellationToken? token = null)
+ {
+ ExecuteAction(
+ _ => action(), token);
+ }
+
+ ///
+ /// Repetitively executes the specified action while it satisfies the current retry policy.
+ ///
+ /// A delegate representing the executable action which doesn't return any results.
+ /// Cancellation token to cancel action between retries.
+ public void ExecuteAction(Action action, CancellationToken? token = null)
+ {
+ ExecuteAction
public DbConnection CreateSqlConnection(string connectionString)
{
- return new SqlConnection(connectionString);
+ RetryPolicy connectionRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
+ RetryPolicy commandRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
+ return new ReliableSqlConnection(connectionString, connectionRetryPolicy, commandRetryPolicy);
}
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/ServiceHost.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/ServiceHost.cs
index 5f5ef1df..f252d3c6 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/ServiceHost.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/ServiceHost.cs
@@ -136,13 +136,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting
TextDocumentSync = TextDocumentSyncKind.Incremental,
DefinitionProvider = true,
ReferencesProvider = true,
- DocumentHighlightProvider = true,
- DocumentSymbolProvider = true,
- WorkspaceSymbolProvider = true,
+ DocumentHighlightProvider = true,
CompletionProvider = new CompletionOptions
{
ResolveProvider = true,
- TriggerCharacters = new string[] { ".", "-", ":", "\\" }
+ TriggerCharacters = new string[] { ".", "-", ":", "\\", ",", " " }
},
SignatureHelpProvider = new SignatureHelpOptions
{
diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteHelper.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteHelper.cs
new file mode 100644
index 00000000..1069e18d
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteHelper.cs
@@ -0,0 +1,514 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System.Collections.Generic;
+using Microsoft.SqlServer.Management.SqlParser.Intellisense;
+using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts;
+using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts;
+
+namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
+{
+ ///
+ /// Main class for Language Service functionality including anything that reqires knowledge of
+ /// the language to perfom, such as definitions, intellisense, etc.
+ ///
+ public static class AutoCompleteHelper
+ {
+ private static readonly string[] DefaultCompletionText = new string[]
+ {
+ "absolute",
+ "accent_sensitivity",
+ "action",
+ "activation",
+ "add",
+ "address",
+ "admin",
+ "after",
+ "aggregate",
+ "algorithm",
+ "allow_page_locks",
+ "allow_row_locks",
+ "allow_snapshot_isolation",
+ "alter",
+ "always",
+ "ansi_null_default",
+ "ansi_nulls",
+ "ansi_padding",
+ "ansi_warnings",
+ "application",
+ "arithabort",
+ "as",
+ "asc",
+ "assembly",
+ "asymmetric",
+ "at",
+ "atomic",
+ "audit",
+ "authentication",
+ "authorization",
+ "auto",
+ "auto_close",
+ "auto_shrink",
+ "auto_update_statistics",
+ "auto_update_statistics_async",
+ "availability",
+ "backup",
+ "before",
+ "begin",
+ "binary",
+ "bit",
+ "block",
+ "break",
+ "browse",
+ "bucket_count",
+ "bulk",
+ "by",
+ "call",
+ "caller",
+ "card",
+ "cascade",
+ "case",
+ "catalog",
+ "catch",
+ "change_tracking",
+ "changes",
+ "char",
+ "character",
+ "check",
+ "checkpoint",
+ "close",
+ "clustered",
+ "collection",
+ "column",
+ "column_encryption_key",
+ "columnstore",
+ "commit",
+ "compatibility_level",
+ "compress_all_row_groups",
+ "compression",
+ "compression_delay",
+ "compute",
+ "concat_null_yields_null",
+ "configuration",
+ "connect",
+ "constraint",
+ "containstable",
+ "continue",
+ "create",
+ "cube",
+ "current",
+ "current_date",
+ "cursor",
+ "cursor_close_on_commit",
+ "cursor_default",
+ "data",
+ "data_compression",
+ "database",
+ "date",
+ "date_correlation_optimization",
+ "datefirst",
+ "datetime",
+ "datetime2",
+ "days",
+ "db_chaining",
+ "dbcc",
+ "deallocate",
+ "dec",
+ "decimal",
+ "declare",
+ "default",
+ "delayed_durability",
+ "delete",
+ "deny",
+ "desc",
+ "description",
+ "disable_broker",
+ "disabled",
+ "disk",
+ "distinct",
+ "distributed",
+ "double",
+ "drop",
+ "drop_existing",
+ "dump",
+ "durability",
+ "dynamic",
+ "else",
+ "enable",
+ "encrypted",
+ "encryption_type",
+ "end",
+ "end-exec",
+ "entry",
+ "errlvl",
+ "escape",
+ "event",
+ "except",
+ "exec",
+ "execute",
+ "exit",
+ "external",
+ "fast_forward",
+ "fetch",
+ "file",
+ "filegroup",
+ "filename",
+ "filestream",
+ "fillfactor",
+ "filter",
+ "first",
+ "float",
+ "for",
+ "foreign",
+ "freetext",
+ "freetexttable",
+ "from",
+ "full",
+ "fullscan",
+ "fulltext",
+ "function",
+ "generated",
+ "geography",
+ "get",
+ "global",
+ "go",
+ "goto",
+ "grant",
+ "group",
+ "hash",
+ "hashed",
+ "having",
+ "hidden",
+ "hierarchyid",
+ "holdlock",
+ "hours",
+ "identity",
+ "identity_insert",
+ "identitycol",
+ "if",
+ "ignore_dup_key",
+ "image",
+ "immediate",
+ "include",
+ "index",
+ "inflectional",
+ "insensitive",
+ "insert",
+ "instead",
+ "int",
+ "integer",
+ "integrated",
+ "intersect",
+ "into",
+ "isolation",
+ "json",
+ "key",
+ "kill",
+ "language",
+ "last",
+ "legacy_cardinality_estimation",
+ "level",
+ "lineno",
+ "load",
+ "local",
+ "locate",
+ "location",
+ "login",
+ "masked",
+ "master",
+ "maxdop",
+ "memory_optimized",
+ "merge",
+ "message",
+ "modify",
+ "move",
+ "multi_user",
+ "namespace",
+ "national",
+ "native_compilation",
+ "nchar",
+ "next",
+ "no",
+ "nocheck",
+ "nocount",
+ "nonclustered",
+ "none",
+ "norecompute",
+ "now",
+ "numeric",
+ "numeric_roundabort",
+ "object",
+ "of",
+ "off",
+ "offsets",
+ "on",
+ "online",
+ "open",
+ "opendatasource",
+ "openquery",
+ "openrowset",
+ "openxml",
+ "option",
+ "order",
+ "out",
+ "output",
+ "over",
+ "owner",
+ "pad_index",
+ "page",
+ "page_verify",
+ "parameter_sniffing",
+ "parameterization",
+ "partial",
+ "partition",
+ "password",
+ "path",
+ "percent",
+ "percentage",
+ "period",
+ "persisted",
+ "plan",
+ "policy",
+ "population",
+ "precision",
+ "predicate",
+ "primary",
+ "print",
+ "prior",
+ "proc",
+ "procedure",
+ "public",
+ "query_optimizer_hotfixes",
+ "query_store",
+ "quoted_identifier",
+ "raiserror",
+ "range",
+ "raw",
+ "read",
+ "read_committed_snapshot",
+ "read_only",
+ "read_write",
+ "readonly",
+ "readtext",
+ "real",
+ "rebuild",
+ "receive",
+ "reconfigure",
+ "recovery",
+ "recursive",
+ "recursive_triggers",
+ "references",
+ "relative",
+ "remove",
+ "reorganize",
+ "replication",
+ "required",
+ "restart",
+ "restore",
+ "restrict",
+ "resume",
+ "return",
+ "returns",
+ "revert",
+ "revoke",
+ "role",
+ "rollback",
+ "rollup",
+ "row",
+ "rowcount",
+ "rowguidcol",
+ "rows",
+ "rule",
+ "sample",
+ "save",
+ "schema",
+ "schemabinding",
+ "scoped",
+ "scroll",
+ "secondary",
+ "security",
+ "securityaudit",
+ "select",
+ "semantickeyphrasetable",
+ "semanticsimilaritydetailstable",
+ "semanticsimilaritytable",
+ "send",
+ "sent",
+ "sequence",
+ "server",
+ "session",
+ "set",
+ "sets",
+ "setuser",
+ "shutdown",
+ "simple",
+ "smallint",
+ "smallmoney",
+ "snapshot",
+ "sort_in_tempdb",
+ "sql",
+ "standard",
+ "start",
+ "started",
+ "state",
+ "statement",
+ "static",
+ "statistics",
+ "statistics_norecompute",
+ "status",
+ "stopped",
+ "supported",
+ "symmetric",
+ "sysname",
+ "system",
+ "system_time",
+ "system_versioning",
+ "table",
+ "tablesample",
+ "take",
+ "target",
+ "textimage_on",
+ "textsize",
+ "then",
+ "thesaurus",
+ "throw",
+ "time",
+ "timestamp",
+ "tinyint",
+ "to",
+ "top",
+ "tran",
+ "transaction",
+ "trigger",
+ "truncate",
+ "trustworthy",
+ "try",
+ "tsql",
+ "type",
+ "union",
+ "unique",
+ "uniqueidentifier",
+ "unlimited",
+ "updatetext",
+ "use",
+ "user",
+ "using",
+ "value",
+ "values",
+ "varchar",
+ "varying",
+ "version",
+ "view",
+ "waitfor",
+ "weight",
+ "when",
+ "where",
+ "while",
+ "with",
+ "within",
+ "within group",
+ "without",
+ "writetext",
+ "xact_abort",
+ "xml",
+ "zone"
+ };
+
+ internal static CompletionItem[] GetDefaultCompletionItems(
+ int row,
+ int startColumn,
+ int endColumn)
+ {
+ var completionItems = new CompletionItem[DefaultCompletionText.Length];
+ for (int i = 0; i < DefaultCompletionText.Length; ++i)
+ {
+ completionItems[i] = CreateDefaultCompletionItem(
+ DefaultCompletionText[i].ToUpper(),
+ row,
+ startColumn,
+ endColumn);
+ }
+ return completionItems;
+ }
+
+ private static CompletionItem CreateDefaultCompletionItem(
+ string label,
+ int row,
+ int startColumn,
+ int endColumn)
+ {
+ return new CompletionItem()
+ {
+ Label = label,
+ Kind = CompletionItemKind.Keyword,
+ Detail = label + " keyword",
+ TextEdit = new TextEdit
+ {
+ NewText = label,
+ Range = new Range
+ {
+ Start = new Position
+ {
+ Line = row,
+ Character = startColumn
+ },
+ End = new Position
+ {
+ Line = row,
+ Character = endColumn
+ }
+ }
+ }
+ };
+ }
+
+ ///
+ /// Converts a list of Declaration objects to CompletionItem objects
+ /// since VS Code expects CompletionItems but SQL Parser works with Declarations
+ ///
+ ///
+ ///
+ ///
+ ///
+ internal static CompletionItem[] ConvertDeclarationsToCompletionItems(
+ IEnumerable suggestions,
+ int row,
+ int startColumn,
+ int endColumn)
+ {
+ List completions = new List();
+ foreach (var autoCompleteItem in suggestions)
+ {
+ // convert the completion item candidates into CompletionItems
+ completions.Add(new CompletionItem()
+ {
+ Label = autoCompleteItem.Title,
+ Kind = CompletionItemKind.Variable,
+ Detail = autoCompleteItem.Title,
+ Documentation = autoCompleteItem.Description,
+ TextEdit = new TextEdit
+ {
+ NewText = autoCompleteItem.Title,
+ Range = new Range
+ {
+ Start = new Position
+ {
+ Line = row,
+ Character = startColumn
+ },
+ End = new Position
+ {
+ Line = row,
+ Character = endColumn
+ }
+ }
+ }
+ });
+ }
+
+ return completions.ToArray();
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs
deleted file mode 100644
index 5abe27f7..00000000
--- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs
+++ /dev/null
@@ -1,323 +0,0 @@
-//
-// Copyright (c) Microsoft. All rights reserved.
-// Licensed under the MIT license. See LICENSE file in the project root for full license information.
-//
-
-using System;
-using System.Collections.Generic;
-using System.Data.SqlClient;
-using System.Threading.Tasks;
-using Microsoft.SqlServer.Management.Common;
-using Microsoft.SqlServer.Management.SmoMetadataProvider;
-using Microsoft.SqlServer.Management.SqlParser.Binder;
-using Microsoft.SqlServer.Management.SqlParser.Intellisense;
-using Microsoft.SqlServer.Management.SqlParser.MetadataProvider;
-using Microsoft.SqlTools.ServiceLayer.Connection;
-using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
-using Microsoft.SqlTools.ServiceLayer.Hosting;
-using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
-using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts;
-using Microsoft.SqlTools.ServiceLayer.SqlContext;
-using Microsoft.SqlTools.ServiceLayer.Workspace;
-using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts;
-
-namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
-{
- ///
- /// Main class for Autocomplete functionality
- ///
- public class AutoCompleteService
- {
- #region Singleton Instance Implementation
-
- ///
- /// Singleton service instance
- ///
- private static Lazy instance
- = new Lazy(() => new AutoCompleteService());
-
- ///
- /// Gets the singleton service instance
- ///
- public static AutoCompleteService Instance
- {
- get
- {
- return instance.Value;
- }
- }
-
- ///
- /// Default, parameterless constructor.
- /// Internal constructor for use in test cases only
- ///
- internal AutoCompleteService()
- {
- }
-
- #endregion
-
- private ConnectionService connectionService = null;
-
- ///
- /// Internal for testing purposes only
- ///
- internal ConnectionService ConnectionServiceInstance
- {
- get
- {
- if(connectionService == null)
- {
- connectionService = ConnectionService.Instance;
- }
- return connectionService;
- }
-
- set
- {
- connectionService = value;
- }
- }
-
- public void InitializeService(ServiceHost serviceHost)
- {
- // Register auto-complete request handler
- serviceHost.SetRequestHandler(CompletionRequest.Type, HandleCompletionRequest);
-
- // Register a callback for when a connection is created
- ConnectionServiceInstance.RegisterOnConnectionTask(UpdateAutoCompleteCache);
-
- // Register a callback for when a connection is closed
- ConnectionServiceInstance.RegisterOnDisconnectTask(RemoveAutoCompleteCacheUriReference);
- }
-
- ///
- /// Auto-complete completion provider request callback
- ///
- ///
- ///
- ///
- private static async Task HandleCompletionRequest(
- TextDocumentPosition textDocumentPosition,
- RequestContext requestContext)
- {
- // get the current list of completion items and return to client
- var scriptFile = WorkspaceService.Instance.Workspace.GetFile(
- textDocumentPosition.TextDocument.Uri);
-
- ConnectionInfo connInfo;
- ConnectionService.Instance.TryFindConnection(
- scriptFile.ClientFilePath,
- out connInfo);
-
- var completionItems = Instance.GetCompletionItems(
- textDocumentPosition, scriptFile, connInfo);
-
- await requestContext.SendResult(completionItems);
- }
-
- ///
- /// Remove a reference to an autocomplete cache from a URI. If
- /// it is the last URI connected to a particular connection,
- /// then remove the cache.
- ///
- public async Task RemoveAutoCompleteCacheUriReference(ConnectionSummary summary)
- {
- // currently this method is disabled, but we need to reimplement now that the
- // implementation of the 'cache' has changed.
- await Task.FromResult(0);
- }
-
- ///
- /// Update the cached autocomplete candidate list when the user connects to a database
- ///
- ///
- public async Task UpdateAutoCompleteCache(ConnectionInfo info)
- {
- await Task.Run( () =>
- {
- if (!LanguageService.Instance.ScriptParseInfoMap.ContainsKey(info.OwnerUri))
- {
- var sqlConn = info.SqlConnection as SqlConnection;
- if (sqlConn != null)
- {
- var srvConn = new ServerConnection(sqlConn);
- var displayInfoProvider = new MetadataDisplayInfoProvider();
- var metadataProvider = SmoMetadataProvider.CreateConnectedProvider(srvConn);
- var binder = BinderProvider.CreateBinder(metadataProvider);
-
- LanguageService.Instance.ScriptParseInfoMap.Add(info.OwnerUri,
- new ScriptParseInfo()
- {
- Binder = binder,
- MetadataProvider = metadataProvider,
- MetadataDisplayInfoProvider = displayInfoProvider
- });
-
- var scriptFile = WorkspaceService.Instance.Workspace.GetFile(info.OwnerUri);
-
- LanguageService.Instance.ParseAndBind(scriptFile, info);
- }
- }
- });
- }
-
- ///
- /// Find the position of the previous delimeter for autocomplete token replacement.
- /// SQL Parser may have similar functionality in which case we'll delete this method.
- ///
- ///
- ///
- ///
- ///
- private int PositionOfPrevDelimeter(string sql, int startRow, int startColumn)
- {
- if (string.IsNullOrWhiteSpace(sql))
- {
- return 1;
- }
-
- int prevLineColumns = 0;
- for (int i = 0; i < startRow; ++i)
- {
- while (sql[prevLineColumns] != '\n' && prevLineColumns < sql.Length)
- {
- ++prevLineColumns;
- }
- ++prevLineColumns;
- }
-
- startColumn += prevLineColumns;
-
- if (startColumn - 1 < sql.Length)
- {
- while (--startColumn >= prevLineColumns)
- {
- if (sql[startColumn] == ' '
- || sql[startColumn] == '\t'
- || sql[startColumn] == '\n'
- || sql[startColumn] == '.'
- || sql[startColumn] == '+'
- || sql[startColumn] == '-'
- || sql[startColumn] == '*'
- || sql[startColumn] == '>'
- || sql[startColumn] == '<'
- || sql[startColumn] == '='
- || sql[startColumn] == '/'
- || sql[startColumn] == '%')
- {
- break;
- }
- }
- }
-
- return startColumn + 1 - prevLineColumns;
- }
-
- ///
- /// Determines whether a reparse and bind is required to provide autocomplete
- ///
- ///
- /// TEMP: Currently hard-coded to false for perf
- private bool RequiresReparse(ScriptParseInfo info)
- {
- return false;
- }
-
- ///
- /// Converts a list of Declaration objects to CompletionItem objects
- /// since VS Code expects CompletionItems but SQL Parser works with Declarations
- ///
- ///
- ///
- ///
- ///
- private CompletionItem[] ConvertDeclarationsToCompletionItems(
- IEnumerable suggestions,
- int row,
- int startColumn,
- int endColumn)
- {
- List completions = new List();
- foreach (var autoCompleteItem in suggestions)
- {
- // convert the completion item candidates into CompletionItems
- completions.Add(new CompletionItem()
- {
- Label = autoCompleteItem.Title,
- Kind = CompletionItemKind.Keyword,
- Detail = autoCompleteItem.Title,
- Documentation = autoCompleteItem.Description,
- TextEdit = new TextEdit
- {
- NewText = autoCompleteItem.Title,
- Range = new Range
- {
- Start = new Position
- {
- Line = row,
- Character = startColumn
- },
- End = new Position
- {
- Line = row,
- Character = endColumn
- }
- }
- }
- });
- }
-
- return completions.ToArray();
- }
-
- ///
- /// Return the completion item list for the current text position.
- /// This method does not await cache builds since it expects to return quickly
- ///
- ///
- public CompletionItem[] GetCompletionItems(
- TextDocumentPosition textDocumentPosition,
- ScriptFile scriptFile,
- ConnectionInfo connInfo)
- {
- string filePath = textDocumentPosition.TextDocument.Uri;
-
- // Take a reference to the list at a point in time in case we update and replace the list
- if (connInfo == null
- || !LanguageService.Instance.ScriptParseInfoMap.ContainsKey(textDocumentPosition.TextDocument.Uri))
- {
- return new CompletionItem[0];
- }
-
- // reparse and bind the SQL statement if needed
- var scriptParseInfo = LanguageService.Instance.ScriptParseInfoMap[textDocumentPosition.TextDocument.Uri];
- if (RequiresReparse(scriptParseInfo))
- {
- LanguageService.Instance.ParseAndBind(scriptFile, connInfo);
- }
-
- if (scriptParseInfo.ParseResult == null)
- {
- return new CompletionItem[0];
- }
-
- // get the completion list from SQL Parser
- var suggestions = Resolver.FindCompletions(
- scriptParseInfo.ParseResult,
- textDocumentPosition.Position.Line + 1,
- textDocumentPosition.Position.Character + 1,
- scriptParseInfo.MetadataDisplayInfoProvider);
-
- // convert the suggestion list to the VS Code format
- return ConvertDeclarationsToCompletionItems(
- suggestions,
- textDocumentPosition.Position.Line,
- PositionOfPrevDelimeter(
- scriptFile.Contents,
- textDocumentPosition.Position.Line,
- textDocumentPosition.Position.Character),
- textDocumentPosition.Position.Character);
- }
- }
-}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/DiagnosticsHelper.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/DiagnosticsHelper.cs
new file mode 100644
index 00000000..0c493b0f
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/DiagnosticsHelper.cs
@@ -0,0 +1,98 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System.Linq;
+using System.Threading.Tasks;
+using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
+using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts;
+using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts;
+
+namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
+{
+ ///
+ /// Main class for Language Service functionality including anything that reqires knowledge of
+ /// the language to perfom, such as definitions, intellisense, etc.
+ ///
+ public static class DiagnosticsHelper
+ {
+ ///
+ /// Send the diagnostic results back to the host application
+ ///
+ ///
+ ///
+ ///
+ internal static async Task PublishScriptDiagnostics(
+ ScriptFile scriptFile,
+ ScriptFileMarker[] semanticMarkers,
+ EventContext eventContext)
+ {
+ var allMarkers = scriptFile.SyntaxMarkers != null
+ ? scriptFile.SyntaxMarkers.Concat(semanticMarkers)
+ : semanticMarkers;
+
+ // Always send syntax and semantic errors. We want to
+ // make sure no out-of-date markers are being displayed.
+ await eventContext.SendEvent(
+ PublishDiagnosticsNotification.Type,
+ new PublishDiagnosticsNotification
+ {
+ Uri = scriptFile.ClientFilePath,
+ Diagnostics =
+ allMarkers
+ .Select(GetDiagnosticFromMarker)
+ .ToArray()
+ });
+ }
+
+ ///
+ /// Convert a ScriptFileMarker to a Diagnostic that is Language Service compatible
+ ///
+ ///
+ ///
+ internal static Diagnostic GetDiagnosticFromMarker(ScriptFileMarker scriptFileMarker)
+ {
+ return new Diagnostic
+ {
+ Severity = MapDiagnosticSeverity(scriptFileMarker.Level),
+ Message = scriptFileMarker.Message,
+ Range = new Range
+ {
+ Start = new Position
+ {
+ Line = scriptFileMarker.ScriptRegion.StartLineNumber - 1,
+ Character = scriptFileMarker.ScriptRegion.StartColumnNumber - 1
+ },
+ End = new Position
+ {
+ Line = scriptFileMarker.ScriptRegion.EndLineNumber - 1,
+ Character = scriptFileMarker.ScriptRegion.EndColumnNumber - 1
+ }
+ }
+ };
+ }
+
+ ///
+ /// Map ScriptFileMarker severity to Diagnostic severity
+ ///
+ ///
+ internal static DiagnosticSeverity MapDiagnosticSeverity(ScriptFileMarkerLevel markerLevel)
+ {
+ switch (markerLevel)
+ {
+ case ScriptFileMarkerLevel.Error:
+ return DiagnosticSeverity.Error;
+
+ case ScriptFileMarkerLevel.Warning:
+ return DiagnosticSeverity.Warning;
+
+ case ScriptFileMarkerLevel.Information:
+ return DiagnosticSeverity.Information;
+
+ default:
+ return DiagnosticSeverity.Error;
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs
index d414b41e..d43840a2 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs
@@ -5,22 +5,28 @@
using System;
using System.Collections.Generic;
+using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SqlTools.EditorServices.Utility;
+using Microsoft.SqlTools.ServiceLayer.Connection;
+using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
+using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
using Microsoft.SqlTools.ServiceLayer.Hosting;
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts;
using Microsoft.SqlTools.ServiceLayer.SqlContext;
using Microsoft.SqlTools.ServiceLayer.Workspace;
using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts;
-using System.Linq;
-using Microsoft.SqlServer.Management.SqlParser.Parser;
-using Location = Microsoft.SqlTools.ServiceLayer.Workspace.Contracts.Location;
-using Microsoft.SqlTools.ServiceLayer.Connection;
-using Microsoft.SqlServer.Management.SqlParser.Binder;
using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlServer.Management.SqlParser;
+using Microsoft.SqlServer.Management.SqlParser.Binder;
+using Microsoft.SqlServer.Management.SqlParser.Intellisense;
+using Microsoft.SqlServer.Management.SqlParser.MetadataProvider;
+using Microsoft.SqlServer.Management.SqlParser.Parser;
+using Microsoft.SqlServer.Management.SmoMetadataProvider;
+
+using Location = Microsoft.SqlTools.ServiceLayer.Workspace.Contracts.Location;
namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
{
@@ -30,6 +36,42 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
///
public sealed class LanguageService
{
+ public const string DefaultBatchSeperator = "GO";
+
+ private const int DiagnosticParseDelay = 750;
+
+ private const int FindCompletionsTimeout = 3000;
+
+ private const int FindCompletionStartTimeout = 50;
+
+ private const int OnConnectionWaitTimeout = 30000;
+
+ private bool ShouldEnableAutocomplete()
+ {
+ return true;
+ }
+
+ private ConnectionService connectionService = null;
+
+ ///
+ /// Internal for testing purposes only
+ ///
+ internal ConnectionService ConnectionServiceInstance
+ {
+ get
+ {
+ if(connectionService == null)
+ {
+ connectionService = ConnectionService.Instance;
+ }
+ return connectionService;
+ }
+
+ set
+ {
+ connectionService = value;
+ }
+ }
#region Singleton Instance Implementation
@@ -98,8 +140,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
serviceHost.SetRequestHandler(SignatureHelpRequest.Type, HandleSignatureHelpRequest);
serviceHost.SetRequestHandler(DocumentHighlightRequest.Type, HandleDocumentHighlightRequest);
serviceHost.SetRequestHandler(HoverRequest.Type, HandleHoverRequest);
- serviceHost.SetRequestHandler(DocumentSymbolRequest.Type, HandleDocumentSymbolRequest);
- serviceHost.SetRequestHandler(WorkspaceSymbolRequest.Type, HandleWorkspaceSymbolRequest);
+ serviceHost.SetRequestHandler(CompletionRequest.Type, HandleCompletionRequest);
// Register a no-op shutdown task for validation of the shutdown logic
serviceHost.RegisterShutdownTask(async (shutdownParams, shutdownRequestContext) =>
@@ -115,104 +156,47 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
WorkspaceService.Instance.RegisterTextDocChangeCallback(HandleDidChangeTextDocumentNotification);
// Register the file open update handler
- WorkspaceService.Instance.RegisterTextDocOpenCallback(HandleDidOpenTextDocumentNotification);
+ WorkspaceService.Instance.RegisterTextDocOpenCallback(HandleDidOpenTextDocumentNotification);
+
+ // Register a callback for when a connection is created
+ ConnectionServiceInstance.RegisterOnConnectionTask(UpdateLanguageServiceOnConnection);
+
+ // Register a callback for when a connection is closed
+ ConnectionServiceInstance.RegisterOnDisconnectTask(RemoveAutoCompleteCacheUriReference);
// Store the SqlToolsContext for future use
Context = context;
}
+ #endregion
+
+ #region Request Handlers
+
///
- /// Parses the SQL text and binds it to the SMO metadata provider if connected
+ /// Auto-complete completion provider request callback
///
- ///
- ///
+ ///
+ ///
///
- public ParseResult ParseAndBind(ScriptFile scriptFile, ConnectionInfo connInfo)
+ private static async Task HandleCompletionRequest(
+ TextDocumentPosition textDocumentPosition,
+ RequestContext requestContext)
{
- ScriptParseInfo parseInfo = null;
- if (this.ScriptParseInfoMap.ContainsKey(scriptFile.ClientFilePath))
- {
- parseInfo = this.ScriptParseInfoMap[scriptFile.ClientFilePath];
- }
+ // get the current list of completion items and return to client
+ var scriptFile = WorkspaceService.Instance.Workspace.GetFile(
+ textDocumentPosition.TextDocument.Uri);
- // parse current SQL file contents to retrieve a list of errors
- ParseOptions parseOptions = new ParseOptions();
- ParseResult parseResult = Parser.IncrementalParse(
- scriptFile.Contents,
- parseInfo != null ? parseInfo.ParseResult : null,
- parseOptions);
-
- // save previous result for next incremental parse
- if (parseInfo != null)
- {
- parseInfo.ParseResult = parseResult;
- }
-
- if (connInfo != null)
- {
- try
- {
- List parseResults = new List();
- parseResults.Add(parseResult);
- parseInfo.Binder.Bind(
- parseResults,
- connInfo.ConnectionDetails.DatabaseName,
- BindMode.Batch);
- }
- catch (ConnectionException)
- {
- Logger.Write(LogLevel.Error, "Hit connection exception while binding - disposing binder object...");
- }
- catch (SqlParserInternalBinderError)
- {
- Logger.Write(LogLevel.Error, "Hit connection exception while binding - disposing binder object...");
- }
- }
-
- return parseResult;
- }
-
- ///
- /// Gets a list of semantic diagnostic marks for the provided script file
- ///
- ///
- public ScriptFileMarker[] GetSemanticMarkers(ScriptFile scriptFile)
- {
ConnectionInfo connInfo;
ConnectionService.Instance.TryFindConnection(
scriptFile.ClientFilePath,
out connInfo);
-
- var parseResult = ParseAndBind(scriptFile, connInfo);
- // build a list of SQL script file markers from the errors
- List markers = new List();
- foreach (var error in parseResult.Errors)
- {
- markers.Add(new ScriptFileMarker()
- {
- Message = error.Message,
- Level = ScriptFileMarkerLevel.Error,
- ScriptRegion = new ScriptRegion()
- {
- File = scriptFile.FilePath,
- StartLineNumber = error.Start.LineNumber,
- StartColumnNumber = error.Start.ColumnNumber,
- StartOffset = 0,
- EndLineNumber = error.End.LineNumber,
- EndColumnNumber = error.End.ColumnNumber,
- EndOffset = 0
- }
- });
- }
+ var completionItems = Instance.GetCompletionItems(
+ textDocumentPosition, scriptFile, connInfo);
- return markers.ToArray();
+ await requestContext.SendResult(completionItems);
}
- #endregion
-
- #region Request Handlers
-
private static async Task HandleDefinitionRequest(
TextDocumentPosition textDocumentPosition,
RequestContext requestContext)
@@ -261,22 +245,6 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
await Task.FromResult(true);
}
- private static async Task HandleDocumentSymbolRequest(
- DocumentSymbolParams documentSymbolParams,
- RequestContext requestContext)
- {
- Logger.Write(LogLevel.Verbose, "HandleDocumentSymbolRequest");
- await Task.FromResult(true);
- }
-
- private static async Task HandleWorkspaceSymbolRequest(
- WorkspaceSymbolParams workspaceSymbolParams,
- RequestContext requestContext)
- {
- Logger.Write(LogLevel.Verbose, "HandleWorkspaceSymbolRequest");
- await Task.FromResult(true);
- }
-
#endregion
#region Handlers for Events from Other Services
@@ -336,7 +304,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
foreach (var scriptFile in WorkspaceService.Instance.Workspace.GetOpenedFiles())
{
- await PublishScriptDiagnostics(scriptFile, emptyAnalysisDiagnostics, eventContext);
+ await DiagnosticsHelper.PublishScriptDiagnostics(scriptFile, emptyAnalysisDiagnostics, eventContext);
}
}
else
@@ -352,7 +320,269 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
#endregion
- #region Private Helpers
+
+ #region "AutoComplete Provider methods"
+
+ ///
+ /// Remove a reference to an autocomplete cache from a URI. If
+ /// it is the last URI connected to a particular connection,
+ /// then remove the cache.
+ ///
+ public async Task RemoveAutoCompleteCacheUriReference(ConnectionSummary summary)
+ {
+ // currently this method is disabled, but we need to reimplement now that the
+ // implementation of the 'cache' has changed.
+ await Task.FromResult(0);
+ }
+
+ ///
+ /// Parses the SQL text and binds it to the SMO metadata provider if connected
+ ///
+ ///
+ ///
+ ///
+ public ParseResult ParseAndBind(ScriptFile scriptFile, ConnectionInfo connInfo)
+ {
+ ScriptParseInfo parseInfo = null;
+ if (this.ScriptParseInfoMap.ContainsKey(scriptFile.ClientFilePath))
+ {
+ parseInfo = this.ScriptParseInfoMap[scriptFile.ClientFilePath];
+ }
+ else
+ {
+ parseInfo = new ScriptParseInfo();
+ this.ScriptParseInfoMap.Add(scriptFile.ClientFilePath, parseInfo);
+ }
+
+ if (parseInfo.BuildingMetadataEvent.WaitOne(LanguageService.FindCompletionsTimeout))
+ {
+ try
+ {
+ parseInfo.BuildingMetadataEvent.Reset();
+
+ // parse current SQL file contents to retrieve a list of errors
+ ParseResult parseResult = Parser.IncrementalParse(
+ scriptFile.Contents,
+ parseInfo.ParseResult,
+ parseInfo.ParseOptions);
+
+ parseInfo.ParseResult = parseResult;
+
+ if (connInfo != null && parseInfo.IsConnected)
+ {
+ try
+ {
+ List parseResults = new List();
+ parseResults.Add(parseResult);
+ parseInfo.Binder.Bind(
+ parseResults,
+ connInfo.ConnectionDetails.DatabaseName,
+ BindMode.Batch);
+ }
+ catch (ConnectionException)
+ {
+ Logger.Write(LogLevel.Error, "Hit connection exception while binding - disposing binder object...");
+ }
+ catch (SqlParserInternalBinderError)
+ {
+ Logger.Write(LogLevel.Error, "Hit connection exception while binding - disposing binder object...");
+ }
+ }
+ }
+ finally
+ {
+ parseInfo.BuildingMetadataEvent.Set();
+ }
+ }
+
+ return parseInfo.ParseResult;
+ }
+
+ ///
+ /// Update the cached autocomplete candidate list when the user connects to a database
+ ///
+ ///
+ public async Task UpdateLanguageServiceOnConnection(ConnectionInfo info)
+ {
+ await Task.Run( () =>
+ {
+ if (ShouldEnableAutocomplete())
+ {
+ ScriptParseInfo scriptInfo =
+ this.ScriptParseInfoMap.ContainsKey(info.OwnerUri)
+ ? this.ScriptParseInfoMap[info.OwnerUri]
+ : new ScriptParseInfo();
+
+ try
+ {
+ scriptInfo.BuildingMetadataEvent.WaitOne(LanguageService.OnConnectionWaitTimeout);
+ scriptInfo.BuildingMetadataEvent.Reset();
+
+ var sqlConn = info.SqlConnection as ReliableSqlConnection;
+ if (sqlConn != null)
+ {
+ ServerConnection serverConn = new ServerConnection(sqlConn.GetUnderlyingConnection());
+ scriptInfo.MetadataDisplayInfoProvider = new MetadataDisplayInfoProvider();
+ scriptInfo.MetadataProvider = SmoMetadataProvider.CreateConnectedProvider(serverConn);
+ scriptInfo.Binder = BinderProvider.CreateBinder(scriptInfo.MetadataProvider);
+ scriptInfo.ServerConnection = new ServerConnection(sqlConn.GetUnderlyingConnection());
+ this.ScriptParseInfoMap[info.OwnerUri] = scriptInfo;
+ }
+ }
+ catch (Exception)
+ {
+ scriptInfo.IsConnected = false;
+ }
+ finally
+ {
+ // Set Metadata Build event to Signal state.
+ // (Tell Language Service that I am ready with Metadata Provider Object)
+ scriptInfo.BuildingMetadataEvent.Set();
+ }
+
+ if (scriptInfo.IsConnected)
+ {
+ var scriptFile = WorkspaceService.Instance.Workspace.GetFile(info.OwnerUri);
+ ParseAndBind(scriptFile, info);
+ }
+ }
+ });
+ }
+
+ ///
+ /// Determines whether a reparse and bind is required to provide autocomplete
+ ///
+ ///
+ private bool RequiresReparse(ScriptParseInfo info, ScriptFile scriptFile)
+ {
+ if (info.ParseResult == null)
+ {
+ return true;
+ }
+
+ string prevSqlText = info.ParseResult.Script.Sql;
+ string currentSqlText = scriptFile.Contents;
+
+ return prevSqlText.Length != currentSqlText.Length
+ || !string.Equals(prevSqlText, currentSqlText);
+ }
+
+ ///
+ /// Return the completion item list for the current text position.
+ /// This method does not await cache builds since it expects to return quickly
+ ///
+ ///
+ public CompletionItem[] GetCompletionItems(
+ TextDocumentPosition textDocumentPosition,
+ ScriptFile scriptFile,
+ ConnectionInfo connInfo)
+ {
+ string filePath = textDocumentPosition.TextDocument.Uri;
+ int startLine = textDocumentPosition.Position.Line;
+ int startColumn = TextUtilities.PositionOfPrevDelimeter(
+ scriptFile.Contents,
+ textDocumentPosition.Position.Line,
+ textDocumentPosition.Position.Character);
+ int endColumn = textDocumentPosition.Position.Character;
+
+ // Take a reference to the list at a point in time in case we update and replace the list
+ if (connInfo == null
+ || !LanguageService.Instance.ScriptParseInfoMap.ContainsKey(textDocumentPosition.TextDocument.Uri))
+ {
+ return AutoCompleteHelper.GetDefaultCompletionItems(startLine, startColumn, endColumn);
+ }
+
+ // reparse and bind the SQL statement if needed
+ var scriptParseInfo = ScriptParseInfoMap[textDocumentPosition.TextDocument.Uri];
+ if (RequiresReparse(scriptParseInfo, scriptFile))
+ {
+ ParseAndBind(scriptFile, connInfo);
+ }
+
+ if (scriptParseInfo.ParseResult == null)
+ {
+ return AutoCompleteHelper.GetDefaultCompletionItems(startLine, startColumn, endColumn);
+ }
+
+ if (scriptParseInfo.IsConnected
+ && scriptParseInfo.BuildingMetadataEvent.WaitOne(LanguageService.FindCompletionStartTimeout))
+ {
+ scriptParseInfo.BuildingMetadataEvent.Reset();
+ Task findCompletionsTask = Task.Run(() => {
+ try
+ {
+ // get the completion list from SQL Parser
+ var suggestions = Resolver.FindCompletions(
+ scriptParseInfo.ParseResult,
+ textDocumentPosition.Position.Line + 1,
+ textDocumentPosition.Position.Character + 1,
+ scriptParseInfo.MetadataDisplayInfoProvider);
+
+ // convert the suggestion list to the VS Code format
+ return AutoCompleteHelper.ConvertDeclarationsToCompletionItems(
+ suggestions,
+ startLine,
+ startColumn,
+ endColumn);
+ }
+ finally
+ {
+ scriptParseInfo.BuildingMetadataEvent.Set();
+ }
+ });
+
+ findCompletionsTask.Wait(LanguageService.FindCompletionsTimeout);
+ if (findCompletionsTask.IsCompleted
+ && findCompletionsTask.Result != null
+ && findCompletionsTask.Result.Length > 0)
+ {
+ return findCompletionsTask.Result;
+ }
+ }
+
+ return AutoCompleteHelper.GetDefaultCompletionItems(startLine, startColumn, endColumn);
+ }
+
+ #endregion
+
+ #region Diagnostic Provider methods
+
+ ///
+ /// Gets a list of semantic diagnostic marks for the provided script file
+ ///
+ ///
+ internal ScriptFileMarker[] GetSemanticMarkers(ScriptFile scriptFile)
+ {
+ ConnectionInfo connInfo;
+ ConnectionService.Instance.TryFindConnection(
+ scriptFile.ClientFilePath,
+ out connInfo);
+
+ var parseResult = ParseAndBind(scriptFile, connInfo);
+
+ // build a list of SQL script file markers from the errors
+ List markers = new List();
+ foreach (var error in parseResult.Errors)
+ {
+ markers.Add(new ScriptFileMarker()
+ {
+ Message = error.Message,
+ Level = ScriptFileMarkerLevel.Error,
+ ScriptRegion = new ScriptRegion()
+ {
+ File = scriptFile.FilePath,
+ StartLineNumber = error.Start.LineNumber,
+ StartColumnNumber = error.Start.ColumnNumber,
+ StartOffset = 0,
+ EndLineNumber = error.End.LineNumber,
+ EndColumnNumber = error.End.ColumnNumber,
+ EndOffset = 0
+ }
+ });
+ }
+
+ return markers.ToArray();
+ }
///
/// Runs script diagnostics on changed files
@@ -401,7 +631,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
Task.Factory.StartNew(
() =>
DelayThenInvokeDiagnostics(
- 750,
+ LanguageService.DiagnosticParseDelay,
filesToAnalyze,
eventContext,
ExistingRequestCancellation.Token),
@@ -451,85 +681,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
ScriptFileMarker[] semanticMarkers = GetSemanticMarkers(scriptFile);
Logger.Write(LogLevel.Verbose, "Analysis complete.");
- await PublishScriptDiagnostics(scriptFile, semanticMarkers, eventContext);
- }
- }
-
- ///
- /// Send the diagnostic results back to the host application
- ///
- ///
- ///
- ///
- private static async Task PublishScriptDiagnostics(
- ScriptFile scriptFile,
- ScriptFileMarker[] semanticMarkers,
- EventContext eventContext)
- {
- var allMarkers = scriptFile.SyntaxMarkers != null
- ? scriptFile.SyntaxMarkers.Concat(semanticMarkers)
- : semanticMarkers;
-
- // Always send syntax and semantic errors. We want to
- // make sure no out-of-date markers are being displayed.
- await eventContext.SendEvent(
- PublishDiagnosticsNotification.Type,
- new PublishDiagnosticsNotification
- {
- Uri = scriptFile.ClientFilePath,
- Diagnostics =
- allMarkers
- .Select(GetDiagnosticFromMarker)
- .ToArray()
- });
- }
-
- ///
- /// Convert a ScriptFileMarker to a Diagnostic that is Language Service compatible
- ///
- ///
- ///
- private static Diagnostic GetDiagnosticFromMarker(ScriptFileMarker scriptFileMarker)
- {
- return new Diagnostic
- {
- Severity = MapDiagnosticSeverity(scriptFileMarker.Level),
- Message = scriptFileMarker.Message,
- Range = new Range
- {
- Start = new Position
- {
- Line = scriptFileMarker.ScriptRegion.StartLineNumber - 1,
- Character = scriptFileMarker.ScriptRegion.StartColumnNumber - 1
- },
- End = new Position
- {
- Line = scriptFileMarker.ScriptRegion.EndLineNumber - 1,
- Character = scriptFileMarker.ScriptRegion.EndColumnNumber - 1
- }
- }
- };
- }
-
- ///
- /// Map ScriptFileMarker severity to Diagnostic severity
- ///
- ///
- private static DiagnosticSeverity MapDiagnosticSeverity(ScriptFileMarkerLevel markerLevel)
- {
- switch (markerLevel)
- {
- case ScriptFileMarkerLevel.Error:
- return DiagnosticSeverity.Error;
-
- case ScriptFileMarkerLevel.Warning:
- return DiagnosticSeverity.Warning;
-
- case ScriptFileMarkerLevel.Information:
- return DiagnosticSeverity.Information;
-
- default:
- return DiagnosticSeverity.Error;
+ await DiagnosticsHelper.PublishScriptDiagnostics(scriptFile, semanticMarkers, eventContext);
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ScriptParseInfo.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ScriptParseInfo.cs
index 4da2c57e..48fb2cce 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ScriptParseInfo.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ScriptParseInfo.cs
@@ -3,10 +3,14 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
+using Microsoft.SqlServer.Management.Common;
+using Microsoft.SqlServer.Management.SmoMetadataProvider;
using Microsoft.SqlServer.Management.SqlParser.Binder;
+using Microsoft.SqlServer.Management.SqlParser.Common;
using Microsoft.SqlServer.Management.SqlParser.MetadataProvider;
using Microsoft.SqlServer.Management.SqlParser.Parser;
-using Microsoft.SqlServer.Management.SmoMetadataProvider;
+using System;
+using System.Threading;
namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
{
@@ -15,6 +19,110 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
///
internal class ScriptParseInfo
{
+ private ManualResetEvent buildingMetadataEvent = new ManualResetEvent(initialState: true);
+
+ private ParseOptions parseOptions = new ParseOptions();
+
+ private ServerConnection serverConnection;
+
+ ///
+ /// Event which tells if MetadataProvider is built fully or not
+ ///
+ public ManualResetEvent BuildingMetadataEvent
+ {
+ get { return this.buildingMetadataEvent; }
+ }
+
+
+ ///
+ /// Gets or sets a flag determining is the LanguageService is connected
+ ///
+ public bool IsConnected { get; set; }
+
+ ///
+ /// Gets or sets the LanguageService SMO ServerConnection
+ ///
+ public ServerConnection ServerConnection
+ {
+ get
+ {
+ return this.serverConnection;
+ }
+ set
+ {
+ this.serverConnection = value;
+ this.parseOptions = new ParseOptions(
+ batchSeparator: LanguageService.DefaultBatchSeperator,
+ isQuotedIdentifierSet: true,
+ compatibilityLevel: DatabaseCompatibilityLevel,
+ transactSqlVersion: TransactSqlVersion);
+ this.IsConnected = true;
+ }
+ }
+
+ ///
+ /// Gets the Language Service ServerVersion
+ ///
+ public ServerVersion ServerVersion
+ {
+ get
+ {
+ return this.ServerConnection != null
+ ? this.ServerConnection.ServerVersion
+ : null;
+ }
+ }
+
+ ///
+ /// Gets the current DataEngineType
+ ///
+ public DatabaseEngineType DatabaseEngineType
+ {
+ get
+ {
+ return this.ServerConnection != null
+ ? this.ServerConnection.DatabaseEngineType
+ : DatabaseEngineType.Standalone;
+ }
+ }
+
+ ///
+ /// Gets the current connections TransactSqlVersion
+ ///
+ public TransactSqlVersion TransactSqlVersion
+ {
+ get
+ {
+ return this.IsConnected
+ ? GetTransactSqlVersion(this.ServerVersion)
+ : TransactSqlVersion.Current;
+ }
+ }
+
+ ///
+ /// Gets the current DatabaseCompatibilityLevel
+ ///
+ public DatabaseCompatibilityLevel DatabaseCompatibilityLevel
+ {
+ get
+ {
+ return this.IsConnected
+ ? GetDatabaseCompatibilityLevel(this.ServerVersion)
+ : DatabaseCompatibilityLevel.Current;
+ }
+ }
+
+ ///
+ /// Gets the current ParseOptions
+ ///
+ public ParseOptions ParseOptions
+ {
+ get
+ {
+ return this.parseOptions;
+ }
+ }
+
///
/// Gets or sets the SMO binder for schema-aware intellisense
///
@@ -28,13 +136,63 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
///
/// Gets or set the SMO metadata provider that's bound to the current connection
///
- ///
public SmoMetadataProvider MetadataProvider { get; set; }
///
/// Gets or sets the SMO metadata display info provider
///
- ///
public MetadataDisplayInfoProvider MetadataDisplayInfoProvider { get; set; }
+
+ ///
+ /// Gets the database compatibility level from a server version
+ ///
+ ///
+ private static DatabaseCompatibilityLevel GetDatabaseCompatibilityLevel(ServerVersion serverVersion)
+ {
+ int versionMajor = Math.Max(serverVersion.Major, 8);
+
+ switch (versionMajor)
+ {
+ case 8:
+ return DatabaseCompatibilityLevel.Version80;
+ case 9:
+ return DatabaseCompatibilityLevel.Version90;
+ case 10:
+ return DatabaseCompatibilityLevel.Version100;
+ case 11:
+ return DatabaseCompatibilityLevel.Version110;
+ case 12:
+ return DatabaseCompatibilityLevel.Version120;
+ case 13:
+ return DatabaseCompatibilityLevel.Version130;
+ default:
+ return DatabaseCompatibilityLevel.Current;
+ }
+ }
+
+ ///
+ /// Gets the transaction sql version from a server version
+ ///
+ ///
+ private static TransactSqlVersion GetTransactSqlVersion(ServerVersion serverVersion)
+ {
+ int versionMajor = Math.Max(serverVersion.Major, 9);
+
+ switch (versionMajor)
+ {
+ case 9:
+ case 10:
+ // In case of 10.0 we still use Version 10.5 as it is the closest available.
+ return TransactSqlVersion.Version105;
+ case 11:
+ return TransactSqlVersion.Version110;
+ case 12:
+ return TransactSqlVersion.Version120;
+ case 13:
+ return TransactSqlVersion.Version130;
+ default:
+ return TransactSqlVersion.Current;
+ }
+ }
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Program.cs b/src/Microsoft.SqlTools.ServiceLayer/Program.cs
index f0d2d6e8..377afe9d 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Program.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Program.cs
@@ -45,7 +45,6 @@ namespace Microsoft.SqlTools.ServiceLayer
// Initialize the services that will be hosted here
WorkspaceService.Instance.InitializeService(serviceHost);
- AutoCompleteService.Instance.InitializeService(serviceHost);
LanguageService.Instance.InitializeService(serviceHost, sqlToolsContext);
ConnectionService.Instance.InitializeService(serviceHost);
CredentialService.Instance.InitializeService(serviceHost);
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs
index 69250afe..98e857ca 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs
@@ -11,6 +11,7 @@ using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SqlTools.EditorServices.Utility;
+using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage;
@@ -137,10 +138,10 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
{
// Register the message listener to *this instance* of the batch
// Note: This is being done to associate messages with batches
- SqlConnection sqlConn = conn as SqlConnection;
+ ReliableSqlConnection sqlConn = conn as ReliableSqlConnection;
if (sqlConn != null)
{
- sqlConn.InfoMessage += StoreDbMessage;
+ sqlConn.GetUnderlyingConnection().InfoMessage += StoreDbMessage;
}
// Create a command that we'll use for executing the query
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbColumnWrapper.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbColumnWrapper.cs
index e80eada5..d7d43248 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbColumnWrapper.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbColumnWrapper.cs
@@ -148,7 +148,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts
}
else
{
- DataType = DataType;
+ DataType = column.DataType;
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/SaveResultsRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/SaveResultsRequest.cs
new file mode 100644
index 00000000..721d13c9
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/SaveResultsRequest.cs
@@ -0,0 +1,96 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts;
+
+namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts
+{
+ ///
+ /// Parameters for the save results request
+ ///
+ public class SaveResultsRequestParams
+ {
+ ///
+ /// The path of the file to save results in
+ ///
+ public string FilePath { get; set; }
+
+ ///
+ /// Index of the batch to get the results from
+ ///
+ public int BatchIndex { get; set; }
+
+ ///
+ /// Index of the result set to get the results from
+ ///
+ public int ResultSetIndex { get; set; }
+
+ ///
+ /// URI for the editor that called save results
+ ///
+ public string OwnerUri { get; set; }
+ }
+
+ ///
+ /// Parameters to save results as CSV
+ ///
+ public class SaveResultsAsCsvRequestParams: SaveResultsRequestParams{
+
+ ///
+ /// CSV - Write values in quotes
+ ///
+ public Boolean ValueInQuotes { get; set; }
+
+ ///
+ /// The encoding of the file to save results in
+ ///
+ public string FileEncoding { get; set; }
+
+ ///
+ /// Include headers of columns in CSV
+ ///
+ public bool IncludeHeaders { get; set; }
+ }
+
+ ///
+ /// Parameters to save results as JSON
+ ///
+ public class SaveResultsAsJsonRequestParams: SaveResultsRequestParams{
+ //TODO: define config for save as JSON
+ }
+
+ ///
+ /// Parameters for the save results result
+ ///
+ public class SaveResultRequestResult
+ {
+ ///
+ /// Error messages for saving to file.
+ ///
+ public string Messages { get; set; }
+ }
+
+ ///
+ /// Request type to save results as CSV
+ ///
+ public class SaveResultsAsCsvRequest
+ {
+ public static readonly
+ RequestType Type =
+ RequestType.Create("query/saveCsv");
+ }
+
+ ///
+ /// Request type to save results as JSON
+ ///
+ public class SaveResultsAsJsonRequest
+ {
+ public static readonly
+ RequestType Type =
+ RequestType.Create("query/saveJson");
+ }
+
+}
\ No newline at end of file
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamReader.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamReader.cs
index 0cfc2466..b6c23349 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamReader.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamReader.cs
@@ -14,13 +14,10 @@ using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage
{
///
- /// Reader for SSMS formatted file streams
+ /// Reader for service buffer formatted file streams
///
public class ServiceBufferFileStreamReader : IFileStreamReader
{
- // Most of this code is based on code from the Microsoft.SqlServer.Management.UI.Grid, SSMS DataStorage
- // $\Data Tools\SSMS_XPlat\sql\ssms\core\DataStorage\src\FileStreamReader.cs
-
private const int DefaultBufferSize = 8192;
#region Member Variables
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamWriter.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamWriter.cs
index d0a1c2a9..c978bade 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamWriter.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamWriter.cs
@@ -14,13 +14,10 @@ using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage
{
///
- /// Writer for SSMS formatted file streams
+ /// Writer for service buffer formatted file streams
///
public class ServiceBufferFileStreamWriter : IFileStreamWriter
{
- // Most of this code is based on code from the Microsoft.SqlServer.Management.UI.Grid, SSMS DataStorage
- // $\Data Tools\SSMS_XPlat\sql\ssms\core\DataStorage\src\FileStreamWriter.cs
-
#region Properties
public const int DefaultBufferLength = 8192;
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs
index 2da2a15d..74193d8b 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs
@@ -10,6 +10,7 @@ using System.Threading;
using System.Threading.Tasks;
using Microsoft.SqlServer.Management.SqlParser.Parser;
using Microsoft.SqlTools.ServiceLayer.Connection;
+using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage;
using Microsoft.SqlTools.ServiceLayer.SqlContext;
@@ -192,11 +193,11 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
{
await conn.OpenAsync();
- SqlConnection sqlConn = conn as SqlConnection;
+ ReliableSqlConnection sqlConn = conn as ReliableSqlConnection;
if (sqlConn != null)
{
// Subscribe to database informational messages
- sqlConn.InfoMessage += OnInfoMessage;
+ sqlConn.GetUnderlyingConnection().InfoMessage += OnInfoMessage;
}
// We need these to execute synchronously, otherwise the user will be very unhappy
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs
index 97d89fc9..24d4de09 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs
@@ -2,9 +2,10 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
-
using System;
using System.Collections.Concurrent;
+using System.IO;
+using System.Linq;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Hosting;
@@ -13,6 +14,7 @@ using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage;
using Microsoft.SqlTools.ServiceLayer.SqlContext;
using Microsoft.SqlTools.ServiceLayer.Workspace;
+using Newtonsoft.Json;
namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
{
@@ -98,6 +100,8 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
serviceHost.SetRequestHandler(QueryExecuteSubsetRequest.Type, HandleResultSubsetRequest);
serviceHost.SetRequestHandler(QueryDisposeRequest.Type, HandleDisposeRequest);
serviceHost.SetRequestHandler(QueryCancelRequest.Type, HandleCancelRequest);
+ serviceHost.SetRequestHandler(SaveResultsAsCsvRequest.Type, HandleSaveResultsAsCsvRequest);
+ serviceHost.SetRequestHandler(SaveResultsAsJsonRequest.Type, HandleSaveResultsAsJsonRequest);
// Register handler for shutdown event
serviceHost.RegisterShutdownTask((shutdownParams, requestContext) =>
@@ -256,6 +260,124 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
}
}
+ ///
+ /// Process request to save a resultSet to a file in CSV format
+ ///
+ public async Task HandleSaveResultsAsCsvRequest( SaveResultsAsCsvRequestParams saveParams,
+ RequestContext requestContext)
+ {
+ // retrieve query for OwnerUri
+ Query result;
+ if (!ActiveQueries.TryGetValue(saveParams.OwnerUri, out result))
+ {
+ await requestContext.SendResult(new SaveResultRequestResult
+ {
+ Messages = "Failed to save results, ID not found."
+ });
+ return;
+ }
+ try
+ {
+ using (StreamWriter csvFile = new StreamWriter(File.Open(saveParams.FilePath, FileMode.Create)))
+ {
+ // get the requested resultSet from query
+ Batch selectedBatch = result.Batches[saveParams.BatchIndex];
+ ResultSet selectedResultSet = (selectedBatch.ResultSets.ToList())[saveParams.ResultSetIndex];
+ if (saveParams.IncludeHeaders)
+ {
+ // write column names to csv
+ await csvFile.WriteLineAsync( string.Join( ",", selectedResultSet.Columns.Select( column => SaveResults.EncodeCsvField(column.ColumnName) ?? string.Empty)));
+ }
+
+ // write rows to csv
+ foreach (var row in selectedResultSet.Rows)
+ {
+ await csvFile.WriteLineAsync( string.Join( ",", row.Select( field => SaveResults.EncodeCsvField((field != null) ? field.ToString(): string.Empty))));
+ }
+ }
+ }
+ catch(Exception ex)
+ {
+ // Delete file when exception occurs
+ if (File.Exists(saveParams.FilePath))
+ {
+ File.Delete(saveParams.FilePath);
+ }
+ await requestContext.SendError(ex.Message);
+ return;
+ }
+ await requestContext.SendResult(new SaveResultRequestResult
+ {
+ Messages = "Success"
+ });
+ return;
+ }
+
+ ///
+ /// Process request to save a resultSet to a file in JSON format
+ ///
+ public async Task HandleSaveResultsAsJsonRequest( SaveResultsAsJsonRequestParams saveParams,
+ RequestContext requestContext)
+ {
+ // retrieve query for OwnerUri
+ Query result;
+ if (!ActiveQueries.TryGetValue(saveParams.OwnerUri, out result))
+ {
+ await requestContext.SendResult(new SaveResultRequestResult
+ {
+ Messages = "Failed to save results, ID not found."
+ });
+ return;
+ }
+ try
+ {
+ using (StreamWriter jsonFile = new StreamWriter(File.Open(saveParams.FilePath, FileMode.Create)))
+ using (JsonWriter jsonWriter = new JsonTextWriter(jsonFile) )
+ {
+ jsonWriter.Formatting = Formatting.Indented;
+ jsonWriter.WriteStartArray();
+
+ // get the requested resultSet from query
+ Batch selectedBatch = result.Batches[saveParams.BatchIndex];
+ ResultSet selectedResultSet = (selectedBatch.ResultSets.ToList())[saveParams.ResultSetIndex];
+
+ // write each row to JSON
+ foreach (var row in selectedResultSet.Rows)
+ {
+ jsonWriter.WriteStartObject();
+ foreach (var field in row.Select((value,i) => new {value, i}))
+ {
+ jsonWriter.WritePropertyName(selectedResultSet.Columns[field.i].ColumnName);
+ if (field.value != null)
+ {
+ jsonWriter.WriteValue(field.value);
+ }
+ else
+ {
+ jsonWriter.WriteNull();
+ }
+ }
+ jsonWriter.WriteEndObject();
+ }
+ jsonWriter.WriteEndArray();
+ }
+ }
+ catch(Exception ex)
+ {
+ // Delete file when exception occurs
+ if (File.Exists(saveParams.FilePath))
+ {
+ File.Delete(saveParams.FilePath);
+ }
+ await requestContext.SendError(ex.Message);
+ return;
+ }
+ await requestContext.SendResult(new SaveResultRequestResult
+ {
+ Messages = "Success"
+ });
+ return;
+ }
#endregion
#region Private Helpers
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs
index 84e18c99..41978255 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs
@@ -111,6 +111,14 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
///
public long RowCount { get; private set; }
+ ///
+ /// The rows of this result set
+ ///
+ public IEnumerable