Credentials store API (#38)

* CredentialService initial impl with Win32 support

- Basic CredentialService APIs for Save, Read, Delete
- E2E unit tests for Credential Service
- Win32 implementation with unit tests

* Save Password support on Mac v1

- Basic keychain support on Mac using Interop with the KeyChain APIs
- All but 1 unit test passing. This will pass once API is changed, but checking this in with the existing API so that if we decide to alter behavior, we have a reference point.

* Remove Username from Credentials API

- Removed Username option from Credentials as this caused conflicting behavior on Mac vs Windows

* Cleanup Using Statements and add Copyright

* Linux CredentialStore Prototype

* Linux credential store support

- Full support for Linux credential store with tests

* Plumbed CredentialService into Program init

* Addressing Pull Request comments
This commit is contained in:
Kevin Cunnane
2016-09-06 18:12:39 -07:00
committed by GitHub
parent 76e7ea041c
commit 8ca88992be
41 changed files with 3120 additions and 165 deletions

View File

@@ -0,0 +1,287 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System;
using System.IO;
using System.Runtime.InteropServices;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.Credentials;
using Microsoft.SqlTools.ServiceLayer.Credentials.Contracts;
using Microsoft.SqlTools.ServiceLayer.Credentials.Linux;
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.Test.Utility;
using Moq;
using Xunit;
namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
{
/// <summary>
/// Credential Service tests that should pass on all platforms, regardless of backing store.
/// These tests run E2E, storing values in the native credential store for whichever platform
/// tests are being run on
/// </summary>
public class CredentialServiceTests : IDisposable
{
private static readonly LinuxCredentialStore.StoreConfig config = new LinuxCredentialStore.StoreConfig()
{
CredentialFolder = ".testsecrets",
CredentialFile = "sqltestsecrets.json",
IsRelativeToUserHomeDir = true
};
const string credentialId = "Microsoft_SqlToolsTest_TestId";
const string password1 = "P@ssw0rd1";
const string password2 = "2Pass2Furious";
const string otherCredId = credentialId + "2345";
const string otherPassword = credentialId + "2345";
// Test-owned credential store used to clean up before/after tests to ensure code works as expected
// even if previous runs stopped midway through
private ICredentialStore credStore;
private CredentialService service;
/// <summary>
/// Constructor called once for every test
/// </summary>
public CredentialServiceTests()
{
credStore = CredentialService.GetStoreForOS(config);
service = new CredentialService(credStore, config);
DeleteDefaultCreds();
}
public void Dispose()
{
DeleteDefaultCreds();
}
private void DeleteDefaultCreds()
{
credStore.DeletePassword(credentialId);
credStore.DeletePassword(otherCredId);
if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
string credsFolder = ((LinuxCredentialStore)credStore).CredentialFolderPath;
if (Directory.Exists(credsFolder))
{
Directory.Delete(credsFolder, true);
}
}
}
[Fact]
public async Task SaveCredentialThrowsIfCredentialIdMissing()
{
object errorResponse = null;
var contextMock = RequestContextMocks.Create<bool>(null).AddErrorHandling(obj => errorResponse = obj);
await service.HandleSaveCredentialRequest(new Credential(null), contextMock.Object);
VerifyErrorSent(contextMock);
Assert.True(((string)errorResponse).Contains("ArgumentException"));
}
[Fact]
public async Task SaveCredentialThrowsIfPasswordMissing()
{
object errorResponse = null;
var contextMock = RequestContextMocks.Create<bool>(null).AddErrorHandling(obj => errorResponse = obj);
await service.HandleSaveCredentialRequest(new Credential(credentialId), contextMock.Object);
VerifyErrorSent(contextMock);
Assert.True(((string)errorResponse).Contains("ArgumentException"));
}
[Fact]
public async Task SaveCredentialWorksForSingleCredential()
{
await RunAndVerify<bool>(
test: (requestContext) => service.HandleSaveCredentialRequest(new Credential(credentialId, password1), requestContext),
verify: (actual => Assert.True(actual)));
}
[Fact]
public async Task SaveCredentialSupportsSavingCredentialMultipleTimes()
{
await RunAndVerify<bool>(
test: (requestContext) => service.HandleSaveCredentialRequest(new Credential(credentialId, password1), requestContext),
verify: (actual => Assert.True(actual)));
await RunAndVerify<bool>(
test: (requestContext) => service.HandleSaveCredentialRequest(new Credential(credentialId, password1), requestContext),
verify: (actual => Assert.True(actual)));
}
[Fact]
public async Task ReadCredentialWorksForSingleCredential()
{
// Given we have saved the credential
await RunAndVerify<bool>(
test: (requestContext) => service.HandleSaveCredentialRequest(new Credential(credentialId, password1), requestContext),
verify: (actual => Assert.True(actual, "Expect Credential to be saved successfully")));
// Expect read of the credential to return the password
await RunAndVerify<Credential>(
test: (requestContext) => service.HandleReadCredentialRequest(new Credential(credentialId, null), requestContext),
verify: (actual =>
{
Assert.Equal(password1, actual.Password);
}));
}
[Fact]
public async Task ReadCredentialWorksForMultipleCredentials()
{
// Given we have saved multiple credentials
await RunAndVerify<bool>(
test: (requestContext) => service.HandleSaveCredentialRequest(new Credential(credentialId, password1), requestContext),
verify: (actual => Assert.True(actual, "Expect Credential to be saved successfully")));
await RunAndVerify<bool>(
test: (requestContext) => service.HandleSaveCredentialRequest(new Credential(otherCredId, otherPassword), requestContext),
verify: (actual => Assert.True(actual, "Expect Credential to be saved successfully")));
// Expect read of the credentials to return the right password
await RunAndVerify<Credential>(
test: (requestContext) => service.HandleReadCredentialRequest(new Credential(credentialId, null), requestContext),
verify: (actual =>
{
Assert.Equal(password1, actual.Password);
}));
await RunAndVerify<Credential>(
test: (requestContext) => service.HandleReadCredentialRequest(new Credential(otherCredId, null), requestContext),
verify: (actual =>
{
Assert.Equal(otherPassword, actual.Password);
}));
}
[Fact]
public async Task ReadCredentialHandlesPasswordUpdate()
{
// Given we have saved twice with a different password
await RunAndVerify<bool>(
test: (requestContext) => service.HandleSaveCredentialRequest(new Credential(credentialId, password1), requestContext),
verify: (actual => Assert.True(actual)));
await RunAndVerify<bool>(
test: (requestContext) => service.HandleSaveCredentialRequest(new Credential(credentialId, password2), requestContext),
verify: (actual => Assert.True(actual)));
// When we read the value for this credential
// Then we expect only the last saved password to be found
await RunAndVerify<Credential>(
test: (requestContext) => service.HandleReadCredentialRequest(new Credential(credentialId), requestContext),
verify: (actual =>
{
Assert.Equal(password2, actual.Password);
}));
}
[Fact]
public async Task ReadCredentialThrowsIfCredentialIsNull()
{
object errorResponse = null;
var contextMock = RequestContextMocks.Create<Credential>(null).AddErrorHandling(obj => errorResponse = obj);
// Verify throws on null, and this is sent as an error
await service.HandleReadCredentialRequest(null, contextMock.Object);
VerifyErrorSent(contextMock);
Assert.True(((string)errorResponse).Contains("ArgumentNullException"));
}
[Fact]
public async Task ReadCredentialThrowsIfIdMissing()
{
object errorResponse = null;
var contextMock = RequestContextMocks.Create<Credential>(null).AddErrorHandling(obj => errorResponse = obj);
// Verify throws with no ID
await service.HandleReadCredentialRequest(new Credential(), contextMock.Object);
VerifyErrorSent(contextMock);
Assert.True(((string)errorResponse).Contains("ArgumentException"));
}
[Fact]
public async Task ReadCredentialReturnsNullPasswordForMissingCredential()
{
// Given a credential whose password doesn't exist
string credWithNoPassword = "Microsoft_SqlTools_CredThatDoesNotExist";
// When reading the credential
// Then expect the credential to be returned but password left blank
await RunAndVerify<Credential>(
test: (requestContext) => service.HandleReadCredentialRequest(new Credential(credWithNoPassword, null), requestContext),
verify: (actual =>
{
Assert.NotNull(actual);
Assert.Equal(credWithNoPassword, actual.CredentialId);
Assert.Null(actual.Password);
}));
}
[Fact]
public async Task DeleteCredentialThrowsIfIdMissing()
{
object errorResponse = null;
var contextMock = RequestContextMocks.Create<bool>(null).AddErrorHandling(obj => errorResponse = obj);
// Verify throws with no ID
await service.HandleDeleteCredentialRequest(new Credential(), contextMock.Object);
VerifyErrorSent(contextMock);
Assert.True(((string)errorResponse).Contains("ArgumentException"));
}
[Fact]
public async Task DeleteCredentialReturnsTrueOnlyIfCredentialExisted()
{
// Save should be true
await RunAndVerify<bool>(
test: (requestContext) => service.HandleSaveCredentialRequest(new Credential(credentialId, password1), requestContext),
verify: (actual => Assert.True(actual)));
// Then delete - should return true
await RunAndVerify<bool>(
test: (requestContext) => service.HandleDeleteCredentialRequest(new Credential(credentialId), requestContext),
verify: (actual => Assert.True(actual)));
// Then delete - should return false as no longer exists
await RunAndVerify<bool>(
test: (requestContext) => service.HandleDeleteCredentialRequest(new Credential(credentialId), requestContext),
verify: (actual => Assert.False(actual)));
}
private async Task RunAndVerify<T>(Func<RequestContext<T>, Task> test, Action<T> verify)
{
T result = default(T);
var contextMock = RequestContextMocks.Create<T>(r => result = r).AddErrorHandling(null);
await test(contextMock.Object);
VerifyResult(contextMock, verify, result);
}
private void VerifyErrorSent<T>(Mock<RequestContext<T>> contextMock)
{
contextMock.Verify(c => c.SendResult(It.IsAny<T>()), Times.Never);
contextMock.Verify(c => c.SendError(It.IsAny<string>()), Times.Once);
}
private void VerifyResult<T, U>(Mock<RequestContext<T>> contextMock, U expected, U actual)
{
contextMock.Verify(c => c.SendResult(It.IsAny<T>()), Times.Once);
Assert.Equal(expected, actual);
contextMock.Verify(c => c.SendError(It.IsAny<string>()), Times.Never);
}
private void VerifyResult<T>(Mock<RequestContext<T>> contextMock, Action<T> verify, T actual)
{
contextMock.Verify(c => c.SendResult(It.IsAny<T>()), Times.Once);
contextMock.Verify(c => c.SendError(It.IsAny<string>()), Times.Never);
verify(actual);
}
}
}

