Edit Data: Better Formatting Errors (#562)

* Refactoring sql script formatting helpers into To and From helpers

* Updates to make error messages for formatting errors more useful

* Fixing dumb breaks in unit tests

* Addressing comments from PR

* Updates to the SR files...
This commit is contained in:
Benjamin Russell
2017-12-05 17:00:13 -08:00
committed by GitHub
parent e7b76a6dec
commit 64133d929e
37 changed files with 7869 additions and 17711 deletions

View File

@@ -0,0 +1,149 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System;
using System.Collections.Generic;
using System.Text;
using System.Text.RegularExpressions;
using Microsoft.SqlTools.Utility;
namespace Microsoft.SqlTools.ServiceLayer.Utility.SqlScriptFormatters
{
/// <summary>
/// Provides utilities for converting from SQL script syntax into POCOs.
/// </summary>
public static class FromSqlScript
{
// Regex: optionally starts with N, captures string wrapped in single quotes
private static readonly Regex StringRegex = new Regex("^N?'(.*)'$", RegexOptions.Compiled);
/// <summary>
/// Decodes a multipart identifier as used in a SQL script into an array of the multiple
/// parts of the identifier. Implemented as a state machine that iterates over the
/// characters of the multipart identifier.
/// </summary>
/// <param name="multipartIdentifier">Multipart identifier to decode (eg, "[dbo].[test]")</param>
/// <returns>The parts of the multipart identifier in an array (eg, "dbo", "test")</returns>
/// <exception cref="FormatException">
/// Thrown if an invalid state transition is made, indicating that the multipart identifer
/// is not valid.
/// </exception>
public static string[] DecodeMultipartIdentifier(string multipartIdentifier)
{
StringBuilder sb = new StringBuilder();
List<string> namedParts = new List<string>();
bool insideBrackets = false;
bool bracketsClosed = false;
for (int i = 0; i < multipartIdentifier.Length; i++)
{
char iChar = multipartIdentifier[i];
if (insideBrackets)
{
if (iChar == ']')
{
if (HasNextCharacter(multipartIdentifier, ']', i))
{
// This is an escaped ]
sb.Append(iChar);
i++;
}
else
{
// This bracket closes the bracket we were in
insideBrackets = false;
bracketsClosed = true;
}
}
else
{
// This is a standard character
sb.Append(iChar);
}
}
else
{
switch (iChar)
{
case '[':
if (bracketsClosed)
{
throw new FormatException();
}
// We're opening a set of brackets
insideBrackets = true;
bracketsClosed = false;
break;
case '.':
if (sb.Length == 0)
{
throw new FormatException();
}
// We're splitting the identifier into a new part
namedParts.Add(sb.ToString());
sb = new StringBuilder();
bracketsClosed = false;
break;
default:
if (bracketsClosed)
{
throw new FormatException();
}
// This is a standard character
sb.Append(iChar);
break;
}
}
}
if (sb.Length == 0)
{
throw new FormatException();
}
namedParts.Add(sb.ToString());
return namedParts.ToArray();
}
/// <summary>
/// Converts a value from a script into a plain version by unwrapping literal wrappers
/// and unescaping characters.
/// </summary>
/// <param name="literal">The value to unwrap (eg, "(N'foo''bar')")</param>
/// <returns>The unwrapped/unescaped literal (eg, "foo'bar")</returns>
public static string UnwrapLiteral(string literal)
{
// Always remove parens
literal = literal.Trim('(', ')');
// Attempt to unwrap inverted commas around a string
Match match = StringRegex.Match(literal);
if (match.Success)
{
// Like: N'stuff' or 'stuff'
return UnEscapeString(match.Groups[1].Value, '\'');
}
return literal;
}
#region Private Helpers
private static bool HasNextCharacter(string haystack, char needle, int position)
{
return position + 1 < haystack.Length
&& haystack[position + 1] == needle;
}
private static string UnEscapeString(string value, char escapeCharacter)
{
Validate.IsNotNull(nameof(value), value);
// Replace 2x of the escape character with 1x of the escape character
return value.Replace(new string(escapeCharacter, 2), escapeCharacter.ToString());
}
#endregion
}
}

View File

@@ -9,17 +9,16 @@ using System.Data.Common;
using System.Globalization;
using System.Linq;
using System.Text;
using System.Text.RegularExpressions;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
using Microsoft.SqlTools.Utility;
namespace Microsoft.SqlTools.ServiceLayer.Utility
namespace Microsoft.SqlTools.ServiceLayer.Utility.SqlScriptFormatters
{
/// <summary>
/// Provides utility for converting arbitrary objects into strings that are ready to be
/// inserted into SQL strings
/// </summary>
public class SqlScriptFormatter
public static class ToSqlScript
{
#region Constants
@@ -27,16 +26,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Utility
private static readonly Dictionary<string, Func<object, DbColumn, string>> FormatFunctions =
new Dictionary<string, Func<object, DbColumn, string>>
{ // CLR Type --------
{ // CLR Type --------
{"bigint", (val, col) => SimpleFormatter(val)}, // long
{"bit", (val, col) => FormatBool(val)}, // bool
{"int", (val, col) => SimpleFormatter(val)}, // int
{"smallint", (val, col) => SimpleFormatter(val)}, // short
{"tinyint", (val, col) => SimpleFormatter(val)}, // byte
{"money", (val, col) => FormatMoney(val, "MONEY")}, // Decimal
{"smallmoney", (val, col) => FormatMoney(val, "SMALLMONEY")}, // Decimal
{"decimal", (val, col) => FormatPreciseNumeric(val, col, "DECIMAL")}, // Decimal
{"numeric", (val, col) => FormatPreciseNumeric(val, col, "NUMERIC")}, // Decimal
{"money", FormatDecimalLike}, // Decimal
{"smallmoney", FormatDecimalLike}, // Decimal
{"decimal", FormatDecimalLike}, // Decimal
{"numeric", FormatDecimalLike}, // Decimal
{"real", (val, col) => FormatFloat(val)}, // float
{"float", (val, col) => FormatDouble(val)}, // double
{"smalldatetime", (val, col) => FormatDateTime(val, "yyyy-MM-dd HH:mm:ss")}, // DateTime
@@ -65,21 +64,105 @@ namespace Microsoft.SqlTools.ServiceLayer.Utility
// sysname - it doesn't appear possible to insert a sysname column
};
private static readonly Type[] NumericTypes =
{
typeof(byte),
typeof(short),
typeof(int),
typeof(long),
typeof(decimal),
typeof(float),
typeof(double)
};
private static Regex StringRegex = new Regex("^N?'(.*)'$", RegexOptions.Compiled);
#endregion
#region Public Methods
/// <summary>
/// Extracts a DbColumn's datatype and turns it into script ready
/// </summary>
/// <param name="column"></param>
/// <returns></returns>
/// <seealso cref="Microsoft.SqlTools.ServiceLayer.ObjectExplorer.SmoModel.SmoColumnCustomNodeHelper.GetTypeSpecifierLabel"/>
/// <exception cref="InvalidOperationException"></exception>
public static string FormatColumnType(DbColumn column)
{
string typeName = column.DataTypeName.ToUpperInvariant();
// TODO: This doesn't support UDTs at all.
// TODO: It's unclear if this will work on a case-sensitive db collation
// If the type supports length parameters, the add those
switch (column.DataTypeName.ToLowerInvariant())
{
// Types with length
case "char":
case "nchar":
case "varchar":
case "nvarchar":
case "binary":
case "varbinary":
if (!column.ColumnSize.HasValue)
{
throw new InvalidOperationException(SR.SqlScriptFormatterLengthTypeMissingSize);
}
string length = column.ColumnSize.Value == int.MaxValue
? "MAX"
: column.ColumnSize.Value.ToString();
typeName += $"({length})";
break;
// Types with precision and scale
case "numeric":
case "decimal":
if (!column.NumericPrecision.HasValue || !column.NumericScale.HasValue)
{
throw new InvalidOperationException(SR.SqlScriptFormatterDecimalMissingPrecision);
}
typeName += $"({column.NumericPrecision}, {column.NumericScale})";
break;
// Types with scale only
case "datetime2":
case "datetimeoffset":
case "time":
if (!column.NumericScale.HasValue)
{
throw new InvalidOperationException(SR.SqlScriptFormatterScalarTypeMissingScale);
}
typeName += $"({column.NumericScale})";
break;
}
return typeName;
}
/// <summary>
/// Escapes an identifier such as a table name or column name by wrapping it in square brackets
/// </summary>
/// <param name="identifier">The identifier to format</param>
/// <returns>Identifier formatted for use in a SQL script</returns>
public static string FormatIdentifier(string identifier)
{
return $"[{EscapeString(identifier, ']')}]";
}
/// <summary>
/// Escapes a multi-part identifier such as a table name or column name with multiple
/// parts split by '.'
/// </summary>
/// <param name="identifier">The identifier to escape (eg, "dbo.test")</param>
/// <returns>The escaped identifier (eg, "[dbo].[test]")</returns>
public static string FormatMultipartIdentifier(string identifier)
{
// If the object is a multi-part identifier (eg, dbo.tablename) split it, and escape as necessary
return FormatMultipartIdentifier(identifier.Split('.'));
}
/// <summary>
/// Escapes a multipart identifier such as a table name, given an array of the parts of the
/// multipart identifier.
/// </summary>
/// <param name="identifiers">The parts of the identifier to escape (eg, "dbo", "test")</param>
/// <returns>An escaped version of the multipart identifier (eg, "[dbo].[test]")</returns>
public static string FormatMultipartIdentifier(string[] identifiers)
{
IEnumerable<string> escapedParts = identifiers.Select(FormatIdentifier);
return string.Join(".", escapedParts);
}
/// <summary>
/// Converts an object into a string for SQL script
/// </summary>
@@ -107,7 +190,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Utility
}
return FormatFunctions[dataType](value, column);
}
/// <summary>
/// Converts a cell value into a string for SQL script
/// </summary>
@@ -120,229 +203,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Utility
return FormatValue(value.RawObject, column);
}
/// <summary>
/// Escapes an identifier such as a table name or column name by wrapping it in square brackets
/// </summary>
/// <param name="identifier">The identifier to format</param>
/// <returns>Identifier formatted for use in a SQL script</returns>
public static string FormatIdentifier(string identifier)
{
return $"[{EscapeString(identifier, ']')}]";
}
/// <summary>
/// Escapes a multi-part identifier such as a table name or column name with multiple
/// parts split by '.'
/// </summary>
/// <param name="identifier">The identifier to escape</param>
/// <returns>The escaped identifier</returns>
public static string FormatMultipartIdentifier(string identifier)
{
// If the object is a multi-part identifier (eg, dbo.tablename) split it, and escape as necessary
return FormatMultipartIdentifier(identifier.Split('.'));
}
/// <summary>
/// Escapes a multipart identifier such as a table name, given an array of the parts of the
/// multipart identifier.
/// </summary>
/// <param name="identifiers">The parts of the identifier to escape</param>
/// <returns>An escaped version of the multipart identifier</returns>
public static string FormatMultipartIdentifier(string[] identifiers)
{
IEnumerable<string> escapedParts = identifiers.Select(FormatIdentifier);
return string.Join(".", escapedParts);
}
/// <summary>
/// Converts a value from a script into a plain version by unwrapping literal wrappers
/// and unescaping characters.
/// </summary>
/// <param name="literal">The value to unwrap</param>
/// <returns>The unwrapped/unescaped literal</returns>
public static string UnwrapLiteral(string literal)
{
// Always remove parens
literal = literal.Trim('(', ')');
// Attempt to unwrap inverted commas around a string
Match match = StringRegex.Match(literal);
if (match.Success)
{
// Like: N'stuff' or 'stuff'
return UnEscapeString(match.Groups[1].Value, '\'');
}
return literal;
}
public static string[] DecodeMultipartIdenfitier(string multipartIdentifier)
{
StringBuilder sb = new StringBuilder();
List<string> namedParts = new List<string>();
bool insideBrackets = false;
bool bracketsClosed = false;
for (int i = 0; i < multipartIdentifier.Length; i++)
{
char iChar = multipartIdentifier[i];
if (insideBrackets)
{
if (iChar == ']')
{
if (HasNextCharacter(multipartIdentifier, ']', i))
{
// This is an escaped ]
sb.Append(iChar);
i++;
}
else
{
// This bracket closes the bracket we were in
insideBrackets = false;
bracketsClosed = true;
}
}
else
{
// This is a standard character
sb.Append(iChar);
}
}
else
{
switch (iChar)
{
case '[':
if (bracketsClosed)
{
throw new FormatException();
}
// We're opening a set of brackets
insideBrackets = true;
bracketsClosed = false;
break;
case '.':
if (sb.Length == 0)
{
throw new FormatException();
}
// We're splitting the identifier into a new part
namedParts.Add(sb.ToString());
sb = new StringBuilder();
bracketsClosed = false;
break;
default:
if (bracketsClosed)
{
throw new FormatException();
}
// This is a standard character
sb.Append(iChar);
break;
}
}
}
if (sb.Length == 0)
{
throw new FormatException();
}
namedParts.Add(sb.ToString());
return namedParts.ToArray();
}
#endregion
#region Private Helpers
private static string SimpleFormatter(object value)
{
return value.ToString();
}
private static string SimpleStringFormatter(object value)
{
return EscapeQuotedSqlString(value.ToString());
}
private static string FormatMoney(object value, string type)
{
// we have to manually format the string by ToStringing the value first, and then converting
// the potential (European formatted) comma to a period.
string numericString = ((decimal)value).ToString(CultureInfo.InvariantCulture);
return $"CAST({numericString} AS {type})";
}
private static string FormatFloat(object value)
{
// The "R" formatting means "Round Trip", which preserves fidelity
return ((float)value).ToString("R");
}
private static string FormatDouble(object value)
{
// The "R" formatting means "Round Trip", which preserves fidelity
return ((double)value).ToString("R");
}
private static string FormatBool(object value)
{
// Attempt to cast to bool
bool boolValue = (bool)value;
return boolValue ? "1" : "0";
}
private static string FormatPreciseNumeric(object value, DbColumn column, string type)
{
// Make sure we have numeric precision and numeric scale
if (!column.NumericPrecision.HasValue || !column.NumericScale.HasValue)
{
throw new InvalidOperationException(SR.SqlScriptFormatterDecimalMissingPrecision);
}
// Convert the value to a decimal, then convert that to a string
string numericString = ((decimal)value).ToString(CultureInfo.InvariantCulture);
return string.Format(CultureInfo.InvariantCulture, "CAST({0} AS {1}({2}, {3}))",
numericString, type, column.NumericPrecision.Value, column.NumericScale.Value);
}
private static string FormatTimeSpan(object value)
{
// "c" provides "HH:mm:ss.FFFFFFF", and time column accepts up to 7 precision
string timeSpanString = ((TimeSpan)value).ToString("c", CultureInfo.InvariantCulture);
return EscapeQuotedSqlString(timeSpanString);
}
private static string FormatDateTime(object value, string format)
{
string dateTimeString = ((DateTime)value).ToString(format, CultureInfo.InvariantCulture);
return EscapeQuotedSqlString(dateTimeString);
}
private static string FormatDateTimeOffset(object value)
{
string dateTimeString = ((DateTimeOffset)value).ToString(CultureInfo.InvariantCulture);
return EscapeQuotedSqlString(dateTimeString);
}
private static string FormatBinary(object value)
{
byte[] bytes = value as byte[];
if (bytes == null)
{
// Bypass processing if we can't turn this into a byte[]
return "NULL";
}
return "0x" + BitConverter.ToString(bytes).Replace("-", string.Empty);
}
private static bool HasNextCharacter(string haystack, char needle, int position)
{
return position + 1 < haystack.Length
&& haystack[position + 1] == needle;
}
/// <summary>
/// Returns a valid SQL string packaged in single quotes with single quotes inside escaped
/// </summary>
@@ -352,7 +217,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Utility
{
return $"N'{EscapeString(rawString, '\'')}'";
}
/// <summary>
/// Replaces all instances of <paramref name="escapeCharacter"/> with a duplicate of
/// <paramref name="escapeCharacter"/>. For example "can't" becomes "can''t"
@@ -375,15 +240,74 @@ namespace Microsoft.SqlTools.ServiceLayer.Utility
}
return sb.ToString();
}
private static string UnEscapeString(string value, char escapeCharacter)
private static string FormatBinary(object value)
{
Validate.IsNotNull(nameof(value), value);
byte[] bytes = value as byte[];
if (bytes == null)
{
// Bypass processing if we can't turn this into a byte[]
return "NULL";
}
// Replace 2x of the escape character with 1x of the escape character
return value.Replace(new string(escapeCharacter, 2), escapeCharacter.ToString());
return "0x" + BitConverter.ToString(bytes).Replace("-", string.Empty);
}
private static string FormatBool(object value)
{
// Attempt to cast to bool
bool boolValue = (bool)value;
return boolValue ? "1" : "0";
}
private static string FormatDateTime(object value, string format)
{
string dateTimeString = ((DateTime)value).ToString(format, CultureInfo.InvariantCulture);
return EscapeQuotedSqlString(dateTimeString);
}
private static string FormatDateTimeOffset(object value)
{
string dateTimeString = ((DateTimeOffset)value).ToString(CultureInfo.InvariantCulture);
return EscapeQuotedSqlString(dateTimeString);
}
private static string FormatDouble(object value)
{
// The "R" formatting means "Round Trip", which preserves fidelity
return ((double)value).ToString("R");
}
private static string FormatFloat(object value)
{
// The "R" formatting means "Round Trip", which preserves fidelity
return ((float)value).ToString("R");
}
private static string FormatDecimalLike(object value, DbColumn column)
{
string numericString = ((decimal)value).ToString(CultureInfo.InvariantCulture);
string typeString = FormatColumnType(column);
return $"CAST({numericString} AS {typeString})";
}
private static string FormatTimeSpan(object value)
{
// "c" provides "HH:mm:ss.FFFFFFF", and time column accepts up to 7 precision
string timeSpanString = ((TimeSpan)value).ToString("c", CultureInfo.InvariantCulture);
return EscapeQuotedSqlString(timeSpanString);
}
private static string SimpleFormatter(object value)
{
return value.ToString();
}
private static string SimpleStringFormatter(object value)
{
return EscapeQuotedSqlString(value.ToString());
}
#endregion
}
}
}