Skip to content
Snippets Groups Projects
Commit cb7cdf8d authored by Thai Duong's avatar Thai Duong
Browse files

Fixing https://github.com/google/tink/issues/71.

Change-Id: If61fa55a7c67883fe8fb31b1ddbee3e0b6bed7a3
ORIGINAL_AUTHOR=Thai Duong <thaidn@google.com>
GitOrigin-RevId: aba017282b5d44672e94fb8753f38559da8d6ea4
parent 65293c74
No related branches found
No related tags found
No related merge requests found
......@@ -19,6 +19,7 @@ package com.google.crypto.tink.integration.awskms;
import com.amazonaws.AmazonServiceException;
import com.amazonaws.services.kms.AWSKMS;
import com.amazonaws.services.kms.model.DecryptRequest;
import com.amazonaws.services.kms.model.DecryptResult;
import com.amazonaws.services.kms.model.EncryptRequest;
import com.amazonaws.util.BinaryUtils;
import com.google.crypto.tink.Aead;
......@@ -68,7 +69,11 @@ public final class AwsKmsAead implements Aead {
if (associatedData != null && associatedData.length != 0) {
req = req.addEncryptionContextEntry("associatedData", BinaryUtils.toHex(associatedData));
}
return kmsClient.decrypt(req).getPlaintext().array();
DecryptResult result = kmsClient.decrypt(req);
if (!result.getKeyId().equals(keyArn)) {
throw new GeneralSecurityException("decryption failed: wrong key id");
}
return result.getPlaintext().array();
} catch (AmazonServiceException e) {
throw new GeneralSecurityException("decryption failed", e);
}
......
......@@ -42,9 +42,8 @@ import org.mockito.runners.MockitoJUnitRunner;
*/
@RunWith(MockitoJUnitRunner.class)
public class AwsKmsAeadTest {
private static final String KEY_ID = "aws-kms://123";
@Mock private AWSKMS mockKms;
private final String keyId = "aws-kms://123";
@Test
public void testEncryptDecrypt() throws Exception {
......@@ -55,10 +54,11 @@ public class AwsKmsAeadTest {
when(mockKms.encrypt(isA(EncryptRequest.class)))
.thenReturn(mockEncryptResult);
Aead aead = new AwsKmsAead(mockKms, keyId);
Aead aead = new AwsKmsAead(mockKms, KEY_ID);
byte[] aad = Random.randBytes(20);
for (int messageSize = 0; messageSize < 75; messageSize++) {
byte[] message = Random.randBytes(messageSize);
when(mockDecryptResult.getKeyId()).thenReturn(KEY_ID);
when(mockDecryptResult.getPlaintext()).thenReturn(ByteBuffer.wrap(message));
when(mockEncryptResult.getCiphertextBlob()).thenReturn(ByteBuffer.wrap(message));
byte[] ciphertext = aead.encrypt(message, aad);
......@@ -68,12 +68,12 @@ public class AwsKmsAeadTest {
}
@Test
public void testEncrypt_shouldThrowExceptionIfRequestFailed() throws Exception {
public void testEncryptShouldThrowExceptionIfRequestFailed() throws Exception {
AmazonServiceException exception = mock(AmazonServiceException.class);
when(mockKms.encrypt(isA(EncryptRequest.class)))
.thenThrow(exception);
Aead aead = new AwsKmsAead(mockKms, keyId);
Aead aead = new AwsKmsAead(mockKms, KEY_ID);
byte[] aad = Random.randBytes(20);
byte[] message = Random.randBytes(20);
try {
......@@ -85,7 +85,7 @@ public class AwsKmsAeadTest {
}
@Test
public void testDecrypt_shouldThrowExceptionIfRequestFailed() throws Exception {
public void testDecryptShouldThrowExceptionIfRequestFailed() throws Exception {
EncryptResult mockEncryptResult = mock(EncryptResult.class);
when(mockKms.encrypt(isA(EncryptRequest.class)))
.thenReturn(mockEncryptResult);
......@@ -93,10 +93,33 @@ public class AwsKmsAeadTest {
when(mockKms.decrypt(isA(DecryptRequest.class)))
.thenThrow(exception);
Aead aead = new AwsKmsAead(mockKms, keyId);
Aead aead = new AwsKmsAead(mockKms, KEY_ID);
byte[] aad = Random.randBytes(20);
byte[] message = Random.randBytes(20);
when(mockEncryptResult.getCiphertextBlob()).thenReturn(ByteBuffer.wrap(message));
byte[] ciphertext = aead.encrypt(message, aad);
try {
aead.decrypt(ciphertext, aad);
fail("Expected GeneralSecurityException");
} catch (GeneralSecurityException e) {
// expected.
}
}
@Test
public void testDecryptShouldThrowExceptionIfKeyIdIsDifferent() throws Exception {
DecryptResult mockDecryptResult = mock(DecryptResult.class);
EncryptResult mockEncryptResult = mock(EncryptResult.class);
when(mockKms.decrypt(isA(DecryptRequest.class)))
.thenReturn(mockDecryptResult);
when(mockKms.encrypt(isA(EncryptRequest.class)))
.thenReturn(mockEncryptResult);
Aead aead = new AwsKmsAead(mockKms, KEY_ID);
byte[] aad = Random.randBytes(20);
byte[] message = Random.randBytes(20);
when(mockEncryptResult.getCiphertextBlob()).thenReturn(ByteBuffer.wrap(message));
when(mockDecryptResult.getKeyId()).thenReturn(KEY_ID + "1");
byte[] ciphertext = aead.encrypt(message, aad);
try {
aead.decrypt(ciphertext, aad);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment