<#@ template debug="false" hostspecific="true" language="C#" #> <#@ output extension=".cs" #> <#@ assembly name="System.Xml.dll" #> <#@ import namespace="System" #> <#@ import namespace="System.Globalization" #> <#@ import namespace="System.Text" #> <#@ import namespace="System.Xml" #> <#@ import namespace="System.Collections.Generic" #> <#@ import namespace="System.IO" #> // // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. // #nullable disable // This file was generated by a T4 Template. Do not modify directly, instead update the SmoQueryModelDefinition.xml file // and re-run the T4 template. This can be done in Visual Studio by right-click in and choosing "Run Custom Tool", // or from the command-line on any platform by running "build.cmd -Target=CodeGen" or "build.sh -Target=CodeGen". using System; using System.Collections.Generic; using System.Composition; using System.Linq; using Microsoft.SqlServer.Management.Smo; using Microsoft.SqlServer.Management.Smo.Broker; using Microsoft.SqlTools.ServiceLayer.ObjectExplorer.Nodes; using Microsoft.SqlTools.Utility; using Index = Microsoft.SqlServer.Management.Smo.Index; namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer.SmoModel { <# var directory = Path.GetDirectoryName(Host.TemplateFile); string xmlFile = Path.Combine(directory, "SmoQueryModelDefinition.xml"); ///////// // Now generate all the Query methods ///////// var allNodes = GetNodes(xmlFile); var indent = " "; foreach (var nodeName in allNodes) { XmlElement nodeElement = GetNodeElement(xmlFile, nodeName); IList parents = GetParents(nodeElement, xmlFile, nodeName); string nodeType = GetNodeType(nodeElement, nodeName); var validFor = nodeElement.GetAttribute("ValidFor"); string queryBaseClass = "SmoQuerier"; PushIndent(indent); WriteLine(""); WriteLine("[Export(typeof({0}))]", queryBaseClass); WriteLine("internal partial class {0}Querier: {1}", nodeName, queryBaseClass); WriteLine("{"); PushIndent(indent); // Supported Types WriteLine("Type[] supportedTypes = new Type[] { typeof("+ nodeType + ") };"); if (!string.IsNullOrWhiteSpace(validFor)) { WriteLine(""); WriteLine("public override ValidForFlag ValidFor {{ get {{ return {0}; }} }}", GetValidForFlags(validFor)); WriteLine(""); } WriteLine(""); WriteLine("public override Type[] SupportedObjectTypes { get { return supportedTypes; } }"); WriteLine(""); // Query impl WriteLine("public override IEnumerable Query(SmoQueryContext context, string filter, bool refresh, IEnumerable extraProperties)"); WriteLine("{"); PushIndent(indent); WriteLine("Logger.Verbose(\"Begin query {0}\");", nodeType); // TODO Allow override of the navigation path foreach(var parentType in parents) { string parentVar = string.Format("parent{0}", parentType); WriteLine("{0} {1} = context.Parent as {0};", parentType, parentVar); WriteLine("if ({0} != null)", parentVar); WriteLine("{"); PushIndent(indent); WriteLine("Logger.Verbose(\"Parent of type `{0}` found\");",parentType); XmlElement navPathElement = GetNavPathElement(xmlFile, nodeName, parentType); string navigationPath = GetNavigationPath(nodeElement, nodeName, navPathElement); string subField = GetNavPathAttribute(navPathElement, "SubField"); string fieldType = GetNavPathAttribute(navPathElement, "FieldType"); if(navPathElement != null){ /** Adding runtime filters to the querier based on the property values available in the context. The code below will go through all the filters in the navPath element and add them to the filter string recieved by the querier. **/ XmlNodeList runtimeFilters = navPathElement.GetElementsByTagName("Filter"); if(runtimeFilters.Count > 0) { WriteLine("List filters = new List();"); foreach(XmlElement filter in runtimeFilters) { string filterName = filter.GetAttribute("Property"); string filterField = filter.GetAttribute("Field"); string filterType = filter.GetAttribute("Type"); string filterValidFor = filter.GetAttribute("ValidFor"); WriteLine("filters.Add(new NodePropertyFilter()"); WriteLine("{"); PushIndent(indent); WriteLine("Property = nameof({0}),", filterName); WriteLine("Type = typeof({0}),", filterType); WriteLine("Values = new List {{ {0}.{1} }},", parentVar, filterField); if(!string.IsNullOrWhiteSpace(filterValidFor)) { WriteLine("ValidFor = {0},", GetValidForFlags(filterValidFor)); } else { WriteLine("ValidFor = ValidForFlag.All "); } PopIndent(); WriteLine("});"); } WriteLine("filter = INodeFilter.AddPropertyFilterToFilterString(filter, filters, this.GetType(), context.ValidFor);"); } } WriteLine("var retValue = {0}.{1};", parentVar, navigationPath); WriteLine("if (retValue != null)"); WriteLine("{"); PushIndent(indent); if (IsCollection(nodeElement)) { WriteLine("retValue.ClearAndInitialize(filter, extraProperties);"); if (string.IsNullOrEmpty(subField) ) { WriteLine("var ret = new SmoCollectionWrapper<{0}>(retValue).Where(c => PassesFinalFilters({1}, c));", nodeType, parentVar); WriteLine("Logger.Verbose(\"End query {0}\");", nodeType); WriteLine("return ret;"); } else { WriteLine("List<{0}> subFieldResult = new List<{0}>();", nodeType); WriteLine("foreach({0} field in retValue)", fieldType); WriteLine("{"); PushIndent(indent); WriteLine("{0} subField = field.{1};", nodeType, subField); WriteLine("if (subField != null)"); WriteLine("{"); PushIndent(indent); WriteLine("subFieldResult.Add(subField);"); PopIndent(); WriteLine("}"); PopIndent(); WriteLine("}"); WriteLine("var ret = subFieldResult.Where(c => PassesFinalFilters({1}, c));", nodeType, parentVar); WriteLine("Logger.Verbose(\"End query {0}\");", nodeType); WriteLine("return ret;"); } } else { WriteLine("if (refresh)"); WriteLine("{"); PushIndent(indent); WriteLine("{0}.{1}.Refresh();", parentVar, navigationPath); PopIndent(); WriteLine("}"); WriteLine("return new SqlSmoObject[] { retValue };"); } PopIndent(); WriteLine("}"); PopIndent(); WriteLine("}"); // close If } WriteLine("return Enumerable.Empty();"); PopIndent(); WriteLine("}"); // close Query method PopIndent(); WriteLine("}"); // close Class PopIndent(); } #> } <#+ public static string[] GetNodes(string xmlFile) { List typesList = new List(); XmlDocument doc = new XmlDocument(); doc.Load(xmlFile); XmlNodeList treeTypes = doc.SelectNodes("/SmoQueryModel/Node"); if (treeTypes != null) { foreach (var type in treeTypes) { XmlElement element = type as XmlElement; if (element != null) { typesList.Add(element.GetAttribute("Name")); } } } return typesList.ToArray(); } public static XmlElement GetNodeElement(string xmlFile, string nodeName) { XmlDocument doc = new XmlDocument(); doc.Load(xmlFile); return (XmlElement)doc.SelectSingleNode(string.Format("/SmoQueryModel/Node[@Name='{0}']", nodeName)); } public static XmlElement GetNavPathElement(string xmlFile, string nodeName, string parent) { XmlDocument doc = new XmlDocument(); doc.Load(xmlFile); XmlElement navPathElement = (XmlElement)doc.SelectSingleNode(string.Format("/SmoQueryModel/Node[@Name='{0}']/NavigationPath[@Parent='{1}']", nodeName, parent)); return navPathElement; } public static string GetNavPathAttribute(XmlElement navPathElement, string attributeName) { return navPathElement == null ? null : navPathElement.GetAttribute(attributeName); } public static string GetNavigationPath(XmlElement nodeElement, string nodeName, XmlElement navPathElement) { string navPathField = GetNavPathAttribute(navPathElement, "Field"); if (!string.IsNullOrEmpty(navPathField)) { return navPathField; } // else use pluralized type as this is the most common scenario string nodeType = GetNodeType(nodeElement, nodeName); string nodeTypeAccessor = IsCollection(nodeElement) ? string.Format("{0}s", nodeType) : nodeType; return nodeTypeAccessor; } public static string GetNodeType(XmlElement nodeElement, string nodeName) { var type = nodeElement.GetAttribute("Type"); if (!string.IsNullOrEmpty(type)) { return type; } // Otherwise assume the type is the node name without "Sql" at the start var prefix = "Sql"; return nodeName.IndexOf(prefix) == 0 ? nodeName.Substring(prefix.Length) : nodeName; } public static bool IsCollection(XmlElement nodeElement) { var collection = nodeElement.GetAttribute("Collection"); bool result; if (bool.TryParse(collection, out result)) { return result; } // Default is true return true; } public static IList GetParents(XmlElement nodeElement, string xmlFile, string parentName) { var parentAttr = nodeElement.GetAttribute("Parent"); if (!string.IsNullOrEmpty(parentAttr)) { return new string[] { parentAttr }; } var parentNodes = GetChildren(xmlFile, parentName, "Parent"); if (parentNodes != null && parentNodes.Count > 0) { List parents = new List(); foreach(var node in parentNodes) { parents.Add(node.InnerText); } return parents; } // default to assuming a type is under Database return new string[] { "Database" }; } public static List GetChildren(string xmlFile, string parentName, string childNode) { XmlElement nodeElement = GetNodeElement(xmlFile, parentName); XmlDocument doc = new XmlDocument(); doc.Load(xmlFile); List retElements = new List(); XmlNodeList nodeList = doc.SelectNodes(string.Format("/SmoQueryModel/Node[@Name='{0}']/{1}", parentName, childNode)); foreach (var item in nodeList) { XmlElement itemAsElement = item as XmlElement; if (itemAsElement != null) { retElements.Add(itemAsElement); } } return retElements; } public static string GetValidForFlags(string validForStr) { List flags = new List(); if (validForStr.Contains("Sql2005")) { flags.Add("ValidForFlag.Sql2005"); } if (validForStr.Contains("Sql2008")) { flags.Add("ValidForFlag.Sql2008"); } if (validForStr.Contains("Sql2012")) { flags.Add("ValidForFlag.Sql2012"); } if (validForStr.Contains("Sql2014")) { flags.Add("ValidForFlag.Sql2014"); } if (validForStr.Contains("Sql2016")) { flags.Add("ValidForFlag.Sql2016"); } if (validForStr.Contains("Sql2017")) { flags.Add("ValidForFlag.Sql2017"); } if (validForStr.Contains("AzureV12")) { flags.Add("ValidForFlag.AzureV12"); } if (validForStr.Contains("AllOnPrem")) { flags.Add("ValidForFlag.AllOnPrem"); } if (validForStr.Contains("AllAzure")) { flags.Add("ValidForFlag.AllAzure"); } if (validForStr.Contains("NotSqlDw")) { flags.Add("ValidForFlag.NotSqlDw"); } if (validForStr.Contains("SqlOnDemand")) { flags.Add("ValidForFlag.SqlOnDemand"); } if (validForStr == "NotSqlDemand") { flags.Add("ValidForFlag.NotSqlDemand"); } if (validForStr == "All") { flags.Add("ValidForFlag.All"); } return string.Join("|", flags); } #>