View File

@@ -0,0 +1,37 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using Microsoft.SqlTools.ServiceLayer.Credentials;
using Microsoft.SqlTools.ServiceLayer.Credentials.Linux;
using Microsoft.SqlTools.ServiceLayer.Test.Utility;
using Xunit;
namespace Microsoft.SqlTools.ServiceLayer.Test.Credentials
{
public class LinuxInteropTests
{
[Fact]
public void GetEUidReturnsInt()
{
TestUtils.RunIfLinux(() =>
{
Assert.NotNull(Interop.Sys.GetEUid());
});
}
[Fact]
public void GetHomeDirectoryFromPwFindsHomeDir()
{
TestUtils.RunIfLinux(() =>
{
string userDir = LinuxCredentialStore.GetHomeDirectoryFromPw();
Assert.StartsWith("/", userDir);
});
}
}
}

View File

@@ -0,0 +1,99 @@
//
// Code originally from http://credentialmanagement.codeplex.com/,
// Licensed under the Apache License 2.0
//
using System;
using Microsoft.SqlTools.ServiceLayer.Credentials.Win32;
using Microsoft.SqlTools.ServiceLayer.Test.Utility;
using Xunit;
namespace Microsoft.SqlTools.ServiceLayer.Test.Credentials
{
public class CredentialSetTests
{
[Fact]
public void CredentialSetCreate()
{
TestUtils.RunIfWindows(() =>
{
Assert.NotNull(new CredentialSet());
});
}
[Fact]
public void CredentialSetCreateWithTarget()
{
TestUtils.RunIfWindows(() =>
{
Assert.NotNull(new CredentialSet("target"));
});
}
[Fact]
public void CredentialSetShouldBeIDisposable()
{
TestUtils.RunIfWindows(() =>
{
Assert.True(new CredentialSet() is IDisposable, "CredentialSet needs to implement IDisposable Interface.");
});
}
[Fact]
public void CredentialSetLoad()
{
TestUtils.RunIfWindows(() =>
{
Win32Credential credential = new Win32Credential
{
Username = "username",
Password = "password",
Target = "target",
Type = CredentialType.Generic
};
credential.Save();
CredentialSet set = new CredentialSet();
set.Load();
Assert.NotNull(set);
Assert.NotEmpty(set);
credential.Delete();
set.Dispose();
});
}
[Fact]
public void CredentialSetLoadShouldReturnSelf()
{
TestUtils.RunIfWindows(() =>
{
CredentialSet set = new CredentialSet();
Assert.IsType<CredentialSet>(set.Load());
set.Dispose();
});
}
[Fact]
public void CredentialSetLoadWithTargetFilter()
{
TestUtils.RunIfWindows(() =>
{
Win32Credential credential = new Win32Credential
{
Username = "filteruser",
Password = "filterpassword",
Target = "filtertarget"
};
credential.Save();
CredentialSet set = new CredentialSet("filtertarget");
Assert.Equal(1, set.Load().Count);
set.Dispose();
});
}
}
}

