diff --git a/java/src/test/java/com/google/crypto/tink/streamingaead/AesGcmHkdfStreamingKeyManagerTest.java b/java/src/test/java/com/google/crypto/tink/streamingaead/AesGcmHkdfStreamingKeyManagerTest.java
index cccb83c67e7b75720bd93f4ed19e4b169dad3efe..1d73fa8e3d1733042ebd58711174a2e966e1face 100644
--- a/java/src/test/java/com/google/crypto/tink/streamingaead/AesGcmHkdfStreamingKeyManagerTest.java
+++ b/java/src/test/java/com/google/crypto/tink/streamingaead/AesGcmHkdfStreamingKeyManagerTest.java
@@ -16,11 +16,9 @@
 
 package com.google.crypto.tink.streamingaead;
 
-import static org.junit.Assert.assertEquals;
+import static com.google.common.truth.Truth.assertThat;
 import static org.junit.Assert.fail;
 
-import com.google.crypto.tink.KeyManager;
-import com.google.crypto.tink.KeyManagerImpl;
 import com.google.crypto.tink.StreamingAead;
 import com.google.crypto.tink.StreamingTestUtil;
 import com.google.crypto.tink.TestUtil;
@@ -28,13 +26,10 @@ import com.google.crypto.tink.proto.AesGcmHkdfStreamingKey;
 import com.google.crypto.tink.proto.AesGcmHkdfStreamingKeyFormat;
 import com.google.crypto.tink.proto.AesGcmHkdfStreamingParams;
 import com.google.crypto.tink.proto.HashType;
-import com.google.crypto.tink.proto.KeyData;
-import com.google.crypto.tink.subtle.Random;
-import com.google.protobuf.ByteString;
+import com.google.crypto.tink.proto.KeyData.KeyMaterialType;
 import java.security.GeneralSecurityException;
 import java.util.Set;
 import java.util.TreeSet;
