//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
using Microsoft.SqlTools.Credentials;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlTools.ServiceLayer.Hosting;
using Microsoft.SqlTools.ServiceLayer.ObjectExplorer;
using Microsoft.SqlTools.ServiceLayer.QueryExecution;
using Microsoft.SqlTools.ServiceLayer.SqlContext;
using Microsoft.SqlTools.ServiceLayer.Workspace;
using Xunit;
namespace Microsoft.SqlTools.ServiceLayer.Test.Common
{
///
/// Class to provide SQL tools service classes
///
public class TestServiceProvider
{
private TestServiceProvider()
{
InitializeTestServices();
}
private static object _lockObject = new object();
private static TestServiceProvider _instance = new TestServiceProvider();
public static TestServiceProvider Instance
{
get
{
return _instance;
}
}
public CredentialService CredentialService
{
get
{
return CredentialService.Instance;
}
}
public ConnectionService ConnectionService
{
get
{
return ConnectionService.Instance;
}
}
public ObjectExplorerService ObjectExplorerService
{
get
{
var serviceProvider = ServiceHost.Instance.ServiceProvider;
return serviceProvider.GetService();
}
}
public TestConnectionProfileService ConnectionProfileService
{
get
{
return TestConnectionProfileService.Instance;
}
}
public WorkspaceService WorkspaceService
{
get
{
return WorkspaceService.Instance;
}
}
///
/// Runs a query by calling the services directly (not using the test driver)
///
public void RunQuery(TestServerType serverType, string databaseName, string queryText, bool throwOnError = false)
{
RunQueryAsync(serverType, databaseName, queryText, throwOnError).Wait();
}
///
/// Runs a query by calling the services directly (not using the test driver)
///
public async Task RunQueryAsync(TestServerType serverType, string databaseName, string queryText, bool throwOnError = false)
{
string uri = "";
using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
{
uri = queryTempFile.FilePath;
ConnectionInfo connInfo = InitLiveConnectionInfo(serverType, databaseName, uri);
Query query = new Query(queryText, connInfo, new QueryExecutionSettings(), MemoryFileSystem.GetFileStreamFactory());
query.Execute();
await query.ExecutionTask;
if (throwOnError)
{
IEnumerable errorBatches = query.Batches.Where(b => b.HasError);
if (errorBatches.Count() > 0)
{
throw new InvalidOperationException(
string.Format(
"The query encountered and error. The batches with errors: {0}",
string.Join(Environment.NewLine, errorBatches.Select(b => b.BatchText))));
}
}
DisconnectConnection(uri);
}
}
public static async Task CalculateRunTime(Func> testToRun, bool printResult, [CallerMemberName] string testName = "")
{
TestTimer timer = new TestTimer() { PrintResult = printResult };
T result = await testToRun();
timer.EndAndPrint(testName);
return result;
}
private ConnectionInfo InitLiveConnectionInfo(TestServerType serverType, string databaseName, string scriptFilePath)
{
ConnectParams connectParams = ConnectionProfileService.GetConnectionParameters(serverType, databaseName);
string ownerUri = scriptFilePath;
var connectionService = ConnectionService.Instance;
var connectionResult = connectionService.Connect(new ConnectParams()
{
OwnerUri = ownerUri,
Connection = connectParams.Connection
});
connectionResult.Wait();
ConnectionInfo connInfo = null;
connectionService.TryFindConnection(ownerUri, out connInfo);
Assert.NotNull(connInfo);
return connInfo;
}
private void DisconnectConnection(string uri)
{
ConnectionService.Instance.Disconnect(new DisconnectParams
{
OwnerUri = uri,
Type = ConnectionType.Default
});
}
private static bool hasInitServices = false;
private static void InitializeTestServices()
{
if (TestServiceProvider.hasInitServices)
{
return;
}
lock (_lockObject)
{
if (TestServiceProvider.hasInitServices)
{
return;
}
TestServiceProvider.hasInitServices = true;
const string hostName = "SQL Tools Test Service Host";
const string hostProfileId = "SQLToolsTestService";
Version hostVersion = new Version(1, 0);
// set up the host details and profile paths
var hostDetails = new HostDetails(hostName, hostProfileId, hostVersion);
SqlToolsContext sqlToolsContext = new SqlToolsContext(hostDetails);
// Grab the instance of the service host
ServiceHost serviceHost = HostLoader.CreateAndStartServiceHost(sqlToolsContext);
}
}
}
}