//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Diagnostics;
using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.SqlTools.ServiceLayer.AzureFunctions.Contracts;
using Microsoft.SqlTools.ServiceLayer.Utility;
using Microsoft.SqlTools.Utility;
using Microsoft.CodeAnalysis.CSharp;
namespace Microsoft.SqlTools.ServiceLayer.AzureFunctions
{
///
/// Class to represent inserting a sql binding into an Azure Function
///
class AddSqlBindingOperation
{
public const string GenericClass = "System.Collections.Generic";
public AddSqlBindingParams Parameters { get; }
public AddSqlBindingOperation(AddSqlBindingParams parameters)
{
Validate.IsNotNull("parameters", parameters);
this.Parameters = parameters;
}
public ResultStatus AddBinding()
{
try
{
string text = File.ReadAllText(Parameters.filePath);
SyntaxTree tree = CSharpSyntaxTree.ParseText(text);
CompilationUnitSyntax root = tree.GetCompilationUnitRoot();
// look for Azure Function to update
IEnumerable azureFunctionMethods = AzureFunctionsUtils.GetMethodsWithFunctionAttributes(root);
IEnumerable matchingMethods = azureFunctionMethods.Where(md => md.AttributeLists.Where(a => a.Attributes.Where(attr => attr.ArgumentList.Arguments.First().ToString().Equals($"\"{Parameters.functionName}\"")).Any()).Any());
if (matchingMethods.Count() == 0)
{
return new ResultStatus()
{
Success = false,
ErrorMessage = SR.CouldntFindAzureFunction(Parameters.functionName, Parameters.filePath)
};
}
else if (matchingMethods.Count() > 1)
{
return new ResultStatus()
{
Success = false,
ErrorMessage = SR.MoreThanOneAzureFunctionWithName(Parameters.functionName, Parameters.filePath)
};
}
MethodDeclarationSyntax azureFunction = matchingMethods.First();
var newParam = this.Parameters.bindingType == BindingType.input ? this.GenerateInputBinding() : this.GenerateOutputBinding();
// Generate updated method with the new parameter
// normalizewhitespace gets rid of any newline whitespace in the leading trivia, so we add that back
var updatedMethod = azureFunction.AddParameterListParameters(newParam).NormalizeWhitespace().WithLeadingTrivia(azureFunction.GetLeadingTrivia()).WithTrailingTrivia(azureFunction.GetTrailingTrivia());
// Replace the node in the tree
root = root.ReplaceNode(azureFunction, updatedMethod);
if (this.Parameters.bindingType == BindingType.input)
{
// Check if file has System.Collections.Generic reference, insert it if not
IEnumerable usingDirectives = root.DescendantNodes().OfType();
var genericUsingDirective = usingDirectives.Where(usingDirective => usingDirective.Name.ToString() == GenericClass);
if (genericUsingDirective.Count() == 0)
{
root = root.AddUsings(SyntaxFactory.UsingDirective(SyntaxFactory.ParseName(GenericClass)).NormalizeWhitespace().WithTrailingTrivia(SyntaxFactory.ElasticCarriageReturnLineFeed));
}
}
// write updated tree to file
var workspace = new AdhocWorkspace();
var syntaxTree = CSharpSyntaxTree.ParseText(root.ToString());
var formattedNode = CodeAnalysis.Formatting.Formatter.Format(syntaxTree.GetRoot(), workspace);
StringBuilder sb = new StringBuilder(formattedNode.ToString());
string content = sb.ToString();
File.WriteAllText(Parameters.filePath, content);
return new ResultStatus()
{
Success = true
};
}
catch (Exception ex)
{
Logger.Write(TraceEventType.Information, $"Failed to add sql binding. Error: {ex.Message}");
throw;
}
}
///
/// Generates a parameter for the sql input binding that looks like
/// [Sql("select * from [dbo].[table1]", CommandType = System.Data.CommandType.Text, ConnectionStringSetting = "SqlConnectionString")] IEnumerable
private ParameterSyntax GenerateInputBinding()
{
// Create arguments for the Sql Input Binding attribute
var argumentList = SyntaxFactory.AttributeArgumentList();
argumentList = argumentList.AddArguments(SyntaxFactory.AttributeArgument(SyntaxFactory.IdentifierName($"\"select * from {Parameters.objectName}\"")));
argumentList = argumentList.AddArguments(SyntaxFactory.AttributeArgument(SyntaxFactory.AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, SyntaxFactory.IdentifierName("CommandType"), SyntaxFactory.IdentifierName("System.Data.CommandType.Text"))));
argumentList = argumentList.AddArguments(SyntaxFactory.AttributeArgument(SyntaxFactory.AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, SyntaxFactory.IdentifierName("ConnectionStringSetting"), SyntaxFactory.IdentifierName($"\"{Parameters.connectionStringSetting}\""))));
// Create Sql Binding attribute
SyntaxList attributesList = new SyntaxList();
attributesList = attributesList.Add(SyntaxFactory.AttributeList(SyntaxFactory.SingletonSeparatedList(SyntaxFactory.Attribute(SyntaxFactory.IdentifierName("Sql")).WithArgumentList(argumentList))));
// Create new parameter
ParameterSyntax newParam = SyntaxFactory.Parameter(attributesList, new SyntaxTokenList(), SyntaxFactory.ParseTypeName("IEnumerable