-import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -42,182 +37,109 @@ import org.junit.runners.JUnit4;
 /** Test for AesGcmHkdfStreamingKeyManager. */
 @RunWith(JUnit4.class)
 public class AesGcmHkdfStreamingKeyManagerTest {
-  private static final int AES_KEY_SIZE = 16;
-  private AesGcmHkdfStreamingParams keyParams;
-  private KeyManager<StreamingAead> keyManager;
-
-  @Before
-  public void setUp() throws GeneralSecurityException {
-    keyParams =
-        AesGcmHkdfStreamingParams.newBuilder()
-            .setCiphertextSegmentSize(128)
-            .setDerivedKeySize(AES_KEY_SIZE)
-            .setHkdfHashType(HashType.SHA256)
-            .build();
-    keyManager = new KeyManagerImpl<>(new AesGcmHkdfStreamingKeyManager(), StreamingAead.class);
+  private final AesGcmHkdfStreamingKeyManager manager = new AesGcmHkdfStreamingKeyManager();
+  private final AesGcmHkdfStreamingKeyManager.KeyFactory<
+          AesGcmHkdfStreamingKeyFormat, AesGcmHkdfStreamingKey>
+      factory = manager.keyFactory();
+
+  private static AesGcmHkdfStreamingKeyFormat createKeyFormat(
+      int keySize, int derivedKeySize, HashType hashType, int segmentSize) {
+    return AesGcmHkdfStreamingKeyFormat.newBuilder()
+        .setKeySize(keySize)
+        .setParams(
+            AesGcmHkdfStreamingParams.newBuilder()
+                .setDerivedKeySize(derivedKeySize)
+                .setHkdfHashType(hashType)
+                .setCiphertextSegmentSize(segmentSize))
+        .build();
   }
 
   @Test
-  public void testBasic() throws Exception {
-    // Create primitive from a given key.
-    AesGcmHkdfStreamingKey key =
-        AesGcmHkdfStreamingKey.newBuilder()
-            .setVersion(0)
-            .setKeyValue(ByteString.copyFrom(Random.randBytes(20)))
-            .setParams(keyParams)
-            .build();
-    StreamingAead streamingAead = keyManager.getPrimitive(key);
-    StreamingTestUtil.testEncryptionAndDecryption(streamingAead);
-
-    // Create a key from KeyFormat, and use the key.
-    AesGcmHkdfStreamingKeyFormat keyFormat =
-        AesGcmHkdfStreamingKeyFormat.newBuilder().setParams(keyParams).setKeySize(16).build();
-    ByteString serializedKeyFormat = ByteString.copyFrom(keyFormat.toByteArray());
-    key = (AesGcmHkdfStreamingKey) keyManager.newKey(serializedKeyFormat);
-    streamingAead = keyManager.getPrimitive(key);
-    StreamingTestUtil.testEncryptionAndDecryption(streamingAead);
+  public void basics() throws Exception {
+    assertThat(manager.getKeyType())
+        .isEqualTo("type.googleapis.com/google.crypto.tink.AesGcmHkdfStreamingKey");
+    assertThat(manager.getVersion()).isEqualTo(0);
+    assertThat(manager.keyMaterialType()).isEqualTo(KeyMaterialType.SYMMETRIC);
   }
 
   @Test
-  public void testSkip() throws Exception {
-    AesGcmHkdfStreamingKey key =
-        AesGcmHkdfStreamingKey.newBuilder()
-            .setVersion(0)
-            .setKeyValue(ByteString.copyFrom(Random.randBytes(20)))
-            .setParams(keyParams)
-            .build();
-    StreamingAead streamingAead = keyManager.getPrimitive(key);
-    int offset = 0;
-    int plaintextSize = 1 << 16;
-    // Runs the test with different sizes for the chunks to skip.
-    StreamingTestUtil.testSkipWithStream(streamingAead, offset, plaintextSize, 1);
-    StreamingTestUtil.testSkipWithStream(streamingAead, offset, plaintextSize, 64);
-    StreamingTestUtil.testSkipWithStream(streamingAead, offset, plaintextSize, 300);
+  public void validateKeyFormat_empty_throws() throws Exception {
+    try {
+      factory.validateKeyFormat(AesGcmHkdfStreamingKeyFormat.getDefaultInstance());
+      fail();
+    } catch (GeneralSecurityException e) {
+      // expected
+    }
   }
 
   @Test
-  public void testNewKeyMultipleTimes() throws Exception {
-    AesGcmHkdfStreamingKeyFormat keyFormat =
-        AesGcmHkdfStreamingKeyFormat.newBuilder().setParams(keyParams).setKeySize(16).build();
-    ByteString serializedKeyFormat = ByteString.copyFrom(keyFormat.toByteArray());
-    Set<String> keys = new TreeSet<String>();
-    // Calls newKey multiple times and make sure that they generate different keys.
-    int numTests = 27;
-    for (int i = 0; i < numTests / 3; i++) {
-      AesGcmHkdfStreamingKey key = (AesGcmHkdfStreamingKey) keyManager.newKey(keyFormat);
-      keys.add(TestUtil.hexEncode(key.getKeyValue().toByteArray()));
-      assertEquals(16, key.getKeyValue().toByteArray().length);
-
-      key = (AesGcmHkdfStreamingKey) keyManager.newKey(serializedKeyFormat);
-      keys.add(TestUtil.hexEncode(key.getKeyValue().toByteArray()));
-      assertEquals(16, key.getKeyValue().toByteArray().length);
-
-      KeyData keyData = keyManager.newKeyData(serializedKeyFormat);
-      key = AesGcmHkdfStreamingKey.parseFrom(keyData.getValue());
-      keys.add(TestUtil.hexEncode(key.getKeyValue().toByteArray()));
-      assertEquals(16, key.getKeyValue().toByteArray().length);
-    }
-    assertEquals(numTests, keys.size());
+  public void validateKeyFormat_valid() throws Exception {
+    factory.validateKeyFormat(createKeyFormat(32, 32, HashType.SHA256, 1024));
   }
 
   @Test
-  public void testNewKeyWithBadFormat() throws Exception {
-    // key_size too small.
-    AesGcmHkdfStreamingKeyFormat keyFormat =
-        AesGcmHkdfStreamingKeyFormat.newBuilder().setParams(keyParams).setKeySize(15).build();
-    ByteString serializedKeyFormat = ByteString.copyFrom(keyFormat.toByteArray());
+  public void validateKeyFormat_unkownHash_throws() throws Exception {
     try {
-      keyManager.newKey(keyFormat);
-      fail("Bad format, should have thrown exception");
-    } catch (GeneralSecurityException expected) {
-      // Expected
-    }
-    try {
-      keyManager.newKeyData(serializedKeyFormat);
-      fail("Bad format, should have thrown exception");
-    } catch (GeneralSecurityException expected) {
-      // Expected
+      factory.validateKeyFormat(createKeyFormat(32, 32, HashType.UNKNOWN_HASH, 1024));
+      fail();
+    } catch (GeneralSecurityException e) {
+      // expected
     }
+  }
 
-    // Unknown HKDF HashType.
-    AesGcmHkdfStreamingParams badKeyParams =
-        AesGcmHkdfStreamingParams.newBuilder()
-            .setCiphertextSegmentSize(128)
-            .setDerivedKeySize(AES_KEY_SIZE)
-            .build();
-    keyFormat =
-        AesGcmHkdfStreamingKeyFormat.newBuilder().setParams(badKeyParams).setKeySize(16).build();
-    serializedKeyFormat = ByteString.copyFrom(keyFormat.toByteArray());
-    try {
-      keyManager.newKey(keyFormat);
-      fail("Bad format, should have thrown exception");
-    } catch (GeneralSecurityException expected) {
-      // Expected
-    }
+  @Test
+  public void validateKeyFormat_smallKey_throws() throws Exception {
     try {
-      keyManager.newKeyData(serializedKeyFormat);
-      fail("Bad format, should have thrown exception");
-    } catch (GeneralSecurityException expected) {
-      // Expected
+      // TODO(b/140161847): Also check (16,32,SHA256,1024)
+      factory.validateKeyFormat(createKeyFormat(15, 32, HashType.SHA256, 1024));
+      fail();
+    } catch (GeneralSecurityException e) {
+      // expected
     }
+  }
 
-    // derived_key_size too small.
-    badKeyParams =
-        AesGcmHkdfStreamingParams.newBuilder()
-            .setCiphertextSegmentSize(128)
-            .setDerivedKeySize(10)
-            .setHkdfHashType(HashType.SHA256)
-            .build();
-    keyFormat =
-        AesGcmHkdfStreamingKeyFormat.newBuilder().setParams(badKeyParams).setKeySize(16).build();
-    serializedKeyFormat = ByteString.copyFrom(keyFormat.toByteArray());
-    try {
-      keyManager.newKey(keyFormat);
-      fail("Bad format, should have thrown exception");
-    } catch (GeneralSecurityException expected) {
-      // Expected
-    }
+  @Test
+  public void validateKeyFormat_smallSegment_throws() throws Exception {
     try {
-      keyManager.newKeyData(serializedKeyFormat);
-      fail("Bad format, should have thrown exception");
-    } catch (GeneralSecurityException expected) {
-      // Expected
+      factory.validateKeyFormat(createKeyFormat(16, 32, HashType.SHA256, 45));
+      fail();
+    } catch (GeneralSecurityException e) {
+      // expected
     }
+  }
 
-    // ciphertext_segment_size too small.
-    badKeyParams =
-        AesGcmHkdfStreamingParams.newBuilder()
-            .setCiphertextSegmentSize(15)
-            .setDerivedKeySize(AES_KEY_SIZE)
-            .setHkdfHashType(HashType.SHA256)
-            .build();
-    keyFormat =
-        AesGcmHkdfStreamingKeyFormat.newBuilder().setParams(badKeyParams).setKeySize(16).build();
-    serializedKeyFormat = ByteString.copyFrom(keyFormat.toByteArray());
-    try {
-      keyManager.newKey(keyFormat);
-      fail("Bad format, should have thrown exception");
-    } catch (GeneralSecurityException expected) {
-      // Expected
-    }
-    try {
-      keyManager.newKeyData(serializedKeyFormat);
-      fail("Bad format, should have thrown exception");
-    } catch (GeneralSecurityException expected) {
-      // Expected
-    }
+  @Test
+  public void createKey_checkValues() throws Exception {
+    AesGcmHkdfStreamingKeyFormat format = createKeyFormat(32, 32, HashType.SHA256, 1024);
+
+    AesGcmHkdfStreamingKey key = factory.createKey(format);
+
+    assertThat(key.getParams()).isEqualTo(format.getParams());
+    assertThat(key.getVersion()).isEqualTo(0);
+    assertThat(key.getKeyValue()).hasSize(format.getKeySize());
+  }
+
+  @Test
+  public void testSkip() throws Exception {
+    AesGcmHkdfStreamingKey key = factory.createKey(createKeyFormat(32, 32, HashType.SHA256, 1024));
+    StreamingAead streamingAead = manager.getPrimitive(key, StreamingAead.class);
+    int offset = 0;
+    int plaintextSize = 1 << 16;
+    // Runs the test with different sizes for the chunks to skip.
+    StreamingTestUtil.testSkipWithStream(streamingAead, offset, plaintextSize, 1);
+    StreamingTestUtil.testSkipWithStream(streamingAead, offset, plaintextSize, 64);
+    StreamingTestUtil.testSkipWithStream(streamingAead, offset, plaintextSize, 300);
+  }
 
-    // All params good.
-    AesGcmHkdfStreamingParams goodKeyParams =
-        AesGcmHkdfStreamingParams.newBuilder()
-            .setCiphertextSegmentSize(130)
-            .setDerivedKeySize(AES_KEY_SIZE)
-            .setHkdfHashType(HashType.SHA256)
-            .build();
-    keyFormat =
-        AesGcmHkdfStreamingKeyFormat.newBuilder().setParams(goodKeyParams).setKeySize(16).build();
-    serializedKeyFormat = ByteString.copyFrom(keyFormat.toByteArray());
-    AesGcmHkdfStreamingKey unusedKey = (AesGcmHkdfStreamingKey) keyManager.newKey(keyFormat);
-    unusedKey = (AesGcmHkdfStreamingKey) keyManager.newKey(serializedKeyFormat);
+  @Test
+  public void testNewKeyMultipleTimes() throws Exception {
+    AesGcmHkdfStreamingKeyFormat keyFormat = createKeyFormat(32, 32, HashType.SHA256, 1024);
+    Set<String> keys = new TreeSet<>();
+    // Calls newKey multiple times and make sure that they generate different keys.
+    int numTests = 100;
+    for (int i = 0; i < numTests; i++) {
+      keys.add(TestUtil.hexEncode(factory.createKey(keyFormat).getKeyValue().toByteArray()));
+    }
+    assertThat(keys).hasSize(numTests);
   }
 }