From cd25d32f30697b270fe151a485cb709747817c40 Mon Sep 17 00:00:00 2001
From: candrian <candrian@google.com>
Date: Thu, 21 Mar 2019 10:18:44 -0700
Subject: [PATCH] cc: Extra sanity checks on setPrimitive.

PiperOrigin-RevId: 239620317
GitOrigin-RevId: 4e5c8b8d8c88fbb97bbc7956b74240209e677bb1
---
 cc/keyset_handle.h    | 11 ++++---
 cc/primitive_set.h    | 34 ++++++++++++++------
 cc/util/validation.cc | 75 ++++++++++++++++++++++++++++++++++++++++---
 cc/util/validation.h  |  3 ++
 4 files changed, 104 insertions(+), 19 deletions(-)

diff --git a/cc/keyset_handle.h b/cc/keyset_handle.h
index 039201156..90a1e0da4 100644
--- a/cc/keyset_handle.h
+++ b/cc/keyset_handle.h
@@ -52,7 +52,7 @@ class KeysetHandle {
   // and writes the resulting EncryptedKeyset to the given |writer|,
   // which must be non-null.
   crypto::tink::util::Status Write(KeysetWriter* writer,
-      const Aead& master_key_aead);
+                                   const Aead& master_key_aead);
 
   // Writes the underlying keyset to |writer| only if the keyset does not
   // contain any secret key material.
@@ -114,8 +114,8 @@ class KeysetHandle {
   // The returned set is usually later "wrapped" into a class that
   // implements the corresponding Primitive-interface.
   template <class P>
-  crypto::tink::util::StatusOr<std::unique_ptr<PrimitiveSet<P>>>
-      GetPrimitives(const KeyManager<P>* custom_manager) const;
+  crypto::tink::util::StatusOr<std::unique_ptr<PrimitiveSet<P>>> GetPrimitives(
+      const KeyManager<P>* custom_manager) const;
 
   google::crypto::tink::Keyset keyset_;
 };
@@ -145,7 +145,9 @@ KeysetHandle::GetPrimitives(const KeyManager<P>* custom_manager) const {
       auto entry_result = primitives->AddPrimitive(std::move(primitive), key);
       if (!entry_result.ok()) return entry_result.status();
       if (key.key_id() == get_keyset().primary_key_id()) {
-        primitives->set_primary(entry_result.ValueOrDie());
+        auto primary_result =
+            primitives->set_primary(entry_result.ValueOrDie());
+        if (!primary_result.ok()) return primary_result;
       }
     }
   }
@@ -176,7 +178,6 @@ crypto::tink::util::StatusOr<std::unique_ptr<P>> KeysetHandle::GetPrimitive(
   return Registry::Wrap<P>(std::move(primitives_result.ValueOrDie()));
 }
 
-
 }  // namespace tink
 }  // namespace crypto
 
