Fix encoding for OSX keychain (#1939)

This commit is contained in:
Cheena Malhotra
2023-03-23 20:31:59 -07:00
committed by GitHub
parent c3444e5cf5
commit 3e4e0bc8c2
5 changed files with 152 additions and 45 deletions

View File

@@ -1,10 +1,8 @@
//
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
#nullable disable
using Microsoft.SqlTools.Credentials.Contracts;
namespace Microsoft.SqlTools.Credentials
@@ -30,7 +28,7 @@ namespace Microsoft.SqlTools.Credentials
/// <param name="credentialId">The name of the credential to find the password for. This is required</param>
/// <param name="password">Out value</param>
/// <returns>true if password was found, false otherwise</returns>
bool TryGetPassword(string credentialId, out string password);
bool TryGetPassword(string credentialId, out string? password);
/// <summary>
/// Deletes a password linked to a given credential

View File

@@ -3,8 +3,6 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
#nullable disable
using System;
using System.Runtime.InteropServices;
using System.Text;
@@ -13,25 +11,47 @@ namespace Microsoft.SqlTools.Credentials
{
internal static class InteropUtils
{
/// <summary>
/// Gets the length in bytes for a Unicode string, for use in interop where length must be defined
/// Gets the length in bytes for an encoded string, for use in interop where length must be defined
/// </summary>
public static UInt32 GetLengthInBytes(string value)
/// <param name="value">String value</param>
/// <param name="encoding">Encoding of string provided.</param>
public static UInt32 GetLengthInBytes(string value, Encoding encoding)
{
return Convert.ToUInt32( (value != null ? Encoding.Unicode.GetByteCount(value) : 0) );
if (encoding != Encoding.Unicode && encoding != Encoding.UTF8)
{
throw new ArgumentException($"Encoding {encoding} not supported.");
}
return Convert.ToUInt32((value != null
? (encoding == Encoding.UTF8
? Encoding.UTF8.GetByteCount(value)
: Encoding.Unicode.GetByteCount(value))
: 0));
}
public static string CopyToString(IntPtr ptr, int length)
/// <summary>
/// Copies data of length <paramref name="length"/> from <paramref name="ptr"/>
/// pointer to a string of provided encoding.
/// </summary>
/// <param name="ptr">Pointer to data</param>
/// <param name="length">Length of data to be copied.</param>
/// <param name="encoding">Character encoding to be used to get string.</param>
/// <returns></returns>
public static string? CopyToString(IntPtr ptr, int length, Encoding encoding)
{
if (ptr == IntPtr.Zero || length == 0)
{
return null;
}
if (encoding != Encoding.Unicode && encoding != Encoding.UTF8)
{
throw new ArgumentException($"Encoding {encoding} not supported.");
}
byte[] pwdBytes = new byte[length];
Marshal.Copy(ptr, pwdBytes, 0, (int)length);
return Encoding.Unicode.GetString(pwdBytes, 0, (int)length);
return (encoding == Encoding.UTF8)
? Encoding.UTF8.GetString(pwdBytes, 0, (int)length)
: Encoding.Unicode.GetString(pwdBytes, 0, (int)length);
}
}

View File

@@ -1,4 +1,4 @@
//
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
@@ -14,8 +14,7 @@ namespace Microsoft.SqlTools.Credentials
{
internal partial class Security
{
[DllImport(Libraries.SecurityLibrary, CharSet = CharSet.Unicode, SetLastError = true)]
[DllImport(Libraries.SecurityLibrary, CharSet = CharSet.Auto, SetLastError = true)]
internal static extern OSStatus SecKeychainAddGenericPassword(IntPtr keyChainRef, UInt32 serviceNameLength, string serviceName,
UInt32 accountNameLength, string accountName, UInt32 passwordLength, IntPtr password, [Out] IntPtr itemRef);
@@ -43,7 +42,7 @@ namespace Microsoft.SqlTools.Credentials
/// Most attributes are optional; you should pass only as many as you need to narrow the search sufficiently for your application's intended use.
/// SecKeychainFindGenericPassword optionally returns a reference to the found item.
/// </remarks>
[DllImport(Libraries.SecurityLibrary, CharSet = CharSet.Unicode, SetLastError = true)]
[DllImport(Libraries.SecurityLibrary, CharSet = CharSet.Auto, SetLastError = true)]
internal static extern OSStatus SecKeychainFindGenericPassword(IntPtr keyChainRef, UInt32 serviceNameLength, string serviceName,
UInt32 accountNameLength, string accountName, out UInt32 passwordLength, out IntPtr password, out IntPtr itemRef);

View File

@@ -0,0 +1,52 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System;
using System.Runtime.InteropServices;
using static Microsoft.SqlTools.Credentials.Interop.Security;
namespace Microsoft.SqlTools.Credentials
{
internal partial class Interop
{
internal class SecurityOld
{
/// <summary>
/// Find a generic password based on the attributes passed (using Unicode encoding)
/// </summary>
/// <param name="keyChainRef">
/// A reference to an array of keychains to search, a single keychain, or NULL to search the user's default keychain search list.
/// </param>
/// <param name="serviceNameLength">The length of the buffer pointed to by serviceName.</param>
/// <param name="serviceName">A pointer to a string containing the service name.</param>
/// <param name="accountNameLength">The length of the buffer pointed to by accountName.</param>
/// <param name="accountName">A pointer to a string containing the account name.</param>
/// <param name="passwordLength">On return, the length of the buffer pointed to by passwordData.</param>
/// <param name="password">
/// On return, a pointer to a data buffer containing the password.
/// Your application must call SecKeychainItemFreeContent(NULL, passwordData)
/// to release this data buffer when it is no longer needed.Pass NULL if you are not interested in retrieving the password data at
/// this time, but simply want to find the item reference.
/// </param>
/// <param name="itemRef">On return, a reference to the keychain item which was found.</param>
/// <returns>A result code that should be in <see cref="OSStatus"/></returns>
/// <remarks>
/// The SecKeychainFindGenericPassword function finds the first generic password item which matches the attributes you provide.
/// Most attributes are optional; you should pass only as many as you need to narrow the search sufficiently for your application's intended use.
/// SecKeychainFindGenericPassword optionally returns a reference to the found item.
/// </remarks>
///
/// ***********************************************************************************************
/// This method is marked OBSOLETE as it used 'Unicode' Charset encoding to save credentials.
/// It has been replaced with respective API in the 'Security' class using 'Auto' Charset encoding.
/// ***********************************************************************************************
[Obsolete]
[DllImport(Libraries.SecurityLibrary, CharSet = CharSet.Unicode, SetLastError = true)]
internal static extern OSStatus SecKeychainFindGenericPassword(IntPtr keyChainRef, UInt32 serviceNameLength, string serviceName,
UInt32 accountNameLength, string accountName, out UInt32 passwordLength, out IntPtr password, out IntPtr itemRef);
}
}
}

View File

@@ -3,10 +3,9 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
#nullable disable
using System;
using System.Runtime.InteropServices;
using System.Text;
using Microsoft.SqlTools.Credentials.Contracts;
using Microsoft.SqlTools.Utility;
@@ -26,7 +25,7 @@ namespace Microsoft.SqlTools.Credentials.OSX
return DeletePasswordImpl(credentialId);
}
public bool TryGetPassword(string credentialId, out string password)
public bool TryGetPassword(string credentialId, out string? password)
{
Validate.IsNotNullOrEmptyString("credentialId", credentialId);
return FindPassword(credentialId, out password);
@@ -49,14 +48,14 @@ namespace Microsoft.SqlTools.Credentials.OSX
private bool AddGenericPassword(Credential credential)
{
IntPtr passwordPtr = Marshal.StringToCoTaskMemUni(credential.Password);
IntPtr passwordPtr = Marshal.StringToCoTaskMemUTF8(credential.Password);
Interop.Security.OSStatus status = Interop.Security.SecKeychainAddGenericPassword(
IntPtr.Zero,
InteropUtils.GetLengthInBytes(credential.CredentialId),
InteropUtils.GetLengthInBytes(credential.CredentialId, Encoding.UTF8),
credential.CredentialId,
0,
null,
InteropUtils.GetLengthInBytes(credential.Password),
InteropUtils.GetLengthInBytes(credential.Password, Encoding.UTF8),
passwordPtr,
IntPtr.Zero);
@@ -66,12 +65,12 @@ namespace Microsoft.SqlTools.Credentials.OSX
/// <summary>
/// Finds the first password matching this credential
/// </summary>
private bool FindPassword(string credentialId, out string password)
private bool FindPassword(string credentialId, out string? password)
{
password = null;
using (KeyChainItemHandle handle = LookupKeyChainItem(credentialId))
using (KeyChainItemHandle? handle = LookupKeyChainItem(credentialId))
{
if( handle == null)
if (handle == null)
{
return false;
}
@@ -81,14 +80,14 @@ namespace Microsoft.SqlTools.Credentials.OSX
return true;
}
private KeyChainItemHandle LookupKeyChainItem(string credentialId)
private KeyChainItemHandle? LookupKeyChainItem(string credentialId)
{
UInt32 passwordLength;
IntPtr passwordPtr;
IntPtr item;
Interop.Security.OSStatus status = Interop.Security.SecKeychainFindGenericPassword(
IntPtr.Zero,
InteropUtils.GetLengthInBytes(credentialId),
InteropUtils.GetLengthInBytes(credentialId, Encoding.UTF8),
credentialId,
0,
null,
@@ -96,9 +95,40 @@ namespace Microsoft.SqlTools.Credentials.OSX
out passwordPtr,
out item);
if(status == Interop.Security.OSStatus.ErrSecSuccess)
if (status == Interop.Security.OSStatus.ErrSecSuccess)
{
return new KeyChainItemHandle(item, passwordPtr, passwordLength);
return new KeyChainItemHandle(item, passwordPtr, passwordLength, Encoding.UTF8);
}
else
{
#pragma warning disable 0612
// Intentional fallback to Unicode to retrieve old passwords before encoding shift
status = Interop.SecurityOld.SecKeychainFindGenericPassword(
IntPtr.Zero,
InteropUtils.GetLengthInBytes(credentialId, Encoding.Unicode),
credentialId,
0,
null!,
out passwordLength,
out passwordPtr,
out item);
#pragma warning restore 0612
if (status == Interop.Security.OSStatus.ErrSecSuccess)
{
using var handle = new KeyChainItemHandle(item, passwordPtr, passwordLength, Encoding.Unicode);
// Migrate credential to 'Auto' encoding.
if (handle != null)
{
var saveResult = this.AddGenericPassword(credential: new Credential(credentialId, handle.Password));
if (saveResult)
{
// Safe to delete old password now.
Interop.Security.SecKeychainItemDelete(handle);
}
// Lookup keychain again to fetch handle for new credential.
return LookupKeyChainItem(credentialId);
}
}
}
return null;
}
@@ -106,7 +136,7 @@ namespace Microsoft.SqlTools.Credentials.OSX
private bool DeletePasswordImpl(string credentialId)
{
// Find password, then Delete, then cleanup
using (KeyChainItemHandle handle = LookupKeyChainItem(credentialId))
using (KeyChainItemHandle? handle = LookupKeyChainItem(credentialId))
{
if (handle == null)
{
@@ -121,32 +151,40 @@ namespace Microsoft.SqlTools.Credentials.OSX
{
private IntPtr passwordPtr;
private int passwordLength;
private Encoding encoding = Encoding.UTF8;
public KeyChainItemHandle() : base()
{
}
public KeyChainItemHandle(IntPtr itemPtr) : this(itemPtr, IntPtr.Zero, 0)
public KeyChainItemHandle(IntPtr itemPtr) : this(itemPtr, IntPtr.Zero, 0, Encoding.UTF8)
{
}
public KeyChainItemHandle(IntPtr itemPtr, IntPtr passwordPtr, UInt32 passwordLength)
public KeyChainItemHandle(IntPtr itemPtr, IntPtr passwordPtr, UInt32 passwordLength, Encoding encoding)
: base(itemPtr)
{
if (encoding != Encoding.UTF8 && encoding != Encoding.Unicode)
{
throw new ArgumentException($"Encoding {encoding} not supported.");
}
this.passwordPtr = passwordPtr;
this.passwordLength = (int) passwordLength;
this.passwordLength = (int)passwordLength;
this.encoding = encoding;
}
public string Password
public string? Password
{
get {
get
{
if (IsInvalid)
{
return null;
}
return InteropUtils.CopyToString(passwordPtr, passwordLength);
return InteropUtils.CopyToString(passwordPtr, passwordLength,
this.encoding == Encoding.UTF8 ? Encoding.UTF8: Encoding.Unicode);
}
}
protected override bool ReleaseHandle()