View File

@@ -0,0 +1,145 @@
//
// Code originally from http://credentialmanagement.codeplex.com/,
// Licensed under the Apache License 2.0
//
using System;
using Microsoft.SqlTools.ServiceLayer.Credentials.Win32;
using Microsoft.SqlTools.ServiceLayer.Test.Utility;
using Xunit;
namespace Microsoft.SqlTools.ServiceLayer.Test.Credentials
{
public class Win32CredentialTests
{
[Fact]
public void Credential_Create_ShouldNotThrowNull()
{
TestUtils.RunIfWindows(() =>
{
Assert.NotNull(new Win32Credential());
});
}
[Fact]
public void Credential_Create_With_Username_ShouldNotThrowNull()
{
TestUtils.RunIfWindows(() =>
{
Assert.NotNull(new Win32Credential("username"));
});
}
[Fact]
public void Credential_Create_With_Username_And_Password_ShouldNotThrowNull()
{
TestUtils.RunIfWindows(() =>
{
Assert.NotNull(new Win32Credential("username", "password"));
});
}
[Fact]
public void Credential_Create_With_Username_Password_Target_ShouldNotThrowNull()
{
TestUtils.RunIfWindows(() =>
{
Assert.NotNull(new Win32Credential("username", "password", "target"));
});
}
[Fact]
public void Credential_ShouldBe_IDisposable()
{
TestUtils.RunIfWindows(() =>
{
Assert.True(new Win32Credential() is IDisposable, "Credential should implement IDisposable Interface.");
});
}
[Fact]
public void Credential_Dispose_ShouldNotThrowException()
{
TestUtils.RunIfWindows(() =>
{
new Win32Credential().Dispose();
});
}
[Fact]
public void Credential_ShouldThrowObjectDisposedException()
{
TestUtils.RunIfWindows(() =>
{
Win32Credential disposed = new Win32Credential { Password = "password" };
disposed.Dispose();
Assert.Throws<ObjectDisposedException>(() => disposed.Username = "username");
});
}
[Fact]
public void Credential_Save()
{
TestUtils.RunIfWindows(() =>
{
Win32Credential saved = new Win32Credential("username", "password", "target", CredentialType.Generic);
saved.PersistanceType = PersistanceType.LocalComputer;
Assert.True(saved.Save());
});
}
[Fact]
public void Credential_Delete()
{
TestUtils.RunIfWindows(() =>
{
new Win32Credential("username", "password", "target").Save();
Assert.True(new Win32Credential("username", "password", "target").Delete());
});
}
[Fact]
public void Credential_Delete_NullTerminator()
{
TestUtils.RunIfWindows(() =>
{
Win32Credential credential = new Win32Credential((string)null, (string)null, "\0", CredentialType.None);
credential.Description = (string)null;
Assert.False(credential.Delete());
});
}
[Fact]
public void Credential_Load()
{
TestUtils.RunIfWindows(() =>
{
Win32Credential setup = new Win32Credential("username", "password", "target", CredentialType.Generic);
setup.Save();
Win32Credential credential = new Win32Credential { Target = "target", Type = CredentialType.Generic };
Assert.True(credential.Load());
Assert.NotEmpty(credential.Username);
Assert.NotNull(credential.Password);
Assert.Equal("username", credential.Username);
Assert.Equal("password", credential.Password);
Assert.Equal("target", credential.Target);
});
}
[Fact]
public void Credential_Exists_Target_ShouldNotBeNull()
{
TestUtils.RunIfWindows(() =>
{
new Win32Credential { Username = "username", Password = "password", Target = "target" }.Save();
Win32Credential existingCred = new Win32Credential { Target = "target" };
Assert.True(existingCred.Exists());
existingCred.Delete();
});
}
}
}