diff --git a/cc/primitive_set.h b/cc/primitive_set.h
index 15b654a93..e938179e6 100644
--- a/cc/primitive_set.h
+++ b/cc/primitive_set.h
@@ -67,12 +67,9 @@ class PrimitiveSet {
 
     const std::string& get_identifier() const { return identifier_; }
 
-    google::crypto::tink::KeyStatusType get_status() const {
-      return status_;
-    }
+    google::crypto::tink::KeyStatusType get_status() const { return status_; }
 
-    google::crypto::tink::OutputPrefixType get_output_prefix_type()
-        const {
+    google::crypto::tink::OutputPrefixType get_output_prefix_type() const {
       return output_prefix_type_;
     }
 
@@ -104,9 +101,8 @@ class PrimitiveSet {
     std::string identifier = identifier_result.ValueOrDie();
     absl::MutexLock lock(&primitives_mutex_);
     primitives_[identifier].push_back(
-        absl::make_unique<Entry<P>>(std::move(primitive),
-                                    identifier, key.status(),
-                                    key.output_prefix_type()));
+        absl::make_unique<Entry<P>>(std::move(primitive), identifier,
+                                    key.status(), key.output_prefix_type()));
     return primitives_[identifier].back().get();
   }
 
@@ -129,8 +125,26 @@ class PrimitiveSet {
     return get_primitives(CryptoFormat::kRawPrefix);
   }
 
-  // Sets the given 'primary' as as the primary primitive of this set.
-  void set_primary(Entry<P>* primary) { primary_ = primary; }
+  // Sets the given 'primary' as the primary primitive of this set.
+  crypto::tink::util::Status set_primary(Entry<P>* primary) {
+    if (!primary) {
+      return ToStatusF(crypto::tink::util::error::INVALID_ARGUMENT,
+                       "The primary primitive must be non-null.");
+    }
+    if (primary->get_status() != google::crypto::tink::KeyStatusType::ENABLED) {
+      return ToStatusF(crypto::tink::util::error::INVALID_ARGUMENT,
+                       "Primary has to be enabled.");
+    }
+    auto entries_result = get_primitives(primary->get_identifier());
+    if (!entries_result.ok()) {
+      return ToStatusF(crypto::tink::util::error::INVALID_ARGUMENT,
+                       "Primary cannot be set to an entry which is "
+                       "not held by this primitive set.");
+    }
+
+    primary_ = primary;
+    return crypto::tink::util::Status::OK;
+  }
 
   // Returns the entry with the primary primitive.
   const Entry<P>* get_primary() const { return primary_; }
diff --git a/cc/util/validation.cc b/cc/util/validation.cc
index 96a9bef1e..9906c18b2 100644
--- a/cc/util/validation.cc
+++ b/cc/util/validation.cc
@@ -20,12 +20,13 @@
 #include "tink/util/status.h"
 #include "proto/tink.pb.h"
 
+using google::crypto::tink::KeyData;
+using google::crypto::tink::Keyset;
+using google::crypto::tink::KeyStatusType;
 
 namespace crypto {
 namespace tink {
 
-// TODO(przydatek): add more validation checks
-
 util::Status ValidateAesKeySize(uint32_t key_size) {
   if (key_size != 16 && key_size != 32) {
     return ToStatusF(util::error::INVALID_ARGUMENT,
@@ -35,11 +36,78 @@ util::Status ValidateAesKeySize(uint32_t key_size) {
   return util::Status::OK;
 }
 
-util::Status ValidateKeyset(const google::crypto::tink::Keyset& keyset) {
+util::Status ValidateKey(const Keyset::Key& key) {
+  if (!key.has_key_data()) {
+    return ToStatusF(util::error::INVALID_ARGUMENT, "key %d, has no key data",
+                     key.key_id());
+  }
+
+  if (key.output_prefix_type() ==
+      google::crypto::tink::OutputPrefixType::UNKNOWN_PREFIX) {
+    return ToStatusF(util::error::INVALID_ARGUMENT, "key %d has unknown prefix",
+                     key.key_id());
+  }
+
+  if (key.status() == google::crypto::tink::KeyStatusType::UNKNOWN_STATUS) {
+    return ToStatusF(util::error::INVALID_ARGUMENT, "key %d has unknown status",
+                     key.key_id());
+  }
+  return util::Status::OK;
+}
+
+util::Status ValidateKeyset(const Keyset& keyset) {
   if (keyset.key_size() < 1) {
     return ToStatusF(util::error::INVALID_ARGUMENT,
                      "A valid keyset must contain at least one key.");
   }
+
+  int primary_key_id = keyset.primary_key_id();
+  bool has_primary_key = false;
+  bool contains_only_public_key_material = true;
+  int enabled_keys = 0;
+
+  for (int i = 0; i < keyset.key_size(); i++) {
+    const Keyset::Key& key = keyset.key(i);
+
+
+    if (key.status() != KeyStatusType::ENABLED) {
+      continue;
+    }
+    enabled_keys += 1;
+
+    auto validation_result = ValidateKey(key);
+    if (!validation_result.ok()) {
+      return validation_result;
+    }
+
+    if (key.status() == KeyStatusType::ENABLED &&
+        key.key_id() == primary_key_id) {
+      if (has_primary_key) {
+        return ToStatusF(util::error::INVALID_ARGUMENT,
+                         "keyset contains multiple primary keys");
+      }
+      has_primary_key = true;
+    }
+
+    if (key.key_data().key_material_type() !=
+        KeyData::KeyMaterialType::KeyData_KeyMaterialType_ASYMMETRIC_PUBLIC) {
+      contains_only_public_key_material = false;
+    }
+  }
+
+  if (enabled_keys == 0) {
+    return ToStatusF(util::error::INVALID_ARGUMENT,
+                     "keyset must contain at least one ENABLED key");
+  }
+
+  // A public key can be used for verification without being set as the primary
+  // key. Therefore, it is okay to have a keyset that contains public but
+  // doesn't have a primary key set.
+  if (!has_primary_key && !contains_only_public_key_material) {
+    return ToStatusF(util::error::INVALID_ARGUMENT,
+                     "keyset doesn't contain a valid primary key");
+  }
+
   return util::Status::OK;
 }
 
@@ -53,6 +121,5 @@ util::Status ValidateVersion(uint32_t candidate, uint32_t max_expected) {
   return util::Status::OK;
 }
 
-
 }  // namespace tink
 }  // namespace crypto
diff --git a/cc/util/validation.h b/cc/util/validation.h
index a6e360550..4580723b3 100644
--- a/cc/util/validation.h
+++ b/cc/util/validation.h
@@ -27,6 +27,9 @@ namespace tink {
 
 crypto::tink::util::Status ValidateAesKeySize(uint32_t key_size);
 
+crypto::tink::util::Status ValidateKey(
+    const google::crypto::tink::Keyset::Key& key);
+
 crypto::tink::util::Status ValidateKeyset(
     const google::crypto::tink::Keyset& keyset);
 
-- 
GitLab