Fix encoding for OSX keychain (#1939)

This commit is contained in:
Cheena Malhotra
2023-03-23 20:31:59 -07:00
committed by GitHub
parent c3444e5cf5
commit 3e4e0bc8c2
5 changed files with 152 additions and 45 deletions

View File

@@ -1,10 +1,8 @@
// //
// Copyright (c) Microsoft. All rights reserved. // Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information. // Licensed under the MIT license. See LICENSE file in the project root for full license information.
// //
#nullable disable
using Microsoft.SqlTools.Credentials.Contracts; using Microsoft.SqlTools.Credentials.Contracts;
namespace Microsoft.SqlTools.Credentials namespace Microsoft.SqlTools.Credentials
@@ -30,7 +28,7 @@ namespace Microsoft.SqlTools.Credentials
/// <param name="credentialId">The name of the credential to find the password for. This is required</param> /// <param name="credentialId">The name of the credential to find the password for. This is required</param>
/// <param name="password">Out value</param> /// <param name="password">Out value</param>
/// <returns>true if password was found, false otherwise</returns> /// <returns>true if password was found, false otherwise</returns>
bool TryGetPassword(string credentialId, out string password); bool TryGetPassword(string credentialId, out string? password);
/// <summary> /// <summary>
/// Deletes a password linked to a given credential /// Deletes a password linked to a given credential

View File

@@ -3,8 +3,6 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information. // Licensed under the MIT license. See LICENSE file in the project root for full license information.
// //
#nullable disable
using System; using System;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Text; using System.Text;
@@ -13,25 +11,47 @@ namespace Microsoft.SqlTools.Credentials
{ {
internal static class InteropUtils internal static class InteropUtils
{ {
/// <summary> /// <summary>
/// Gets the length in bytes for a Unicode string, for use in interop where length must be defined /// Gets the length in bytes for an encoded string, for use in interop where length must be defined
/// </summary> /// </summary>
public static UInt32 GetLengthInBytes(string value) /// <param name="value">String value</param>
/// <param name="encoding">Encoding of string provided.</param>
public static UInt32 GetLengthInBytes(string value, Encoding encoding)
{ {
if (encoding != Encoding.Unicode && encoding != Encoding.UTF8)
return Convert.ToUInt32( (value != null ? Encoding.Unicode.GetByteCount(value) : 0) ); {
throw new ArgumentException($"Encoding {encoding} not supported.");
}
return Convert.ToUInt32((value != null
? (encoding == Encoding.UTF8
? Encoding.UTF8.GetByteCount(value)
: Encoding.Unicode.GetByteCount(value))
: 0));
} }
public static string CopyToString(IntPtr ptr, int length) /// <summary>
/// Copies data of length <paramref name="length"/> from <paramref name="ptr"/>
/// pointer to a string of provided encoding.
/// </summary>
/// <param name="ptr">Pointer to data</param>
/// <param name="length">Length of data to be copied.</param>
/// <param name="encoding">Character encoding to be used to get string.</param>
/// <returns></returns>
public static string? CopyToString(IntPtr ptr, int length, Encoding encoding)
{ {
if (ptr == IntPtr.Zero || length == 0) if (ptr == IntPtr.Zero || length == 0)
{ {
return null; return null;
} }
if (encoding != Encoding.Unicode && encoding != Encoding.UTF8)
{
throw new ArgumentException($"Encoding {encoding} not supported.");
}
byte[] pwdBytes = new byte[length]; byte[] pwdBytes = new byte[length];
Marshal.Copy(ptr, pwdBytes, 0, (int)length); Marshal.Copy(ptr, pwdBytes, 0, (int)length);
return Encoding.Unicode.GetString(pwdBytes, 0, (int)length); return (encoding == Encoding.UTF8)
? Encoding.UTF8.GetString(pwdBytes, 0, (int)length)
: Encoding.Unicode.GetString(pwdBytes, 0, (int)length);
} }
} }

View File

@@ -1,4 +1,4 @@
// //
// Copyright (c) Microsoft. All rights reserved. // Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information. // Licensed under the MIT license. See LICENSE file in the project root for full license information.
// //
@@ -14,8 +14,7 @@ namespace Microsoft.SqlTools.Credentials
{ {
internal partial class Security internal partial class Security
{ {
[DllImport(Libraries.SecurityLibrary, CharSet = CharSet.Auto, SetLastError = true)]
[DllImport(Libraries.SecurityLibrary, CharSet = CharSet.Unicode, SetLastError = true)]
internal static extern OSStatus SecKeychainAddGenericPassword(IntPtr keyChainRef, UInt32 serviceNameLength, string serviceName, internal static extern OSStatus SecKeychainAddGenericPassword(IntPtr keyChainRef, UInt32 serviceNameLength, string serviceName,
UInt32 accountNameLength, string accountName, UInt32 passwordLength, IntPtr password, [Out] IntPtr itemRef); UInt32 accountNameLength, string accountName, UInt32 passwordLength, IntPtr password, [Out] IntPtr itemRef);
@@ -43,7 +42,7 @@ namespace Microsoft.SqlTools.Credentials
/// Most attributes are optional; you should pass only as many as you need to narrow the search sufficiently for your application's intended use. /// Most attributes are optional; you should pass only as many as you need to narrow the search sufficiently for your application's intended use.
/// SecKeychainFindGenericPassword optionally returns a reference to the found item. /// SecKeychainFindGenericPassword optionally returns a reference to the found item.
/// </remarks> /// </remarks>
[DllImport(Libraries.SecurityLibrary, CharSet = CharSet.Unicode, SetLastError = true)] [DllImport(Libraries.SecurityLibrary, CharSet = CharSet.Auto, SetLastError = true)]
internal static extern OSStatus SecKeychainFindGenericPassword(IntPtr keyChainRef, UInt32 serviceNameLength, string serviceName, internal static extern OSStatus SecKeychainFindGenericPassword(IntPtr keyChainRef, UInt32 serviceNameLength, string serviceName,
UInt32 accountNameLength, string accountName, out UInt32 passwordLength, out IntPtr password, out IntPtr itemRef); UInt32 accountNameLength, string accountName, out UInt32 passwordLength, out IntPtr password, out IntPtr itemRef);

View File

@@ -0,0 +1,52 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System;
using System.Runtime.InteropServices;
using static Microsoft.SqlTools.Credentials.Interop.Security;
namespace Microsoft.SqlTools.Credentials
{
internal partial class Interop
{
internal class SecurityOld
{
/// <summary>
/// Find a generic password based on the attributes passed (using Unicode encoding)
/// </summary>
/// <param name="keyChainRef">
/// A reference to an array of keychains to search, a single keychain, or NULL to search the user's default keychain search list.
/// </param>
/// <param name="serviceNameLength">The length of the buffer pointed to by serviceName.</param>
/// <param name="serviceName">A pointer to a string containing the service name.</param>
/// <param name="accountNameLength">The length of the buffer pointed to by accountName.</param>
/// <param name="accountName">A pointer to a string containing the account name.</param>
/// <param name="passwordLength">On return, the length of the buffer pointed to by passwordData.</param>
/// <param name="password">
/// On return, a pointer to a data buffer containing the password.
/// Your application must call SecKeychainItemFreeContent(NULL, passwordData)
/// to release this data buffer when it is no longer needed.Pass NULL if you are not interested in retrieving the password data at
/// this time, but simply want to find the item reference.
/// </param>
/// <param name="itemRef">On return, a reference to the keychain item which was found.</param>
/// <returns>A result code that should be in <see cref="OSStatus"/></returns>
/// <remarks>
/// The SecKeychainFindGenericPassword function finds the first generic password item which matches the attributes you provide.
/// Most attributes are optional; you should pass only as many as you need to narrow the search sufficiently for your application's intended use.
/// SecKeychainFindGenericPassword optionally returns a reference to the found item.
/// </remarks>
///
/// ***********************************************************************************************
/// This method is marked OBSOLETE as it used 'Unicode' Charset encoding to save credentials.
/// It has been replaced with respective API in the 'Security' class using 'Auto' Charset encoding.
/// ***********************************************************************************************
[Obsolete]
[DllImport(Libraries.SecurityLibrary, CharSet = CharSet.Unicode, SetLastError = true)]
internal static extern OSStatus SecKeychainFindGenericPassword(IntPtr keyChainRef, UInt32 serviceNameLength, string serviceName,
UInt32 accountNameLength, string accountName, out UInt32 passwordLength, out IntPtr password, out IntPtr itemRef);
}
}
}

View File

@@ -3,10 +3,9 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information. // Licensed under the MIT license. See LICENSE file in the project root for full license information.
// //
#nullable disable
using System; using System;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Text;
using Microsoft.SqlTools.Credentials.Contracts; using Microsoft.SqlTools.Credentials.Contracts;
using Microsoft.SqlTools.Utility; using Microsoft.SqlTools.Utility;
@@ -26,7 +25,7 @@ namespace Microsoft.SqlTools.Credentials.OSX
return DeletePasswordImpl(credentialId); return DeletePasswordImpl(credentialId);
} }
public bool TryGetPassword(string credentialId, out string password) public bool TryGetPassword(string credentialId, out string? password)
{ {
Validate.IsNotNullOrEmptyString("credentialId", credentialId); Validate.IsNotNullOrEmptyString("credentialId", credentialId);
return FindPassword(credentialId, out password); return FindPassword(credentialId, out password);
@@ -48,15 +47,15 @@ namespace Microsoft.SqlTools.Credentials.OSX
} }
private bool AddGenericPassword(Credential credential) private bool AddGenericPassword(Credential credential)
{ {
IntPtr passwordPtr = Marshal.StringToCoTaskMemUni(credential.Password); IntPtr passwordPtr = Marshal.StringToCoTaskMemUTF8(credential.Password);
Interop.Security.OSStatus status = Interop.Security.SecKeychainAddGenericPassword( Interop.Security.OSStatus status = Interop.Security.SecKeychainAddGenericPassword(
IntPtr.Zero, IntPtr.Zero,
InteropUtils.GetLengthInBytes(credential.CredentialId), InteropUtils.GetLengthInBytes(credential.CredentialId, Encoding.UTF8),
credential.CredentialId, credential.CredentialId,
0, 0,
null, null,
InteropUtils.GetLengthInBytes(credential.Password), InteropUtils.GetLengthInBytes(credential.Password, Encoding.UTF8),
passwordPtr, passwordPtr,
IntPtr.Zero); IntPtr.Zero);
@@ -66,12 +65,12 @@ namespace Microsoft.SqlTools.Credentials.OSX
/// <summary> /// <summary>
/// Finds the first password matching this credential /// Finds the first password matching this credential
/// </summary> /// </summary>
private bool FindPassword(string credentialId, out string password) private bool FindPassword(string credentialId, out string? password)
{ {
password = null; password = null;
using (KeyChainItemHandle handle = LookupKeyChainItem(credentialId)) using (KeyChainItemHandle? handle = LookupKeyChainItem(credentialId))
{ {
if( handle == null) if (handle == null)
{ {
return false; return false;
} }
@@ -81,14 +80,14 @@ namespace Microsoft.SqlTools.Credentials.OSX
return true; return true;
} }
private KeyChainItemHandle LookupKeyChainItem(string credentialId) private KeyChainItemHandle? LookupKeyChainItem(string credentialId)
{ {
UInt32 passwordLength; UInt32 passwordLength;
IntPtr passwordPtr; IntPtr passwordPtr;
IntPtr item; IntPtr item;
Interop.Security.OSStatus status = Interop.Security.SecKeychainFindGenericPassword( Interop.Security.OSStatus status = Interop.Security.SecKeychainFindGenericPassword(
IntPtr.Zero, IntPtr.Zero,
InteropUtils.GetLengthInBytes(credentialId), InteropUtils.GetLengthInBytes(credentialId, Encoding.UTF8),
credentialId, credentialId,
0, 0,
null, null,
@@ -96,9 +95,40 @@ namespace Microsoft.SqlTools.Credentials.OSX
out passwordPtr, out passwordPtr,
out item); out item);
if(status == Interop.Security.OSStatus.ErrSecSuccess) if (status == Interop.Security.OSStatus.ErrSecSuccess)
{ {
return new KeyChainItemHandle(item, passwordPtr, passwordLength); return new KeyChainItemHandle(item, passwordPtr, passwordLength, Encoding.UTF8);
}
else
{
#pragma warning disable 0612
// Intentional fallback to Unicode to retrieve old passwords before encoding shift
status = Interop.SecurityOld.SecKeychainFindGenericPassword(
IntPtr.Zero,
InteropUtils.GetLengthInBytes(credentialId, Encoding.Unicode),
credentialId,
0,
null!,
out passwordLength,
out passwordPtr,
out item);
#pragma warning restore 0612
if (status == Interop.Security.OSStatus.ErrSecSuccess)
{
using var handle = new KeyChainItemHandle(item, passwordPtr, passwordLength, Encoding.Unicode);
// Migrate credential to 'Auto' encoding.
if (handle != null)
{
var saveResult = this.AddGenericPassword(credential: new Credential(credentialId, handle.Password));
if (saveResult)
{
// Safe to delete old password now.
Interop.Security.SecKeychainItemDelete(handle);
}
// Lookup keychain again to fetch handle for new credential.
return LookupKeyChainItem(credentialId);
}
}
} }
return null; return null;
} }
@@ -106,7 +136,7 @@ namespace Microsoft.SqlTools.Credentials.OSX
private bool DeletePasswordImpl(string credentialId) private bool DeletePasswordImpl(string credentialId)
{ {
// Find password, then Delete, then cleanup // Find password, then Delete, then cleanup
using (KeyChainItemHandle handle = LookupKeyChainItem(credentialId)) using (KeyChainItemHandle? handle = LookupKeyChainItem(credentialId))
{ {
if (handle == null) if (handle == null)
{ {
@@ -114,39 +144,47 @@ namespace Microsoft.SqlTools.Credentials.OSX
} }
Interop.Security.OSStatus status = Interop.Security.SecKeychainItemDelete(handle); Interop.Security.OSStatus status = Interop.Security.SecKeychainItemDelete(handle);
return status == Interop.Security.OSStatus.ErrSecSuccess; return status == Interop.Security.OSStatus.ErrSecSuccess;
} }
} }
private class KeyChainItemHandle : SafeCreateHandle private class KeyChainItemHandle : SafeCreateHandle
{ {
private IntPtr passwordPtr; private IntPtr passwordPtr;
private int passwordLength; private int passwordLength;
private Encoding encoding = Encoding.UTF8;
public KeyChainItemHandle() : base() public KeyChainItemHandle() : base()
{ {
} }
public KeyChainItemHandle(IntPtr itemPtr) : this(itemPtr, IntPtr.Zero, 0) public KeyChainItemHandle(IntPtr itemPtr) : this(itemPtr, IntPtr.Zero, 0, Encoding.UTF8)
{ {
} }
public KeyChainItemHandle(IntPtr itemPtr, IntPtr passwordPtr, UInt32 passwordLength) public KeyChainItemHandle(IntPtr itemPtr, IntPtr passwordPtr, UInt32 passwordLength, Encoding encoding)
: base(itemPtr) : base(itemPtr)
{ {
if (encoding != Encoding.UTF8 && encoding != Encoding.Unicode)
{
throw new ArgumentException($"Encoding {encoding} not supported.");
}
this.passwordPtr = passwordPtr; this.passwordPtr = passwordPtr;
this.passwordLength = (int) passwordLength; this.passwordLength = (int)passwordLength;
this.encoding = encoding;
} }
public string Password public string? Password
{ {
get { get
{
if (IsInvalid) if (IsInvalid)
{ {
return null; return null;
} }
return InteropUtils.CopyToString(passwordPtr, passwordLength); return InteropUtils.CopyToString(passwordPtr, passwordLength,
this.encoding == Encoding.UTF8 ? Encoding.UTF8: Encoding.Unicode);
} }
} }
protected override bool ReleaseHandle() protected override bool ReleaseHandle()