View File

@@ -4,7 +4,6 @@
//
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
// General Information about an assembly is controlled through the following

View File

@@ -7,6 +7,7 @@ using System;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
using Microsoft.SqlTools.ServiceLayer.Test.Utility;
using Moq;
using Xunit;
@@ -21,7 +22,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
// ... I request a query (doesn't matter what kind) and execute it
var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true);
var executeParams = new QueryExecuteParams { QueryText = Common.StandardQuery, OwnerUri = Common.OwnerUri };
var executeRequest = Common.GetQueryExecuteResultContextMock(null, null, null);
var executeRequest =
RequestContextMocks.SetupRequestContextMock<QueryExecuteResult, QueryExecuteCompleteParams>(null, QueryExecuteCompleteEvent.Type, null, null);
queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait();
queryService.ActiveQueries[Common.OwnerUri].HasExecuted = false; // Fake that it hasn't completed execution
@@ -47,7 +49,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
// ... I request a query (doesn't matter what kind) and wait for execution
var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true);
var executeParams = new QueryExecuteParams {QueryText = Common.StandardQuery, OwnerUri = Common.OwnerUri};
var executeRequest = Common.GetQueryExecuteResultContextMock(null, null, null);
var executeRequest =
RequestContextMocks.SetupRequestContextMock<QueryExecuteResult, QueryExecuteCompleteParams>(null, QueryExecuteCompleteEvent.Type, null, null);
queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait();
// ... And then I request to cancel the query

