diff --git a/.gitignore b/.gitignore index 4c997e2b..de0fdc5b 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ project.lock.json *.user *.userosscache *.sln.docstates +*.exe # Build results [Dd]ebug/ @@ -29,7 +30,13 @@ msbuild.log msbuild.err msbuild.wrn - +# code coverage artifacts +coverage.xml +node_modules +packages +reports +opencovertests.xml +sqltools.xml # Cross building rootfs cross/rootfs/ diff --git a/BUILD.md b/BUILD.md new file mode 100644 index 00000000..603d42b4 --- /dev/null +++ b/BUILD.md @@ -0,0 +1,75 @@ +# Usage + +Run `build.(ps1|sh)` with the desired set of arguments (see below for options). +The build script itself is `build.cake`, written in C# using the Cake build automation system. +All build related activites should be encapsulated in this file for cross-platform access. + +# Arguments + +## Primary + + `-target=TargetName`: The name of the build task/target to execute (see below for listing and details). + Defaults to `Default`. + + `-configuration=(Release|Debug)`: The configuration to build. + Defaults to `Release`. + +## Extra + + `-test-configuration=(Release|Debug)`: The configuration to use for the unit tests. + Defaults to `Debug`. + + `-install-path=Path`: Path used for the **Install** target. + Defaults to `(%USERPROFILE%|$HOME)/.sqltoolsservice/local` + + `-archive`: Enable the generation of publishable archives after a build. + +# Targets + +**Default**: Alias for Local. + +**Local**: Full build including testing for the machine-local runtime. + +**All**: Same as local, but targeting all runtimes selected by `PopulateRuntimes` in `build.cake`. + Currently configured to also build for a 32-bit Windows runtime on Windows machines. + No additional runtimes are currently selected on non-Windows machines. + +**Quick**: Local build which skips all testing. + +**Install**: Same as quick, but installs the generated binaries into `install-path`. + +**SetPackageVersions**: Updates the dependency versions found within `project.json` files using information from `depversion.json`. + Used for maintainence within the project, not needed for end-users. More information below. + +# Configuration files + +## build.json + +A number of build-related options, including folder names for different entities. Interesting options: + +**DotNetInstallScriptURL**: The URL where the .NET SDK install script is located. + Can be used to pin to a specific script version, if a breaking change occurs. + +**"DotNetChannel"**: The .NET SDK channel used for retreiving the tools. + +**"DotNetVersion"**: The .NET SDK version used for the build. Can be used to pin to a specific version. + Using the string `Latest` will retrieve the latest version. + +## depversion.json + +A listing of all dependencies (and their desired versions) used by `project.json` files throughout the project. +Allows for quick and automatic updates to the dependency version numbers using the **SetPackageVersions** target. + +# Artifacts generated + +* Binaries of Microsoft.SqlTools.ServiceLayer and its libraries built for the local machine in `artifacts/publish/Microsoft.SqlTools.ServiceLayer/default/{framework}/` +* Scripts to run Microsoft.SqlTools.ServiceLayer at `scripts/SQLTOOLSSERVICE(.Core)(.cmd)` + * These scripts are updated for every build and every install. + * The scripts point to the installed binary after and install, otherwise just the build folder (reset if a new build occurs without an install). +* Binaries of Microsoft.SqlTools.ServiceLayer and its libraries cross-compiled for other runtimes (if selected in **PopulateRuntimes**) `artifacts/publish/Microsoft.SqlTools.ServiceLayer/{runtime}/{framework}/` +* Test logs in `artifacts/logs` +* Archived binaries in `artifacts/package` (only if `-archive` used on command line) + +# Requirements + +The build system requires Mono to be installed on non-Windows machines as Cake is not built using .NET Core (yet). diff --git a/build.cake b/build.cake new file mode 100644 index 00000000..61047b05 --- /dev/null +++ b/build.cake @@ -0,0 +1,507 @@ +#addin "Newtonsoft.Json" + +#load "scripts/runhelpers.cake" +#load "scripts/archiving.cake" +#load "scripts/artifacts.cake" + +using System.ComponentModel; +using System.Net; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; + +// Basic arguments +var target = Argument("target", "Default"); +var configuration = Argument("configuration", "Release"); +// Optional arguments +var testConfiguration = Argument("test-configuration", "Debug"); +var installFolder = Argument("install-path", System.IO.Path.Combine(Environment.GetEnvironmentVariable(IsRunningOnWindows() ? "USERPROFILE" : "HOME"), + ".sqltoolsservice", "local")); +var requireArchive = HasArgument("archive"); + +// Working directory +var workingDirectory = System.IO.Directory.GetCurrentDirectory(); + +// System specific shell configuration +var shell = IsRunningOnWindows() ? "powershell" : "bash"; +var shellArgument = IsRunningOnWindows() ? "-NoProfile /Command" : "-C"; +var shellExtension = IsRunningOnWindows() ? "ps1" : "sh"; + +/// +/// Class representing build.json +/// +public class BuildPlan +{ + public IDictionary TestProjects { get; set; } + public string BuildToolsFolder { get; set; } + public string ArtifactsFolder { get; set; } + public bool UseSystemDotNetPath { get; set; } + public string DotNetFolder { get; set; } + public string DotNetInstallScriptURL { get; set; } + public string DotNetChannel { get; set; } + public string DotNetVersion { get; set; } + public string[] Frameworks { get; set; } + public string[] Rids { get; set; } + public string MainProject { get; set; } +} + +var buildPlan = JsonConvert.DeserializeObject( + System.IO.File.ReadAllText(System.IO.Path.Combine(workingDirectory, "build.json"))); + +// Folders and tools +var dotnetFolder = System.IO.Path.Combine(workingDirectory, buildPlan.DotNetFolder); +var dotnetcli = buildPlan.UseSystemDotNetPath ? "dotnet" : System.IO.Path.Combine(System.IO.Path.GetFullPath(dotnetFolder), "dotnet"); +var toolsFolder = System.IO.Path.Combine(workingDirectory, buildPlan.BuildToolsFolder); + +var sourceFolder = System.IO.Path.Combine(workingDirectory, "src"); +var testFolder = System.IO.Path.Combine(workingDirectory, "test"); + +var artifactFolder = System.IO.Path.Combine(workingDirectory, buildPlan.ArtifactsFolder); +var publishFolder = System.IO.Path.Combine(artifactFolder, "publish"); +var logFolder = System.IO.Path.Combine(artifactFolder, "logs"); +var packageFolder = System.IO.Path.Combine(artifactFolder, "package"); +var scriptFolder = System.IO.Path.Combine(artifactFolder, "scripts"); + +/// +/// Clean artifacts. +/// +Task("Cleanup") + .Does(() => +{ + if (System.IO.Directory.Exists(artifactFolder)) + { + System.IO.Directory.Delete(artifactFolder, true); + } + System.IO.Directory.CreateDirectory(artifactFolder); + System.IO.Directory.CreateDirectory(logFolder); + System.IO.Directory.CreateDirectory(packageFolder); + System.IO.Directory.CreateDirectory(scriptFolder); +}); + +/// +/// Pre-build setup tasks. +/// +Task("Setup") + .IsDependentOn("BuildEnvironment") + .IsDependentOn("PopulateRuntimes") + .Does(() => +{ +}); + +/// +/// Populate the RIDs for the specific environment. +/// Use default RID (+ win7-x86 on Windows) for now. +/// +Task("PopulateRuntimes") + .IsDependentOn("BuildEnvironment") + .Does(() => +{ + buildPlan.Rids = new string[] + { + "default", // To allow testing the published artifact + "win7-x64", + "win7-x86", + "ubuntu.14.04-x64", + "ubuntu.16.04-x64", + "centos.7-x64", + "rhel.7.2-x64", + "debian.8-x64", + "fedora.23-x64", + "opensuse.13.2-x64", + "osx.10.11-x64" + }; +}); + +/// +/// Install/update build environment. +/// +Task("BuildEnvironment") + .Does(() => +{ + var installScript = $"dotnet-install.{shellExtension}"; + System.IO.Directory.CreateDirectory(dotnetFolder); + var scriptPath = System.IO.Path.Combine(dotnetFolder, installScript); + using (WebClient client = new WebClient()) + { + client.DownloadFile($"{buildPlan.DotNetInstallScriptURL}/{installScript}", scriptPath); + } + if (!IsRunningOnWindows()) + { + Run("chmod", $"+x '{scriptPath}'"); + } + var installArgs = $"-Channel {buildPlan.DotNetChannel}"; + if (!String.IsNullOrEmpty(buildPlan.DotNetVersion)) + { + installArgs = $"{installArgs} -Version {buildPlan.DotNetVersion}"; + } + if (!buildPlan.UseSystemDotNetPath) + { + installArgs = $"{installArgs} -InstallDir {dotnetFolder}"; + } + Run(shell, $"{shellArgument} {scriptPath} {installArgs}"); + try + { + Run(dotnetcli, "--info"); + } + catch (Win32Exception) + { + throw new Exception(".NET CLI binary cannot be found."); + } + + System.IO.Directory.CreateDirectory(toolsFolder); + + var nugetPath = Environment.GetEnvironmentVariable("NUGET_EXE"); + var arguments = $"install xunit.runner.console -ExcludeVersion -NoCache -Prerelease -OutputDirectory \"{toolsFolder}\""; + if (IsRunningOnWindows()) + { + Run(nugetPath, arguments); + } + else + { + Run("mono", $"\"{nugetPath}\" {arguments}"); + } +}); + +/// +/// Restore required NuGet packages. +/// +Task("Restore") + .IsDependentOn("Setup") + .Does(() => +{ + RunRestore(dotnetcli, "restore", sourceFolder) + .ExceptionOnError("Failed to restore projects under source code folder."); + RunRestore(dotnetcli, "restore --infer-runtimes", testFolder) + .ExceptionOnError("Failed to restore projects under test code folder."); +}); + +/// +/// Build Test projects. +/// +Task("BuildTest") + .IsDependentOn("Setup") + .IsDependentOn("Restore") + .Does(() => +{ + foreach (var pair in buildPlan.TestProjects) + { + foreach (var framework in pair.Value) + { + var project = pair.Key; + var projectFolder = System.IO.Path.Combine(testFolder, project); + var runLog = new List(); + Run(dotnetcli, $"build --framework {framework} --configuration {testConfiguration} \"{projectFolder}\"", + new RunOptions + { + StandardOutputListing = runLog + }) + .ExceptionOnError($"Building test {project} failed for {framework}."); + System.IO.File.WriteAllLines(System.IO.Path.Combine(logFolder, $"{project}-{framework}-build.log"), runLog.ToArray()); + } + } +}); + +/// +/// Run all tests for .NET Desktop and .NET Core +/// +Task("TestAll") + .IsDependentOn("Test") + .IsDependentOn("TestCore") + .Does(() =>{}); + +/// +/// Run all tests for Travis CI .NET Desktop and .NET Core +/// +Task("TravisTestAll") + .IsDependentOn("Cleanup") + .IsDependentOn("TestAll") + .Does(() =>{}); + +/// +/// Run tests for .NET Core (using .NET CLI). +/// +Task("TestCore") + .IsDependentOn("Setup") + .IsDependentOn("Restore") + .Does(() => +{ + var testProjects = buildPlan.TestProjects + .Where(pair => pair.Value.Any(framework => framework.Contains("netcoreapp"))) + .Select(pair => pair.Key) + .ToList(); + + foreach (var testProject in testProjects) + { + var logFile = System.IO.Path.Combine(logFolder, $"{testProject}-core-result.xml"); + var testWorkingDir = System.IO.Path.Combine(testFolder, testProject); + Run(dotnetcli, $"test -f netcoreapp1.0 -xml \"{logFile}\" -notrait category=failing", testWorkingDir) + .ExceptionOnError($"Test {testProject} failed for .NET Core."); + } +}); + +/// +/// Run tests for other frameworks (using XUnit2). +/// +Task("Test") + .IsDependentOn("Setup") + .IsDependentOn("BuildTest") + .Does(() => +{ + foreach (var pair in buildPlan.TestProjects) + { + foreach (var framework in pair.Value) + { + // Testing against core happens in TestCore + if (framework.Contains("netcoreapp")) + { + continue; + } + + var project = pair.Key; + var frameworkFolder = System.IO.Path.Combine(testFolder, project, "bin", testConfiguration, framework); + var runtime = System.IO.Directory.GetDirectories(frameworkFolder).First(); + var instanceFolder = System.IO.Path.Combine(frameworkFolder, runtime); + + // Copy xunit executable to test folder to solve path errors + var xunitToolsFolder = System.IO.Path.Combine(toolsFolder, "xunit.runner.console", "tools"); + var xunitInstancePath = System.IO.Path.Combine(instanceFolder, "xunit.console.exe"); + System.IO.File.Copy(System.IO.Path.Combine(xunitToolsFolder, "xunit.console.exe"), xunitInstancePath, true); + System.IO.File.Copy(System.IO.Path.Combine(xunitToolsFolder, "xunit.runner.utility.desktop.dll"), System.IO.Path.Combine(instanceFolder, "xunit.runner.utility.desktop.dll"), true); + var targetPath = System.IO.Path.Combine(instanceFolder, $"{project}.dll"); + var logFile = System.IO.Path.Combine(logFolder, $"{project}-{framework}-result.xml"); + var arguments = $"\"{targetPath}\" -parallel none -xml \"{logFile}\" -notrait category=failing"; + if (IsRunningOnWindows()) + { + Run(xunitInstancePath, arguments, instanceFolder) + .ExceptionOnError($"Test {project} failed for {framework}"); + } + else + { + Run("mono", $"\"{xunitInstancePath}\" {arguments}", instanceFolder) + .ExceptionOnError($"Test {project} failed for {framework}"); + } + } + } +}); + +/// +/// Build, publish and package artifacts. +/// Targets all RIDs specified in build.json unless restricted by RestrictToLocalRuntime. +/// No dependencies on other tasks to support quick builds. +/// +Task("OnlyPublish") + .IsDependentOn("Setup") + .Does(() => +{ + var project = buildPlan.MainProject; + var projectFolder = System.IO.Path.Combine(sourceFolder, project); + foreach (var framework in buildPlan.Frameworks) + { + foreach (var runtime in buildPlan.Rids) + { + var outputFolder = System.IO.Path.Combine(publishFolder, project, runtime, framework); + var publishArguments = "publish"; + if (!runtime.Equals("default")) + { + publishArguments = $"{publishArguments} --runtime {runtime}"; + } + publishArguments = $"{publishArguments} --framework {framework} --configuration {configuration}"; + publishArguments = $"{publishArguments} --output \"{outputFolder}\" \"{projectFolder}\""; + Run(dotnetcli, publishArguments) + .ExceptionOnError($"Failed to publish {project} / {framework}"); + + if (requireArchive) + { + Package(runtime, framework, outputFolder, packageFolder, buildPlan.MainProject.ToLower()); + } + } + } + CreateRunScript(System.IO.Path.Combine(publishFolder, project, "default"), scriptFolder); +}); + +/// +/// Alias for OnlyPublish. +/// Targets all RIDs as specified in build.json. +/// +Task("AllPublish") + .IsDependentOn("Restore") + .IsDependentOn("OnlyPublish") + .Does(() => +{ +}); + +/// +/// Restrict the RIDs for the local default. +/// +Task("RestrictToLocalRuntime") + .IsDependentOn("Setup") + .Does(() => +{ + buildPlan.Rids = new string[] {"default"}; +}); + +/// +/// Alias for OnlyPublish. +/// Restricts publishing to local RID. +/// +Task("LocalPublish") + .IsDependentOn("Restore") + .IsDependentOn("RestrictToLocalRuntime") + .IsDependentOn("OnlyPublish") + .Does(() => +{ +}); + +/// +/// Test the published binaries if they start up without errors. +/// Uses builds corresponding to local RID. +/// +Task("TestPublished") + .IsDependentOn("Setup") + .Does(() => +{ + var project = buildPlan.MainProject; + var projectFolder = System.IO.Path.Combine(sourceFolder, project); + var scriptsToTest = new string[] {"SQLTOOLSSERVICE.Core"};//TODO + foreach (var script in scriptsToTest) + { + var scriptPath = System.IO.Path.Combine(scriptFolder, script); + var didNotExitWithError = Run($"{shell}", $"{shellArgument} \"{scriptPath}\" -s \"{projectFolder}\" --stdio", + new RunOptions + { + TimeOut = 10000 + }) + .DidTimeOut; + if (!didNotExitWithError) + { + throw new Exception($"Failed to run {script}"); + } + } +}); + +/// +/// Clean install path. +/// +Task("CleanupInstall") + .Does(() => +{ + if (System.IO.Directory.Exists(installFolder)) + { + System.IO.Directory.Delete(installFolder, true); + } + System.IO.Directory.CreateDirectory(installFolder); +}); + +/// +/// Quick build. +/// +Task("Quick") + .IsDependentOn("Cleanup") + .IsDependentOn("LocalPublish") + .Does(() => +{ +}); + +/// +/// Quick build + install. +/// +Task("Install") + .IsDependentOn("Cleanup") + .IsDependentOn("LocalPublish") + .IsDependentOn("CleanupInstall") + .Does(() => +{ + var project = buildPlan.MainProject; + foreach (var framework in buildPlan.Frameworks) + { + var outputFolder = System.IO.Path.GetFullPath(System.IO.Path.Combine(publishFolder, project, "default", framework)); + var targetFolder = System.IO.Path.GetFullPath(System.IO.Path.Combine(installFolder, framework)); + // Copy all the folders + foreach (var directory in System.IO.Directory.GetDirectories(outputFolder, "*", SearchOption.AllDirectories)) + System.IO.Directory.CreateDirectory(System.IO.Path.Combine(targetFolder, directory.Substring(outputFolder.Length + 1))); + //Copy all the files + foreach (string file in System.IO.Directory.GetFiles(outputFolder, "*", SearchOption.AllDirectories)) + System.IO.File.Copy(file, System.IO.Path.Combine(targetFolder, file.Substring(outputFolder.Length + 1)), true); + } + CreateRunScript(installFolder, scriptFolder); +}); + +/// +/// Full build targeting all RIDs specified in build.json. +/// +Task("All") + .IsDependentOn("Cleanup") + .IsDependentOn("Restore") + .IsDependentOn("TestAll") + .IsDependentOn("AllPublish") + //.IsDependentOn("TestPublished") + .Does(() => +{ +}); + +/// +/// Full build targeting local RID. +/// +Task("Local") + .IsDependentOn("Cleanup") + .IsDependentOn("Restore") + .IsDependentOn("TestAll") + .IsDependentOn("LocalPublish") + // .IsDependentOn("TestPublished") + .Does(() => +{ +}); + +/// +/// Build centered around producing the final artifacts for Travis +/// +/// The tests are run as a different task "TestAll" +/// +Task("Travis") + .IsDependentOn("Cleanup") + .IsDependentOn("Restore") + .IsDependentOn("AllPublish") + // .IsDependentOn("TestPublished") + .Does(() => +{ +}); + +/// +/// Update the package versions within project.json files. +/// Uses depversion.json file as input. +/// +Task("SetPackageVersions") + .Does(() => +{ + var jDepVersion = JObject.Parse(System.IO.File.ReadAllText(System.IO.Path.Combine(workingDirectory, "depversion.json"))); + var projects = System.IO.Directory.GetFiles(sourceFolder, "project.json", SearchOption.AllDirectories).ToList(); + projects.AddRange(System.IO.Directory.GetFiles(testFolder, "project.json", SearchOption.AllDirectories)); + foreach (var project in projects) + { + var jProject = JObject.Parse(System.IO.File.ReadAllText(project)); + var dependencies = jProject.SelectTokens("dependencies") + .Union(jProject.SelectTokens("frameworks.*.dependencies")) + .SelectMany(dependencyToken => dependencyToken.Children()); + foreach (JProperty dependency in dependencies) + { + if (jDepVersion[dependency.Name] != null) + { + dependency.Value = jDepVersion[dependency.Name]; + } + } + System.IO.File.WriteAllText(project, JsonConvert.SerializeObject(jProject, Formatting.Indented)); + } +}); + +/// +/// Default Task aliases to Local. +/// +Task("Default") + .IsDependentOn("Local") + .Does(() => +{ +}); + +/// +/// Default to Local. +/// +RunTarget(target); diff --git a/build.json b/build.json new file mode 100644 index 00000000..a0741723 --- /dev/null +++ b/build.json @@ -0,0 +1,18 @@ +{ + "UseSystemDotNetPath": "true", + "DotNetFolder": ".dotnet", + "DotNetInstallScriptURL": "https://raw.githubusercontent.com/dotnet/cli/rel/1.0.0-preview2/scripts/obtain", + "DotNetChannel": "preview", + "DotNetVersion": "1.0.0-preview2-003121", + "BuildToolsFolder": ".tools", + "ArtifactsFolder": "artifacts", + "TestProjects": { + "Microsoft.SqlTools.ServiceLayer.Test": [ + "netcoreapp1.0" + ] + }, + "Frameworks": [ + "netcoreapp1.0" + ], + "MainProject": "Microsoft.SqlTools.ServiceLayer" +} diff --git a/build.ps1 b/build.ps1 new file mode 100644 index 00000000..68dc2c1e --- /dev/null +++ b/build.ps1 @@ -0,0 +1,3 @@ +$Env:SQLTOOLSSERVICE_PACKAGE_OSNAME = "win-x64" +.\scripts\cake-bootstrap.ps1 -experimental @args +exit $LASTEXITCODE diff --git a/build.sh b/build.sh new file mode 100644 index 00000000..90332a42 --- /dev/null +++ b/build.sh @@ -0,0 +1,12 @@ +#!/bin/bash +# Handle to many files on osx +if [ "$TRAVIS_OS_NAME" == "osx" ] || [ `uname` == "Darwin" ]; then + ulimit -n 4096 +fi + +if [ "$TRAVIS_OS_NAME" == "osx" ] || [ `uname` == "Darwin" ]; then + export SQLTOOLSSERVICE_PACKAGE_OSNAME=osx-x64 +else + export SQLTOOLSSERVICE_PACKAGE_OSNAME=linux-x64 +fi +bash ./scripts/cake-bootstrap.sh "$@" diff --git a/global.json b/global.json index db6ba19b..9ae78d22 100644 --- a/global.json +++ b/global.json @@ -1,5 +1,8 @@ { - "projects": [ "src", "test" ] + "projects": [ "src", "test" ], + "sdk": { + "version": "1.0.0-preview2-003121" + } } diff --git a/nuget.config b/nuget.config new file mode 100644 index 00000000..edd564a3 --- /dev/null +++ b/nuget.config @@ -0,0 +1,6 @@ + + + + + + diff --git a/scripts/archiving.cake b/scripts/archiving.cake new file mode 100644 index 00000000..4b012986 --- /dev/null +++ b/scripts/archiving.cake @@ -0,0 +1,104 @@ +#load "runhelpers.cake" + +using System.IO.Compression; +using System.Text.RegularExpressions; + +/// +/// Generate the build identifier based on the RID and framework identifier. +/// Special rules when running on Travis (for publishing purposes). +/// +/// The RID +/// The framework identifier +/// The designated build identifier +string GetBuildIdentifier(string runtime, string framework) +{ + var runtimeShort = ""; + // Default RID uses package name set in build script + if (runtime.Equals("default")) + { + runtimeShort = Environment.GetEnvironmentVariable("SQLTOOLSSERVICE_PACKAGE_OSNAME"); + } + else + { + // Remove version number. Note: because there are separate versions for Ubuntu 14 and 16, + // we treat Ubuntu as a special case. + if (runtime.StartsWith("ubuntu.14")) + { + runtimeShort = "ubuntu14-x64"; + } + else if (runtime.StartsWith("ubuntu.16")) + { + runtimeShort = "ubuntu16-x64"; + } + else + { + runtimeShort = Regex.Replace(runtime, "(\\d|\\.)*-", "-"); + } + } + + return $"{runtimeShort}-{framework}"; +} + +/// +/// Generate an archive out of the given published folder. +/// Use ZIP for Windows runtimes. +/// Use TAR.GZ for non-Windows runtimes. +/// Use 7z to generate TAR.GZ on Windows if available. +/// +/// The RID +/// The folder containing the files to package +/// The target archive name (without extension) +void DoArchive(string runtime, string contentFolder, string archiveName) +{ + // On all platforms use ZIP for Windows runtimes + if (runtime.Contains("win") || (runtime.Equals("default") && IsRunningOnWindows())) + { + var zipFile = $"{archiveName}.zip"; + Zip(contentFolder, zipFile); + } + // On all platforms use TAR.GZ for Unix runtimes + else + { + var tarFile = $"{archiveName}.tar.gz"; + // Use 7z to create TAR.GZ on Windows + if (IsRunningOnWindows()) + { + var tempFile = $"{archiveName}.tar"; + try + { + Run("7z", $"a \"{tempFile}\"", contentFolder) + .ExceptionOnError($"Tar-ing failed for {contentFolder} {archiveName}"); + Run("7z", $"a \"{tarFile}\" \"{tempFile}\"", contentFolder) + .ExceptionOnError($"Compression failed for {contentFolder} {archiveName}"); + System.IO.File.Delete(tempFile); + } + catch(Win32Exception) + { + Information("Warning: 7z not available on PATH to pack tar.gz results"); + } + } + // Use tar to create TAR.GZ on Unix + else + { + Run("tar", $"czf \"{tarFile}\" .", contentFolder) + .ExceptionOnError($"Compression failed for {contentFolder} {archiveName}"); + } + } +} + +/// +/// Package a given output folder using a build identifier generated from the RID and framework identifier. +/// +/// The RID +/// The framework identifier +/// The folder containing the files to package +/// The destination folder for the archive +/// The project name +void Package(string runtime, string framework, string contentFolder, string packageFolder, string projectName) +{ + var buildIdentifier = GetBuildIdentifier(runtime, framework); + if (buildIdentifier != null) + { + DoArchive(runtime, contentFolder, $"{packageFolder}/{projectName}-{buildIdentifier}"); + } +} \ No newline at end of file diff --git a/scripts/artifacts.cake b/scripts/artifacts.cake new file mode 100644 index 00000000..f448fe3f --- /dev/null +++ b/scripts/artifacts.cake @@ -0,0 +1,43 @@ +#load "runhelpers.cake" + +/// +/// Generate the scripts which target the SQLTOOLSSERVICE binaries. +/// +/// The root folder where the publised (or installed) binaries are located +void CreateRunScript(string outputRoot, string scriptFolder) +{ + if (IsRunningOnWindows()) + { + var coreScript = System.IO.Path.Combine(scriptFolder, "SQLTOOLSSERVICE.Core.cmd"); + var sqlToolsServicePath = System.IO.Path.Combine(System.IO.Path.GetFullPath(outputRoot), "{0}", "SQLTOOLSSERVICE"); + var content = new string[] { + "SETLOCAL", + "", + $"\"{sqlToolsServicePath}\" %*" + }; + if (System.IO.File.Exists(coreScript)) + { + System.IO.File.Delete(coreScript); + } + content[2] = String.Format(content[2], "netcoreapp1.0"); + System.IO.File.WriteAllLines(coreScript, content); + } + else + { + var coreScript = System.IO.Path.Combine(scriptFolder, "SQLTOOLSSERVICE.Core"); + var sqlToolsServicePath = System.IO.Path.Combine(System.IO.Path.GetFullPath(outputRoot), "{1}", "SQLTOOLSSERVICE"); + var content = new string[] { + "#!/bin/bash", + "", + $"{{0}} \"{sqlToolsServicePath}{{2}}\" \"$@\"" + }; + + if (System.IO.File.Exists(coreScript)) + { + System.IO.File.Delete(coreScript); + } + content[2] = String.Format(content[2], "", "netcoreapp1.0", ""); + System.IO.File.WriteAllLines(coreScript, content); + Run("chmod", $"+x \"{coreScript}\""); + } +} \ No newline at end of file diff --git a/scripts/cake-bootstrap.ps1 b/scripts/cake-bootstrap.ps1 new file mode 100644 index 00000000..a87c1478 --- /dev/null +++ b/scripts/cake-bootstrap.ps1 @@ -0,0 +1,110 @@ +<# + +.SYNOPSIS +This is a Powershell script to bootstrap a Cake build. + +.DESCRIPTION +This Powershell script will download NuGet if missing, restore NuGet tools (including Cake) +and execute your Cake build script with the parameters you provide. + +.PARAMETER Script +The build script to execute. +.PARAMETER Target +The build script target to run. +.PARAMETER Configuration +The build configuration to use. +.PARAMETER Verbosity +Specifies the amount of information to be displayed. +Tells Cake to use the latest Roslyn release. +.PARAMETER WhatIf +Performs a dry run of the build script. +No tasks will be executed. +.PARAMETER Mono +Tells Cake to use the Mono scripting engine. + +.LINK +http://cakebuild.net + +#> + +[CmdletBinding()] +Param( + [string]$Script = "build.cake", + [ValidateSet("Quiet", "Minimal", "Normal", "Verbose", "Diagnostic")] + [string]$Verbosity = "Verbose", + [Alias("DryRun","Noop")] + [switch]$WhatIf, + [switch]$Mono, + [switch]$SkipToolPackageRestore, + [Parameter(Position=0,Mandatory=$false,ValueFromRemainingArguments=$true)] + [string[]]$ScriptArgs +) + +Write-Host "Preparing to run build script..." + +$PS_SCRIPT_ROOT = split-path -parent $MyInvocation.MyCommand.Definition; +$TOOLS_DIR = Join-Path $PSScriptRoot "..\.tools" +$NUGET_EXE = Join-Path $TOOLS_DIR "nuget.exe" +$NUGET_URL = "https://dist.nuget.org/win-x86-commandline/v3.3.0/nuget.exe" +$CAKE_EXE = Join-Path $TOOLS_DIR "Cake/Cake.exe" +$PACKAGES_CONFIG = Join-Path $PS_SCRIPT_ROOT "packages.config" + +# Should we use mono? +$UseMono = ""; +if($Mono.IsPresent) { + Write-Verbose -Message "Using the Mono based scripting engine." + $UseMono = "-mono" +} + +# Is this a dry run? +$UseDryRun = ""; +if($WhatIf.IsPresent) { + $UseDryRun = "-dryrun" +} + +# Make sure tools folder exists +if ((Test-Path $PSScriptRoot) -and !(Test-Path $TOOLS_DIR)) { + Write-Verbose -Message "Creating tools directory..." + New-Item -Path $TOOLS_DIR -Type directory | out-null +} + +# Try download NuGet.exe if not exists +if (!(Test-Path $NUGET_EXE)) { + Write-Verbose -Message "Downloading NuGet.exe..." + try { + (New-Object System.Net.WebClient).DownloadFile($NUGET_URL, $NUGET_EXE) + } catch { + Throw "Could not download NuGet.exe." + } +} + +# Save nuget.exe path to environment to be available to child processed +$ENV:NUGET_EXE = $NUGET_EXE + +# Restore tools from NuGet? +if(-Not $SkipToolPackageRestore.IsPresent) +{ + # Restore packages from NuGet. + Push-Location + Set-Location $TOOLS_DIR + + Write-Verbose -Message "Restoring tools from NuGet..." + $NuGetOutput = Invoke-Expression "&`"$NUGET_EXE`" install $PACKAGES_CONFIG -ExcludeVersion -OutputDirectory `"$TOOLS_DIR`"" + Write-Verbose -Message ($NuGetOutput | out-string) + + Pop-Location + if ($LASTEXITCODE -ne 0) + { + exit $LASTEXITCODE + } +} + +# Make sure that Cake has been installed. +if (!(Test-Path $CAKE_EXE)) { + Throw "Could not find Cake.exe at $CAKE_EXE" +} + +# Start Cake +Write-Host "Running build script..." +Invoke-Expression "& `"$CAKE_EXE`" `"$Script`" -verbosity=`"$Verbosity`" $UseMono $UseDryRun $ScriptArgs" +exit $LASTEXITCODE diff --git a/scripts/cake-bootstrap.sh b/scripts/cake-bootstrap.sh new file mode 100644 index 00000000..abc3ed43 --- /dev/null +++ b/scripts/cake-bootstrap.sh @@ -0,0 +1,69 @@ +#!/usr/bin/env bash +############################################################### +# This is the Cake bootstrapper script that is responsible for +# downloading Cake and all specified tools from NuGet. +############################################################### + +# Define directories. +SCRIPT_DIR=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd ) +TOOLS_DIR=$SCRIPT_DIR/../.tools +export NUGET_EXE=$TOOLS_DIR/nuget.exe +CAKE_EXE=$TOOLS_DIR/Cake/Cake.exe +PACKAGES_CONFIG=$SCRIPT_DIR/packages.config + +# Define default arguments. +SCRIPT="build.cake" +VERBOSITY="verbose" +DRYRUN= +SHOW_VERSION=false +SCRIPT_ARGUMENTS=() + +# Parse arguments. +for i in "$@"; do + case $1 in + -s|--script) SCRIPT="$2"; shift ;; + -v|--verbosity) VERBOSITY="$2"; shift ;; + -d|--dryrun) DRYRUN="-dryrun" ;; + --version) SHOW_VERSION=true ;; + --) shift; SCRIPT_ARGUMENTS+=("$@"); break ;; + *) SCRIPT_ARGUMENTS+=("$1") ;; + esac + shift +done + +# Make sure the tools folder exist. +if [ ! -d "$TOOLS_DIR" ]; then + mkdir "$TOOLS_DIR" +fi + +# Download NuGet if it does not exist. +if [ ! -f "$NUGET_EXE" ]; then + echo "Downloading NuGet..." + curl -Lsfo "$NUGET_EXE" https://dist.nuget.org/win-x86-commandline/v3.3.0/nuget.exe + if [ $? -ne 0 ]; then + echo "An error occured while downloading nuget.exe." + exit 1 + fi +fi + +# Restore tools from NuGet. +pushd "$TOOLS_DIR" >/dev/null +mono "$NUGET_EXE" install "$PACKAGES_CONFIG" -ExcludeVersion -OutputDirectory "$TOOLS_DIR" +if [ $? -ne 0 ]; then + echo "Could not restore NuGet packages." + exit 1 +fi +popd >/dev/null + +# Make sure that Cake has been installed. +if [ ! -f "$CAKE_EXE" ]; then + echo "Could not find Cake.exe at '$CAKE_EXE'." + exit 1 +fi + +# Start Cake +if $SHOW_VERSION; then + exec mono "$CAKE_EXE" -version +else + exec mono "$CAKE_EXE" $SCRIPT -verbosity=$VERBOSITY $DRYRUN "${SCRIPT_ARGUMENTS[@]}" +fi diff --git a/scripts/packages.config b/scripts/packages.config new file mode 100644 index 00000000..c4feb50f --- /dev/null +++ b/scripts/packages.config @@ -0,0 +1,5 @@ + + + + + diff --git a/scripts/runhelpers.cake b/scripts/runhelpers.cake new file mode 100644 index 00000000..03499601 --- /dev/null +++ b/scripts/runhelpers.cake @@ -0,0 +1,204 @@ +using System.Collections.Generic; +using System.Diagnostics; + +/// +/// Class encompassing the optional settings for running processes. +/// +public class RunOptions +{ + /// + /// The working directory of the process. + /// + public string WorkingDirectory { get; set; } + /// + /// Container logging the StandardOutput content. + /// + public IList StandardOutputListing { get; set; } + /// + /// Desired maximum time-out for the process + /// + public int TimeOut { get; set; } +} + +/// +/// Wrapper for the exit code and state. +/// Used to query the result of an execution with method calls. +/// +public struct ExitStatus +{ + private int _code; + private bool _timeOut; + /// + /// Default constructor when the execution finished. + /// + /// The exit code + public ExitStatus(int code) + { + this._code = code; + this._timeOut = false; + } + /// + /// Default constructor when the execution potentially timed out. + /// + /// The exit code + /// True if the execution timed out + public ExitStatus(int code, bool timeOut) + { + this._code = code; + this._timeOut = timeOut; + } + /// + /// Flag signalling that the execution timed out. + /// + public bool DidTimeOut { get { return _timeOut; } } + /// + /// Implicit conversion from ExitStatus to the exit code. + /// + /// The exit status + /// The exit code + public static implicit operator int(ExitStatus exitStatus) + { + return exitStatus._code; + } + /// + /// Trigger Exception for non-zero exit code. + /// + /// The message to use in the Exception + /// The exit status for further queries + public ExitStatus ExceptionOnError(string errorMessage) + { + if (this._code != 0) + { + throw new Exception(errorMessage); + } + return this; + } +} + +/// +/// Run the given executable with the given arguments. +/// +/// Executable to run +/// Arguments +/// The exit status for further queries +ExitStatus Run(string exec, string args) +{ + return Run(exec, args, new RunOptions()); +} + +/// +/// Run the given executable with the given arguments. +/// +/// Executable to run +/// Arguments +/// Working directory +/// The exit status for further queries +ExitStatus Run(string exec, string args, string workingDirectory) +{ + return Run(exec, args, + new RunOptions() + { + WorkingDirectory = workingDirectory + }); +} + +/// +/// Run the given executable with the given arguments. +/// +/// Executable to run +/// Arguments +/// Optional settings +/// The exit status for further queries +ExitStatus Run(string exec, string args, RunOptions runOptions) +{ + var workingDirectory = runOptions.WorkingDirectory ?? System.IO.Directory.GetCurrentDirectory(); + var process = System.Diagnostics.Process.Start( + new ProcessStartInfo(exec, args) + { + WorkingDirectory = workingDirectory, + UseShellExecute = false, + RedirectStandardOutput = runOptions.StandardOutputListing != null + }); + if (runOptions.StandardOutputListing != null) + { + process.OutputDataReceived += (s, e) => + { + if (e.Data != null) + { + runOptions.StandardOutputListing.Add(e.Data); + } + }; + process.BeginOutputReadLine(); + } + if (runOptions.TimeOut == 0) + { + process.WaitForExit(); + return new ExitStatus(process.ExitCode); + } + else + { + bool finished = process.WaitForExit(runOptions.TimeOut); + if (finished) + { + return new ExitStatus(process.ExitCode); + } + else + { + KillProcessTree(process); + return new ExitStatus(0, true); + } + } +} + +/// +/// Run restore with the given arguments +/// +/// Executable to run +/// Arguments +/// Optional settings +/// The exit status for further queries +ExitStatus RunRestore(string exec, string args, string workingDirectory) +{ + Information("Restoring packages...."); + var p = StartAndReturnProcess(exec, + new ProcessSettings + { + Arguments = args, + RedirectStandardOutput = true, + WorkingDirectory = workingDirectory + }); + p.WaitForExit(); + var exitCode = p.GetExitCode(); + + if (exitCode == 0) + { + Information("Package restore successful!"); + } + else + { + Error(string.Join("\n", p.GetStandardOutput())); + } + return new ExitStatus(exitCode); +} + +/// +/// Kill the given process and all its child processes. +/// +/// Root process +public void KillProcessTree(Process process) +{ + // Child processes are not killed on Windows by default + // Use TASKKILL to kill the process hierarchy rooted in the process + if (IsRunningOnWindows()) + { + StartProcess($"TASKKILL", + new ProcessSettings + { + Arguments = $"/PID {process.Id} /T /F", + }); + } + else + { + process.Kill(); + } +} diff --git a/sqltoolsservice.sln b/sqltoolsservice.sln new file mode 100644 index 00000000..cd55b538 --- /dev/null +++ b/sqltoolsservice.sln @@ -0,0 +1,43 @@ +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio 14 +VisualStudioVersion = 14.0.25420.1 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{2BBD7364-054F-4693-97CD-1C395E3E84A9}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "test", "test", "{AB9CA2B8-6F70-431C-8A1D-67479D8A7BE4}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{32DC973E-9EEA-4694-B1C2-B031167AB945}" + ProjectSection(SolutionItems) = preProject + .gitignore = .gitignore + global.json = global.json + nuget.config = nuget.config + README.md = README.md + EndProjectSection +EndProject +Project("{8BB2217D-0F2D-49D1-97BC-3654ED321F3B}") = "Microsoft.SqlTools.ServiceLayer", "src\Microsoft.SqlTools.ServiceLayer\Microsoft.SqlTools.ServiceLayer.xproj", "{0D61DC2B-DA66-441D-B9D0-F76C98F780F9}" +EndProject +Project("{8BB2217D-0F2D-49D1-97BC-3654ED321F3B}") = "Microsoft.SqlTools.ServiceLayer.Test", "test\Microsoft.SqlTools.ServiceLayer.Test\Microsoft.SqlTools.ServiceLayer.Test.xproj", "{2D771D16-9D85-4053-9F79-E2034737DEEF}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {0D61DC2B-DA66-441D-B9D0-F76C98F780F9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {0D61DC2B-DA66-441D-B9D0-F76C98F780F9}.Debug|Any CPU.Build.0 = Debug|Any CPU + {0D61DC2B-DA66-441D-B9D0-F76C98F780F9}.Release|Any CPU.ActiveCfg = Release|Any CPU + {0D61DC2B-DA66-441D-B9D0-F76C98F780F9}.Release|Any CPU.Build.0 = Release|Any CPU + {2D771D16-9D85-4053-9F79-E2034737DEEF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {2D771D16-9D85-4053-9F79-E2034737DEEF}.Debug|Any CPU.Build.0 = Debug|Any CPU + {2D771D16-9D85-4053-9F79-E2034737DEEF}.Release|Any CPU.ActiveCfg = Release|Any CPU + {2D771D16-9D85-4053-9F79-E2034737DEEF}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(NestedProjects) = preSolution + {0D61DC2B-DA66-441D-B9D0-F76C98F780F9} = {2BBD7364-054F-4693-97CD-1C395E3E84A9} + {2D771D16-9D85-4053-9F79-E2034737DEEF} = {AB9CA2B8-6F70-431C-8A1D-67479D8A7BE4} + EndGlobalSection +EndGlobal diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs new file mode 100644 index 00000000..31d0026d --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs @@ -0,0 +1,53 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Data.Common; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Connection +{ + /// + /// Information pertaining to a unique connection instance. + /// + public class ConnectionInfo + { + /// + /// Constructor + /// + public ConnectionInfo(ISqlConnectionFactory factory, string ownerUri, ConnectionDetails details) + { + Factory = factory; + OwnerUri = ownerUri; + ConnectionDetails = details; + ConnectionId = Guid.NewGuid(); + } + + /// + /// Unique Id, helpful to identify a connection info object + /// + public Guid ConnectionId { get; private set; } + + /// + /// URI identifying the owner/user of the connection. Could be a file, service, resource, etc. + /// + public string OwnerUri { get; private set; } + + /// + /// Factory used for creating the SQL connection associated with the connection info. + /// + public ISqlConnectionFactory Factory {get; private set;} + + /// + /// Properties used for creating/opening the SQL connection. + /// + public ConnectionDetails ConnectionDetails { get; private set; } + + /// + /// The connection to the SQL database that commands will be run against. + /// + public DbConnection SqlConnection { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs new file mode 100644 index 00000000..57a7ba6e --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs @@ -0,0 +1,533 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Data.SqlClient; +using System.Threading.Tasks; +using Microsoft.SqlTools.EditorServices.Utility; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.SqlContext; +using Microsoft.SqlTools.ServiceLayer.Workspace; + +namespace Microsoft.SqlTools.ServiceLayer.Connection +{ + /// + /// Main class for the Connection Management services + /// + public class ConnectionService + { + /// + /// Singleton service instance + /// + private static Lazy instance + = new Lazy(() => new ConnectionService()); + + /// + /// Gets the singleton service instance + /// + public static ConnectionService Instance + { + get + { + return instance.Value; + } + } + + /// + /// The SQL connection factory object + /// + private ISqlConnectionFactory connectionFactory; + + private Dictionary ownerToConnectionMap = new Dictionary(); + + /// + /// Service host object for sending/receiving requests/events. + /// Internal for testing purposes. + /// + internal IProtocolEndpoint ServiceHost + { + get; + set; + } + + /// + /// Default constructor is private since it's a singleton class + /// + private ConnectionService() + { + } + + /// + /// Callback for onconnection handler + /// + /// + public delegate Task OnConnectionHandler(ConnectionInfo info); + + /// + // Callback for ondisconnect handler + /// + public delegate Task OnDisconnectHandler(ConnectionSummary summary); + + /// + /// List of onconnection handlers + /// + private readonly List onConnectionActivities = new List(); + + /// + /// List of ondisconnect handlers + /// + private readonly List onDisconnectActivities = new List(); + + /// + /// Gets the SQL connection factory instance + /// + public ISqlConnectionFactory ConnectionFactory + { + get + { + if (this.connectionFactory == null) + { + this.connectionFactory = new SqlConnectionFactory(); + } + return this.connectionFactory; + } + } + + /// + /// Test constructor that injects dependency interfaces + /// + /// + public ConnectionService(ISqlConnectionFactory testFactory) + { + this.connectionFactory = testFactory; + } + + // Attempts to link a URI to an actively used connection for this URI + public bool TryFindConnection(string ownerUri, out ConnectionInfo connectionInfo) + { + return this.ownerToConnectionMap.TryGetValue(ownerUri, out connectionInfo); + } + + /// + /// Open a connection with the specified connection details + /// + /// + public ConnectResponse Connect(ConnectParams connectionParams) + { + // Validate parameters + string paramValidationErrorMessage; + if (connectionParams == null) + { + return new ConnectResponse() + { + Messages = "Error: Connection parameters cannot be null." + }; + } + else if (!connectionParams.IsValid(out paramValidationErrorMessage)) + { + return new ConnectResponse() + { + Messages = paramValidationErrorMessage + }; + } + + // Resolve if it is an existing connection + // Disconnect active connection if the URI is already connected + ConnectionInfo connectionInfo; + if (ownerToConnectionMap.TryGetValue(connectionParams.OwnerUri, out connectionInfo) ) + { + var disconnectParams = new DisconnectParams() + { + OwnerUri = connectionParams.OwnerUri + }; + Disconnect(disconnectParams); + } + connectionInfo = new ConnectionInfo(ConnectionFactory, connectionParams.OwnerUri, connectionParams.Connection); + + // try to connect + var response = new ConnectResponse(); + try + { + // build the connection string from the input parameters + string connectionString = ConnectionService.BuildConnectionString(connectionInfo.ConnectionDetails); + + // create a sql connection instance + connectionInfo.SqlConnection = connectionInfo.Factory.CreateSqlConnection(connectionString); + connectionInfo.SqlConnection.Open(); + } + catch(Exception ex) + { + response.Messages = ex.ToString(); + return response; + } + + ownerToConnectionMap[connectionParams.OwnerUri] = connectionInfo; + + // invoke callback notifications + foreach (var activity in this.onConnectionActivities) + { + activity(connectionInfo); + } + + // return the connection result + response.ConnectionId = connectionInfo.ConnectionId.ToString(); + return response; + } + + /// + /// Close a connection with the specified connection details. + /// + public bool Disconnect(DisconnectParams disconnectParams) + { + // Validate parameters + if (disconnectParams == null || string.IsNullOrEmpty(disconnectParams.OwnerUri)) + { + return false; + } + + // Lookup the connection owned by the URI + ConnectionInfo info; + if (!ownerToConnectionMap.TryGetValue(disconnectParams.OwnerUri, out info)) + { + return false; + } + + // Close the connection + info.SqlConnection.Close(); + + // Remove URI mapping + ownerToConnectionMap.Remove(disconnectParams.OwnerUri); + + // Invoke callback notifications + foreach (var activity in this.onDisconnectActivities) + { + activity(info.ConnectionDetails); + } + + // Success + return true; + } + + /// + /// List all databases on the server specified + /// + public ListDatabasesResponse ListDatabases(ListDatabasesParams listDatabasesParams) + { + // Verify parameters + var owner = listDatabasesParams.OwnerUri; + if (string.IsNullOrEmpty(owner)) + { + throw new ArgumentException("OwnerUri cannot be null or empty"); + } + + // Use the existing connection as a base for the search + ConnectionInfo info; + if (!TryFindConnection(owner, out info)) + { + throw new Exception("Specified OwnerUri \"" + owner + "\" does not have an existing connection"); + } + ConnectionDetails connectionDetails = info.ConnectionDetails.Clone(); + + // Connect to master and query sys.databases + connectionDetails.DatabaseName = "master"; + var connection = this.ConnectionFactory.CreateSqlConnection(BuildConnectionString(connectionDetails)); + connection.Open(); + + DbCommand command = connection.CreateCommand(); + command.CommandText = "SELECT name FROM sys.databases"; + command.CommandTimeout = 15; + command.CommandType = CommandType.Text; + var reader = command.ExecuteReader(); + + List results = new List(); + while (reader.Read()) + { + results.Add(reader[0].ToString()); + } + + connection.Close(); + + ListDatabasesResponse response = new ListDatabasesResponse(); + response.DatabaseNames = results.ToArray(); + + return response; + } + + public void InitializeService(IProtocolEndpoint serviceHost) + { + this.ServiceHost = serviceHost; + + // Register request and event handlers with the Service Host + serviceHost.SetRequestHandler(ConnectionRequest.Type, HandleConnectRequest); + serviceHost.SetRequestHandler(DisconnectRequest.Type, HandleDisconnectRequest); + serviceHost.SetRequestHandler(ListDatabasesRequest.Type, HandleListDatabasesRequest); + + // Register the configuration update handler + WorkspaceService.Instance.RegisterConfigChangeCallback(HandleDidChangeConfigurationNotification); + } + + /// + /// Add a new method to be called when the onconnection request is submitted + /// + /// + public void RegisterOnConnectionTask(OnConnectionHandler activity) + { + onConnectionActivities.Add(activity); + } + + /// + /// Add a new method to be called when the ondisconnect request is submitted + /// + public void RegisterOnDisconnectTask(OnDisconnectHandler activity) + { + onDisconnectActivities.Add(activity); + } + + /// + /// Handle new connection requests + /// + /// + /// + /// + protected async Task HandleConnectRequest( + ConnectParams connectParams, + RequestContext requestContext) + { + Logger.Write(LogLevel.Verbose, "HandleConnectRequest"); + + try + { + // open connection base on request details + ConnectResponse result = ConnectionService.Instance.Connect(connectParams); + await requestContext.SendResult(result); + } + catch(Exception ex) + { + await requestContext.SendError(ex.ToString()); + } + } + + /// + /// Handle disconnect requests + /// + protected async Task HandleDisconnectRequest( + DisconnectParams disconnectParams, + RequestContext requestContext) + { + Logger.Write(LogLevel.Verbose, "HandleDisconnectRequest"); + + try + { + bool result = ConnectionService.Instance.Disconnect(disconnectParams); + await requestContext.SendResult(result); + } + catch(Exception ex) + { + await requestContext.SendError(ex.ToString()); + } + + } + + /// + /// Handle requests to list databases on the current server + /// + protected async Task HandleListDatabasesRequest( + ListDatabasesParams listDatabasesParams, + RequestContext requestContext) + { + Logger.Write(LogLevel.Verbose, "ListDatabasesRequest"); + + try + { + ListDatabasesResponse result = ConnectionService.Instance.ListDatabases(listDatabasesParams); + await requestContext.SendResult(result); + } + catch(Exception ex) + { + await requestContext.SendError(ex.ToString()); + } + } + + public Task HandleDidChangeConfigurationNotification( + SqlToolsSettings newSettings, + SqlToolsSettings oldSettings, + EventContext eventContext) + { + return Task.FromResult(true); + } + + /// + /// Build a connection string from a connection details instance + /// + /// + public static string BuildConnectionString(ConnectionDetails connectionDetails) + { + SqlConnectionStringBuilder connectionBuilder = new SqlConnectionStringBuilder(); + connectionBuilder["Data Source"] = connectionDetails.ServerName; + connectionBuilder["User Id"] = connectionDetails.UserName; + connectionBuilder["Password"] = connectionDetails.Password; + + // Check for any optional parameters + if (!string.IsNullOrEmpty(connectionDetails.DatabaseName)) + { + connectionBuilder["Initial Catalog"] = connectionDetails.DatabaseName; + } + if (!string.IsNullOrEmpty(connectionDetails.AuthenticationType)) + { + switch(connectionDetails.AuthenticationType) + { + case "Integrated": + connectionBuilder.IntegratedSecurity = true; + break; + case "SqlLogin": + break; + default: + throw new ArgumentException(string.Format("Invalid value \"{0}\" for AuthenticationType. Valid values are \"Integrated\" and \"SqlLogin\".", connectionDetails.AuthenticationType)); + } + } + if (connectionDetails.Encrypt.HasValue) + { + connectionBuilder.Encrypt = connectionDetails.Encrypt.Value; + } + if (connectionDetails.TrustServerCertificate.HasValue) + { + connectionBuilder.TrustServerCertificate = connectionDetails.TrustServerCertificate.Value; + } + if (connectionDetails.PersistSecurityInfo.HasValue) + { + connectionBuilder.PersistSecurityInfo = connectionDetails.PersistSecurityInfo.Value; + } + if (connectionDetails.ConnectTimeout.HasValue) + { + connectionBuilder.ConnectTimeout = connectionDetails.ConnectTimeout.Value; + } + if (connectionDetails.ConnectRetryCount.HasValue) + { + connectionBuilder.ConnectRetryCount = connectionDetails.ConnectRetryCount.Value; + } + if (connectionDetails.ConnectRetryInterval.HasValue) + { + connectionBuilder.ConnectRetryInterval = connectionDetails.ConnectRetryInterval.Value; + } + if (!string.IsNullOrEmpty(connectionDetails.ApplicationName)) + { + connectionBuilder.ApplicationName = connectionDetails.ApplicationName; + } + if (!string.IsNullOrEmpty(connectionDetails.WorkstationId)) + { + connectionBuilder.WorkstationID = connectionDetails.WorkstationId; + } + if (!string.IsNullOrEmpty(connectionDetails.ApplicationIntent)) + { + ApplicationIntent intent; + switch (connectionDetails.ApplicationIntent) + { + case "ReadOnly": + intent = ApplicationIntent.ReadOnly; + break; + case "ReadWrite": + intent = ApplicationIntent.ReadWrite; + break; + default: + throw new ArgumentException(string.Format("Invalid value \"{0}\" for ApplicationIntent. Valid values are \"ReadWrite\" and \"ReadOnly\".", connectionDetails.ApplicationIntent)); + } + connectionBuilder.ApplicationIntent = intent; + } + if (!string.IsNullOrEmpty(connectionDetails.CurrentLanguage)) + { + connectionBuilder.CurrentLanguage = connectionDetails.CurrentLanguage; + } + if (connectionDetails.Pooling.HasValue) + { + connectionBuilder.Pooling = connectionDetails.Pooling.Value; + } + if (connectionDetails.MaxPoolSize.HasValue) + { + connectionBuilder.MaxPoolSize = connectionDetails.MaxPoolSize.Value; + } + if (connectionDetails.MinPoolSize.HasValue) + { + connectionBuilder.MinPoolSize = connectionDetails.MinPoolSize.Value; + } + if (connectionDetails.LoadBalanceTimeout.HasValue) + { + connectionBuilder.LoadBalanceTimeout = connectionDetails.LoadBalanceTimeout.Value; + } + if (connectionDetails.Replication.HasValue) + { + connectionBuilder.Replication = connectionDetails.Replication.Value; + } + if (!string.IsNullOrEmpty(connectionDetails.AttachDbFilename)) + { + connectionBuilder.AttachDBFilename = connectionDetails.AttachDbFilename; + } + if (!string.IsNullOrEmpty(connectionDetails.FailoverPartner)) + { + connectionBuilder.FailoverPartner = connectionDetails.FailoverPartner; + } + if (connectionDetails.MultiSubnetFailover.HasValue) + { + connectionBuilder.MultiSubnetFailover = connectionDetails.MultiSubnetFailover.Value; + } + if (connectionDetails.MultipleActiveResultSets.HasValue) + { + connectionBuilder.MultipleActiveResultSets = connectionDetails.MultipleActiveResultSets.Value; + } + if (connectionDetails.PacketSize.HasValue) + { + connectionBuilder.PacketSize = connectionDetails.PacketSize.Value; + } + if (!string.IsNullOrEmpty(connectionDetails.TypeSystemVersion)) + { + connectionBuilder.TypeSystemVersion = connectionDetails.TypeSystemVersion; + } + + return connectionBuilder.ToString(); + } + + /// + /// Change the database context of a connection. + /// + /// URI of the owner of the connection + /// Name of the database to change the connection to + public void ChangeConnectionDatabaseContext(string ownerUri, string newDatabaseName) + { + ConnectionInfo info; + if (TryFindConnection(ownerUri, out info)) + { + try + { + if (info.SqlConnection.State == ConnectionState.Open) + { + info.SqlConnection.ChangeDatabase(newDatabaseName); + } + info.ConnectionDetails.DatabaseName = newDatabaseName; + + // Fire a connection changed event + ConnectionChangedParams parameters = new ConnectionChangedParams(); + ConnectionSummary summary = (ConnectionSummary)(info.ConnectionDetails); + parameters.Connection = summary.Clone(); + parameters.OwnerUri = ownerUri; + ServiceHost.SendEvent(ConnectionChangedNotification.Type, parameters); + } + catch (Exception e) + { + Logger.Write( + LogLevel.Error, + string.Format( + "Exception caught while trying to change database context to [{0}] for OwnerUri [{1}]. Exception:{2}", + newDatabaseName, + ownerUri, + e.ToString()) + ); + } + } + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs new file mode 100644 index 00000000..31dad8c5 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs @@ -0,0 +1,26 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Parameters for the Connect Request. + /// + public class ConnectParams + { + /// + /// A URI identifying the owner of the connection. This will most commonly be a file in the workspace + /// or a virtual file representing an object in a database. + /// + public string OwnerUri { get; set; } + /// + /// Contains the required parameters to initialize a connection to a database. + /// A connection will identified by its server name, database name and user name. + /// This may be changed in the future to support multiple connections with different + /// connection properties to the same database. + /// + public ConnectionDetails Connection { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParamsExtensions.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParamsExtensions.cs new file mode 100644 index 00000000..9f2c7356 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParamsExtensions.cs @@ -0,0 +1,56 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Extension methods to ConnectParams + /// + public static class ConnectParamsExtensions + { + /// + /// Check that the fields in ConnectParams are all valid + /// + public static bool IsValid(this ConnectParams parameters, out string errorMessage) + { + errorMessage = string.Empty; + if (string.IsNullOrEmpty(parameters.OwnerUri)) + { + errorMessage = "Error: OwnerUri cannot be null or empty."; + } + else if (parameters.Connection == null) + { + errorMessage = "Error: Connection details object cannot be null."; + } + else if (string.IsNullOrEmpty(parameters.Connection.ServerName)) + { + errorMessage = "Error: ServerName cannot be null or empty."; + } + else if (string.IsNullOrEmpty(parameters.Connection.AuthenticationType) || parameters.Connection.AuthenticationType == "SqlLogin") + { + // For SqlLogin, username/password cannot be empty + if (string.IsNullOrEmpty(parameters.Connection.UserName)) + { + errorMessage = "Error: UserName cannot be null or empty when using SqlLogin authentication."; + } + else if( string.IsNullOrEmpty(parameters.Connection.Password)) + { + errorMessage = "Error: Password cannot be null or empty when using SqlLogin authentication."; + } + } + + if (string.IsNullOrEmpty(errorMessage)) + { + return true; + } + else + { + return false; + } + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectResponse.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectResponse.cs new file mode 100644 index 00000000..c325c64f --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectResponse.cs @@ -0,0 +1,23 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Message format for the connection result response + /// + public class ConnectResponse + { + /// + /// A GUID representing a unique connection ID + /// + public string ConnectionId { get; set; } + + /// + /// Gets or sets any connection error messages + /// + public string Messages { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedNotification.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedNotification.cs new file mode 100644 index 00000000..c0daee6d --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedNotification.cs @@ -0,0 +1,19 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// ConnectionChanged notification mapping entry + /// + public class ConnectionChangedNotification + { + public static readonly + EventType Type = + EventType.Create("connection/connectionchanged"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedParams.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedParams.cs new file mode 100644 index 00000000..3db86f34 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedParams.cs @@ -0,0 +1,23 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Parameters for the ConnectionChanged Notification. + /// + public class ConnectionChangedParams + { + /// + /// A URI identifying the owner of the connection. This will most commonly be a file in the workspace + /// or a virtual file representing an object in a database. + /// + public string OwnerUri { get; set; } + /// + /// Contains the high-level properties about the connection, for display to the user. + /// + public ConnectionSummary Connection { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs new file mode 100644 index 00000000..ce1c6208 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs @@ -0,0 +1,132 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Message format for the initial connection request + /// + /// + /// If this contract is ever changed, be sure to update ConnectionDetailsExtensions methods. + /// + public class ConnectionDetails : ConnectionSummary + { + /// + /// Gets or sets the connection password + /// + /// + public string Password { get; set; } + + /// + /// Gets or sets the authentication to use. + /// + public string AuthenticationType { get; set; } + + /// + /// Gets or sets a Boolean value that indicates whether SQL Server uses SSL encryption for all data sent between the client and server if the server has a certificate installed. + /// + public bool? Encrypt { get; set; } + + /// + /// Gets or sets a value that indicates whether the channel will be encrypted while bypassing walking the certificate chain to validate trust. + /// + public bool? TrustServerCertificate { get; set; } + + /// + /// Gets or sets a Boolean value that indicates if security-sensitive information, such as the password, is not returned as part of the connection if the connection is open or has ever been in an open state. + /// + public bool? PersistSecurityInfo { get; set; } + + /// + /// Gets or sets the length of time (in seconds) to wait for a connection to the server before terminating the attempt and generating an error. + /// + public int? ConnectTimeout { get; set; } + + /// + /// The number of reconnections attempted after identifying that there was an idle connection failure. + /// + public int? ConnectRetryCount { get; set; } + + /// + /// Amount of time (in seconds) between each reconnection attempt after identifying that there was an idle connection failure. + /// + public int? ConnectRetryInterval { get; set; } + + /// + /// Gets or sets the name of the application associated with the connection string. + /// + public string ApplicationName { get; set; } + + /// + /// Gets or sets the name of the workstation connecting to SQL Server. + /// + public string WorkstationId { get; set; } + + /// + /// Declares the application workload type when connecting to a database in an SQL Server Availability Group. + /// + public string ApplicationIntent { get; set; } + + /// + /// Gets or sets the SQL Server Language record name. + /// + public string CurrentLanguage { get; set; } + + /// + /// Gets or sets a Boolean value that indicates whether the connection will be pooled or explicitly opened every time that the connection is requested. + /// + public bool? Pooling { get; set; } + + /// + /// Gets or sets the maximum number of connections allowed in the connection pool for this specific connection string. + /// + public int? MaxPoolSize { get; set; } + + /// + /// Gets or sets the minimum number of connections allowed in the connection pool for this specific connection string. + /// + public int? MinPoolSize { get; set; } + + /// + /// Gets or sets the minimum time, in seconds, for the connection to live in the connection pool before being destroyed. + /// + public int? LoadBalanceTimeout { get; set; } + + /// + /// Gets or sets a Boolean value that indicates whether replication is supported using the connection. + /// + public bool? Replication { get; set; } + + /// + /// Gets or sets a string that contains the name of the primary data file. This includes the full path name of an attachable database. + /// + public string AttachDbFilename { get; set; } + + /// + /// Gets or sets the name or address of the partner server to connect to if the primary server is down. + /// + public string FailoverPartner { get; set; } + + /// + /// If your application is connecting to an AlwaysOn availability group (AG) on different subnets, setting MultiSubnetFailover=true provides faster detection of and connection to the (currently) active server. + /// + public bool? MultiSubnetFailover { get; set; } + + /// + /// When true, an application can maintain multiple active result sets (MARS). + /// + public bool? MultipleActiveResultSets { get; set; } + + /// + /// Gets or sets the size in bytes of the network packets used to communicate with an instance of SQL Server. + /// + public int? PacketSize { get; set; } + + /// + /// Gets or sets a string value that indicates the type system the application expects. + /// + public string TypeSystemVersion { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetailsExtensions.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetailsExtensions.cs new file mode 100644 index 00000000..106fa06e --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetailsExtensions.cs @@ -0,0 +1,49 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Extension methods for the ConnectionDetails contract class + /// + public static class ConnectionDetailsExtensions + { + /// + /// Create a copy of a connection details object. + /// + public static ConnectionDetails Clone(this ConnectionDetails details) + { + return new ConnectionDetails() + { + ServerName = details.ServerName, + DatabaseName = details.DatabaseName, + UserName = details.UserName, + Password = details.Password, + AuthenticationType = details.AuthenticationType, + Encrypt = details.Encrypt, + TrustServerCertificate = details.TrustServerCertificate, + PersistSecurityInfo = details.PersistSecurityInfo, + ConnectTimeout = details.ConnectTimeout, + ConnectRetryCount = details.ConnectRetryCount, + ConnectRetryInterval = details.ConnectRetryInterval, + ApplicationName = details.ApplicationName, + WorkstationId = details.WorkstationId, + ApplicationIntent = details.ApplicationIntent, + CurrentLanguage = details.CurrentLanguage, + Pooling = details.Pooling, + MaxPoolSize = details.MaxPoolSize, + MinPoolSize = details.MinPoolSize, + LoadBalanceTimeout = details.LoadBalanceTimeout, + Replication = details.Replication, + AttachDbFilename = details.AttachDbFilename, + FailoverPartner = details.FailoverPartner, + MultiSubnetFailover = details.MultiSubnetFailover, + MultipleActiveResultSets = details.MultipleActiveResultSets, + PacketSize = details.PacketSize, + TypeSystemVersion = details.TypeSystemVersion + }; + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionRequest.cs new file mode 100644 index 00000000..50251e12 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionRequest.cs @@ -0,0 +1,19 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Connect request mapping entry + /// + public class ConnectionRequest + { + public static readonly + RequestType Type = + RequestType.Create("connection/connect"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummary.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummary.cs new file mode 100644 index 00000000..11549e85 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummary.cs @@ -0,0 +1,28 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Provides high level information about a connection. + /// + public class ConnectionSummary + { + /// + /// Gets or sets the connection server name + /// + public string ServerName { get; set; } + + /// + /// Gets or sets the connection database name + /// + public string DatabaseName { get; set; } + + /// + /// Gets or sets the connection user name + /// + public string UserName { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryComparer.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryComparer.cs new file mode 100644 index 00000000..dfeb0ab4 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryComparer.cs @@ -0,0 +1,53 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + + /// + /// Treats connections as the same if their server, db and usernames all match + /// + public class ConnectionSummaryComparer : IEqualityComparer + { + public bool Equals(ConnectionSummary x, ConnectionSummary y) + { + if(x == y) { return true; } + else if(x != null) + { + if(y == null) { return false; } + + // Compare server, db, username. Note: server is case-insensitive in the driver + return string.Compare(x.ServerName, y.ServerName, StringComparison.OrdinalIgnoreCase) == 0 + && string.Compare(x.DatabaseName, y.DatabaseName, StringComparison.Ordinal) == 0 + && string.Compare(x.UserName, y.UserName, StringComparison.Ordinal) == 0; + } + return false; + } + + public int GetHashCode(ConnectionSummary obj) + { + int hashcode = 31; + if(obj != null) + { + if(obj.ServerName != null) + { + hashcode ^= obj.ServerName.GetHashCode(); + } + if (obj.DatabaseName != null) + { + hashcode ^= obj.DatabaseName.GetHashCode(); + } + if (obj.UserName != null) + { + hashcode ^= obj.UserName.GetHashCode(); + } + } + return hashcode; + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryExtensions.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryExtensions.cs new file mode 100644 index 00000000..02bc7623 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryExtensions.cs @@ -0,0 +1,26 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Extension methods to ConnectionSummary + /// + public static class ConnectionSummaryExtensions + { + /// + /// Create a copy of a ConnectionSummary object + /// + public static ConnectionSummary Clone(this ConnectionSummary summary) + { + return new ConnectionSummary() + { + ServerName = summary.ServerName, + DatabaseName = summary.DatabaseName, + UserName = summary.UserName + }; + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectParams.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectParams.cs new file mode 100644 index 00000000..91bc7faf --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectParams.cs @@ -0,0 +1,19 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Parameters for the Disconnect Request. + /// + public class DisconnectParams + { + /// + /// A URI identifying the owner of the connection. This will most commonly be a file in the workspace + /// or a virtual file representing an object in a database. + /// + public string OwnerUri { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectRequest.cs new file mode 100644 index 00000000..cbf67ef2 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectRequest.cs @@ -0,0 +1,19 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Disconnect request mapping entry + /// + public class DisconnectRequest + { + public static readonly + RequestType Type = + RequestType.Create("connection/disconnect"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ListDatabasesParams.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ListDatabasesParams.cs new file mode 100644 index 00000000..fa607e75 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ListDatabasesParams.cs @@ -0,0 +1,18 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Parameters for the List Databases Request. + /// + public class ListDatabasesParams + { + /// + /// URI of the owner of the connection requesting the list of databases. + /// + public string OwnerUri { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ListDatabasesRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ListDatabasesRequest.cs new file mode 100644 index 00000000..01c12a45 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ListDatabasesRequest.cs @@ -0,0 +1,19 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// List databases request mapping entry + /// + public class ListDatabasesRequest + { + public static readonly + RequestType Type = + RequestType.Create("connection/listdatabases"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ListDatabasesResponse.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ListDatabasesResponse.cs new file mode 100644 index 00000000..68610803 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ListDatabasesResponse.cs @@ -0,0 +1,18 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Message format for the list databases response + /// + public class ListDatabasesResponse + { + /// + /// Gets or sets the list of database names. + /// + public string[] DatabaseNames { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ISqlConnectionFactory.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ISqlConnectionFactory.cs new file mode 100644 index 00000000..ed0cc01b --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ISqlConnectionFactory.cs @@ -0,0 +1,20 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Data.Common; + +namespace Microsoft.SqlTools.ServiceLayer.Connection +{ + /// + /// Interface for the SQL Connection factory + /// + public interface ISqlConnectionFactory + { + /// + /// Create a new SQL Connection object + /// + DbConnection CreateSqlConnection(string connectionString); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/SqlConnectionFactory.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/SqlConnectionFactory.cs new file mode 100644 index 00000000..cffb690d --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/SqlConnectionFactory.cs @@ -0,0 +1,26 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Data.Common; +using System.Data.SqlClient; + +namespace Microsoft.SqlTools.ServiceLayer.Connection +{ + /// + /// Factory class to create SqlClientConnections + /// The purpose of the factory is to make it easier to mock out the database + /// in 'offline' unit test scenarios. + /// + public class SqlConnectionFactory : ISqlConnectionFactory + { + /// + /// Creates a new SqlConnection object + /// + public DbConnection CreateSqlConnection(string connectionString) + { + return new SqlConnection(connectionString); + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/Contracts/Credential.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Contracts/Credential.cs new file mode 100644 index 00000000..be595ec8 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Contracts/Credential.cs @@ -0,0 +1,124 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.EditorServices.Utility; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Credentials.Contracts +{ + /// + /// A Credential containing information needed to log into a resource. This is primarily + /// defined as a unique with an associated + /// that's linked to it. + /// + public class Credential + { + /// + /// A unique ID to identify the credential being saved. + /// + public string CredentialId { get; set; } + + /// + /// The Password stored for this credential. + /// + public string Password { get; set; } + + /// + /// Default Constructor + /// + public Credential() + { + } + + /// + /// Constructor used when only is known + /// + /// + public Credential(string credentialId) + : this(credentialId, null) + { + + } + + /// + /// Constructor + /// + /// + /// + public Credential(string credentialId, string password) + { + CredentialId = credentialId; + Password = password; + } + + internal static Credential Copy(Credential credential) + { + return new Credential + { + CredentialId = credential.CredentialId, + Password = credential.Password + }; + } + + /// + /// Validates the credential has all the properties needed to look up the password + /// + public static void ValidateForLookup(Credential credential) + { + Validate.IsNotNull("credential", credential); + Validate.IsNotNullOrEmptyString("credential.CredentialId", credential.CredentialId); + } + + + /// + /// Validates the credential has all the properties needed to save a password + /// + public static void ValidateForSave(Credential credential) + { + ValidateForLookup(credential); + Validate.IsNotNullOrEmptyString("credential.Password", credential.Password); + } + } + + /// + /// Read Credential request mapping entry. Expects a Credential with CredentialId, + /// and responds with the filled in if found + /// + public class ReadCredentialRequest + { + /// + /// Request definition + /// + public static readonly + RequestType Type = + RequestType.Create("credential/read"); + } + + /// + /// Save Credential request mapping entry + /// + public class SaveCredentialRequest + { + /// + /// Request definition + /// + public static readonly + RequestType Type = + RequestType.Create("credential/save"); + } + + /// + /// Delete Credential request mapping entry + /// + public class DeleteCredentialRequest + { + /// + /// Request definition + /// + public static readonly + RequestType Type = + RequestType.Create("credential/delete"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/CredentialService.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/CredentialService.cs new file mode 100644 index 00000000..f1a80807 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/CredentialService.cs @@ -0,0 +1,151 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Runtime.InteropServices; +using System.Threading.Tasks; +using Microsoft.SqlTools.EditorServices.Utility; +using Microsoft.SqlTools.ServiceLayer.Credentials.Contracts; +using Microsoft.SqlTools.ServiceLayer.Credentials.Linux; +using Microsoft.SqlTools.ServiceLayer.Credentials.OSX; +using Microsoft.SqlTools.ServiceLayer.Credentials.Win32; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; + +namespace Microsoft.SqlTools.ServiceLayer.Credentials +{ + /// + /// Service responsible for securing credentials in a platform-neutral manner. This provides + /// a generic API for read, save and delete credentials + /// + public class CredentialService + { + internal static string DefaultSecretsFolder = ".sqlsecrets"; + internal const string DefaultSecretsFile = "sqlsecrets.json"; + + + /// + /// Singleton service instance + /// + private static Lazy instance + = new Lazy(() => new CredentialService()); + + /// + /// Gets the singleton service instance + /// + public static CredentialService Instance + { + get + { + return instance.Value; + } + } + + private ICredentialStore credStore; + + /// + /// Default constructor is private since it's a singleton class + /// + private CredentialService() + : this(null, new LinuxCredentialStore.StoreConfig() + { CredentialFolder = DefaultSecretsFolder, CredentialFile = DefaultSecretsFile, IsRelativeToUserHomeDir = true}) + { + } + + /// + /// Internal for testing purposes only + /// + internal CredentialService(ICredentialStore store, LinuxCredentialStore.StoreConfig config) + { + this.credStore = store != null ? store : GetStoreForOS(config); + } + + /// + /// Internal for testing purposes only + /// + internal static ICredentialStore GetStoreForOS(LinuxCredentialStore.StoreConfig config) + { + if(RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + return new Win32CredentialStore(); + } + else if(RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + return new OSXCredentialStore(); + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + return new LinuxCredentialStore(config); + } + throw new InvalidOperationException("Platform not currently supported"); + } + + public void InitializeService(IProtocolEndpoint serviceHost) + { + // Register request and event handlers with the Service Host + serviceHost.SetRequestHandler(ReadCredentialRequest.Type, HandleReadCredentialRequest); + serviceHost.SetRequestHandler(SaveCredentialRequest.Type, HandleSaveCredentialRequest); + serviceHost.SetRequestHandler(DeleteCredentialRequest.Type, HandleDeleteCredentialRequest); + } + + public async Task HandleReadCredentialRequest(Credential credential, RequestContext requestContext) + { + Func doRead = () => + { + return ReadCredential(credential); + }; + await HandleRequest(doRead, requestContext, "HandleReadCredentialRequest"); + } + + + private Credential ReadCredential(Credential credential) + { + Credential.ValidateForLookup(credential); + + Credential result = Credential.Copy(credential); + string password; + if (credStore.TryGetPassword(credential.CredentialId, out password)) + { + result.Password = password; + } + return result; + } + + public async Task HandleSaveCredentialRequest(Credential credential, RequestContext requestContext) + { + Func doSave = () => + { + Credential.ValidateForSave(credential); + return credStore.Save(credential); + }; + await HandleRequest(doSave, requestContext, "HandleSaveCredentialRequest"); + } + + public async Task HandleDeleteCredentialRequest(Credential credential, RequestContext requestContext) + { + Func doDelete = () => + { + Credential.ValidateForLookup(credential); + return credStore.DeletePassword(credential.CredentialId); + }; + await HandleRequest(doDelete, requestContext, "HandleDeleteCredentialRequest"); + } + + private async Task HandleRequest(Func handler, RequestContext requestContext, string requestType) + { + Logger.Write(LogLevel.Verbose, requestType); + + try + { + T result = handler(); + await requestContext.SendResult(result); + } + catch (Exception ex) + { + await requestContext.SendError(ex.ToString()); + } + } + + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/ICredentialStore.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/ICredentialStore.cs new file mode 100644 index 00000000..0fa51cdd --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/ICredentialStore.cs @@ -0,0 +1,41 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Credentials.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Credentials +{ + /// + /// An support securely saving and retrieving passwords + /// + public interface ICredentialStore + { + /// + /// Saves a Password linked to a given Credential + /// + /// + /// A to be saved. + /// and are required + /// + /// True if successful, false otherwise + bool Save(Credential credential); + + /// + /// Gets a Password and sets it into a object + /// + /// The name of the credential to find the password for. This is required + /// Out value + /// true if password was found, false otherwise + bool TryGetPassword(string credentialId, out string password); + + /// + /// Deletes a password linked to a given credential + /// + /// The name of the credential to find the password for. This is required + /// True if password existed and was deleted, false otherwise + bool DeletePassword(string credentialId); + + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/InteropUtils.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/InteropUtils.cs new file mode 100644 index 00000000..fdb5343e --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/InteropUtils.cs @@ -0,0 +1,36 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Runtime.InteropServices; +using System.Text; + +namespace Microsoft.SqlTools.ServiceLayer.Credentials +{ + internal static class InteropUtils + { + + /// + /// Gets the length in bytes for a Unicode string, for use in interop where length must be defined + /// + public static UInt32 GetLengthInBytes(string value) + { + + return Convert.ToUInt32( (value != null ? Encoding.Unicode.GetByteCount(value) : 0) ); + } + + public static string CopyToString(IntPtr ptr, int length) + { + if (ptr == IntPtr.Zero || length == 0) + { + return null; + } + byte[] pwdBytes = new byte[length]; + Marshal.Copy(ptr, pwdBytes, 0, (int)length); + return Encoding.Unicode.GetString(pwdBytes, 0, (int)length); + } + + } +} \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/CredentialsWrapper.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/CredentialsWrapper.cs new file mode 100644 index 00000000..3deab819 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/CredentialsWrapper.cs @@ -0,0 +1,18 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Collections.Generic; +using Microsoft.SqlTools.ServiceLayer.Credentials.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Credentials.Linux +{ + /// + /// Simplified class to enable writing a set of credentials to/from disk + /// + public class CredentialsWrapper + { + public List Credentials { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/FileTokenStorage.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/FileTokenStorage.cs new file mode 100644 index 00000000..ef2c2a67 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/FileTokenStorage.cs @@ -0,0 +1,87 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Microsoft.SqlTools.EditorServices.Utility; +using Microsoft.SqlTools.ServiceLayer.Credentials.Contracts; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Newtonsoft.Json; + +namespace Microsoft.SqlTools.ServiceLayer.Credentials.Linux +{ + public class FileTokenStorage + { + private const int OwnerAccessMode = 384; // Permission 0600 - owner read/write, nobody else has access + + private object lockObject = new object(); + + private string fileName; + + public FileTokenStorage(string fileName) + { + Validate.IsNotNullOrEmptyString("fileName", fileName); + this.fileName = fileName; + } + + public void AddEntries(IEnumerable newEntries, IEnumerable existingEntries) + { + var allEntries = existingEntries.Concat(newEntries); + this.SaveEntries(allEntries); + } + + public void Clear() + { + this.SaveEntries(new List()); + } + + public IEnumerable LoadEntries() + { + if(!File.Exists(this.fileName)) + { + return Enumerable.Empty(); + } + + string serializedCreds; + lock (lockObject) + { + serializedCreds = File.ReadAllText(this.fileName); + } + + CredentialsWrapper creds = JsonConvert.DeserializeObject(serializedCreds, Constants.JsonSerializerSettings); + if(creds != null) + { + return creds.Credentials; + } + return Enumerable.Empty(); + } + + public void SaveEntries(IEnumerable entries) + { + CredentialsWrapper credentials = new CredentialsWrapper() { Credentials = entries.ToList() }; + string serializedCreds = JsonConvert.SerializeObject(credentials, Constants.JsonSerializerSettings); + + lock(lockObject) + { + WriteToFile(this.fileName, serializedCreds); + } + } + + private static void WriteToFile(string filePath, string fileContents) + { + string dir = Path.GetDirectoryName(filePath); + if(!Directory.Exists(dir)) + { + Directory.CreateDirectory(dir); + } + + // Overwrite file, then use ChMod to ensure we have + File.WriteAllText(filePath, fileContents); + // set appropriate permissions so only current user can read/write + Interop.Sys.ChMod(filePath, OwnerAccessMode); + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/Interop.Errors.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/Interop.Errors.cs new file mode 100644 index 00000000..f3b1d5f5 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/Interop.Errors.cs @@ -0,0 +1,221 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.SqlTools.ServiceLayer.Credentials +{ + internal static partial class Interop + { + /// Common Unix errno error codes. + internal enum Error + { + // These values were defined in src/Native/System.Native/fxerrno.h + // + // They compare against values obtained via Interop.Sys.GetLastError() not Marshal.GetLastWin32Error() + // which obtains the raw errno that varies between unixes. The strong typing as an enum is meant to + // prevent confusing the two. Casting to or from int is suspect. Use GetLastErrorInfo() if you need to + // correlate these to the underlying platform values or obtain the corresponding error message. + // + + SUCCESS = 0, + + E2BIG = 0x10001, // Argument list too long. + EACCES = 0x10002, // Permission denied. + EADDRINUSE = 0x10003, // Address in use. + EADDRNOTAVAIL = 0x10004, // Address not available. + EAFNOSUPPORT = 0x10005, // Address family not supported. + EAGAIN = 0x10006, // Resource unavailable, try again (same value as EWOULDBLOCK), + EALREADY = 0x10007, // Connection already in progress. + EBADF = 0x10008, // Bad file descriptor. + EBADMSG = 0x10009, // Bad message. + EBUSY = 0x1000A, // Device or resource busy. + ECANCELED = 0x1000B, // Operation canceled. + ECHILD = 0x1000C, // No child processes. + ECONNABORTED = 0x1000D, // Connection aborted. + ECONNREFUSED = 0x1000E, // Connection refused. + ECONNRESET = 0x1000F, // Connection reset. + EDEADLK = 0x10010, // Resource deadlock would occur. + EDESTADDRREQ = 0x10011, // Destination address required. + EDOM = 0x10012, // Mathematics argument out of domain of function. + EDQUOT = 0x10013, // Reserved. + EEXIST = 0x10014, // File exists. + EFAULT = 0x10015, // Bad address. + EFBIG = 0x10016, // File too large. + EHOSTUNREACH = 0x10017, // Host is unreachable. + EIDRM = 0x10018, // Identifier removed. + EILSEQ = 0x10019, // Illegal byte sequence. + EINPROGRESS = 0x1001A, // Operation in progress. + EINTR = 0x1001B, // Interrupted function. + EINVAL = 0x1001C, // Invalid argument. + EIO = 0x1001D, // I/O error. + EISCONN = 0x1001E, // Socket is connected. + EISDIR = 0x1001F, // Is a directory. + ELOOP = 0x10020, // Too many levels of symbolic links. + EMFILE = 0x10021, // File descriptor value too large. + EMLINK = 0x10022, // Too many links. + EMSGSIZE = 0x10023, // Message too large. + EMULTIHOP = 0x10024, // Reserved. + ENAMETOOLONG = 0x10025, // Filename too long. + ENETDOWN = 0x10026, // Network is down. + ENETRESET = 0x10027, // Connection aborted by network. + ENETUNREACH = 0x10028, // Network unreachable. + ENFILE = 0x10029, // Too many files open in system. + ENOBUFS = 0x1002A, // No buffer space available. + ENODEV = 0x1002C, // No such device. + ENOENT = 0x1002D, // No such file or directory. + ENOEXEC = 0x1002E, // Executable file format error. + ENOLCK = 0x1002F, // No locks available. + ENOLINK = 0x10030, // Reserved. + ENOMEM = 0x10031, // Not enough space. + ENOMSG = 0x10032, // No message of the desired type. + ENOPROTOOPT = 0x10033, // Protocol not available. + ENOSPC = 0x10034, // No space left on device. + ENOSYS = 0x10037, // Function not supported. + ENOTCONN = 0x10038, // The socket is not connected. + ENOTDIR = 0x10039, // Not a directory or a symbolic link to a directory. + ENOTEMPTY = 0x1003A, // Directory not empty. + ENOTSOCK = 0x1003C, // Not a socket. + ENOTSUP = 0x1003D, // Not supported (same value as EOPNOTSUP). + ENOTTY = 0x1003E, // Inappropriate I/O control operation. + ENXIO = 0x1003F, // No such device or address. + EOVERFLOW = 0x10040, // Value too large to be stored in data type. + EPERM = 0x10042, // Operation not permitted. + EPIPE = 0x10043, // Broken pipe. + EPROTO = 0x10044, // Protocol error. + EPROTONOSUPPORT = 0x10045, // Protocol not supported. + EPROTOTYPE = 0x10046, // Protocol wrong type for socket. + ERANGE = 0x10047, // Result too large. + EROFS = 0x10048, // Read-only file system. + ESPIPE = 0x10049, // Invalid seek. + ESRCH = 0x1004A, // No such process. + ESTALE = 0x1004B, // Reserved. + ETIMEDOUT = 0x1004D, // Connection timed out. + ETXTBSY = 0x1004E, // Text file busy. + EXDEV = 0x1004F, // Cross-device link. + ESOCKTNOSUPPORT = 0x1005E, // Socket type not supported. + EPFNOSUPPORT = 0x10060, // Protocol family not supported. + ESHUTDOWN = 0x1006C, // Socket shutdown. + EHOSTDOWN = 0x10070, // Host is down. + ENODATA = 0x10071, // No data available. + + // POSIX permits these to have the same value and we make them always equal so + // that CoreFX cannot introduce a dependency on distinguishing between them that + // would not work on all platforms. + EOPNOTSUPP = ENOTSUP, // Operation not supported on socket. + EWOULDBLOCK = EAGAIN, // Operation would block. + } + + + // Represents a platform-agnostic Error and underlying platform-specific errno + internal struct ErrorInfo + { + private Error _error; + private int _rawErrno; + + internal ErrorInfo(int errno) + { + _error = Interop.Sys.ConvertErrorPlatformToPal(errno); + _rawErrno = errno; + } + + internal ErrorInfo(Error error) + { + _error = error; + _rawErrno = -1; + } + + internal Error Error + { + get { return _error; } + } + + internal int RawErrno + { + get { return _rawErrno == -1 ? (_rawErrno = Interop.Sys.ConvertErrorPalToPlatform(_error)) : _rawErrno; } + } + + internal string GetErrorMessage() + { + return Interop.Sys.StrError(RawErrno); + } + + public override string ToString() + { + return string.Format( + "RawErrno: {0} Error: {1} GetErrorMessage: {2}", // No localization required; text is member names used for debugging purposes + RawErrno, Error, GetErrorMessage()); + } + } + + internal partial class Sys + { + internal static Error GetLastError() + { + return ConvertErrorPlatformToPal(Marshal.GetLastWin32Error()); + } + + internal static ErrorInfo GetLastErrorInfo() + { + return new ErrorInfo(Marshal.GetLastWin32Error()); + } + + internal static string StrError(int platformErrno) + { + int maxBufferLength = 1024; // should be long enough for most any UNIX error + IntPtr buffer = Marshal.AllocHGlobal(maxBufferLength); + try + { + IntPtr message = StrErrorR(platformErrno, buffer, maxBufferLength); + + if (message == IntPtr.Zero) + { + // This means the buffer was not large enough, but still contains + // as much of the error message as possible and is guaranteed to + // be null-terminated. We're not currently resizing/retrying because + // maxBufferLength is large enough in practice, but we could do + // so here in the future if necessary. + message = buffer; + } + + string returnMsg = Marshal.PtrToStringAnsi(message); + return returnMsg; + } + finally + { + // Deallocate the buffer we created + Marshal.FreeHGlobal(buffer); + } + } + + [DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_ConvertErrorPlatformToPal")] + internal static extern Error ConvertErrorPlatformToPal(int platformErrno); + + [DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_ConvertErrorPalToPlatform")] + internal static extern int ConvertErrorPalToPlatform(Error error); + + [DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_StrErrorR")] + private static extern IntPtr StrErrorR(int platformErrno, IntPtr buffer, int bufferSize); + } + } + + // NOTE: extension method can't be nested inside Interop class. + internal static class InteropErrorExtensions + { + // Intended usage is e.g. Interop.Error.EFAIL.Info() for brevity + // vs. new Interop.ErrorInfo(Interop.Error.EFAIL) for synthesizing + // errors. Errors originated from the system should be obtained + // via GetLastErrorInfo(), not GetLastError().Info() as that will + // convert twice, which is not only inefficient but also lossy if + // we ever encounter a raw errno that no equivalent in the Error + // enum. + public static Interop.ErrorInfo Info(this Interop.Error error) + { + return new Interop.ErrorInfo(error); + } + } + +} \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/Interop.Sys.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/Interop.Sys.cs new file mode 100644 index 00000000..8777ab0c --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/Interop.Sys.cs @@ -0,0 +1,42 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.SqlTools.ServiceLayer.Credentials +{ + internal static partial class Interop + { + internal static partial class Sys + { + [DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_ChMod", SetLastError = true)] + internal static extern int ChMod(string path, int mode); + + internal struct Passwd + { + internal IntPtr Name; // char* + internal IntPtr Password; // char* + internal uint UserId; + internal uint GroupId; + internal IntPtr UserInfo; // char* + internal IntPtr HomeDirectory; // char* + internal IntPtr Shell; // char* + }; + + [DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_GetPwUidR", SetLastError = false)] + internal static extern int GetPwUidR(uint uid, out Passwd pwd, IntPtr buf, int bufLen); + + [DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_GetEUid")] + internal static extern uint GetEUid(); + + private static partial class Libraries + { + internal const string SystemNative = "System.Native"; + } + } + + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/LinuxCredentialStore.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/LinuxCredentialStore.cs new file mode 100644 index 00000000..6d6b5908 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Linux/LinuxCredentialStore.cs @@ -0,0 +1,231 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Runtime.InteropServices; +using Microsoft.SqlTools.EditorServices.Utility; +using Microsoft.SqlTools.ServiceLayer.Credentials.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Credentials.Linux +{ + /// + /// Linux implementation of the credential store. + /// + /// + /// This entire implementation may need to be revised to support encryption of + /// passwords and protection of them when loaded into memory. + /// + /// + internal class LinuxCredentialStore : ICredentialStore + { + internal struct StoreConfig + { + public string CredentialFolder { get; set; } + public string CredentialFile { get; set; } + public bool IsRelativeToUserHomeDir { get; set; } + } + + private string credentialFolderPath; + private string credentialFileName; + private FileTokenStorage storage; + + public LinuxCredentialStore(StoreConfig config) + { + Validate.IsNotNull("config", config); + Validate.IsNotNullOrEmptyString("credentialFolder", config.CredentialFolder); + Validate.IsNotNullOrEmptyString("credentialFileName", config.CredentialFile); + + this.credentialFolderPath = config.IsRelativeToUserHomeDir ? GetUserScopedDirectory(config.CredentialFolder) : config.CredentialFolder; + this.credentialFileName = config.CredentialFile; + + + string combinedPath = Path.Combine(this.credentialFolderPath, this.credentialFileName); + storage = new FileTokenStorage(combinedPath); + } + + public bool DeletePassword(string credentialId) + { + Validate.IsNotNullOrEmptyString("credentialId", credentialId); + IEnumerable creds; + if (LoadCredentialsAndFilterById(credentialId, out creds)) + { + storage.SaveEntries(creds); + return true; + } + + return false; + } + + /// + /// Gets filtered credentials with a specific ID filtered out + /// + /// True if the credential to filter was removed, false if it was not found + private bool LoadCredentialsAndFilterById(string idToFilter, out IEnumerable creds) + { + bool didRemove = false; + creds = storage.LoadEntries().Where(cred => + { + if (IsCredentialMatch(idToFilter, cred)) + { + didRemove = true; + return false; // filter this out + } + return true; + }).ToList(); // Call ToList ensures Where clause is executed so didRemove can be evaluated + + return didRemove; + } + + private static bool IsCredentialMatch(string credentialId, Credential cred) + { + return string.Equals(credentialId, cred.CredentialId, StringComparison.Ordinal); + } + + public bool TryGetPassword(string credentialId, out string password) + { + Validate.IsNotNullOrEmptyString("credentialId", credentialId); + Credential cred = storage.LoadEntries().FirstOrDefault(c => IsCredentialMatch(credentialId, c)); + if (cred != null) + { + password = cred.Password; + return true; + } + + // Else this was not found in the list + password = null; + return false; + } + + public bool Save(Credential credential) + { + Credential.ValidateForSave(credential); + + // Load the credentials, removing the existing Cred for this + IEnumerable creds; + LoadCredentialsAndFilterById(credential.CredentialId, out creds); + storage.SaveEntries(creds.Append(credential)); + + return true; + } + + + /// + /// Internal for testing purposes only + /// + internal string CredentialFolderPath + { + get { return this.credentialFolderPath; } + } + + /// + /// Concatenates a directory to the user home directory's path + /// + internal static string GetUserScopedDirectory(string userPath) + { + string homeDir = GetHomeDirectory() ?? string.Empty; + return Path.Combine(homeDir, userPath); + } + + + /// Gets the current user's home directory. + /// The path to the home directory, or null if it could not be determined. + internal static string GetHomeDirectory() + { + // First try to get the user's home directory from the HOME environment variable. + // This should work in most cases. + string userHomeDirectory = Environment.GetEnvironmentVariable("HOME"); + if (!string.IsNullOrEmpty(userHomeDirectory)) + { + return userHomeDirectory; + } + + // In initialization conditions, however, the "HOME" environment variable may + // not yet be set. For such cases, consult with the password entry. + + // First try with a buffer that should suffice for 99% of cases. + // Note that, theoretically, userHomeDirectory may be null in the success case + // if we simply couldn't find a home directory for the current user. + // In that case, we pass back the null value and let the caller decide + // what to do. + return GetHomeDirectoryFromPw(); + } + + internal static string GetHomeDirectoryFromPw() + { + string userHomeDirectory = null; + const int BufLen = 1024; + if (TryGetHomeDirectoryFromPasswd(BufLen, out userHomeDirectory)) + { + return userHomeDirectory; + } + // Fallback to heap allocations if necessary, growing the buffer until + // we succeed. TryGetHomeDirectory will throw if there's an unexpected error. + int lastBufLen = BufLen; + while (true) + { + lastBufLen *= 2; + if (TryGetHomeDirectoryFromPasswd(lastBufLen, out userHomeDirectory)) + { + return userHomeDirectory; + } + } + } + + /// Wrapper for getpwuid_r. + /// The length of the buffer to use when storing the password result. + /// The resulting path; null if the user didn't have an entry. + /// true if the call was successful (path may still be null); false is a larger buffer is needed. + private static bool TryGetHomeDirectoryFromPasswd(int bufLen, out string path) + { + // Call getpwuid_r to get the passwd struct + Interop.Sys.Passwd passwd; + IntPtr buffer = Marshal.AllocHGlobal(bufLen); + try + { + int error = Interop.Sys.GetPwUidR(Interop.Sys.GetEUid(), out passwd, buffer, bufLen); + + // If the call succeeds, give back the home directory path retrieved + if (error == 0) + { + Debug.Assert(passwd.HomeDirectory != IntPtr.Zero); + path = Marshal.PtrToStringAnsi(passwd.HomeDirectory); + return true; + } + + // If the current user's entry could not be found, give back null + // path, but still return true as false indicates the buffer was + // too small. + if (error == -1) + { + path = null; + return true; + } + + var errorInfo = new Interop.ErrorInfo(error); + + // If the call failed because the buffer was too small, return false to + // indicate the caller should try again with a larger buffer. + if (errorInfo.Error == Interop.Error.ERANGE) + { + path = null; + return false; + } + + // Otherwise, fail. + throw new IOException(errorInfo.GetErrorMessage(), errorInfo.RawErrno); + } + finally + { + // Deallocate the buffer we created + Marshal.FreeHGlobal(buffer); + } + } + } +} \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/OSX/Interop.CoreFoundation.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/OSX/Interop.CoreFoundation.cs new file mode 100644 index 00000000..140dfc63 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/OSX/Interop.CoreFoundation.cs @@ -0,0 +1,105 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.SqlTools.ServiceLayer.Credentials +{ + internal static partial class Interop + { + internal static partial class CoreFoundation + { + /// + /// Tells the OS what encoding the passed in String is in. These come from the CFString.h header file in the CoreFoundation framework. + /// + private enum CFStringBuiltInEncodings : uint + { + kCFStringEncodingMacRoman = 0, + kCFStringEncodingWindowsLatin1 = 0x0500, + kCFStringEncodingISOLatin1 = 0x0201, + kCFStringEncodingNextStepLatin = 0x0B01, + kCFStringEncodingASCII = 0x0600, + kCFStringEncodingUnicode = 0x0100, + kCFStringEncodingUTF8 = 0x08000100, + kCFStringEncodingNonLossyASCII = 0x0BFF, + + kCFStringEncodingUTF16 = 0x0100, + kCFStringEncodingUTF16BE = 0x10000100, + kCFStringEncodingUTF16LE = 0x14000100, + kCFStringEncodingUTF32 = 0x0c000100, + kCFStringEncodingUTF32BE = 0x18000100, + kCFStringEncodingUTF32LE = 0x1c000100 + } + + /// + /// Creates a CFStringRef from a 8-bit String object. Follows the "Create Rule" where if you create it, you delete it. + /// + /// Should be IntPtr.Zero + /// The string to get a CFStringRef for + /// The encoding of the str variable. This should be UTF 8 for OS X + /// Returns a pointer to a CFString on success; otherwise, returns IntPtr.Zero + /// For *nix systems, the CLR maps ANSI to UTF-8, so be explicit about that + [DllImport(Interop.Libraries.CoreFoundationLibrary, CharSet = CharSet.Ansi)] + private static extern SafeCreateHandle CFStringCreateWithCString( + IntPtr allocator, + string str, + CFStringBuiltInEncodings encoding); + + /// + /// Creates a CFStringRef from a 8-bit String object. Follows the "Create Rule" where if you create it, you delete it. + /// + /// The string to get a CFStringRef for + /// Returns a valid SafeCreateHandle to a CFString on success; otherwise, returns an invalid SafeCreateHandle + internal static SafeCreateHandle CFStringCreateWithCString(string str) + { + return CFStringCreateWithCString(IntPtr.Zero, str, CFStringBuiltInEncodings.kCFStringEncodingUTF8); + } + + /// + /// Creates a pointer to an unmanaged CFArray containing the input values. Follows the "Create Rule" where if you create it, you delete it. + /// + /// Should be IntPtr.Zero + /// The values to put in the array + /// The number of values in the array + /// Should be IntPtr.Zero + /// Returns a pointer to a CFArray on success; otherwise, returns IntPtr.Zero + [DllImport(Interop.Libraries.CoreFoundationLibrary)] + private static extern SafeCreateHandle CFArrayCreate( + IntPtr allocator, + [MarshalAs(UnmanagedType.LPArray)] + IntPtr[] values, + ulong numValues, + IntPtr callbacks); + + /// + /// Creates a pointer to an unmanaged CFArray containing the input values. Follows the "Create Rule" where if you create it, you delete it. + /// + /// The values to put in the array + /// The number of values in the array + /// Returns a valid SafeCreateHandle to a CFArray on success; otherwise, returns an invalid SafeCreateHandle + internal static SafeCreateHandle CFArrayCreate(IntPtr[] values, ulong numValues) + { + return CFArrayCreate(IntPtr.Zero, values, numValues, IntPtr.Zero); + } + + /// + /// You should retain a Core Foundation object when you receive it from elsewhere + /// (that is, you did not create or copy it) and you want it to persist. If you + /// retain a Core Foundation object you are responsible for releasing it + /// + /// The CFType object to retain. This value must not be NULL + /// The input value + [DllImport(Interop.Libraries.CoreFoundationLibrary)] + internal extern static IntPtr CFRetain(IntPtr ptr); + + /// + /// Decrements the reference count on the specified object and, if the ref count hits 0, cleans up the object. + /// + /// The pointer on which to decrement the reference count. + [DllImport(Interop.Libraries.CoreFoundationLibrary)] + internal extern static void CFRelease(IntPtr ptr); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/OSX/Interop.Libraries.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/OSX/Interop.Libraries.cs new file mode 100644 index 00000000..7ad5b639 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/OSX/Interop.Libraries.cs @@ -0,0 +1,17 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.Credentials +{ + internal static partial class Interop + { + private static partial class Libraries + { + internal const string CoreFoundationLibrary = "/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation"; + internal const string CoreServicesLibrary = "/System/Library/Frameworks/CoreServices.framework/CoreServices"; + internal const string SecurityLibrary = "/System/Library/Frameworks/Security.framework/Security"; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/OSX/Interop.Security.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/OSX/Interop.Security.cs new file mode 100644 index 00000000..0a6209e8 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/OSX/Interop.Security.cs @@ -0,0 +1,459 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.SqlTools.ServiceLayer.Credentials +{ + internal partial class Interop + { + internal partial class Security + { + + [DllImport(Libraries.SecurityLibrary, CharSet = CharSet.Unicode, SetLastError = true)] + internal static extern OSStatus SecKeychainAddGenericPassword(IntPtr keyChainRef, UInt32 serviceNameLength, string serviceName, + UInt32 accountNameLength, string accountName, UInt32 passwordLength, IntPtr password, [Out] IntPtr itemRef); + + /// + /// Find a generic password based on the attributes passed + /// + /// + /// A reference to an array of keychains to search, a single keychain, or NULL to search the user's default keychain search list. + /// + /// The length of the buffer pointed to by serviceName. + /// A pointer to a string containing the service name. + /// The length of the buffer pointed to by accountName. + /// A pointer to a string containing the account name. + /// On return, the length of the buffer pointed to by passwordData. + /// + /// On return, a pointer to a data buffer containing the password. + /// Your application must call SecKeychainItemFreeContent(NULL, passwordData) + /// to release this data buffer when it is no longer needed.Pass NULL if you are not interested in retrieving the password data at + /// this time, but simply want to find the item reference. + /// + /// On return, a reference to the keychain item which was found. + /// A result code that should be in + /// + /// The SecKeychainFindGenericPassword function finds the first generic password item which matches the attributes you provide. + /// Most attributes are optional; you should pass only as many as you need to narrow the search sufficiently for your application's intended use. + /// SecKeychainFindGenericPassword optionally returns a reference to the found item. + /// + [DllImport(Libraries.SecurityLibrary, CharSet = CharSet.Unicode, SetLastError = true)] + internal static extern OSStatus SecKeychainFindGenericPassword(IntPtr keyChainRef, UInt32 serviceNameLength, string serviceName, + UInt32 accountNameLength, string accountName, out UInt32 passwordLength, out IntPtr password, out IntPtr itemRef); + + /// + /// Releases the memory used by the keychain attribute list and the keychain data retrieved in a previous call to SecKeychainItemCopyContent. + /// + /// A pointer to the attribute list to release. Pass NULL to ignore this parameter. + /// A pointer to the data buffer to release. Pass NULL to ignore this parameter. + /// A result code that should be in + [DllImport(Libraries.SecurityLibrary, SetLastError = true)] + internal static extern OSStatus SecKeychainItemFreeContent([In] IntPtr attrList, [In] IntPtr data); + + /// + /// Deletes a keychain item from the default keychain's permanent data store. + /// + /// A keychain item reference of the item to delete. + /// A result code that should be in + /// + /// If itemRef has not previously been added to the keychain, SecKeychainItemDelete does nothing and returns ErrSecSuccess. + /// IMPORTANT: SecKeychainItemDelete does not dispose the memory occupied by the item reference itself; + /// use the CFRelease function when you are completely * * finished with an item. + /// + [DllImport(Libraries.SecurityLibrary, SetLastError = true)] + internal static extern OSStatus SecKeychainItemDelete(SafeHandle itemRef); + + #region OSStatus Codes + /// Common Unix errno error codes. + internal enum OSStatus + { + ErrSecSuccess = 0, /* No error. */ + ErrSecUnimplemented = -4, /* Function or operation not implemented. */ + ErrSecDskFull = -34, + ErrSecIO = -36, /*I/O error (bummers)*/ + + ErrSecParam = -50, /* One or more parameters passed to a function were not valid. */ + ErrSecWrPerm = -61, /* write permissions error*/ + ErrSecAllocate = -108, /* Failed to allocate memory. */ + ErrSecUserCanceled = -128, /* User canceled the operation. */ + ErrSecBadReq = -909, /* Bad parameter or invalid state for operation. */ + + ErrSecInternalComponent = -2070, + ErrSecCoreFoundationUnknown = -4960, + + ErrSecNotAvailable = -25291, /* No keychain is available. You may need to restart your computer. */ + ErrSecReadOnly = -25292, /* This keychain cannot be modified. */ + ErrSecAuthFailed = -25293, /* The user name or passphrase you entered is not correct. */ + ErrSecNoSuchKeychain = -25294, /* The specified keychain could not be found. */ + ErrSecInvalidKeychain = -25295, /* The specified keychain is not a valid keychain file. */ + ErrSecDuplicateKeychain = -25296, /* A keychain with the same name already exists. */ + ErrSecDuplicateCallback = -25297, /* The specified callback function is already installed. */ + ErrSecInvalidCallback = -25298, /* The specified callback function is not valid. */ + ErrSecDuplicateItem = -25299, /* The specified item already exists in the keychain. */ + ErrSecItemNotFound = -25300, /* The specified item could not be found in the keychain. */ + ErrSecBufferTooSmall = -25301, /* There is not enough memory available to use the specified item. */ + ErrSecDataTooLarge = -25302, /* This item contains information which is too large or in a format that cannot be displayed. */ + ErrSecNoSuchAttr = -25303, /* The specified attribute does not exist. */ + ErrSecInvalidItemRef = -25304, /* The specified item is no longer valid. It may have been deleted from the keychain. */ + ErrSecInvalidSearchRef = -25305, /* Unable to search the current keychain. */ + ErrSecNoSuchClass = -25306, /* The specified item does not appear to be a valid keychain item. */ + ErrSecNoDefaultKeychain = -25307, /* A default keychain could not be found. */ + ErrSecInteractionNotAllowed = -25308, /* User interaction is not allowed. */ + ErrSecReadOnlyAttr = -25309, /* The specified attribute could not be modified. */ + ErrSecWrongSecVersion = -25310, /* This keychain was created by a different version of the system software and cannot be opened. */ + ErrSecKeySizeNotAllowed = -25311, /* This item specifies a key size which is too large. */ + ErrSecNoStorageModule = -25312, /* A required component (data storage module) could not be loaded. You may need to restart your computer. */ + ErrSecNoCertificateModule = -25313, /* A required component (certificate module) could not be loaded. You may need to restart your computer. */ + ErrSecNoPolicyModule = -25314, /* A required component (policy module) could not be loaded. You may need to restart your computer. */ + ErrSecInteractionRequired = -25315, /* User interaction is required, but is currently not allowed. */ + ErrSecDataNotAvailable = -25316, /* The contents of this item cannot be retrieved. */ + ErrSecDataNotModifiable = -25317, /* The contents of this item cannot be modified. */ + ErrSecCreateChainFailed = -25318, /* One or more certificates required to validate this certificate cannot be found. */ + ErrSecInvalidPrefsDomain = -25319, /* The specified preferences domain is not valid. */ + ErrSecInDarkWake = -25320, /* In dark wake, no UI possible */ + + ErrSecACLNotSimple = -25240, /* The specified access control list is not in standard (simple) form. */ + ErrSecPolicyNotFound = -25241, /* The specified policy cannot be found. */ + ErrSecInvalidTrustSetting = -25242, /* The specified trust setting is invalid. */ + ErrSecNoAccessForItem = -25243, /* The specified item has no access control. */ + ErrSecInvalidOwnerEdit = -25244, /* Invalid attempt to change the owner of this item. */ + ErrSecTrustNotAvailable = -25245, /* No trust results are available. */ + ErrSecUnsupportedFormat = -25256, /* Import/Export format unsupported. */ + ErrSecUnknownFormat = -25257, /* Unknown format in import. */ + ErrSecKeyIsSensitive = -25258, /* Key material must be wrapped for export. */ + ErrSecMultiplePrivKeys = -25259, /* An attempt was made to import multiple private keys. */ + ErrSecPassphraseRequired = -25260, /* Passphrase is required for import/export. */ + ErrSecInvalidPasswordRef = -25261, /* The password reference was invalid. */ + ErrSecInvalidTrustSettings = -25262, /* The Trust Settings Record was corrupted. */ + ErrSecNoTrustSettings = -25263, /* No Trust Settings were found. */ + ErrSecPkcs12VerifyFailure = -25264, /* MAC verification failed during PKCS12 import (wrong password?) */ + ErrSecNotSigner = -26267, /* A certificate was not signed by its proposed parent. */ + + ErrSecDecode = -26275, /* Unable to decode the provided data. */ + + ErrSecServiceNotAvailable = -67585, /* The required service is not available. */ + ErrSecInsufficientClientID = -67586, /* The client ID is not correct. */ + ErrSecDeviceReset = -67587, /* A device reset has occurred. */ + ErrSecDeviceFailed = -67588, /* A device failure has occurred. */ + ErrSecAppleAddAppACLSubject = -67589, /* Adding an application ACL subject failed. */ + ErrSecApplePublicKeyIncomplete = -67590, /* The public key is incomplete. */ + ErrSecAppleSignatureMismatch = -67591, /* A signature mismatch has occurred. */ + ErrSecAppleInvalidKeyStartDate = -67592, /* The specified key has an invalid start date. */ + ErrSecAppleInvalidKeyEndDate = -67593, /* The specified key has an invalid end date. */ + ErrSecConversionError = -67594, /* A conversion error has occurred. */ + ErrSecAppleSSLv2Rollback = -67595, /* A SSLv2 rollback error has occurred. */ + ErrSecDiskFull = -34, /* The disk is full. */ + ErrSecQuotaExceeded = -67596, /* The quota was exceeded. */ + ErrSecFileTooBig = -67597, /* The file is too big. */ + ErrSecInvalidDatabaseBlob = -67598, /* The specified database has an invalid blob. */ + ErrSecInvalidKeyBlob = -67599, /* The specified database has an invalid key blob. */ + ErrSecIncompatibleDatabaseBlob = -67600, /* The specified database has an incompatible blob. */ + ErrSecIncompatibleKeyBlob = -67601, /* The specified database has an incompatible key blob. */ + ErrSecHostNameMismatch = -67602, /* A host name mismatch has occurred. */ + ErrSecUnknownCriticalExtensionFlag = -67603, /* There is an unknown critical extension flag. */ + ErrSecNoBasicConstraints = -67604, /* No basic constraints were found. */ + ErrSecNoBasicConstraintsCA = -67605, /* No basic CA constraints were found. */ + ErrSecInvalidAuthorityKeyID = -67606, /* The authority key ID is not valid. */ + ErrSecInvalidSubjectKeyID = -67607, /* The subject key ID is not valid. */ + ErrSecInvalidKeyUsageForPolicy = -67608, /* The key usage is not valid for the specified policy. */ + ErrSecInvalidExtendedKeyUsage = -67609, /* The extended key usage is not valid. */ + ErrSecInvalidIDLinkage = -67610, /* The ID linkage is not valid. */ + ErrSecPathLengthConstraintExceeded = -67611, /* The path length constraint was exceeded. */ + ErrSecInvalidRoot = -67612, /* The root or anchor certificate is not valid. */ + ErrSecCRLExpired = -67613, /* The CRL has expired. */ + ErrSecCRLNotValidYet = -67614, /* The CRL is not yet valid. */ + ErrSecCRLNotFound = -67615, /* The CRL was not found. */ + ErrSecCRLServerDown = -67616, /* The CRL server is down. */ + ErrSecCRLBadURI = -67617, /* The CRL has a bad Uniform Resource Identifier. */ + ErrSecUnknownCertExtension = -67618, /* An unknown certificate extension was encountered. */ + ErrSecUnknownCRLExtension = -67619, /* An unknown CRL extension was encountered. */ + ErrSecCRLNotTrusted = -67620, /* The CRL is not trusted. */ + ErrSecCRLPolicyFailed = -67621, /* The CRL policy failed. */ + ErrSecIDPFailure = -67622, /* The issuing distribution point was not valid. */ + ErrSecSMIMEEmailAddressesNotFound = -67623, /* An email address mismatch was encountered. */ + ErrSecSMIMEBadExtendedKeyUsage = -67624, /* The appropriate extended key usage for SMIME was not found. */ + ErrSecSMIMEBadKeyUsage = -67625, /* The key usage is not compatible with SMIME. */ + ErrSecSMIMEKeyUsageNotCritical = -67626, /* The key usage extension is not marked as critical. */ + ErrSecSMIMENoEmailAddress = -67627, /* No email address was found in the certificate. */ + ErrSecSMIMESubjAltNameNotCritical = -67628, /* The subject alternative name extension is not marked as critical. */ + ErrSecSSLBadExtendedKeyUsage = -67629, /* The appropriate extended key usage for SSL was not found. */ + ErrSecOCSPBadResponse = -67630, /* The OCSP response was incorrect or could not be parsed. */ + ErrSecOCSPBadRequest = -67631, /* The OCSP request was incorrect or could not be parsed. */ + ErrSecOCSPUnavailable = -67632, /* OCSP service is unavailable. */ + ErrSecOCSPStatusUnrecognized = -67633, /* The OCSP server did not recognize this certificate. */ + ErrSecEndOfData = -67634, /* An end-of-data was detected. */ + ErrSecIncompleteCertRevocationCheck = -67635, /* An incomplete certificate revocation check occurred. */ + ErrSecNetworkFailure = -67636, /* A network failure occurred. */ + ErrSecOCSPNotTrustedToAnchor = -67637, /* The OCSP response was not trusted to a root or anchor certificate. */ + ErrSecRecordModified = -67638, /* The record was modified. */ + ErrSecOCSPSignatureError = -67639, /* The OCSP response had an invalid signature. */ + ErrSecOCSPNoSigner = -67640, /* The OCSP response had no signer. */ + ErrSecOCSPResponderMalformedReq = -67641, /* The OCSP responder was given a malformed request. */ + ErrSecOCSPResponderInternalError = -67642, /* The OCSP responder encountered an internal error. */ + ErrSecOCSPResponderTryLater = -67643, /* The OCSP responder is busy, try again later. */ + ErrSecOCSPResponderSignatureRequired = -67644, /* The OCSP responder requires a signature. */ + ErrSecOCSPResponderUnauthorized = -67645, /* The OCSP responder rejected this request as unauthorized. */ + ErrSecOCSPResponseNonceMismatch = -67646, /* The OCSP response nonce did not match the request. */ + ErrSecCodeSigningBadCertChainLength = -67647, /* Code signing encountered an incorrect certificate chain length. */ + ErrSecCodeSigningNoBasicConstraints = -67648, /* Code signing found no basic constraints. */ + ErrSecCodeSigningBadPathLengthConstraint= -67649, /* Code signing encountered an incorrect path length constraint. */ + ErrSecCodeSigningNoExtendedKeyUsage = -67650, /* Code signing found no extended key usage. */ + ErrSecCodeSigningDevelopment = -67651, /* Code signing indicated use of a development-only certificate. */ + ErrSecResourceSignBadCertChainLength = -67652, /* Resource signing has encountered an incorrect certificate chain length. */ + ErrSecResourceSignBadExtKeyUsage = -67653, /* Resource signing has encountered an error in the extended key usage. */ + ErrSecTrustSettingDeny = -67654, /* The trust setting for this policy was set to Deny. */ + ErrSecInvalidSubjectName = -67655, /* An invalid certificate subject name was encountered. */ + ErrSecUnknownQualifiedCertStatement = -67656, /* An unknown qualified certificate statement was encountered. */ + ErrSecMobileMeRequestQueued = -67657, /* The MobileMe request will be sent during the next connection. */ + ErrSecMobileMeRequestRedirected = -67658, /* The MobileMe request was redirected. */ + ErrSecMobileMeServerError = -67659, /* A MobileMe server error occurred. */ + ErrSecMobileMeServerNotAvailable = -67660, /* The MobileMe server is not available. */ + ErrSecMobileMeServerAlreadyExists = -67661, /* The MobileMe server reported that the item already exists. */ + ErrSecMobileMeServerServiceErr = -67662, /* A MobileMe service error has occurred. */ + ErrSecMobileMeRequestAlreadyPending = -67663, /* A MobileMe request is already pending. */ + ErrSecMobileMeNoRequestPending = -67664, /* MobileMe has no request pending. */ + ErrSecMobileMeCSRVerifyFailure = -67665, /* A MobileMe CSR verification failure has occurred. */ + ErrSecMobileMeFailedConsistencyCheck = -67666, /* MobileMe has found a failed consistency check. */ + ErrSecNotInitialized = -67667, /* A function was called without initializing CSSM. */ + ErrSecInvalidHandleUsage = -67668, /* The CSSM handle does not match with the service type. */ + ErrSecPVCReferentNotFound = -67669, /* A reference to the calling module was not found in the list of authorized callers. */ + ErrSecFunctionIntegrityFail = -67670, /* A function address was not within the verified module. */ + ErrSecInternalError = -67671, /* An internal error has occurred. */ + ErrSecMemoryError = -67672, /* A memory error has occurred. */ + ErrSecInvalidData = -67673, /* Invalid data was encountered. */ + ErrSecMDSError = -67674, /* A Module Directory Service error has occurred. */ + ErrSecInvalidPointer = -67675, /* An invalid pointer was encountered. */ + ErrSecSelfCheckFailed = -67676, /* Self-check has failed. */ + ErrSecFunctionFailed = -67677, /* A function has failed. */ + ErrSecModuleManifestVerifyFailed = -67678, /* A module manifest verification failure has occurred. */ + ErrSecInvalidGUID = -67679, /* An invalid GUID was encountered. */ + ErrSecInvalidHandle = -67680, /* An invalid handle was encountered. */ + ErrSecInvalidDBList = -67681, /* An invalid DB list was encountered. */ + ErrSecInvalidPassthroughID = -67682, /* An invalid passthrough ID was encountered. */ + ErrSecInvalidNetworkAddress = -67683, /* An invalid network address was encountered. */ + ErrSecCRLAlreadySigned = -67684, /* The certificate revocation list is already signed. */ + ErrSecInvalidNumberOfFields = -67685, /* An invalid number of fields were encountered. */ + ErrSecVerificationFailure = -67686, /* A verification failure occurred. */ + ErrSecUnknownTag = -67687, /* An unknown tag was encountered. */ + ErrSecInvalidSignature = -67688, /* An invalid signature was encountered. */ + ErrSecInvalidName = -67689, /* An invalid name was encountered. */ + ErrSecInvalidCertificateRef = -67690, /* An invalid certificate reference was encountered. */ + ErrSecInvalidCertificateGroup = -67691, /* An invalid certificate group was encountered. */ + ErrSecTagNotFound = -67692, /* The specified tag was not found. */ + ErrSecInvalidQuery = -67693, /* The specified query was not valid. */ + ErrSecInvalidValue = -67694, /* An invalid value was detected. */ + ErrSecCallbackFailed = -67695, /* A callback has failed. */ + ErrSecACLDeleteFailed = -67696, /* An ACL delete operation has failed. */ + ErrSecACLReplaceFailed = -67697, /* An ACL replace operation has failed. */ + ErrSecACLAddFailed = -67698, /* An ACL add operation has failed. */ + ErrSecACLChangeFailed = -67699, /* An ACL change operation has failed. */ + ErrSecInvalidAccessCredentials = -67700, /* Invalid access credentials were encountered. */ + ErrSecInvalidRecord = -67701, /* An invalid record was encountered. */ + ErrSecInvalidACL = -67702, /* An invalid ACL was encountered. */ + ErrSecInvalidSampleValue = -67703, /* An invalid sample value was encountered. */ + ErrSecIncompatibleVersion = -67704, /* An incompatible version was encountered. */ + ErrSecPrivilegeNotGranted = -67705, /* The privilege was not granted. */ + ErrSecInvalidScope = -67706, /* An invalid scope was encountered. */ + ErrSecPVCAlreadyConfigured = -67707, /* The PVC is already configured. */ + ErrSecInvalidPVC = -67708, /* An invalid PVC was encountered. */ + ErrSecEMMLoadFailed = -67709, /* The EMM load has failed. */ + ErrSecEMMUnloadFailed = -67710, /* The EMM unload has failed. */ + ErrSecAddinLoadFailed = -67711, /* The add-in load operation has failed. */ + ErrSecInvalidKeyRef = -67712, /* An invalid key was encountered. */ + ErrSecInvalidKeyHierarchy = -67713, /* An invalid key hierarchy was encountered. */ + ErrSecAddinUnloadFailed = -67714, /* The add-in unload operation has failed. */ + ErrSecLibraryReferenceNotFound = -67715, /* A library reference was not found. */ + ErrSecInvalidAddinFunctionTable = -67716, /* An invalid add-in function table was encountered. */ + ErrSecInvalidServiceMask = -67717, /* An invalid service mask was encountered. */ + ErrSecModuleNotLoaded = -67718, /* A module was not loaded. */ + ErrSecInvalidSubServiceID = -67719, /* An invalid subservice ID was encountered. */ + ErrSecAttributeNotInContext = -67720, /* An attribute was not in the context. */ + ErrSecModuleManagerInitializeFailed = -67721, /* A module failed to initialize. */ + ErrSecModuleManagerNotFound = -67722, /* A module was not found. */ + ErrSecEventNotificationCallbackNotFound = -67723, /* An event notification callback was not found. */ + ErrSecInputLengthError = -67724, /* An input length error was encountered. */ + ErrSecOutputLengthError = -67725, /* An output length error was encountered. */ + ErrSecPrivilegeNotSupported = -67726, /* The privilege is not supported. */ + ErrSecDeviceError = -67727, /* A device error was encountered. */ + ErrSecAttachHandleBusy = -67728, /* The CSP handle was busy. */ + ErrSecNotLoggedIn = -67729, /* You are not logged in. */ + ErrSecAlgorithmMismatch = -67730, /* An algorithm mismatch was encountered. */ + ErrSecKeyUsageIncorrect = -67731, /* The key usage is incorrect. */ + ErrSecKeyBlobTypeIncorrect = -67732, /* The key blob type is incorrect. */ + ErrSecKeyHeaderInconsistent = -67733, /* The key header is inconsistent. */ + ErrSecUnsupportedKeyFormat = -67734, /* The key header format is not supported. */ + ErrSecUnsupportedKeySize = -67735, /* The key size is not supported. */ + ErrSecInvalidKeyUsageMask = -67736, /* The key usage mask is not valid. */ + ErrSecUnsupportedKeyUsageMask = -67737, /* The key usage mask is not supported. */ + ErrSecInvalidKeyAttributeMask = -67738, /* The key attribute mask is not valid. */ + ErrSecUnsupportedKeyAttributeMask = -67739, /* The key attribute mask is not supported. */ + ErrSecInvalidKeyLabel = -67740, /* The key label is not valid. */ + ErrSecUnsupportedKeyLabel = -67741, /* The key label is not supported. */ + ErrSecInvalidKeyFormat = -67742, /* The key format is not valid. */ + ErrSecUnsupportedVectorOfBuffers = -67743, /* The vector of buffers is not supported. */ + ErrSecInvalidInputVector = -67744, /* The input vector is not valid. */ + ErrSecInvalidOutputVector = -67745, /* The output vector is not valid. */ + ErrSecInvalidContext = -67746, /* An invalid context was encountered. */ + ErrSecInvalidAlgorithm = -67747, /* An invalid algorithm was encountered. */ + ErrSecInvalidAttributeKey = -67748, /* A key attribute was not valid. */ + ErrSecMissingAttributeKey = -67749, /* A key attribute was missing. */ + ErrSecInvalidAttributeInitVector = -67750, /* An init vector attribute was not valid. */ + ErrSecMissingAttributeInitVector = -67751, /* An init vector attribute was missing. */ + ErrSecInvalidAttributeSalt = -67752, /* A salt attribute was not valid. */ + ErrSecMissingAttributeSalt = -67753, /* A salt attribute was missing. */ + ErrSecInvalidAttributePadding = -67754, /* A padding attribute was not valid. */ + ErrSecMissingAttributePadding = -67755, /* A padding attribute was missing. */ + ErrSecInvalidAttributeRandom = -67756, /* A random number attribute was not valid. */ + ErrSecMissingAttributeRandom = -67757, /* A random number attribute was missing. */ + ErrSecInvalidAttributeSeed = -67758, /* A seed attribute was not valid. */ + ErrSecMissingAttributeSeed = -67759, /* A seed attribute was missing. */ + ErrSecInvalidAttributePassphrase = -67760, /* A passphrase attribute was not valid. */ + ErrSecMissingAttributePassphrase = -67761, /* A passphrase attribute was missing. */ + ErrSecInvalidAttributeKeyLength = -67762, /* A key length attribute was not valid. */ + ErrSecMissingAttributeKeyLength = -67763, /* A key length attribute was missing. */ + ErrSecInvalidAttributeBlockSize = -67764, /* A block size attribute was not valid. */ + ErrSecMissingAttributeBlockSize = -67765, /* A block size attribute was missing. */ + ErrSecInvalidAttributeOutputSize = -67766, /* An output size attribute was not valid. */ + ErrSecMissingAttributeOutputSize = -67767, /* An output size attribute was missing. */ + ErrSecInvalidAttributeRounds = -67768, /* The number of rounds attribute was not valid. */ + ErrSecMissingAttributeRounds = -67769, /* The number of rounds attribute was missing. */ + ErrSecInvalidAlgorithmParms = -67770, /* An algorithm parameters attribute was not valid. */ + ErrSecMissingAlgorithmParms = -67771, /* An algorithm parameters attribute was missing. */ + ErrSecInvalidAttributeLabel = -67772, /* A label attribute was not valid. */ + ErrSecMissingAttributeLabel = -67773, /* A label attribute was missing. */ + ErrSecInvalidAttributeKeyType = -67774, /* A key type attribute was not valid. */ + ErrSecMissingAttributeKeyType = -67775, /* A key type attribute was missing. */ + ErrSecInvalidAttributeMode = -67776, /* A mode attribute was not valid. */ + ErrSecMissingAttributeMode = -67777, /* A mode attribute was missing. */ + ErrSecInvalidAttributeEffectiveBits = -67778, /* An effective bits attribute was not valid. */ + ErrSecMissingAttributeEffectiveBits = -67779, /* An effective bits attribute was missing. */ + ErrSecInvalidAttributeStartDate = -67780, /* A start date attribute was not valid. */ + ErrSecMissingAttributeStartDate = -67781, /* A start date attribute was missing. */ + ErrSecInvalidAttributeEndDate = -67782, /* An end date attribute was not valid. */ + ErrSecMissingAttributeEndDate = -67783, /* An end date attribute was missing. */ + ErrSecInvalidAttributeVersion = -67784, /* A version attribute was not valid. */ + ErrSecMissingAttributeVersion = -67785, /* A version attribute was missing. */ + ErrSecInvalidAttributePrime = -67786, /* A prime attribute was not valid. */ + ErrSecMissingAttributePrime = -67787, /* A prime attribute was missing. */ + ErrSecInvalidAttributeBase = -67788, /* A base attribute was not valid. */ + ErrSecMissingAttributeBase = -67789, /* A base attribute was missing. */ + ErrSecInvalidAttributeSubprime = -67790, /* A subprime attribute was not valid. */ + ErrSecMissingAttributeSubprime = -67791, /* A subprime attribute was missing. */ + ErrSecInvalidAttributeIterationCount = -67792, /* An iteration count attribute was not valid. */ + ErrSecMissingAttributeIterationCount = -67793, /* An iteration count attribute was missing. */ + ErrSecInvalidAttributeDLDBHandle = -67794, /* A database handle attribute was not valid. */ + ErrSecMissingAttributeDLDBHandle = -67795, /* A database handle attribute was missing. */ + ErrSecInvalidAttributeAccessCredentials = -67796, /* An access credentials attribute was not valid. */ + ErrSecMissingAttributeAccessCredentials = -67797, /* An access credentials attribute was missing. */ + ErrSecInvalidAttributePublicKeyFormat = -67798, /* A public key format attribute was not valid. */ + ErrSecMissingAttributePublicKeyFormat = -67799, /* A public key format attribute was missing. */ + ErrSecInvalidAttributePrivateKeyFormat = -67800, /* A private key format attribute was not valid. */ + ErrSecMissingAttributePrivateKeyFormat = -67801, /* A private key format attribute was missing. */ + ErrSecInvalidAttributeSymmetricKeyFormat = -67802, /* A symmetric key format attribute was not valid. */ + ErrSecMissingAttributeSymmetricKeyFormat = -67803, /* A symmetric key format attribute was missing. */ + ErrSecInvalidAttributeWrappedKeyFormat = -67804, /* A wrapped key format attribute was not valid. */ + ErrSecMissingAttributeWrappedKeyFormat = -67805, /* A wrapped key format attribute was missing. */ + ErrSecStagedOperationInProgress = -67806, /* A staged operation is in progress. */ + ErrSecStagedOperationNotStarted = -67807, /* A staged operation was not started. */ + ErrSecVerifyFailed = -67808, /* A cryptographic verification failure has occurred. */ + ErrSecQuerySizeUnknown = -67809, /* The query size is unknown. */ + ErrSecBlockSizeMismatch = -67810, /* A block size mismatch occurred. */ + ErrSecPublicKeyInconsistent = -67811, /* The public key was inconsistent. */ + ErrSecDeviceVerifyFailed = -67812, /* A device verification failure has occurred. */ + ErrSecInvalidLoginName = -67813, /* An invalid login name was detected. */ + ErrSecAlreadyLoggedIn = -67814, /* The user is already logged in. */ + ErrSecInvalidDigestAlgorithm = -67815, /* An invalid digest algorithm was detected. */ + ErrSecInvalidCRLGroup = -67816, /* An invalid CRL group was detected. */ + ErrSecCertificateCannotOperate = -67817, /* The certificate cannot operate. */ + ErrSecCertificateExpired = -67818, /* An expired certificate was detected. */ + ErrSecCertificateNotValidYet = -67819, /* The certificate is not yet valid. */ + ErrSecCertificateRevoked = -67820, /* The certificate was revoked. */ + ErrSecCertificateSuspended = -67821, /* The certificate was suspended. */ + ErrSecInsufficientCredentials = -67822, /* Insufficient credentials were detected. */ + ErrSecInvalidAction = -67823, /* The action was not valid. */ + ErrSecInvalidAuthority = -67824, /* The authority was not valid. */ + ErrSecVerifyActionFailed = -67825, /* A verify action has failed. */ + ErrSecInvalidCertAuthority = -67826, /* The certificate authority was not valid. */ + ErrSecInvaldCRLAuthority = -67827, /* The CRL authority was not valid. */ + ErrSecInvalidCRLEncoding = -67828, /* The CRL encoding was not valid. */ + ErrSecInvalidCRLType = -67829, /* The CRL type was not valid. */ + ErrSecInvalidCRL = -67830, /* The CRL was not valid. */ + ErrSecInvalidFormType = -67831, /* The form type was not valid. */ + ErrSecInvalidID = -67832, /* The ID was not valid. */ + ErrSecInvalidIdentifier = -67833, /* The identifier was not valid. */ + ErrSecInvalidIndex = -67834, /* The index was not valid. */ + ErrSecInvalidPolicyIdentifiers = -67835, /* The policy identifiers are not valid. */ + ErrSecInvalidTimeString = -67836, /* The time specified was not valid. */ + ErrSecInvalidReason = -67837, /* The trust policy reason was not valid. */ + ErrSecInvalidRequestInputs = -67838, /* The request inputs are not valid. */ + ErrSecInvalidResponseVector = -67839, /* The response vector was not valid. */ + ErrSecInvalidStopOnPolicy = -67840, /* The stop-on policy was not valid. */ + ErrSecInvalidTuple = -67841, /* The tuple was not valid. */ + ErrSecMultipleValuesUnsupported = -67842, /* Multiple values are not supported. */ + ErrSecNotTrusted = -67843, /* The trust policy was not trusted. */ + ErrSecNoDefaultAuthority = -67844, /* No default authority was detected. */ + ErrSecRejectedForm = -67845, /* The trust policy had a rejected form. */ + ErrSecRequestLost = -67846, /* The request was lost. */ + ErrSecRequestRejected = -67847, /* The request was rejected. */ + ErrSecUnsupportedAddressType = -67848, /* The address type is not supported. */ + ErrSecUnsupportedService = -67849, /* The service is not supported. */ + ErrSecInvalidTupleGroup = -67850, /* The tuple group was not valid. */ + ErrSecInvalidBaseACLs = -67851, /* The base ACLs are not valid. */ + ErrSecInvalidTupleCredendtials = -67852, /* The tuple credentials are not valid. */ + ErrSecInvalidEncoding = -67853, /* The encoding was not valid. */ + ErrSecInvalidValidityPeriod = -67854, /* The validity period was not valid. */ + ErrSecInvalidRequestor = -67855, /* The requestor was not valid. */ + ErrSecRequestDescriptor = -67856, /* The request descriptor was not valid. */ + ErrSecInvalidBundleInfo = -67857, /* The bundle information was not valid. */ + ErrSecInvalidCRLIndex = -67858, /* The CRL index was not valid. */ + ErrSecNoFieldValues = -67859, /* No field values were detected. */ + ErrSecUnsupportedFieldFormat = -67860, /* The field format is not supported. */ + ErrSecUnsupportedIndexInfo = -67861, /* The index information is not supported. */ + ErrSecUnsupportedLocality = -67862, /* The locality is not supported. */ + ErrSecUnsupportedNumAttributes = -67863, /* The number of attributes is not supported. */ + ErrSecUnsupportedNumIndexes = -67864, /* The number of indexes is not supported. */ + ErrSecUnsupportedNumRecordTypes = -67865, /* The number of record types is not supported. */ + ErrSecFieldSpecifiedMultiple = -67866, /* Too many fields were specified. */ + ErrSecIncompatibleFieldFormat = -67867, /* The field format was incompatible. */ + ErrSecInvalidParsingModule = -67868, /* The parsing module was not valid. */ + ErrSecDatabaseLocked = -67869, /* The database is locked. */ + ErrSecDatastoreIsOpen = -67870, /* The data store is open. */ + ErrSecMissingValue = -67871, /* A missing value was detected. */ + ErrSecUnsupportedQueryLimits = -67872, /* The query limits are not supported. */ + ErrSecUnsupportedNumSelectionPreds = -67873, /* The number of selection predicates is not supported. */ + ErrSecUnsupportedOperator = -67874, /* The operator is not supported. */ + ErrSecInvalidDBLocation = -67875, /* The database location is not valid. */ + ErrSecInvalidAccessRequest = -67876, /* The access request is not valid. */ + ErrSecInvalidIndexInfo = -67877, /* The index information is not valid. */ + ErrSecInvalidNewOwner = -67878, /* The new owner is not valid. */ + ErrSecInvalidModifyMode = -67879, /* The modify mode is not valid. */ + ErrSecMissingRequiredExtension = -67880, /* A required certificate extension is missing. */ + ErrSecExtendedKeyUsageNotCritical = -67881, /* The extended key usage extension was not marked critical. */ + ErrSecTimestampMissing = -67882, /* A timestamp was expected but was not found. */ + ErrSecTimestampInvalid = -67883, /* The timestamp was not valid. */ + ErrSecTimestampNotTrusted = -67884, /* The timestamp was not trusted. */ + ErrSecTimestampServiceNotAvailable = -67885, /* The timestamp service is not available. */ + ErrSecTimestampBadAlg = -67886, /* An unrecognized or unsupported Algorithm Identifier in timestamp. */ + ErrSecTimestampBadRequest = -67887, /* The timestamp transaction is not permitted or supported. */ + ErrSecTimestampBadDataFormat = -67888, /* The timestamp data submitted has the wrong format. */ + ErrSecTimestampTimeNotAvailable = -67889, /* The time source for the Timestamp Authority is not available. */ + ErrSecTimestampUnacceptedPolicy = -67890, /* The requested policy is not supported by the Timestamp Authority. */ + ErrSecTimestampUnacceptedExtension = -67891, /* The requested extension is not supported by the Timestamp Authority. */ + ErrSecTimestampAddInfoNotAvailable = -67892, /* The additional information requested is not available. */ + ErrSecTimestampSystemFailure = -67893, /* The timestamp request cannot be handled due to system failure. */ + ErrSecSigningTimeMissing = -67894, /* A signing time was expected but was not found. */ + ErrSecTimestampRejection = -67895, /* A timestamp transaction was rejected. */ + ErrSecTimestampWaiting = -67896, /* A timestamp transaction is waiting. */ + ErrSecTimestampRevocationWarning = -67897, /* A timestamp authority revocation warning was issued. */ + ErrSecTimestampRevocationNotification = -67898, /* A timestamp authority revocation notification was issued. */ + } + + #endregion + } + } +} + diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/OSX/OSXCredentialStore.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/OSX/OSXCredentialStore.cs new file mode 100644 index 00000000..dc868040 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/OSX/OSXCredentialStore.cs @@ -0,0 +1,158 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Runtime.InteropServices; +using Microsoft.SqlTools.EditorServices.Utility; +using Microsoft.SqlTools.ServiceLayer.Credentials.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Credentials.OSX +{ + /// + /// OSX implementation of the credential store + /// + internal class OSXCredentialStore : ICredentialStore + { + public bool DeletePassword(string credentialId) + { + Validate.IsNotNullOrEmptyString("credentialId", credentialId); + return DeletePasswordImpl(credentialId); + } + + public bool TryGetPassword(string credentialId, out string password) + { + Validate.IsNotNullOrEmptyString("credentialId", credentialId); + return FindPassword(credentialId, out password); + } + + public bool Save(Credential credential) + { + Credential.ValidateForSave(credential); + bool result = false; + + // Note: OSX blocks AddPassword if the credential + // already exists, so for now we delete the password if already present since we're updating + // the value. In the future, we could consider updating but it's low value to solve this + DeletePasswordImpl(credential.CredentialId); + + // Now add the password + result = AddGenericPassword(credential); + return result; + } + + private bool AddGenericPassword(Credential credential) + { + IntPtr passwordPtr = Marshal.StringToCoTaskMemUni(credential.Password); + Interop.Security.OSStatus status = Interop.Security.SecKeychainAddGenericPassword( + IntPtr.Zero, + InteropUtils.GetLengthInBytes(credential.CredentialId), + credential.CredentialId, + 0, + null, + InteropUtils.GetLengthInBytes(credential.Password), + passwordPtr, + IntPtr.Zero); + + return status == Interop.Security.OSStatus.ErrSecSuccess; + } + + /// + /// Finds the first password matching this credential + /// + private bool FindPassword(string credentialId, out string password) + { + password = null; + using (KeyChainItemHandle handle = LookupKeyChainItem(credentialId)) + { + if( handle == null) + { + return false; + } + password = handle.Password; + } + + return true; + } + + private KeyChainItemHandle LookupKeyChainItem(string credentialId) + { + UInt32 passwordLength; + IntPtr passwordPtr; + IntPtr item; + Interop.Security.OSStatus status = Interop.Security.SecKeychainFindGenericPassword( + IntPtr.Zero, + InteropUtils.GetLengthInBytes(credentialId), + credentialId, + 0, + null, + out passwordLength, + out passwordPtr, + out item); + + if(status == Interop.Security.OSStatus.ErrSecSuccess) + { + return new KeyChainItemHandle(item, passwordPtr, passwordLength); + } + return null; + } + + private bool DeletePasswordImpl(string credentialId) + { + // Find password, then Delete, then cleanup + using (KeyChainItemHandle handle = LookupKeyChainItem(credentialId)) + { + if (handle == null) + { + return false; + } + Interop.Security.OSStatus status = Interop.Security.SecKeychainItemDelete(handle); + return status == Interop.Security.OSStatus.ErrSecSuccess; + } + } + + private class KeyChainItemHandle : SafeCreateHandle + { + private IntPtr passwordPtr; + private int passwordLength; + + public KeyChainItemHandle() : base() + { + + } + + public KeyChainItemHandle(IntPtr itemPtr) : this(itemPtr, IntPtr.Zero, 0) + { + + } + + public KeyChainItemHandle(IntPtr itemPtr, IntPtr passwordPtr, UInt32 passwordLength) + : base(itemPtr) + { + this.passwordPtr = passwordPtr; + this.passwordLength = (int) passwordLength; + } + + public string Password + { + get { + if (IsInvalid) + { + return null; + } + return InteropUtils.CopyToString(passwordPtr, passwordLength); + } + } + protected override bool ReleaseHandle() + { + if (passwordPtr != IntPtr.Zero) + { + Interop.Security.SecKeychainItemFreeContent(IntPtr.Zero, passwordPtr); + } + base.ReleaseHandle(); + return true; + } + } + } +} \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/OSX/SafeCreateHandle.OSX.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/OSX/SafeCreateHandle.OSX.cs new file mode 100644 index 00000000..5beaaf26 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/OSX/SafeCreateHandle.OSX.cs @@ -0,0 +1,43 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.SqlTools.ServiceLayer.Credentials + +{ + /// + /// This class is a wrapper around the Create pattern in OS X where + /// if a Create* function is called, the caller must also CFRelease + /// on the same pointer in order to correctly free the memory. + /// + [System.Security.SecurityCritical] + internal partial class SafeCreateHandle : SafeHandle + { + internal SafeCreateHandle() : base(IntPtr.Zero, true) { } + + internal SafeCreateHandle(IntPtr ptr) : base(IntPtr.Zero, true) + { + this.SetHandle(ptr); + } + + [System.Security.SecurityCritical] + protected override bool ReleaseHandle() + { + Interop.CoreFoundation.CFRelease(handle); + + return true; + } + + public override bool IsInvalid + { + [System.Security.SecurityCritical] + get + { + return handle == IntPtr.Zero; + } + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/SecureStringHelper.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/SecureStringHelper.cs new file mode 100644 index 00000000..070e0b20 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/SecureStringHelper.cs @@ -0,0 +1,42 @@ +// +// Code originally from http://credentialmanagement.codeplex.com/, +// Licensed under the Apache License 2.0 +// + +using System; +using System.Runtime.InteropServices; +using System.Security; + +namespace Microsoft.SqlTools.ServiceLayer.Credentials.Win32 +{ + internal static class SecureStringHelper + { + // Methods + internal static SecureString CreateSecureString(string plainString) + { + SecureString str = new SecureString(); + if (!string.IsNullOrEmpty(plainString)) + { + foreach (char c in plainString) + { + str.AppendChar(c); + } + } + str.MakeReadOnly(); + return str; + } + + internal static string CreateString(SecureString value) + { + IntPtr ptr = SecureStringMarshal.SecureStringToGlobalAllocUnicode(value); + try + { + return Marshal.PtrToStringUni(ptr); + } + finally + { + Marshal.ZeroFreeGlobalAllocUnicode(ptr); + } + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/Win32/CredentialResources.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Win32/CredentialResources.cs new file mode 100644 index 00000000..d3834c42 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Win32/CredentialResources.cs @@ -0,0 +1,16 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.Credentials.Win32 +{ + // TODO Replace this strings class with a resx file + internal class CredentialResources + { + public const string PasswordLengthExceeded = "The password has exceeded 512 bytes."; + public const string TargetRequiredForDelete = "Target must be specified to delete a credential."; + public const string TargetRequiredForLookup = "Target must be specified to check existance of a credential."; + public const string CredentialDisposed = "Win32Credential object is already disposed."; + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/Win32/CredentialSet.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Win32/CredentialSet.cs new file mode 100644 index 00000000..f94a0f57 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Win32/CredentialSet.cs @@ -0,0 +1,113 @@ +// +// Code originally from http://credentialmanagement.codeplex.com/, +// Licensed under the Apache License 2.0 +// + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Runtime.InteropServices; +using Microsoft.SqlTools.EditorServices.Utility; + +namespace Microsoft.SqlTools.ServiceLayer.Credentials.Win32 +{ + public class CredentialSet: List, IDisposable + { + bool _disposed; + + public CredentialSet() + { + } + + public CredentialSet(string target) + : this() + { + if (string.IsNullOrEmpty(target)) + { + throw new ArgumentNullException("target"); + } + Target = target; + } + + public string Target { get; set; } + + + public void Dispose() + { + Dispose(true); + + // Prevent GC Collection since we have already disposed of this object + GC.SuppressFinalize(this); + } + + ~CredentialSet() + { + Dispose(false); + } + + private void Dispose(bool disposing) + { + if (!_disposed) + { + if (disposing) + { + if (Count > 0) + { + ForEach(cred => cred.Dispose()); + } + } + } + _disposed = true; + } + + public CredentialSet Load() + { + LoadInternal(); + return this; + } + + private void LoadInternal() + { + uint count; + + IntPtr pCredentials = IntPtr.Zero; + bool result = NativeMethods.CredEnumerateW(Target, 0, out count, out pCredentials); + if (!result) + { + Logger.Write(LogLevel.Error, string.Format("Win32Exception: {0}", new Win32Exception(Marshal.GetLastWin32Error()).ToString())); + return; + } + + // Read in all of the pointers first + IntPtr[] ptrCredList = new IntPtr[count]; + for (int i = 0; i < count; i++) + { + ptrCredList[i] = Marshal.ReadIntPtr(pCredentials, IntPtr.Size*i); + } + + // Now let's go through all of the pointers in the list + // and create our Credential object(s) + List credentialHandles = + ptrCredList.Select(ptrCred => new NativeMethods.CriticalCredentialHandle(ptrCred)).ToList(); + + IEnumerable existingCredentials = credentialHandles + .Select(handle => handle.GetCredential()) + .Select(nativeCredential => + { + Win32Credential credential = new Win32Credential(); + credential.LoadInternal(nativeCredential); + return credential; + }); + AddRange(existingCredentials); + + // The individual credentials should not be free'd + credentialHandles.ForEach(handle => handle.SetHandleAsInvalid()); + + // Clean up memory to the Enumeration pointer + NativeMethods.CredFree(pCredentials); + } + + } + +} \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/Win32/CredentialType.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Win32/CredentialType.cs new file mode 100644 index 00000000..edc16d0d --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Win32/CredentialType.cs @@ -0,0 +1,16 @@ +// +// Code originally from http://credentialmanagement.codeplex.com/, +// Licensed under the Apache License 2.0 +// + +namespace Microsoft.SqlTools.ServiceLayer.Credentials.Win32 +{ + public enum CredentialType: uint + { + None = 0, + Generic = 1, + DomainPassword = 2, + DomainCertificate = 3, + DomainVisiblePassword = 4 + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/Win32/GlobalSuppressions.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Win32/GlobalSuppressions.cs new file mode 100644 index 00000000..3ee40dc8 Binary files /dev/null and b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Win32/GlobalSuppressions.cs differ diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/Win32/NativeMethods.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Win32/NativeMethods.cs new file mode 100644 index 00000000..1e43205c --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Win32/NativeMethods.cs @@ -0,0 +1,109 @@ +// +// Code originally from http://credentialmanagement.codeplex.com/, +// Licensed under the Apache License 2.0 +// + +using System; +using System.Runtime.InteropServices; +using System.Text; + +namespace Microsoft.SqlTools.ServiceLayer.Credentials.Win32 +{ + internal class NativeMethods + { + + [StructLayout(LayoutKind.Sequential)] + internal struct CREDENTIAL + { + public int Flags; + public int Type; + [MarshalAs(UnmanagedType.LPWStr)] + public string TargetName; + [MarshalAs(UnmanagedType.LPWStr)] + public string Comment; + public long LastWritten; + public int CredentialBlobSize; + public IntPtr CredentialBlob; + public int Persist; + public int AttributeCount; + public IntPtr Attributes; + [MarshalAs(UnmanagedType.LPWStr)] + public string TargetAlias; + [MarshalAs(UnmanagedType.LPWStr)] + public string UserName; + } + + [DllImport("Advapi32.dll", EntryPoint = "CredReadW", CharSet = CharSet.Unicode, SetLastError = true)] + internal static extern bool CredRead(string target, CredentialType type, int reservedFlag, out IntPtr CredentialPtr); + + [DllImport("Advapi32.dll", EntryPoint = "CredWriteW", CharSet = CharSet.Unicode, SetLastError = true)] + internal static extern bool CredWrite([In] ref CREDENTIAL userCredential, [In] UInt32 flags); + + [DllImport("Advapi32.dll", EntryPoint = "CredFree", SetLastError = true)] + internal static extern bool CredFree([In] IntPtr cred); + + [DllImport("advapi32.dll", EntryPoint = "CredDeleteW", CharSet = CharSet.Unicode)] + internal static extern bool CredDelete(StringBuilder target, CredentialType type, int flags); + + [DllImport("advapi32.dll", SetLastError = true, CharSet = CharSet.Unicode)] + internal static extern bool CredEnumerateW(string filter, int flag, out uint count, out IntPtr pCredentials); + + [DllImport("ole32.dll")] + internal static extern void CoTaskMemFree(IntPtr ptr); + + + internal abstract class CriticalHandleZeroOrMinusOneIsInvalid : CriticalHandle + { + protected CriticalHandleZeroOrMinusOneIsInvalid() : base(IntPtr.Zero) + { + } + + public override bool IsInvalid + { + get { return handle == new IntPtr(0) || handle == new IntPtr(-1); } + } + } + + internal sealed class CriticalCredentialHandle : CriticalHandleZeroOrMinusOneIsInvalid + { + // Set the handle. + internal CriticalCredentialHandle(IntPtr preexistingHandle) + { + SetHandle(preexistingHandle); + } + + internal CREDENTIAL GetCredential() + { + if (!IsInvalid) + { + // Get the Credential from the mem location + return (CREDENTIAL)Marshal.PtrToStructure(handle); + } + else + { + throw new InvalidOperationException("Invalid CriticalHandle!"); + } + } + + // Perform any specific actions to release the handle in the ReleaseHandle method. + // Often, you need to use Pinvoke to make a call into the Win32 API to release the + // handle. In this case, however, we can use the Marshal class to release the unmanaged memory. + + override protected bool ReleaseHandle() + { + // If the handle was set, free it. Return success. + if (!IsInvalid) + { + // NOTE: We should also ZERO out the memory allocated to the handle, before free'ing it + // so there are no traces of the sensitive data left in memory. + CredFree(handle); + // Mark the handle as invalid for future users. + SetHandleAsInvalid(); + return true; + } + // Return false. + return false; + } + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/Win32/PersistanceType.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Win32/PersistanceType.cs new file mode 100644 index 00000000..b08eff08 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Win32/PersistanceType.cs @@ -0,0 +1,14 @@ +// +// Code originally from http://credentialmanagement.codeplex.com/, +// Licensed under the Apache License 2.0 +// + +namespace Microsoft.SqlTools.ServiceLayer.Credentials.Win32 +{ + public enum PersistanceType : uint + { + Session = 1, + LocalComputer = 2, + Enterprise = 3 + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/Win32/Win32Credential.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Win32/Win32Credential.cs new file mode 100644 index 00000000..21c1c8b9 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Win32/Win32Credential.cs @@ -0,0 +1,290 @@ +// +// Code originally from http://credentialmanagement.codeplex.com/, +// Licensed under the Apache License 2.0 +// + +using System; +using System.Runtime.InteropServices; +using System.Security; +using System.Text; + +namespace Microsoft.SqlTools.ServiceLayer.Credentials.Win32 +{ + public class Win32Credential: IDisposable + { + bool disposed; + + CredentialType type; + string target; + SecureString password; + string username; + string description; + DateTime lastWriteTime; + PersistanceType persistanceType; + + public Win32Credential() + : this(null) + { + } + + public Win32Credential(string username) + : this(username, null) + { + } + + public Win32Credential(string username, string password) + : this(username, password, null) + { + } + + public Win32Credential(string username, string password, string target) + : this(username, password, target, CredentialType.Generic) + { + } + + public Win32Credential(string username, string password, string target, CredentialType type) + { + Username = username; + Password = password; + Target = target; + Type = type; + PersistanceType = PersistanceType.Session; + lastWriteTime = DateTime.MinValue; + } + + + public void Dispose() + { + Dispose(true); + + // Prevent GC Collection since we have already disposed of this object + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (!disposed) + { + if (disposing) + { + SecurePassword.Clear(); + SecurePassword.Dispose(); + } + } + disposed = true; + } + + private void CheckNotDisposed() + { + if (disposed) + { + throw new ObjectDisposedException(CredentialResources.CredentialDisposed); + } + } + + + public string Username { + get + { + CheckNotDisposed(); + return username; + } + set + { + CheckNotDisposed(); + username = value; + } + } + public string Password + { + get + { + return SecureStringHelper.CreateString(SecurePassword); + } + set + { + CheckNotDisposed(); + SecurePassword = SecureStringHelper.CreateSecureString(string.IsNullOrEmpty(value) ? string.Empty : value); + } + } + public SecureString SecurePassword + { + get + { + CheckNotDisposed(); + return null == password ? new SecureString() : password.Copy(); + } + set + { + CheckNotDisposed(); + if (null != password) + { + password.Clear(); + password.Dispose(); + } + password = null == value ? new SecureString() : value.Copy(); + } + } + public string Target + { + get + { + CheckNotDisposed(); + return target; + } + set + { + CheckNotDisposed(); + target = value; + } + } + + public string Description + { + get + { + CheckNotDisposed(); + return description; + } + set + { + CheckNotDisposed(); + description = value; + } + } + + public DateTime LastWriteTime + { + get + { + return LastWriteTimeUtc.ToLocalTime(); + } + } + public DateTime LastWriteTimeUtc + { + get + { + CheckNotDisposed(); + return lastWriteTime; + } + private set { lastWriteTime = value; } + } + + public CredentialType Type + { + get + { + CheckNotDisposed(); + return type; + } + set + { + CheckNotDisposed(); + type = value; + } + } + + public PersistanceType PersistanceType + { + get + { + CheckNotDisposed(); + return persistanceType; + } + set + { + CheckNotDisposed(); + persistanceType = value; + } + } + + public bool Save() + { + CheckNotDisposed(); + + byte[] passwordBytes = Encoding.Unicode.GetBytes(Password); + if (Password.Length > (512)) + { + throw new ArgumentOutOfRangeException(CredentialResources.PasswordLengthExceeded); + } + + NativeMethods.CREDENTIAL credential = new NativeMethods.CREDENTIAL(); + credential.TargetName = Target; + credential.UserName = Username; + credential.CredentialBlob = Marshal.StringToCoTaskMemUni(Password); + credential.CredentialBlobSize = passwordBytes.Length; + credential.Comment = Description; + credential.Type = (int)Type; + credential.Persist = (int) PersistanceType; + + bool result = NativeMethods.CredWrite(ref credential, 0); + if (!result) + { + return false; + } + LastWriteTimeUtc = DateTime.UtcNow; + return true; + } + + public bool Delete() + { + CheckNotDisposed(); + + if (string.IsNullOrEmpty(Target)) + { + throw new InvalidOperationException(CredentialResources.TargetRequiredForDelete); + } + + StringBuilder target = string.IsNullOrEmpty(Target) ? new StringBuilder() : new StringBuilder(Target); + bool result = NativeMethods.CredDelete(target, Type, 0); + return result; + } + + public bool Load() + { + CheckNotDisposed(); + + IntPtr credPointer; + + bool result = NativeMethods.CredRead(Target, Type, 0, out credPointer); + if (!result) + { + return false; + } + using (NativeMethods.CriticalCredentialHandle credentialHandle = new NativeMethods.CriticalCredentialHandle(credPointer)) + { + LoadInternal(credentialHandle.GetCredential()); + } + return true; + } + + public bool Exists() + { + CheckNotDisposed(); + + if (string.IsNullOrEmpty(Target)) + { + throw new InvalidOperationException(CredentialResources.TargetRequiredForLookup); + } + + using (Win32Credential existing = new Win32Credential { Target = Target, Type = Type }) + { + return existing.Load(); + } + } + + internal void LoadInternal(NativeMethods.CREDENTIAL credential) + { + Username = credential.UserName; + if (credential.CredentialBlobSize > 0) + { + Password = Marshal.PtrToStringUni(credential.CredentialBlob, credential.CredentialBlobSize / 2); + } + Target = credential.TargetName; + Type = (CredentialType)credential.Type; + PersistanceType = (PersistanceType)credential.Persist; + Description = credential.Comment; + LastWriteTimeUtc = DateTime.FromFileTimeUtc(credential.LastWritten); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/Credentials/Win32/Win32CredentialStore.cs b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Win32/Win32CredentialStore.cs new file mode 100644 index 00000000..8a219854 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Credentials/Win32/Win32CredentialStore.cs @@ -0,0 +1,63 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.EditorServices.Utility; +using Microsoft.SqlTools.ServiceLayer.Credentials.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Credentials.Win32 +{ + /// + /// Win32 implementation of the credential store + /// + internal class Win32CredentialStore : ICredentialStore + { + private const string AnyUsername = "*"; + + public bool DeletePassword(string credentialId) + { + using (Win32Credential cred = new Win32Credential() { Target = credentialId, Username = AnyUsername }) + { + return cred.Delete(); + } + } + + public bool TryGetPassword(string credentialId, out string password) + { + Validate.IsNotNullOrEmptyString("credentialId", credentialId); + password = null; + + using (CredentialSet set = new CredentialSet(credentialId).Load()) + { + // Note: Credentials are disposed on disposal of the set + Win32Credential foundCred = null; + if (set.Count > 0) + { + foundCred = set[0]; + } + + if (foundCred != null) + { + password = foundCred.Password; + return true; + } + return false; + } + } + + public bool Save(Credential credential) + { + Credential.ValidateForSave(credential); + + using (Win32Credential cred = + new Win32Credential(AnyUsername, credential.Password, credential.CredentialId, CredentialType.Generic) + { PersistanceType = PersistanceType.LocalComputer }) + { + return cred.Save(); + } + + } + } + +} diff --git a/src/ServiceHost/LanguageServer/ClientCapabilities.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/ClientCapabilities.cs similarity index 85% rename from src/ServiceHost/LanguageServer/ClientCapabilities.cs rename to src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/ClientCapabilities.cs index 70e2d068..397deceb 100644 --- a/src/ServiceHost/LanguageServer/ClientCapabilities.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/ClientCapabilities.cs @@ -4,7 +4,7 @@ // -namespace Microsoft.SqlTools.EditorServices.Protocol.LanguageServer +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Contracts { /// /// Defines a class that describes the capabilities of a language diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs new file mode 100644 index 00000000..d6e65801 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs @@ -0,0 +1,28 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Contracts +{ + /// + /// Parameters to be used for reporting hosting-level errors, such as protocol violations + /// + public class HostingErrorParams + { + /// + /// The message of the error + /// + public string Message { get; set; } + } + + public class HostingErrorEvent + { + public static readonly + EventType Type = + EventType.Create("hosting/error"); + + } +} diff --git a/src/ServiceHost/LanguageServer/Initialize.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/Initialize.cs similarity index 91% rename from src/ServiceHost/LanguageServer/Initialize.cs rename to src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/Initialize.cs index 7551835e..215edf87 100644 --- a/src/ServiceHost/LanguageServer/Initialize.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/Initialize.cs @@ -3,9 +3,9 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; -namespace Microsoft.SqlTools.EditorServices.Protocol.LanguageServer +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Contracts { public class InitializeRequest { diff --git a/src/ServiceHost/LanguageServer/ServerCapabilities.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/ServerCapabilities.cs similarity index 96% rename from src/ServiceHost/LanguageServer/ServerCapabilities.cs rename to src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/ServerCapabilities.cs index 2f7404d9..32f0e736 100644 --- a/src/ServiceHost/LanguageServer/ServerCapabilities.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/ServerCapabilities.cs @@ -3,7 +3,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -namespace Microsoft.SqlTools.EditorServices.Protocol.LanguageServer +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Contracts { public class ServerCapabilities { diff --git a/src/ServiceHost/LanguageServer/Shutdown.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/Shutdown.cs similarity index 86% rename from src/ServiceHost/LanguageServer/Shutdown.cs rename to src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/Shutdown.cs index f0a7bbd2..1ccb9cfc 100644 --- a/src/ServiceHost/LanguageServer/Shutdown.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/Shutdown.cs @@ -3,9 +3,9 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; -namespace Microsoft.SqlTools.EditorServices.Protocol.LanguageServer +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Contracts { /// /// Defines a message that is sent from the client to request diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/VersionRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/VersionRequest.cs new file mode 100644 index 00000000..ed7ab358 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/VersionRequest.cs @@ -0,0 +1,20 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Contracts +{ + /// + /// Defines a message that is sent from the client to request + /// the version of the server. + /// + public class VersionRequest + { + public static readonly + RequestType Type = + RequestType.Create("version"); + } +} diff --git a/src/ServiceHost/MessageProtocol/Channel/ChannelBase.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Channel/ChannelBase.cs similarity index 94% rename from src/ServiceHost/MessageProtocol/Channel/ChannelBase.cs rename to src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Channel/ChannelBase.cs index 848da39f..48cd66aa 100644 --- a/src/ServiceHost/MessageProtocol/Channel/ChannelBase.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Channel/ChannelBase.cs @@ -3,10 +3,10 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol.Serializers; using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers; -namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol.Channel +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Channel { /// /// Defines a base implementation for servers and their clients over a diff --git a/src/ServiceHost/MessageProtocol/Channel/StdioClientChannel.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Channel/StdioClientChannel.cs similarity index 96% rename from src/ServiceHost/MessageProtocol/Channel/StdioClientChannel.cs rename to src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Channel/StdioClientChannel.cs index 5390f52d..02b79c6b 100644 --- a/src/ServiceHost/MessageProtocol/Channel/StdioClientChannel.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Channel/StdioClientChannel.cs @@ -7,8 +7,9 @@ using System.Diagnostics; using System.IO; using System.Text; using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers; -namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol.Channel +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Channel { /// /// Provides a client implementation for the standard I/O channel. diff --git a/src/ServiceHost/MessageProtocol/Channel/StdioServerChannel.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Channel/StdioServerChannel.cs similarity index 93% rename from src/ServiceHost/MessageProtocol/Channel/StdioServerChannel.cs rename to src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Channel/StdioServerChannel.cs index 0b9376d4..204a9908 100644 --- a/src/ServiceHost/MessageProtocol/Channel/StdioServerChannel.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Channel/StdioServerChannel.cs @@ -6,8 +6,9 @@ using System.IO; using System.Text; using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers; -namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol.Channel +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Channel { /// /// Provides a server implementation for the standard I/O channel. diff --git a/src/ServiceHost/MessageProtocol/Constants.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Constants.cs similarity index 91% rename from src/ServiceHost/MessageProtocol/Constants.cs rename to src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Constants.cs index 0fae5d8d..14f3d762 100644 --- a/src/ServiceHost/MessageProtocol/Constants.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Constants.cs @@ -6,7 +6,7 @@ using Newtonsoft.Json; using Newtonsoft.Json.Serialization; -namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol { public static class Constants { diff --git a/src/ServiceHost/MessageProtocol/EventType.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Contracts/EventType.cs similarity index 93% rename from src/ServiceHost/MessageProtocol/EventType.cs rename to src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Contracts/EventType.cs index dd460817..4d9a251b 100644 --- a/src/ServiceHost/MessageProtocol/EventType.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Contracts/EventType.cs @@ -3,7 +3,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts { /// /// Defines an event type with a particular method name. diff --git a/src/ServiceHost/MessageProtocol/Message.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Contracts/Message.cs similarity index 98% rename from src/ServiceHost/MessageProtocol/Message.cs rename to src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Contracts/Message.cs index 75dab5cd..6af6a101 100644 --- a/src/ServiceHost/MessageProtocol/Message.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Contracts/Message.cs @@ -6,7 +6,7 @@ using System.Diagnostics; using Newtonsoft.Json.Linq; -namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts { /// /// Defines all possible message types. diff --git a/src/ServiceHost/MessageProtocol/RequestType.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Contracts/RequestType.cs similarity index 89% rename from src/ServiceHost/MessageProtocol/RequestType.cs rename to src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Contracts/RequestType.cs index 29fc11c5..67676259 100644 --- a/src/ServiceHost/MessageProtocol/RequestType.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Contracts/RequestType.cs @@ -5,7 +5,7 @@ using System.Diagnostics; -namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts { [DebuggerDisplay("RequestType MethodName = {MethodName}")] public class RequestType diff --git a/src/ServiceHost/MessageProtocol/EventContext.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/EventContext.cs similarity index 86% rename from src/ServiceHost/MessageProtocol/EventContext.cs rename to src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/EventContext.cs index eb42ebbb..4a1bc40e 100644 --- a/src/ServiceHost/MessageProtocol/EventContext.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/EventContext.cs @@ -4,8 +4,9 @@ // using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; -namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol { /// /// Provides context for a received event so that handlers diff --git a/src/ServiceHost/MessageProtocol/IMessageSender.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/IMessageSender.cs similarity index 75% rename from src/ServiceHost/MessageProtocol/IMessageSender.cs rename to src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/IMessageSender.cs index 7f331eed..583fb3b0 100644 --- a/src/ServiceHost/MessageProtocol/IMessageSender.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/IMessageSender.cs @@ -4,10 +4,11 @@ // using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; -namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol { - internal interface IMessageSender + public interface IMessageSender { Task SendEvent( EventType eventType, diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/IProtocolEndpoint.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/IProtocolEndpoint.cs new file mode 100644 index 00000000..496e3d56 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/IProtocolEndpoint.cs @@ -0,0 +1,32 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol +{ + /// + /// A ProtocolEndpoint is used for inter-process communication. Services can register to + /// respond to requests and events, send their own requests, and listen for notifications + /// sent by the other side of the endpoint + /// + public interface IProtocolEndpoint : IMessageSender + { + void SetRequestHandler( + RequestType requestType, + Func, Task> requestHandler); + + void SetEventHandler( + EventType eventType, + Func eventHandler); + + void SetEventHandler( + EventType eventType, + Func eventHandler, + bool overrideExisting); + } +} diff --git a/src/ServiceHost/MessageProtocol/MessageDispatcher.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs similarity index 86% rename from src/ServiceHost/MessageProtocol/MessageDispatcher.cs rename to src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs index 21c179e2..c4cf5365 100644 --- a/src/ServiceHost/MessageProtocol/MessageDispatcher.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs @@ -3,15 +3,17 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol.Channel; -using Microsoft.SqlTools.EditorServices.Utility; using System; using System.Collections.Generic; using System.IO; using System.Threading; using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Contracts; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Channel; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; +using Microsoft.SqlTools.EditorServices.Utility; -namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol { public class MessageDispatcher { @@ -197,10 +199,9 @@ namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol this.SynchronizationContext = SynchronizationContext.Current; // Run the message loop - bool isRunning = true; - while (isRunning && !cancellationToken.IsCancellationRequested) + while (!cancellationToken.IsCancellationRequested) { - Message newMessage = null; + Message newMessage; try { @@ -209,12 +210,12 @@ namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol } catch (MessageParseException e) { - // TODO: Write an error response - - Logger.Write( - LogLevel.Error, - "Could not parse a message that was received:\r\n\r\n" + - e.ToString()); + string message = string.Format("Exception occurred while parsing message: {0}", e.Message); + Logger.Write(LogLevel.Error, message); + await MessageWriter.WriteEvent(HostingErrorEvent.Type, new HostingErrorParams + { + Message = message + }); // Continue the loop continue; @@ -226,18 +227,29 @@ namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol } catch (Exception e) { - var b = e.Message; - newMessage = null; + // Log the error and send an error event to the client + string message = string.Format("Exception occurred while receiving message: {0}", e.Message); + Logger.Write(LogLevel.Error, message); + await MessageWriter.WriteEvent(HostingErrorEvent.Type, new HostingErrorParams + { + Message = message + }); + + // Continue the loop + continue; } // The message could be null if there was an error parsing the // previous message. In this case, do not try to dispatch it. if (newMessage != null) { + // Verbose logging + string logMessage = string.Format("Received message of type[{0}] and method[{1}]", + newMessage.MessageType, newMessage.Method); + Logger.Write(LogLevel.Verbose, logMessage); + // Process the message - await this.DispatchMessage( - newMessage, - this.MessageWriter); + await this.DispatchMessage(newMessage, this.MessageWriter); } } } diff --git a/src/ServiceHost/MessageProtocol/MessageParseException.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageParseException.cs similarity index 89% rename from src/ServiceHost/MessageProtocol/MessageParseException.cs rename to src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageParseException.cs index 98a17c20..b4ef94c2 100644 --- a/src/ServiceHost/MessageProtocol/MessageParseException.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageParseException.cs @@ -5,7 +5,7 @@ using System; -namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol { public class MessageParseException : Exception { diff --git a/src/ServiceHost/MessageProtocol/MessageProtocolType.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageProtocolType.cs similarity index 89% rename from src/ServiceHost/MessageProtocol/MessageProtocolType.cs rename to src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageProtocolType.cs index 5484ae3c..480332fa 100644 --- a/src/ServiceHost/MessageProtocol/MessageProtocolType.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageProtocolType.cs @@ -3,7 +3,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol { /// /// Defines the possible message protocol types. diff --git a/src/ServiceHost/MessageProtocol/MessageReader.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs similarity index 60% rename from src/ServiceHost/MessageProtocol/MessageReader.cs rename to src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs index a2df43ba..e70f33f4 100644 --- a/src/ServiceHost/MessageProtocol/MessageReader.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs @@ -3,16 +3,17 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.EditorServices.Utility; -using Newtonsoft.Json; -using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; using System.IO; using System.Text; using System.Threading.Tasks; +using Microsoft.SqlTools.EditorServices.Utility; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers; +using Newtonsoft.Json.Linq; -namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol { public class MessageReader { @@ -23,22 +24,22 @@ namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol private const int CR = 0x0D; private const int LF = 0x0A; - private static string[] NewLineDelimiters = new string[] { Environment.NewLine }; + private static readonly string[] NewLineDelimiters = { Environment.NewLine }; - private Stream inputStream; - private IMessageSerializer messageSerializer; - private Encoding messageEncoding; + private readonly Stream inputStream; + private readonly IMessageSerializer messageSerializer; + private readonly Encoding messageEncoding; private ReadState readState; private bool needsMoreData = true; private int readOffset; private int bufferEndOffset; - private byte[] messageBuffer = new byte[DefaultBufferSize]; + private byte[] messageBuffer; private int expectedContentLength; private Dictionary messageHeaders; - enum ReadState + private enum ReadState { Headers, Content @@ -83,7 +84,7 @@ namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol this.needsMoreData = false; // Do we need to look for message headers? - if (this.readState == ReadState.Headers && + if (this.readState == ReadState.Headers && !this.TryReadMessageHeaders()) { // If we don't have enough data to read headers yet, keep reading @@ -92,7 +93,7 @@ namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol } // Do we need to look for message content? - if (this.readState == ReadState.Content && + if (this.readState == ReadState.Content && !this.TryReadMessageContent(out messageContent)) { // If we don't have enough data yet to construct the content, keep reading @@ -104,16 +105,12 @@ namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol break; } + // Now that we have a message, reset the buffer's state + ShiftBufferBytesAndShrink(readOffset); + // Get the JObject for the JSON content JObject messageObject = JObject.Parse(messageContent); - // Load the message - Logger.Write( - LogLevel.Verbose, - string.Format( - "READ MESSAGE:\r\n\r\n{0}", - messageObject.ToString(Formatting.Indented))); - // Return the parsed message return this.messageSerializer.DeserializeMessage(messageObject); } @@ -160,8 +157,7 @@ namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol { int scanOffset = this.readOffset; - // Scan for the final double-newline that marks the - // end of the header lines + // Scan for the final double-newline that marks the end of the header lines while (scanOffset + 3 < this.bufferEndOffset && (this.messageBuffer[scanOffset] != CR || this.messageBuffer[scanOffset + 1] != LF || @@ -171,45 +167,51 @@ namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol scanOffset++; } - // No header or body separator found (e.g CRLFCRLF) + // Make sure we haven't reached the end of the buffer without finding a separator (e.g CRLFCRLF) if (scanOffset + 3 >= this.bufferEndOffset) { return false; } - this.messageHeaders = new Dictionary(); + // Convert the header block into a array of lines + var headers = Encoding.ASCII.GetString(this.messageBuffer, this.readOffset, scanOffset) + .Split(NewLineDelimiters, StringSplitOptions.RemoveEmptyEntries); - var headers = - Encoding.ASCII - .GetString(this.messageBuffer, this.readOffset, scanOffset) - .Split(NewLineDelimiters, StringSplitOptions.RemoveEmptyEntries); - - // Read each header and store it in the dictionary - foreach (var header in headers) + try { - int currentLength = header.IndexOf(':'); - if (currentLength == -1) + // Read each header and store it in the dictionary + this.messageHeaders = new Dictionary(); + foreach (var header in headers) { - throw new ArgumentException("Message header must separate key and value using :"); + int currentLength = header.IndexOf(':'); + if (currentLength == -1) + { + throw new ArgumentException("Message header must separate key and value using :"); + } + + var key = header.Substring(0, currentLength); + var value = header.Substring(currentLength + 1).Trim(); + this.messageHeaders[key] = value; } - var key = header.Substring(0, currentLength); - var value = header.Substring(currentLength + 1).Trim(); - this.messageHeaders[key] = value; - } + // Parse out the content length as an int + string contentLengthString; + if (!this.messageHeaders.TryGetValue("Content-Length", out contentLengthString)) + { + throw new MessageParseException("", "Fatal error: Content-Length header must be provided."); + } - // Make sure a Content-Length header was present, otherwise it - // is a fatal error - string contentLengthString = null; - if (!this.messageHeaders.TryGetValue("Content-Length", out contentLengthString)) - { - throw new MessageParseException("", "Fatal error: Content-Length header must be provided."); + // Parse the content length to an integer + if (!int.TryParse(contentLengthString, out this.expectedContentLength)) + { + throw new MessageParseException("", "Fatal error: Content-Length value is not an integer."); + } } - - // Parse the content length to an integer - if (!int.TryParse(contentLengthString, out this.expectedContentLength)) + catch (Exception) { - throw new MessageParseException("", "Fatal error: Content-Length value is not an integer."); + // The content length was invalid or missing. Trash the buffer we've read + ShiftBufferBytesAndShrink(scanOffset + 4); + throw; } // Skip past the headers plus the newline characters @@ -232,31 +234,40 @@ namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol } // Convert the message contents to a string using the specified encoding - messageContent = - this.messageEncoding.GetString( - this.messageBuffer, - this.readOffset, - this.expectedContentLength); + messageContent = this.messageEncoding.GetString( + this.messageBuffer, + this.readOffset, + this.expectedContentLength); - // Move the remaining bytes to the front of the buffer for the next message - var remainingByteCount = this.bufferEndOffset - (this.expectedContentLength + this.readOffset); - Buffer.BlockCopy( - this.messageBuffer, - this.expectedContentLength + this.readOffset, - this.messageBuffer, - 0, - remainingByteCount); + readOffset += expectedContentLength; - // Reset the offsets for the next read - this.readOffset = 0; - this.bufferEndOffset = remainingByteCount; - - // Done reading content, now look for headers + // Done reading content, now look for headers for the next message this.readState = ReadState.Headers; return true; } + private void ShiftBufferBytesAndShrink(int bytesToRemove) + { + // Create a new buffer that is shrunken by the number of bytes to remove + // Note: by using Max, we can guarantee a buffer of at least default buffer size + byte[] newBuffer = new byte[Math.Max(messageBuffer.Length - bytesToRemove, DefaultBufferSize)]; + + // If we need to do shifting, do the shifting + if (bytesToRemove <= messageBuffer.Length) + { + // Copy the existing buffer starting at the offset to remove + Buffer.BlockCopy(messageBuffer, bytesToRemove, newBuffer, 0, bufferEndOffset - bytesToRemove); + } + + // Make the new buffer the message buffer + messageBuffer = newBuffer; + + // Reset the read offset and the end offset + readOffset = 0; + bufferEndOffset -= bytesToRemove; + } + #endregion } } diff --git a/src/ServiceHost/MessageProtocol/MessageWriter.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageWriter.cs similarity index 95% rename from src/ServiceHost/MessageProtocol/MessageWriter.cs rename to src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageWriter.cs index 96e13bcd..b269f750 100644 --- a/src/ServiceHost/MessageProtocol/MessageWriter.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageWriter.cs @@ -3,14 +3,16 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.EditorServices.Utility; -using Newtonsoft.Json; -using Newtonsoft.Json.Linq; using System.IO; using System.Text; using System.Threading.Tasks; +using Microsoft.SqlTools.EditorServices.Utility; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; -namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol { public class MessageWriter { diff --git a/src/ServiceHost/MessageProtocol/ProtocolEndpoint.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/ProtocolEndpoint.cs similarity index 97% rename from src/ServiceHost/MessageProtocol/ProtocolEndpoint.cs rename to src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/ProtocolEndpoint.cs index daead186..5a18f85b 100644 --- a/src/ServiceHost/MessageProtocol/ProtocolEndpoint.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/ProtocolEndpoint.cs @@ -3,19 +3,20 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol.Channel; using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Channel; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; -namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol { /// /// Provides behavior for a client or server endpoint that /// communicates using the specified protocol. /// - public class ProtocolEndpoint : IMessageSender + public class ProtocolEndpoint : IMessageSender, IProtocolEndpoint { private bool isStarted; private int currentMessageId; diff --git a/src/ServiceHost/MessageProtocol/RequestContext.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/RequestContext.cs similarity index 73% rename from src/ServiceHost/MessageProtocol/RequestContext.cs rename to src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/RequestContext.cs index a35bb136..a2811f6a 100644 --- a/src/ServiceHost/MessageProtocol/RequestContext.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/RequestContext.cs @@ -3,10 +3,11 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Newtonsoft.Json.Linq; using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; +using Newtonsoft.Json.Linq; -namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol { public class RequestContext { @@ -19,7 +20,9 @@ namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol this.messageWriter = messageWriter; } - public async Task SendResult(TResult resultDetails) + public RequestContext() { } + + public virtual async Task SendResult(TResult resultDetails) { await this.messageWriter.WriteResponse( resultDetails, @@ -27,14 +30,14 @@ namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol requestMessage.Id); } - public async Task SendEvent(EventType eventType, TParams eventParams) + public virtual async Task SendEvent(EventType eventType, TParams eventParams) { await this.messageWriter.WriteEvent( eventType, eventParams); } - public async Task SendError(object errorDetails) + public virtual async Task SendError(object errorDetails) { await this.messageWriter.WriteMessage( Message.ResponseError( diff --git a/src/ServiceHost/MessageProtocol/IMessageSerializer.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Serializers/IMessageSerializer.cs similarity index 87% rename from src/ServiceHost/MessageProtocol/IMessageSerializer.cs rename to src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Serializers/IMessageSerializer.cs index 81b23fa6..6a1133ff 100644 --- a/src/ServiceHost/MessageProtocol/IMessageSerializer.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Serializers/IMessageSerializer.cs @@ -3,9 +3,10 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; using Newtonsoft.Json.Linq; -namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers { /// /// Defines a common interface for message serializers. diff --git a/src/ServiceHost/MessageProtocol/Serializers/JsonRpcMessageSerializer.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Serializers/JsonRpcMessageSerializer.cs similarity index 95% rename from src/ServiceHost/MessageProtocol/Serializers/JsonRpcMessageSerializer.cs rename to src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Serializers/JsonRpcMessageSerializer.cs index fa1d1518..0ccca078 100644 --- a/src/ServiceHost/MessageProtocol/Serializers/JsonRpcMessageSerializer.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Serializers/JsonRpcMessageSerializer.cs @@ -3,9 +3,10 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; using Newtonsoft.Json.Linq; -namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol.Serializers +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers { /// /// Serializes messages in the JSON RPC format. Used primarily diff --git a/src/ServiceHost/MessageProtocol/Serializers/V8MessageSerializer.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Serializers/V8MessageSerializer.cs similarity index 96% rename from src/ServiceHost/MessageProtocol/Serializers/V8MessageSerializer.cs rename to src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Serializers/V8MessageSerializer.cs index 941e249a..f1385e00 100644 --- a/src/ServiceHost/MessageProtocol/Serializers/V8MessageSerializer.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/Serializers/V8MessageSerializer.cs @@ -5,8 +5,9 @@ using Newtonsoft.Json.Linq; using System; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; -namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol.Serializers +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers { /// /// Serializes messages in the V8 format. Used primarily for debug adapters. diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/ServiceHost.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/ServiceHost.cs new file mode 100644 index 00000000..5f5ef1df --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/ServiceHost.cs @@ -0,0 +1,168 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Linq; +using System.Threading.Tasks; +using System.Collections.Generic; +using Microsoft.SqlTools.EditorServices.Utility; +using Microsoft.SqlTools.ServiceLayer.Hosting.Contracts; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Channel; +using System.Reflection; + +namespace Microsoft.SqlTools.ServiceLayer.Hosting +{ + /// + /// SQL Tools VS Code Language Server request handler. Provides the entire JSON RPC + /// implementation for sending/receiving JSON requests and dispatching the requests to + /// handlers that are registered prior to startup. + /// + public sealed class ServiceHost : ServiceHostBase + { + #region Singleton Instance Code + + /// + /// Singleton instance of the service host for internal storage + /// + private static readonly Lazy instance = new Lazy(() => new ServiceHost()); + + /// + /// Current instance of the ServiceHost + /// + public static ServiceHost Instance + { + get { return instance.Value; } + } + + /// + /// Constructs new instance of ServiceHost using the host and profile details provided. + /// Access is private to ensure only one instance exists at a time. + /// + private ServiceHost() : base(new StdioServerChannel()) + { + // Initialize the shutdown activities + shutdownCallbacks = new List(); + initializeCallbacks = new List(); + } + + /// + /// Provide initialization that must occur after the service host is started + /// + public void Initialize() + { + // Register the requests that this service host will handle + this.SetRequestHandler(InitializeRequest.Type, this.HandleInitializeRequest); + this.SetRequestHandler(ShutdownRequest.Type, this.HandleShutdownRequest); + this.SetRequestHandler(VersionRequest.Type, HandleVersionRequest); + } + + #endregion + + #region Member Variables + + public delegate Task ShutdownCallback(object shutdownParams, RequestContext shutdownRequestContext); + + public delegate Task InitializeCallback(InitializeRequest startupParams, RequestContext requestContext); + + private readonly List shutdownCallbacks; + + private readonly List initializeCallbacks; + + private static readonly Version serviceVersion = Assembly.GetEntryAssembly().GetName().Version; + + #endregion + + #region Public Methods + + /// + /// Adds a new callback to be called when the shutdown request is submitted + /// + /// Callback to perform when a shutdown request is submitted + public void RegisterShutdownTask(ShutdownCallback callback) + { + shutdownCallbacks.Add(callback); + } + + /// + /// Add a new method to be called when the initialize request is submitted + /// + /// Callback to perform when an initialize request is submitted + public void RegisterInitializeTask(InitializeCallback callback) + { + initializeCallbacks.Add(callback); + } + + #endregion + + #region Request Handlers + + /// + /// Handles the shutdown event for the Language Server + /// + private async Task HandleShutdownRequest(object shutdownParams, RequestContext requestContext) + { + Logger.Write(LogLevel.Normal, "Service host is shutting down..."); + + // Call all the shutdown methods provided by the service components + Task[] shutdownTasks = shutdownCallbacks.Select(t => t(shutdownParams, requestContext)).ToArray(); + await Task.WhenAll(shutdownTasks); + } + + /// + /// Handles the initialization request + /// + /// + /// + /// + private async Task HandleInitializeRequest(InitializeRequest initializeParams, RequestContext requestContext) + { + Logger.Write(LogLevel.Verbose, "HandleInitializationRequest"); + + // Call all tasks that registered on the initialize request + var initializeTasks = initializeCallbacks.Select(t => t(initializeParams, requestContext)); + await Task.WhenAll(initializeTasks); + + // TODO: Figure out where this needs to go to be agnostic of the language + + // Send back what this server can do + await requestContext.SendResult( + new InitializeResult + { + Capabilities = new ServerCapabilities + { + TextDocumentSync = TextDocumentSyncKind.Incremental, + DefinitionProvider = true, + ReferencesProvider = true, + DocumentHighlightProvider = true, + DocumentSymbolProvider = true, + WorkspaceSymbolProvider = true, + CompletionProvider = new CompletionOptions + { + ResolveProvider = true, + TriggerCharacters = new string[] { ".", "-", ":", "\\" } + }, + SignatureHelpProvider = new SignatureHelpOptions + { + TriggerCharacters = new string[] { " " } // TODO: Other characters here? + } + } + }); + } + + /// + /// Handles the version request. Sends back the server version as result. + /// + private static async Task HandleVersionRequest( + object versionRequestParams, + RequestContext requestContext) + { + Logger.Write(LogLevel.Verbose, "HandleVersionRequest"); + await requestContext.SendResult(serviceVersion.ToString()); + } + + #endregion + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/ServiceHostBase.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/ServiceHostBase.cs new file mode 100644 index 00000000..8158822b --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/ServiceHostBase.cs @@ -0,0 +1,47 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Contracts; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Channel; + +namespace Microsoft.SqlTools.ServiceLayer.Hosting +{ + public abstract class ServiceHostBase : ProtocolEndpoint + { + private bool isStarted; + private TaskCompletionSource serverExitedTask; + + protected ServiceHostBase(ChannelBase serverChannel) : + base(serverChannel, MessageProtocolType.LanguageServer) + { + } + + protected override Task OnStart() + { + // Register handlers for server lifetime messages + + this.SetEventHandler(ExitNotification.Type, this.HandleExitNotification); + + return Task.FromResult(true); + } + + private async Task HandleExitNotification( + object exitParams, + EventContext eventContext) + { + // Stop the server channel + await this.Stop(); + + // Notify any waiter that the server has exited + if (this.serverExitedTask != null) + { + this.serverExitedTask.SetResult(true); + } + } + } +} + diff --git a/src/ServiceHost/Server/LanguageServerEditorOperations.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/ServiceHostEditorOperations.cs similarity index 100% rename from src/ServiceHost/Server/LanguageServerEditorOperations.cs rename to src/Microsoft.SqlTools.ServiceLayer/Hosting/ServiceHostEditorOperations.cs diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs new file mode 100644 index 00000000..5abe27f7 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs @@ -0,0 +1,323 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Data.SqlClient; +using System.Threading.Tasks; +using Microsoft.SqlServer.Management.Common; +using Microsoft.SqlServer.Management.SmoMetadataProvider; +using Microsoft.SqlServer.Management.SqlParser.Binder; +using Microsoft.SqlServer.Management.SqlParser.Intellisense; +using Microsoft.SqlServer.Management.SqlParser.MetadataProvider; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; +using Microsoft.SqlTools.ServiceLayer.Hosting; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts; +using Microsoft.SqlTools.ServiceLayer.SqlContext; +using Microsoft.SqlTools.ServiceLayer.Workspace; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices +{ + /// + /// Main class for Autocomplete functionality + /// + public class AutoCompleteService + { + #region Singleton Instance Implementation + + /// + /// Singleton service instance + /// + private static Lazy instance + = new Lazy(() => new AutoCompleteService()); + + /// + /// Gets the singleton service instance + /// + public static AutoCompleteService Instance + { + get + { + return instance.Value; + } + } + + /// + /// Default, parameterless constructor. + /// Internal constructor for use in test cases only + /// + internal AutoCompleteService() + { + } + + #endregion + + private ConnectionService connectionService = null; + + /// + /// Internal for testing purposes only + /// + internal ConnectionService ConnectionServiceInstance + { + get + { + if(connectionService == null) + { + connectionService = ConnectionService.Instance; + } + return connectionService; + } + + set + { + connectionService = value; + } + } + + public void InitializeService(ServiceHost serviceHost) + { + // Register auto-complete request handler + serviceHost.SetRequestHandler(CompletionRequest.Type, HandleCompletionRequest); + + // Register a callback for when a connection is created + ConnectionServiceInstance.RegisterOnConnectionTask(UpdateAutoCompleteCache); + + // Register a callback for when a connection is closed + ConnectionServiceInstance.RegisterOnDisconnectTask(RemoveAutoCompleteCacheUriReference); + } + + /// + /// Auto-complete completion provider request callback + /// + /// + /// + /// + private static async Task HandleCompletionRequest( + TextDocumentPosition textDocumentPosition, + RequestContext requestContext) + { + // get the current list of completion items and return to client + var scriptFile = WorkspaceService.Instance.Workspace.GetFile( + textDocumentPosition.TextDocument.Uri); + + ConnectionInfo connInfo; + ConnectionService.Instance.TryFindConnection( + scriptFile.ClientFilePath, + out connInfo); + + var completionItems = Instance.GetCompletionItems( + textDocumentPosition, scriptFile, connInfo); + + await requestContext.SendResult(completionItems); + } + + /// + /// Remove a reference to an autocomplete cache from a URI. If + /// it is the last URI connected to a particular connection, + /// then remove the cache. + /// + public async Task RemoveAutoCompleteCacheUriReference(ConnectionSummary summary) + { + // currently this method is disabled, but we need to reimplement now that the + // implementation of the 'cache' has changed. + await Task.FromResult(0); + } + + /// + /// Update the cached autocomplete candidate list when the user connects to a database + /// + /// + public async Task UpdateAutoCompleteCache(ConnectionInfo info) + { + await Task.Run( () => + { + if (!LanguageService.Instance.ScriptParseInfoMap.ContainsKey(info.OwnerUri)) + { + var sqlConn = info.SqlConnection as SqlConnection; + if (sqlConn != null) + { + var srvConn = new ServerConnection(sqlConn); + var displayInfoProvider = new MetadataDisplayInfoProvider(); + var metadataProvider = SmoMetadataProvider.CreateConnectedProvider(srvConn); + var binder = BinderProvider.CreateBinder(metadataProvider); + + LanguageService.Instance.ScriptParseInfoMap.Add(info.OwnerUri, + new ScriptParseInfo() + { + Binder = binder, + MetadataProvider = metadataProvider, + MetadataDisplayInfoProvider = displayInfoProvider + }); + + var scriptFile = WorkspaceService.Instance.Workspace.GetFile(info.OwnerUri); + + LanguageService.Instance.ParseAndBind(scriptFile, info); + } + } + }); + } + + /// + /// Find the position of the previous delimeter for autocomplete token replacement. + /// SQL Parser may have similar functionality in which case we'll delete this method. + /// + /// + /// + /// + /// + private int PositionOfPrevDelimeter(string sql, int startRow, int startColumn) + { + if (string.IsNullOrWhiteSpace(sql)) + { + return 1; + } + + int prevLineColumns = 0; + for (int i = 0; i < startRow; ++i) + { + while (sql[prevLineColumns] != '\n' && prevLineColumns < sql.Length) + { + ++prevLineColumns; + } + ++prevLineColumns; + } + + startColumn += prevLineColumns; + + if (startColumn - 1 < sql.Length) + { + while (--startColumn >= prevLineColumns) + { + if (sql[startColumn] == ' ' + || sql[startColumn] == '\t' + || sql[startColumn] == '\n' + || sql[startColumn] == '.' + || sql[startColumn] == '+' + || sql[startColumn] == '-' + || sql[startColumn] == '*' + || sql[startColumn] == '>' + || sql[startColumn] == '<' + || sql[startColumn] == '=' + || sql[startColumn] == '/' + || sql[startColumn] == '%') + { + break; + } + } + } + + return startColumn + 1 - prevLineColumns; + } + + /// + /// Determines whether a reparse and bind is required to provide autocomplete + /// + /// + /// TEMP: Currently hard-coded to false for perf + private bool RequiresReparse(ScriptParseInfo info) + { + return false; + } + + /// + /// Converts a list of Declaration objects to CompletionItem objects + /// since VS Code expects CompletionItems but SQL Parser works with Declarations + /// + /// + /// + /// + /// + private CompletionItem[] ConvertDeclarationsToCompletionItems( + IEnumerable suggestions, + int row, + int startColumn, + int endColumn) + { + List completions = new List(); + foreach (var autoCompleteItem in suggestions) + { + // convert the completion item candidates into CompletionItems + completions.Add(new CompletionItem() + { + Label = autoCompleteItem.Title, + Kind = CompletionItemKind.Keyword, + Detail = autoCompleteItem.Title, + Documentation = autoCompleteItem.Description, + TextEdit = new TextEdit + { + NewText = autoCompleteItem.Title, + Range = new Range + { + Start = new Position + { + Line = row, + Character = startColumn + }, + End = new Position + { + Line = row, + Character = endColumn + } + } + } + }); + } + + return completions.ToArray(); + } + + /// + /// Return the completion item list for the current text position. + /// This method does not await cache builds since it expects to return quickly + /// + /// + public CompletionItem[] GetCompletionItems( + TextDocumentPosition textDocumentPosition, + ScriptFile scriptFile, + ConnectionInfo connInfo) + { + string filePath = textDocumentPosition.TextDocument.Uri; + + // Take a reference to the list at a point in time in case we update and replace the list + if (connInfo == null + || !LanguageService.Instance.ScriptParseInfoMap.ContainsKey(textDocumentPosition.TextDocument.Uri)) + { + return new CompletionItem[0]; + } + + // reparse and bind the SQL statement if needed + var scriptParseInfo = LanguageService.Instance.ScriptParseInfoMap[textDocumentPosition.TextDocument.Uri]; + if (RequiresReparse(scriptParseInfo)) + { + LanguageService.Instance.ParseAndBind(scriptFile, connInfo); + } + + if (scriptParseInfo.ParseResult == null) + { + return new CompletionItem[0]; + } + + // get the completion list from SQL Parser + var suggestions = Resolver.FindCompletions( + scriptParseInfo.ParseResult, + textDocumentPosition.Position.Line + 1, + textDocumentPosition.Position.Character + 1, + scriptParseInfo.MetadataDisplayInfoProvider); + + // convert the suggestion list to the VS Code format + return ConvertDeclarationsToCompletionItems( + suggestions, + textDocumentPosition.Position.Line, + PositionOfPrevDelimeter( + scriptFile.Contents, + textDocumentPosition.Position.Line, + textDocumentPosition.Position.Character), + textDocumentPosition.Position.Character); + } + } +} diff --git a/src/ServiceHost/LanguageServer/Completion.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/Completion.cs similarity index 92% rename from src/ServiceHost/LanguageServer/Completion.cs rename to src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/Completion.cs index 5f26ea96..64c0464f 100644 --- a/src/ServiceHost/LanguageServer/Completion.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/Completion.cs @@ -4,9 +4,10 @@ // using System.Diagnostics; -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; -namespace Microsoft.SqlTools.EditorServices.Protocol.LanguageServer +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts { public class CompletionRequest { diff --git a/src/ServiceHost/LanguageServer/Definition.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/Definition.cs similarity index 67% rename from src/ServiceHost/LanguageServer/Definition.cs rename to src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/Definition.cs index b18845c3..d17930e1 100644 --- a/src/ServiceHost/LanguageServer/Definition.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/Definition.cs @@ -3,9 +3,10 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; -namespace Microsoft.SqlTools.EditorServices.Protocol.LanguageServer +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts { public class DefinitionRequest { diff --git a/src/ServiceHost/LanguageServer/Diagnostics.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/Diagnostics.cs similarity index 90% rename from src/ServiceHost/LanguageServer/Diagnostics.cs rename to src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/Diagnostics.cs index a5472607..73d75e5e 100644 --- a/src/ServiceHost/LanguageServer/Diagnostics.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/Diagnostics.cs @@ -3,9 +3,10 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; -namespace Microsoft.SqlTools.EditorServices.Protocol.LanguageServer +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts { public class PublishDiagnosticsNotification { diff --git a/src/ServiceHost/LanguageServer/DocumentHighlight.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/DocumentHighlight.cs similarity index 78% rename from src/ServiceHost/LanguageServer/DocumentHighlight.cs rename to src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/DocumentHighlight.cs index 6849ddfb..968e390b 100644 --- a/src/ServiceHost/LanguageServer/DocumentHighlight.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/DocumentHighlight.cs @@ -3,9 +3,10 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; -namespace Microsoft.SqlTools.EditorServices.Protocol.LanguageServer +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts { public enum DocumentHighlightKind { diff --git a/src/ServiceHost/LanguageServer/ExpandAliasRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/ExpandAliasRequest.cs similarity index 73% rename from src/ServiceHost/LanguageServer/ExpandAliasRequest.cs rename to src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/ExpandAliasRequest.cs index d7f9fde4..e758aa3d 100644 --- a/src/ServiceHost/LanguageServer/ExpandAliasRequest.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/ExpandAliasRequest.cs @@ -3,9 +3,9 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; -namespace Microsoft.SqlTools.EditorServices.Protocol.LanguageServer +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts { public class ExpandAliasRequest { diff --git a/src/ServiceHost/LanguageServer/FindModuleRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/FindModuleRequest.cs similarity index 80% rename from src/ServiceHost/LanguageServer/FindModuleRequest.cs rename to src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/FindModuleRequest.cs index ab78a158..5de004d5 100644 --- a/src/ServiceHost/LanguageServer/FindModuleRequest.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/FindModuleRequest.cs @@ -3,10 +3,10 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol; using System.Collections.Generic; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; -namespace Microsoft.SqlTools.EditorServices.Protocol.LanguageServer +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts { public class FindModuleRequest { diff --git a/src/ServiceHost/LanguageServer/Hover.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/Hover.cs similarity index 77% rename from src/ServiceHost/LanguageServer/Hover.cs rename to src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/Hover.cs index 2e196fba..fbce0883 100644 --- a/src/ServiceHost/LanguageServer/Hover.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/Hover.cs @@ -3,9 +3,10 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; -namespace Microsoft.SqlTools.EditorServices.Protocol.LanguageServer +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts { public class MarkedString { diff --git a/src/ServiceHost/LanguageServer/InstallModuleRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/InstallModuleRequest.cs similarity index 73% rename from src/ServiceHost/LanguageServer/InstallModuleRequest.cs rename to src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/InstallModuleRequest.cs index b03b8864..fc5bc289 100644 --- a/src/ServiceHost/LanguageServer/InstallModuleRequest.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/InstallModuleRequest.cs @@ -3,9 +3,9 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; -namespace Microsoft.SqlTools.EditorServices.Protocol.LanguageServer +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts { class InstallModuleRequest { diff --git a/src/ServiceHost/LanguageServer/References.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/References.cs similarity index 76% rename from src/ServiceHost/LanguageServer/References.cs rename to src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/References.cs index 25a92b12..90b0b3f6 100644 --- a/src/ServiceHost/LanguageServer/References.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/References.cs @@ -3,9 +3,10 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; -namespace Microsoft.SqlTools.EditorServices.Protocol.LanguageServer +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts { public class ReferencesRequest { diff --git a/src/ServiceHost/LanguageServer/ShowOnlineHelpRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/ShowOnlineHelpRequest.cs similarity index 73% rename from src/ServiceHost/LanguageServer/ShowOnlineHelpRequest.cs rename to src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/ShowOnlineHelpRequest.cs index 8f21fb1b..d8cf3f48 100644 --- a/src/ServiceHost/LanguageServer/ShowOnlineHelpRequest.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/ShowOnlineHelpRequest.cs @@ -3,9 +3,9 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; -namespace Microsoft.SqlTools.EditorServices.Protocol.LanguageServer +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts { public class ShowOnlineHelpRequest { diff --git a/src/ServiceHost/LanguageServer/SignatureHelp.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/SignatureHelp.cs similarity index 83% rename from src/ServiceHost/LanguageServer/SignatureHelp.cs rename to src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/SignatureHelp.cs index 5d4233e3..e8fab16b 100644 --- a/src/ServiceHost/LanguageServer/SignatureHelp.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/Contracts/SignatureHelp.cs @@ -3,9 +3,10 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; -namespace Microsoft.SqlTools.EditorServices.Protocol.LanguageServer +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts { public class SignatureHelpRequest { diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs new file mode 100644 index 00000000..d414b41e --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs @@ -0,0 +1,538 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SqlTools.EditorServices.Utility; +using Microsoft.SqlTools.ServiceLayer.Hosting; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts; +using Microsoft.SqlTools.ServiceLayer.SqlContext; +using Microsoft.SqlTools.ServiceLayer.Workspace; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; +using System.Linq; +using Microsoft.SqlServer.Management.SqlParser.Parser; +using Location = Microsoft.SqlTools.ServiceLayer.Workspace.Contracts.Location; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlServer.Management.SqlParser.Binder; +using Microsoft.SqlServer.Management.Common; +using Microsoft.SqlServer.Management.SqlParser; + +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices +{ + /// + /// Main class for Language Service functionality including anything that reqires knowledge of + /// the language to perfom, such as definitions, intellisense, etc. + /// + public sealed class LanguageService + { + + #region Singleton Instance Implementation + + private static readonly Lazy instance = new Lazy(() => new LanguageService()); + + private Lazy> scriptParseInfoMap + = new Lazy>(() => new Dictionary()); + + internal Dictionary ScriptParseInfoMap + { + get + { + return this.scriptParseInfoMap.Value; + } + } + + public static LanguageService Instance + { + get { return instance.Value; } + } + + /// + /// Default, parameterless constructor. + /// + internal LanguageService() + { + } + + #endregion + + #region Properties + + private static CancellationTokenSource ExistingRequestCancellation { get; set; } + + private SqlToolsSettings CurrentSettings + { + get { return WorkspaceService.Instance.CurrentSettings; } + } + + private Workspace.Workspace CurrentWorkspace + { + get { return WorkspaceService.Instance.Workspace; } + } + + /// + /// Gets or sets the current SQL Tools context + /// + /// + private SqlToolsContext Context { get; set; } + + #endregion + + #region Public Methods + + /// + /// Initializes the Language Service instance + /// + /// + /// + public void InitializeService(ServiceHost serviceHost, SqlToolsContext context) + { + // Register the requests that this service will handle + serviceHost.SetRequestHandler(DefinitionRequest.Type, HandleDefinitionRequest); + serviceHost.SetRequestHandler(ReferencesRequest.Type, HandleReferencesRequest); + serviceHost.SetRequestHandler(CompletionResolveRequest.Type, HandleCompletionResolveRequest); + serviceHost.SetRequestHandler(SignatureHelpRequest.Type, HandleSignatureHelpRequest); + serviceHost.SetRequestHandler(DocumentHighlightRequest.Type, HandleDocumentHighlightRequest); + serviceHost.SetRequestHandler(HoverRequest.Type, HandleHoverRequest); + serviceHost.SetRequestHandler(DocumentSymbolRequest.Type, HandleDocumentSymbolRequest); + serviceHost.SetRequestHandler(WorkspaceSymbolRequest.Type, HandleWorkspaceSymbolRequest); + + // Register a no-op shutdown task for validation of the shutdown logic + serviceHost.RegisterShutdownTask(async (shutdownParams, shutdownRequestContext) => + { + Logger.Write(LogLevel.Verbose, "Shutting down language service"); + await Task.FromResult(0); + }); + + // Register the configuration update handler + WorkspaceService.Instance.RegisterConfigChangeCallback(HandleDidChangeConfigurationNotification); + + // Register the file change update handler + WorkspaceService.Instance.RegisterTextDocChangeCallback(HandleDidChangeTextDocumentNotification); + + // Register the file open update handler + WorkspaceService.Instance.RegisterTextDocOpenCallback(HandleDidOpenTextDocumentNotification); + + // Store the SqlToolsContext for future use + Context = context; + } + + /// + /// Parses the SQL text and binds it to the SMO metadata provider if connected + /// + /// + /// + /// + public ParseResult ParseAndBind(ScriptFile scriptFile, ConnectionInfo connInfo) + { + ScriptParseInfo parseInfo = null; + if (this.ScriptParseInfoMap.ContainsKey(scriptFile.ClientFilePath)) + { + parseInfo = this.ScriptParseInfoMap[scriptFile.ClientFilePath]; + } + + // parse current SQL file contents to retrieve a list of errors + ParseOptions parseOptions = new ParseOptions(); + ParseResult parseResult = Parser.IncrementalParse( + scriptFile.Contents, + parseInfo != null ? parseInfo.ParseResult : null, + parseOptions); + + // save previous result for next incremental parse + if (parseInfo != null) + { + parseInfo.ParseResult = parseResult; + } + + if (connInfo != null) + { + try + { + List parseResults = new List(); + parseResults.Add(parseResult); + parseInfo.Binder.Bind( + parseResults, + connInfo.ConnectionDetails.DatabaseName, + BindMode.Batch); + } + catch (ConnectionException) + { + Logger.Write(LogLevel.Error, "Hit connection exception while binding - disposing binder object..."); + } + catch (SqlParserInternalBinderError) + { + Logger.Write(LogLevel.Error, "Hit connection exception while binding - disposing binder object..."); + } + } + + return parseResult; + } + + /// + /// Gets a list of semantic diagnostic marks for the provided script file + /// + /// + public ScriptFileMarker[] GetSemanticMarkers(ScriptFile scriptFile) + { + ConnectionInfo connInfo; + ConnectionService.Instance.TryFindConnection( + scriptFile.ClientFilePath, + out connInfo); + + var parseResult = ParseAndBind(scriptFile, connInfo); + + // build a list of SQL script file markers from the errors + List markers = new List(); + foreach (var error in parseResult.Errors) + { + markers.Add(new ScriptFileMarker() + { + Message = error.Message, + Level = ScriptFileMarkerLevel.Error, + ScriptRegion = new ScriptRegion() + { + File = scriptFile.FilePath, + StartLineNumber = error.Start.LineNumber, + StartColumnNumber = error.Start.ColumnNumber, + StartOffset = 0, + EndLineNumber = error.End.LineNumber, + EndColumnNumber = error.End.ColumnNumber, + EndOffset = 0 + } + }); + } + + return markers.ToArray(); + } + + #endregion + + #region Request Handlers + + private static async Task HandleDefinitionRequest( + TextDocumentPosition textDocumentPosition, + RequestContext requestContext) + { + Logger.Write(LogLevel.Verbose, "HandleDefinitionRequest"); + await Task.FromResult(true); + } + + private static async Task HandleReferencesRequest( + ReferencesParams referencesParams, + RequestContext requestContext) + { + Logger.Write(LogLevel.Verbose, "HandleReferencesRequest"); + await Task.FromResult(true); + } + + private static async Task HandleCompletionResolveRequest( + CompletionItem completionItem, + RequestContext requestContext) + { + Logger.Write(LogLevel.Verbose, "HandleCompletionResolveRequest"); + await Task.FromResult(true); + } + + private static async Task HandleSignatureHelpRequest( + TextDocumentPosition textDocumentPosition, + RequestContext requestContext) + { + Logger.Write(LogLevel.Verbose, "HandleSignatureHelpRequest"); + await Task.FromResult(true); + } + + private static async Task HandleDocumentHighlightRequest( + TextDocumentPosition textDocumentPosition, + RequestContext requestContext) + { + Logger.Write(LogLevel.Verbose, "HandleDocumentHighlightRequest"); + await Task.FromResult(true); + } + + private static async Task HandleHoverRequest( + TextDocumentPosition textDocumentPosition, + RequestContext requestContext) + { + Logger.Write(LogLevel.Verbose, "HandleHoverRequest"); + await Task.FromResult(true); + } + + private static async Task HandleDocumentSymbolRequest( + DocumentSymbolParams documentSymbolParams, + RequestContext requestContext) + { + Logger.Write(LogLevel.Verbose, "HandleDocumentSymbolRequest"); + await Task.FromResult(true); + } + + private static async Task HandleWorkspaceSymbolRequest( + WorkspaceSymbolParams workspaceSymbolParams, + RequestContext requestContext) + { + Logger.Write(LogLevel.Verbose, "HandleWorkspaceSymbolRequest"); + await Task.FromResult(true); + } + + #endregion + + #region Handlers for Events from Other Services + + /// + /// Handle the file open notification + /// + /// + /// + /// + public async Task HandleDidOpenTextDocumentNotification( + ScriptFile scriptFile, + EventContext eventContext) + { + await this.RunScriptDiagnostics( + new ScriptFile[] { scriptFile }, + eventContext); + + await Task.FromResult(true); + } + + /// + /// Handles text document change events + /// + /// + /// + /// + public async Task HandleDidChangeTextDocumentNotification(ScriptFile[] changedFiles, EventContext eventContext) + { + await this.RunScriptDiagnostics( + changedFiles.ToArray(), + eventContext); + + await Task.FromResult(true); + } + + /// + /// Handle the file configuration change notification + /// + /// + /// + /// + public async Task HandleDidChangeConfigurationNotification( + SqlToolsSettings newSettings, + SqlToolsSettings oldSettings, + EventContext eventContext) + { + // If script analysis settings have changed we need to clear & possibly update the current diagnostic records. + bool oldScriptAnalysisEnabled = oldSettings.ScriptAnalysis.Enable.HasValue; + if ((oldScriptAnalysisEnabled != newSettings.ScriptAnalysis.Enable)) + { + // If the user just turned off script analysis or changed the settings path, send a diagnostics + // event to clear the analysis markers that they already have. + if (!newSettings.ScriptAnalysis.Enable.Value) + { + ScriptFileMarker[] emptyAnalysisDiagnostics = new ScriptFileMarker[0]; + + foreach (var scriptFile in WorkspaceService.Instance.Workspace.GetOpenedFiles()) + { + await PublishScriptDiagnostics(scriptFile, emptyAnalysisDiagnostics, eventContext); + } + } + else + { + await this.RunScriptDiagnostics(CurrentWorkspace.GetOpenedFiles(), eventContext); + } + } + + // Update the settings in the current + CurrentSettings.EnableProfileLoading = newSettings.EnableProfileLoading; + CurrentSettings.ScriptAnalysis.Update(newSettings.ScriptAnalysis, CurrentWorkspace.WorkspacePath); + } + + #endregion + + #region Private Helpers + + /// + /// Runs script diagnostics on changed files + /// + /// + /// + private Task RunScriptDiagnostics(ScriptFile[] filesToAnalyze, EventContext eventContext) + { + if (!CurrentSettings.ScriptAnalysis.Enable.Value) + { + // If the user has disabled script analysis, skip it entirely + return Task.FromResult(true); + } + + // If there's an existing task, attempt to cancel it + try + { + if (ExistingRequestCancellation != null) + { + // Try to cancel the request + ExistingRequestCancellation.Cancel(); + + // If cancellation didn't throw an exception, + // clean up the existing token + ExistingRequestCancellation.Dispose(); + ExistingRequestCancellation = null; + } + } + catch (Exception e) + { + Logger.Write( + LogLevel.Error, + string.Format( + "Exception while cancelling analysis task:\n\n{0}", + e.ToString())); + + TaskCompletionSource cancelTask = new TaskCompletionSource(); + cancelTask.SetCanceled(); + return cancelTask.Task; + } + + // Create a fresh cancellation token and then start the task. + // We create this on a different TaskScheduler so that we + // don't block the main message loop thread. + ExistingRequestCancellation = new CancellationTokenSource(); + Task.Factory.StartNew( + () => + DelayThenInvokeDiagnostics( + 750, + filesToAnalyze, + eventContext, + ExistingRequestCancellation.Token), + CancellationToken.None, + TaskCreationOptions.None, + TaskScheduler.Default); + + return Task.FromResult(true); + } + + /// + /// Actually run the script diagnostics after waiting for some small delay + /// + /// + /// + /// + /// + private async Task DelayThenInvokeDiagnostics( + int delayMilliseconds, + ScriptFile[] filesToAnalyze, + EventContext eventContext, + CancellationToken cancellationToken) + { + // First of all, wait for the desired delay period before + // analyzing the provided list of files + try + { + await Task.Delay(delayMilliseconds, cancellationToken); + } + catch (TaskCanceledException) + { + // If the task is cancelled, exit directly + return; + } + + // If we've made it past the delay period then we don't care + // about the cancellation token anymore. This could happen + // when the user stops typing for long enough that the delay + // period ends but then starts typing while analysis is going + // on. It makes sense to send back the results from the first + // delay period while the second one is ticking away. + + // Get the requested files + foreach (ScriptFile scriptFile in filesToAnalyze) + { + Logger.Write(LogLevel.Verbose, "Analyzing script file: " + scriptFile.FilePath); + ScriptFileMarker[] semanticMarkers = GetSemanticMarkers(scriptFile); + Logger.Write(LogLevel.Verbose, "Analysis complete."); + + await PublishScriptDiagnostics(scriptFile, semanticMarkers, eventContext); + } + } + + /// + /// Send the diagnostic results back to the host application + /// + /// + /// + /// + private static async Task PublishScriptDiagnostics( + ScriptFile scriptFile, + ScriptFileMarker[] semanticMarkers, + EventContext eventContext) + { + var allMarkers = scriptFile.SyntaxMarkers != null + ? scriptFile.SyntaxMarkers.Concat(semanticMarkers) + : semanticMarkers; + + // Always send syntax and semantic errors. We want to + // make sure no out-of-date markers are being displayed. + await eventContext.SendEvent( + PublishDiagnosticsNotification.Type, + new PublishDiagnosticsNotification + { + Uri = scriptFile.ClientFilePath, + Diagnostics = + allMarkers + .Select(GetDiagnosticFromMarker) + .ToArray() + }); + } + + /// + /// Convert a ScriptFileMarker to a Diagnostic that is Language Service compatible + /// + /// + /// + private static Diagnostic GetDiagnosticFromMarker(ScriptFileMarker scriptFileMarker) + { + return new Diagnostic + { + Severity = MapDiagnosticSeverity(scriptFileMarker.Level), + Message = scriptFileMarker.Message, + Range = new Range + { + Start = new Position + { + Line = scriptFileMarker.ScriptRegion.StartLineNumber - 1, + Character = scriptFileMarker.ScriptRegion.StartColumnNumber - 1 + }, + End = new Position + { + Line = scriptFileMarker.ScriptRegion.EndLineNumber - 1, + Character = scriptFileMarker.ScriptRegion.EndColumnNumber - 1 + } + } + }; + } + + /// + /// Map ScriptFileMarker severity to Diagnostic severity + /// + /// + private static DiagnosticSeverity MapDiagnosticSeverity(ScriptFileMarkerLevel markerLevel) + { + switch (markerLevel) + { + case ScriptFileMarkerLevel.Error: + return DiagnosticSeverity.Error; + + case ScriptFileMarkerLevel.Warning: + return DiagnosticSeverity.Warning; + + case ScriptFileMarkerLevel.Information: + return DiagnosticSeverity.Information; + + default: + return DiagnosticSeverity.Error; + } + } + + #endregion + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ScriptParseInfo.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ScriptParseInfo.cs new file mode 100644 index 00000000..4da2c57e --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ScriptParseInfo.cs @@ -0,0 +1,40 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlServer.Management.SqlParser.Binder; +using Microsoft.SqlServer.Management.SqlParser.MetadataProvider; +using Microsoft.SqlServer.Management.SqlParser.Parser; +using Microsoft.SqlServer.Management.SmoMetadataProvider; + +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices +{ + /// + /// Class for storing cached metadata regarding a parsed SQL file + /// + internal class ScriptParseInfo + { + /// + /// Gets or sets the SMO binder for schema-aware intellisense + /// + public IBinder Binder { get; set; } + + /// + /// Gets or sets the previous SQL parse result + /// + public ParseResult ParseResult { get; set; } + + /// + /// Gets or set the SMO metadata provider that's bound to the current connection + /// + /// + public SmoMetadataProvider MetadataProvider { get; set; } + + /// + /// Gets or sets the SMO metadata display info provider + /// + /// + public MetadataDisplayInfoProvider MetadataDisplayInfoProvider { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.xproj b/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.xproj new file mode 100644 index 00000000..358bb7c3 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.xproj @@ -0,0 +1,19 @@ + + + + 14.0 + $(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion) + + + + {0D61DC2B-DA66-441D-B9D0-F76C98F780F9} + Microsoft.SqlTools.ServiceLayer + .\obj + .\bin\ + v4.5.2 + + + 2.0 + + + \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/Program.cs b/src/Microsoft.SqlTools.ServiceLayer/Program.cs new file mode 100644 index 00000000..f0d2d6e8 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Program.cs @@ -0,0 +1,58 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System; +using Microsoft.SqlTools.EditorServices.Utility; +using Microsoft.SqlTools.ServiceLayer.Hosting; +using Microsoft.SqlTools.ServiceLayer.SqlContext; +using Microsoft.SqlTools.ServiceLayer.Workspace; +using Microsoft.SqlTools.ServiceLayer.LanguageServices; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.Credentials; + +namespace Microsoft.SqlTools.ServiceLayer +{ + /// + /// Main application class for SQL Tools API Service Host executable + /// + class Program + { + /// + /// Main entry point into the SQL Tools API Service Host + /// + static void Main(string[] args) + { + // turn on Verbose logging during early development + // we need to switch to Normal when preparing for public preview + Logger.Initialize(minimumLogLevel: LogLevel.Verbose); + Logger.Write(LogLevel.Normal, "Starting SQL Tools Service Host"); + + const string hostName = "SQL Tools Service Host"; + const string hostProfileId = "SQLToolsService"; + Version hostVersion = new Version(1,0); + + // set up the host details and profile paths + var hostDetails = new HostDetails(hostName, hostProfileId, hostVersion); + var profilePaths = new ProfilePaths(hostProfileId, "baseAllUsersPath", "baseCurrentUserPath"); + SqlToolsContext sqlToolsContext = new SqlToolsContext(hostDetails, profilePaths); + + // Grab the instance of the service host + ServiceHost serviceHost = ServiceHost.Instance; + + // Start the service + serviceHost.Start().Wait(); + + // Initialize the services that will be hosted here + WorkspaceService.Instance.InitializeService(serviceHost); + AutoCompleteService.Instance.InitializeService(serviceHost); + LanguageService.Instance.InitializeService(serviceHost, sqlToolsContext); + ConnectionService.Instance.InitializeService(serviceHost); + CredentialService.Instance.InitializeService(serviceHost); + QueryExecutionService.Instance.InitializeService(serviceHost); + + serviceHost.Initialize(); + serviceHost.WaitForExit(); + } + } +} diff --git a/src/ServiceHost/Properties/AssemblyInfo.cs b/src/Microsoft.SqlTools.ServiceLayer/Properties/AssemblyInfo.cs similarity index 95% rename from src/ServiceHost/Properties/AssemblyInfo.cs rename to src/Microsoft.SqlTools.ServiceLayer/Properties/AssemblyInfo.cs index 27c1daba..33b9e4a8 100644 --- a/src/ServiceHost/Properties/AssemblyInfo.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Properties/AssemblyInfo.cs @@ -41,4 +41,4 @@ using System.Runtime.InteropServices; [assembly: AssemblyFileVersion("1.0.0.0")] [assembly: AssemblyInformationalVersion("1.0.0.0")] -[assembly: InternalsVisibleTo("Microsoft.SqlTools.EditorServices.Test.Protocol")] +[assembly: InternalsVisibleTo("Microsoft.SqlTools.ServiceLayer.Test")] diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs new file mode 100644 index 00000000..69250afe --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs @@ -0,0 +1,301 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Data.SqlClient; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SqlTools.EditorServices.Utility; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution +{ + /// + /// This class represents a batch within a query + /// + public class Batch : IDisposable + { + private const string RowsAffectedFormat = "({0} row(s) affected)"; + + #region Member Variables + + /// + /// For IDisposable implementation, whether or not this has been disposed + /// + private bool disposed; + + /// + /// Factory for creating readers/writrs for the output of the batch + /// + private readonly IFileStreamFactory outputFileFactory; + + /// + /// Internal representation of the messages so we can modify internally + /// + private readonly List resultMessages; + + /// + /// Internal representation of the result sets so we can modify internally + /// + private readonly List resultSets; + + #endregion + + internal Batch(string batchText, int startLine, IFileStreamFactory outputFileFactory) + { + // Sanity check for input + Validate.IsNotNullOrEmptyString(nameof(batchText), batchText); + Validate.IsNotNull(nameof(outputFileFactory), outputFileFactory); + + // Initialize the internal state + BatchText = batchText; + StartLine = startLine - 1; // -1 to make sure that the line number of the batch is 0-indexed, since SqlParser gives 1-indexed line numbers + HasExecuted = false; + resultSets = new List(); + resultMessages = new List(); + this.outputFileFactory = outputFileFactory; + } + + #region Properties + + /// + /// The text of batch that will be executed + /// + public string BatchText { get; set; } + + /// + /// Whether or not this batch has an error + /// + public bool HasError { get; set; } + + /// + /// Whether or not this batch has been executed, regardless of success or failure + /// + public bool HasExecuted { get; set; } + + /// + /// Messages that have come back from the server + /// + public IEnumerable ResultMessages + { + get { return resultMessages; } + } + + /// + /// The result sets of the batch execution + /// + public IEnumerable ResultSets + { + get { return resultSets; } + } + + /// + /// Property for generating a set result set summaries from the result sets + /// + public ResultSetSummary[] ResultSummaries + { + get + { + return ResultSets.Select((set, index) => new ResultSetSummary() + { + ColumnInfo = set.Columns, + Id = index, + RowCount = set.RowCount + }).ToArray(); + } + } + + /// + /// The 0-indexed line number that this batch started on + /// + internal int StartLine { get; set; } + + #endregion + + #region Public Methods + + /// + /// Executes this batch and captures any server messages that are returned. + /// + /// The connection to use to execute the batch + /// Token for cancelling the execution + public async Task Execute(DbConnection conn, CancellationToken cancellationToken) + { + // Sanity check to make sure we haven't already run this batch + if (HasExecuted) + { + throw new InvalidOperationException("Batch has already executed."); + } + + try + { + // Register the message listener to *this instance* of the batch + // Note: This is being done to associate messages with batches + SqlConnection sqlConn = conn as SqlConnection; + if (sqlConn != null) + { + sqlConn.InfoMessage += StoreDbMessage; + } + + // Create a command that we'll use for executing the query + using (DbCommand command = conn.CreateCommand()) + { + command.CommandText = BatchText; + command.CommandType = CommandType.Text; + + // Execute the command to get back a reader + using (DbDataReader reader = await command.ExecuteReaderAsync(cancellationToken)) + { + do + { + // Skip this result set if there aren't any rows + if (!reader.HasRows && reader.FieldCount == 0) + { + // Create a message with the number of affected rows -- IF the query affects rows + resultMessages.Add(reader.RecordsAffected >= 0 + ? string.Format(RowsAffectedFormat, reader.RecordsAffected) + : "Command(s) completed successfully."); + continue; + } + + // Read until we hit the end of the result set + ResultSet resultSet = new ResultSet(reader, outputFileFactory); + await resultSet.ReadResultToEnd(cancellationToken); + + // Add the result set to the results of the query + resultSets.Add(resultSet); + + // Add a message for the number of rows the query returned + resultMessages.Add(string.Format(RowsAffectedFormat, resultSet.RowCount)); + } while (await reader.NextResultAsync(cancellationToken)); + } + } + } + catch (DbException dbe) + { + HasError = true; + UnwrapDbException(dbe); + } + catch (Exception) + { + HasError = true; + throw; + } + finally + { + // Remove the message event handler from the connection + SqlConnection sqlConn = conn as SqlConnection; + if (sqlConn != null) + { + sqlConn.InfoMessage -= StoreDbMessage; + } + + // Mark that we have executed + HasExecuted = true; + } + } + + /// + /// Generates a subset of the rows from a result set of the batch + /// + /// The index for selecting the result set + /// The starting row of the results + /// How many rows to retrieve + /// A subset of results + public Task GetSubset(int resultSetIndex, int startRow, int rowCount) + { + // Sanity check to make sure we have valid numbers + if (resultSetIndex < 0 || resultSetIndex >= resultSets.Count) + { + throw new ArgumentOutOfRangeException(nameof(resultSetIndex), "Result set index cannot be less than 0" + + "or greater than the number of result sets"); + } + + // Retrieve the result set + return resultSets[resultSetIndex].GetSubset(startRow, rowCount); + } + + #endregion + + #region Private Helpers + + /// + /// Delegate handler for storing messages that are returned from the server + /// NOTE: Only messages that are below a certain severity will be returned via this + /// mechanism. Anything above that level will trigger an exception. + /// + /// Object that fired the event + /// Arguments from the event + private void StoreDbMessage(object sender, SqlInfoMessageEventArgs args) + { + resultMessages.Add(args.Message); + } + + /// + /// Attempts to convert a to a that + /// contains much more info about Sql Server errors. The exception is then unwrapped and + /// messages are formatted and stored in . If the exception + /// cannot be converted to SqlException, the message is written to the messages list. + /// + /// The exception to unwrap + private void UnwrapDbException(DbException dbe) + { + SqlException se = dbe as SqlException; + if (se != null) + { + foreach (var error in se.Errors) + { + SqlError sqlError = error as SqlError; + if (sqlError != null) + { + int lineNumber = sqlError.LineNumber + StartLine; + string message = String.Format("Msg {0}, Level {1}, State {2}, Line {3}{4}{5}", + sqlError.Number, sqlError.Class, sqlError.State, lineNumber, + Environment.NewLine, sqlError.Message); + resultMessages.Add(message); + } + } + } + else + { + resultMessages.Add(dbe.Message); + } + } + + #endregion + + #region IDisposable Implementation + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (disposed) + { + return; + } + + if (disposing) + { + foreach (ResultSet r in ResultSets) + { + r.Dispose(); + } + } + + disposed = true; + } + + #endregion + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/BatchSummary.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/BatchSummary.cs new file mode 100644 index 00000000..73d1d4c8 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/BatchSummary.cs @@ -0,0 +1,33 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts +{ + /// + /// Summary of a batch within a query + /// + public class BatchSummary + { + /// + /// Whether or not the batch was successful. True indicates errors, false indicates success + /// + public bool HasError { get; set; } + + /// + /// The ID of the result set within the query results + /// + public int Id { get; set; } + + /// + /// Any messages that came back from the server during execution of the batch + /// + public string[] Messages { get; set; } + + /// + /// The summaries of the result sets inside the batch + /// + public ResultSetSummary[] ResultSetSummaries { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbColumnWrapper.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbColumnWrapper.cs new file mode 100644 index 00000000..e80eada5 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbColumnWrapper.cs @@ -0,0 +1,226 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.Data.SqlTypes; +using System.Diagnostics; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts +{ + /// + /// Wrapper around a DbColumn, which provides extra functionality, but can be used as a + /// regular DbColumn + /// + public class DbColumnWrapper : DbColumn + { + /// + /// All types supported by the server, stored as a hash set to provide O(1) lookup + /// + private static readonly HashSet AllServerDataTypes = new HashSet + { + "bigint", + "binary", + "bit", + "char", + "datetime", + "decimal", + "float", + "image", + "int", + "money", + "nchar", + "ntext", + "nvarchar", + "real", + "uniqueidentifier", + "smalldatetime", + "smallint", + "smallmoney", + "text", + "timestamp", + "tinyint", + "varbinary", + "varchar", + "sql_variant", + "xml", + "date", + "time", + "datetimeoffset", + "datetime2" + }; + + private readonly DbColumn internalColumn; + + /// + /// Constructor for a DbColumnWrapper + /// + /// Most of this logic is taken from SSMS ColumnInfo class + /// The column we're wrapping around + public DbColumnWrapper(DbColumn column) + { + internalColumn = column; + + switch (column.DataTypeName) + { + case "varchar": + case "nvarchar": + IsChars = true; + + Debug.Assert(column.ColumnSize.HasValue); + if (column.ColumnSize.Value == int.MaxValue) + { + //For Yukon, special case nvarchar(max) with column name == "Microsoft SQL Server 2005 XML Showplan" - + //assume it is an XML showplan. + //Please note this field must be in sync with a similar field defined in QESQLBatch.cs. + //This is not the best fix that we could do but we are trying to minimize code impact + //at this point. Post Yukon we should review this code again and avoid + //hard-coding special column name in multiple places. + const string YukonXmlShowPlanColumn = "Microsoft SQL Server 2005 XML Showplan"; + if (column.ColumnName == YukonXmlShowPlanColumn) + { + // Indicate that this is xml to apply the right size limit + // Note we leave chars type as well to use the right retrieval mechanism. + IsXml = true; + } + IsLong = true; + } + break; + case "text": + case "ntext": + IsChars = true; + IsLong = true; + break; + case "xml": + IsXml = true; + IsLong = true; + break; + case "binary": + case "image": + IsBytes = true; + IsLong = true; + break; + case "varbinary": + case "rowversion": + IsBytes = true; + + Debug.Assert(column.ColumnSize.HasValue); + if (column.ColumnSize.Value == int.MaxValue) + { + IsLong = true; + } + break; + case "sql_variant": + IsSqlVariant = true; + break; + default: + if (!AllServerDataTypes.Contains(column.DataTypeName)) + { + // treat all UDT's as long/bytes data types to prevent the CLR from attempting + // to load the UDT assembly into our process to call ToString() on the object. + + IsUdt = true; + IsBytes = true; + IsLong = true; + } + break; + } + + + if (IsUdt) + { + // udtassemblyqualifiedname property is used to find if the datatype is of hierarchyid assembly type + // Internally hiearchyid is sqlbinary so providerspecific type and type is changed to sqlbinarytype + object assemblyQualifiedName = internalColumn.UdtAssemblyQualifiedName; + const string hierarchyId = "MICROSOFT.SQLSERVER.TYPES.SQLHIERARCHYID"; + + if (assemblyQualifiedName != null && + string.Equals(assemblyQualifiedName.ToString(), hierarchyId, StringComparison.OrdinalIgnoreCase)) + { + DataType = typeof(SqlBinary); + } + else + { + DataType = typeof(byte[]); + } + } + else + { + DataType = DataType; + } + } + + #region Properties + + /// + /// Whether or not the column is bytes + /// + public bool IsBytes { get; private set; } + + /// + /// Whether or not the column is a character type + /// + public bool IsChars { get; private set; } + + /// + /// Whether or not the column is a long type (eg, varchar(MAX)) + /// + public new bool IsLong { get; private set; } + + /// + /// Whether or not the column is a SqlVariant type + /// + public bool IsSqlVariant { get; private set; } + + /// + /// Whether or not the column is a user-defined type + /// + public bool IsUdt { get; private set; } + + /// + /// Whether or not the column is XML + /// + public bool IsXml { get; private set; } + + #endregion + + #region DbColumn Fields + + /// + /// Override for column name, if null or empty, we default to a "no column name" value + /// + public new string ColumnName + { + get + { + // TODO: Localize + return string.IsNullOrEmpty(internalColumn.ColumnName) ? "(No column name)" : internalColumn.ColumnName; + } + } + + public new bool? AllowDBNull { get { return internalColumn.AllowDBNull; } } + public new string BaseCatalogName { get { return internalColumn.BaseCatalogName; } } + public new string BaseColumnName { get { return internalColumn.BaseColumnName; } } + public new string BaseServerName { get { return internalColumn.BaseServerName; } } + public new string BaseTableName { get { return internalColumn.BaseTableName; } } + public new int? ColumnOrdinal { get { return internalColumn.ColumnOrdinal; } } + public new int? ColumnSize { get { return internalColumn.ColumnSize; } } + public new bool? IsAliased { get { return internalColumn.IsAliased; } } + public new bool? IsAutoIncrement { get { return internalColumn.IsAutoIncrement; } } + public new bool? IsExpression { get { return internalColumn.IsExpression; } } + public new bool? IsHidden { get { return internalColumn.IsHidden; } } + public new bool? IsIdentity { get { return internalColumn.IsIdentity; } } + public new bool? IsKey { get { return internalColumn.IsKey; } } + public new bool? IsReadOnly { get { return internalColumn.IsReadOnly; } } + public new bool? IsUnique { get { return internalColumn.IsUnique; } } + public new int? NumericPrecision { get { return internalColumn.NumericPrecision; } } + public new int? NumericScale { get { return internalColumn.NumericScale; } } + public new string UdtAssemblyQualifiedName { get { return internalColumn.UdtAssemblyQualifiedName; } } + public new Type DataType { get; private set; } + public new string DataTypeName { get { return internalColumn.DataTypeName; } } + + #endregion + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryCancelRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryCancelRequest.cs new file mode 100644 index 00000000..3eb87f4f --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryCancelRequest.cs @@ -0,0 +1,36 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts +{ + /// + /// Parameters for the query cancellation request + /// + public class QueryCancelParams + { + public string OwnerUri { get; set; } + } + + /// + /// Parameters to return as the result of a query dispose request + /// + public class QueryCancelResult + { + /// + /// Any error messages that occurred during disposing the result set. Optional, can be set + /// to null if there were no errors. + /// + public string Messages { get; set; } + } + + public class QueryCancelRequest + { + public static readonly + RequestType Type = + RequestType.Create("query/cancel"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryDisposeRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryDisposeRequest.cs new file mode 100644 index 00000000..70e6631c --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryDisposeRequest.cs @@ -0,0 +1,36 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts +{ + /// + /// Parameters for the query dispose request + /// + public class QueryDisposeParams + { + public string OwnerUri { get; set; } + } + + /// + /// Parameters to return as the result of a query dispose request + /// + public class QueryDisposeResult + { + /// + /// Any error messages that occurred during disposing the result set. Optional, can be set + /// to null if there were no errors. + /// + public string Messages { get; set; } + } + + public class QueryDisposeRequest + { + public static readonly + RequestType Type = + RequestType.Create("query/dispose"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteCompleteNotification.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteCompleteNotification.cs new file mode 100644 index 00000000..90c8c7b3 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteCompleteNotification.cs @@ -0,0 +1,32 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts +{ + /// + /// Parameters to be sent back with a query execution complete event + /// + public class QueryExecuteCompleteParams + { + /// + /// URI for the editor that owns the query + /// + public string OwnerUri { get; set; } + + /// + /// Summaries of the result sets that were returned with the query + /// + public BatchSummary[] BatchSummaries { get; set; } + } + + public class QueryExecuteCompleteEvent + { + public static readonly + EventType Type = + EventType.Create("query/complete"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteRequest.cs new file mode 100644 index 00000000..cac98c1a --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteRequest.cs @@ -0,0 +1,43 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts +{ + /// + /// Parameters for the query execute request + /// + public class QueryExecuteParams + { + /// + /// The text of the query to execute + /// + public string QueryText { get; set; } + + /// + /// URI for the editor that is asking for the query execute + /// + public string OwnerUri { get; set; } + } + + /// + /// Parameters for the query execute result + /// + public class QueryExecuteResult + { + /// + /// Connection error messages. Optional, can be set to null to indicate no errors + /// + public string Messages { get; set; } + } + + public class QueryExecuteRequest + { + public static readonly + RequestType Type = + RequestType.Create("query/execute"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteSubsetRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteSubsetRequest.cs new file mode 100644 index 00000000..2c861502 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteSubsetRequest.cs @@ -0,0 +1,66 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts +{ + /// + /// Parameters for a query result subset retrieval request + /// + public class QueryExecuteSubsetParams + { + /// + /// URI for the file that owns the query to look up the results for + /// + public string OwnerUri { get; set; } + + /// + /// Index of the batch to get the results from + /// + public int BatchIndex { get; set; } + + /// + /// Index of the result set to get the results from + /// + public int ResultSetIndex { get; set; } + + /// + /// Beginning index of the rows to return from the selected resultset. This index will be + /// included in the results. + /// + public int RowsStartIndex { get; set; } + + /// + /// Number of rows to include in the result of this request. If the number of the rows + /// exceeds the number of rows available after the start index, all available rows after + /// the start index will be returned. + /// + public int RowsCount { get; set; } + } + + /// + /// Parameters for the result of a subset retrieval request + /// + public class QueryExecuteSubsetResult + { + /// + /// Subset request error messages. Optional, can be set to null to indicate no errors + /// + public string Message { get; set; } + + /// + /// The requested subset of results. Optional, can be set to null to indicate an error + /// + public ResultSetSubset ResultSubset { get; set; } + } + + public class QueryExecuteSubsetRequest + { + public static readonly + RequestType Type = + RequestType.Create("query/subset"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSubset.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSubset.cs new file mode 100644 index 00000000..8e2b49a9 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSubset.cs @@ -0,0 +1,24 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts +{ + /// + /// Class used to represent a subset of results from a query for transmission across JSON RPC + /// + public class ResultSetSubset + { + /// + /// The number of rows returned from result set, useful for determining if less rows were + /// returned than requested. + /// + public int RowCount { get; set; } + + /// + /// 2D array of the cell values requested from result set + /// + public object[][] Rows { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSummary.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSummary.cs new file mode 100644 index 00000000..c8705d8b --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSummary.cs @@ -0,0 +1,28 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts +{ + /// + /// Represents a summary of information about a result without returning any cells of the results + /// + public class ResultSetSummary + { + /// + /// The ID of the result set within the batch results + /// + public int Id { get; set; } + + /// + /// The number of rows that was returned with the resultset + /// + public long RowCount { get; set; } + + /// + /// Details about the columns that are provided as solutions + /// + public DbColumnWrapper[] ColumnInfo { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/FileStreamReadResult.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/FileStreamReadResult.cs new file mode 100644 index 00000000..61ee62e0 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/FileStreamReadResult.cs @@ -0,0 +1,50 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage +{ + /// + /// Represents a value returned from a read from a file stream. This is used to eliminate ref + /// parameters used in the read methods. + /// + /// The type of the value that was read + public struct FileStreamReadResult + { + /// + /// Whether or not the value of the field is null + /// + public bool IsNull { get; set; } + + /// + /// The value of the field. If is true, this will be set to default(T) + /// + public T Value { get; set; } + + /// + /// The total length in bytes of the value, (including the bytes used to store the length + /// of the value) + /// + /// + /// Cell values are stored such that the length of the value is stored first, then the + /// value itself is stored. Eg, a string may be stored as 0x03 0x6C 0x6F 0x6C. Under this + /// system, the value would be "lol", the length would be 3, and the total length would be + /// 4 bytes. + /// + public int TotalLength { get; set; } + + /// + /// Constructs a new FileStreamReadResult + /// + /// The value of the result + /// The number of bytes for the used to store the value's length and value + /// Whether or not the value is null + public FileStreamReadResult(T value, int totalLength, bool isNull) + { + Value = value; + TotalLength = totalLength; + IsNull = isNull; + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/FileStreamWrapper.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/FileStreamWrapper.cs new file mode 100644 index 00000000..3a6c3ecf --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/FileStreamWrapper.cs @@ -0,0 +1,278 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Diagnostics; +using System.IO; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage +{ + /// + /// Wrapper for a file stream, providing simplified creation, deletion, read, and write + /// functionality. + /// + public class FileStreamWrapper : IFileStreamWrapper + { + #region Member Variables + + private byte[] buffer; + private int bufferDataSize; + private FileStream fileStream; + private long startOffset; + private long currentOffset; + + #endregion + + /// + /// Constructs a new FileStreamWrapper and initializes its state. + /// + public FileStreamWrapper() + { + // Initialize the internal state + bufferDataSize = 0; + startOffset = 0; + currentOffset = 0; + } + + #region IFileStreamWrapper Implementation + + /// + /// Initializes the wrapper by creating the internal buffer and opening the requested file. + /// If the file does not already exist, it will be created. + /// + /// Name of the file to open/create + /// The length of the internal buffer + /// + /// Whether or not the wrapper will be used for reading. If true, any calls to a + /// method that writes will cause an InvalidOperationException + /// + public void Init(string fileName, int bufferLength, FileAccess accessMethod) + { + // Sanity check for valid buffer length, fileName, and accessMethod + if (bufferLength <= 0) + { + throw new ArgumentOutOfRangeException(nameof(bufferLength), "Buffer length must be a positive value"); + } + if (string.IsNullOrWhiteSpace(fileName)) + { + throw new ArgumentNullException(nameof(fileName), "File name cannot be null or whitespace"); + } + if (accessMethod == FileAccess.Write) + { + throw new ArgumentException("Access method cannot be write-only", nameof(fileName)); + } + + // Setup the buffer + buffer = new byte[bufferLength]; + + // Open the requested file for reading/writing, creating one if it doesn't exist + fileStream = new FileStream(fileName, FileMode.OpenOrCreate, accessMethod, FileShare.ReadWrite, + bufferLength, false /*don't use asyncio*/); + } + + /// + /// Reads data into a buffer from the current offset into the file + /// + /// The buffer to output the read data to + /// The number of bytes to read into the buffer + /// The number of bytes read + public int ReadData(byte[] buf, int bytes) + { + return ReadData(buf, bytes, currentOffset); + } + + /// + /// Reads data into a buffer from the specified offset into the file + /// + /// The buffer to output the read data to + /// The number of bytes to read into the buffer + /// The offset into the file to start reading bytes from + /// The number of bytes read + public int ReadData(byte[] buf, int bytes, long offset) + { + // Make sure that we're initialized before performing operations + if (buffer == null) + { + throw new InvalidOperationException("FileStreamWrapper must be initialized before performing operations"); + } + + MoveTo(offset); + + int bytesCopied = 0; + while (bytesCopied < bytes) + { + int bufferOffset, bytesToCopy; + GetByteCounts(bytes, bytesCopied, out bufferOffset, out bytesToCopy); + Buffer.BlockCopy(buffer, bufferOffset, buf, bytesCopied, bytesToCopy); + bytesCopied += bytesToCopy; + + if (bytesCopied < bytes && // did not get all the bytes yet + bufferDataSize == buffer.Length) // since current data buffer is full we should continue reading the file + { + // move forward one full length of the buffer + MoveTo(startOffset + buffer.Length); + } + else + { + // copied all the bytes requested or possible, adjust the current buffer pointer + currentOffset += bytesToCopy; + break; + } + } + return bytesCopied; + } + + /// + /// Writes data to the underlying filestream, with buffering. + /// + /// The buffer of bytes to write to the filestream + /// The number of bytes to write + /// The number of bytes written + public int WriteData(byte[] buf, int bytes) + { + // Make sure that we're initialized before performing operations + if (buffer == null) + { + throw new InvalidOperationException("FileStreamWrapper must be initialized before performing operations"); + } + if (!fileStream.CanWrite) + { + throw new InvalidOperationException("This FileStreamWrapper canot be used for writing"); + } + + int bytesCopied = 0; + while (bytesCopied < bytes) + { + int bufferOffset, bytesToCopy; + GetByteCounts(bytes, bytesCopied, out bufferOffset, out bytesToCopy); + Buffer.BlockCopy(buf, bytesCopied, buffer, bufferOffset, bytesToCopy); + bytesCopied += bytesToCopy; + + // adjust the current buffer pointer + currentOffset += bytesToCopy; + + if (bytesCopied < bytes) // did not get all the bytes yet + { + Debug.Assert((int)(currentOffset - startOffset) == buffer.Length); + // flush buffer + Flush(); + } + } + Debug.Assert(bytesCopied == bytes); + return bytesCopied; + } + + /// + /// Flushes the internal buffer to the filestream + /// + public void Flush() + { + // Make sure that we're initialized before performing operations + if (buffer == null) + { + throw new InvalidOperationException("FileStreamWrapper must be initialized before performing operations"); + } + if (!fileStream.CanWrite) + { + throw new InvalidOperationException("This FileStreamWrapper cannot be used for writing"); + } + + // Make sure we are at the right place in the file + Debug.Assert(fileStream.Position == startOffset); + + int bytesToWrite = (int)(currentOffset - startOffset); + fileStream.Write(buffer, 0, bytesToWrite); + startOffset += bytesToWrite; + fileStream.Flush(); + + Debug.Assert(startOffset == currentOffset); + } + + /// + /// Deletes the given file (ideally, created with this wrapper) from the filesystem + /// + /// The path to the file to delete + public static void DeleteFile(string fileName) + { + File.Delete(fileName); + } + + #endregion + + /// + /// Perform calculations to determine how many bytes to copy and what the new buffer offset + /// will be for copying. + /// + /// Number of bytes requested to copy + /// Number of bytes copied so far + /// New offset to start copying from/to + /// Number of bytes to copy in this iteration + private void GetByteCounts(int bytes, int bytesCopied, out int bufferOffset, out int bytesToCopy) + { + bufferOffset = (int) (currentOffset - startOffset); + bytesToCopy = bytes - bytesCopied; + if (bytesToCopy > buffer.Length - bufferOffset) + { + bytesToCopy = buffer.Length - bufferOffset; + } + } + + /// + /// Moves the internal buffer to the specified offset into the file + /// + /// Offset into the file to move to + private void MoveTo(long offset) + { + if (buffer.Length > bufferDataSize || // buffer is not completely filled + offset < startOffset || // before current buffer start + offset >= (startOffset + buffer.Length)) // beyond current buffer end + { + // init the offset + startOffset = offset; + + // position file pointer + fileStream.Seek(startOffset, SeekOrigin.Begin); + + // fill in the buffer + bufferDataSize = fileStream.Read(buffer, 0, buffer.Length); + } + // make sure to record where we are + currentOffset = offset; + } + + #region IDisposable Implementation + + private bool disposed; + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (disposed) + { + return; + } + + if (disposing && fileStream != null) + { + if(fileStream.CanWrite) { Flush(); } + fileStream.Dispose(); + } + + disposed = true; + } + + ~FileStreamWrapper() + { + Dispose(false); + } + + #endregion + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamFactory.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamFactory.cs new file mode 100644 index 00000000..6cb50095 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamFactory.cs @@ -0,0 +1,22 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage +{ + /// + /// Interface for a factory that creates filesystem readers/writers + /// + public interface IFileStreamFactory + { + string CreateFile(); + + IFileStreamReader GetReader(string fileName); + + IFileStreamWriter GetWriter(string fileName, int maxCharsToStore, int maxXmlCharsToStore); + + void DisposeFile(string fileName); + + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamReader.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamReader.cs new file mode 100644 index 00000000..ea5584f1 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamReader.cs @@ -0,0 +1,35 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Data.SqlTypes; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage +{ + /// + /// Interface for a object that reads from the filesystem + /// + public interface IFileStreamReader : IDisposable + { + object[] ReadRow(long offset, IEnumerable columns); + FileStreamReadResult ReadInt16(long i64Offset); + FileStreamReadResult ReadInt32(long i64Offset); + FileStreamReadResult ReadInt64(long i64Offset); + FileStreamReadResult ReadByte(long i64Offset); + FileStreamReadResult ReadChar(long i64Offset); + FileStreamReadResult ReadBoolean(long i64Offset); + FileStreamReadResult ReadSingle(long i64Offset); + FileStreamReadResult ReadDouble(long i64Offset); + FileStreamReadResult ReadSqlDecimal(long i64Offset); + FileStreamReadResult ReadDecimal(long i64Offset); + FileStreamReadResult ReadDateTime(long i64Offset); + FileStreamReadResult ReadTimeSpan(long i64Offset); + FileStreamReadResult ReadString(long i64Offset); + FileStreamReadResult ReadBytes(long i64Offset); + FileStreamReadResult ReadDateTimeOffset(long i64Offset); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWrapper.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWrapper.cs new file mode 100644 index 00000000..38c283c5 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWrapper.cs @@ -0,0 +1,22 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.IO; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage +{ + /// + /// Interface for a wrapper around a filesystem reader/writer, mainly for unit testing purposes + /// + public interface IFileStreamWrapper : IDisposable + { + void Init(string fileName, int bufferSize, FileAccess fileAccessMode); + int ReadData(byte[] buffer, int bytes); + int ReadData(byte[] buffer, int bytes, long fileOffset); + int WriteData(byte[] buffer, int bytes); + void Flush(); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWriter.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWriter.cs new file mode 100644 index 00000000..968701ed --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWriter.cs @@ -0,0 +1,35 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Data.SqlTypes; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage +{ + /// + /// Interface for a object that writes to a filesystem wrapper + /// + public interface IFileStreamWriter : IDisposable + { + int WriteRow(StorageDataReader dataReader); + int WriteNull(); + int WriteInt16(short val); + int WriteInt32(int val); + int WriteInt64(long val); + int WriteByte(byte val); + int WriteChar(char val); + int WriteBoolean(bool val); + int WriteSingle(float val); + int WriteDouble(double val); + int WriteDecimal(decimal val); + int WriteSqlDecimal(SqlDecimal val); + int WriteDateTime(DateTime val); + int WriteDateTimeOffset(DateTimeOffset dtoVal); + int WriteTimeSpan(TimeSpan val); + int WriteString(string val); + int WriteBytes(byte[] bytes, int length); + void FlushBuffer(); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamFactory.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamFactory.cs new file mode 100644 index 00000000..c06a13ac --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamFactory.cs @@ -0,0 +1,64 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.IO; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage +{ + /// + /// Factory that creates file reader/writers that process rows in an internal, non-human readable file format + /// + public class ServiceBufferFileStreamFactory : IFileStreamFactory + { + /// + /// Creates a new temporary file + /// + /// The name of the temporary file + public string CreateFile() + { + return Path.GetTempFileName(); + } + + /// + /// Creates a new for reading values back from + /// an SSMS formatted buffer file + /// + /// The file to read values from + /// A + public IFileStreamReader GetReader(string fileName) + { + return new ServiceBufferFileStreamReader(new FileStreamWrapper(), fileName); + } + + /// + /// Creates a new for writing values out to an + /// SSMS formatted buffer file + /// + /// The file to write values to + /// The maximum number of characters to store from long text fields + /// The maximum number of characters to store from xml fields + /// A + public IFileStreamWriter GetWriter(string fileName, int maxCharsToStore, int maxXmlCharsToStore) + { + return new ServiceBufferFileStreamWriter(new FileStreamWrapper(), fileName, maxCharsToStore, maxXmlCharsToStore); + } + + /// + /// Disposes of a file created via this factory + /// + /// The file to dispose of + public void DisposeFile(string fileName) + { + try + { + FileStreamWrapper.DeleteFile(fileName); + } + catch + { + // If we have problems deleting the file from a temp location, we don't really care + } + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamReader.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamReader.cs new file mode 100644 index 00000000..0cfc2466 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamReader.cs @@ -0,0 +1,889 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Data.SqlTypes; +using System.Diagnostics; +using System.IO; +using System.Text; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage +{ + /// + /// Reader for SSMS formatted file streams + /// + public class ServiceBufferFileStreamReader : IFileStreamReader + { + // Most of this code is based on code from the Microsoft.SqlServer.Management.UI.Grid, SSMS DataStorage + // $\Data Tools\SSMS_XPlat\sql\ssms\core\DataStorage\src\FileStreamReader.cs + + private const int DefaultBufferSize = 8192; + + #region Member Variables + + private byte[] buffer; + + private readonly IFileStreamWrapper fileStream; + + #endregion + + /// + /// Constructs a new ServiceBufferFileStreamReader and initializes its state + /// + /// The filestream wrapper to read from + /// The name of the file to read from + public ServiceBufferFileStreamReader(IFileStreamWrapper fileWrapper, string fileName) + { + // Open file for reading/writing + fileStream = fileWrapper; + fileStream.Init(fileName, DefaultBufferSize, FileAccess.Read); + + // Create internal buffer + buffer = new byte[DefaultBufferSize]; + } + + #region IFileStreamStorage Implementation + + /// + /// Reads a row from the file, based on the columns provided + /// + /// Offset into the file where the row starts + /// The columns that were encoded + /// The objects from the row + public object[] ReadRow(long fileOffset, IEnumerable columns) + { + // Initialize for the loop + long currentFileOffset = fileOffset; + List results = new List(); + + // Iterate over the columns + foreach (DbColumnWrapper column in columns) + { + // We will pivot based on the type of the column + Type colType; + if (column.IsSqlVariant) + { + // For SQL Variant columns, the type is written first in string format + FileStreamReadResult sqlVariantTypeResult = ReadString(currentFileOffset); + currentFileOffset += sqlVariantTypeResult.TotalLength; + + // If the typename is null, then the whole value is null + if (sqlVariantTypeResult.IsNull) + { + results.Add(null); + continue; + } + + // The typename is stored in the string + colType = Type.GetType(sqlVariantTypeResult.Value); + + // Workaround .NET bug, see sqlbu# 440643 and vswhidbey# 599834 + // TODO: Is this workaround necessary for .NET Core? + if (colType == null && sqlVariantTypeResult.Value == "System.Data.SqlTypes.SqlSingle") + { + colType = typeof(SqlSingle); + } + } + else + { + colType = column.DataType; + } + + if (colType == typeof(string)) + { + // String - most frequently used data type + FileStreamReadResult result = ReadString(currentFileOffset); + currentFileOffset += result.TotalLength; + results.Add(result.IsNull ? null : result.Value); + } + else if (colType == typeof(SqlString)) + { + // SqlString + FileStreamReadResult result = ReadString(currentFileOffset); + currentFileOffset += result.TotalLength; + results.Add(result.IsNull ? null : (SqlString) result.Value); + } + else if (colType == typeof(short)) + { + // Int16 + FileStreamReadResult result = ReadInt16(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(SqlInt16)) + { + // SqlInt16 + FileStreamReadResult result = ReadInt16(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add((SqlInt16)result.Value); + } + } + else if (colType == typeof(int)) + { + // Int32 + FileStreamReadResult result = ReadInt32(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(SqlInt32)) + { + // SqlInt32 + FileStreamReadResult result = ReadInt32(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add((SqlInt32)result.Value); + } + } + else if (colType == typeof(long)) + { + // Int64 + FileStreamReadResult result = ReadInt64(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(SqlInt64)) + { + // SqlInt64 + FileStreamReadResult result = ReadInt64(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add((SqlInt64)result.Value); + } + } + else if (colType == typeof(byte)) + { + // byte + FileStreamReadResult result = ReadByte(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(SqlByte)) + { + // SqlByte + FileStreamReadResult result = ReadByte(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add((SqlByte)result.Value); + } + } + else if (colType == typeof(char)) + { + // Char + FileStreamReadResult result = ReadChar(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(bool)) + { + // Bool + FileStreamReadResult result = ReadBoolean(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(SqlBoolean)) + { + // SqlBoolean + FileStreamReadResult result = ReadBoolean(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add((SqlBoolean)result.Value); + } + } + else if (colType == typeof(double)) + { + // double + FileStreamReadResult result = ReadDouble(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(SqlDouble)) + { + // SqlByte + FileStreamReadResult result = ReadDouble(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add((SqlDouble)result.Value); + } + } + else if (colType == typeof(float)) + { + // float + FileStreamReadResult result = ReadSingle(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(SqlSingle)) + { + // SqlSingle + FileStreamReadResult result = ReadSingle(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add((SqlSingle)result.Value); + } + } + else if (colType == typeof(decimal)) + { + // Decimal + FileStreamReadResult result = ReadDecimal(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(SqlDecimal)) + { + // SqlDecimal + FileStreamReadResult result = ReadSqlDecimal(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(DateTime)) + { + // DateTime + FileStreamReadResult result = ReadDateTime(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(SqlDateTime)) + { + // SqlDateTime + FileStreamReadResult result = ReadDateTime(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add((SqlDateTime)result.Value); + } + } + else if (colType == typeof(DateTimeOffset)) + { + // DateTimeOffset + FileStreamReadResult result = ReadDateTimeOffset(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(TimeSpan)) + { + // TimeSpan + FileStreamReadResult result = ReadTimeSpan(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(byte[])) + { + // Byte Array + FileStreamReadResult result = ReadBytes(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull || (column.IsUdt && result.Value.Length == 0)) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(SqlBytes)) + { + // SqlBytes + FileStreamReadResult result = ReadBytes(currentFileOffset); + currentFileOffset += result.TotalLength; + results.Add(result.IsNull ? null : new SqlBytes(result.Value)); + } + else if (colType == typeof(SqlBinary)) + { + // SqlBinary + FileStreamReadResult result = ReadBytes(currentFileOffset); + currentFileOffset += result.TotalLength; + results.Add(result.IsNull ? null : new SqlBinary(result.Value)); + } + else if (colType == typeof(SqlGuid)) + { + // SqlGuid + FileStreamReadResult result = ReadBytes(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(new SqlGuid(result.Value)); + } + } + else if (colType == typeof(SqlMoney)) + { + // SqlMoney + FileStreamReadResult result = ReadDecimal(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(new SqlMoney(result.Value)); + } + } + else + { + // Treat everything else as a string + FileStreamReadResult result = ReadString(currentFileOffset); + currentFileOffset += result.TotalLength; + results.Add(result.IsNull ? null : result.Value); + } + } + + return results.ToArray(); + } + + /// + /// Reads a short from the file at the offset provided + /// + /// Offset into the file to read the short from + /// A short + public FileStreamReadResult ReadInt16(long fileOffset) + { + + LengthResult length = ReadLength(fileOffset); + Debug.Assert(length.ValueLength == 0 || length.ValueLength == 2, "Invalid data length"); + + bool isNull = length.ValueLength == 0; + short val = default(short); + if (!isNull) + { + fileStream.ReadData(buffer, length.ValueLength); + val = BitConverter.ToInt16(buffer, 0); + } + + return new FileStreamReadResult(val, length.TotalLength, isNull); + } + + /// + /// Reads a int from the file at the offset provided + /// + /// Offset into the file to read the int from + /// An int + public FileStreamReadResult ReadInt32(long fileOffset) + { + LengthResult length = ReadLength(fileOffset); + Debug.Assert(length.ValueLength == 0 || length.ValueLength == 4, "Invalid data length"); + + bool isNull = length.ValueLength == 0; + int val = default(int); + if (!isNull) + { + fileStream.ReadData(buffer, length.ValueLength); + val = BitConverter.ToInt32(buffer, 0); + } + return new FileStreamReadResult(val, length.TotalLength, isNull); + } + + /// + /// Reads a long from the file at the offset provided + /// + /// Offset into the file to read the long from + /// A long + public FileStreamReadResult ReadInt64(long fileOffset) + { + LengthResult length = ReadLength(fileOffset); + Debug.Assert(length.ValueLength == 0 || length.ValueLength == 8, "Invalid data length"); + + bool isNull = length.ValueLength == 0; + long val = default(long); + if (!isNull) + { + fileStream.ReadData(buffer, length.ValueLength); + val = BitConverter.ToInt64(buffer, 0); + } + return new FileStreamReadResult(val, length.TotalLength, isNull); + } + + /// + /// Reads a byte from the file at the offset provided + /// + /// Offset into the file to read the byte from + /// A byte + public FileStreamReadResult ReadByte(long fileOffset) + { + LengthResult length = ReadLength(fileOffset); + Debug.Assert(length.ValueLength == 0 || length.ValueLength == 1, "Invalid data length"); + + bool isNull = length.ValueLength == 0; + byte val = default(byte); + if (!isNull) + { + fileStream.ReadData(buffer, length.ValueLength); + val = buffer[0]; + } + return new FileStreamReadResult(val, length.TotalLength, isNull); + } + + /// + /// Reads a char from the file at the offset provided + /// + /// Offset into the file to read the char from + /// A char + public FileStreamReadResult ReadChar(long fileOffset) + { + LengthResult length = ReadLength(fileOffset); + Debug.Assert(length.ValueLength == 0 || length.ValueLength == 2, "Invalid data length"); + + bool isNull = length.ValueLength == 0; + char val = default(char); + if (!isNull) + { + fileStream.ReadData(buffer, length.ValueLength); + val = BitConverter.ToChar(buffer, 0); + } + return new FileStreamReadResult(val, length.TotalLength, isNull); + } + + /// + /// Reads a bool from the file at the offset provided + /// + /// Offset into the file to read the bool from + /// A bool + public FileStreamReadResult ReadBoolean(long fileOffset) + { + LengthResult length = ReadLength(fileOffset); + Debug.Assert(length.ValueLength == 0 || length.ValueLength == 1, "Invalid data length"); + + bool isNull = length.ValueLength == 0; + bool val = default(bool); + if (!isNull) + { + fileStream.ReadData(buffer, length.ValueLength); + val = buffer[0] == 0x01; + } + return new FileStreamReadResult(val, length.TotalLength, isNull); + } + + /// + /// Reads a single from the file at the offset provided + /// + /// Offset into the file to read the single from + /// A single + public FileStreamReadResult ReadSingle(long fileOffset) + { + LengthResult length = ReadLength(fileOffset); + Debug.Assert(length.ValueLength == 0 || length.ValueLength == 4, "Invalid data length"); + + bool isNull = length.ValueLength == 0; + float val = default(float); + if (!isNull) + { + fileStream.ReadData(buffer, length.ValueLength); + val = BitConverter.ToSingle(buffer, 0); + } + return new FileStreamReadResult(val, length.TotalLength, isNull); + } + + /// + /// Reads a double from the file at the offset provided + /// + /// Offset into the file to read the double from + /// A double + public FileStreamReadResult ReadDouble(long fileOffset) + { + LengthResult length = ReadLength(fileOffset); + Debug.Assert(length.ValueLength == 0 || length.ValueLength == 8, "Invalid data length"); + + bool isNull = length.ValueLength == 0; + double val = default(double); + if (!isNull) + { + fileStream.ReadData(buffer, length.ValueLength); + val = BitConverter.ToDouble(buffer, 0); + } + return new FileStreamReadResult(val, length.TotalLength, isNull); + } + + /// + /// Reads a SqlDecimal from the file at the offset provided + /// + /// Offset into the file to read the SqlDecimal from + /// A SqlDecimal + public FileStreamReadResult ReadSqlDecimal(long offset) + { + LengthResult length = ReadLength(offset); + Debug.Assert(length.ValueLength == 0 || (length.ValueLength - 3)%4 == 0, + string.Format("Invalid data length: {0}", length.ValueLength)); + + bool isNull = length.ValueLength == 0; + SqlDecimal val = default(SqlDecimal); + if (!isNull) + { + fileStream.ReadData(buffer, length.ValueLength); + + int[] arrInt32 = new int[(length.ValueLength - 3)/4]; + Buffer.BlockCopy(buffer, 3, arrInt32, 0, length.ValueLength - 3); + val = new SqlDecimal(buffer[0], buffer[1], 1 == buffer[2], arrInt32); + } + return new FileStreamReadResult(val, length.TotalLength, isNull); + } + + /// + /// Reads a decimal from the file at the offset provided + /// + /// Offset into the file to read the decimal from + /// A decimal + public FileStreamReadResult ReadDecimal(long offset) + { + LengthResult length = ReadLength(offset); + Debug.Assert(length.ValueLength%4 == 0, "Invalid data length"); + + bool isNull = length.ValueLength == 0; + decimal val = default(decimal); + if (!isNull) + { + fileStream.ReadData(buffer, length.ValueLength); + + int[] arrInt32 = new int[length.ValueLength/4]; + Buffer.BlockCopy(buffer, 0, arrInt32, 0, length.ValueLength); + val = new decimal(arrInt32); + } + return new FileStreamReadResult(val, length.TotalLength, isNull); + } + + /// + /// Reads a DateTime from the file at the offset provided + /// + /// Offset into the file to read the DateTime from + /// A DateTime + public FileStreamReadResult ReadDateTime(long offset) + { + FileStreamReadResult ticks = ReadInt64(offset); + DateTime val = default(DateTime); + if (!ticks.IsNull) + { + val = new DateTime(ticks.Value); + } + return new FileStreamReadResult(val, ticks.TotalLength, ticks.IsNull); + } + + /// + /// Reads a DateTimeOffset from the file at the offset provided + /// + /// Offset into the file to read the DateTimeOffset from + /// A DateTimeOffset + public FileStreamReadResult ReadDateTimeOffset(long offset) + { + // DateTimeOffset is represented by DateTime.Ticks followed by TimeSpan.Ticks + // both as Int64 values + + // read the DateTime ticks + DateTimeOffset val = default(DateTimeOffset); + FileStreamReadResult dateTimeTicks = ReadInt64(offset); + int totalLength = dateTimeTicks.TotalLength; + if (dateTimeTicks.TotalLength > 0 && !dateTimeTicks.IsNull) + { + // read the TimeSpan ticks + FileStreamReadResult timeSpanTicks = ReadInt64(offset + dateTimeTicks.TotalLength); + Debug.Assert(!timeSpanTicks.IsNull, "TimeSpan ticks cannot be null if DateTime ticks are not null!"); + + totalLength += timeSpanTicks.TotalLength; + + // build the DateTimeOffset + val = new DateTimeOffset(new DateTime(dateTimeTicks.Value), new TimeSpan(timeSpanTicks.Value)); + } + return new FileStreamReadResult(val, totalLength, dateTimeTicks.IsNull); + } + + /// + /// Reads a TimeSpan from the file at the offset provided + /// + /// Offset into the file to read the TimeSpan from + /// A TimeSpan + public FileStreamReadResult ReadTimeSpan(long offset) + { + FileStreamReadResult timeSpanTicks = ReadInt64(offset); + TimeSpan val = default(TimeSpan); + if (!timeSpanTicks.IsNull) + { + val = new TimeSpan(timeSpanTicks.Value); + } + return new FileStreamReadResult(val, timeSpanTicks.TotalLength, timeSpanTicks.IsNull); + } + + /// + /// Reads a string from the file at the offset provided + /// + /// Offset into the file to read the string from + /// A string + public FileStreamReadResult ReadString(long offset) + { + LengthResult fieldLength = ReadLength(offset); + Debug.Assert(fieldLength.ValueLength%2 == 0, "Invalid data length"); + + if (fieldLength.ValueLength == 0) // there is no data + { + // If the total length is 5 (5 bytes for length, 0 for value), then the string is empty + // Otherwise, the string is null + bool isNull = fieldLength.TotalLength != 5; + return new FileStreamReadResult(isNull ? null : string.Empty, + fieldLength.TotalLength, isNull); + } + + // positive length + AssureBufferLength(fieldLength.ValueLength); + fileStream.ReadData(buffer, fieldLength.ValueLength); + return new FileStreamReadResult(Encoding.Unicode.GetString(buffer, 0, fieldLength.ValueLength), fieldLength.TotalLength, false); + } + + /// + /// Reads bytes from the file at the offset provided + /// + /// Offset into the file to read the bytes from + /// A byte array + public FileStreamReadResult ReadBytes(long offset) + { + LengthResult fieldLength = ReadLength(offset); + + if (fieldLength.ValueLength == 0) + { + // If the total length is 5 (5 bytes for length, 0 for value), then the byte array is 0x + // Otherwise, the byte array is null + bool isNull = fieldLength.TotalLength != 5; + return new FileStreamReadResult(isNull ? null : new byte[0], + fieldLength.TotalLength, isNull); + } + + // positive length + byte[] val = new byte[fieldLength.ValueLength]; + fileStream.ReadData(val, fieldLength.ValueLength); + return new FileStreamReadResult(val, fieldLength.TotalLength, false); + } + + /// + /// Reads the length of a field at the specified offset in the file + /// + /// Offset into the file to read the field length from + /// A LengthResult + internal LengthResult ReadLength(long offset) + { + // read in length information + int lengthValue; + int lengthLength = fileStream.ReadData(buffer, 1, offset); + if (buffer[0] != 0xFF) + { + // one byte is enough + lengthValue = Convert.ToInt32(buffer[0]); + } + else + { + // read in next 4 bytes + lengthLength += fileStream.ReadData(buffer, 4); + + // reconstruct the length + lengthValue = BitConverter.ToInt32(buffer, 0); + } + + return new LengthResult {LengthLength = lengthLength, ValueLength = lengthValue}; + } + + #endregion + + /// + /// Internal struct used for representing the length of a field from the file + /// + internal struct LengthResult + { + /// + /// How many bytes the length takes up + /// + public int LengthLength { get; set; } + + /// + /// How many bytes the value takes up + /// + public int ValueLength { get; set; } + + /// + /// + + /// + public int TotalLength + { + get { return LengthLength + ValueLength; } + } + } + + /// + /// Creates a new buffer that is of the specified length if the buffer is not already + /// at least as long as specified. + /// + /// The minimum buffer size + private void AssureBufferLength(int newBufferLength) + { + if (buffer.Length < newBufferLength) + { + buffer = new byte[newBufferLength]; + } + } + + #region IDisposable Implementation + + private bool disposed; + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (disposed) + { + return; + } + + if (disposing) + { + fileStream.Dispose(); + } + + disposed = true; + } + + ~ServiceBufferFileStreamReader() + { + Dispose(false); + } + + #endregion + + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamWriter.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamWriter.cs new file mode 100644 index 00000000..d0a1c2a9 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamWriter.cs @@ -0,0 +1,749 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Data.SqlTypes; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Text; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage +{ + /// + /// Writer for SSMS formatted file streams + /// + public class ServiceBufferFileStreamWriter : IFileStreamWriter + { + // Most of this code is based on code from the Microsoft.SqlServer.Management.UI.Grid, SSMS DataStorage + // $\Data Tools\SSMS_XPlat\sql\ssms\core\DataStorage\src\FileStreamWriter.cs + + #region Properties + + public const int DefaultBufferLength = 8192; + + private int MaxCharsToStore { get; set; } + private int MaxXmlCharsToStore { get; set; } + + private IFileStreamWrapper FileStream { get; set; } + private byte[] byteBuffer; + private readonly short[] shortBuffer; + private readonly int[] intBuffer; + private readonly long[] longBuffer; + private readonly char[] charBuffer; + private readonly double[] doubleBuffer; + private readonly float[] floatBuffer; + + #endregion + + /// + /// Constructs a new writer + /// + /// The file wrapper to use as the underlying file stream + /// Name of the file to write to + /// Maximum number of characters to store for long text fields + /// Maximum number of characters to store for XML fields + public ServiceBufferFileStreamWriter(IFileStreamWrapper fileWrapper, string fileName, int maxCharsToStore, int maxXmlCharsToStore) + { + // open file for reading/writing + FileStream = fileWrapper; + FileStream.Init(fileName, DefaultBufferLength, FileAccess.ReadWrite); + + // create internal buffer + byteBuffer = new byte[DefaultBufferLength]; + + // Create internal buffers for blockcopy of contents to byte array + // Note: We create them now to avoid the overhead of creating a new array for every write call + shortBuffer = new short[1]; + intBuffer = new int[1]; + longBuffer = new long[1]; + charBuffer = new char[1]; + doubleBuffer = new double[1]; + floatBuffer = new float[1]; + + // Store max chars to store + MaxCharsToStore = maxCharsToStore; + MaxXmlCharsToStore = maxXmlCharsToStore; + } + + #region IFileStreamWriter Implementation + + /// + /// Writes an entire row to the file stream + /// + /// A primed reader + /// Number of bytes used to write the row + public int WriteRow(StorageDataReader reader) + { + // Determine if we have any long fields + bool hasLongFields = reader.Columns.Any(column => column.IsLong); + + object[] values = new object[reader.Columns.Length]; + int rowBytes = 0; + if (!hasLongFields) + { + // get all record values in one shot if there are no extra long fields + reader.GetValues(values); + } + + // Loop over all the columns and write the values to the temp file + for (int i = 0; i < reader.Columns.Length; i++) + { + DbColumnWrapper ci = reader.Columns[i]; + if (hasLongFields) + { + if (reader.IsDBNull(i)) + { + // Need special case for DBNull because + // reader.GetValue doesn't return DBNull in case of SqlXml and CLR type + values[i] = DBNull.Value; + } + else + { + if (!ci.IsLong) + { + // not a long field + values[i] = reader.GetValue(i); + } + else + { + // this is a long field + if (ci.IsBytes) + { + values[i] = reader.GetBytesWithMaxCapacity(i, MaxCharsToStore); + } + else if (ci.IsChars) + { + Debug.Assert(MaxCharsToStore > 0); + values[i] = reader.GetCharsWithMaxCapacity(i, + ci.IsXml ? MaxXmlCharsToStore : MaxCharsToStore); + } + else if (ci.IsXml) + { + Debug.Assert(MaxXmlCharsToStore > 0); + values[i] = reader.GetXmlWithMaxCapacity(i, MaxXmlCharsToStore); + } + else + { + // we should never get here + Debug.Assert(false); + } + } + } + } + + Type tVal = values[i].GetType(); // get true type of the object + + if (tVal == typeof(DBNull)) + { + rowBytes += WriteNull(); + } + else + { + if (ci.IsSqlVariant) + { + // serialize type information as a string before the value + string val = tVal.ToString(); + rowBytes += WriteString(val); + } + + if (tVal == typeof(string)) + { + // String - most frequently used data type + string val = (string)values[i]; + rowBytes += WriteString(val); + } + else if (tVal == typeof(SqlString)) + { + // SqlString + SqlString val = (SqlString)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteString(val.Value); + } + } + else if (tVal == typeof(short)) + { + // Int16 + short val = (short)values[i]; + rowBytes += WriteInt16(val); + } + else if (tVal == typeof(SqlInt16)) + { + // SqlInt16 + SqlInt16 val = (SqlInt16)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteInt16(val.Value); + } + } + else if (tVal == typeof(int)) + { + // Int32 + int val = (int)values[i]; + rowBytes += WriteInt32(val); + } + else if (tVal == typeof(SqlInt32)) + { + // SqlInt32 + SqlInt32 val = (SqlInt32)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteInt32(val.Value); + } + } + else if (tVal == typeof(long)) + { + // Int64 + long val = (long)values[i]; + rowBytes += WriteInt64(val); + } + else if (tVal == typeof(SqlInt64)) + { + // SqlInt64 + SqlInt64 val = (SqlInt64)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteInt64(val.Value); + } + } + else if (tVal == typeof(byte)) + { + // Byte + byte val = (byte)values[i]; + rowBytes += WriteByte(val); + } + else if (tVal == typeof(SqlByte)) + { + // SqlByte + SqlByte val = (SqlByte)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteByte(val.Value); + } + } + else if (tVal == typeof(char)) + { + // Char + char val = (char)values[i]; + rowBytes += WriteChar(val); + } + else if (tVal == typeof(bool)) + { + // Boolean + bool val = (bool)values[i]; + rowBytes += WriteBoolean(val); + } + else if (tVal == typeof(SqlBoolean)) + { + // SqlBoolean + SqlBoolean val = (SqlBoolean)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteBoolean(val.Value); + } + } + else if (tVal == typeof(double)) + { + // Double + double val = (double)values[i]; + rowBytes += WriteDouble(val); + } + else if (tVal == typeof(SqlDouble)) + { + // SqlDouble + SqlDouble val = (SqlDouble)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteDouble(val.Value); + } + } + else if (tVal == typeof(SqlSingle)) + { + // SqlSingle + SqlSingle val = (SqlSingle)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteSingle(val.Value); + } + } + else if (tVal == typeof(decimal)) + { + // Decimal + decimal val = (decimal)values[i]; + rowBytes += WriteDecimal(val); + } + else if (tVal == typeof(SqlDecimal)) + { + // SqlDecimal + SqlDecimal val = (SqlDecimal)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteSqlDecimal(val); + } + } + else if (tVal == typeof(DateTime)) + { + // DateTime + DateTime val = (DateTime)values[i]; + rowBytes += WriteDateTime(val); + } + else if (tVal == typeof(DateTimeOffset)) + { + // DateTimeOffset + DateTimeOffset val = (DateTimeOffset)values[i]; + rowBytes += WriteDateTimeOffset(val); + } + else if (tVal == typeof(SqlDateTime)) + { + // SqlDateTime + SqlDateTime val = (SqlDateTime)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteDateTime(val.Value); + } + } + else if (tVal == typeof(TimeSpan)) + { + // TimeSpan + TimeSpan val = (TimeSpan)values[i]; + rowBytes += WriteTimeSpan(val); + } + else if (tVal == typeof(byte[])) + { + // Bytes + byte[] val = (byte[])values[i]; + rowBytes += WriteBytes(val, val.Length); + } + else if (tVal == typeof(SqlBytes)) + { + // SqlBytes + SqlBytes val = (SqlBytes)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteBytes(val.Value, val.Value.Length); + } + } + else if (tVal == typeof(SqlBinary)) + { + // SqlBinary + SqlBinary val = (SqlBinary)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteBytes(val.Value, val.Value.Length); + } + } + else if (tVal == typeof(SqlGuid)) + { + // SqlGuid + SqlGuid val = (SqlGuid)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + byte[] bytesVal = val.ToByteArray(); + rowBytes += WriteBytes(bytesVal, bytesVal.Length); + } + } + else if (tVal == typeof(SqlMoney)) + { + // SqlMoney + SqlMoney val = (SqlMoney)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteDecimal(val.Value); + } + } + else + { + // treat everything else as string + string val = values[i].ToString(); + rowBytes += WriteString(val); + } + } + } + + // Flush the buffer after every row + FlushBuffer(); + return rowBytes; + } + + /// + /// Writes null to the file as one 0x00 byte + /// + /// Number of bytes used to store the null + public int WriteNull() + { + byteBuffer[0] = 0x00; + return FileStream.WriteData(byteBuffer, 1); + } + + /// + /// Writes a short to the file + /// + /// Number of bytes used to store the short + public int WriteInt16(short val) + { + byteBuffer[0] = 0x02; // length + shortBuffer[0] = val; + Buffer.BlockCopy(shortBuffer, 0, byteBuffer, 1, 2); + return FileStream.WriteData(byteBuffer, 3); + } + + /// + /// Writes a int to the file + /// + /// Number of bytes used to store the int + public int WriteInt32(int val) + { + byteBuffer[0] = 0x04; // length + intBuffer[0] = val; + Buffer.BlockCopy(intBuffer, 0, byteBuffer, 1, 4); + return FileStream.WriteData(byteBuffer, 5); + } + + /// + /// Writes a long to the file + /// + /// Number of bytes used to store the long + public int WriteInt64(long val) + { + byteBuffer[0] = 0x08; // length + longBuffer[0] = val; + Buffer.BlockCopy(longBuffer, 0, byteBuffer, 1, 8); + return FileStream.WriteData(byteBuffer, 9); + } + + /// + /// Writes a char to the file + /// + /// Number of bytes used to store the char + public int WriteChar(char val) + { + byteBuffer[0] = 0x02; // length + charBuffer[0] = val; + Buffer.BlockCopy(charBuffer, 0, byteBuffer, 1, 2); + return FileStream.WriteData(byteBuffer, 3); + } + + /// + /// Writes a bool to the file + /// + /// Number of bytes used to store the bool + public int WriteBoolean(bool val) + { + byteBuffer[0] = 0x01; // length + byteBuffer[1] = (byte) (val ? 0x01 : 0x00); + return FileStream.WriteData(byteBuffer, 2); + } + + /// + /// Writes a byte to the file + /// + /// Number of bytes used to store the byte + public int WriteByte(byte val) + { + byteBuffer[0] = 0x01; // length + byteBuffer[1] = val; + return FileStream.WriteData(byteBuffer, 2); + } + + /// + /// Writes a float to the file + /// + /// Number of bytes used to store the float + public int WriteSingle(float val) + { + byteBuffer[0] = 0x04; // length + floatBuffer[0] = val; + Buffer.BlockCopy(floatBuffer, 0, byteBuffer, 1, 4); + return FileStream.WriteData(byteBuffer, 5); + } + + /// + /// Writes a double to the file + /// + /// Number of bytes used to store the double + public int WriteDouble(double val) + { + byteBuffer[0] = 0x08; // length + doubleBuffer[0] = val; + Buffer.BlockCopy(doubleBuffer, 0, byteBuffer, 1, 8); + return FileStream.WriteData(byteBuffer, 9); + } + + /// + /// Writes a SqlDecimal to the file + /// + /// Number of bytes used to store the SqlDecimal + public int WriteSqlDecimal(SqlDecimal val) + { + int[] arrInt32 = val.Data; + int iLen = 3 + (arrInt32.Length * 4); + int iTotalLen = WriteLength(iLen); // length + + // precision + byteBuffer[0] = val.Precision; + + // scale + byteBuffer[1] = val.Scale; + + // positive + byteBuffer[2] = (byte)(val.IsPositive ? 0x01 : 0x00); + + // data value + Buffer.BlockCopy(arrInt32, 0, byteBuffer, 3, iLen - 3); + iTotalLen += FileStream.WriteData(byteBuffer, iLen); + return iTotalLen; // len+data + } + + /// + /// Writes a decimal to the file + /// + /// Number of bytes used to store the decimal + public int WriteDecimal(decimal val) + { + int[] arrInt32 = decimal.GetBits(val); + + int iLen = arrInt32.Length * 4; + int iTotalLen = WriteLength(iLen); // length + + Buffer.BlockCopy(arrInt32, 0, byteBuffer, 0, iLen); + iTotalLen += FileStream.WriteData(byteBuffer, iLen); + + return iTotalLen; // len+data + } + + /// + /// Writes a DateTime to the file + /// + /// Number of bytes used to store the DateTime + public int WriteDateTime(DateTime dtVal) + { + return WriteInt64(dtVal.Ticks); + } + + /// + /// Writes a DateTimeOffset to the file + /// + /// Number of bytes used to store the DateTimeOffset + public int WriteDateTimeOffset(DateTimeOffset dtoVal) + { + // DateTimeOffset gets written as a DateTime + TimeOffset + // both represented as 'Ticks' written as Int64's + return WriteInt64(dtoVal.Ticks) + WriteInt64(dtoVal.Offset.Ticks); + } + + /// + /// Writes a TimeSpan to the file + /// + /// Number of bytes used to store the TimeSpan + public int WriteTimeSpan(TimeSpan timeSpan) + { + return WriteInt64(timeSpan.Ticks); + } + + /// + /// Writes a string to the file + /// + /// Number of bytes used to store the string + public int WriteString(string sVal) + { + if (sVal == null) + { + throw new ArgumentNullException(nameof(sVal), "String to store must be non-null."); + } + + int iTotalLen; + if (0 == sVal.Length) // special case of 0 length string + { + const int iLen = 5; + + AssureBufferLength(iLen); + byteBuffer[0] = 0xFF; + byteBuffer[1] = 0x00; + byteBuffer[2] = 0x00; + byteBuffer[3] = 0x00; + byteBuffer[4] = 0x00; + + iTotalLen = FileStream.WriteData(byteBuffer, 5); + } + else + { + // Convert to a unicode byte array + byte[] bytes = Encoding.Unicode.GetBytes(sVal); + + // convert char array into byte array and write it out + iTotalLen = WriteLength(bytes.Length); + iTotalLen += FileStream.WriteData(bytes, bytes.Length); + } + return iTotalLen; // len+data + } + + /// + /// Writes a byte[] to the file + /// + /// Number of bytes used to store the byte[] + public int WriteBytes(byte[] bytesVal, int iLen) + { + if (bytesVal == null) + { + throw new ArgumentNullException(nameof(bytesVal), "Byte array to store must be non-null."); + } + + int iTotalLen; + if (0 == iLen) // special case of 0 length byte array "0x" + { + iLen = 5; + + AssureBufferLength(iLen); + byteBuffer[0] = 0xFF; + byteBuffer[1] = 0x00; + byteBuffer[2] = 0x00; + byteBuffer[3] = 0x00; + byteBuffer[4] = 0x00; + + iTotalLen = FileStream.WriteData(byteBuffer, iLen); + } + else + { + iTotalLen = WriteLength(iLen); + iTotalLen += FileStream.WriteData(bytesVal, iLen); + } + return iTotalLen; // len+data + } + + /// + /// Writes the length of the field using the appropriate number of bytes (ie, 1 if the + /// length is <255, 5 if the length is >=255) + /// + /// Number of bytes used to store the length + private int WriteLength(int iLen) + { + if (iLen < 0xFF) + { + // fits in one byte of memory only need to write one byte + int iTmp = iLen & 0x000000FF; + + byteBuffer[0] = Convert.ToByte(iTmp); + return FileStream.WriteData(byteBuffer, 1); + } + // The length won't fit in 1 byte, so we need to use 1 byte to signify that the length + // is a full 4 bytes. + byteBuffer[0] = 0xFF; + + // convert int32 into array of bytes + intBuffer[0] = iLen; + Buffer.BlockCopy(intBuffer, 0, byteBuffer, 1, 4); + return FileStream.WriteData(byteBuffer, 5); + } + + /// + /// Flushes the internal buffer to the file stream + /// + public void FlushBuffer() + { + FileStream.Flush(); + } + + #endregion + + private void AssureBufferLength(int newBufferLength) + { + if (newBufferLength > byteBuffer.Length) + { + byteBuffer = new byte[byteBuffer.Length]; + } + } + + #region IDisposable Implementation + + private bool disposed; + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (disposed) + { + return; + } + + if (disposing) + { + FileStream.Flush(); + FileStream.Dispose(); + } + + disposed = true; + } + + ~ServiceBufferFileStreamWriter() + { + Dispose(false); + } + + #endregion + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/StorageDataReader.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/StorageDataReader.cs new file mode 100644 index 00000000..f63046b1 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/StorageDataReader.cs @@ -0,0 +1,356 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Data.Common; +using System.Data.SqlClient; +using System.Data.SqlTypes; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using System.Xml; +using Microsoft.SqlTools.EditorServices.Utility; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage +{ + /// + /// Wrapper around a DbData reader to perform some special operations more simply + /// + public class StorageDataReader + { + // This code is based on code from Microsoft.SqlServer.Management.UI.Grid, SSMS DataStorage, + // StorageDataReader + // $\Data Tools\SSMS_XPlat\sql\ssms\core\DataStorage\src\StorageDataReader.cs + + #region Member Variables + + /// + /// If the DbDataReader is a SqlDataReader, it will be set here + /// + private readonly SqlDataReader sqlDataReader; + + /// + /// Whether or not the data reader supports SqlXml types + /// + private readonly bool supportSqlXml; + + #endregion + + /// + /// Constructs a new wrapper around the provided reader + /// + /// The reader to wrap around + public StorageDataReader(DbDataReader reader) + { + // Sanity check to make sure there is a data reader + Validate.IsNotNull(nameof(reader), reader); + + // Attempt to use this reader as a SqlDataReader + sqlDataReader = reader as SqlDataReader; + supportSqlXml = sqlDataReader != null; + DbDataReader = reader; + + // Read the columns into a set of wrappers + Columns = DbDataReader.GetColumnSchema().Select(column => new DbColumnWrapper(column)).ToArray(); + } + + #region Properties + + /// + /// All the columns that this reader currently contains + /// + public DbColumnWrapper[] Columns { get; private set; } + + /// + /// The that will be read from + /// + public DbDataReader DbDataReader { get; private set; } + + #endregion + + #region DbDataReader Methods + + /// + /// Pass-through to DbDataReader.ReadAsync() + /// + /// The cancellation token to use for cancelling a query + /// + public Task ReadAsync(CancellationToken cancellationToken) + { + return DbDataReader.ReadAsync(cancellationToken); + } + + /// + /// Retrieves a value + /// + /// Column ordinal + /// The value of the given column + public object GetValue(int i) + { + return sqlDataReader == null ? DbDataReader.GetValue(i) : sqlDataReader.GetValue(i); + } + + /// + /// Stores all values of the current row into the provided object array + /// + /// Where to store the values from this row + public void GetValues(object[] values) + { + if (sqlDataReader == null) + { + DbDataReader.GetValues(values); + } + else + { + sqlDataReader.GetValues(values); + } + } + + /// + /// Whether or not the cell of the given column at the current row is a DBNull + /// + /// Column ordinal + /// True if the cell is DBNull, false otherwise + public bool IsDBNull(int i) + { + return DbDataReader.IsDBNull(i); + } + + #endregion + + #region Public Methods + + /// + /// Retrieves bytes with a maximum number of bytes to return + /// + /// Column ordinal + /// Number of bytes to return at maximum + /// Byte array + public byte[] GetBytesWithMaxCapacity(int iCol, int maxNumBytesToReturn) + { + if (maxNumBytesToReturn <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxNumBytesToReturn), "Maximum number of bytes to return must be greater than zero."); + } + + //first, ask provider how much data it has and calculate the final # of bytes + //NOTE: -1 means that it doesn't know how much data it has + long neededLength; + long origLength = neededLength = GetBytes(iCol, 0, null, 0, 0); + if (neededLength == -1 || neededLength > maxNumBytesToReturn) + { + neededLength = maxNumBytesToReturn; + } + + //get the data up to the maxNumBytesToReturn + byte[] bytesBuffer = new byte[neededLength]; + GetBytes(iCol, 0, bytesBuffer, 0, (int)neededLength); + + //see if server sent back more data than we should return + if (origLength == -1 || origLength > neededLength) + { + //pump the rest of data from the reader and discard it right away + long dataIndex = neededLength; + const int tmpBufSize = 100000; + byte[] tmpBuf = new byte[tmpBufSize]; + while (GetBytes(iCol, dataIndex, tmpBuf, 0, tmpBufSize) == tmpBufSize) + { + dataIndex += tmpBufSize; + } + } + + return bytesBuffer; + } + + /// + /// Retrieves characters with a maximum number of charss to return + /// + /// Column ordinal + /// Number of chars to return at maximum + /// String + public string GetCharsWithMaxCapacity(int iCol, int maxCharsToReturn) + { + if (maxCharsToReturn <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxCharsToReturn), "Maximum number of chars to return must be greater than zero"); + } + + //first, ask provider how much data it has and calculate the final # of chars + //NOTE: -1 means that it doesn't know how much data it has + long neededLength; + long origLength = neededLength = GetChars(iCol, 0, null, 0, 0); + if (neededLength == -1 || neededLength > maxCharsToReturn) + { + neededLength = maxCharsToReturn; + } + Debug.Assert(neededLength < int.MaxValue); + + //get the data up to maxCharsToReturn + char[] buffer = new char[neededLength]; + if (neededLength > 0) + { + GetChars(iCol, 0, buffer, 0, (int)neededLength); + } + + //see if server sent back more data than we should return + if (origLength == -1 || origLength > neededLength) + { + //pump the rest of data from the reader and discard it right away + long dataIndex = neededLength; + const int tmpBufSize = 100000; + char[] tmpBuf = new char[tmpBufSize]; + while (GetChars(iCol, dataIndex, tmpBuf, 0, tmpBufSize) == tmpBufSize) + { + dataIndex += tmpBufSize; + } + } + string res = new string(buffer); + return res; + } + + /// + /// Retrieves xml with a maximum number of bytes to return + /// + /// Column ordinal + /// Number of chars to return at maximum + /// String + public string GetXmlWithMaxCapacity(int iCol, int maxCharsToReturn) + { + if (supportSqlXml) + { + SqlXml sm = GetSqlXml(iCol); + if (sm == null) + { + return null; + } + + //this code is mostly copied from SqlClient implementation of returning value for XML data type + StringWriterWithMaxCapacity sw = new StringWriterWithMaxCapacity(null, maxCharsToReturn); + XmlWriterSettings writerSettings = new XmlWriterSettings + { + CloseOutput = false, + ConformanceLevel = ConformanceLevel.Fragment + }; + // don't close the memory stream + XmlWriter ww = XmlWriter.Create(sw, writerSettings); + + XmlReader reader = sm.CreateReader(); + reader.Read(); + + while (!reader.EOF) + { + ww.WriteNode(reader, true); + } + ww.Flush(); + return sw.ToString(); + } + + object o = GetValue(iCol); + return o?.ToString(); + } + + #endregion + + #region Private Helpers + + private long GetBytes(int i, long dataIndex, byte[] buffer, int bufferIndex, int length) + { + return DbDataReader.GetBytes(i, dataIndex, buffer, bufferIndex, length); + } + + private long GetChars(int i, long dataIndex, char[] buffer, int bufferIndex, int length) + { + return DbDataReader.GetChars(i, dataIndex, buffer, bufferIndex, length); + } + + private SqlXml GetSqlXml(int i) + { + if (sqlDataReader == null) + { + // We need a Sql data reader in order to retrieve sql xml + throw new InvalidOperationException("Cannot retrieve SqlXml without a SqlDataReader"); + } + + return sqlDataReader.GetSqlXml(i); + } + + #endregion + + /// + /// Internal class for writing strings with a maximum capacity + /// + /// + /// This code is take almost verbatim from Microsoft.SqlServer.Management.UI.Grid, SSMS + /// DataStorage, StorageDataReader class. + /// + private class StringWriterWithMaxCapacity : StringWriter + { + private bool stopWriting; + + private int CurrentLength + { + get { return GetStringBuilder().Length; } + } + + public StringWriterWithMaxCapacity(IFormatProvider formatProvider, int capacity) : base(formatProvider) + { + MaximumCapacity = capacity; + } + + private int MaximumCapacity { get; set; } + + public override void Write(char value) + { + if (stopWriting) { return; } + + if (CurrentLength < MaximumCapacity) + { + base.Write(value); + } + else + { + stopWriting = true; + } + } + + public override void Write(char[] buffer, int index, int count) + { + if (stopWriting) { return; } + + int curLen = CurrentLength; + if (curLen + (count - index) > MaximumCapacity) + { + stopWriting = true; + + count = MaximumCapacity - curLen + index; + if (count < 0) + { + count = 0; + } + } + base.Write(buffer, index, count); + } + + public override void Write(string value) + { + if (stopWriting) { return; } + + int curLen = CurrentLength; + if (value.Length + curLen > MaximumCapacity) + { + stopWriting = true; + base.Write(value.Substring(0, MaximumCapacity - curLen)); + } + else + { + base.Write(value); + } + } + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs new file mode 100644 index 00000000..2da2a15d --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs @@ -0,0 +1,290 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Data.Common; +using System.Data.SqlClient; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SqlServer.Management.SqlParser.Parser; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage; +using Microsoft.SqlTools.ServiceLayer.SqlContext; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution +{ + /// + /// Internal representation of an active query + /// + public class Query : IDisposable + { + /// + /// "Error" code produced by SQL Server when the database context (name) for a connection changes. + /// + private const int DatabaseContextChangeErrorNumber = 5701; + + #region Member Variables + + /// + /// Cancellation token source, used for cancelling async db actions + /// + private readonly CancellationTokenSource cancellationSource; + + /// + /// For IDisposable implementation, whether or not this object has been disposed + /// + private bool disposed; + + /// + /// The connection info associated with the file editor owner URI, used to create a new + /// connection upon execution of the query + /// + private readonly ConnectionInfo editorConnection; + + /// + /// Whether or not the execute method has been called for this query + /// + private bool hasExecuteBeenCalled; + + /// + /// The factory to use for outputting the results of this query + /// + private readonly IFileStreamFactory outputFileFactory; + + #endregion + + /// + /// Constructor for a query + /// + /// The text of the query to execute + /// The information of the connection to use to execute the query + /// Settings for how to execute the query, from the user + /// Factory for creating output files + public Query(string queryText, ConnectionInfo connection, QueryExecutionSettings settings, IFileStreamFactory outputFactory) + { + // Sanity check for input + if (string.IsNullOrEmpty(queryText)) + { + throw new ArgumentNullException(nameof(queryText), "Query text cannot be null"); + } + if (connection == null) + { + throw new ArgumentNullException(nameof(connection), "Connection cannot be null"); + } + if (settings == null) + { + throw new ArgumentNullException(nameof(settings), "Settings cannot be null"); + } + if (outputFactory == null) + { + throw new ArgumentNullException(nameof(outputFactory), "Output file factory cannot be null"); + } + + // Initialize the internal state + QueryText = queryText; + editorConnection = connection; + cancellationSource = new CancellationTokenSource(); + outputFileFactory = outputFactory; + + // Process the query into batches + ParseResult parseResult = Parser.Parse(queryText, new ParseOptions + { + BatchSeparator = settings.BatchSeparator + }); + // NOTE: We only want to process batches that have statements (ie, ignore comments and empty lines) + Batches = parseResult.Script.Batches.Where(b => b.Statements.Count > 0) + .Select(b => new Batch(b.Sql, b.StartLocation.LineNumber, outputFileFactory)).ToArray(); + } + + #region Properties + + /// + /// The batches underneath this query + /// + internal Batch[] Batches { get; set; } + + /// + /// The summaries of the batches underneath this query + /// + public BatchSummary[] BatchSummaries + { + get + { + if (!HasExecuted) + { + throw new InvalidOperationException("Query has not been executed."); + } + + return Batches.Select((batch, index) => new BatchSummary + { + Id = index, + HasError = batch.HasError, + Messages = batch.ResultMessages.ToArray(), + ResultSetSummaries = batch.ResultSummaries + }).ToArray(); + } + } + + /// + /// Whether or not the query has completed executed, regardless of success or failure + /// + /// + /// Don't touch the setter unless you're doing unit tests! + /// + public bool HasExecuted + { + get { return Batches.Length == 0 ? hasExecuteBeenCalled : Batches.All(b => b.HasExecuted); } + internal set + { + hasExecuteBeenCalled = value; + foreach (var batch in Batches) + { + batch.HasExecuted = value; + } + } + } + + /// + /// The text of the query to execute + /// + public string QueryText { get; set; } + + #endregion + + #region Public Methods + + /// + /// Cancels the query by issuing the cancellation token + /// + public void Cancel() + { + // Make sure that the query hasn't completed execution + if (HasExecuted) + { + throw new InvalidOperationException("The query has already completed, it cannot be cancelled."); + } + + // Issue the cancellation token for the query + cancellationSource.Cancel(); + } + + /// + /// Executes this query asynchronously and collects all result sets + /// + public async Task Execute() + { + // Mark that we've internally executed + hasExecuteBeenCalled = true; + + // Don't actually execute if there aren't any batches to execute + if (Batches.Length == 0) + { + return; + } + + // Open up a connection for querying the database + string connectionString = ConnectionService.BuildConnectionString(editorConnection.ConnectionDetails); + // TODO: Don't create a new connection every time, see TFS #834978 + using (DbConnection conn = editorConnection.Factory.CreateSqlConnection(connectionString)) + { + await conn.OpenAsync(); + + SqlConnection sqlConn = conn as SqlConnection; + if (sqlConn != null) + { + // Subscribe to database informational messages + sqlConn.InfoMessage += OnInfoMessage; + } + + // We need these to execute synchronously, otherwise the user will be very unhappy + foreach (Batch b in Batches) + { + await b.Execute(conn, cancellationSource.Token); + } + + // TODO: Close connection after eliminating using statement for above TODO + } + } + + /// + /// Handler for database messages during query execution + /// + private void OnInfoMessage(object sender, SqlInfoMessageEventArgs args) + { + SqlConnection conn = sender as SqlConnection; + if (conn == null) + { + throw new InvalidOperationException("Sender for OnInfoMessage event must be a SqlConnection"); + } + + foreach(SqlError error in args.Errors) + { + // Did the database context change (error code 5701)? + if (error.Number == DatabaseContextChangeErrorNumber) + { + ConnectionService.Instance.ChangeConnectionDatabaseContext(editorConnection.OwnerUri, conn.Database); + } + } + } + + /// + /// Retrieves a subset of the result sets + /// + /// The index for selecting the batch item + /// The index for selecting the result set + /// The starting row of the results + /// How many rows to retrieve + /// A subset of results + public Task GetSubset(int batchIndex, int resultSetIndex, int startRow, int rowCount) + { + // Sanity check that the results are available + if (!HasExecuted) + { + throw new InvalidOperationException("The query has not completed, yet."); + } + + // Sanity check to make sure that the batch is within bounds + if (batchIndex < 0 || batchIndex >= Batches.Length) + { + throw new ArgumentOutOfRangeException(nameof(batchIndex), "Result set index cannot be less than 0" + + "or greater than the number of result sets"); + } + + return Batches[batchIndex].GetSubset(resultSetIndex, startRow, rowCount); + } + + #endregion + + #region IDisposable Implementation + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (disposed) + { + return; + } + + if (disposing) + { + cancellationSource.Dispose(); + foreach (Batch b in Batches) + { + b.Dispose(); + } + } + + disposed = true; + } + + #endregion + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs new file mode 100644 index 00000000..97d89fc9 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs @@ -0,0 +1,372 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Concurrent; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Hosting; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage; +using Microsoft.SqlTools.ServiceLayer.SqlContext; +using Microsoft.SqlTools.ServiceLayer.Workspace; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution +{ + /// + /// Service for executing queries + /// + public sealed class QueryExecutionService : IDisposable + { + #region Singleton Instance Implementation + + private static readonly Lazy instance = new Lazy(() => new QueryExecutionService()); + + /// + /// Singleton instance of the query execution service + /// + public static QueryExecutionService Instance + { + get { return instance.Value; } + } + + private QueryExecutionService() + { + ConnectionService = ConnectionService.Instance; + } + + internal QueryExecutionService(ConnectionService connService) + { + ConnectionService = connService; + } + + #endregion + + #region Properties + + /// + /// File factory to be used to create a buffer file for results. + /// + /// + /// Made internal here to allow for overriding in unit testing + /// + internal IFileStreamFactory BufferFileStreamFactory; + + /// + /// File factory to be used to create a buffer file for results + /// + private IFileStreamFactory BufferFileFactory + { + get { return BufferFileStreamFactory ?? (BufferFileStreamFactory = new ServiceBufferFileStreamFactory()); } + } + + /// + /// The collection of active queries + /// + internal ConcurrentDictionary ActiveQueries + { + get { return queries.Value; } + } + + /// + /// Instance of the connection service, used to get the connection info for a given owner URI + /// + private ConnectionService ConnectionService { get; set; } + + /// + /// Internal storage of active queries, lazily constructed as a threadsafe dictionary + /// + private readonly Lazy> queries = + new Lazy>(() => new ConcurrentDictionary()); + + private SqlToolsSettings Settings { get { return WorkspaceService.Instance.CurrentSettings; } } + + #endregion + + /// + /// Initializes the service with the service host, registers request handlers and shutdown + /// event handler. + /// + /// The service host instance to register with + public void InitializeService(ServiceHost serviceHost) + { + // Register handlers for requests + serviceHost.SetRequestHandler(QueryExecuteRequest.Type, HandleExecuteRequest); + serviceHost.SetRequestHandler(QueryExecuteSubsetRequest.Type, HandleResultSubsetRequest); + serviceHost.SetRequestHandler(QueryDisposeRequest.Type, HandleDisposeRequest); + serviceHost.SetRequestHandler(QueryCancelRequest.Type, HandleCancelRequest); + + // Register handler for shutdown event + serviceHost.RegisterShutdownTask((shutdownParams, requestContext) => + { + Dispose(); + return Task.FromResult(0); + }); + + // Register a handler for when the configuration changes + WorkspaceService.Instance.RegisterConfigChangeCallback((oldSettings, newSettings, eventContext) => + { + Settings.QueryExecutionSettings.Update(newSettings.QueryExecutionSettings); + return Task.FromResult(0); + }); + } + + #region Request Handlers + + public async Task HandleExecuteRequest(QueryExecuteParams executeParams, + RequestContext requestContext) + { + try + { + // Get a query new active query + Query newQuery = await CreateAndActivateNewQuery(executeParams, requestContext); + + // Execute the query + await ExecuteAndCompleteQuery(executeParams, requestContext, newQuery); + } + catch (Exception e) + { + // Dump any unexpected exceptions as errors + await requestContext.SendError(e.Message); + } + } + + public async Task HandleResultSubsetRequest(QueryExecuteSubsetParams subsetParams, + RequestContext requestContext) + { + try + { + // Attempt to load the query + Query query; + if (!ActiveQueries.TryGetValue(subsetParams.OwnerUri, out query)) + { + await requestContext.SendResult(new QueryExecuteSubsetResult + { + Message = "The requested query does not exist." + }); + return; + } + + // Retrieve the requested subset and return it + var result = new QueryExecuteSubsetResult + { + Message = null, + ResultSubset = await query.GetSubset(subsetParams.BatchIndex, + subsetParams.ResultSetIndex, subsetParams.RowsStartIndex, subsetParams.RowsCount) + }; + await requestContext.SendResult(result); + } + catch (InvalidOperationException ioe) + { + // Return the error as a result + await requestContext.SendResult(new QueryExecuteSubsetResult + { + Message = ioe.Message + }); + } + catch (ArgumentOutOfRangeException aoore) + { + // Return the error as a result + await requestContext.SendResult(new QueryExecuteSubsetResult + { + Message = aoore.Message + }); + } + catch (Exception e) + { + // This was unexpected, so send back as error + await requestContext.SendError(e.Message); + } + } + + public async Task HandleDisposeRequest(QueryDisposeParams disposeParams, + RequestContext requestContext) + { + try + { + // Attempt to remove the query for the owner uri + Query result; + if (!ActiveQueries.TryRemove(disposeParams.OwnerUri, out result)) + { + await requestContext.SendResult(new QueryDisposeResult + { + Messages = "Failed to dispose query, ID not found." + }); + return; + } + + // Success + await requestContext.SendResult(new QueryDisposeResult + { + Messages = null + }); + } + catch (Exception e) + { + await requestContext.SendError(e.Message); + } + } + + public async Task HandleCancelRequest(QueryCancelParams cancelParams, + RequestContext requestContext) + { + try + { + // Attempt to find the query for the owner uri + Query result; + if (!ActiveQueries.TryGetValue(cancelParams.OwnerUri, out result)) + { + await requestContext.SendResult(new QueryCancelResult + { + Messages = "Failed to cancel query, ID not found." + }); + return; + } + + // Cancel the query + result.Cancel(); + + // Attempt to dispose the query + if (!ActiveQueries.TryRemove(cancelParams.OwnerUri, out result)) + { + // It really shouldn't be possible to get to this scenario, but we'll cover it anyhow + await requestContext.SendResult(new QueryCancelResult + { + Messages = "Query successfully cancelled, failed to dispose query. ID not found." + }); + return; + } + + await requestContext.SendResult(new QueryCancelResult()); + } + catch (InvalidOperationException e) + { + // If this exception occurred, we most likely were trying to cancel a completed query + await requestContext.SendResult(new QueryCancelResult + { + Messages = e.Message + }); + } + catch (Exception e) + { + await requestContext.SendError(e.Message); + } + } + + #endregion + + #region Private Helpers + + private async Task CreateAndActivateNewQuery(QueryExecuteParams executeParams, RequestContext requestContext) + { + try + { + // Attempt to get the connection for the editor + ConnectionInfo connectionInfo; + if (!ConnectionService.TryFindConnection(executeParams.OwnerUri, out connectionInfo)) + { + await requestContext.SendResult(new QueryExecuteResult + { + Messages = "This editor is not connected to a database." + }); + return null; + } + + // Attempt to clean out any old query on the owner URI + Query oldQuery; + if (ActiveQueries.TryGetValue(executeParams.OwnerUri, out oldQuery) && oldQuery.HasExecuted) + { + ActiveQueries.TryRemove(executeParams.OwnerUri, out oldQuery); + } + + // Retrieve the current settings for executing the query with + QueryExecutionSettings settings = WorkspaceService.Instance.CurrentSettings.QueryExecutionSettings; + + // If we can't add the query now, it's assumed the query is in progress + Query newQuery = new Query(executeParams.QueryText, connectionInfo, settings, BufferFileFactory); + if (!ActiveQueries.TryAdd(executeParams.OwnerUri, newQuery)) + { + await requestContext.SendResult(new QueryExecuteResult + { + Messages = "A query is already in progress for this editor session." + + "Please cancel this query or wait for its completion." + }); + return null; + } + + return newQuery; + } + catch (ArgumentNullException ane) + { + await requestContext.SendResult(new QueryExecuteResult { Messages = ane.Message }); + return null; + } + // Any other exceptions will fall through here and be collected at the end + } + + private async Task ExecuteAndCompleteQuery(QueryExecuteParams executeParams, RequestContext requestContext, Query query) + { + // Skip processing if the query is null + if (query == null) + { + return; + } + + // Launch the query and respond with successfully launching it + Task executeTask = query.Execute(); + await requestContext.SendResult(new QueryExecuteResult + { + Messages = null + }); + + // Wait for query execution and then send back the results + await Task.WhenAll(executeTask); + QueryExecuteCompleteParams eventParams = new QueryExecuteCompleteParams + { + OwnerUri = executeParams.OwnerUri, + BatchSummaries = query.BatchSummaries + }; + await requestContext.SendEvent(QueryExecuteCompleteEvent.Type, eventParams); + } + + #endregion + + #region IDisposable Implementation + + private bool disposed; + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (disposed) + { + return; + } + + if (disposing) + { + foreach (var query in ActiveQueries) + { + query.Value.Dispose(); + } + } + + disposed = true; + } + + ~QueryExecutionService() + { + Dispose(false); + } + + #endregion + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs new file mode 100644 index 00000000..84e18c99 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs @@ -0,0 +1,218 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage; +using Microsoft.SqlTools.ServiceLayer.Utility; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution +{ + public class ResultSet : IDisposable + { + #region Constants + + private const int DefaultMaxCharsToStore = 65535; // 64 KB - QE default + + // xml is a special case so number of chars to store is usually greater than for other long types + private const int DefaultMaxXmlCharsToStore = 2097152; // 2 MB - QE default + + #endregion + + #region Member Variables + + /// + /// For IDisposable pattern, whether or not object has been disposed + /// + private bool disposed; + + /// + /// The factory to use to get reading/writing handlers + /// + private readonly IFileStreamFactory fileStreamFactory; + + /// + /// File stream reader that will be reused to make rapid-fire retrieval of result subsets + /// quick and low perf impact. + /// + private IFileStreamReader fileStreamReader; + + /// + /// Whether or not the result set has been read in from the database + /// + private bool hasBeenRead; + + /// + /// The name of the temporary file we're using to output these results in + /// + private readonly string outputFileName; + + #endregion + + /// + /// Creates a new result set and initializes its state + /// + /// The reader from executing a query + /// Factory for creating a reader/writer + public ResultSet(DbDataReader reader, IFileStreamFactory factory) + { + // Sanity check to make sure we got a reader + if (reader == null) + { + throw new ArgumentNullException(nameof(reader), "Reader cannot be null"); + } + DataReader = new StorageDataReader(reader); + + // Initialize the storage + outputFileName = factory.CreateFile(); + FileOffsets = new LongList(); + + // Store the factory + fileStreamFactory = factory; + hasBeenRead = false; + } + + #region Properties + + /// + /// The columns for this result set + /// + public DbColumnWrapper[] Columns { get; private set; } + + /// + /// The reader to use for this resultset + /// + private StorageDataReader DataReader { get; set; } + + /// + /// A list of offsets into the buffer file that correspond to where rows start + /// + private LongList FileOffsets { get; set; } + + /// + /// Maximum number of characters to store for a field + /// + public int MaxCharsToStore { get { return DefaultMaxCharsToStore; } } + + /// + /// Maximum number of characters to store for an XML field + /// + public int MaxXmlCharsToStore { get { return DefaultMaxXmlCharsToStore; } } + + /// + /// The number of rows for this result set + /// + public long RowCount { get; private set; } + + #endregion + + #region Public Methods + + /// + /// Generates a subset of the rows from the result set + /// + /// The starting row of the results + /// How many rows to retrieve + /// A subset of results + public Task GetSubset(int startRow, int rowCount) + { + // Sanity check to make sure that the results have been read beforehand + if (!hasBeenRead || fileStreamReader == null) + { + throw new InvalidOperationException("Cannot read subset unless the results have been read from the server"); + } + + // Sanity check to make sure that the row and the row count are within bounds + if (startRow < 0 || startRow >= RowCount) + { + throw new ArgumentOutOfRangeException(nameof(startRow), "Start row cannot be less than 0 " + + "or greater than the number of rows in the resultset"); + } + if (rowCount <= 0) + { + throw new ArgumentOutOfRangeException(nameof(rowCount), "Row count must be a positive integer"); + } + + return Task.Factory.StartNew(() => + { + // Figure out which rows we need to read back + IEnumerable rowOffsets = FileOffsets.Skip(startRow).Take(rowCount); + + // Iterate over the rows we need and process them into output + object[][] rows = rowOffsets.Select(rowOffset => fileStreamReader.ReadRow(rowOffset, Columns)).ToArray(); + + // Retrieve the subset of the results as per the request + return new ResultSetSubset + { + Rows = rows, + RowCount = rows.Length + }; + }); + } + + /// + /// Reads from the reader until there are no more results to read + /// + /// Cancellation token for cancelling the query + public async Task ReadResultToEnd(CancellationToken cancellationToken) + { + // Open a writer for the file + using (IFileStreamWriter fileWriter = fileStreamFactory.GetWriter(outputFileName, MaxCharsToStore, MaxXmlCharsToStore)) + { + // If we can initialize the columns using the column schema, use that + if (!DataReader.DbDataReader.CanGetColumnSchema()) + { + throw new InvalidOperationException("Could not retrieve column schema for result set."); + } + Columns = DataReader.Columns; + long currentFileOffset = 0; + + while (await DataReader.ReadAsync(cancellationToken)) + { + RowCount++; + FileOffsets.Add(currentFileOffset); + currentFileOffset += fileWriter.WriteRow(DataReader); + } + } + + // Mark that result has been read + hasBeenRead = true; + fileStreamReader = fileStreamFactory.GetReader(outputFileName); + } + + #endregion + + #region IDisposable Implementation + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (disposed) + { + return; + } + + if (disposing) + { + fileStreamReader?.Dispose(); + fileStreamFactory.DisposeFile(outputFileName); + } + + disposed = true; + } + + #endregion + } +} diff --git a/src/ServiceHost/Session/HostDetails.cs b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/HostDetails.cs similarity index 98% rename from src/ServiceHost/Session/HostDetails.cs rename to src/Microsoft.SqlTools.ServiceLayer/SqlContext/HostDetails.cs index 1a5fc80d..1b78faa4 100644 --- a/src/ServiceHost/Session/HostDetails.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/HostDetails.cs @@ -5,7 +5,7 @@ using System; -namespace Microsoft.SqlTools.EditorServices.Session +namespace Microsoft.SqlTools.ServiceLayer.SqlContext { /// /// Contains details about the current host application (most diff --git a/src/ServiceHost/Session/ProfilePaths.cs b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/ProfilePaths.cs similarity index 98% rename from src/ServiceHost/Session/ProfilePaths.cs rename to src/Microsoft.SqlTools.ServiceLayer/SqlContext/ProfilePaths.cs index 4af38521..f841970d 100644 --- a/src/ServiceHost/Session/ProfilePaths.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/ProfilePaths.cs @@ -3,12 +3,11 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using System; using System.Collections.Generic; using System.IO; using System.Linq; -namespace Microsoft.SqlTools.EditorServices.Session +namespace Microsoft.SqlTools.ServiceLayer.SqlContext { /// /// Provides profile path resolution behavior relative to the name diff --git a/src/Microsoft.SqlTools.ServiceLayer/SqlContext/QueryExecutionSettings.cs b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/QueryExecutionSettings.cs new file mode 100644 index 00000000..4934a4da --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/QueryExecutionSettings.cs @@ -0,0 +1,38 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.SqlContext +{ + /// + /// Collection of settings related to the execution of queries + /// + public class QueryExecutionSettings + { + /// + /// Default value for batch separator (de facto standard as per SSMS) + /// + private const string DefaultBatchSeparator = "GO"; + + private string batchSeparator; + + /// + /// The configured batch separator, will use a default if a value was not configured + /// + public string BatchSeparator + { + get { return batchSeparator ?? DefaultBatchSeparator; } + set { batchSeparator = value; } + } + + /// + /// Update the current settings with the new settings + /// + /// The new settings + public void Update(QueryExecutionSettings newSettings) + { + BatchSeparator = newSettings.BatchSeparator; + } + } +} diff --git a/src/ServiceHost/Session/SqlToolsContext.cs b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsContext.cs similarity index 81% rename from src/ServiceHost/Session/SqlToolsContext.cs rename to src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsContext.cs index d8016afd..d110f28c 100644 --- a/src/ServiceHost/Session/SqlToolsContext.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsContext.cs @@ -5,8 +5,11 @@ using System; -namespace Microsoft.SqlTools.EditorServices.Session +namespace Microsoft.SqlTools.ServiceLayer.SqlContext { + /// + /// Context for SQL Tools + /// public class SqlToolsContext { /// diff --git a/src/ServiceHost/Server/LanguageServerSettings.cs b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsSettings.cs similarity index 76% rename from src/ServiceHost/Server/LanguageServerSettings.cs rename to src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsSettings.cs index be09984a..198884f2 100644 --- a/src/ServiceHost/Server/LanguageServerSettings.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsSettings.cs @@ -1,25 +1,30 @@ -// -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. -// - -using System.IO; +using System.IO; using Microsoft.SqlTools.EditorServices.Utility; -namespace Microsoft.SqlTools.EditorServices.Protocol.Server +namespace Microsoft.SqlTools.ServiceLayer.SqlContext { - public class LanguageServerSettings + /// + /// Class for serialization and deserialization of the settings the SQL Tools Service needs. + /// + public class SqlToolsSettings { + // TODO: Is this needed? I can't make sense of this comment. + // NOTE: This property is capitalized as 'SqlTools' because the + // mode name sent from the client is written as 'SqlTools' and + // JSON.net is using camelCasing. + //public ServiceHostSettings SqlTools { get; set; } + + public SqlToolsSettings() + { + this.ScriptAnalysis = new ScriptAnalysisSettings(); + this.QueryExecutionSettings = new QueryExecutionSettings(); + } + public bool EnableProfileLoading { get; set; } public ScriptAnalysisSettings ScriptAnalysis { get; set; } - public LanguageServerSettings() - { - this.ScriptAnalysis = new ScriptAnalysisSettings(); - } - - public void Update(LanguageServerSettings settings, string workspaceRootPath) + public void Update(SqlToolsSettings settings, string workspaceRootPath) { if (settings != null) { @@ -27,9 +32,13 @@ namespace Microsoft.SqlTools.EditorServices.Protocol.Server this.ScriptAnalysis.Update(settings.ScriptAnalysis, workspaceRootPath); } } - } - + public QueryExecutionSettings QueryExecutionSettings { get; set; } + } + + /// + /// Sub class for serialization and deserialization of script analysis settings + /// public class ScriptAnalysisSettings { public bool? Enable { get; set; } @@ -77,14 +86,4 @@ namespace Microsoft.SqlTools.EditorServices.Protocol.Server } } } - - - public class LanguageServerSettingsWrapper - { - // NOTE: This property is capitalized as 'SqlTools' because the - // mode name sent from the client is written as 'SqlTools' and - // JSON.net is using camelCasing. - - public LanguageServerSettings SqlTools { get; set; } - } } diff --git a/src/ServiceHost/Utility/AsyncContext.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/AsyncContext.cs similarity index 100% rename from src/ServiceHost/Utility/AsyncContext.cs rename to src/Microsoft.SqlTools.ServiceLayer/Utility/AsyncContext.cs diff --git a/src/ServiceHost/Utility/AsyncContextThread.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/AsyncContextThread.cs similarity index 100% rename from src/ServiceHost/Utility/AsyncContextThread.cs rename to src/Microsoft.SqlTools.ServiceLayer/Utility/AsyncContextThread.cs diff --git a/src/ServiceHost/Utility/AsyncLock.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/AsyncLock.cs similarity index 100% rename from src/ServiceHost/Utility/AsyncLock.cs rename to src/Microsoft.SqlTools.ServiceLayer/Utility/AsyncLock.cs diff --git a/src/ServiceHost/Utility/AsyncQueue.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/AsyncQueue.cs similarity index 100% rename from src/ServiceHost/Utility/AsyncQueue.cs rename to src/Microsoft.SqlTools.ServiceLayer/Utility/AsyncQueue.cs diff --git a/src/ServiceHost/Utility/Extensions.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/Extensions.cs similarity index 100% rename from src/ServiceHost/Utility/Extensions.cs rename to src/Microsoft.SqlTools.ServiceLayer/Utility/Extensions.cs diff --git a/src/ServiceHost/Utility/Logger.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/Logger.cs similarity index 100% rename from src/ServiceHost/Utility/Logger.cs rename to src/Microsoft.SqlTools.ServiceLayer/Utility/Logger.cs diff --git a/src/Microsoft.SqlTools.ServiceLayer/Utility/LongList.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/LongList.cs new file mode 100644 index 00000000..afacc98f --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Utility/LongList.cs @@ -0,0 +1,259 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections; +using System.Collections.Generic; + +namespace Microsoft.SqlTools.ServiceLayer.Utility +{ + /// + /// Collection class that permits storage of over int.MaxValue items. This is performed + /// by using a 2D list of lists. The internal lists are only initialized as necessary. This + /// collection implements IEnumerable to make it easier to run LINQ queries against it. + /// + /// + /// This class is based on code from $\Data Tools\SSMS_Main\sql\ssms\core\DataStorage\ArrayList64.cs + /// with additions to bring it up to .NET 4.5 standards + /// + /// Type of the values to store + public class LongList : IEnumerable + { + #region Member Variables + + private List> expandedList; + private readonly List shortList; + + #endregion + + /// + /// Creates a new long list + /// + public LongList() + { + shortList = new List(); + Count = 0; + } + + #region Properties + + /// + /// The total number of elements in the array + /// + public long Count { get; private set; } + + public T this[long index] + { + get { return GetItem(index); } + } + + #endregion + + #region Public Methods + + /// + /// Adds the specified value to the end of the list + /// + /// Value to add to the list + /// Index of the item that was just added + public long Add(T val) + { + if (Count <= int.MaxValue) + { + shortList.Add(val); + } + else // need to split values into several arrays + { + if (expandedList == null) + { + // very inefficient so delay as much as possible + // immediately add 0th array + expandedList = new List> {shortList}; + } + + int arrayIndex = (int)(Count/int.MaxValue); // 0 based + + List arr; + if (expandedList.Count <= arrayIndex) // need to make a new array + { + arr = new List(); + expandedList.Add(arr); + } + else // use existing array + { + arr = expandedList[arrayIndex]; + } + arr.Add(val); + } + return (++Count); + } + + /// + /// Returns the item at the specified index + /// + /// Index of the item to return + /// The item at the index specified + public T GetItem(long index) + { + T val = default(T); + + if (Count <= int.MaxValue) + { + int i32Index = Convert.ToInt32(index); + val = shortList[i32Index]; + } + else + { + int iArray32Index = (int) (Count/int.MaxValue); + if (expandedList.Count > iArray32Index) + { + List arr = expandedList[iArray32Index]; + + int i32Index = (int) (Count%int.MaxValue); + if (arr.Count > i32Index) + { + val = arr[i32Index]; + } + } + } + return val; + } + + /// + /// Removes an item at the specified location and shifts all the items after the provided + /// index up by one. + /// + /// The index to remove from the list + public void RemoveAt(long index) + { + if (Count <= int.MaxValue) + { + int iArray32MemberIndex = Convert.ToInt32(index); // 0 based + shortList.RemoveAt(iArray32MemberIndex); + } + else // handle the case of multiple arrays + { + // find out which array it is in + int arrayIndex = (int) (index/int.MaxValue); + List arr = expandedList[arrayIndex]; + + // find out index into this array + int iArray32MemberIndex = (int) (index%int.MaxValue); + arr.RemoveAt(iArray32MemberIndex); + + // now shift members of the array back one + int iArray32TotalIndex = (int) (Count/Int32.MaxValue); + for (int i = arrayIndex + 1; i < iArray32TotalIndex; i++) + { + List arr1 = expandedList[i - 1]; + List arr2 = expandedList[i]; + + arr1.Add(arr2[int.MaxValue - 1]); + arr2.RemoveAt(0); + } + } + --Count; + } + + #endregion + + #region IEnumerable Implementation + + /// + /// Returns a generic enumerator for enumeration of this LongList + /// + /// Enumerator for LongList + public IEnumerator GetEnumerator() + { + return new LongListEnumerator(this); + } + + /// + /// Returns an enumerator for enumeration of this LongList + /// + /// + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + #endregion + + public class LongListEnumerator : IEnumerator + { + #region Member Variables + + /// + /// The index into the list of the item that is the current item + /// + private long index; + + /// + /// The current list that we're iterating over. + /// + private readonly LongList localList; + + #endregion + + /// + /// Constructs a new enumerator for a given LongList + /// + /// The list to enumerate + public LongListEnumerator(LongList list) + { + localList = list; + index = 0; + Current = default(TEt); + } + + #region IEnumerator Implementation + + /// + /// Returns the current item in the enumeration + /// + public TEt Current { get; private set; } + + object IEnumerator.Current + { + get { return Current; } + } + + /// + /// Moves to the next item in the list we're iterating over + /// + /// Whether or not the move was successful + public bool MoveNext() + { + if (index < localList.Count) + { + Current = localList[index]; + index++; + return true; + } + Current = default(TEt); + return false; + } + + /// + /// Resets the enumeration + /// + public void Reset() + { + index = 0; + Current = default(TEt); + } + + /// + /// Disposal method. Does nothing. + /// + public void Dispose() + { + } + + #endregion + } + } +} + diff --git a/src/ServiceHost/Utility/ThreadSynchronizationContext.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/ThreadSynchronizationContext.cs similarity index 100% rename from src/ServiceHost/Utility/ThreadSynchronizationContext.cs rename to src/Microsoft.SqlTools.ServiceLayer/Utility/ThreadSynchronizationContext.cs diff --git a/src/ServiceHost/Utility/Validate.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/Validate.cs similarity index 100% rename from src/ServiceHost/Utility/Validate.cs rename to src/Microsoft.SqlTools.ServiceLayer/Utility/Validate.cs diff --git a/src/ServiceHost/Workspace/BufferPosition.cs b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/BufferPosition.cs similarity index 98% rename from src/ServiceHost/Workspace/BufferPosition.cs rename to src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/BufferPosition.cs index 8f790d85..713736a7 100644 --- a/src/ServiceHost/Workspace/BufferPosition.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/BufferPosition.cs @@ -5,7 +5,7 @@ using System.Diagnostics; -namespace Microsoft.SqlTools.EditorServices +namespace Microsoft.SqlTools.ServiceLayer.Workspace.Contracts { /// /// Provides details about a position in a file buffer. All diff --git a/src/ServiceHost/Workspace/BufferRange.cs b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/BufferRange.cs similarity index 98% rename from src/ServiceHost/Workspace/BufferRange.cs rename to src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/BufferRange.cs index 5d20598f..f8253d02 100644 --- a/src/ServiceHost/Workspace/BufferRange.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/BufferRange.cs @@ -6,7 +6,7 @@ using System; using System.Diagnostics; -namespace Microsoft.SqlTools.EditorServices +namespace Microsoft.SqlTools.ServiceLayer.Workspace.Contracts { /// /// Provides details about a range between two positions in diff --git a/src/ServiceHost/LanguageServer/Configuration.cs b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/Configuration.cs similarity index 80% rename from src/ServiceHost/LanguageServer/Configuration.cs rename to src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/Configuration.cs index b9ad87db..af50835f 100644 --- a/src/ServiceHost/LanguageServer/Configuration.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/Configuration.cs @@ -3,9 +3,9 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; -namespace Microsoft.SqlTools.EditorServices.Protocol.LanguageServer +namespace Microsoft.SqlTools.ServiceLayer.Workspace.Contracts { public class DidChangeConfigurationNotification { diff --git a/src/ServiceHost/Workspace/FileChange.cs b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/FileChange.cs similarity index 94% rename from src/ServiceHost/Workspace/FileChange.cs rename to src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/FileChange.cs index 2f6efdf8..e3d49471 100644 --- a/src/ServiceHost/Workspace/FileChange.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/FileChange.cs @@ -3,7 +3,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -namespace Microsoft.SqlTools.EditorServices +namespace Microsoft.SqlTools.ServiceLayer.Workspace.Contracts { /// /// Contains details relating to a content change in an open file. diff --git a/src/ServiceHost/Workspace/FilePosition.cs b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/FilePosition.cs similarity index 98% rename from src/ServiceHost/Workspace/FilePosition.cs rename to src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/FilePosition.cs index 2cb58745..65fe268a 100644 --- a/src/ServiceHost/Workspace/FilePosition.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/FilePosition.cs @@ -3,7 +3,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -namespace Microsoft.SqlTools.EditorServices +namespace Microsoft.SqlTools.ServiceLayer.Workspace.Contracts { /// /// Provides details and operations for a buffer position in a diff --git a/src/ServiceHost/Workspace/ScriptFile.cs b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/ScriptFile.cs similarity index 95% rename from src/ServiceHost/Workspace/ScriptFile.cs rename to src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/ScriptFile.cs index 90d66244..74592ae5 100644 --- a/src/ServiceHost/Workspace/ScriptFile.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/ScriptFile.cs @@ -9,19 +9,13 @@ using System.Collections.Generic; using System.IO; using System.Linq; -namespace Microsoft.SqlTools.EditorServices +namespace Microsoft.SqlTools.ServiceLayer.Workspace.Contracts { /// /// Contains the details and contents of an open script file. /// public class ScriptFile { - #region Private Fields - - private Version SqlToolsVersion; - - #endregion - #region Properties /// @@ -40,9 +34,10 @@ namespace Microsoft.SqlTools.EditorServices public string FilePath { get; private set; } /// - /// Gets the path which the editor client uses to identify this file. + /// Gets or sets the path which the editor client uses to identify this file. + /// Setter for testing purposes only /// - public string ClientFilePath { get; private set; } + public string ClientFilePath { get; internal set; } /// /// Gets or sets a boolean that determines whether @@ -58,7 +53,8 @@ namespace Microsoft.SqlTools.EditorServices public bool IsInMemory { get; private set; } /// - /// Gets a string containing the full contents of the file. + /// Gets or sets a string containing the full contents of the file. + /// Setter for testing purposes only /// public string Contents { @@ -66,6 +62,10 @@ namespace Microsoft.SqlTools.EditorServices { return string.Join("\r\n", this.FileLines); } + set + { + this.FileLines = value != null ? value.Split('\n') : null; + } } /// @@ -106,6 +106,14 @@ namespace Microsoft.SqlTools.EditorServices #region Constructors + /// + /// Add a default constructor for testing + /// + internal ScriptFile() + { + ClientFilePath = "test.sql"; + } + /// /// Creates a new ScriptFile instance by reading file contents from /// the given TextReader. @@ -113,18 +121,15 @@ namespace Microsoft.SqlTools.EditorServices /// The path at which the script file resides. /// The path which the client uses to identify the file. /// The TextReader to use for reading the file's contents. - /// The version of SqlTools for which the script is being parsed. public ScriptFile( string filePath, string clientFilePath, - TextReader textReader, - Version SqlToolsVersion) + TextReader textReader) { this.FilePath = filePath; this.ClientFilePath = clientFilePath; this.IsAnalysisEnabled = true; this.IsInMemory = Workspace.IsPathInMemory(filePath); - this.SqlToolsVersion = SqlToolsVersion; this.SetFileContents(textReader.ReadToEnd()); } @@ -135,17 +140,14 @@ namespace Microsoft.SqlTools.EditorServices /// The path at which the script file resides. /// The path which the client uses to identify the file. /// The initial contents of the script file. - /// The version of SqlTools for which the script is being parsed. public ScriptFile( string filePath, string clientFilePath, - string initialBuffer, - Version SqlToolsVersion) + string initialBuffer) { this.FilePath = filePath; this.ClientFilePath = clientFilePath; this.IsAnalysisEnabled = true; - this.SqlToolsVersion = SqlToolsVersion; this.SetFileContents(initialBuffer); } @@ -433,11 +435,11 @@ namespace Microsoft.SqlTools.EditorServices return new BufferRange(startPosition, endPosition); } - #endregion - - #region Private Methods - - private void SetFileContents(string fileContents) + /// + /// Set the script files contents + /// + /// + public void SetFileContents(string fileContents) { // Split the file contents into lines and trim // any carriage returns from the strings. @@ -451,6 +453,10 @@ namespace Microsoft.SqlTools.EditorServices this.ParseFileContents(); } + #endregion + + #region Private Methods + /// /// Parses the current file contents to get the AST, tokens, /// and parse errors. diff --git a/src/ServiceHost/Workspace/ScriptFileMarker.cs b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/ScriptFileMarker.cs similarity index 95% rename from src/ServiceHost/Workspace/ScriptFileMarker.cs rename to src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/ScriptFileMarker.cs index 87c2576c..743ef3d6 100644 --- a/src/ServiceHost/Workspace/ScriptFileMarker.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/ScriptFileMarker.cs @@ -3,7 +3,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -namespace Microsoft.SqlTools.EditorServices +namespace Microsoft.SqlTools.ServiceLayer.Workspace.Contracts { /// /// Defines the message level of a script file marker. diff --git a/src/ServiceHost/Workspace/ScriptRegion.cs b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/ScriptRegion.cs similarity index 97% rename from src/ServiceHost/Workspace/ScriptRegion.cs rename to src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/ScriptRegion.cs index f2fa4ac8..e68400e9 100644 --- a/src/ServiceHost/Workspace/ScriptRegion.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/ScriptRegion.cs @@ -5,7 +5,7 @@ //using System.Management.Automation.Language; -namespace Microsoft.SqlTools.EditorServices +namespace Microsoft.SqlTools.ServiceLayer.Workspace.Contracts { /// /// Contains details about a specific region of text in script file. diff --git a/src/ServiceHost/LanguageServer/TextDocument.cs b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/TextDocument.cs similarity index 69% rename from src/ServiceHost/LanguageServer/TextDocument.cs rename to src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/TextDocument.cs index 9f477374..cf3f8468 100644 --- a/src/ServiceHost/LanguageServer/TextDocument.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/TextDocument.cs @@ -4,9 +4,9 @@ // using System.Diagnostics; -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; -namespace Microsoft.SqlTools.EditorServices.Protocol.LanguageServer +namespace Microsoft.SqlTools.ServiceLayer.Workspace.Contracts { /// /// Defines a base parameter class for identifying a text document. @@ -19,37 +19,70 @@ namespace Microsoft.SqlTools.EditorServices.Protocol.LanguageServer /// text document. /// public string Uri { get; set; } - } + } /// /// Defines a position in a text document. /// [DebuggerDisplay("TextDocumentPosition = {Position.Line}:{Position.Character}")] - public class TextDocumentPosition : TextDocumentIdentifier + public class TextDocumentPosition { + /// + /// Gets or sets the document identifier. + /// + public TextDocumentIdentifier TextDocument { get; set; } + /// /// Gets or sets the position in the document. /// public Position Position { get; set; } } - public class DidOpenTextDocumentNotification : TextDocumentIdentifier + /// + /// Defines a text document. + /// + [DebuggerDisplay("TextDocumentItem = {Uri}")] + public class TextDocumentItem + { + /// + /// Gets or sets the URI which identifies the path of the + /// text document. + /// + public string Uri { get; set; } + + /// + /// Gets or sets the language of the document + /// + public string LanguageId { get; set; } + + /// + /// Gets or sets the version of the document + /// + public int Version { get; set; } + + /// + /// Gets or sets the full content of the document. + /// + public string Text { get; set; } + } + + public class DidOpenTextDocumentNotification { public static readonly EventType Type = EventType.Create("textDocument/didOpen"); /// - /// Gets or sets the full content of the opened document. + /// Gets or sets the opened document. /// - public string Text { get; set; } + public TextDocumentItem TextDocument { get; set; } } public class DidCloseTextDocumentNotification { public static readonly - EventType Type = - EventType.Create("textDocument/didClose"); + EventType Type = + EventType.Create("textDocument/didClose"); } public class DidChangeTextDocumentNotification @@ -59,9 +92,20 @@ namespace Microsoft.SqlTools.EditorServices.Protocol.LanguageServer EventType.Create("textDocument/didChange"); } - public class DidChangeTextDocumentParams : TextDocumentIdentifier + public class DidCloseTextDocumentParams { - public TextDocumentUriChangeEvent TextDocument { get; set; } + /// + /// Gets or sets the closed document. + /// + public TextDocumentItem TextDocument { get; set; } + } + + public class DidChangeTextDocumentParams + { + /// + /// Gets or sets the changed document. + /// + public VersionedTextDocumentIdentifier TextDocument { get; set; } /// /// Gets or sets the list of changes to the document content. @@ -69,13 +113,11 @@ namespace Microsoft.SqlTools.EditorServices.Protocol.LanguageServer public TextDocumentChangeEvent[] ContentChanges { get; set; } } - public class TextDocumentUriChangeEvent - { - /// - /// Gets or sets the Uri of the changed text document - /// - public string Uri { get; set; } - + /// + /// Define a specific version of a text document + /// + public class VersionedTextDocumentIdentifier : TextDocumentIdentifier + { /// /// Gets or sets the Version of the changed text document /// diff --git a/src/ServiceHost/LanguageServer/WorkspaceSymbols.cs b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/WorkspaceSymbols.cs similarity index 70% rename from src/ServiceHost/LanguageServer/WorkspaceSymbols.cs rename to src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/WorkspaceSymbols.cs index 25a554b5..93140df3 100644 --- a/src/ServiceHost/LanguageServer/WorkspaceSymbols.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Contracts/WorkspaceSymbols.cs @@ -3,9 +3,9 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; -namespace Microsoft.SqlTools.EditorServices.Protocol.LanguageServer +namespace Microsoft.SqlTools.ServiceLayer.Workspace.Contracts { public enum SymbolKind { @@ -43,8 +43,16 @@ namespace Microsoft.SqlTools.EditorServices.Protocol.LanguageServer public class DocumentSymbolRequest { public static readonly - RequestType Type = - RequestType.Create("textDocument/documentSymbol"); + RequestType Type = + RequestType.Create("textDocument/documentSymbol"); + } + + /// + /// Defines a set of parameters to send document symbol request + /// + public class DocumentSymbolParams + { + public TextDocumentIdentifier TextDocument { get; set; } } public class WorkspaceSymbolRequest diff --git a/src/ServiceHost/Workspace/Workspace.cs b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Workspace.cs similarity index 90% rename from src/ServiceHost/Workspace/Workspace.cs rename to src/Microsoft.SqlTools.ServiceLayer/Workspace/Workspace.cs index 39e1d70f..3099a3d5 100644 --- a/src/ServiceHost/Workspace/Workspace.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Workspace.cs @@ -3,25 +3,25 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.EditorServices.Utility; using System; using System.Collections.Generic; using System.IO; using System.Text; using System.Text.RegularExpressions; using System.Linq; +using Microsoft.SqlTools.EditorServices.Utility; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; -namespace Microsoft.SqlTools.EditorServices +namespace Microsoft.SqlTools.ServiceLayer.Workspace { /// /// Manages a "workspace" of script files that are open for a particular /// editing session. Also helps to navigate references between ScriptFiles. /// - public class Workspace + public class Workspace : IDisposable { - #region Private Fields + #region Private Fields - private Version SqlToolsVersion; private Dictionary workspaceFiles = new Dictionary(); #endregion @@ -40,10 +40,8 @@ namespace Microsoft.SqlTools.EditorServices /// /// Creates a new instance of the Workspace class. /// - /// The version of SqlTools for which scripts will be parsed. - public Workspace(Version SqlToolsVersion) + public Workspace() { - this.SqlToolsVersion = SqlToolsVersion; } #endregion @@ -78,12 +76,7 @@ namespace Microsoft.SqlTools.EditorServices using (FileStream fileStream = new FileStream(resolvedFilePath, FileMode.Open, FileAccess.Read)) using (StreamReader streamReader = new StreamReader(fileStream, Encoding.UTF8)) { - scriptFile = - new ScriptFile( - resolvedFilePath, - filePath, - streamReader, - this.SqlToolsVersion); + scriptFile = new ScriptFile(resolvedFilePath, filePath,streamReader); this.workspaceFiles.Add(keyName, scriptFile); } @@ -131,6 +124,7 @@ namespace Microsoft.SqlTools.EditorServices // type SqlTools have a path starting with 'untitled'. return filePath.StartsWith("inmemory") || + filePath.StartsWith("tsqloutput") || filePath.StartsWith("untitled"); } @@ -169,12 +163,7 @@ namespace Microsoft.SqlTools.EditorServices ScriptFile scriptFile = null; if (!this.workspaceFiles.TryGetValue(keyName, out scriptFile)) { - scriptFile = - new ScriptFile( - resolvedFilePath, - filePath, - initialBuffer, - this.SqlToolsVersion); + scriptFile = new ScriptFile(resolvedFilePath, filePath, initialBuffer); this.workspaceFiles.Add(keyName, scriptFile); @@ -244,5 +233,17 @@ namespace Microsoft.SqlTools.EditorServices } #endregion + + #region IDisposable Implementation + + /// + /// Disposes of any Runspaces that were created for the + /// services used in this session. + /// + public void Dispose() + { + } + + #endregion } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs b/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs new file mode 100644 index 00000000..0e5e4a25 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs @@ -0,0 +1,267 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.SqlTools.EditorServices.Utility; +using Microsoft.SqlTools.ServiceLayer.Hosting; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Workspace +{ + /// + /// Class for handling requests/events that deal with the state of the workspace, including the + /// opening and closing of files, the changing of configuration, etc. + /// + /// + /// The type of the class used for serializing and deserializing the configuration. Must be the + /// actual type of the instance otherwise deserialization will be incomplete. + /// + public class WorkspaceService where TConfig : class, new() + { + + #region Singleton Instance Implementation + + private static readonly Lazy> instance = new Lazy>(() => new WorkspaceService()); + + public static WorkspaceService Instance + { + get { return instance.Value; } + } + + /// + /// Default, parameterless constructor. + /// TODO: Figure out how to make this truely singleton even with dependency injection for tests + /// + public WorkspaceService() + { + ConfigChangeCallbacks = new List(); + TextDocChangeCallbacks = new List(); + TextDocOpenCallbacks = new List(); + + CurrentSettings = new TConfig(); + } + + #endregion + + #region Properties + + public Workspace Workspace { get; private set; } + + public TConfig CurrentSettings { get; internal set; } + + /// + /// Delegate for callbacks that occur when the configuration for the workspace changes + /// + /// The settings that were just set + /// The settings before they were changed + /// Context of the event that triggered the callback + /// + public delegate Task ConfigChangeCallback(TConfig newSettings, TConfig oldSettings, EventContext eventContext); + + /// + /// Delegate for callbacks that occur when the current text document changes + /// + /// Array of files that changed + /// Context of the event raised for the changed files + public delegate Task TextDocChangeCallback(ScriptFile[] changedFiles, EventContext eventContext); + + /// + /// Delegate for callbacks that occur when a text document is opened + /// + /// File that was opened + /// Context of the event raised for the changed files + public delegate Task TextDocOpenCallback(ScriptFile openFile, EventContext eventContext); + + /// + /// List of callbacks to call when the configuration of the workspace changes + /// + private List ConfigChangeCallbacks { get; set; } + + /// + /// List of callbacks to call when the current text document changes + /// + private List TextDocChangeCallbacks { get; set; } + + /// + /// List of callbacks to call when a text document is opened + /// + private List TextDocOpenCallbacks { get; set; } + + + #endregion + + #region Public Methods + + public void InitializeService(ServiceHost serviceHost) + { + // Create a workspace that will handle state for the session + Workspace = new Workspace(); + + // Register the handlers for when changes to the workspae occur + serviceHost.SetEventHandler(DidChangeTextDocumentNotification.Type, HandleDidChangeTextDocumentNotification); + serviceHost.SetEventHandler(DidOpenTextDocumentNotification.Type, HandleDidOpenTextDocumentNotification); + serviceHost.SetEventHandler(DidCloseTextDocumentNotification.Type, HandleDidCloseTextDocumentNotification); + serviceHost.SetEventHandler(DidChangeConfigurationNotification.Type, HandleDidChangeConfigurationNotification); + + // Register an initialization handler that sets the workspace path + serviceHost.RegisterInitializeTask(async (parameters, contect) => + { + Logger.Write(LogLevel.Verbose, "Initializing workspace service"); + + if (Workspace != null) + { + Workspace.WorkspacePath = parameters.RootPath; + } + await Task.FromResult(0); + }); + + // Register a shutdown request that disposes the workspace + serviceHost.RegisterShutdownTask(async (parameters, context) => + { + Logger.Write(LogLevel.Verbose, "Shutting down workspace service"); + + if (Workspace != null) + { + Workspace.Dispose(); + Workspace = null; + } + await Task.FromResult(0); + }); + } + + /// + /// Adds a new task to be called when the configuration has been changed. Use this to + /// handle changing configuration and changing the current configuration. + /// + /// Task to handle the request + public void RegisterConfigChangeCallback(ConfigChangeCallback task) + { + ConfigChangeCallbacks.Add(task); + } + + /// + /// Adds a new task to be called when the text of a document changes. + /// + /// Delegate to call when the document changes + public void RegisterTextDocChangeCallback(TextDocChangeCallback task) + { + TextDocChangeCallbacks.Add(task); + } + + /// + /// Adds a new task to be called when a file is opened + /// + /// Delegate to call when a document is opened + public void RegisterTextDocOpenCallback(TextDocOpenCallback task) + { + TextDocOpenCallbacks.Add(task); + } + + #endregion + + #region Event Handlers + + /// + /// Handles text document change events + /// + protected Task HandleDidChangeTextDocumentNotification( + DidChangeTextDocumentParams textChangeParams, + EventContext eventContext) + { + StringBuilder msg = new StringBuilder(); + msg.Append("HandleDidChangeTextDocumentNotification"); + List changedFiles = new List(); + + // A text change notification can batch multiple change requests + foreach (var textChange in textChangeParams.ContentChanges) + { + string fileUri = textChangeParams.TextDocument.Uri ?? textChangeParams.TextDocument.Uri; + msg.AppendLine(string.Format(" File: {0}", fileUri)); + + ScriptFile changedFile = Workspace.GetFile(fileUri); + + changedFile.ApplyChange( + GetFileChangeDetails( + textChange.Range.Value, + textChange.Text)); + + changedFiles.Add(changedFile); + } + + Logger.Write(LogLevel.Verbose, msg.ToString()); + + var handlers = TextDocChangeCallbacks.Select(t => t(changedFiles.ToArray(), eventContext)); + return Task.WhenAll(handlers); + } + + protected async Task HandleDidOpenTextDocumentNotification( + DidOpenTextDocumentNotification openParams, + EventContext eventContext) + { + Logger.Write(LogLevel.Verbose, "HandleDidOpenTextDocumentNotification"); + + // read the SQL file contents into the ScriptFile + ScriptFile openedFile = Workspace.GetFileBuffer(openParams.TextDocument.Uri, openParams.TextDocument.Text); + + // Propagate the changes to the event handlers + var textDocOpenTasks = TextDocOpenCallbacks.Select( + t => t(openedFile, eventContext)); + + await Task.WhenAll(textDocOpenTasks); + } + + protected Task HandleDidCloseTextDocumentNotification( + DidCloseTextDocumentParams closeParams, + EventContext eventContext) + { + Logger.Write(LogLevel.Verbose, "HandleDidCloseTextDocumentNotification"); + return Task.FromResult(true); + } + + /// + /// Handles the configuration change event + /// + protected async Task HandleDidChangeConfigurationNotification( + DidChangeConfigurationParams configChangeParams, + EventContext eventContext) + { + Logger.Write(LogLevel.Verbose, "HandleDidChangeConfigurationNotification"); + + // Propagate the changes to the event handlers + var configUpdateTasks = ConfigChangeCallbacks.Select( + t => t(configChangeParams.Settings, CurrentSettings, eventContext)); + await Task.WhenAll(configUpdateTasks); + } + + #endregion + + #region Private Helpers + + /// + /// Switch from 0-based offsets to 1 based offsets + /// + /// + /// + private static FileChange GetFileChangeDetails(Range changeRange, string insertString) + { + // The protocol's positions are zero-based so add 1 to all offsets + return new FileChange + { + InsertString = insertString, + Line = changeRange.Start.Line + 1, + Offset = changeRange.Start.Character + 1, + EndLine = changeRange.End.Line + 1, + EndOffset = changeRange.End.Character + 1 + }; + } + + #endregion + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/project.json b/src/Microsoft.SqlTools.ServiceLayer/project.json new file mode 100644 index 00000000..02e977aa --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/project.json @@ -0,0 +1,41 @@ +{ + "name": "Microsoft.SqlTools.ServiceLayer", + "version": "1.0.0-*", + "buildOptions": { + "debugType": "portable", + "emitEntryPoint": true + }, + "dependencies": { + "Newtonsoft.Json": "9.0.1", + "Microsoft.SqlServer.SqlParser": "140.1.5", + "System.Data.Common": "4.1.0", + "System.Data.SqlClient": "4.1.0", + "Microsoft.SqlServer.Smo": "140.1.5", + "System.Security.SecureString": "4.0.0", + "System.Collections.Specialized": "4.0.1", + "System.ComponentModel.TypeConverter": "4.1.0", + "System.Diagnostics.TraceSource": "4.0.0", + "NETStandard.Library": "1.6.0", + "Microsoft.NETCore.Runtime.CoreCLR": "1.0.2", + "Microsoft.NETCore.DotNetHostPolicy": "1.0.1", + "System.Diagnostics.Process": "4.1.0", + "System.Threading.Thread": "4.0.0" + }, + "frameworks": { + "netcoreapp1.0": { + "imports": "dnxcore50" + } + }, + "runtimes": { + "win7-x64": {}, + "win7-x86": {}, + "osx.10.11-x64": {}, + "ubuntu.14.04-x64": {}, + "ubuntu.16.04-x64": {}, + "centos.7-x64": {}, + "rhel.7.2-x64": {}, + "debian.8-x64": {}, + "fedora.23-x64": {}, + "opensuse.13.2-x64": {} + } +} diff --git a/src/ServiceHost/LanguageSupport/LanguageService.cs b/src/ServiceHost/LanguageSupport/LanguageService.cs deleted file mode 100644 index 3ab77697..00000000 --- a/src/ServiceHost/LanguageSupport/LanguageService.cs +++ /dev/null @@ -1,57 +0,0 @@ -// -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. -// - -using Microsoft.SqlTools.EditorServices; -using Microsoft.SqlTools.EditorServices.Session; - -namespace Microsoft.SqlTools.LanguageSupport -{ - /// - /// Main class for Language Service functionality - /// - public class LanguageService - { - /// - /// Gets or sets the current SQL Tools context - /// - /// - private SqlToolsContext Context { get; set; } - - /// - /// Constructor for the Language Service class - /// - /// - public LanguageService(SqlToolsContext context) - { - this.Context = context; - } - - /// - /// Gets a list of semantic diagnostic marks for the provided script file - /// - /// - public ScriptFileMarker[] GetSemanticMarkers(ScriptFile scriptFile) - { - // the commented out snippet is an example of how to create a error marker - // semanticMarkers = new ScriptFileMarker[1]; - // semanticMarkers[0] = new ScriptFileMarker() - // { - // Message = "Error message", - // Level = ScriptFileMarkerLevel.Error, - // ScriptRegion = new ScriptRegion() - // { - // File = scriptFile.FilePath, - // StartLineNumber = 2, - // StartColumnNumber = 2, - // StartOffset = 0, - // EndLineNumber = 4, - // EndColumnNumber = 10, - // EndOffset = 0 - // } - // }; - return new ScriptFileMarker[0]; - } - } -} diff --git a/src/ServiceHost/Program.cs b/src/ServiceHost/Program.cs deleted file mode 100644 index 6bfd0f24..00000000 --- a/src/ServiceHost/Program.cs +++ /dev/null @@ -1,40 +0,0 @@ -// -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. -using System; -using Microsoft.SqlTools.EditorServices.Protocol.Server; -using Microsoft.SqlTools.EditorServices.Session; -using Microsoft.SqlTools.EditorServices.Utility; - -namespace Microsoft.SqlTools.ServiceHost -{ - /// - /// Main application class for SQL Tools API Service Host executable - /// - class Program - { - /// - /// Main entry point into the SQL Tools API Service Host - /// - static void Main(string[] args) - { - // turn on Verbose logging during early development - // we need to switch to Normal when preparing for public preview - Logger.Initialize(minimumLogLevel: LogLevel.Verbose); - Logger.Write(LogLevel.Normal, "Starting SQL Tools Service Host"); - - const string hostName = "SQL Tools Service Host"; - const string hostProfileId = "SQLToolsService"; - Version hostVersion = new Version(1,0); - - // set up the host details and profile paths - var hostDetails = new HostDetails(hostName, hostProfileId, hostVersion); - var profilePaths = new ProfilePaths(hostProfileId, "baseAllUsersPath", "baseCurrentUserPath"); - - // create and run the language server - var languageServer = new LanguageServer(hostDetails, profilePaths); - languageServer.Start().Wait(); - languageServer.WaitForExit(); - } - } -} diff --git a/src/ServiceHost/Server/LanguageServer.cs b/src/ServiceHost/Server/LanguageServer.cs deleted file mode 100644 index d6719141..00000000 --- a/src/ServiceHost/Server/LanguageServer.cs +++ /dev/null @@ -1,517 +0,0 @@ -// -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. -// -using Microsoft.SqlTools.EditorServices.Protocol.LanguageServer; -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol; -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol.Channel; -using Microsoft.SqlTools.EditorServices.Session; -using System.Threading.Tasks; -using Microsoft.SqlTools.EditorServices.Utility; -using System.Collections.Generic; -using System.Text; -using System.Threading; -using System.Linq; -using System; - -namespace Microsoft.SqlTools.EditorServices.Protocol.Server -{ - /// - /// SQL Tools VS Code Language Server request handler - /// - public class LanguageServer : LanguageServerBase - { - private static CancellationTokenSource existingRequestCancellation; - - private LanguageServerSettings currentSettings = new LanguageServerSettings(); - - private EditorSession editorSession; - - /// - /// Provides details about the host application. - /// - public LanguageServer(HostDetails hostDetails, ProfilePaths profilePaths) - : base(new StdioServerChannel()) - { - this.editorSession = new EditorSession(); - this.editorSession.StartSession(hostDetails, profilePaths); - } - - /// - /// Initialize the VS Code request/response callbacks - /// - protected override void Initialize() - { - // Register all supported message types - this.SetRequestHandler(InitializeRequest.Type, this.HandleInitializeRequest); - this.SetEventHandler(DidChangeTextDocumentNotification.Type, this.HandleDidChangeTextDocumentNotification); - this.SetEventHandler(DidOpenTextDocumentNotification.Type, this.HandleDidOpenTextDocumentNotification); - this.SetEventHandler(DidCloseTextDocumentNotification.Type, this.HandleDidCloseTextDocumentNotification); - this.SetEventHandler(DidChangeConfigurationNotification.Type, this.HandleDidChangeConfigurationNotification); - - this.SetRequestHandler(DefinitionRequest.Type, this.HandleDefinitionRequest); - this.SetRequestHandler(ReferencesRequest.Type, this.HandleReferencesRequest); - this.SetRequestHandler(CompletionRequest.Type, this.HandleCompletionRequest); - this.SetRequestHandler(CompletionResolveRequest.Type, this.HandleCompletionResolveRequest); - this.SetRequestHandler(SignatureHelpRequest.Type, this.HandleSignatureHelpRequest); - this.SetRequestHandler(DocumentHighlightRequest.Type, this.HandleDocumentHighlightRequest); - this.SetRequestHandler(HoverRequest.Type, this.HandleHoverRequest); - this.SetRequestHandler(DocumentSymbolRequest.Type, this.HandleDocumentSymbolRequest); - this.SetRequestHandler(WorkspaceSymbolRequest.Type, this.HandleWorkspaceSymbolRequest); - } - - /// - /// Handles the shutdown event for the Language Server - /// - protected override async Task Shutdown() - { - Logger.Write(LogLevel.Normal, "Language service is shutting down..."); - - if (this.editorSession != null) - { - this.editorSession.Dispose(); - this.editorSession = null; - } - - await Task.FromResult(true); - } - - /// - /// Handles the initialization request - /// - /// - /// - /// - protected async Task HandleInitializeRequest( - InitializeRequest initializeParams, - RequestContext requestContext) - { - Logger.Write(LogLevel.Verbose, "HandleDidChangeTextDocumentNotification"); - - // Grab the workspace path from the parameters - editorSession.Workspace.WorkspacePath = initializeParams.RootPath; - - await requestContext.SendResult( - new InitializeResult - { - Capabilities = new ServerCapabilities - { - TextDocumentSync = TextDocumentSyncKind.Incremental, - DefinitionProvider = true, - ReferencesProvider = true, - DocumentHighlightProvider = true, - DocumentSymbolProvider = true, - WorkspaceSymbolProvider = true, - HoverProvider = true, - CompletionProvider = new CompletionOptions - { - ResolveProvider = true, - TriggerCharacters = new string[] { ".", "-", ":", "\\" } - }, - SignatureHelpProvider = new SignatureHelpOptions - { - TriggerCharacters = new string[] { " " } // TODO: Other characters here? - } - } - }); - } - - /// - /// Handles text document change events - /// - /// - /// - /// - protected Task HandleDidChangeTextDocumentNotification( - DidChangeTextDocumentParams textChangeParams, - EventContext eventContext) - { - StringBuilder msg = new StringBuilder(); - msg.Append("HandleDidChangeTextDocumentNotification"); - List changedFiles = new List(); - - // A text change notification can batch multiple change requests - foreach (var textChange in textChangeParams.ContentChanges) - { - string fileUri = textChangeParams.TextDocument.Uri; - msg.AppendLine(); - msg.Append(" File: "); - msg.Append(fileUri); - - ScriptFile changedFile = editorSession.Workspace.GetFile(fileUri); - - changedFile.ApplyChange( - GetFileChangeDetails( - textChange.Range.Value, - textChange.Text)); - - changedFiles.Add(changedFile); - } - - Logger.Write(LogLevel.Verbose, msg.ToString()); - - this.RunScriptDiagnostics( - changedFiles.ToArray(), - editorSession, - eventContext); - - return Task.FromResult(true); - } - - protected Task HandleDidOpenTextDocumentNotification( - DidOpenTextDocumentNotification openParams, - EventContext eventContext) - { - Logger.Write(LogLevel.Verbose, "HandleDidOpenTextDocumentNotification"); - return Task.FromResult(true); - } - - protected Task HandleDidCloseTextDocumentNotification( - TextDocumentIdentifier closeParams, - EventContext eventContext) - { - Logger.Write(LogLevel.Verbose, "HandleDidCloseTextDocumentNotification"); - return Task.FromResult(true); - } - - /// - /// Handles the configuration change event - /// - /// - /// - protected async Task HandleDidChangeConfigurationNotification( - DidChangeConfigurationParams configChangeParams, - EventContext eventContext) - { - Logger.Write(LogLevel.Verbose, "HandleDidChangeConfigurationNotification"); - - bool oldLoadProfiles = this.currentSettings.EnableProfileLoading; - bool oldScriptAnalysisEnabled = - this.currentSettings.ScriptAnalysis.Enable.HasValue; - string oldScriptAnalysisSettingsPath = - this.currentSettings.ScriptAnalysis.SettingsPath; - - this.currentSettings.Update( - configChangeParams.Settings.SqlTools, - this.editorSession.Workspace.WorkspacePath); - - // If script analysis settings have changed we need to clear & possibly update the current diagnostic records. - if ((oldScriptAnalysisEnabled != this.currentSettings.ScriptAnalysis.Enable)) - { - // If the user just turned off script analysis or changed the settings path, send a diagnostics - // event to clear the analysis markers that they already have. - if (!this.currentSettings.ScriptAnalysis.Enable.Value) - { - ScriptFileMarker[] emptyAnalysisDiagnostics = new ScriptFileMarker[0]; - - foreach (var scriptFile in editorSession.Workspace.GetOpenedFiles()) - { - await PublishScriptDiagnostics( - scriptFile, - emptyAnalysisDiagnostics, - eventContext); - } - } - else - { - await this.RunScriptDiagnostics( - this.editorSession.Workspace.GetOpenedFiles(), - this.editorSession, - eventContext); - } - } - - await Task.FromResult(true); - } - - protected async Task HandleDefinitionRequest( - TextDocumentPosition textDocumentPosition, - RequestContext requestContext) - { - Logger.Write(LogLevel.Verbose, "HandleDefinitionRequest"); - await Task.FromResult(true); - } - - protected async Task HandleReferencesRequest( - ReferencesParams referencesParams, - RequestContext requestContext) - { - Logger.Write(LogLevel.Verbose, "HandleReferencesRequest"); - await Task.FromResult(true); - } - - protected async Task HandleCompletionRequest( - TextDocumentPosition textDocumentPosition, - RequestContext requestContext) - { - Logger.Write(LogLevel.Verbose, "HandleCompletionRequest"); - await Task.FromResult(true); - } - - protected async Task HandleCompletionResolveRequest( - CompletionItem completionItem, - RequestContext requestContext) - { - Logger.Write(LogLevel.Verbose, "HandleCompletionResolveRequest"); - await Task.FromResult(true); - } - - protected async Task HandleSignatureHelpRequest( - TextDocumentPosition textDocumentPosition, - RequestContext requestContext) - { - Logger.Write(LogLevel.Verbose, "HandleSignatureHelpRequest"); - await Task.FromResult(true); - } - - protected async Task HandleDocumentHighlightRequest( - TextDocumentPosition textDocumentPosition, - RequestContext requestContext) - { - Logger.Write(LogLevel.Verbose, "HandleDocumentHighlightRequest"); - await Task.FromResult(true); - } - - protected async Task HandleHoverRequest( - TextDocumentPosition textDocumentPosition, - RequestContext requestContext) - { - Logger.Write(LogLevel.Verbose, "HandleHoverRequest"); - await Task.FromResult(true); - } - - protected async Task HandleDocumentSymbolRequest( - TextDocumentIdentifier textDocumentIdentifier, - RequestContext requestContext) - { - Logger.Write(LogLevel.Verbose, "HandleDocumentSymbolRequest"); - await Task.FromResult(true); - } - - protected async Task HandleWorkspaceSymbolRequest( - WorkspaceSymbolParams workspaceSymbolParams, - RequestContext requestContext) - { - Logger.Write(LogLevel.Verbose, "HandleWorkspaceSymbolRequest"); - await Task.FromResult(true); - } - - /// - /// Runs script diagnostics on changed files - /// - /// - /// - /// - private Task RunScriptDiagnostics( - ScriptFile[] filesToAnalyze, - EditorSession editorSession, - EventContext eventContext) - { - if (!this.currentSettings.ScriptAnalysis.Enable.Value) - { - // If the user has disabled script analysis, skip it entirely - return Task.FromResult(true); - } - - // If there's an existing task, attempt to cancel it - try - { - if (existingRequestCancellation != null) - { - // Try to cancel the request - existingRequestCancellation.Cancel(); - - // If cancellation didn't throw an exception, - // clean up the existing token - existingRequestCancellation.Dispose(); - existingRequestCancellation = null; - } - } - catch (Exception e) - { - Logger.Write( - LogLevel.Error, - string.Format( - "Exception while cancelling analysis task:\n\n{0}", - e.ToString())); - - TaskCompletionSource cancelTask = new TaskCompletionSource(); - cancelTask.SetCanceled(); - return cancelTask.Task; - } - - // Create a fresh cancellation token and then start the task. - // We create this on a different TaskScheduler so that we - // don't block the main message loop thread. - existingRequestCancellation = new CancellationTokenSource(); - Task.Factory.StartNew( - () => - DelayThenInvokeDiagnostics( - 750, - filesToAnalyze, - editorSession, - eventContext, - existingRequestCancellation.Token), - CancellationToken.None, - TaskCreationOptions.None, - TaskScheduler.Default); - - return Task.FromResult(true); - } - - /// - /// Actually run the script diagnostics after waiting for some small delay - /// - /// - /// - /// - /// - /// - private static async Task DelayThenInvokeDiagnostics( - int delayMilliseconds, - ScriptFile[] filesToAnalyze, - EditorSession editorSession, - EventContext eventContext, - CancellationToken cancellationToken) - { - // First of all, wait for the desired delay period before - // analyzing the provided list of files - try - { - await Task.Delay(delayMilliseconds, cancellationToken); - } - catch (TaskCanceledException) - { - // If the task is cancelled, exit directly - return; - } - - // If we've made it past the delay period then we don't care - // about the cancellation token anymore. This could happen - // when the user stops typing for long enough that the delay - // period ends but then starts typing while analysis is going - // on. It makes sense to send back the results from the first - // delay period while the second one is ticking away. - - // Get the requested files - foreach (ScriptFile scriptFile in filesToAnalyze) - { - ScriptFileMarker[] semanticMarkers = null; - if (editorSession.LanguageService != null) - { - Logger.Write(LogLevel.Verbose, "Analyzing script file: " + scriptFile.FilePath); - semanticMarkers = editorSession.LanguageService.GetSemanticMarkers(scriptFile); - Logger.Write(LogLevel.Verbose, "Analysis complete."); - } - else - { - // Semantic markers aren't available if the AnalysisService - // isn't available - semanticMarkers = new ScriptFileMarker[0]; - } - - await PublishScriptDiagnostics( - scriptFile, - semanticMarkers, - eventContext); - } - } - - /// - /// Send the diagnostic results back to the host application - /// - /// - /// - /// - private static async Task PublishScriptDiagnostics( - ScriptFile scriptFile, - ScriptFileMarker[] semanticMarkers, - EventContext eventContext) - { - var allMarkers = scriptFile.SyntaxMarkers != null - ? scriptFile.SyntaxMarkers.Concat(semanticMarkers) - : semanticMarkers; - - // Always send syntax and semantic errors. We want to - // make sure no out-of-date markers are being displayed. - await eventContext.SendEvent( - PublishDiagnosticsNotification.Type, - new PublishDiagnosticsNotification - { - Uri = scriptFile.ClientFilePath, - Diagnostics = - allMarkers - .Select(GetDiagnosticFromMarker) - .ToArray() - }); - } - - /// - /// Convert a ScriptFileMarker to a Diagnostic that is Language Service compatible - /// - /// - /// - private static Diagnostic GetDiagnosticFromMarker(ScriptFileMarker scriptFileMarker) - { - return new Diagnostic - { - Severity = MapDiagnosticSeverity(scriptFileMarker.Level), - Message = scriptFileMarker.Message, - Range = new Range - { - // TODO: What offsets should I use? - Start = new Position - { - Line = scriptFileMarker.ScriptRegion.StartLineNumber - 1, - Character = scriptFileMarker.ScriptRegion.StartColumnNumber - 1 - }, - End = new Position - { - Line = scriptFileMarker.ScriptRegion.EndLineNumber - 1, - Character = scriptFileMarker.ScriptRegion.EndColumnNumber - 1 - } - } - }; - } - - /// - /// Map ScriptFileMarker severity to Diagnostic severity - /// - /// - private static DiagnosticSeverity MapDiagnosticSeverity(ScriptFileMarkerLevel markerLevel) - { - switch (markerLevel) - { - case ScriptFileMarkerLevel.Error: - return DiagnosticSeverity.Error; - - case ScriptFileMarkerLevel.Warning: - return DiagnosticSeverity.Warning; - - case ScriptFileMarkerLevel.Information: - return DiagnosticSeverity.Information; - - default: - return DiagnosticSeverity.Error; - } - } - - /// - /// Switch from 0-based offsets to 1 based offsets - /// - /// - /// - private static FileChange GetFileChangeDetails(Range changeRange, string insertString) - { - // The protocol's positions are zero-based so add 1 to all offsets - return new FileChange - { - InsertString = insertString, - Line = changeRange.Start.Line + 1, - Offset = changeRange.Start.Character + 1, - EndLine = changeRange.End.Line + 1, - EndOffset = changeRange.End.Character + 1 - }; - } - } -} diff --git a/src/ServiceHost/Server/LanguageServerBase.cs b/src/ServiceHost/Server/LanguageServerBase.cs deleted file mode 100644 index 0128484b..00000000 --- a/src/ServiceHost/Server/LanguageServerBase.cs +++ /dev/null @@ -1,84 +0,0 @@ -// -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. -// - -using Microsoft.SqlTools.EditorServices.Protocol.LanguageServer; -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol; -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol.Channel; -using System.Threading.Tasks; - -namespace Microsoft.SqlTools.EditorServices.Protocol.Server -{ - public abstract class LanguageServerBase : ProtocolEndpoint - { - private bool isStarted; - private ChannelBase serverChannel; - private TaskCompletionSource serverExitedTask; - - public LanguageServerBase(ChannelBase serverChannel) : - base(serverChannel, MessageProtocolType.LanguageServer) - { - this.serverChannel = serverChannel; - } - - protected override Task OnStart() - { - // Register handlers for server lifetime messages - this.SetRequestHandler(ShutdownRequest.Type, this.HandleShutdownRequest); - this.SetEventHandler(ExitNotification.Type, this.HandleExitNotification); - - // Initialize the implementation class - this.Initialize(); - - return Task.FromResult(true); - } - - protected override async Task OnStop() - { - await this.Shutdown(); - } - - /// - /// Overridden by the subclass to provide initialization - /// logic after the server channel is started. - /// - protected abstract void Initialize(); - - /// - /// Can be overridden by the subclass to provide shutdown - /// logic before the server exits. Subclasses do not need - /// to invoke or return the value of the base implementation. - /// - protected virtual Task Shutdown() - { - // No default implementation yet. - return Task.FromResult(true); - } - - private async Task HandleShutdownRequest( - object shutdownParams, - RequestContext requestContext) - { - // Allow the implementor to shut down gracefully - await this.Shutdown(); - - await requestContext.SendResult(new object()); - } - - private async Task HandleExitNotification( - object exitParams, - EventContext eventContext) - { - // Stop the server channel - await this.Stop(); - - // Notify any waiter that the server has exited - if (this.serverExitedTask != null) - { - this.serverExitedTask.SetResult(true); - } - } - } -} - diff --git a/src/ServiceHost/Session/EditorSession.cs b/src/ServiceHost/Session/EditorSession.cs deleted file mode 100644 index 3c592a8f..00000000 --- a/src/ServiceHost/Session/EditorSession.cs +++ /dev/null @@ -1,75 +0,0 @@ -// -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. -// - -using System; -using Microsoft.SqlTools.EditorServices.Session; -using Microsoft.SqlTools.LanguageSupport; - -namespace Microsoft.SqlTools.EditorServices -{ - /// - /// Manages a single session for all editor services. This - /// includes managing all open script files for the session. - /// - public class EditorSession : IDisposable - { - #region Properties - - /// - /// Gets the Workspace instance for this session. - /// - public Workspace Workspace { get; private set; } - - /// - /// Gets or sets the Language Service - /// - /// - public LanguageService LanguageService { get; set; } - - /// - /// Gets the SqlToolsContext instance for this session. - /// - public SqlToolsContext SqlToolsContext { get; private set; } - - #endregion - - #region Public Methods - - /// - /// Starts the session using the provided IConsoleHost implementation - /// for the ConsoleService. - /// - /// - /// Provides details about the host application. - /// - /// - /// An object containing the profile paths for the session. - /// - public void StartSession(HostDetails hostDetails, ProfilePaths profilePaths) - { - // Initialize all services - this.SqlToolsContext = new SqlToolsContext(hostDetails, profilePaths); - this.LanguageService = new LanguageService(this.SqlToolsContext); - - // Create a workspace to contain open files - this.Workspace = new Workspace(this.SqlToolsContext.SqlToolsVersion); - } - - #endregion - - #region IDisposable Implementation - - /// - /// Disposes of any Runspaces that were created for the - /// services used in this session. - /// - public void Dispose() - { - } - - #endregion - - } -} diff --git a/src/ServiceHost/Session/OutputType.cs b/src/ServiceHost/Session/OutputType.cs deleted file mode 100644 index 8ba866d7..00000000 --- a/src/ServiceHost/Session/OutputType.cs +++ /dev/null @@ -1,41 +0,0 @@ -// -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. -// - -namespace Microsoft.SqlTools.EditorServices -{ - /// - /// Enumerates the types of output lines that will be sent - /// to an IConsoleHost implementation. - /// - public enum OutputType - { - /// - /// A normal output line, usually written with the or Write-Host or - /// Write-Output cmdlets. - /// - Normal, - - /// - /// A debug output line, written with the Write-Debug cmdlet. - /// - Debug, - - /// - /// A verbose output line, written with the Write-Verbose cmdlet. - /// - Verbose, - - /// - /// A warning output line, written with the Write-Warning cmdlet. - /// - Warning, - - /// - /// An error output line, written with the Write-Error cmdlet or - /// as a result of some error during SqlTools pipeline execution. - /// - Error - } -} diff --git a/src/ServiceHost/Session/OutputWrittenEventArgs.cs b/src/ServiceHost/Session/OutputWrittenEventArgs.cs deleted file mode 100644 index 4b1dbbe3..00000000 --- a/src/ServiceHost/Session/OutputWrittenEventArgs.cs +++ /dev/null @@ -1,65 +0,0 @@ -// -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. -// - -using System; - -namespace Microsoft.SqlTools.EditorServices -{ - /// - /// Provides details about output that has been written to the - /// SqlTools host. - /// - public class OutputWrittenEventArgs - { - /// - /// Gets the text of the output. - /// - public string OutputText { get; private set; } - - /// - /// Gets the type of the output. - /// - public OutputType OutputType { get; private set; } - - /// - /// Gets a boolean which indicates whether a newline - /// should be written after the output. - /// - public bool IncludeNewLine { get; private set; } - - /// - /// Gets the foreground color of the output text. - /// - public ConsoleColor ForegroundColor { get; private set; } - - /// - /// Gets the background color of the output text. - /// - public ConsoleColor BackgroundColor { get; private set; } - - /// - /// Creates an instance of the OutputWrittenEventArgs class. - /// - /// The text of the output. - /// A boolean which indicates whether a newline should be written after the output. - /// The type of the output. - /// The foreground color of the output text. - /// The background color of the output text. - public OutputWrittenEventArgs( - string outputText, - bool includeNewLine, - OutputType outputType, - ConsoleColor foregroundColor, - ConsoleColor backgroundColor) - { - this.OutputText = outputText; - this.IncludeNewLine = includeNewLine; - this.OutputType = outputType; - this.ForegroundColor = foregroundColor; - this.BackgroundColor = backgroundColor; - } - } -} - diff --git a/src/ServiceHost/project.json b/src/ServiceHost/project.json deleted file mode 100644 index 11340892..00000000 --- a/src/ServiceHost/project.json +++ /dev/null @@ -1,21 +0,0 @@ -{ - "version": "1.0.0-*", - "buildOptions": { - "debugType": "portable", - "emitEntryPoint": true - }, - "dependencies": { - "Newtonsoft.Json": "9.0.1" - }, - "frameworks": { - "netcoreapp1.0": { - "dependencies": { - "Microsoft.NETCore.App": { - "type": "platform", - "version": "1.0.0" - } - }, - "imports": "dnxcore50" - } - } -} diff --git a/test/CodeCoverage/ReplaceText.vbs b/test/CodeCoverage/ReplaceText.vbs new file mode 100644 index 00000000..31e27994 --- /dev/null +++ b/test/CodeCoverage/ReplaceText.vbs @@ -0,0 +1,55 @@ +' ReplaceText.vbs +' Copied from answer at http://stackoverflow.com/questions/1115508/batch-find-and-edit-lines-in-txt-file + +Option Explicit + +Const ForAppending = 8 +Const TristateFalse = 0 ' the value for ASCII +Const Overwrite = True + +Const WindowsFolder = 0 +Const SystemFolder = 1 +Const TemporaryFolder = 2 + +Dim FileSystem +Dim Filename, OldText, NewText +Dim OriginalFile, TempFile, Line +Dim TempFilename + +If WScript.Arguments.Count = 3 Then + Filename = WScript.Arguments.Item(0) + OldText = WScript.Arguments.Item(1) + NewText = WScript.Arguments.Item(2) +Else + Wscript.Echo "Usage: ReplaceText.vbs " + Wscript.Quit +End If + +Set FileSystem = CreateObject("Scripting.FileSystemObject") +Dim tempFolder: tempFolder = FileSystem.GetSpecialFolder(TemporaryFolder) +TempFilename = FileSystem.GetTempName + +If FileSystem.FileExists(TempFilename) Then + FileSystem.DeleteFile TempFilename +End If + +Set TempFile = FileSystem.CreateTextFile(TempFilename, Overwrite, TristateFalse) +Set OriginalFile = FileSystem.OpenTextFile(Filename) + +Do Until OriginalFile.AtEndOfStream + Line = OriginalFile.ReadLine + + If InStr(Line, OldText) > 0 Then + Line = Replace(Line, OldText, NewText) + End If + + TempFile.WriteLine(Line) +Loop + +OriginalFile.Close +TempFile.Close + +FileSystem.DeleteFile Filename +FileSystem.MoveFile TempFilename, Filename + +Wscript.Quit diff --git a/test/CodeCoverage/codecoverage.bat b/test/CodeCoverage/codecoverage.bat new file mode 100644 index 00000000..098ec1a1 --- /dev/null +++ b/test/CodeCoverage/codecoverage.bat @@ -0,0 +1,25 @@ +SET WORKINGDIR=%~dp0 + +REM clean-up results from previous run +RMDIR %WORKINGDIR%reports\ /S /Q +DEL %WORKINGDIR%coverage.xml +MKDIR reports + +REM backup current project.json +COPY /Y %WORKINGDIR%..\..\src\Microsoft.SqlTools.ServiceLayer\project.json %WORKINGDIR%..\..\src\Microsoft.SqlTools.ServiceLayer\project.json.BAK + +REM switch PDB type to Full since that is required by OpenCover for now +REM we should remove this step on OpenCover supports portable PDB +cscript /nologo ReplaceText.vbs %WORKINGDIR%..\..\src\Microsoft.SqlTools.ServiceLayer\project.json portable full + +REM rebuild the SqlToolsService project +dotnet build %WORKINGDIR%..\..\src\Microsoft.SqlTools.ServiceLayer\project.json + +REM run the tests through OpenCover and generate a report +"%WORKINGDIR%packages\OpenCover.4.6.519\tools\OpenCover.Console.exe" -register:user -target:dotnet.exe -targetargs:"test %WORKINGDIR%..\Microsoft.SqlTools.ServiceLayer.Test\project.json" -oldstyle -filter:"+[Microsoft.SqlTools.*]* -[xunit*]*" -output:coverage.xml -searchdirs:%WORKINGDIR%..\Microsoft.SqlTools.ServiceLayer.Test\bin\Debug\netcoreapp1.0 +"%WORKINGDIR%packages\ReportGenerator.2.4.5.0\tools\ReportGenerator.exe" "-reports:coverage.xml" "-targetdir:%WORKINGDIR%\reports" + +REM restore original project.json +COPY /Y %WORKINGDIR%..\..\src\Microsoft.SqlTools.ServiceLayer\project.json.BAK %WORKINGDIR%..\..\src\Microsoft.SqlTools.ServiceLayer\project.json +DEL %WORKINGDIR%..\..\src\Microsoft.SqlTools.ServiceLayer\project.json.BAK +EXIT diff --git a/test/CodeCoverage/gulpfile.js b/test/CodeCoverage/gulpfile.js new file mode 100644 index 00000000..38b1576c --- /dev/null +++ b/test/CodeCoverage/gulpfile.js @@ -0,0 +1,107 @@ +var gulp = require('gulp'); +var del = require('del'); +var request = require('request'); +var fs = require('fs'); +var gutil = require('gulp-util'); +var through = require('through2'); +var cproc = require('child_process'); +var os = require('os'); + +function nugetRestoreArgs(nupkg, options) { + var args = new Array(); + if (os.platform() != 'win32') { + args.push('./nuget.exe'); + } + args.push('restore'); + args.push(nupkg); + + var withValues = [ + 'source', + 'configFile', + 'packagesDirectory', + 'solutionDirectory', + 'msBuildVersion' + ]; + + var withoutValues = [ + 'noCache', + 'requireConsent', + 'disableParallelProcessing' + ]; + + withValues.forEach(function(prop) { + var value = options[prop]; + if(value) { + args.push('-' + prop); + args.push(value); + } + }); + + withoutValues.forEach(function(prop) { + var value = options[prop]; + if(value) { + args.push('-' + prop); + } + }); + + args.push('-noninteractive'); + + return args; +}; + +function nugetRestore(options) { + options = options || {}; + options.nuget = options.nuget || './nuget.exe'; + if (os.platform() != 'win32') { + options.nuget = 'mono'; + } + + return through.obj(function(file, encoding, done) { + var args = nugetRestoreArgs(file.path, options); + cproc.execFile(options.nuget, args, function(err, stdout) { + if (err) { + throw new gutil.PluginError('gulp-nuget', err); + } + + gutil.log(stdout.trim()); + done(null, file); + }); + }); +}; + +gulp.task('ext:nuget-download', function(done) { + if(fs.existsSync('nuget.exe')) { + return done(); + } + + request.get('http://nuget.org/nuget.exe') + .pipe(fs.createWriteStream('nuget.exe')) + .on('close', done); +}); + +gulp.task('ext:nuget-restore', function() { + + var options = { + configFile: './nuget.config', + packagesDirectory: './packages' + }; + + return gulp.src('./packages.config') + .pipe(nugetRestore(options)); +}); + + +gulp.task('ext:code-coverage', function(done) { + cproc.execFile('cmd.exe', [ '/c', 'codecoverage.bat' ], function(err, stdout) { + if (err) { + throw new gutil.PluginError('ext:code-coverage', err); + } + + gutil.log(stdout.trim()); + }); + return done(); +}); + +gulp.task('test', gulp.series('ext:nuget-download', 'ext:nuget-restore', 'ext:code-coverage')); + +gulp.task('default', gulp.series('test')); diff --git a/test/CodeCoverage/nuget.config b/test/CodeCoverage/nuget.config new file mode 100644 index 00000000..1eab8195 --- /dev/null +++ b/test/CodeCoverage/nuget.config @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/test/CodeCoverage/package.json b/test/CodeCoverage/package.json new file mode 100644 index 00000000..0d4efd10 --- /dev/null +++ b/test/CodeCoverage/package.json @@ -0,0 +1,16 @@ +{ + "name": "sqltoolsservice", + "version": "0.1.0", + "description": "SQL Tools Service Layer", + "main": "gulpfile.js", + "dependencies": { + "gulp": "github:gulpjs/gulp#4.0", + "del": "^2.2.1", + "gulp-hub": "frankwallis/gulp-hub#registry-init", + "gulp-install": "^0.6.0", + "request": "^2.73.0" + }, + "devDependencies": {}, + "author": "Microsoft", + "license": "MIT" +} diff --git a/test/CodeCoverage/packages.config b/test/CodeCoverage/packages.config new file mode 100644 index 00000000..4a7355aa --- /dev/null +++ b/test/CodeCoverage/packages.config @@ -0,0 +1,5 @@ + + + + + diff --git a/test/ServiceHost.Test/App.config b/test/Microsoft.SqlTools.ServiceLayer.Test/App.config similarity index 100% rename from test/ServiceHost.Test/App.config rename to test/Microsoft.SqlTools.ServiceLayer.Test/App.config diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs new file mode 100644 index 00000000..72bafabf --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs @@ -0,0 +1,586 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Reflection; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; +using Microsoft.SqlTools.Test.Utility; +using Moq; +using Moq.Protected; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Connection +{ + /// + /// Tests for the ServiceHost Connection Service tests + /// + public class ConnectionServiceTests + { + /// + /// Creates a mock db command that returns a predefined result set + /// + public static DbCommand CreateTestCommand(Dictionary[][] data) + { + var commandMock = new Mock { CallBase = true }; + var commandMockSetup = commandMock.Protected() + .Setup("ExecuteDbDataReader", It.IsAny()); + + commandMockSetup.Returns(new TestDbDataReader(data)); + + return commandMock.Object; + } + + /// + /// Creates a mock db connection that returns predefined data when queried for a result set + /// + public DbConnection CreateMockDbConnection(Dictionary[][] data) + { + var connectionMock = new Mock { CallBase = true }; + connectionMock.Protected() + .Setup("CreateDbCommand") + .Returns(CreateTestCommand(data)); + + return connectionMock.Object; + } + + /// Verify that we can connect to the default database when no database name is + /// provided as a parameter. + /// + [Theory] + [InlineDataAttribute(null)] + [InlineDataAttribute("")] + public void CanConnectWithEmptyDatabaseName(string databaseName) + { + // Connect + var connectionDetails = TestObjects.GetTestConnectionDetails(); + connectionDetails.DatabaseName = databaseName; + var connectionResult = + TestObjects.GetTestConnectionService() + .Connect(new ConnectParams() + { + OwnerUri = "file:///my/test/file.sql", + Connection = connectionDetails + }); + + // check that a connection was created + Assert.NotEmpty(connectionResult.ConnectionId); + } + + /// + /// Verify that when a connection is started for a URI with an already existing + /// connection, we disconnect first before connecting. + /// + [Fact] + public void ConnectingWhenConnectionExistCausesDisconnectThenConnect() + { + bool callbackInvoked = false; + + // first connect + string ownerUri = "file://my/sample/file.sql"; + var connectionService = TestObjects.GetTestConnectionService(); + var connectionResult = + connectionService + .Connect(new ConnectParams() + { + OwnerUri = ownerUri, + Connection = TestObjects.GetTestConnectionDetails() + }); + + // verify that we are connected + Assert.NotEmpty(connectionResult.ConnectionId); + + // register disconnect callback + connectionService.RegisterOnDisconnectTask( + (result) => { + callbackInvoked = true; + return Task.FromResult(true); + } + ); + + // send annother connect request + connectionResult = + connectionService + .Connect(new ConnectParams() + { + OwnerUri = ownerUri, + Connection = TestObjects.GetTestConnectionDetails() + }); + + // verify that the event was fired (we disconnected first before connecting) + Assert.True(callbackInvoked); + + // verify that we connected again + Assert.NotEmpty(connectionResult.ConnectionId); + } + + /// + /// Verify that when connecting with invalid credentials, an error is thrown. + /// + [Fact] + public void ConnectingWithInvalidCredentialsYieldsErrorMessage() + { + var testConnectionDetails = TestObjects.GetTestConnectionDetails(); + var invalidConnectionDetails = new ConnectionDetails(); + invalidConnectionDetails.ServerName = testConnectionDetails.ServerName; + invalidConnectionDetails.DatabaseName = testConnectionDetails.DatabaseName; + invalidConnectionDetails.UserName = "invalidUsername"; // triggers exception when opening mock connection + invalidConnectionDetails.Password = "invalidPassword"; + + // Connect to test db with invalid credentials + var connectionResult = + TestObjects.GetTestConnectionService() + .Connect(new ConnectParams() + { + OwnerUri = "file://my/sample/file.sql", + Connection = invalidConnectionDetails + }); + + // check that an error was caught + Assert.NotNull(connectionResult.Messages); + Assert.NotEqual(String.Empty, connectionResult.Messages); + } + + /// + /// Verify that when connecting with invalid parameters, an error is thrown. + /// + [Theory] + [InlineData("SqlLogin", null, "my-server", "test", "sa", "123456")] + [InlineData("SqlLogin", "file://my/sample/file.sql", null, "test", "sa", "123456")] + [InlineData("SqlLogin", "file://my/sample/file.sql", "my-server", "test", null, "123456")] + [InlineData("SqlLogin", "file://my/sample/file.sql", "my-server", "test", "sa", null)] + [InlineData("SqlLogin", "", "my-server", "test", "sa", "123456")] + [InlineData("SqlLogin", "file://my/sample/file.sql", "", "test", "sa", "123456")] + [InlineData("SqlLogin", "file://my/sample/file.sql", "my-server", "test", "", "123456")] + [InlineData("SqlLogin", "file://my/sample/file.sql", "my-server", "test", "sa", "")] + [InlineData("Integrated", null, "my-server", "test", "sa", "123456")] + [InlineData("Integrated", "file://my/sample/file.sql", null, "test", "sa", "123456")] + [InlineData("Integrated", "", "my-server", "test", "sa", "123456")] + [InlineData("Integrated", "file://my/sample/file.sql", "", "test", "sa", "123456")] + public void ConnectingWithInvalidParametersYieldsErrorMessage(string authType, string ownerUri, string server, string database, string userName, string password) + { + // Connect with invalid parameters + var connectionResult = + TestObjects.GetTestConnectionService() + .Connect(new ConnectParams() + { + OwnerUri = ownerUri, + Connection = new ConnectionDetails() { + ServerName = server, + DatabaseName = database, + UserName = userName, + Password = password, + AuthenticationType = authType + } + }); + + // check that an error was caught + Assert.NotNull(connectionResult.Messages); + Assert.NotEqual(String.Empty, connectionResult.Messages); + } + + /// + /// Verify that when using integrated authentication, the username and/or password can be empty. + /// + [Theory] + [InlineData(null, null)] + [InlineData(null, "")] + [InlineData("", null)] + [InlineData("", "")] + [InlineData("sa", null)] + [InlineData("sa", "")] + [InlineData(null, "12345678")] + [InlineData("", "12345678")] + public void ConnectingWithNoUsernameOrPasswordWorksForIntegratedAuth(string userName, string password) + { + // Connect + var connectionResult = + TestObjects.GetTestConnectionService() + .Connect(new ConnectParams() + { + OwnerUri = "file:///my/test/file.sql", + Connection = new ConnectionDetails() { + ServerName = "my-server", + DatabaseName = "test", + UserName = userName, + Password = password, + AuthenticationType = "Integrated" + } + }); + + // check that the connection was successful + Assert.NotEmpty(connectionResult.ConnectionId); + Assert.Null(connectionResult.Messages); + } + + /// + /// Verify that when connecting with a null parameters object, an error is thrown. + /// + [Fact] + public void ConnectingWithNullParametersObjectYieldsErrorMessage() + { + // Connect with null parameters + var connectionResult = + TestObjects.GetTestConnectionService() + .Connect(null); + + // check that an error was caught + Assert.NotNull(connectionResult.Messages); + Assert.NotEqual(String.Empty, connectionResult.Messages); + } + + /// + /// Verify that optional parameters can be built into a connection string for connecting. + /// + [Theory] + [InlineData("AuthenticationType", "Integrated", "Integrated Security")] + [InlineData("AuthenticationType", "SqlLogin", "")] + [InlineData("Encrypt", true, "Encrypt")] + [InlineData("Encrypt", false, "Encrypt")] + [InlineData("TrustServerCertificate", true, "TrustServerCertificate")] + [InlineData("TrustServerCertificate", false, "TrustServerCertificate")] + [InlineData("PersistSecurityInfo", true, "Persist Security Info")] + [InlineData("PersistSecurityInfo", false, "Persist Security Info")] + [InlineData("ConnectTimeout", 15, "Connect Timeout")] + [InlineData("ConnectRetryCount", 1, "ConnectRetryCount")] + [InlineData("ConnectRetryInterval", 10, "ConnectRetryInterval")] + [InlineData("ApplicationName", "vscode-mssql", "Application Name")] + [InlineData("WorkstationId", "mycomputer", "Workstation ID")] + [InlineData("ApplicationIntent", "ReadWrite", "ApplicationIntent")] + [InlineData("ApplicationIntent", "ReadOnly", "ApplicationIntent")] + [InlineData("CurrentLanguage", "test", "Current Language")] + [InlineData("Pooling", false, "Pooling")] + [InlineData("Pooling", true, "Pooling")] + [InlineData("MaxPoolSize", 100, "Max Pool Size")] + [InlineData("MinPoolSize", 0, "Min Pool Size")] + [InlineData("LoadBalanceTimeout", 0, "Load Balance Timeout")] + [InlineData("Replication", true, "Replication")] + [InlineData("Replication", false, "Replication")] + [InlineData("AttachDbFilename", "myfile", "AttachDbFilename")] + [InlineData("FailoverPartner", "partner", "Failover Partner")] + [InlineData("MultiSubnetFailover", true, "MultiSubnetFailover")] + [InlineData("MultiSubnetFailover", false, "MultiSubnetFailover")] + [InlineData("MultipleActiveResultSets", false, "MultipleActiveResultSets")] + [InlineData("MultipleActiveResultSets", true, "MultipleActiveResultSets")] + [InlineData("PacketSize", 8192, "Packet Size")] + [InlineData("TypeSystemVersion", "Latest", "Type System Version")] + public void ConnectingWithOptionalParametersBuildsConnectionString(string propertyName, object propertyValue, string connectionStringMarker) + { + // Create a test connection details object and set the property to a specific value + ConnectionDetails details = TestObjects.GetTestConnectionDetails(); + PropertyInfo info = details.GetType().GetProperty(propertyName); + info.SetValue(details, propertyValue); + + // Test that a connection string can be created without exceptions + string connectionString = ConnectionService.BuildConnectionString(details); + Assert.NotNull(connectionString); + Assert.NotEmpty(connectionString); + + // Verify that the parameter is in the connection string + Assert.True(connectionString.Contains(connectionStringMarker)); + } + + /// + /// Verify that a connection changed event is fired when the database context changes. + /// + [Fact] + public void ConnectionChangedEventIsFiredWhenDatabaseContextChanges() + { + var serviceHostMock = new Mock(); + + var connectionService = TestObjects.GetTestConnectionService(); + connectionService.ServiceHost = serviceHostMock.Object; + + // Set up an initial connection + string ownerUri = "file://my/sample/file.sql"; + var connectionResult = + connectionService + .Connect(new ConnectParams() + { + OwnerUri = ownerUri, + Connection = TestObjects.GetTestConnectionDetails() + }); + + // verify that a valid connection id was returned + Assert.NotEmpty(connectionResult.ConnectionId); + + ConnectionInfo info; + Assert.True(connectionService.TryFindConnection(ownerUri, out info)); + + // Tell the connection manager that the database change ocurred + connectionService.ChangeConnectionDatabaseContext(ownerUri, "myOtherDb"); + + // Verify that the connection changed event was fired + serviceHostMock.Verify(x => x.SendEvent(ConnectionChangedNotification.Type, It.IsAny()), Times.Once()); + } + + /// + /// Verify that the SQL parser correctly detects errors in text + /// + [Fact] + public void ConnectToDatabaseTest() + { + // connect to a database instance + string ownerUri = "file://my/sample/file.sql"; + var connectionResult = + TestObjects.GetTestConnectionService() + .Connect(new ConnectParams() + { + OwnerUri = ownerUri, + Connection = TestObjects.GetTestConnectionDetails() + }); + + // verify that a valid connection id was returned + Assert.NotEmpty(connectionResult.ConnectionId); + } + + /// + /// Verify that we can disconnect from an active connection succesfully + /// + [Fact] + public void DisconnectFromDatabaseTest() + { + // first connect + string ownerUri = "file://my/sample/file.sql"; + var connectionService = TestObjects.GetTestConnectionService(); + var connectionResult = + connectionService + .Connect(new ConnectParams() + { + OwnerUri = ownerUri, + Connection = TestObjects.GetTestConnectionDetails() + }); + + // verify that we are connected + Assert.NotEmpty(connectionResult.ConnectionId); + + // send disconnect request + var disconnectResult = + connectionService + .Disconnect(new DisconnectParams() + { + OwnerUri = ownerUri + }); + Assert.True(disconnectResult); + } + + /// + /// Test that when a disconnect is performed, the callback event is fired + /// + [Fact] + public void DisconnectFiresCallbackEvent() + { + bool callbackInvoked = false; + + // first connect + string ownerUri = "file://my/sample/file.sql"; + var connectionService = TestObjects.GetTestConnectionService(); + var connectionResult = + connectionService + .Connect(new ConnectParams() + { + OwnerUri = ownerUri, + Connection = TestObjects.GetTestConnectionDetails() + }); + + // verify that we are connected + Assert.NotEmpty(connectionResult.ConnectionId); + + // register disconnect callback + connectionService.RegisterOnDisconnectTask( + (result) => { + callbackInvoked = true; + return Task.FromResult(true); + } + ); + + // send disconnect request + var disconnectResult = + connectionService + .Disconnect(new DisconnectParams() + { + OwnerUri = ownerUri + }); + Assert.True(disconnectResult); + + // verify that the event was fired + Assert.True(callbackInvoked); + } + + /// + /// Test that disconnecting an active connection removes the Owner URI -> ConnectionInfo mapping + /// + [Fact] + public void DisconnectRemovesOwnerMapping() + { + // first connect + string ownerUri = "file://my/sample/file.sql"; + var connectionService = TestObjects.GetTestConnectionService(); + var connectionResult = + connectionService + .Connect(new ConnectParams() + { + OwnerUri = ownerUri, + Connection = TestObjects.GetTestConnectionDetails() + }); + + // verify that we are connected + Assert.NotEmpty(connectionResult.ConnectionId); + + // check that the owner mapping exists + ConnectionInfo info; + Assert.True(connectionService.TryFindConnection(ownerUri, out info)); + + // send disconnect request + var disconnectResult = + connectionService + .Disconnect(new DisconnectParams() + { + OwnerUri = ownerUri + }); + Assert.True(disconnectResult); + + // check that the owner mapping no longer exists + Assert.False(connectionService.TryFindConnection(ownerUri, out info)); + } + + /// + /// Test that disconnecting validates parameters and doesn't succeed when they are invalid + /// + [Theory] + [InlineDataAttribute(null)] + [InlineDataAttribute("")] + + public void DisconnectValidatesParameters(string disconnectUri) + { + // first connect + string ownerUri = "file://my/sample/file.sql"; + var connectionService = TestObjects.GetTestConnectionService(); + var connectionResult = + connectionService + .Connect(new ConnectParams() + { + OwnerUri = ownerUri, + Connection = TestObjects.GetTestConnectionDetails() + }); + + // verify that we are connected + Assert.NotEmpty(connectionResult.ConnectionId); + + // send disconnect request + var disconnectResult = + connectionService + .Disconnect(new DisconnectParams() + { + OwnerUri = disconnectUri + }); + + // verify that disconnect failed + Assert.False(disconnectResult); + } + + /// + /// Verifies the the list databases operation lists database names for the server used by a connection. + /// + [Fact] + public void ListDatabasesOnServerForCurrentConnectionReturnsDatabaseNames() + { + // Result set for the query of database names + Dictionary[] data = + { + new Dictionary { {"name", "master" } }, + new Dictionary { {"name", "model" } }, + new Dictionary { {"name", "msdb" } }, + new Dictionary { {"name", "tempdb" } }, + new Dictionary { {"name", "mydatabase" } }, + }; + + // Setup mock connection factory to inject query results + var mockFactory = new Mock(); + mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny())) + .Returns(CreateMockDbConnection(new[] {data})); + var connectionService = new ConnectionService(mockFactory.Object); + + // connect to a database instance + string ownerUri = "file://my/sample/file.sql"; + var connectionResult = + connectionService + .Connect(new ConnectParams() + { + OwnerUri = ownerUri, + Connection = TestObjects.GetTestConnectionDetails() + }); + + // verify that a valid connection id was returned + Assert.NotEmpty(connectionResult.ConnectionId); + + // list databases for the connection + ListDatabasesParams parameters = new ListDatabasesParams(); + parameters.OwnerUri = ownerUri; + var listDatabasesResult = connectionService.ListDatabases(parameters); + string[] databaseNames = listDatabasesResult.DatabaseNames; + + Assert.Equal(databaseNames.Length, 5); + Assert.Equal(databaseNames[0], "master"); + Assert.Equal(databaseNames[1], "model"); + Assert.Equal(databaseNames[2], "msdb"); + Assert.Equal(databaseNames[3], "tempdb"); + Assert.Equal(databaseNames[4], "mydatabase"); + } + + /// + /// Verify that the SQL parser correctly detects errors in text + /// + [Fact] + public void OnConnectionCallbackHandlerTest() + { + bool callbackInvoked = false; + + // setup connection service with callback + var connectionService = TestObjects.GetTestConnectionService(); + connectionService.RegisterOnConnectionTask( + (sqlConnection) => { + callbackInvoked = true; + return Task.FromResult(true); + } + ); + + // connect to a database instance + var connectionResult = connectionService.Connect(TestObjects.GetTestConnectionParams()); + + // verify that a valid connection id was returned + Assert.True(callbackInvoked); + } + + /// + /// Verify when a connection is created that the URI -> Connection mapping is created in the connection service. + /// + [Fact] + public void TestConnectRequestRegistersOwner() + { + // Given a request to connect to a database + var service = TestObjects.GetTestConnectionService(); + var connectParams = TestObjects.GetTestConnectionParams(); + + // connect to a database instance + var connectionResult = service.Connect(connectParams); + + // verify that a valid connection id was returned + Assert.NotNull(connectionResult.ConnectionId); + Assert.NotEqual(String.Empty, connectionResult.ConnectionId); + Assert.NotNull(new Guid(connectionResult.ConnectionId)); + + // verify that the (URI -> connection) mapping was created + ConnectionInfo info; + Assert.True(service.TryFindConnection(connectParams.OwnerUri, out info)); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Credentials/CredentialServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Credentials/CredentialServiceTests.cs new file mode 100644 index 00000000..7adbdebe --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Credentials/CredentialServiceTests.cs @@ -0,0 +1,287 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.IO; +using System.Runtime.InteropServices; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Credentials; +using Microsoft.SqlTools.ServiceLayer.Credentials.Contracts; +using Microsoft.SqlTools.ServiceLayer.Credentials.Linux; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; +using Moq; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Connection +{ + /// + /// Credential Service tests that should pass on all platforms, regardless of backing store. + /// These tests run E2E, storing values in the native credential store for whichever platform + /// tests are being run on + /// + public class CredentialServiceTests : IDisposable + { + private static readonly LinuxCredentialStore.StoreConfig config = new LinuxCredentialStore.StoreConfig() + { + CredentialFolder = ".testsecrets", + CredentialFile = "sqltestsecrets.json", + IsRelativeToUserHomeDir = true + }; + + const string credentialId = "Microsoft_SqlToolsTest_TestId"; + const string password1 = "P@ssw0rd1"; + const string password2 = "2Pass2Furious"; + + const string otherCredId = credentialId + "2345"; + const string otherPassword = credentialId + "2345"; + + // Test-owned credential store used to clean up before/after tests to ensure code works as expected + // even if previous runs stopped midway through + private ICredentialStore credStore; + private CredentialService service; + /// + /// Constructor called once for every test + /// + public CredentialServiceTests() + { + credStore = CredentialService.GetStoreForOS(config); + service = new CredentialService(credStore, config); + DeleteDefaultCreds(); + } + + public void Dispose() + { + DeleteDefaultCreds(); + } + + private void DeleteDefaultCreds() + { + credStore.DeletePassword(credentialId); + credStore.DeletePassword(otherCredId); + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + string credsFolder = ((LinuxCredentialStore)credStore).CredentialFolderPath; + if (Directory.Exists(credsFolder)) + { + Directory.Delete(credsFolder, true); + } + } + } + + [Fact] + public async Task SaveCredentialThrowsIfCredentialIdMissing() + { + object errorResponse = null; + var contextMock = RequestContextMocks.Create(null).AddErrorHandling(obj => errorResponse = obj); + + await service.HandleSaveCredentialRequest(new Credential(null), contextMock.Object); + VerifyErrorSent(contextMock); + Assert.True(((string)errorResponse).Contains("ArgumentException")); + } + + [Fact] + public async Task SaveCredentialThrowsIfPasswordMissing() + { + object errorResponse = null; + var contextMock = RequestContextMocks.Create(null).AddErrorHandling(obj => errorResponse = obj); + + await service.HandleSaveCredentialRequest(new Credential(credentialId), contextMock.Object); + VerifyErrorSent(contextMock); + Assert.True(((string)errorResponse).Contains("ArgumentException")); + } + + [Fact] + public async Task SaveCredentialWorksForSingleCredential() + { + await RunAndVerify( + test: (requestContext) => service.HandleSaveCredentialRequest(new Credential(credentialId, password1), requestContext), + verify: (actual => Assert.True(actual))); + } + + [Fact] + public async Task SaveCredentialSupportsSavingCredentialMultipleTimes() + { + await RunAndVerify( + test: (requestContext) => service.HandleSaveCredentialRequest(new Credential(credentialId, password1), requestContext), + verify: (actual => Assert.True(actual))); + + await RunAndVerify( + test: (requestContext) => service.HandleSaveCredentialRequest(new Credential(credentialId, password1), requestContext), + verify: (actual => Assert.True(actual))); + } + + [Fact] + public async Task ReadCredentialWorksForSingleCredential() + { + // Given we have saved the credential + await RunAndVerify( + test: (requestContext) => service.HandleSaveCredentialRequest(new Credential(credentialId, password1), requestContext), + verify: (actual => Assert.True(actual, "Expect Credential to be saved successfully"))); + + + // Expect read of the credential to return the password + await RunAndVerify( + test: (requestContext) => service.HandleReadCredentialRequest(new Credential(credentialId, null), requestContext), + verify: (actual => + { + Assert.Equal(password1, actual.Password); + })); + } + + [Fact] + public async Task ReadCredentialWorksForMultipleCredentials() + { + + // Given we have saved multiple credentials + await RunAndVerify( + test: (requestContext) => service.HandleSaveCredentialRequest(new Credential(credentialId, password1), requestContext), + verify: (actual => Assert.True(actual, "Expect Credential to be saved successfully"))); + await RunAndVerify( + test: (requestContext) => service.HandleSaveCredentialRequest(new Credential(otherCredId, otherPassword), requestContext), + verify: (actual => Assert.True(actual, "Expect Credential to be saved successfully"))); + + + // Expect read of the credentials to return the right password + await RunAndVerify( + test: (requestContext) => service.HandleReadCredentialRequest(new Credential(credentialId, null), requestContext), + verify: (actual => + { + Assert.Equal(password1, actual.Password); + })); + await RunAndVerify( + test: (requestContext) => service.HandleReadCredentialRequest(new Credential(otherCredId, null), requestContext), + verify: (actual => + { + Assert.Equal(otherPassword, actual.Password); + })); + } + + [Fact] + public async Task ReadCredentialHandlesPasswordUpdate() + { + // Given we have saved twice with a different password + await RunAndVerify( + test: (requestContext) => service.HandleSaveCredentialRequest(new Credential(credentialId, password1), requestContext), + verify: (actual => Assert.True(actual))); + + await RunAndVerify( + test: (requestContext) => service.HandleSaveCredentialRequest(new Credential(credentialId, password2), requestContext), + verify: (actual => Assert.True(actual))); + + // When we read the value for this credential + // Then we expect only the last saved password to be found + await RunAndVerify( + test: (requestContext) => service.HandleReadCredentialRequest(new Credential(credentialId), requestContext), + verify: (actual => + { + Assert.Equal(password2, actual.Password); + })); + } + + [Fact] + public async Task ReadCredentialThrowsIfCredentialIsNull() + { + object errorResponse = null; + var contextMock = RequestContextMocks.Create(null).AddErrorHandling(obj => errorResponse = obj); + + // Verify throws on null, and this is sent as an error + await service.HandleReadCredentialRequest(null, contextMock.Object); + VerifyErrorSent(contextMock); + Assert.True(((string)errorResponse).Contains("ArgumentNullException")); + } + + [Fact] + public async Task ReadCredentialThrowsIfIdMissing() + { + object errorResponse = null; + var contextMock = RequestContextMocks.Create(null).AddErrorHandling(obj => errorResponse = obj); + + // Verify throws with no ID + await service.HandleReadCredentialRequest(new Credential(), contextMock.Object); + VerifyErrorSent(contextMock); + Assert.True(((string)errorResponse).Contains("ArgumentException")); + } + + [Fact] + public async Task ReadCredentialReturnsNullPasswordForMissingCredential() + { + // Given a credential whose password doesn't exist + string credWithNoPassword = "Microsoft_SqlTools_CredThatDoesNotExist"; + + // When reading the credential + // Then expect the credential to be returned but password left blank + await RunAndVerify( + test: (requestContext) => service.HandleReadCredentialRequest(new Credential(credWithNoPassword, null), requestContext), + verify: (actual => + { + Assert.NotNull(actual); + Assert.Equal(credWithNoPassword, actual.CredentialId); + Assert.Null(actual.Password); + })); + } + + [Fact] + public async Task DeleteCredentialThrowsIfIdMissing() + { + object errorResponse = null; + var contextMock = RequestContextMocks.Create(null).AddErrorHandling(obj => errorResponse = obj); + + // Verify throws with no ID + await service.HandleDeleteCredentialRequest(new Credential(), contextMock.Object); + VerifyErrorSent(contextMock); + Assert.True(((string)errorResponse).Contains("ArgumentException")); + } + + [Fact] + public async Task DeleteCredentialReturnsTrueOnlyIfCredentialExisted() + { + // Save should be true + await RunAndVerify( + test: (requestContext) => service.HandleSaveCredentialRequest(new Credential(credentialId, password1), requestContext), + verify: (actual => Assert.True(actual))); + + // Then delete - should return true + await RunAndVerify( + test: (requestContext) => service.HandleDeleteCredentialRequest(new Credential(credentialId), requestContext), + verify: (actual => Assert.True(actual))); + + // Then delete - should return false as no longer exists + await RunAndVerify( + test: (requestContext) => service.HandleDeleteCredentialRequest(new Credential(credentialId), requestContext), + verify: (actual => Assert.False(actual))); + } + + private async Task RunAndVerify(Func, Task> test, Action verify) + { + T result = default(T); + var contextMock = RequestContextMocks.Create(r => result = r).AddErrorHandling(null); + await test(contextMock.Object); + VerifyResult(contextMock, verify, result); + } + + private void VerifyErrorSent(Mock> contextMock) + { + contextMock.Verify(c => c.SendResult(It.IsAny()), Times.Never); + contextMock.Verify(c => c.SendError(It.IsAny()), Times.Once); + } + + private void VerifyResult(Mock> contextMock, U expected, U actual) + { + contextMock.Verify(c => c.SendResult(It.IsAny()), Times.Once); + Assert.Equal(expected, actual); + contextMock.Verify(c => c.SendError(It.IsAny()), Times.Never); + } + + private void VerifyResult(Mock> contextMock, Action verify, T actual) + { + contextMock.Verify(c => c.SendResult(It.IsAny()), Times.Once); + contextMock.Verify(c => c.SendError(It.IsAny()), Times.Never); + verify(actual); + } + + } +} + diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Credentials/Linux/LinuxInteropTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Credentials/Linux/LinuxInteropTests.cs new file mode 100644 index 00000000..1dcff8e6 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Credentials/Linux/LinuxInteropTests.cs @@ -0,0 +1,37 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Credentials; +using Microsoft.SqlTools.ServiceLayer.Credentials.Linux; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Credentials +{ + public class LinuxInteropTests + { + [Fact] + public void GetEUidReturnsInt() + { + TestUtils.RunIfLinux(() => + { + Assert.NotNull(Interop.Sys.GetEUid()); + }); + } + + [Fact] + public void GetHomeDirectoryFromPwFindsHomeDir() + { + + TestUtils.RunIfLinux(() => + { + string userDir = LinuxCredentialStore.GetHomeDirectoryFromPw(); + Assert.StartsWith("/", userDir); + }); + } + + } +} + diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Credentials/Win32/CredentialSetTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Credentials/Win32/CredentialSetTests.cs new file mode 100644 index 00000000..19edb663 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Credentials/Win32/CredentialSetTests.cs @@ -0,0 +1,99 @@ +// +// Code originally from http://credentialmanagement.codeplex.com/, +// Licensed under the Apache License 2.0 +// + +using System; +using Microsoft.SqlTools.ServiceLayer.Credentials.Win32; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Credentials +{ + public class CredentialSetTests + { + [Fact] + public void CredentialSetCreate() + { + TestUtils.RunIfWindows(() => + { + Assert.NotNull(new CredentialSet()); + }); + } + + [Fact] + public void CredentialSetCreateWithTarget() + { + TestUtils.RunIfWindows(() => + { + Assert.NotNull(new CredentialSet("target")); + }); + } + + [Fact] + public void CredentialSetShouldBeIDisposable() + { + TestUtils.RunIfWindows(() => + { + Assert.True(new CredentialSet() is IDisposable, "CredentialSet needs to implement IDisposable Interface."); + }); + } + + [Fact] + public void CredentialSetLoad() + { + TestUtils.RunIfWindows(() => + { + Win32Credential credential = new Win32Credential + { + Username = "username", + Password = "password", + Target = "target", + Type = CredentialType.Generic + }; + credential.Save(); + + CredentialSet set = new CredentialSet(); + set.Load(); + Assert.NotNull(set); + Assert.NotEmpty(set); + + credential.Delete(); + + set.Dispose(); + }); + } + + [Fact] + public void CredentialSetLoadShouldReturnSelf() + { + TestUtils.RunIfWindows(() => + { + CredentialSet set = new CredentialSet(); + Assert.IsType(set.Load()); + + set.Dispose(); + }); + } + + [Fact] + public void CredentialSetLoadWithTargetFilter() + { + TestUtils.RunIfWindows(() => + { + Win32Credential credential = new Win32Credential + { + Username = "filteruser", + Password = "filterpassword", + Target = "filtertarget" + }; + credential.Save(); + + CredentialSet set = new CredentialSet("filtertarget"); + Assert.Equal(1, set.Load().Count); + set.Dispose(); + }); + } + + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Credentials/Win32/Win32CredentialTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Credentials/Win32/Win32CredentialTests.cs new file mode 100644 index 00000000..ae3f18b2 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Credentials/Win32/Win32CredentialTests.cs @@ -0,0 +1,145 @@ +// +// Code originally from http://credentialmanagement.codeplex.com/, +// Licensed under the Apache License 2.0 +// + +using System; +using Microsoft.SqlTools.ServiceLayer.Credentials.Win32; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Credentials +{ + public class Win32CredentialTests + { + [Fact] + public void Credential_Create_ShouldNotThrowNull() + { + TestUtils.RunIfWindows(() => + { + Assert.NotNull(new Win32Credential()); + }); + } + + [Fact] + public void Credential_Create_With_Username_ShouldNotThrowNull() + { + TestUtils.RunIfWindows(() => + { + Assert.NotNull(new Win32Credential("username")); + }); + } + + [Fact] + public void Credential_Create_With_Username_And_Password_ShouldNotThrowNull() + { + TestUtils.RunIfWindows(() => + { + Assert.NotNull(new Win32Credential("username", "password")); + }); + } + + [Fact] + public void Credential_Create_With_Username_Password_Target_ShouldNotThrowNull() + { + TestUtils.RunIfWindows(() => + { + Assert.NotNull(new Win32Credential("username", "password", "target")); + }); + } + + [Fact] + public void Credential_ShouldBe_IDisposable() + { + TestUtils.RunIfWindows(() => + { + Assert.True(new Win32Credential() is IDisposable, "Credential should implement IDisposable Interface."); + }); + } + + [Fact] + public void Credential_Dispose_ShouldNotThrowException() + { + TestUtils.RunIfWindows(() => + { + new Win32Credential().Dispose(); + }); + } + + [Fact] + public void Credential_ShouldThrowObjectDisposedException() + { + TestUtils.RunIfWindows(() => + { + Win32Credential disposed = new Win32Credential { Password = "password" }; + disposed.Dispose(); + Assert.Throws(() => disposed.Username = "username"); + }); + } + + [Fact] + public void Credential_Save() + { + TestUtils.RunIfWindows(() => + { + Win32Credential saved = new Win32Credential("username", "password", "target", CredentialType.Generic); + saved.PersistanceType = PersistanceType.LocalComputer; + Assert.True(saved.Save()); + }); + } + + [Fact] + public void Credential_Delete() + { + TestUtils.RunIfWindows(() => + { + new Win32Credential("username", "password", "target").Save(); + Assert.True(new Win32Credential("username", "password", "target").Delete()); + }); + } + + [Fact] + public void Credential_Delete_NullTerminator() + { + TestUtils.RunIfWindows(() => + { + Win32Credential credential = new Win32Credential((string)null, (string)null, "\0", CredentialType.None); + credential.Description = (string)null; + Assert.False(credential.Delete()); + }); + } + + [Fact] + public void Credential_Load() + { + TestUtils.RunIfWindows(() => + { + Win32Credential setup = new Win32Credential("username", "password", "target", CredentialType.Generic); + setup.Save(); + + Win32Credential credential = new Win32Credential { Target = "target", Type = CredentialType.Generic }; + Assert.True(credential.Load()); + + Assert.NotEmpty(credential.Username); + Assert.NotNull(credential.Password); + Assert.Equal("username", credential.Username); + Assert.Equal("password", credential.Password); + Assert.Equal("target", credential.Target); + }); + } + + [Fact] + public void Credential_Exists_Target_ShouldNotBeNull() + { + TestUtils.RunIfWindows(() => + { + new Win32Credential { Username = "username", Password = "password", Target = "target" }.Save(); + + Win32Credential existingCred = new Win32Credential { Target = "target" }; + Assert.True(existingCred.Exists()); + + existingCred.Delete(); + }); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs new file mode 100644 index 00000000..0725c209 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs @@ -0,0 +1,205 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Data.SqlClient; +using System.IO; +using System.Reflection; +using System.Threading.Tasks; +using Microsoft.SqlServer.Management.Common; +using Microsoft.SqlServer.Management.Smo; +using Microsoft.SqlServer.Management.SmoMetadataProvider; +using Microsoft.SqlServer.Management.SqlParser; +using Microsoft.SqlServer.Management.SqlParser.Binder; +using Microsoft.SqlServer.Management.SqlParser.Intellisense; +using Microsoft.SqlServer.Management.SqlParser.MetadataProvider; +using Microsoft.SqlServer.Management.SqlParser.Parser; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; +using Microsoft.SqlTools.ServiceLayer.LanguageServices; +using Microsoft.SqlTools.ServiceLayer.Test.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; +using Microsoft.SqlTools.Test.Utility; +using Moq; +using Moq.Protected; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices +{ + /// + /// Tests for the ServiceHost Language Service tests + /// + public class LanguageServiceTests + { + #region "Diagnostics tests" + + /// + /// Verify that the latest SqlParser (2016 as of this writing) is used by default + /// + [Fact] + public void LatestSqlParserIsUsedByDefault() + { + // This should only parse correctly on SQL server 2016 or newer + const string sql2016Text = + @"CREATE SECURITY POLICY [FederatedSecurityPolicy]" + "\r\n" + + @"ADD FILTER PREDICATE [rls].[fn_securitypredicate]([CustomerId])" + "\r\n" + + @"ON [dbo].[Customer];"; + + LanguageService service = TestObjects.GetTestLanguageService(); + + // parse + var scriptFile = new ScriptFile(); + scriptFile.SetFileContents(sql2016Text); + ScriptFileMarker[] fileMarkers = service.GetSemanticMarkers(scriptFile); + + // verify that no errors are detected + Assert.Equal(0, fileMarkers.Length); + } + + /// + /// Verify that the SQL parser correctly detects errors in text + /// + [Fact] + public void ParseSelectStatementWithoutErrors() + { + // sql statement with no errors + const string sqlWithErrors = "SELECT * FROM sys.objects"; + + // get the test service + LanguageService service = TestObjects.GetTestLanguageService(); + + // parse the sql statement + var scriptFile = new ScriptFile(); + scriptFile.SetFileContents(sqlWithErrors); + ScriptFileMarker[] fileMarkers = service.GetSemanticMarkers(scriptFile); + + // verify there are no errors + Assert.Equal(0, fileMarkers.Length); + } + + /// + /// Verify that the SQL parser correctly detects errors in text + /// + [Fact] + public void ParseSelectStatementWithError() + { + // sql statement with errors + const string sqlWithErrors = "SELECT *** FROM sys.objects"; + + // get test service + LanguageService service = TestObjects.GetTestLanguageService(); + + // parse sql statement + var scriptFile = new ScriptFile(); + scriptFile.SetFileContents(sqlWithErrors); + ScriptFileMarker[] fileMarkers = service.GetSemanticMarkers(scriptFile); + + // verify there is one error + Assert.Equal(1, fileMarkers.Length); + + // verify the position of the error + Assert.Equal(9, fileMarkers[0].ScriptRegion.StartColumnNumber); + Assert.Equal(1, fileMarkers[0].ScriptRegion.StartLineNumber); + Assert.Equal(10, fileMarkers[0].ScriptRegion.EndColumnNumber); + Assert.Equal(1, fileMarkers[0].ScriptRegion.EndLineNumber); + } + + /// + /// Verify that the SQL parser correctly detects errors in text + /// + [Fact] + public void ParseMultilineSqlWithErrors() + { + // multiline sql with errors + const string sqlWithErrors = + "SELECT *** FROM sys.objects;\n" + + "GO\n" + + "SELECT *** FROM sys.objects;\n"; + + // get test service + LanguageService service = TestObjects.GetTestLanguageService(); + + // parse sql + var scriptFile = new ScriptFile(); + scriptFile.SetFileContents(sqlWithErrors); + ScriptFileMarker[] fileMarkers = service.GetSemanticMarkers(scriptFile); + + // verify there are two errors + Assert.Equal(2, fileMarkers.Length); + + // check position of first error + Assert.Equal(9, fileMarkers[0].ScriptRegion.StartColumnNumber); + Assert.Equal(1, fileMarkers[0].ScriptRegion.StartLineNumber); + Assert.Equal(10, fileMarkers[0].ScriptRegion.EndColumnNumber); + Assert.Equal(1, fileMarkers[0].ScriptRegion.EndLineNumber); + + // check position of second error + Assert.Equal(9, fileMarkers[1].ScriptRegion.StartColumnNumber); + Assert.Equal(3, fileMarkers[1].ScriptRegion.StartLineNumber); + Assert.Equal(10, fileMarkers[1].ScriptRegion.EndColumnNumber); + Assert.Equal(3, fileMarkers[1].ScriptRegion.EndLineNumber); + } + + #endregion + + #region "Autocomplete Tests" + + // This test currently requires a live database connection to initialize + // SMO connected metadata provider. Since we don't want a live DB dependency + // in the CI unit tests this scenario is currently disabled. + //[Fact] + public void AutoCompleteFindCompletions() + { + TextDocumentPosition textDocument; + ConnectionInfo connInfo; + ScriptFile scriptFile; + Common.GetAutoCompleteTestObjects(out textDocument, out scriptFile, out connInfo); + + textDocument.Position.Character = 7; + scriptFile.Contents = "select "; + + var autoCompleteService = AutoCompleteService.Instance; + var completions = autoCompleteService.GetCompletionItems( + textDocument, + scriptFile, + connInfo); + + Assert.True(completions.Length > 0); + } + + /// + /// Creates a mock db command that returns a predefined result set + /// + public static DbCommand CreateTestCommand(Dictionary[][] data) + { + var commandMock = new Mock { CallBase = true }; + var commandMockSetup = commandMock.Protected() + .Setup("ExecuteDbDataReader", It.IsAny()); + + commandMockSetup.Returns(new TestDbDataReader(data)); + + return commandMock.Object; + } + + /// + /// Creates a mock db connection that returns predefined data when queried for a result set + /// + public DbConnection CreateMockDbConnection(Dictionary[][] data) + { + var connectionMock = new Mock { CallBase = true }; + connectionMock.Protected() + .Setup("CreateDbCommand") + .Returns(CreateTestCommand(data)); + + return connectionMock.Object; + } + + #endregion + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/Common.cs new file mode 100644 index 00000000..b575fe0c --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/Common.cs @@ -0,0 +1,17 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Text; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging +{ + public class Common + { + public const string TestEventString = @"{""type"":""event"",""event"":""testEvent"",""body"":null}"; + public const string TestEventFormatString = @"{{""event"":""testEvent"",""body"":{{""someString"":""{0}""}},""seq"":0,""type"":""event""}}"; + public static readonly int ExpectedMessageByteCount = Encoding.UTF8.GetByteCount(TestEventString); + + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageReaderTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageReaderTests.cs new file mode 100644 index 00000000..0a12dc3e --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageReaderTests.cs @@ -0,0 +1,241 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.IO; +using System.Text; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers; +using Newtonsoft.Json; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging +{ + public class MessageReaderTests + { + + private readonly IMessageSerializer messageSerializer; + + public MessageReaderTests() + { + this.messageSerializer = new V8MessageSerializer(); + } + + [Fact] + public void ReadsMessage() + { + MemoryStream inputStream = new MemoryStream(); + MessageReader messageReader = new MessageReader(inputStream, this.messageSerializer); + + // Write a message to the stream + byte[] messageBuffer = this.GetMessageBytes(Common.TestEventString); + inputStream.Write(this.GetMessageBytes(Common.TestEventString), 0, messageBuffer.Length); + + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + Message messageResult = messageReader.ReadMessage().Result; + Assert.Equal("testEvent", messageResult.Method); + + inputStream.Dispose(); + } + + [Fact] + public void ReadsManyBufferedMessages() + { + MemoryStream inputStream = new MemoryStream(); + MessageReader messageReader = + new MessageReader( + inputStream, + this.messageSerializer); + + // Get a message to use for writing to the stream + byte[] messageBuffer = this.GetMessageBytes(Common.TestEventString); + + // How many messages of this size should we write to overflow the buffer? + int overflowMessageCount = + (int)Math.Ceiling( + (MessageReader.DefaultBufferSize * 1.5) / messageBuffer.Length); + + // Write the necessary number of messages to the stream + for (int i = 0; i < overflowMessageCount; i++) + { + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + } + + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Read the written messages from the stream + for (int i = 0; i < overflowMessageCount; i++) + { + Message messageResult = messageReader.ReadMessage().Result; + Assert.Equal("testEvent", messageResult.Method); + } + + inputStream.Dispose(); + } + + [Fact] + public void ReadMalformedMissingHeaderTest() + { + using (MemoryStream inputStream = new MemoryStream()) + { + // If: + // ... I create a new stream and pass it information that is malformed + // ... and attempt to read a message from it + MessageReader messageReader = new MessageReader(inputStream, messageSerializer); + byte[] messageBuffer = Encoding.ASCII.GetBytes("This is an invalid header\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Then: + // ... An exception should be thrown while reading + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + } + } + + [Fact] + public void ReadMalformedContentLengthNonIntegerTest() + { + using (MemoryStream inputStream = new MemoryStream()) + { + // If: + // ... I create a new stream and pass it a non-integer content-length header + // ... and attempt to read a message from it + MessageReader messageReader = new MessageReader(inputStream, messageSerializer); + byte[] messageBuffer = Encoding.ASCII.GetBytes("Content-Length: asdf\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Then: + // ... An exception should be thrown while reading + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + } + } + + [Fact] + public void ReadMissingContentLengthHeaderTest() + { + using (MemoryStream inputStream = new MemoryStream()) + { + // If: + // ... I create a new stream and pass it a a message without a content-length header + // ... and attempt to read a message from it + MessageReader messageReader = new MessageReader(inputStream, messageSerializer); + byte[] messageBuffer = Encoding.ASCII.GetBytes("Content-Type: asdf\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Then: + // ... An exception should be thrown while reading + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + } + } + + [Fact] + public void ReadMalformedContentLengthTooShortTest() + { + using (MemoryStream inputStream = new MemoryStream()) + { + // If: + // ... Pass in an event that has an incorrect content length + // ... And pass in an event that is correct + MessageReader messageReader = new MessageReader(inputStream, messageSerializer); + byte[] messageBuffer = Encoding.ASCII.GetBytes("Content-Length: 10\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + messageBuffer = Encoding.UTF8.GetBytes(Common.TestEventString); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + messageBuffer = Encoding.ASCII.GetBytes("\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Then: + // ... The first read should fail with an exception while deserializing + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + + // ... The second read should fail with an exception while reading headers + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + } + } + + [Fact] + public void ReadMalformedThenValidTest() + { + // If: + // ... I create a new stream and pass it information that is malformed + // ... and attempt to read a message from it + // ... Then pass it information that is valid and attempt to read a message from it + using (MemoryStream inputStream = new MemoryStream()) + { + MessageReader messageReader = new MessageReader(inputStream, messageSerializer); + byte[] messageBuffer = Encoding.ASCII.GetBytes("This is an invalid header\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + messageBuffer = GetMessageBytes(Common.TestEventString); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Then: + // ... An exception should be thrown while reading the first one + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + + // ... A test event should be successfully read from the second one + Message messageResult = messageReader.ReadMessage().Result; + Assert.NotNull(messageResult); + Assert.Equal("testEvent", messageResult.Method); + } + } + + [Fact] + public void ReaderResizesBufferForLargeMessages() + { + MemoryStream inputStream = new MemoryStream(); + MessageReader messageReader = + new MessageReader( + inputStream, + this.messageSerializer); + + // Get a message with content so large that the buffer will need + // to be resized to fit it all. + byte[] messageBuffer = this.GetMessageBytes( + string.Format( + Common.TestEventFormatString, + new String('X', (int) (MessageReader.DefaultBufferSize*3)))); + + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + Message messageResult = messageReader.ReadMessage().Result; + Assert.Equal("testEvent", messageResult.Method); + + inputStream.Dispose(); + } + + private byte[] GetMessageBytes(string messageString, Encoding encoding = null) + { + if (encoding == null) + { + encoding = Encoding.UTF8; + } + + byte[] messageBytes = Encoding.UTF8.GetBytes(messageString); + byte[] headerBytes = Encoding.ASCII.GetBytes(string.Format(Constants.ContentLengthFormatString, messageBytes.Length)); + + // Copy the bytes into a single buffer + byte[] finalBytes = new byte[headerBytes.Length + messageBytes.Length]; + Buffer.BlockCopy(headerBytes, 0, finalBytes, 0, headerBytes.Length); + Buffer.BlockCopy(messageBytes, 0, finalBytes, headerBytes.Length, messageBytes.Length); + + return finalBytes; + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageWriterTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageWriterTests.cs new file mode 100644 index 00000000..3c007a85 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageWriterTests.cs @@ -0,0 +1,55 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.IO; +using System.Text; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging +{ + public class MessageWriterTests + { + private readonly IMessageSerializer messageSerializer; + + public MessageWriterTests() + { + this.messageSerializer = new V8MessageSerializer(); + } + + [Fact] + public async Task WritesMessage() + { + MemoryStream outputStream = new MemoryStream(); + MessageWriter messageWriter = new MessageWriter(outputStream, this.messageSerializer); + + // Write the message and then roll back the stream to be read + // TODO: This will need to be redone! + await messageWriter.WriteMessage(Hosting.Protocol.Contracts.Message.Event("testEvent", null)); + outputStream.Seek(0, SeekOrigin.Begin); + + string expectedHeaderString = string.Format(Constants.ContentLengthFormatString, + Common.ExpectedMessageByteCount); + + byte[] buffer = new byte[128]; + await outputStream.ReadAsync(buffer, 0, expectedHeaderString.Length); + + Assert.Equal( + expectedHeaderString, + Encoding.ASCII.GetString(buffer, 0, expectedHeaderString.Length)); + + // Read the message + await outputStream.ReadAsync(buffer, 0, Common.ExpectedMessageByteCount); + + Assert.Equal(Common.TestEventString, + Encoding.UTF8.GetString(buffer, 0, Common.ExpectedMessageByteCount)); + + outputStream.Dispose(); + } + + } +} diff --git a/test/ServiceHost.Test/Message/TestMessageTypes.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/TestMessageTypes.cs similarity index 74% rename from test/ServiceHost.Test/Message/TestMessageTypes.cs rename to test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/TestMessageTypes.cs index cc5981dc..89238098 100644 --- a/test/ServiceHost.Test/Message/TestMessageTypes.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/TestMessageTypes.cs @@ -3,19 +3,16 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol; -using System; using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; -namespace Microsoft.SqlTools.EditorServices.Test.Protocol.MessageProtocol +namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging { #region Request Types internal class TestRequest { - public Task ProcessMessage( - EditorSession editorSession, - MessageWriter messageWriter) + public Task ProcessMessage(MessageWriter messageWriter) { return Task.FromResult(false); } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Microsoft.SqlTools.ServiceLayer.Test.xproj b/test/Microsoft.SqlTools.ServiceLayer.Test/Microsoft.SqlTools.ServiceLayer.Test.xproj new file mode 100644 index 00000000..cb4c13ed --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Microsoft.SqlTools.ServiceLayer.Test.xproj @@ -0,0 +1,22 @@ + + + + 14.0 + $(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion) + + + + 2d771d16-9d85-4053-9f79-e2034737deef + Microsoft.SqlTools.ServiceLayer.Test + .\obj + .\bin\ + v4.5.2 + + + 2.0 + + + + + + \ No newline at end of file diff --git a/test/ServiceHost.Test/Properties/AssemblyInfo.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Properties/AssemblyInfo.cs similarity index 97% rename from test/ServiceHost.Test/Properties/AssemblyInfo.cs rename to test/Microsoft.SqlTools.ServiceLayer.Test/Properties/AssemblyInfo.cs index 5cf54b90..a1587015 100644 --- a/test/ServiceHost.Test/Properties/AssemblyInfo.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Properties/AssemblyInfo.cs @@ -4,7 +4,6 @@ // using System.Reflection; -using System.Runtime.CompilerServices; using System.Runtime.InteropServices; // General Information about an assembly is controlled through the following diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs new file mode 100644 index 00000000..7e0e5a4d --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs @@ -0,0 +1,127 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; +using Moq; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution +{ + public class CancelTests + { + [Fact] + public void CancelInProgressQueryTest() + { + // If: + // ... I request a query (doesn't matter what kind) and execute it + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); + var executeParams = new QueryExecuteParams { QueryText = Common.StandardQuery, OwnerUri = Common.OwnerUri }; + var executeRequest = + RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); + queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + queryService.ActiveQueries[Common.OwnerUri].HasExecuted = false; // Fake that it hasn't completed execution + + // ... And then I request to cancel the query + var cancelParams = new QueryCancelParams {OwnerUri = Common.OwnerUri}; + QueryCancelResult result = null; + var cancelRequest = GetQueryCancelResultContextMock(qcr => result = qcr, null); + queryService.HandleCancelRequest(cancelParams, cancelRequest.Object).Wait(); + + // Then: + // ... I should have seen a successful event (no messages) + VerifyQueryCancelCallCount(cancelRequest, Times.Once(), Times.Never()); + Assert.Null(result.Messages); + + // ... The query should have been disposed as well + Assert.Empty(queryService.ActiveQueries); + } + + [Fact] + public void CancelExecutedQueryTest() + { + // If: + // ... I request a query (doesn't matter what kind) and wait for execution + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); + var executeParams = new QueryExecuteParams {QueryText = Common.StandardQuery, OwnerUri = Common.OwnerUri}; + var executeRequest = + RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); + queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + + // ... And then I request to cancel the query + var cancelParams = new QueryCancelParams {OwnerUri = Common.OwnerUri}; + QueryCancelResult result = null; + var cancelRequest = GetQueryCancelResultContextMock(qcr => result = qcr, null); + queryService.HandleCancelRequest(cancelParams, cancelRequest.Object).Wait(); + + // Then: + // ... I should have seen a result event with an error message + VerifyQueryCancelCallCount(cancelRequest, Times.Once(), Times.Never()); + Assert.NotNull(result.Messages); + + // ... The query should not have been disposed + Assert.NotEmpty(queryService.ActiveQueries); + } + + [Fact] + public void CancelNonExistantTest() + { + // If: + // ... I request to cancel a query that doesn't exist + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), false); + var cancelParams = new QueryCancelParams {OwnerUri = "Doesn't Exist"}; + QueryCancelResult result = null; + var cancelRequest = GetQueryCancelResultContextMock(qcr => result = qcr, null); + queryService.HandleCancelRequest(cancelParams, cancelRequest.Object).Wait(); + + // Then: + // ... I should have seen a result event with an error message + VerifyQueryCancelCallCount(cancelRequest, Times.Once(), Times.Never()); + Assert.NotNull(result.Messages); + } + + #region Mocking + + private static Mock> GetQueryCancelResultContextMock( + Action resultCallback, + Action errorCallback) + { + var requestContext = new Mock>(); + + // Setup the mock for SendResult + var sendResultFlow = requestContext + .Setup(rc => rc.SendResult(It.IsAny())) + .Returns(Task.FromResult(0)); + if (resultCallback != null) + { + sendResultFlow.Callback(resultCallback); + } + + // Setup the mock for SendError + var sendErrorFlow = requestContext + .Setup(rc => rc.SendError(It.IsAny())) + .Returns(Task.FromResult(0)); + if (errorCallback != null) + { + sendErrorFlow.Callback(errorCallback); + } + + return requestContext; + } + + private static void VerifyQueryCancelCallCount(Mock> mock, + Times sendResultCalls, Times sendErrorCalls) + { + mock.Verify(rc => rc.SendResult(It.IsAny()), sendResultCalls); + mock.Verify(rc => rc.SendError(It.IsAny()), sendErrorCalls); + } + + #endregion + + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs new file mode 100644 index 00000000..c7a00c09 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs @@ -0,0 +1,288 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.IO; +using System.Data.SqlClient; +using System.Threading; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; +using Microsoft.SqlServer.Management.Common; +using Microsoft.SqlServer.Management.SmoMetadataProvider; +using Microsoft.SqlServer.Management.SqlParser.Binder; +using Microsoft.SqlServer.Management.SqlParser.MetadataProvider; +using Microsoft.SqlTools.ServiceLayer.LanguageServices; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage; +using Microsoft.SqlTools.ServiceLayer.SqlContext; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; +using Moq; +using Moq.Protected; + +namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution +{ + public class Common + { + public const string StandardQuery = "SELECT * FROM sys.objects"; + + public const string InvalidQuery = "SELECT *** FROM sys.objects"; + + public const string NoOpQuery = "-- No ops here, just us chickens."; + + public const string OwnerUri = "testFile"; + + public const int StandardRows = 5; + + public const int StandardColumns = 5; + + public static string TestServer { get; set; } + + public static string TestDatabase { get; set; } + + static Common() + { + TestServer = "sqltools11"; + TestDatabase = "master"; + } + + public static Dictionary[] StandardTestData + { + get { return GetTestData(StandardRows, StandardColumns); } + } + + public static Dictionary[] GetTestData(int columns, int rows) + { + Dictionary[] output = new Dictionary[rows]; + for (int row = 0; row < rows; row++) + { + Dictionary rowDictionary = new Dictionary(); + for (int column = 0; column < columns; column++) + { + rowDictionary.Add(string.Format("column{0}", column), string.Format("val{0}{1}", column, row)); + } + output[row] = rowDictionary; + } + + return output; + } + + public static Batch GetBasicExecutedBatch() + { + Batch batch = new Batch(StandardQuery, 1, GetFileStreamFactory()); + batch.Execute(CreateTestConnection(new[] {StandardTestData}, false), CancellationToken.None).Wait(); + return batch; + } + + public static Query GetBasicExecutedQuery() + { + ConnectionInfo ci = CreateTestConnectionInfo(new[] {StandardTestData}, false); + Query query = new Query(StandardQuery, ci, new QueryExecutionSettings(), GetFileStreamFactory()); + query.Execute().Wait(); + return query; + } + + #region FileStreamWriteMocking + + public static IFileStreamFactory GetFileStreamFactory() + { + Mock mock = new Mock(); + mock.Setup(fsf => fsf.GetReader(It.IsAny())) + .Returns(new ServiceBufferFileStreamReader(new InMemoryWrapper(), It.IsAny())); + mock.Setup(fsf => fsf.GetWriter(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(new ServiceBufferFileStreamWriter(new InMemoryWrapper(), It.IsAny(), 1024, + 1024)); + + return mock.Object; + } + + public class InMemoryWrapper : IFileStreamWrapper + { + private readonly byte[] storage = new byte[8192]; + private readonly MemoryStream memoryStream; + private bool readingOnly; + + public InMemoryWrapper() + { + memoryStream = new MemoryStream(storage); + } + + public void Dispose() + { + // We'll dispose this via a special method + } + + public void Init(string fileName, int bufferSize, FileAccess fAccess) + { + readingOnly = fAccess == FileAccess.Read; + } + + public int ReadData(byte[] buffer, int bytes) + { + return ReadData(buffer, bytes, memoryStream.Position); + } + + public int ReadData(byte[] buffer, int bytes, long fileOffset) + { + memoryStream.Seek(fileOffset, SeekOrigin.Begin); + return memoryStream.Read(buffer, 0, bytes); + } + + public int WriteData(byte[] buffer, int bytes) + { + if (readingOnly) { throw new InvalidOperationException(); } + memoryStream.Write(buffer, 0, bytes); + memoryStream.Flush(); + return bytes; + } + + public void Flush() + { + if (readingOnly) { throw new InvalidOperationException(); } + } + + public void Close() + { + memoryStream.Dispose(); + } + } + + #endregion + + #region DbConnection Mocking + + public static DbCommand CreateTestCommand(Dictionary[][] data, bool throwOnRead) + { + var commandMock = new Mock { CallBase = true }; + var commandMockSetup = commandMock.Protected() + .Setup("ExecuteDbDataReader", It.IsAny()); + + // Setup the expected behavior + if (throwOnRead) + { + var mockException = new Mock(); + mockException.SetupGet(dbe => dbe.Message).Returns("Message"); + commandMockSetup.Throws(mockException.Object); + } + else + { + commandMockSetup.Returns(new TestDbDataReader(data)); + } + + + return commandMock.Object; + } + + public static DbConnection CreateTestConnection(Dictionary[][] data, bool throwOnRead) + { + var connectionMock = new Mock { CallBase = true }; + connectionMock.Protected() + .Setup("CreateDbCommand") + .Returns(CreateTestCommand(data, throwOnRead)); + + return connectionMock.Object; + } + + public static ISqlConnectionFactory CreateMockFactory(Dictionary[][] data, bool throwOnRead) + { + var mockFactory = new Mock(); + mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny())) + .Returns(CreateTestConnection(data, throwOnRead)); + + return mockFactory.Object; + } + + public static ConnectionInfo CreateTestConnectionInfo(Dictionary[][] data, bool throwOnRead) + { + // Create connection info + ConnectionDetails connDetails = new ConnectionDetails + { + UserName = "sa", + Password = "Yukon900", + DatabaseName = Common.TestDatabase, + ServerName = Common.TestServer + }; + + return new ConnectionInfo(CreateMockFactory(data, throwOnRead), OwnerUri, connDetails); + } + + #endregion + + #region Service Mocking + + public static void GetAutoCompleteTestObjects( + out TextDocumentPosition textDocument, + out ScriptFile scriptFile, + out ConnectionInfo connInfo + ) + { + textDocument = new TextDocumentPosition + { + TextDocument = new TextDocumentIdentifier {Uri = OwnerUri}, + Position = new Position + { + Line = 0, + Character = 0 + } + }; + + connInfo = Common.CreateTestConnectionInfo(null, false); + + var srvConn = GetServerConnection(connInfo); + var displayInfoProvider = new MetadataDisplayInfoProvider(); + var metadataProvider = SmoMetadataProvider.CreateConnectedProvider(srvConn); + var binder = BinderProvider.CreateBinder(metadataProvider); + + LanguageService.Instance.ScriptParseInfoMap.Add(textDocument.TextDocument.Uri, + new ScriptParseInfo + { + Binder = binder, + MetadataProvider = metadataProvider, + MetadataDisplayInfoProvider = displayInfoProvider + }); + + scriptFile = new ScriptFile {ClientFilePath = textDocument.TextDocument.Uri}; + + } + + public static ServerConnection GetServerConnection(ConnectionInfo connection) + { + string connectionString = ConnectionService.BuildConnectionString(connection.ConnectionDetails); + var sqlConnection = new SqlConnection(connectionString); + return new ServerConnection(sqlConnection); + } + + public static ConnectionDetails GetTestConnectionDetails() + { + return new ConnectionDetails + { + DatabaseName = "123", + Password = "456", + ServerName = "789", + UserName = "012" + }; + } + + public static QueryExecutionService GetPrimedExecutionService(ISqlConnectionFactory factory, bool isConnected) + { + var connectionService = new ConnectionService(factory); + if (isConnected) + { + connectionService.Connect(new ConnectParams + { + Connection = GetTestConnectionDetails(), + OwnerUri = OwnerUri + }); + } + return new QueryExecutionService(connectionService) {BufferFileStreamFactory = GetFileStreamFactory()}; + } + + #endregion + + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/FileStreamWrapperTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/FileStreamWrapperTests.cs new file mode 100644 index 00000000..5911c577 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/FileStreamWrapperTests.cs @@ -0,0 +1,218 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.IO; +using System.Linq; +using System.Text; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.DataStorage +{ + public class FileStreamWrapperTests + { + [Theory] + [InlineData(null)] + [InlineData("")] + [InlineData(" ")] + public void InitInvalidFilenameParameter(string fileName) + { + // If: + // ... I have a file stream wrapper that is initialized with invalid fileName + // Then: + // ... It should throw an argument null exception + using (FileStreamWrapper fsw = new FileStreamWrapper()) + { + Assert.Throws(() => fsw.Init(fileName, 8192, FileAccess.Read)); + } + } + + [Theory] + [InlineData(0)] + [InlineData(-1)] + public void InitInvalidBufferLength(int bufferLength) + { + // If: + // ... I have a file stream wrapper that is initialized with an invalid buffer length + // Then: + // ... I should throw an argument out of range exception + using (FileStreamWrapper fsw = new FileStreamWrapper()) + { + Assert.Throws(() => fsw.Init("validFileName", bufferLength, FileAccess.Read)); + } + } + + [Fact] + public void InitInvalidFileAccessMode() + { + // If: + // ... I attempt to open a file stream wrapper that is initialized with an invalid file + // access mode + // Then: + // ... I should get an invalid argument exception + using (FileStreamWrapper fsw = new FileStreamWrapper()) + { + Assert.Throws(() => fsw.Init("validFileName", 8192, FileAccess.Write)); + } + } + + [Fact] + public void InitSuccessful() + { + string fileName = Path.GetTempFileName(); + + try + { + using (FileStreamWrapper fsw = new FileStreamWrapper()) + { + // If: + // ... I have a file stream wrapper that is initialized with valid parameters + fsw.Init(fileName, 8192, FileAccess.ReadWrite); + + // Then: + // ... The file should exist + FileInfo fileInfo = new FileInfo(fileName); + Assert.True(fileInfo.Exists); + } + } + finally + { + // Cleanup: + // ... Delete the file that was created + try { File.Delete(fileName); } catch { /* Don't care */ } + } + } + + [Fact] + public void PerformOpWithoutInit() + { + byte[] buf = new byte[10]; + + using (FileStreamWrapper fsw = new FileStreamWrapper()) + { + // If: + // ... I have a file stream wrapper that hasn't been initialized + // Then: + // ... Attempting to perform any operation will result in an exception + Assert.Throws(() => fsw.ReadData(buf, 1)); + Assert.Throws(() => fsw.ReadData(buf, 1, 0)); + Assert.Throws(() => fsw.WriteData(buf, 1)); + Assert.Throws(() => fsw.Flush()); + } + } + + [Fact] + public void PerformWriteOpOnReadOnlyWrapper() + { + byte[] buf = new byte[10]; + + using (FileStreamWrapper fsw = new FileStreamWrapper()) + { + // If: + // ... I have a readonly file stream wrapper + // Then: + // ... Attempting to perform any write operation should result in an exception + Assert.Throws(() => fsw.WriteData(buf, 1)); + Assert.Throws(() => fsw.Flush()); + } + } + + [Theory] + [InlineData(1024, 20, 10)] // Standard scenario + [InlineData(1024, 100, 100)] // Requested more bytes than there are + [InlineData(5, 20, 10)] // Internal buffer too small, force a move-to operation + public void ReadData(int internalBufferLength, int outBufferLength, int requestedBytes) + { + // Setup: + // ... I have a file that has a handful of bytes in it + string fileName = Path.GetTempFileName(); + const string stringToWrite = "hello"; + CreateTestFile(fileName, stringToWrite); + byte[] targetBytes = Encoding.Unicode.GetBytes(stringToWrite); + + try + { + // If: + // ... I have a file stream wrapper that has been initialized to an existing file + // ... And I read some bytes from it + int bytesRead; + byte[] buf = new byte[outBufferLength]; + using (FileStreamWrapper fsw = new FileStreamWrapper()) + { + fsw.Init(fileName, internalBufferLength, FileAccess.Read); + bytesRead = fsw.ReadData(buf, targetBytes.Length); + } + + // Then: + // ... I should get those bytes back + Assert.Equal(targetBytes.Length, bytesRead); + Assert.True(targetBytes.Take(targetBytes.Length).SequenceEqual(buf.Take(targetBytes.Length))); + + } + finally + { + // Cleanup: + // ... Delete the test file + CleanupTestFile(fileName); + } + } + + [Theory] + [InlineData(1024)] // Standard scenario + [InlineData(10)] // Internal buffer too small, forces a flush + public void WriteData(int internalBufferLength) + { + string fileName = Path.GetTempFileName(); + byte[] bytesToWrite = Encoding.Unicode.GetBytes("hello"); + + try + { + // If: + // ... I have a file stream that has been initialized + // ... And I write some bytes to it + using (FileStreamWrapper fsw = new FileStreamWrapper()) + { + fsw.Init(fileName, internalBufferLength, FileAccess.ReadWrite); + int bytesWritten = fsw.WriteData(bytesToWrite, bytesToWrite.Length); + + Assert.Equal(bytesToWrite.Length, bytesWritten); + } + + // Then: + // ... The file I wrote to should contain only the bytes I wrote out + using (FileStream fs = File.OpenRead(fileName)) + { + byte[] readBackBytes = new byte[1024]; + int bytesRead = fs.Read(readBackBytes, 0, readBackBytes.Length); + + Assert.Equal(bytesToWrite.Length, bytesRead); // If bytes read is not equal, then more or less of the original string was written to the file + Assert.True(bytesToWrite.SequenceEqual(readBackBytes.Take(bytesRead))); + } + } + finally + { + // Cleanup: + // ... Delete the test file + CleanupTestFile(fileName); + } + } + + private static void CreateTestFile(string fileName, string value) + { + using (FileStream fs = new FileStream(fileName, FileMode.OpenOrCreate, FileAccess.ReadWrite)) + { + byte[] bytesToWrite = Encoding.Unicode.GetBytes(value); + fs.Write(bytesToWrite, 0, bytesToWrite.Length); + fs.Flush(); + } + } + + private static void CleanupTestFile(string fileName) + { + try { File.Delete(fileName); } catch { /* Don't Care */} + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/ServiceBufferFileStreamReaderWriterTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/ServiceBufferFileStreamReaderWriterTests.cs new file mode 100644 index 00000000..b10a7f92 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/ServiceBufferFileStreamReaderWriterTests.cs @@ -0,0 +1,295 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Data.SqlTypes; +using System.Text; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.DataStorage +{ + public class ReaderWriterPairTest + { + private static void VerifyReadWrite(int valueLength, T value, Func writeFunc, Func> readFunc) + { + // Setup: Create a mock file stream wrapper + Common.InMemoryWrapper mockWrapper = new Common.InMemoryWrapper(); + try + { + // If: + // ... I write a type T to the writer + using (ServiceBufferFileStreamWriter writer = new ServiceBufferFileStreamWriter(mockWrapper, "abc", 10, 10)) + { + int writtenBytes = writeFunc(writer, value); + Assert.Equal(valueLength, writtenBytes); + } + + // ... And read the type T back + FileStreamReadResult outValue; + using (ServiceBufferFileStreamReader reader = new ServiceBufferFileStreamReader(mockWrapper, "abc")) + { + outValue = readFunc(reader); + } + + // Then: + Assert.Equal(value, outValue.Value); + Assert.Equal(valueLength, outValue.TotalLength); + Assert.False(outValue.IsNull); + } + finally + { + // Cleanup: Close the wrapper + mockWrapper.Close(); + } + } + + [Theory] + [InlineData(0)] + [InlineData(10)] + [InlineData(-10)] + [InlineData(short.MaxValue)] // Two byte number + [InlineData(short.MinValue)] // Negative two byte number + public void Int16(short value) + { + VerifyReadWrite(sizeof(short) + 1, value, (writer, val) => writer.WriteInt16(val), reader => reader.ReadInt16(0)); + } + + [Theory] + [InlineData(0)] + [InlineData(10)] + [InlineData(-10)] + [InlineData(short.MaxValue)] // Two byte number + [InlineData(short.MinValue)] // Negative two byte number + [InlineData(int.MaxValue)] // Four byte number + [InlineData(int.MinValue)] // Negative four byte number + public void Int32(int value) + { + VerifyReadWrite(sizeof(int) + 1, value, (writer, val) => writer.WriteInt32(val), reader => reader.ReadInt32(0)); + } + + [Theory] + [InlineData(0)] + [InlineData(10)] + [InlineData(-10)] + [InlineData(short.MaxValue)] // Two byte number + [InlineData(short.MinValue)] // Negative two byte number + [InlineData(int.MaxValue)] // Four byte number + [InlineData(int.MinValue)] // Negative four byte number + [InlineData(long.MaxValue)] // Eight byte number + [InlineData(long.MinValue)] // Negative eight byte number + public void Int64(long value) + { + VerifyReadWrite(sizeof(long) + 1, value, (writer, val) => writer.WriteInt64(val), reader => reader.ReadInt64(0)); + } + + [Theory] + [InlineData(0)] + [InlineData(10)] + public void Byte(byte value) + { + VerifyReadWrite(sizeof(byte) + 1, value, (writer, val) => writer.WriteByte(val), reader => reader.ReadByte(0)); + } + + [Theory] + [InlineData('a')] + [InlineData('1')] + [InlineData((char)0x9152)] // Test something in the UTF-16 space + public void Char(char value) + { + VerifyReadWrite(sizeof(char) + 1, value, (writer, val) => writer.WriteChar(val), reader => reader.ReadChar(0)); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void Boolean(bool value) + { + VerifyReadWrite(sizeof(bool) + 1, value, (writer, val) => writer.WriteBoolean(val), reader => reader.ReadBoolean(0)); + } + + [Theory] + [InlineData(0)] + [InlineData(10.1)] + [InlineData(-10.1)] + [InlineData(float.MinValue)] + [InlineData(float.MaxValue)] + [InlineData(float.PositiveInfinity)] + [InlineData(float.NegativeInfinity)] + public void Single(float value) + { + VerifyReadWrite(sizeof(float) + 1, value, (writer, val) => writer.WriteSingle(val), reader => reader.ReadSingle(0)); + } + + [Theory] + [InlineData(0)] + [InlineData(10.1)] + [InlineData(-10.1)] + [InlineData(float.MinValue)] + [InlineData(float.MaxValue)] + [InlineData(float.PositiveInfinity)] + [InlineData(float.NegativeInfinity)] + [InlineData(double.PositiveInfinity)] + [InlineData(double.NegativeInfinity)] + [InlineData(double.MinValue)] + [InlineData(double.MaxValue)] + public void Double(double value) + { + VerifyReadWrite(sizeof(double) + 1, value, (writer, val) => writer.WriteDouble(val), reader => reader.ReadDouble(0)); + } + + [Fact] + public void SqlDecimalTest() + { + // Setup: Create some test values + // NOTE: We are doing these here instead of InlineData because SqlDecimal values can't be written as constant expressions + SqlDecimal[] testValues = + { + SqlDecimal.MaxValue, SqlDecimal.MinValue, new SqlDecimal(0x01, 0x01, true, 0, 0, 0, 0) + }; + foreach (SqlDecimal value in testValues) + { + int valueLength = 4 + value.BinData.Length; + VerifyReadWrite(valueLength, value, (writer, val) => writer.WriteSqlDecimal(val), reader => reader.ReadSqlDecimal(0)); + } + } + + [Fact] + public void Decimal() + { + // Setup: Create some test values + // NOTE: We are doing these here instead of InlineData because Decimal values can't be written as constant expressions + decimal[] testValues = + { + decimal.Zero, decimal.One, decimal.MinusOne, decimal.MinValue, decimal.MaxValue + }; + + foreach (decimal value in testValues) + { + int valueLength = decimal.GetBits(value).Length*4 + 1; + VerifyReadWrite(valueLength, value, (writer, val) => writer.WriteDecimal(val), reader => reader.ReadDecimal(0)); + } + } + + [Fact] + public void DateTimeTest() + { + // Setup: Create some test values + // NOTE: We are doing these here instead of InlineData because DateTime values can't be written as constant expressions + DateTime[] testValues = + { + DateTime.Now, DateTime.UtcNow, DateTime.MinValue, DateTime.MaxValue + }; + foreach (DateTime value in testValues) + { + VerifyReadWrite(sizeof(long) + 1, value, (writer, val) => writer.WriteDateTime(val), reader => reader.ReadDateTime(0)); + } + } + + [Fact] + public void DateTimeOffsetTest() + { + // Setup: Create some test values + // NOTE: We are doing these here instead of InlineData because DateTimeOffset values can't be written as constant expressions + DateTimeOffset[] testValues = + { + DateTimeOffset.Now, DateTimeOffset.UtcNow, DateTimeOffset.MinValue, DateTimeOffset.MaxValue + }; + foreach (DateTimeOffset value in testValues) + { + VerifyReadWrite((sizeof(long) + 1)*2, value, (writer, val) => writer.WriteDateTimeOffset(val), reader => reader.ReadDateTimeOffset(0)); + } + } + + [Fact] + public void TimeSpanTest() + { + // Setup: Create some test values + // NOTE: We are doing these here instead of InlineData because TimeSpan values can't be written as constant expressions + TimeSpan[] testValues = + { + TimeSpan.Zero, TimeSpan.MinValue, TimeSpan.MaxValue, TimeSpan.FromMinutes(60) + }; + foreach (TimeSpan value in testValues) + { + VerifyReadWrite(sizeof(long) + 1, value, (writer, val) => writer.WriteTimeSpan(val), reader => reader.ReadTimeSpan(0)); + } + } + + [Fact] + public void StringNullTest() + { + // Setup: Create a mock file stream wrapper + Common.InMemoryWrapper mockWrapper = new Common.InMemoryWrapper(); + + // If: + // ... I write null as a string to the writer + using (ServiceBufferFileStreamWriter writer = new ServiceBufferFileStreamWriter(mockWrapper, "abc", 10, 10)) + { + // Then: + // ... I should get an argument null exception + Assert.Throws(() => writer.WriteString(null)); + } + } + + [Theory] + [InlineData(0, null)] // Test of empty string + [InlineData(1, new[] { 'j' })] + [InlineData(1, new[] { (char)0x9152 })] + [InlineData(100, new[] { 'j', (char)0x9152 })] // Test alternating utf-16/ascii characters + [InlineData(512, new[] { 'j', (char)0x9152 })] // Test that requires a 4 byte length + public void StringTest(int length, char[] values) + { + // Setup: + // ... Generate the test value + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < length; i++) + { + sb.Append(values[i%values.Length]); + } + string value = sb.ToString(); + int lengthLength = length == 0 || length > 255 ? 5 : 1; + VerifyReadWrite(sizeof(char)*length + lengthLength, value, (writer, val) => writer.WriteString(value), reader => reader.ReadString(0)); + } + + [Fact] + public void BytesNullTest() + { + // Setup: Create a mock file stream wrapper + Common.InMemoryWrapper mockWrapper = new Common.InMemoryWrapper(); + + // If: + // ... I write null as a string to the writer + using (ServiceBufferFileStreamWriter writer = new ServiceBufferFileStreamWriter(mockWrapper, "abc", 10, 10)) + { + // Then: + // ... I should get an argument null exception + Assert.Throws(() => writer.WriteBytes(null, 0)); + } + } + + [Theory] + [InlineData(0, new byte[] { 0x00 })] // Test of empty byte[] + [InlineData(1, new byte[] { 0x00 })] + [InlineData(1, new byte[] { 0xFF })] + [InlineData(100, new byte[] { 0x10, 0xFF, 0x00 })] + [InlineData(512, new byte[] { 0x10, 0xFF, 0x00 })] // Test that requires a 4 byte length + public void Bytes(int length, byte[] values) + { + // Setup: + // ... Generate the test value + List sb = new List(); + for (int i = 0; i < length; i++) + { + sb.Add(values[i % values.Length]); + } + byte[] value = sb.ToArray(); + int lengthLength = length == 0 || length > 255 ? 5 : 1; + int valueLength = sizeof(byte)*length + lengthLength; + VerifyReadWrite(valueLength, value, (writer, val) => writer.WriteBytes(value, length), reader => reader.ReadBytes(0)); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs new file mode 100644 index 00000000..8c79296d --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs @@ -0,0 +1,99 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; +using Moq; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution +{ + public class DisposeTests + { + [Fact] + public void DisposeExecutedQuery() + { + // If: + // ... I request a query (doesn't matter what kind) + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); + var executeParams = new QueryExecuteParams {QueryText = "Doesn'tMatter", OwnerUri = Common.OwnerUri}; + var executeRequest = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); + queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + + // ... And then I dispose of the query + var disposeParams = new QueryDisposeParams {OwnerUri = Common.OwnerUri}; + QueryDisposeResult result = null; + var disposeRequest = GetQueryDisposeResultContextMock(qdr => result = qdr, null); + queryService.HandleDisposeRequest(disposeParams, disposeRequest.Object).Wait(); + + // Then: + // ... I should have seen a successful result + // ... And the active queries should be empty + VerifyQueryDisposeCallCount(disposeRequest, Times.Once(), Times.Never()); + Assert.Null(result.Messages); + Assert.Empty(queryService.ActiveQueries); + } + + [Fact] + public void QueryDisposeMissingQuery() + { + // If: + // ... I attempt to dispose a query that doesn't exist + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), false); + var disposeParams = new QueryDisposeParams {OwnerUri = Common.OwnerUri}; + QueryDisposeResult result = null; + var disposeRequest = GetQueryDisposeResultContextMock(qdr => result = qdr, null); + queryService.HandleDisposeRequest(disposeParams, disposeRequest.Object).Wait(); + + // Then: + // ... I should have gotten an error result + VerifyQueryDisposeCallCount(disposeRequest, Times.Once(), Times.Never()); + Assert.NotNull(result.Messages); + Assert.NotEmpty(result.Messages); + } + + #region Mocking + + private Mock> GetQueryDisposeResultContextMock( + Action resultCallback, + Action errorCallback) + { + var requestContext = new Mock>(); + + // Setup the mock for SendResult + var sendResultFlow = requestContext + .Setup(rc => rc.SendResult(It.IsAny())) + .Returns(Task.FromResult(0)); + if (resultCallback != null) + { + sendResultFlow.Callback(resultCallback); + } + + // Setup the mock for SendError + var sendErrorFlow = requestContext + .Setup(rc => rc.SendError(It.IsAny())) + .Returns(Task.FromResult(0)); + if (errorCallback != null) + { + sendErrorFlow.Callback(errorCallback); + } + + return requestContext; + } + + private void VerifyQueryDisposeCallCount(Mock> mock, Times sendResultCalls, + Times sendErrorCalls) + { + mock.Verify(rc => rc.SendResult(It.IsAny()), sendResultCalls); + mock.Verify(rc => rc.SendError(It.IsAny()), sendErrorCalls); + } + + #endregion + + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs new file mode 100644 index 00000000..bda6ca0d --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs @@ -0,0 +1,627 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Data.Common; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Microsoft.SqlTools.ServiceLayer.SqlContext; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; +using Microsoft.SqlTools.ServiceLayer.Workspace; +using Moq; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution +{ + public class ExecuteTests + { + #region Batch Class Tests + + [Fact] + public void BatchCreationTest() + { + // If I create a new batch... + Batch batch = new Batch(Common.StandardQuery, 1, Common.GetFileStreamFactory()); + + // Then: + // ... The text of the batch should be stored + Assert.NotEmpty(batch.BatchText); + + // ... It should not have executed and no error + Assert.False(batch.HasExecuted, "The query should not have executed."); + Assert.False(batch.HasError, "The batch should not have an error"); + + // ... The results should be empty + Assert.Empty(batch.ResultSets); + Assert.Empty(batch.ResultSummaries); + Assert.Empty(batch.ResultMessages); + + // ... The start line of the batch should be 0 + Assert.Equal(0, batch.StartLine); + } + + [Fact] + public void BatchExecuteNoResultSets() + { + // If I execute a query that should get no result sets + Batch batch = new Batch(Common.StandardQuery, 1, Common.GetFileStreamFactory()); + batch.Execute(GetConnection(Common.CreateTestConnectionInfo(null, false)), CancellationToken.None).Wait(); + + // Then: + // ... It should have executed without error + Assert.True(batch.HasExecuted, "The query should have been marked executed."); + Assert.False(batch.HasError, "The batch should not have an error"); + + // ... The results should be empty + Assert.Empty(batch.ResultSets); + Assert.Empty(batch.ResultSummaries); + + // ... The results should not be null + Assert.NotNull(batch.ResultSets); + Assert.NotNull(batch.ResultSummaries); + + // ... There should be a message for how many rows were affected + Assert.Equal(1, batch.ResultMessages.Count()); + } + + [Fact] + public void BatchExecuteOneResultSet() + { + int resultSets = 1; + ConnectionInfo ci = Common.CreateTestConnectionInfo(new[] { Common.StandardTestData }, false); + + // If I execute a query that should get one result set + Batch batch = new Batch(Common.StandardQuery, 1, Common.GetFileStreamFactory()); + batch.Execute(GetConnection(ci), CancellationToken.None).Wait(); + + // Then: + // ... It should have executed without error + Assert.True(batch.HasExecuted, "The batch should have been marked executed."); + Assert.False(batch.HasError, "The batch should not have an error"); + + // ... There should be exactly one result set + Assert.Equal(resultSets, batch.ResultSets.Count()); + Assert.Equal(resultSets, batch.ResultSummaries.Length); + + // ... Inside the result set should be with 5 rows + Assert.Equal(Common.StandardRows, batch.ResultSets.First().RowCount); + Assert.Equal(Common.StandardRows, batch.ResultSummaries[0].RowCount); + + // ... Inside the result set should have 5 columns + Assert.Equal(Common.StandardColumns, batch.ResultSets.First().Columns.Length); + Assert.Equal(Common.StandardColumns, batch.ResultSummaries[0].ColumnInfo.Length); + + // ... There should be a message for how many rows were affected + Assert.Equal(resultSets, batch.ResultMessages.Count()); + } + + [Fact] + public void BatchExecuteTwoResultSets() + { + var dataset = new[] { Common.StandardTestData, Common.StandardTestData }; + int resultSets = dataset.Length; + ConnectionInfo ci = Common.CreateTestConnectionInfo(dataset, false); + + // If I execute a query that should get two result sets + Batch batch = new Batch(Common.StandardQuery, 1, Common.GetFileStreamFactory()); + batch.Execute(GetConnection(ci), CancellationToken.None).Wait(); + + // Then: + // ... It should have executed without error + Assert.True(batch.HasExecuted, "The batch should have been marked executed."); + Assert.False(batch.HasError, "The batch should not have an error"); + + // ... There should be exactly two result sets + Assert.Equal(resultSets, batch.ResultSets.Count()); + + foreach (ResultSet rs in batch.ResultSets) + { + // ... Each result set should have 5 rows + Assert.Equal(Common.StandardRows, rs.RowCount); + + // ... Inside each result set should be 5 columns + Assert.Equal(Common.StandardColumns, rs.Columns.Length); + } + + // ... There should be exactly two result set summaries + Assert.Equal(resultSets, batch.ResultSummaries.Length); + + foreach (ResultSetSummary rs in batch.ResultSummaries) + { + // ... Inside each result summary, there should be 5 rows + Assert.Equal(Common.StandardRows, rs.RowCount); + + // ... Inside each result summary, there should be 5 column definitions + Assert.Equal(Common.StandardColumns, rs.ColumnInfo.Length); + } + + // ... There should be a message for how many rows were affected + Assert.Equal(resultSets, batch.ResultMessages.Count()); + } + + [Fact] + public void BatchExecuteInvalidQuery() + { + ConnectionInfo ci = Common.CreateTestConnectionInfo(null, true); + + // If I execute a batch that is invalid + Batch batch = new Batch(Common.StandardQuery, 1, Common.GetFileStreamFactory()); + batch.Execute(GetConnection(ci), CancellationToken.None).Wait(); + + // Then: + // ... It should have executed with error + Assert.True(batch.HasExecuted); + Assert.True(batch.HasError); + + // ... There should be no result sets + Assert.Empty(batch.ResultSets); + Assert.Empty(batch.ResultSummaries); + + // ... There should be plenty of messages for the error + Assert.NotEmpty(batch.ResultMessages); + } + + [Fact] + public async Task BatchExecuteExecuted() + { + ConnectionInfo ci = Common.CreateTestConnectionInfo(new[] { Common.StandardTestData }, false); + + // If I execute a batch + Batch batch = new Batch(Common.StandardQuery, 1, Common.GetFileStreamFactory()); + batch.Execute(GetConnection(ci), CancellationToken.None).Wait(); + + // Then: + // ... It should have executed without error + Assert.True(batch.HasExecuted, "The batch should have been marked executed."); + Assert.False(batch.HasError, "The batch should not have an error"); + + // If I execute it again + // Then: + // ... It should throw an invalid operation exception + await Assert.ThrowsAsync(() => + batch.Execute(GetConnection(ci), CancellationToken.None)); + + // ... The data should still be available without error + Assert.False(batch.HasError, "The batch should not be in an error condition"); + Assert.True(batch.HasExecuted, "The batch should still be marked executed."); + Assert.NotEmpty(batch.ResultSets); + Assert.NotEmpty(batch.ResultSummaries); + } + + [Theory] + [InlineData("")] + [InlineData(null)] + public void BatchExecuteNoSql(string query) + { + // If: + // ... I create a batch that has an empty query + // Then: + // ... It should throw an exception + Assert.Throws(() => new Batch(query, 1, Common.GetFileStreamFactory())); + } + + [Fact] + public void BatchNoBufferFactory() + { + // If: + // ... I create a batch that has no file stream factory + // Then: + // ... It should throw an exception + Assert.Throws(() => new Batch("stuff", 1, null)); + } + + #endregion + + #region Query Class Tests + + [Fact] + public void QueryExecuteNoQueryText() + { + // If: + // ... I create a query that has a null query text + // Then: + // ... It should throw an exception + Assert.Throws(() => + new Query(null, Common.CreateTestConnectionInfo(null, false), new QueryExecutionSettings(), Common.GetFileStreamFactory())); + } + + [Fact] + public void QueryExecuteNoConnectionInfo() + { + // If: + // ... I create a query that has a null connection info + // Then: + // ... It should throw an exception + Assert.Throws(() => new Query("Some Query", null, new QueryExecutionSettings(), Common.GetFileStreamFactory())); + } + + [Fact] + public void QueryExecuteNoSettings() + { + // If: + // ... I create a query that has a null settings + // Then: + // ... It should throw an exception + Assert.Throws(() => + new Query("Some query", Common.CreateTestConnectionInfo(null, false), null, Common.GetFileStreamFactory())); + } + + [Fact] + public void QueryExecuteNoBufferFactory() + { + // If: + // ... I create a query that has a null file stream factory + // Then: + // ... It should throw an exception + Assert.Throws(() => + new Query("Some query", Common.CreateTestConnectionInfo(null, false), new QueryExecutionSettings(),null)); + } + + [Fact] + public void QueryExecuteSingleBatch() + { + // If: + // ... I create a query from a single batch (without separator) + ConnectionInfo ci = Common.CreateTestConnectionInfo(null, false); + Query query = new Query(Common.StandardQuery, ci, new QueryExecutionSettings(), Common.GetFileStreamFactory()); + + // Then: + // ... I should get a single batch to execute that hasn't been executed + Assert.NotEmpty(query.QueryText); + Assert.NotEmpty(query.Batches); + Assert.Equal(1, query.Batches.Length); + Assert.False(query.HasExecuted); + Assert.Throws(() => query.BatchSummaries); + + // If: + // ... I then execute the query + query.Execute().Wait(); + + // Then: + // ... The query should have completed successfully with one batch summary returned + Assert.True(query.HasExecuted); + Assert.NotEmpty(query.BatchSummaries); + Assert.Equal(1, query.BatchSummaries.Length); + } + + [Fact] + public void QueryExecuteNoOpBatch() + { + // If: + // ... I create a query from a single batch that does nothing + ConnectionInfo ci = Common.CreateTestConnectionInfo(null, false); + Query query = new Query(Common.NoOpQuery, ci, new QueryExecutionSettings(), Common.GetFileStreamFactory()); + + // Then: + // ... I should get no batches back + Assert.NotEmpty(query.QueryText); + Assert.Empty(query.Batches); + Assert.False(query.HasExecuted); + Assert.Throws(() => query.BatchSummaries); + + // If: + // ... I Then execute the query + query.Execute().Wait(); + + // Then: + // ... The query should have completed successfully with no batch summaries returned + Assert.True(query.HasExecuted); + Assert.Empty(query.BatchSummaries); + } + + [Fact] + public void QueryExecuteMultipleBatches() + { + // If: + // ... I create a query from two batches (with separator) + ConnectionInfo ci = Common.CreateTestConnectionInfo(null, false); + string queryText = string.Format("{0}\r\nGO\r\n{0}", Common.StandardQuery); + Query query = new Query(queryText, ci, new QueryExecutionSettings(), Common.GetFileStreamFactory()); + + // Then: + // ... I should get back two batches to execute that haven't been executed + Assert.NotEmpty(query.QueryText); + Assert.NotEmpty(query.Batches); + Assert.Equal(2, query.Batches.Length); + Assert.False(query.HasExecuted); + Assert.Throws(() => query.BatchSummaries); + + // If: + // ... I then execute the query + query.Execute().Wait(); + + // Then: + // ... The query should have completed successfully with two batch summaries returned + Assert.True(query.HasExecuted); + Assert.NotEmpty(query.BatchSummaries); + Assert.Equal(2, query.BatchSummaries.Length); + } + + [Fact] + public void QueryExecuteMultipleBatchesWithNoOp() + { + // If: + // ... I create a query from a two batches (with separator) + ConnectionInfo ci = Common.CreateTestConnectionInfo(null, false); + string queryText = string.Format("{0}\r\nGO\r\n{1}", Common.StandardQuery, Common.NoOpQuery); + Query query = new Query(queryText, ci, new QueryExecutionSettings(), Common.GetFileStreamFactory()); + + // Then: + // ... I should get back one batch to execute that hasn't been executed + Assert.NotEmpty(query.QueryText); + Assert.NotEmpty(query.Batches); + Assert.Equal(1, query.Batches.Length); + Assert.False(query.HasExecuted); + Assert.Throws(() => query.BatchSummaries); + + // If: + // .. I then execute the query + query.Execute().Wait(); + + // ... The query should have completed successfully with one batch summary returned + Assert.True(query.HasExecuted); + Assert.NotEmpty(query.BatchSummaries); + Assert.Equal(1, query.BatchSummaries.Length); + } + + [Fact] + public void QueryExecuteInvalidBatch() + { + // If: + // ... I create a query from an invalid batch + ConnectionInfo ci = Common.CreateTestConnectionInfo(null, true); + Query query = new Query(Common.InvalidQuery, ci, new QueryExecutionSettings(), Common.GetFileStreamFactory()); + + // Then: + // ... I should get back a query with one batch not executed + Assert.NotEmpty(query.QueryText); + Assert.NotEmpty(query.Batches); + Assert.Equal(1, query.Batches.Length); + Assert.False(query.HasExecuted); + Assert.Throws(() => query.BatchSummaries); + + // If: + // ... I then execute the query + query.Execute().Wait(); + + // Then: + // ... There should be an error on the batch + Assert.True(query.HasExecuted); + Assert.NotEmpty(query.BatchSummaries); + Assert.Equal(1, query.BatchSummaries.Length); + Assert.True(query.BatchSummaries[0].HasError); + Assert.NotEmpty(query.BatchSummaries[0].Messages); + } + + #endregion + + #region Service Tests + + [Fact] + public void QueryExecuteValidNoResultsTest() + { + // Given: + // ... Default settings are stored in the workspace service + WorkspaceService.Instance.CurrentSettings = new SqlToolsSettings(); + + // If: + // ... I request to execute a valid query with no results + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); + var queryParams = new QueryExecuteParams { QueryText = Common.StandardQuery, OwnerUri = Common.OwnerUri }; + + QueryExecuteResult result = null; + QueryExecuteCompleteParams completeParams = null; + var requestContext = + RequestContextMocks.SetupRequestContextMock( + resultCallback: qer => result = qer, + expectedEvent: QueryExecuteCompleteEvent.Type, + eventCallback: (et, cp) => completeParams = cp, + errorCallback: null); + queryService.HandleExecuteRequest(queryParams, requestContext.Object).Wait(); + + // Then: + // ... No Errors should have been sent + // ... A successful result should have been sent with messages on the first batch + // ... A completion event should have been fired with empty results + VerifyQueryExecuteCallCount(requestContext, Times.Once(), Times.Once(), Times.Never()); + Assert.Null(result.Messages); + Assert.Equal(1, completeParams.BatchSummaries.Length); + Assert.Empty(completeParams.BatchSummaries[0].ResultSetSummaries); + Assert.NotEmpty(completeParams.BatchSummaries[0].Messages); + + // ... There should be one active query + Assert.Equal(1, queryService.ActiveQueries.Count); + } + + [Fact] + public void QueryExecuteValidResultsTest() + { + // If: + // ... I request to execute a valid query with results + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(new[] { Common.StandardTestData }, false), true); + var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QueryText = Common.StandardQuery }; + + QueryExecuteResult result = null; + QueryExecuteCompleteParams completeParams = null; + var requestContext = + RequestContextMocks.SetupRequestContextMock( + resultCallback: qer => result = qer, + expectedEvent: QueryExecuteCompleteEvent.Type, + eventCallback: (et, cp) => completeParams = cp, + errorCallback: null); + queryService.HandleExecuteRequest(queryParams, requestContext.Object).Wait(); + + // Then: + // ... No errors should have been sent + // ... A successful result should have been sent with messages + // ... A completion event should have been fired with one result + VerifyQueryExecuteCallCount(requestContext, Times.Once(), Times.Once(), Times.Never()); + Assert.Null(result.Messages); + Assert.Equal(1, completeParams.BatchSummaries.Length); + Assert.NotEmpty(completeParams.BatchSummaries[0].ResultSetSummaries); + Assert.NotEmpty(completeParams.BatchSummaries[0].Messages); + Assert.False(completeParams.BatchSummaries[0].HasError); + + // ... There should be one active query + Assert.Equal(1, queryService.ActiveQueries.Count); + } + + [Fact] + public void QueryExecuteUnconnectedUriTest() + { + // If: + // ... I request to execute a query using a file URI that isn't connected + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), false); + var queryParams = new QueryExecuteParams { OwnerUri = "notConnected", QueryText = Common.StandardQuery }; + + QueryExecuteResult result = null; + var requestContext = RequestContextMocks.SetupRequestContextMock(qer => result = qer, QueryExecuteCompleteEvent.Type, null, null); + queryService.HandleExecuteRequest(queryParams, requestContext.Object).Wait(); + + // Then: + // ... An error message should have been returned via the result + // ... No completion event should have been fired + // ... No error event should have been fired + // ... There should be no active queries + VerifyQueryExecuteCallCount(requestContext, Times.Once(), Times.Never(), Times.Never()); + Assert.NotNull(result.Messages); + Assert.NotEmpty(result.Messages); + Assert.Empty(queryService.ActiveQueries); + } + + [Fact] + public void QueryExecuteInProgressTest() + { + // If: + // ... I request to execute a query + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); + var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QueryText = Common.StandardQuery }; + + // Note, we don't care about the results of the first request + var firstRequestContext = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); + queryService.HandleExecuteRequest(queryParams, firstRequestContext.Object).Wait(); + + // ... And then I request another query without waiting for the first to complete + queryService.ActiveQueries[Common.OwnerUri].HasExecuted = false; // Simulate query hasn't finished + QueryExecuteResult result = null; + var secondRequestContext = RequestContextMocks.SetupRequestContextMock(qer => result = qer, QueryExecuteCompleteEvent.Type, null, null); + queryService.HandleExecuteRequest(queryParams, secondRequestContext.Object).Wait(); + + // Then: + // ... No errors should have been sent + // ... A result should have been sent with an error message + // ... No completion event should have been fired + // ... There should only be one active query + VerifyQueryExecuteCallCount(secondRequestContext, Times.Once(), Times.AtMostOnce(), Times.Never()); + Assert.NotNull(result.Messages); + Assert.NotEmpty(result.Messages); + Assert.Equal(1, queryService.ActiveQueries.Count); + } + + [Fact] + public void QueryExecuteCompletedTest() + { + // If: + // ... I request to execute a query + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); + var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QueryText = Common.StandardQuery }; + + // Note, we don't care about the results of the first request + var firstRequestContext = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); + + queryService.HandleExecuteRequest(queryParams, firstRequestContext.Object).Wait(); + + // ... And then I request another query after waiting for the first to complete + QueryExecuteResult result = null; + QueryExecuteCompleteParams complete = null; + var secondRequestContext = + RequestContextMocks.SetupRequestContextMock(qer => result = qer, QueryExecuteCompleteEvent.Type, (et, qecp) => complete = qecp, null); + queryService.HandleExecuteRequest(queryParams, secondRequestContext.Object).Wait(); + + // Then: + // ... No errors should have been sent + // ... A result should have been sent with no errors + // ... There should only be one active query + VerifyQueryExecuteCallCount(secondRequestContext, Times.Once(), Times.Once(), Times.Never()); + Assert.Null(result.Messages); + Assert.False(complete.BatchSummaries.Any(b => b.HasError)); + Assert.Equal(1, queryService.ActiveQueries.Count); + } + + [Theory] + [InlineData("")] + [InlineData(null)] + public void QueryExecuteMissingQueryTest(string query) + { + // If: + // ... I request to execute a query with a missing query string + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); + var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QueryText = query }; + + QueryExecuteResult result = null; + var requestContext = + RequestContextMocks.SetupRequestContextMock(qer => result = qer, QueryExecuteCompleteEvent.Type, null, null); + queryService.HandleExecuteRequest(queryParams, requestContext.Object).Wait(); + + // Then: + // ... No errors should have been sent + // ... A result should have been sent with an error message + // ... No completion event should have been fired + VerifyQueryExecuteCallCount(requestContext, Times.Once(), Times.Never(), Times.Never()); + Assert.NotNull(result.Messages); + Assert.NotEmpty(result.Messages); + + // ... There should not be an active query + Assert.Empty(queryService.ActiveQueries); + } + + [Fact] + public void QueryExecuteInvalidQueryTest() + { + // If: + // ... I request to execute a query that is invalid + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, true), true); + var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QueryText = Common.StandardQuery }; + + QueryExecuteResult result = null; + QueryExecuteCompleteParams complete = null; + var requestContext = + RequestContextMocks.SetupRequestContextMock(qer => result = qer, QueryExecuteCompleteEvent.Type, (et, qecp) => complete = qecp, null); + queryService.HandleExecuteRequest(queryParams, requestContext.Object).Wait(); + + // Then: + // ... No errors should have been sent + // ... A result should have been sent with success (we successfully started the query) + // ... A completion event should have been sent with error + VerifyQueryExecuteCallCount(requestContext, Times.Once(), Times.Once(), Times.Never()); + Assert.Null(result.Messages); + Assert.Equal(1, complete.BatchSummaries.Length); + Assert.True(complete.BatchSummaries[0].HasError); + Assert.NotEmpty(complete.BatchSummaries[0].Messages); + } + + #endregion + + private void VerifyQueryExecuteCallCount(Mock> mock, Times sendResultCalls, Times sendEventCalls, Times sendErrorCalls) + { + mock.Verify(rc => rc.SendResult(It.IsAny()), sendResultCalls); + mock.Verify(rc => rc.SendEvent( + It.Is>(m => m == QueryExecuteCompleteEvent.Type), + It.IsAny()), sendEventCalls); + mock.Verify(rc => rc.SendError(It.IsAny()), sendErrorCalls); + } + + private DbConnection GetConnection(ConnectionInfo info) + { + return info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails)); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs new file mode 100644 index 00000000..2968e709 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs @@ -0,0 +1,231 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Microsoft.SqlTools.ServiceLayer.SqlContext; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; +using Moq; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution +{ + public class SubsetTests + { + #region Batch Class Tests + + [Theory] + [InlineData(2)] + [InlineData(20)] + public void BatchSubsetValidTest(int rowCount) + { + // If I have an executed batch + Batch b = Common.GetBasicExecutedBatch(); + + // ... And I ask for a subset with valid arguments + ResultSetSubset subset = b.GetSubset(0, 0, rowCount).Result; + + // Then: + // I should get the requested number of rows + Assert.Equal(Math.Min(rowCount, Common.StandardTestData.Length), subset.RowCount); + Assert.Equal(Math.Min(rowCount, Common.StandardTestData.Length), subset.Rows.Length); + } + + [Theory] + [InlineData(-1, 0, 2)] // Invalid result set, too low + [InlineData(2, 0, 2)] // Invalid result set, too high + [InlineData(0, -1, 2)] // Invalid start index, too low + [InlineData(0, 10, 2)] // Invalid start index, too high + [InlineData(0, 0, -1)] // Invalid row count, too low + [InlineData(0, 0, 0)] // Invalid row count, zero + public void BatchSubsetInvalidParamsTest(int resultSetIndex, int rowStartInex, int rowCount) + { + // If I have an executed batch + Batch b = Common.GetBasicExecutedBatch(); + + // ... And I ask for a subset with an invalid result set index + // Then: + // ... It should throw an exception + Assert.ThrowsAsync(() => b.GetSubset(resultSetIndex, rowStartInex, rowCount)).Wait(); + } + + #endregion + + #region Query Class Tests + + [Fact] + public void SubsetUnexecutedQueryTest() + { + // If I have a query that has *not* been executed + Query q = new Query(Common.StandardQuery, Common.CreateTestConnectionInfo(null, false), new QueryExecutionSettings(), Common.GetFileStreamFactory()); + + // ... And I ask for a subset with valid arguments + // Then: + // ... It should throw an exception + Assert.ThrowsAsync(() => q.GetSubset(0, 0, 0, 2)).Wait(); + } + + [Theory] + [InlineData(-1)] // Invalid batch, too low + [InlineData(2)] // Invalid batch, too high + public void QuerySubsetInvalidParamsTest(int batchIndex) + { + // If I have an executed query + Query q = Common.GetBasicExecutedQuery(); + + // ... And I ask for a subset with an invalid result set index + // Then: + // ... It should throw an exception + Assert.ThrowsAsync(() => q.GetSubset(batchIndex, 0, 0, 1)).Wait(); + } + + #endregion + + #region Service Intergration Tests + + [Fact] + public void SubsetServiceValidTest() + { + // If: + // ... I have a query that has results (doesn't matter what) + var queryService =Common.GetPrimedExecutionService( + Common.CreateMockFactory(new[] {Common.StandardTestData}, false), true); + var executeParams = new QueryExecuteParams {QueryText = "Doesn'tMatter", OwnerUri = Common.OwnerUri}; + var executeRequest = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); + queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + + // ... And I then ask for a valid set of results from it + var subsetParams = new QueryExecuteSubsetParams {OwnerUri = Common.OwnerUri, RowsCount = 1, ResultSetIndex = 0, RowsStartIndex = 0}; + QueryExecuteSubsetResult result = null; + var subsetRequest = GetQuerySubsetResultContextMock(qesr => result = qesr, null); + queryService.HandleResultSubsetRequest(subsetParams, subsetRequest.Object).Wait(); + + // Then: + // ... I should have a successful result + // ... There should be rows there (other test validate that the rows are correct) + // ... There should not be any error calls + VerifyQuerySubsetCallCount(subsetRequest, Times.Once(), Times.Never()); + Assert.Null(result.Message); + Assert.NotNull(result.ResultSubset); + } + + [Fact] + public void SubsetServiceMissingQueryTest() + { + // If: + // ... I ask for a set of results for a file that hasn't executed a query + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); + var subsetParams = new QueryExecuteSubsetParams { OwnerUri = Common.OwnerUri, RowsCount = 1, ResultSetIndex = 0, RowsStartIndex = 0 }; + QueryExecuteSubsetResult result = null; + var subsetRequest = GetQuerySubsetResultContextMock(qesr => result = qesr, null); + queryService.HandleResultSubsetRequest(subsetParams, subsetRequest.Object).Wait(); + + // Then: + // ... I should have an error result + // ... There should be no rows in the result set + // ... There should not be any error calls + VerifyQuerySubsetCallCount(subsetRequest, Times.Once(), Times.Never()); + Assert.NotNull(result.Message); + Assert.Null(result.ResultSubset); + } + + [Fact] + public void SubsetServiceUnexecutedQueryTest() + { + // If: + // ... I have a query that hasn't finished executing (doesn't matter what) + var queryService = Common.GetPrimedExecutionService( + Common.CreateMockFactory(new[] { Common.StandardTestData }, false), true); + var executeParams = new QueryExecuteParams { QueryText = "Doesn'tMatter", OwnerUri = Common.OwnerUri }; + var executeRequest = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); + queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + queryService.ActiveQueries[Common.OwnerUri].HasExecuted = false; + + // ... And I then ask for a valid set of results from it + var subsetParams = new QueryExecuteSubsetParams { OwnerUri = Common.OwnerUri, RowsCount = 1, ResultSetIndex = 0, RowsStartIndex = 0 }; + QueryExecuteSubsetResult result = null; + var subsetRequest = GetQuerySubsetResultContextMock(qesr => result = qesr, null); + queryService.HandleResultSubsetRequest(subsetParams, subsetRequest.Object).Wait(); + + // Then: + // ... I should get an error result + // ... There should not be rows + // ... There should not be any error calls + VerifyQuerySubsetCallCount(subsetRequest, Times.Once(), Times.Never()); + Assert.NotNull(result.Message); + Assert.Null(result.ResultSubset); + } + + [Fact] + public void SubsetServiceOutOfRangeSubsetTest() + { + // If: + // ... I have a query that doesn't have any result sets + var queryService = Common.GetPrimedExecutionService( + Common.CreateMockFactory(null, false), true); + var executeParams = new QueryExecuteParams { QueryText = "Doesn'tMatter", OwnerUri = Common.OwnerUri }; + var executeRequest = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); + queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + + // ... And I then ask for a set of results from it + var subsetParams = new QueryExecuteSubsetParams { OwnerUri = Common.OwnerUri, RowsCount = 1, ResultSetIndex = 0, RowsStartIndex = 0 }; + QueryExecuteSubsetResult result = null; + var subsetRequest = GetQuerySubsetResultContextMock(qesr => result = qesr, null); + queryService.HandleResultSubsetRequest(subsetParams, subsetRequest.Object).Wait(); + + // Then: + // ... I should get an error result + // ... There should not be rows + // ... There should not be any error calls + VerifyQuerySubsetCallCount(subsetRequest, Times.Once(), Times.Never()); + Assert.NotNull(result.Message); + Assert.Null(result.ResultSubset); + } + + #endregion + + #region Mocking + + private Mock> GetQuerySubsetResultContextMock( + Action resultCallback, + Action errorCallback) + { + var requestContext = new Mock>(); + + // Setup the mock for SendResult + var sendResultFlow = requestContext + .Setup(rc => rc.SendResult(It.IsAny())) + .Returns(Task.FromResult(0)); + if (resultCallback != null) + { + sendResultFlow.Callback(resultCallback); + } + + // Setup the mock for SendError + var sendErrorFlow = requestContext + .Setup(rc => rc.SendError(It.IsAny())) + .Returns(Task.FromResult(0)); + if (errorCallback != null) + { + sendErrorFlow.Callback(errorCallback); + } + + return requestContext; + } + + private void VerifyQuerySubsetCallCount(Mock> mock, Times sendResultCalls, + Times sendErrorCalls) + { + mock.Verify(rc => rc.SendResult(It.IsAny()), sendResultCalls); + mock.Verify(rc => rc.SendError(It.IsAny()), sendErrorCalls); + } + + #endregion + + } +} diff --git a/test/ServiceHost.Test/LanguageServer/JsonRpcMessageSerializerTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/ServiceHost/JsonRpcMessageSerializerTests.cs similarity index 90% rename from test/ServiceHost.Test/LanguageServer/JsonRpcMessageSerializerTests.cs rename to test/Microsoft.SqlTools.ServiceLayer.Test/ServiceHost/JsonRpcMessageSerializerTests.cs index 9ec341c5..d50f02da 100644 --- a/test/ServiceHost.Test/LanguageServer/JsonRpcMessageSerializerTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/ServiceHost/JsonRpcMessageSerializerTests.cs @@ -3,12 +3,12 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol; -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol.Serializers; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers; +using HostingMessage = Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts.Message; using Newtonsoft.Json.Linq; using Xunit; -namespace Microsoft.SqlTools.EditorServices.Test.Protocol.LanguageServer +namespace Microsoft.SqlTools.ServiceLayer.Test.ServiceHost { public class TestMessageContents { @@ -44,7 +44,7 @@ namespace Microsoft.SqlTools.EditorServices.Test.Protocol.LanguageServer { var messageObj = this.messageSerializer.SerializeMessage( - Message.Request( + HostingMessage.Request( MessageId, MethodName, MessageContent)); @@ -61,7 +61,7 @@ namespace Microsoft.SqlTools.EditorServices.Test.Protocol.LanguageServer { var messageObj = this.messageSerializer.SerializeMessage( - Message.Event( + HostingMessage.Event( MethodName, MessageContent)); @@ -76,7 +76,7 @@ namespace Microsoft.SqlTools.EditorServices.Test.Protocol.LanguageServer { var messageObj = this.messageSerializer.SerializeMessage( - Message.Response( + HostingMessage.Response( MessageId, null, MessageContent)); @@ -92,7 +92,7 @@ namespace Microsoft.SqlTools.EditorServices.Test.Protocol.LanguageServer { var messageObj = this.messageSerializer.SerializeMessage( - Message.ResponseError( + HostingMessage.ResponseError( MessageId, null, MessageContent)); diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/RequestContextMocks.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/RequestContextMocks.cs new file mode 100644 index 00000000..91e05a76 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/RequestContextMocks.cs @@ -0,0 +1,77 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; +using Moq; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Utility +{ + public static class RequestContextMocks + { + + public static Mock> Create(Action resultCallback) + { + var requestContext = new Mock>(); + + // Setup the mock for SendResult + var sendResultFlow = requestContext + .Setup(rc => rc.SendResult(It.IsAny())) + .Returns(Task.FromResult(0)); + if (resultCallback != null) + { + sendResultFlow.Callback(resultCallback); + } + return requestContext; + } + + public static Mock> AddEventHandling( + this Mock> mock, + EventType expectedEvent, + Action, TParams> eventCallback) + { + var flow = mock.Setup(rc => rc.SendEvent( + It.Is>(m => m == expectedEvent), + It.IsAny())) + .Returns(Task.FromResult(0)); + if (eventCallback != null) + { + flow.Callback(eventCallback); + } + + return mock; + } + + public static Mock> AddErrorHandling( + this Mock> mock, + Action errorCallback) + { + + // Setup the mock for SendError + var sendErrorFlow = mock.Setup(rc => rc.SendError(It.IsAny())) + .Returns(Task.FromResult(0)); + if (mock != null && errorCallback != null) + { + sendErrorFlow.Callback(errorCallback); + } + + return mock; + } + + public static Mock> SetupRequestContextMock( + Action resultCallback, + EventType expectedEvent, + Action, TParams> eventCallback, + Action errorCallback) + { + return Create(resultCallback) + .AddEventHandling(expectedEvent, eventCallback) + .AddErrorHandling(errorCallback); + } + + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbColumn.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbColumn.cs new file mode 100644 index 00000000..c2765783 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbColumn.cs @@ -0,0 +1,21 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Data.Common; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Utility +{ + public class TestDbColumn : DbColumn + { + public TestDbColumn() + { + base.IsLong = false; + base.ColumnName = "Test Column"; + base.ColumnSize = 128; + base.AllowDBNull = true; + base.DataType = typeof(string); + base.DataTypeName = "nvarchar"; + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs new file mode 100644 index 00000000..0330cda0 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs @@ -0,0 +1,216 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Data.Common; +using System.Linq; +using Moq; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Utility +{ + public class TestDbDataReader : DbDataReader, IDbColumnSchemaGenerator + { + + #region Test Specific Implementations + + private Dictionary[][] Data { get; set; } + + public IEnumerator[]> ResultSet { get; private set; } + + private IEnumerator> Rows { get; set; } + + public TestDbDataReader(Dictionary[][] data) + { + Data = data; + if (Data != null) + { + ResultSet = ((IEnumerable[]>) Data).GetEnumerator(); + ResultSet.MoveNext(); + } + } + + #endregion + + public override bool HasRows + { + get { return ResultSet != null && ResultSet.Current.Length > 0; } + } + + public override bool Read() + { + if (Rows == null) + { + Rows = ((IEnumerable>) ResultSet.Current).GetEnumerator(); + } + return Rows.MoveNext(); + } + + public override bool NextResult() + { + if (Data == null || !ResultSet.MoveNext()) + { + return false; + } + Rows = ((IEnumerable>)ResultSet.Current).GetEnumerator(); + return true; + } + + public override object GetValue(int ordinal) + { + return this[ordinal]; + } + + public override int GetValues(object[] values) + { + for(int i = 0; i < Rows.Current.Count; i++) + { + values[i] = this[i]; + } + return Rows.Current.Count; + } + + public override object this[string name] + { + get { return Rows.Current[name]; } + } + + public override object this[int ordinal] + { + get { return Rows.Current[Rows.Current.Keys.AsEnumerable().ToArray()[ordinal]]; } + } + + public ReadOnlyCollection GetColumnSchema() + { + if (ResultSet?.Current == null || ResultSet.Current.Length <= 0) + { + return new ReadOnlyCollection(new List()); + } + + List columns = new List(); + for (int i = 0; i < ResultSet.Current[0].Count; i++) + { + columns.Add(new TestDbColumn()); + } + return new ReadOnlyCollection(columns); + } + + public override bool IsDBNull(int ordinal) + { + return this[ordinal] == null; + } + + public override int FieldCount { get { return Rows?.Current.Count ?? 0; } } + + public override int RecordsAffected + { + // Mimics the behavior of SqlDataReader + get { return Rows != null ? -1 : 1; } + } + + #region Not Implemented + + public override bool GetBoolean(int ordinal) + { + throw new NotImplementedException(); + } + + public override byte GetByte(int ordinal) + { + throw new NotImplementedException(); + } + + public override long GetBytes(int ordinal, long dataOffset, byte[] buffer, int bufferOffset, int length) + { + throw new NotImplementedException(); + } + + public override char GetChar(int ordinal) + { + throw new NotImplementedException(); + } + + public override long GetChars(int ordinal, long dataOffset, char[] buffer, int bufferOffset, int length) + { + throw new NotImplementedException(); + } + + public override string GetDataTypeName(int ordinal) + { + throw new NotImplementedException(); + } + + public override DateTime GetDateTime(int ordinal) + { + throw new NotImplementedException(); + } + + public override decimal GetDecimal(int ordinal) + { + throw new NotImplementedException(); + } + + public override double GetDouble(int ordinal) + { + throw new NotImplementedException(); + } + + public override int GetOrdinal(string name) + { + throw new NotImplementedException(); + } + + public override string GetName(int ordinal) + { + throw new NotImplementedException(); + } + + public override long GetInt64(int ordinal) + { + throw new NotImplementedException(); + } + + public override int GetInt32(int ordinal) + { + throw new NotImplementedException(); + } + + public override short GetInt16(int ordinal) + { + throw new NotImplementedException(); + } + + public override Guid GetGuid(int ordinal) + { + throw new NotImplementedException(); + } + + public override float GetFloat(int ordinal) + { + throw new NotImplementedException(); + } + + public override Type GetFieldType(int ordinal) + { + throw new NotImplementedException(); + } + + public override string GetString(int ordinal) + { + throw new NotImplementedException(); + } + + public override IEnumerator GetEnumerator() + { + throw new NotImplementedException(); + } + + public override int Depth { get; } + public override bool IsClosed { get; } + + #endregion + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs new file mode 100644 index 00000000..23ff7260 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs @@ -0,0 +1,207 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +//#define USE_LIVE_CONNECTION + +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; +using Microsoft.SqlTools.ServiceLayer.LanguageServices; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; + +namespace Microsoft.SqlTools.Test.Utility +{ + /// + /// Tests for the ServiceHost Connection Service tests + /// + public class TestObjects + { + /// + /// Creates a test connection service + /// + public static ConnectionService GetTestConnectionService() + { +#if !USE_LIVE_CONNECTION + // use mock database connection + return new ConnectionService(new TestSqlConnectionFactory()); +#else + // connect to a real server instance + return ConnectionService.Instance; +#endif + } + + public static ConnectParams GetTestConnectionParams() + { + return new ConnectParams() + { + OwnerUri = "file://some/file.sql", + Connection = GetTestConnectionDetails() + }; + } + + /// + /// Creates a test connection details object + /// + public static ConnectionDetails GetTestConnectionDetails() + { + return new ConnectionDetails() + { + UserName = "sa", + Password = "Yukon900", + DatabaseName = "AdventureWorks2016CTP3_2", + ServerName = "sqltools11" + }; + } + + /// + /// Create a test language service instance + /// + /// + public static LanguageService GetTestLanguageService() + { + return new LanguageService(); + } + + /// + /// Creates a test autocomplete service instance + /// + public static AutoCompleteService GetAutoCompleteService() + { + return AutoCompleteService.Instance; + } + + /// + /// Creates a test sql connection factory instance + /// + public static ISqlConnectionFactory GetTestSqlConnectionFactory() + { +#if !USE_LIVE_CONNECTION + // use mock database connection + return new TestSqlConnectionFactory(); +#else + // connect to a real server instance + return ConnectionService.Instance.ConnectionFactory; +#endif + + } + } + + /// + /// Test mock class for IDbCommand + /// + public class TestSqlCommand : DbCommand + { + internal TestSqlCommand(Dictionary[][] data) + { + Data = data; + } + + internal Dictionary[][] Data { get; set; } + + public override void Cancel() + { + throw new NotImplementedException(); + } + + public override int ExecuteNonQuery() + { + throw new NotImplementedException(); + } + + public override object ExecuteScalar() + { + throw new NotImplementedException(); + } + + public override void Prepare() + { + throw new NotImplementedException(); + } + + public override string CommandText { get; set; } + public override int CommandTimeout { get; set; } + public override CommandType CommandType { get; set; } + public override UpdateRowSource UpdatedRowSource { get; set; } + protected override DbConnection DbConnection { get; set; } + protected override DbParameterCollection DbParameterCollection { get; } + protected override DbTransaction DbTransaction { get; set; } + public override bool DesignTimeVisible { get; set; } + + protected override DbParameter CreateDbParameter() + { + throw new NotImplementedException(); + } + + protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) + { + return new TestDbDataReader(Data); + } + } + + /// + /// Test mock class for SqlConnection wrapper + /// + public class TestSqlConnection : DbConnection + { + internal TestSqlConnection(Dictionary[][] data) + { + Data = data; + } + + internal Dictionary[][] Data { get; set; } + + protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) + { + throw new NotImplementedException(); + } + + public override void Close() + { + // No Op + } + + public override void Open() + { + // No Op, unless credentials are bad + if(ConnectionString.Contains("invalidUsername")) + { + throw new Exception("Invalid credentials provided"); + } + } + + public override string ConnectionString { get; set; } + public override string Database { get; } + public override ConnectionState State { get; } + public override string DataSource { get; } + public override string ServerVersion { get; } + + protected override DbCommand CreateDbCommand() + { + return new TestSqlCommand(Data); + } + + public override void ChangeDatabase(string databaseName) + { + // No Op + } + } + + /// + /// Test mock class for SqlConnection factory + /// + public class TestSqlConnectionFactory : ISqlConnectionFactory + { + public DbConnection CreateSqlConnection(string connectionString) + { + return new TestSqlConnection(null) + { + ConnectionString = connectionString + }; + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestUtils.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestUtils.cs new file mode 100644 index 00000000..9a5f8ce1 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestUtils.cs @@ -0,0 +1,25 @@ +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Utility +{ + public static class TestUtils + { + + public static void RunIfLinux(Action test) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + test(); + } + } + + public static void RunIfWindows(Action test) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + test(); + } + } + } +} diff --git a/test/ServiceHost.Test/packages.config b/test/Microsoft.SqlTools.ServiceLayer.Test/packages.config similarity index 100% rename from test/ServiceHost.Test/packages.config rename to test/Microsoft.SqlTools.ServiceLayer.Test/packages.config diff --git a/test/ServiceHost.Test/project.json b/test/Microsoft.SqlTools.ServiceLayer.Test/project.json similarity index 57% rename from test/ServiceHost.Test/project.json rename to test/Microsoft.SqlTools.ServiceLayer.Test/project.json index b7f00724..06a46957 100644 --- a/test/ServiceHost.Test/project.json +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/project.json @@ -1,4 +1,5 @@ { + "name": "Microsoft.SqlTools.ServiceLayer.Test", "version": "1.0.0-*", "buildOptions": { "debugType": "portable" @@ -6,11 +7,18 @@ "dependencies": { "Newtonsoft.Json": "9.0.1", "System.Runtime.Serialization.Primitives": "4.1.1", + "System.Data.Common": "4.1.0", + "System.Data.SqlClient": "4.1.0", + "Microsoft.SqlServer.Smo": "140.1.5", + "System.Security.SecureString": "4.0.0", + "System.Collections.Specialized": "4.0.1", + "System.ComponentModel.TypeConverter": "4.1.0", "xunit": "2.1.0", "dotnet-test-xunit": "1.0.0-rc2-192208-24", - "ServiceHost": { - "target": "project" - } + "moq": "4.6.36-alpha", + "Microsoft.SqlTools.ServiceLayer": { + "target": "project" + } }, "testRunner": "xunit", "frameworks": { diff --git a/test/ServiceHost.Test/Message/MessageReaderWriterTests.cs b/test/ServiceHost.Test/Message/MessageReaderWriterTests.cs deleted file mode 100644 index 82e619f5..00000000 --- a/test/ServiceHost.Test/Message/MessageReaderWriterTests.cs +++ /dev/null @@ -1,177 +0,0 @@ -// -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. -// - -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol; -using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol.Serializers; -using System; -using System.IO; -using System.Text; -using System.Threading.Tasks; -using Xunit; - -namespace Microsoft.SqlTools.EditorServices.Test.Protocol.MessageProtocol -{ - public class MessageReaderWriterTests - { - const string TestEventString = "{\"type\":\"event\",\"event\":\"testEvent\",\"body\":null}"; - const string TestEventFormatString = "{{\"event\":\"testEvent\",\"body\":{{\"someString\":\"{0}\"}},\"seq\":0,\"type\":\"event\"}}"; - readonly int ExpectedMessageByteCount = Encoding.UTF8.GetByteCount(TestEventString); - - private IMessageSerializer messageSerializer; - - public MessageReaderWriterTests() - { - this.messageSerializer = new V8MessageSerializer(); - } - - [Fact] - public async Task WritesMessage() - { - MemoryStream outputStream = new MemoryStream(); - - MessageWriter messageWriter = - new MessageWriter( - outputStream, - this.messageSerializer); - - // Write the message and then roll back the stream to be read - // TODO: This will need to be redone! - await messageWriter.WriteMessage(Message.Event("testEvent", null)); - outputStream.Seek(0, SeekOrigin.Begin); - - string expectedHeaderString = - string.Format( - Constants.ContentLengthFormatString, - ExpectedMessageByteCount); - - byte[] buffer = new byte[128]; - await outputStream.ReadAsync(buffer, 0, expectedHeaderString.Length); - - Assert.Equal( - expectedHeaderString, - Encoding.ASCII.GetString(buffer, 0, expectedHeaderString.Length)); - - // Read the message - await outputStream.ReadAsync(buffer, 0, ExpectedMessageByteCount); - - Assert.Equal( - TestEventString, - Encoding.UTF8.GetString(buffer, 0, ExpectedMessageByteCount)); - - outputStream.Dispose(); - } - - [Fact] - public void ReadsMessage() - { - MemoryStream inputStream = new MemoryStream(); - MessageReader messageReader = - new MessageReader( - inputStream, - this.messageSerializer); - - // Write a message to the stream - byte[] messageBuffer = this.GetMessageBytes(TestEventString); - inputStream.Write( - this.GetMessageBytes(TestEventString), - 0, - messageBuffer.Length); - - inputStream.Flush(); - inputStream.Seek(0, SeekOrigin.Begin); - - Message messageResult = messageReader.ReadMessage().Result; - Assert.Equal("testEvent", messageResult.Method); - - inputStream.Dispose(); - } - - [Fact] - public void ReadsManyBufferedMessages() - { - MemoryStream inputStream = new MemoryStream(); - MessageReader messageReader = - new MessageReader( - inputStream, - this.messageSerializer); - - // Get a message to use for writing to the stream - byte[] messageBuffer = this.GetMessageBytes(TestEventString); - - // How many messages of this size should we write to overflow the buffer? - int overflowMessageCount = - (int)Math.Ceiling( - (MessageReader.DefaultBufferSize * 1.5) / messageBuffer.Length); - - // Write the necessary number of messages to the stream - for (int i = 0; i < overflowMessageCount; i++) - { - inputStream.Write(messageBuffer, 0, messageBuffer.Length); - } - - inputStream.Flush(); - inputStream.Seek(0, SeekOrigin.Begin); - - // Read the written messages from the stream - for (int i = 0; i < overflowMessageCount; i++) - { - Message messageResult = messageReader.ReadMessage().Result; - Assert.Equal("testEvent", messageResult.Method); - } - - inputStream.Dispose(); - } - - [Fact] - public void ReaderResizesBufferForLargeMessages() - { - MemoryStream inputStream = new MemoryStream(); - MessageReader messageReader = - new MessageReader( - inputStream, - this.messageSerializer); - - // Get a message with content so large that the buffer will need - // to be resized to fit it all. - byte[] messageBuffer = - this.GetMessageBytes( - string.Format( - TestEventFormatString, - new String('X', (int)(MessageReader.DefaultBufferSize * 3)))); - - inputStream.Write(messageBuffer, 0, messageBuffer.Length); - inputStream.Flush(); - inputStream.Seek(0, SeekOrigin.Begin); - - Message messageResult = messageReader.ReadMessage().Result; - Assert.Equal("testEvent", messageResult.Method); - - inputStream.Dispose(); - } - - private byte[] GetMessageBytes(string messageString, Encoding encoding = null) - { - if (encoding == null) - { - encoding = Encoding.UTF8; - } - - byte[] messageBytes = Encoding.UTF8.GetBytes(messageString); - byte[] headerBytes = - Encoding.ASCII.GetBytes( - string.Format( - Constants.ContentLengthFormatString, - messageBytes.Length)); - - // Copy the bytes into a single buffer - byte[] finalBytes = new byte[headerBytes.Length + messageBytes.Length]; - Buffer.BlockCopy(headerBytes, 0, finalBytes, 0, headerBytes.Length); - Buffer.BlockCopy(messageBytes, 0, finalBytes, headerBytes.Length, messageBytes.Length); - - return finalBytes; - } - } -} - diff --git a/test/ServiceHost.Test/PowerShellEditorServices.Test.Protocol.csproj b/test/ServiceHost.Test/PowerShellEditorServices.Test.Protocol.csproj deleted file mode 100644 index 54e20896..00000000 --- a/test/ServiceHost.Test/PowerShellEditorServices.Test.Protocol.csproj +++ /dev/null @@ -1,109 +0,0 @@ - - - - - - - Debug - AnyCPU - {E3A5CF5D-6E41-44AC-AE0A-4C227E4BACD4} - Library - Properties - Microsoft.SqlTools.EditorServices.Test.Protocol - Microsoft.SqlTools.EditorServices.Test.Protocol - v4.6.1 - 512 - 69e9ba79 - ..\..\ - true - - - - true - full - false - bin\Debug\ - DEBUG;TRACE - prompt - 4 - - - pdbonly - true - bin\Release\ - TRACE - prompt - 4 - - - - ..\..\packages\Newtonsoft.Json.8.0.2\lib\net45\Newtonsoft.Json.dll - True - - - - - - - - - - ..\..\packages\xunit.abstractions.2.0.0\lib\net35\xunit.abstractions.dll - True - - - ..\..\packages\xunit.assert.2.1.0\lib\portable-net45+win8+wp8+wpa81\xunit.assert.dll - True - - - ..\..\packages\xunit.extensibility.core.2.1.0\lib\portable-net45+win8+wp8+wpa81\xunit.core.dll - True - - - ..\..\packages\xunit.extensibility.execution.2.1.0\lib\net45\xunit.execution.desktop.dll - True - - - - - - - - - - - - - - - - - {f8a0946a-5d25-4651-8079-b8d5776916fb} - SqlToolsEditorServices.Protocol - - - {81e8cbcd-6319-49e7-9662-0475bd0791f4} - SqlToolsEditorServices - - - - - - - - - This project references NuGet package(s) that are missing on this computer. Enable NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. - - - - - - - - \ No newline at end of file