From 3e4e0bc8c2e74d59db8fd6e397af7f0fa4bb7785 Mon Sep 17 00:00:00 2001
From: Cheena Malhotra <13396919+cheenamalhotra@users.noreply.github.com>
Date: Thu, 23 Mar 2023 20:31:59 -0700
Subject: [PATCH] Fix encoding for OSX keychain (#1939)
---
.../Credentials/ICredentialStore.cs | 6 +-
.../Credentials/InteropUtils.cs | 38 ++++++--
.../Credentials/OSX/Interop.Security.cs | 7 +-
.../Credentials/OSX/Interop.Security.old.cs | 52 ++++++++++
.../Credentials/OSX/OSXCredentialStore.cs | 94 +++++++++++++------
5 files changed, 152 insertions(+), 45 deletions(-)
create mode 100644 src/Microsoft.SqlTools.Credentials/Credentials/OSX/Interop.Security.old.cs
diff --git a/src/Microsoft.SqlTools.Credentials/Credentials/ICredentialStore.cs b/src/Microsoft.SqlTools.Credentials/Credentials/ICredentialStore.cs
index 770da611..a8ba0bd2 100644
--- a/src/Microsoft.SqlTools.Credentials/Credentials/ICredentialStore.cs
+++ b/src/Microsoft.SqlTools.Credentials/Credentials/ICredentialStore.cs
@@ -1,10 +1,8 @@
-//
+//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
-#nullable disable
-
using Microsoft.SqlTools.Credentials.Contracts;
namespace Microsoft.SqlTools.Credentials
@@ -30,7 +28,7 @@ namespace Microsoft.SqlTools.Credentials
/// The name of the credential to find the password for. This is required
/// Out value
/// true if password was found, false otherwise
- bool TryGetPassword(string credentialId, out string password);
+ bool TryGetPassword(string credentialId, out string? password);
///
/// Deletes a password linked to a given credential
diff --git a/src/Microsoft.SqlTools.Credentials/Credentials/InteropUtils.cs b/src/Microsoft.SqlTools.Credentials/Credentials/InteropUtils.cs
index 82d4f18a..5ee6f159 100644
--- a/src/Microsoft.SqlTools.Credentials/Credentials/InteropUtils.cs
+++ b/src/Microsoft.SqlTools.Credentials/Credentials/InteropUtils.cs
@@ -3,8 +3,6 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
-#nullable disable
-
using System;
using System.Runtime.InteropServices;
using System.Text;
@@ -13,25 +11,47 @@ namespace Microsoft.SqlTools.Credentials
{
internal static class InteropUtils
{
-
///
- /// Gets the length in bytes for a Unicode string, for use in interop where length must be defined
+ /// Gets the length in bytes for an encoded string, for use in interop where length must be defined
///
- public static UInt32 GetLengthInBytes(string value)
+ /// String value
+ /// Encoding of string provided.
+ public static UInt32 GetLengthInBytes(string value, Encoding encoding)
{
-
- return Convert.ToUInt32( (value != null ? Encoding.Unicode.GetByteCount(value) : 0) );
+ if (encoding != Encoding.Unicode && encoding != Encoding.UTF8)
+ {
+ throw new ArgumentException($"Encoding {encoding} not supported.");
+ }
+ return Convert.ToUInt32((value != null
+ ? (encoding == Encoding.UTF8
+ ? Encoding.UTF8.GetByteCount(value)
+ : Encoding.Unicode.GetByteCount(value))
+ : 0));
}
- public static string CopyToString(IntPtr ptr, int length)
+ ///
+ /// Copies data of length from
+ /// pointer to a string of provided encoding.
+ ///
+ /// Pointer to data
+ /// Length of data to be copied.
+ /// Character encoding to be used to get string.
+ ///
+ public static string? CopyToString(IntPtr ptr, int length, Encoding encoding)
{
if (ptr == IntPtr.Zero || length == 0)
{
return null;
}
+ if (encoding != Encoding.Unicode && encoding != Encoding.UTF8)
+ {
+ throw new ArgumentException($"Encoding {encoding} not supported.");
+ }
byte[] pwdBytes = new byte[length];
Marshal.Copy(ptr, pwdBytes, 0, (int)length);
- return Encoding.Unicode.GetString(pwdBytes, 0, (int)length);
+ return (encoding == Encoding.UTF8)
+ ? Encoding.UTF8.GetString(pwdBytes, 0, (int)length)
+ : Encoding.Unicode.GetString(pwdBytes, 0, (int)length);
}
}
diff --git a/src/Microsoft.SqlTools.Credentials/Credentials/OSX/Interop.Security.cs b/src/Microsoft.SqlTools.Credentials/Credentials/OSX/Interop.Security.cs
index 40b44f46..b1ca5d73 100644
--- a/src/Microsoft.SqlTools.Credentials/Credentials/OSX/Interop.Security.cs
+++ b/src/Microsoft.SqlTools.Credentials/Credentials/OSX/Interop.Security.cs
@@ -1,4 +1,4 @@
-//
+//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
@@ -14,8 +14,7 @@ namespace Microsoft.SqlTools.Credentials
{
internal partial class Security
{
-
- [DllImport(Libraries.SecurityLibrary, CharSet = CharSet.Unicode, SetLastError = true)]
+ [DllImport(Libraries.SecurityLibrary, CharSet = CharSet.Auto, SetLastError = true)]
internal static extern OSStatus SecKeychainAddGenericPassword(IntPtr keyChainRef, UInt32 serviceNameLength, string serviceName,
UInt32 accountNameLength, string accountName, UInt32 passwordLength, IntPtr password, [Out] IntPtr itemRef);
@@ -43,7 +42,7 @@ namespace Microsoft.SqlTools.Credentials
/// Most attributes are optional; you should pass only as many as you need to narrow the search sufficiently for your application's intended use.
/// SecKeychainFindGenericPassword optionally returns a reference to the found item.
///
- [DllImport(Libraries.SecurityLibrary, CharSet = CharSet.Unicode, SetLastError = true)]
+ [DllImport(Libraries.SecurityLibrary, CharSet = CharSet.Auto, SetLastError = true)]
internal static extern OSStatus SecKeychainFindGenericPassword(IntPtr keyChainRef, UInt32 serviceNameLength, string serviceName,
UInt32 accountNameLength, string accountName, out UInt32 passwordLength, out IntPtr password, out IntPtr itemRef);
diff --git a/src/Microsoft.SqlTools.Credentials/Credentials/OSX/Interop.Security.old.cs b/src/Microsoft.SqlTools.Credentials/Credentials/OSX/Interop.Security.old.cs
new file mode 100644
index 00000000..7609cafd
--- /dev/null
+++ b/src/Microsoft.SqlTools.Credentials/Credentials/OSX/Interop.Security.old.cs
@@ -0,0 +1,52 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.Runtime.InteropServices;
+using static Microsoft.SqlTools.Credentials.Interop.Security;
+
+namespace Microsoft.SqlTools.Credentials
+{
+ internal partial class Interop
+ {
+ internal class SecurityOld
+ {
+ ///
+ /// Find a generic password based on the attributes passed (using Unicode encoding)
+ ///
+ ///
+ /// A reference to an array of keychains to search, a single keychain, or NULL to search the user's default keychain search list.
+ ///
+ /// The length of the buffer pointed to by serviceName.
+ /// A pointer to a string containing the service name.
+ /// The length of the buffer pointed to by accountName.
+ /// A pointer to a string containing the account name.
+ /// On return, the length of the buffer pointed to by passwordData.
+ ///
+ /// On return, a pointer to a data buffer containing the password.
+ /// Your application must call SecKeychainItemFreeContent(NULL, passwordData)
+ /// to release this data buffer when it is no longer needed.Pass NULL if you are not interested in retrieving the password data at
+ /// this time, but simply want to find the item reference.
+ ///
+ /// On return, a reference to the keychain item which was found.
+ /// A result code that should be in
+ ///
+ /// The SecKeychainFindGenericPassword function finds the first generic password item which matches the attributes you provide.
+ /// Most attributes are optional; you should pass only as many as you need to narrow the search sufficiently for your application's intended use.
+ /// SecKeychainFindGenericPassword optionally returns a reference to the found item.
+ ///
+ ///
+ /// ***********************************************************************************************
+ /// This method is marked OBSOLETE as it used 'Unicode' Charset encoding to save credentials.
+ /// It has been replaced with respective API in the 'Security' class using 'Auto' Charset encoding.
+ /// ***********************************************************************************************
+ [Obsolete]
+ [DllImport(Libraries.SecurityLibrary, CharSet = CharSet.Unicode, SetLastError = true)]
+ internal static extern OSStatus SecKeychainFindGenericPassword(IntPtr keyChainRef, UInt32 serviceNameLength, string serviceName,
+ UInt32 accountNameLength, string accountName, out UInt32 passwordLength, out IntPtr password, out IntPtr itemRef);
+ }
+ }
+}
+
diff --git a/src/Microsoft.SqlTools.Credentials/Credentials/OSX/OSXCredentialStore.cs b/src/Microsoft.SqlTools.Credentials/Credentials/OSX/OSXCredentialStore.cs
index 70eb20f7..ab334dc4 100644
--- a/src/Microsoft.SqlTools.Credentials/Credentials/OSX/OSXCredentialStore.cs
+++ b/src/Microsoft.SqlTools.Credentials/Credentials/OSX/OSXCredentialStore.cs
@@ -3,10 +3,9 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
-#nullable disable
-
using System;
using System.Runtime.InteropServices;
+using System.Text;
using Microsoft.SqlTools.Credentials.Contracts;
using Microsoft.SqlTools.Utility;
@@ -26,7 +25,7 @@ namespace Microsoft.SqlTools.Credentials.OSX
return DeletePasswordImpl(credentialId);
}
- public bool TryGetPassword(string credentialId, out string password)
+ public bool TryGetPassword(string credentialId, out string? password)
{
Validate.IsNotNullOrEmptyString("credentialId", credentialId);
return FindPassword(credentialId, out password);
@@ -48,15 +47,15 @@ namespace Microsoft.SqlTools.Credentials.OSX
}
private bool AddGenericPassword(Credential credential)
- {
- IntPtr passwordPtr = Marshal.StringToCoTaskMemUni(credential.Password);
+ {
+ IntPtr passwordPtr = Marshal.StringToCoTaskMemUTF8(credential.Password);
Interop.Security.OSStatus status = Interop.Security.SecKeychainAddGenericPassword(
- IntPtr.Zero,
- InteropUtils.GetLengthInBytes(credential.CredentialId),
+ IntPtr.Zero,
+ InteropUtils.GetLengthInBytes(credential.CredentialId, Encoding.UTF8),
credential.CredentialId,
- 0,
+ 0,
null,
- InteropUtils.GetLengthInBytes(credential.Password),
+ InteropUtils.GetLengthInBytes(credential.Password, Encoding.UTF8),
passwordPtr,
IntPtr.Zero);
@@ -66,12 +65,12 @@ namespace Microsoft.SqlTools.Credentials.OSX
///
/// Finds the first password matching this credential
///
- private bool FindPassword(string credentialId, out string password)
+ private bool FindPassword(string credentialId, out string? password)
{
password = null;
- using (KeyChainItemHandle handle = LookupKeyChainItem(credentialId))
+ using (KeyChainItemHandle? handle = LookupKeyChainItem(credentialId))
{
- if( handle == null)
+ if (handle == null)
{
return false;
}
@@ -81,14 +80,14 @@ namespace Microsoft.SqlTools.Credentials.OSX
return true;
}
- private KeyChainItemHandle LookupKeyChainItem(string credentialId)
+ private KeyChainItemHandle? LookupKeyChainItem(string credentialId)
{
UInt32 passwordLength;
IntPtr passwordPtr;
IntPtr item;
Interop.Security.OSStatus status = Interop.Security.SecKeychainFindGenericPassword(
IntPtr.Zero,
- InteropUtils.GetLengthInBytes(credentialId),
+ InteropUtils.GetLengthInBytes(credentialId, Encoding.UTF8),
credentialId,
0,
null,
@@ -96,9 +95,40 @@ namespace Microsoft.SqlTools.Credentials.OSX
out passwordPtr,
out item);
- if(status == Interop.Security.OSStatus.ErrSecSuccess)
+ if (status == Interop.Security.OSStatus.ErrSecSuccess)
{
- return new KeyChainItemHandle(item, passwordPtr, passwordLength);
+ return new KeyChainItemHandle(item, passwordPtr, passwordLength, Encoding.UTF8);
+ }
+ else
+ {
+#pragma warning disable 0612
+ // Intentional fallback to Unicode to retrieve old passwords before encoding shift
+ status = Interop.SecurityOld.SecKeychainFindGenericPassword(
+ IntPtr.Zero,
+ InteropUtils.GetLengthInBytes(credentialId, Encoding.Unicode),
+ credentialId,
+ 0,
+ null!,
+ out passwordLength,
+ out passwordPtr,
+ out item);
+#pragma warning restore 0612
+ if (status == Interop.Security.OSStatus.ErrSecSuccess)
+ {
+ using var handle = new KeyChainItemHandle(item, passwordPtr, passwordLength, Encoding.Unicode);
+ // Migrate credential to 'Auto' encoding.
+ if (handle != null)
+ {
+ var saveResult = this.AddGenericPassword(credential: new Credential(credentialId, handle.Password));
+ if (saveResult)
+ {
+ // Safe to delete old password now.
+ Interop.Security.SecKeychainItemDelete(handle);
+ }
+ // Lookup keychain again to fetch handle for new credential.
+ return LookupKeyChainItem(credentialId);
+ }
+ }
}
return null;
}
@@ -106,7 +136,7 @@ namespace Microsoft.SqlTools.Credentials.OSX
private bool DeletePasswordImpl(string credentialId)
{
// Find password, then Delete, then cleanup
- using (KeyChainItemHandle handle = LookupKeyChainItem(credentialId))
+ using (KeyChainItemHandle? handle = LookupKeyChainItem(credentialId))
{
if (handle == null)
{
@@ -114,39 +144,47 @@ namespace Microsoft.SqlTools.Credentials.OSX
}
Interop.Security.OSStatus status = Interop.Security.SecKeychainItemDelete(handle);
return status == Interop.Security.OSStatus.ErrSecSuccess;
- }
+ }
}
private class KeyChainItemHandle : SafeCreateHandle
{
private IntPtr passwordPtr;
private int passwordLength;
+ private Encoding encoding = Encoding.UTF8;
public KeyChainItemHandle() : base()
{
}
-
- public KeyChainItemHandle(IntPtr itemPtr) : this(itemPtr, IntPtr.Zero, 0)
+
+ public KeyChainItemHandle(IntPtr itemPtr) : this(itemPtr, IntPtr.Zero, 0, Encoding.UTF8)
{
-
+
}
- public KeyChainItemHandle(IntPtr itemPtr, IntPtr passwordPtr, UInt32 passwordLength)
+ public KeyChainItemHandle(IntPtr itemPtr, IntPtr passwordPtr, UInt32 passwordLength, Encoding encoding)
: base(itemPtr)
{
+ if (encoding != Encoding.UTF8 && encoding != Encoding.Unicode)
+ {
+ throw new ArgumentException($"Encoding {encoding} not supported.");
+ }
this.passwordPtr = passwordPtr;
- this.passwordLength = (int) passwordLength;
+ this.passwordLength = (int)passwordLength;
+ this.encoding = encoding;
}
-
- public string Password
- {
- get {
+
+ public string? Password
+ {
+ get
+ {
if (IsInvalid)
{
return null;
}
- return InteropUtils.CopyToString(passwordPtr, passwordLength);
+ return InteropUtils.CopyToString(passwordPtr, passwordLength,
+ this.encoding == Encoding.UTF8 ? Encoding.UTF8: Encoding.Unicode);
}
}
protected override bool ReleaseHandle()