View File

@@ -3,24 +3,19 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Data.SqlClient;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlServer.Management.SmoMetadataProvider;
using Microsoft.SqlServer.Management.SqlParser.Binder;
using Microsoft.SqlServer.Management.SqlParser.MetadataProvider;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts;
using Microsoft.SqlTools.ServiceLayer.LanguageServices;
using Microsoft.SqlTools.ServiceLayer.QueryExecution;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
using Microsoft.SqlTools.ServiceLayer.SqlContext;
using Microsoft.SqlTools.ServiceLayer.Test.Utility;
using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts;
@@ -215,46 +210,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
}
#endregion
#region Request Mocking
public static Mock<RequestContext<QueryExecuteResult>> GetQueryExecuteResultContextMock(
Action<QueryExecuteResult> resultCallback,
Action<EventType<QueryExecuteCompleteParams>, QueryExecuteCompleteParams> eventCallback,
Action<object> errorCallback)
{
var requestContext = new Mock<RequestContext<QueryExecuteResult>>();
// Setup the mock for SendResult
var sendResultFlow = requestContext
.Setup(rc => rc.SendResult(It.IsAny<QueryExecuteResult>()))
.Returns(Task.FromResult(0));
if (resultCallback != null)
{
sendResultFlow.Callback(resultCallback);
}
// Setup the mock for SendEvent
var sendEventFlow = requestContext.Setup(rc => rc.SendEvent(
It.Is<EventType<QueryExecuteCompleteParams>>(m => m == QueryExecuteCompleteEvent.Type),
It.IsAny<QueryExecuteCompleteParams>()))
.Returns(Task.FromResult(0));
if (eventCallback != null)
{
sendEventFlow.Callback(eventCallback);
}
// Setup the mock for SendError
var sendErrorFlow = requestContext.Setup(rc => rc.SendError(It.IsAny<object>()))
.Returns(Task.FromResult(0));
if (errorCallback != null)
{
sendErrorFlow.Callback(errorCallback);
}
return requestContext;
}
#endregion
}
}

