diff --git a/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/AesImplementation.cs b/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/AesImplementation.cs index 9729a1b9ab8a67..c485530281a0a2 100644 --- a/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/AesImplementation.cs +++ b/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/AesImplementation.cs @@ -9,13 +9,28 @@ namespace System.Security.Cryptography internal sealed partial class AesImplementation : Aes { private FixedMemoryKeyBox? _keyBox; + private ILiteSymmetricCipher? _encryptEcbCipher; + private ILiteSymmetricCipher? _decryptEcbCipher; + private ILiteSymmetricCipher? _encryptCbcCipher; + private ILiteSymmetricCipher? _decryptCbcCipher; + private ConcurrencyBlock _block; private FixedMemoryKeyBox GetKey() { if (_keyBox is null) { - GenerateKey(); - Debug.Assert(_keyBox is not null); + Span key = stackalloc byte[KeySize / BitsPerByte]; + + try + { + RandomNumberGenerator.Fill(key); + SetKeyCoreUnchecked(key); + Debug.Assert(_keyBox is not null); + } + finally + { + CryptographicOperations.ZeroMemory(key); + } } return _keyBox; @@ -32,9 +47,13 @@ public override int KeySize get => base.KeySize; set { - base.KeySize = value; - _keyBox?.Dispose(); - _keyBox = null; + using (ConcurrencyBlock.Enter(ref _block)) + { + base.KeySize = value; + ClearCachedCiphers(); + _keyBox?.Dispose(); + _keyBox = null; + } } } @@ -70,14 +89,23 @@ public sealed override void GenerateIV() public sealed override unsafe void GenerateKey() { Span key = stackalloc byte[KeySize / BitsPerByte]; - RandomNumberGenerator.Fill(key); - SetKeyCore(key); + + try + { + RandomNumberGenerator.Fill(key); + SetKeyCore(key); + } + finally + { + CryptographicOperations.ZeroMemory(key); + } } protected sealed override void Dispose(bool disposing) { if (disposing) { + ClearCachedCiphers(); _keyBox?.Dispose(); _keyBox = null; } @@ -87,9 +115,10 @@ protected sealed override void Dispose(bool disposing) protected override void SetKeyCore(ReadOnlySpan key) { - KeySizeValue = checked(BitsPerByte * key.Length); - _keyBox?.Dispose(); - _keyBox = new FixedMemoryKeyBox(key); + using (ConcurrencyBlock.Enter(ref _block)) + { + SetKeyCoreUnchecked(key); + } } protected override bool TryDecryptEcbCore( @@ -98,19 +127,14 @@ protected override bool TryDecryptEcbCore( PaddingMode paddingMode, out int bytesWritten) { - ILiteSymmetricCipher cipher = GetKey().UseKey( - BlockSize / BitsPerByte, - static (blockSizeBytes, key) => CreateLiteCipher( + using (ConcurrencyBlock.Enter(ref _block)) + { + ILiteSymmetricCipher cipher = GetOrCreateCachedLiteCipher( + ref _decryptEcbCipher, CipherMode.ECB, - key, iv: default, - blockSize: blockSizeBytes, - paddingSize: blockSizeBytes, - 0, /*feedback size */ - encrypting: false)); + encrypting: false); - using (cipher) - { return UniversalCryptoOneShot.OneShotDecrypt(cipher, paddingMode, ciphertext, destination, out bytesWritten); } } @@ -121,19 +145,14 @@ protected override bool TryEncryptEcbCore( PaddingMode paddingMode, out int bytesWritten) { - ILiteSymmetricCipher cipher = GetKey().UseKey( - BlockSize / BitsPerByte, - static (blockSizeBytes, key) => CreateLiteCipher( + using (ConcurrencyBlock.Enter(ref _block)) + { + ILiteSymmetricCipher cipher = GetOrCreateCachedLiteCipher( + ref _encryptEcbCipher, CipherMode.ECB, - key, iv: default, - blockSize: blockSizeBytes, - paddingSize: blockSizeBytes, - 0, /*feedback size */ - encrypting: true)); + encrypting: true); - using (cipher) - { return UniversalCryptoOneShot.OneShotEncrypt(cipher, paddingMode, plaintext, destination, out bytesWritten); } } @@ -145,20 +164,14 @@ protected override bool TryEncryptCbcCore( PaddingMode paddingMode, out int bytesWritten) { - ILiteSymmetricCipher cipher = GetKey().UseKey( - iv, - BlockSize / BitsPerByte, - static (iv, blockSizeBytes, key) => CreateLiteCipher( + using (ConcurrencyBlock.Enter(ref _block)) + { + ILiteSymmetricCipher cipher = GetOrCreateCachedLiteCipher( + ref _encryptCbcCipher, CipherMode.CBC, - key, iv, - blockSize: blockSizeBytes, - paddingSize: blockSizeBytes, - 0, /*feedback size */ - encrypting: true)); + encrypting: true); - using (cipher) - { return UniversalCryptoOneShot.OneShotEncrypt(cipher, paddingMode, plaintext, destination, out bytesWritten); } } @@ -170,20 +183,14 @@ protected override bool TryDecryptCbcCore( PaddingMode paddingMode, out int bytesWritten) { - ILiteSymmetricCipher cipher = GetKey().UseKey( - iv, - BlockSize / BitsPerByte, - static (iv, blockSizeBytes, key) => CreateLiteCipher( + using (ConcurrencyBlock.Enter(ref _block)) + { + ILiteSymmetricCipher cipher = GetOrCreateCachedLiteCipher( + ref _decryptCbcCipher, CipherMode.CBC, - key, iv, - blockSize: blockSizeBytes, - paddingSize: blockSizeBytes, - 0, /*feedback size */ - encrypting: false)); + encrypting: false); - using (cipher) - { return UniversalCryptoOneShot.OneShotDecrypt(cipher, paddingMode, ciphertext, destination, out bytesWritten); } } @@ -198,21 +205,24 @@ protected override bool TryDecryptCfbCore( { ValidateCFBFeedbackSize(feedbackSizeInBits); - ILiteSymmetricCipher cipher = GetKey().UseKey( - iv, - (BlockSizeBytes: BlockSize / BitsPerByte, FeedbackSizeBytes: feedbackSizeInBits / BitsPerByte), - static (iv, state, key) => CreateLiteCipher( - CipherMode.CFB, - key, - iv: iv, - blockSize: state.BlockSizeBytes, - paddingSize: state.FeedbackSizeBytes, - state.FeedbackSizeBytes, - encrypting: false)); - - using (cipher) + using (ConcurrencyBlock.Enter(ref _block)) { - return UniversalCryptoOneShot.OneShotDecrypt(cipher, paddingMode, ciphertext, destination, out bytesWritten); + ILiteSymmetricCipher cipher = GetKey().UseKey( + iv, + (BlockSizeBytes: BlockSize / BitsPerByte, FeedbackSizeBytes: feedbackSizeInBits / BitsPerByte), + static (iv, state, key) => CreateLiteCipher( + CipherMode.CFB, + key, + iv: iv, + blockSize: state.BlockSizeBytes, + paddingSize: state.FeedbackSizeBytes, + state.FeedbackSizeBytes, + encrypting: false)); + + using (cipher) + { + return UniversalCryptoOneShot.OneShotDecrypt(cipher, paddingMode, ciphertext, destination, out bytesWritten); + } } } @@ -226,21 +236,24 @@ protected override bool TryEncryptCfbCore( { ValidateCFBFeedbackSize(feedbackSizeInBits); - ILiteSymmetricCipher cipher = GetKey().UseKey( - iv, - (BlockSizeBytes: BlockSize / BitsPerByte, FeedbackSizeBytes: feedbackSizeInBits / BitsPerByte), - static (iv, state, key) => CreateLiteCipher( - CipherMode.CFB, - key, - iv, - blockSize: state.BlockSizeBytes, - paddingSize: state.FeedbackSizeBytes, - state.FeedbackSizeBytes, - encrypting: true)); - - using (cipher) + using (ConcurrencyBlock.Enter(ref _block)) { - return UniversalCryptoOneShot.OneShotEncrypt(cipher, paddingMode, plaintext, destination, out bytesWritten); + ILiteSymmetricCipher cipher = GetKey().UseKey( + iv, + (BlockSizeBytes: BlockSize / BitsPerByte, FeedbackSizeBytes: feedbackSizeInBits / BitsPerByte), + static (iv, state, key) => CreateLiteCipher( + CipherMode.CFB, + key, + iv, + blockSize: state.BlockSizeBytes, + paddingSize: state.FeedbackSizeBytes, + state.FeedbackSizeBytes, + encrypting: true)); + + using (cipher) + { + return UniversalCryptoOneShot.OneShotEncrypt(cipher, paddingMode, plaintext, destination, out bytesWritten); + } } } @@ -291,6 +304,66 @@ private static void ValidateCFBFeedbackSize(int feedback) } } + private ILiteSymmetricCipher GetOrCreateCachedLiteCipher( + ref ILiteSymmetricCipher? cipher, + CipherMode cipherMode, + ReadOnlySpan iv, + bool encrypting) + { + Debug.Assert(cipherMode is CipherMode.ECB or CipherMode.CBC); + + if (cipher is not null) + { + try + { + cipher.Reset(iv); + return cipher; + } + catch + { + cipher.Dispose(); + cipher = null; // Null-out the cipher field passed by reference. + throw; + } + } + + int blockSizeBytes = BlockSize / BitsPerByte; + cipher = GetKey().UseKey( + iv, + (BlockSizeBytes: blockSizeBytes, CipherMode: cipherMode, Encrypting: encrypting), + static (iv, state, key) => CreateLiteCipher( + state.CipherMode, + key, + iv, + blockSize: state.BlockSizeBytes, + paddingSize: state.BlockSizeBytes, + 0, /* feedback size */ + encrypting: state.Encrypting)); + + return cipher; + } + + private void SetKeyCoreUnchecked(ReadOnlySpan key) + { + KeySizeValue = checked(BitsPerByte * key.Length); + FixedMemoryKeyBox keyBox = new FixedMemoryKeyBox(key); + ClearCachedCiphers(); + _keyBox?.Dispose(); + _keyBox = keyBox; + } + + private void ClearCachedCiphers() + { + _encryptEcbCipher?.Dispose(); + _encryptEcbCipher = null; + _decryptEcbCipher?.Dispose(); + _decryptEcbCipher = null; + _encryptCbcCipher?.Dispose(); + _encryptCbcCipher = null; + _decryptCbcCipher?.Dispose(); + _decryptCbcCipher = null; + } + private const int BitsPerByte = 8; } }