View File

@@ -7,6 +7,7 @@ using System;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
using Microsoft.SqlTools.ServiceLayer.Test.Utility;
using Moq;
using Xunit;
@@ -21,7 +22,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
// ... I request a query (doesn't matter what kind)
var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true);
var executeParams = new QueryExecuteParams {QueryText = "Doesn'tMatter", OwnerUri = Common.OwnerUri};
var executeRequest = Common.GetQueryExecuteResultContextMock(null, null, null);
var executeRequest = RequestContextMocks.SetupRequestContextMock<QueryExecuteResult, QueryExecuteCompleteParams>(null, QueryExecuteCompleteEvent.Type, null, null);
queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait();
// ... And then I dispose of the query

View File

@@ -14,6 +14,7 @@ using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts;
using Microsoft.SqlTools.ServiceLayer.QueryExecution;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
using Microsoft.SqlTools.ServiceLayer.SqlContext;
using Microsoft.SqlTools.ServiceLayer.Test.Utility;
using Microsoft.SqlTools.ServiceLayer.Workspace;
using Moq;
using Xunit;
@@ -399,7 +400,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
QueryExecuteResult result = null;
QueryExecuteCompleteParams completeParams = null;
var requestContext = Common.GetQueryExecuteResultContextMock(qer => result = qer, (et, cp) => completeParams = cp, null);
var requestContext =
RequestContextMocks.SetupRequestContextMock<QueryExecuteResult, QueryExecuteCompleteParams>(
resultCallback: qer => result = qer,
expectedEvent: QueryExecuteCompleteEvent.Type,
eventCallback: (et, cp) => completeParams = cp,
errorCallback: null);
queryService.HandleExecuteRequest(queryParams, requestContext.Object).Wait();
// Then:
@@ -426,7 +432,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
QueryExecuteResult result = null;
QueryExecuteCompleteParams completeParams = null;
var requestContext = Common.GetQueryExecuteResultContextMock(qer => result = qer, (et, cp) => completeParams = cp, null);
var requestContext =
RequestContextMocks.SetupRequestContextMock<QueryExecuteResult, QueryExecuteCompleteParams>(
resultCallback: qer => result = qer,
expectedEvent: QueryExecuteCompleteEvent.Type,
eventCallback: (et, cp) => completeParams = cp,
errorCallback: null);
queryService.HandleExecuteRequest(queryParams, requestContext.Object).Wait();
// Then:
@@ -453,7 +464,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
var queryParams = new QueryExecuteParams { OwnerUri = "notConnected", QueryText = Common.StandardQuery };
QueryExecuteResult result = null;
var requestContext = Common.GetQueryExecuteResultContextMock(qer => result = qer, null, null);
var requestContext = RequestContextMocks.SetupRequestContextMock<QueryExecuteResult, QueryExecuteCompleteParams>(qer => result = qer, QueryExecuteCompleteEvent.Type, null, null);
queryService.HandleExecuteRequest(queryParams, requestContext.Object).Wait();
// Then:
@@ -476,13 +487,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QueryText = Common.StandardQuery };
// Note, we don't care about the results of the first request
var firstRequestContext = Common.GetQueryExecuteResultContextMock(null, null, null);
var firstRequestContext = RequestContextMocks.SetupRequestContextMock<QueryExecuteResult, QueryExecuteCompleteParams>(null, QueryExecuteCompleteEvent.Type, null, null);
queryService.HandleExecuteRequest(queryParams, firstRequestContext.Object).Wait();
// ... And then I request another query without waiting for the first to complete
queryService.ActiveQueries[Common.OwnerUri].HasExecuted = false; // Simulate query hasn't finished
QueryExecuteResult result = null;
var secondRequestContext = Common.GetQueryExecuteResultContextMock(qer => result = qer, null, null);
var secondRequestContext = RequestContextMocks.SetupRequestContextMock<QueryExecuteResult, QueryExecuteCompleteParams>(qer => result = qer, QueryExecuteCompleteEvent.Type, null, null);
queryService.HandleExecuteRequest(queryParams, secondRequestContext.Object).Wait();
// Then:
@@ -505,13 +516,15 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QueryText = Common.StandardQuery };
// Note, we don't care about the results of the first request
var firstRequestContext = Common.GetQueryExecuteResultContextMock(null, null, null);
var firstRequestContext = RequestContextMocks.SetupRequestContextMock<QueryExecuteResult, QueryExecuteCompleteParams>(null, QueryExecuteCompleteEvent.Type, null, null);
queryService.HandleExecuteRequest(queryParams, firstRequestContext.Object).Wait();
// ... And then I request another query after waiting for the first to complete
QueryExecuteResult result = null;
QueryExecuteCompleteParams complete = null;
var secondRequestContext = Common.GetQueryExecuteResultContextMock(qer => result = qer, (et, qecp) => complete = qecp, null);
var secondRequestContext =
RequestContextMocks.SetupRequestContextMock<QueryExecuteResult, QueryExecuteCompleteParams>(qer => result = qer, QueryExecuteCompleteEvent.Type, (et, qecp) => complete = qecp, null);
queryService.HandleExecuteRequest(queryParams, secondRequestContext.Object).Wait();
// Then:
@@ -535,7 +548,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QueryText = query };
QueryExecuteResult result = null;
var requestContext = Common.GetQueryExecuteResultContextMock(qer => result = qer, null, null);
var requestContext =
RequestContextMocks.SetupRequestContextMock<QueryExecuteResult, QueryExecuteCompleteParams>(qer => result = qer, QueryExecuteCompleteEvent.Type, null, null);
queryService.HandleExecuteRequest(queryParams, requestContext.Object).Wait();
// Then:
@@ -560,7 +574,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
QueryExecuteResult result = null;
QueryExecuteCompleteParams complete = null;
var requestContext = Common.GetQueryExecuteResultContextMock(qer => result = qer, (et, qecp) => complete = qecp, null);
var requestContext =
RequestContextMocks.SetupRequestContextMock<QueryExecuteResult, QueryExecuteCompleteParams>(qer => result = qer, QueryExecuteCompleteEvent.Type, (et, qecp) => complete = qecp, null);
queryService.HandleExecuteRequest(queryParams, requestContext.Object).Wait();
// Then:

View File

@@ -9,6 +9,7 @@ using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.QueryExecution;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
using Microsoft.SqlTools.ServiceLayer.SqlContext;
using Microsoft.SqlTools.ServiceLayer.Test.Utility;
using Moq;
using Xunit;
@@ -95,7 +96,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
var queryService =Common.GetPrimedExecutionService(
Common.CreateMockFactory(new[] {Common.StandardTestData}, false), true);
var executeParams = new QueryExecuteParams {QueryText = "Doesn'tMatter", OwnerUri = Common.OwnerUri};
var executeRequest = Common.GetQueryExecuteResultContextMock(null, null, null);
var executeRequest = RequestContextMocks.SetupRequestContextMock<QueryExecuteResult, QueryExecuteCompleteParams>(null, QueryExecuteCompleteEvent.Type, null, null);
queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait();
// ... And I then ask for a valid set of results from it
@@ -141,7 +142,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
var queryService = Common.GetPrimedExecutionService(
Common.CreateMockFactory(new[] { Common.StandardTestData }, false), true);
var executeParams = new QueryExecuteParams { QueryText = "Doesn'tMatter", OwnerUri = Common.OwnerUri };
var executeRequest = Common.GetQueryExecuteResultContextMock(null, null, null);
var executeRequest = RequestContextMocks.SetupRequestContextMock<QueryExecuteResult, QueryExecuteCompleteParams>(null, QueryExecuteCompleteEvent.Type, null, null);
queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait();
queryService.ActiveQueries[Common.OwnerUri].HasExecuted = false;
@@ -168,7 +169,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
var queryService = Common.GetPrimedExecutionService(
Common.CreateMockFactory(null, false), true);
var executeParams = new QueryExecuteParams { QueryText = "Doesn'tMatter", OwnerUri = Common.OwnerUri };
var executeRequest = Common.GetQueryExecuteResultContextMock(null, null, null);
var executeRequest = RequestContextMocks.SetupRequestContextMock<QueryExecuteResult, QueryExecuteCompleteParams>(null, QueryExecuteCompleteEvent.Type, null, null);
queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait();
// ... And I then ask for a set of results from it

View File

@@ -0,0 +1,77 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts;
using Moq;
namespace Microsoft.SqlTools.ServiceLayer.Test.Utility
{
public static class RequestContextMocks
{
public static Mock<RequestContext<TResponse>> Create<TResponse>(Action<TResponse> resultCallback)
{
var requestContext = new Mock<RequestContext<TResponse>>();
// Setup the mock for SendResult
var sendResultFlow = requestContext
.Setup(rc => rc.SendResult(It.IsAny<TResponse>()))
.Returns(Task.FromResult(0));
if (resultCallback != null)
{
sendResultFlow.Callback(resultCallback);
}
return requestContext;
}
public static Mock<RequestContext<TResponse>> AddEventHandling<TResponse, TParams>(
this Mock<RequestContext<TResponse>> mock,
EventType<TParams> expectedEvent,
Action<EventType<TParams>, TParams> eventCallback)
{
var flow = mock.Setup(rc => rc.SendEvent(
It.Is<EventType<TParams>>(m => m == expectedEvent),
It.IsAny<TParams>()))
.Returns(Task.FromResult(0));
if (eventCallback != null)
{
flow.Callback(eventCallback);
}
return mock;
}
public static Mock<RequestContext<TResponse>> AddErrorHandling<TResponse>(
this Mock<RequestContext<TResponse>> mock,
Action<object> errorCallback)
{
// Setup the mock for SendError
var sendErrorFlow = mock.Setup(rc => rc.SendError(It.IsAny<object>()))
.Returns(Task.FromResult(0));
if (mock != null && errorCallback != null)
{
sendErrorFlow.Callback(errorCallback);
}
return mock;
}
public static Mock<RequestContext<TResponse>> SetupRequestContextMock<TResponse, TParams>(
Action<TResponse> resultCallback,
EventType<TParams> expectedEvent,
Action<EventType<TParams>, TParams> eventCallback,
Action<object> errorCallback)
{
return Create(resultCallback)
.AddEventHandling(expectedEvent, eventCallback)
.AddErrorHandling(errorCallback);
}
}
}

View File

@@ -6,20 +6,13 @@
//#define USE_LIVE_CONNECTION
using System;
using System.Collections;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Data.SqlClient;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlTools.ServiceLayer.LanguageServices;
using Microsoft.SqlTools.ServiceLayer.SqlContext;
using Microsoft.SqlTools.ServiceLayer.Test.Utility;
using Xunit;
namespace Microsoft.SqlTools.Test.Utility
{

View File

@@ -0,0 +1,25 @@
using System;
using System.Runtime.InteropServices;
namespace Microsoft.SqlTools.ServiceLayer.Test.Utility
{
public static class TestUtils
{
public static void RunIfLinux(Action test)
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
test();
}
}
public static void RunIfWindows(Action test)
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
test();
}
}
}
}