From 4c865a5df09e0b874545e42c8133d52f749db1cb Mon Sep 17 00:00:00 2001
From: przydatek <przydatek@google.com>
Date: Tue, 20 Aug 2019 13:36:04 -0700
Subject: [PATCH] Removing const from RandomAccessStream::size(), and changing
 the return value to StatusOr<int64_t>, to allow richer error reporting.

PiperOrigin-RevId: 264455697
---
 cc/BUILD.bazel                                |  1 +
 cc/CMakeLists.txt                             |  1 +
 cc/random_access_stream.h                     | 11 ++--
 cc/streamingaead/BUILD.bazel                  |  1 +
 cc/streamingaead/CMakeLists.txt               |  1 +
 .../decrypting_random_access_stream.cc        |  5 +-
 .../decrypting_random_access_stream.h         |  2 +-
 .../decrypting_random_access_stream_test.cc   |  5 +-
 .../shared_random_access_stream.h             |  3 +-
 .../shared_random_access_stream_test.cc       |  2 +-
 cc/util/BUILD.bazel                           |  3 +
 cc/util/CMakeLists.txt                        |  3 +
 cc/util/file_random_access_stream.cc          |  6 +-
 cc/util/file_random_access_stream.h           |  3 +-
 cc/util/file_random_access_stream_test.cc     |  4 +-
 cc/util/test_util.h                           | 63 +++++++++++--------
 16 files changed, 71 insertions(+), 43 deletions(-)

diff --git a/cc/BUILD.bazel b/cc/BUILD.bazel
index 4323bd464..a13346625 100644
--- a/cc/BUILD.bazel
+++ b/cc/BUILD.bazel
@@ -155,6 +155,7 @@ cc_library(
     deps = [
         "//cc/util:buffer",
         "//cc/util:status",
+        "//cc/util:statusor",
     ],
 )
 
diff --git a/cc/CMakeLists.txt b/cc/CMakeLists.txt
index bca0c6b6a..35d36ec13 100644
--- a/cc/CMakeLists.txt
+++ b/cc/CMakeLists.txt
@@ -152,6 +152,7 @@ tink_cc_library(
   DEPS
     tink::util::buffer
     tink::util::status
+    tink::util::statusor
 )
 
 tink_cc_library(
diff --git a/cc/random_access_stream.h b/cc/random_access_stream.h
index bd2cc2117..b40445ad2 100644
--- a/cc/random_access_stream.h
+++ b/cc/random_access_stream.h
@@ -19,6 +19,7 @@
 
 #include "tink/util/buffer.h"
 #include "tink/util/status.h"
+#include "tink/util/statusor.h"
 
 namespace crypto {
 namespace tink {
@@ -53,12 +54,12 @@ class RandomAccessStream {
       crypto::tink::util::Buffer* dest_buffer) = 0;
 
   // Returns the size of this stream in bytes, if available.
-  // It is the "logical" size of a stream (i.e. of a sequence of bytes),
-  // stating how many bytes are there in the sequence.
+  // If the size is not available, returns a non-Ok status.
+  // The returned value is the "logical" size of a stream, i.e. of
+  // a sequence of bytes), stating how many bytes are there in the sequence.
   // For a successful PRead-operation the starting position should be
-  // in the range 0..size()-1 (otherwise a non-Ok status is returned).
-  // If the size is not available, returns -1;
-  virtual int64_t size() const = 0;
+  // in the range 0..size()-1 (otherwise PRead may return a non-Ok status).
+  virtual crypto::tink::util::StatusOr<int64_t> size() = 0;
 };
 
 }  // namespace tink
diff --git a/cc/streamingaead/BUILD.bazel b/cc/streamingaead/BUILD.bazel
index f8884a928..2e559f1a0 100644
--- a/cc/streamingaead/BUILD.bazel
+++ b/cc/streamingaead/BUILD.bazel
@@ -149,6 +149,7 @@ cc_library(
         "//cc:random_access_stream",
         "//cc/util:buffer",
         "//cc/util:status",
+        "//cc/util:statusor",
     ],
 )
 
diff --git a/cc/streamingaead/CMakeLists.txt b/cc/streamingaead/CMakeLists.txt
index 556cfc757..d87a781e6 100644
--- a/cc/streamingaead/CMakeLists.txt
+++ b/cc/streamingaead/CMakeLists.txt
@@ -131,6 +131,7 @@ tink_cc_library(
     tink::core::random_access_stream
     tink::util::buffer
     tink::util::status
+    tink::util::statusor
 )
 
 tink_cc_library(
diff --git a/cc/streamingaead/decrypting_random_access_stream.cc b/cc/streamingaead/decrypting_random_access_stream.cc
index 3b2e052f5..055c07e76 100644
--- a/cc/streamingaead/decrypting_random_access_stream.cc
+++ b/cc/streamingaead/decrypting_random_access_stream.cc
@@ -120,12 +120,13 @@ util::Status DecryptingRandomAccessStream::PRead(
                 "Could not find a decrypter matching the ciphertext stream.");
 }
 
-int64_t DecryptingRandomAccessStream::size() const {
+StatusOr<int64_t> DecryptingRandomAccessStream::size() {
   absl::ReaderMutexLock lock(&matching_mutex_);
   if (matching_stream_ != nullptr) {
     return matching_stream_->size();
   }
-  return -1;
+  // TODO(b/139722894): attempt matching here?
+  return Status(util::error::UNAVAILABLE, "no matching found yet");
 }
 
 }  // namespace streamingaead
diff --git a/cc/streamingaead/decrypting_random_access_stream.h b/cc/streamingaead/decrypting_random_access_stream.h
index 271b3961e..7471d5b05 100644
--- a/cc/streamingaead/decrypting_random_access_stream.h
+++ b/cc/streamingaead/decrypting_random_access_stream.h
@@ -51,7 +51,7 @@ class DecryptingRandomAccessStream : public crypto::tink::RandomAccessStream {
   ~DecryptingRandomAccessStream() override {}
   crypto::tink::util::Status PRead(int64_t position, int count,
       crypto::tink::util::Buffer* dest_buffer) override;
-  int64_t size() const override;
+  crypto::tink::util::StatusOr<int64_t> size() override;
 
  private:
   DecryptingRandomAccessStream(
diff --git a/cc/streamingaead/decrypting_random_access_stream_test.cc b/cc/streamingaead/decrypting_random_access_stream_test.cc
index 8cbb03aff..02b78f0a0 100644
--- a/cc/streamingaead/decrypting_random_access_stream_test.cc
+++ b/cc/streamingaead/decrypting_random_access_stream_test.cc
@@ -168,11 +168,12 @@ TEST(DecryptingRandomAccessStreamTest, BasicDecryption) {
         auto dec_stream_result =
             DecryptingRandomAccessStream::New(saead_set, std::move(ct), aad);
         EXPECT_THAT(dec_stream_result.status(), IsOk());
+        auto dec_stream = std::move(dec_stream_result.ValueOrDie());
         std::string decrypted;
-        auto status = ReadAll(dec_stream_result.ValueOrDie().get(),
-                              &decrypted);
+        auto status = ReadAll(dec_stream.get(), &decrypted);
         EXPECT_THAT(status, StatusIs(util::error::OUT_OF_RANGE,
                                      HasSubstr("EOF")));
+        EXPECT_EQ(pt_size, dec_stream->size().ValueOrDie());
         EXPECT_EQ(plaintext, decrypted);
       }
     }
diff --git a/cc/streamingaead/shared_random_access_stream.h b/cc/streamingaead/shared_random_access_stream.h
index 06aab3e3f..9e8a0a4ff 100644
--- a/cc/streamingaead/shared_random_access_stream.h
+++ b/cc/streamingaead/shared_random_access_stream.h
@@ -20,6 +20,7 @@
 #include "tink/random_access_stream.h"
 #include "tink/util/buffer.h"
 #include "tink/util/status.h"
+#include "tink/util/statusor.h"
 
 namespace crypto {
 namespace tink {
@@ -45,7 +46,7 @@ class SharedRandomAccessStream : public crypto::tink::RandomAccessStream {
     return random_access_stream_->PRead(position, count, dest_buffer);
   }
 
-  int64_t size() const override {
+  crypto::tink::util::StatusOr<int64_t> size() override {
     return random_access_stream_->size();
   }
 
diff --git a/cc/streamingaead/shared_random_access_stream_test.cc b/cc/streamingaead/shared_random_access_stream_test.cc
index 88d930acf..1aa473e51 100644
--- a/cc/streamingaead/shared_random_access_stream_test.cc
+++ b/cc/streamingaead/shared_random_access_stream_test.cc
@@ -66,7 +66,7 @@ TEST(SharedRandomAccessStreamTest, ReadingStreams) {
     EXPECT_EQ(util::error::OUT_OF_RANGE, status.error_code());
     EXPECT_EQ("EOF", status.error_message());
     EXPECT_EQ(file_contents, stream_contents);
-    EXPECT_EQ(stream_size, shared_stream.size());
+    EXPECT_EQ(stream_size, shared_stream.size().ValueOrDie());
   }
 }
 
diff --git a/cc/util/BUILD.bazel b/cc/util/BUILD.bazel
index da28ab80e..ce472c213 100644
--- a/cc/util/BUILD.bazel
+++ b/cc/util/BUILD.bazel
@@ -130,6 +130,7 @@ cc_library(
         ":buffer",
         ":errors",
         ":status",
+        ":statusor",
         "//cc:random_access_stream",
         "@com_google_absl//absl/memory",
     ],
@@ -202,8 +203,10 @@ cc_library(
         "//proto:ecies_aead_hkdf_cc_proto",
         "//proto:ed25519_cc_proto",
         "//proto:tink_cc_proto",
+        "@com_google_absl//absl/base:core_headers",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/synchronization",
     ],
 )
 
diff --git a/cc/util/CMakeLists.txt b/cc/util/CMakeLists.txt
index e9badf7ae..abd44e48d 100644
--- a/cc/util/CMakeLists.txt
+++ b/cc/util/CMakeLists.txt
@@ -107,6 +107,7 @@ tink_cc_library(
     tink::util::buffer
     tink::util::errors
     tink::util::status
+    tink::util::statusor
     tink::core::random_access_stream
     absl::memory
 )
@@ -172,8 +173,10 @@ tink_cc_library(
     tink::proto::ed25519_cc_proto
     tink::proto::tink_cc_proto
     tink::util::buffer
+    absl::core_headers
     absl::memory
     absl::strings
+    absl::synchronization
 )
 
 tink_cc_library(
diff --git a/cc/util/file_random_access_stream.cc b/cc/util/file_random_access_stream.cc
index c6db78f68..f0ddcc6c9 100644
--- a/cc/util/file_random_access_stream.cc
+++ b/cc/util/file_random_access_stream.cc
@@ -25,12 +25,14 @@
 #include "tink/util/buffer.h"
 #include "tink/util/errors.h"
 #include "tink/util/status.h"
+#include "tink/util/statusor.h"
 
 namespace crypto {
 namespace tink {
 namespace util {
 
 using crypto::tink::util::Status;
+using crypto::tink::util::StatusOr;
 
 namespace {
 
@@ -91,10 +93,10 @@ FileRandomAccessStream::~FileRandomAccessStream() {
   close_ignoring_eintr(fd_);
 }
 
-int64_t FileRandomAccessStream::size() const {
+StatusOr<int64_t> FileRandomAccessStream::size() {
   struct stat s;
   if (fstat(fd_, &s) == -1) {
-    return -1;
+    return Status(util::error::UNAVAILABLE, "size unavailable");
   } else {
     return s.st_size;
   }
diff --git a/cc/util/file_random_access_stream.h b/cc/util/file_random_access_stream.h
index 33aa3b4ce..47bfc13d4 100644
--- a/cc/util/file_random_access_stream.h
+++ b/cc/util/file_random_access_stream.h
@@ -22,6 +22,7 @@
 #include "tink/random_access_stream.h"
 #include "tink/util/buffer.h"
 #include "tink/util/status.h"
+#include "tink/util/statusor.h"
 
 namespace crypto {
 namespace tink {
@@ -41,7 +42,7 @@ class FileRandomAccessStream : public crypto::tink::RandomAccessStream {
                                    int count,
                                    Buffer* dest_buffer) override;
 
-  int64_t size() const override;
+  crypto::tink::util::StatusOr<int64_t> size() override;
 
  private:
   int fd_;
diff --git a/cc/util/file_random_access_stream_test.cc b/cc/util/file_random_access_stream_test.cc
index 365a014eb..9fc592433 100644
--- a/cc/util/file_random_access_stream_test.cc
+++ b/cc/util/file_random_access_stream_test.cc
@@ -60,7 +60,7 @@ void ReadAndVerifyChunk(RandomAccessStream* ra_stream,
                             ", position = ", position,
                             ", count = ", count));
   auto buffer = std::move(Buffer::New(count).ValueOrDie());
-  int stream_size = ra_stream->size();
+  int stream_size = ra_stream->size().ValueOrDie();
   EXPECT_EQ(file_contents.size(), stream_size);
   auto status = ra_stream->PRead(position, count, buffer.get());
   EXPECT_TRUE(status.ok());
@@ -89,7 +89,7 @@ TEST(FileRandomAccessStreamTest, ReadingStreams) {
     EXPECT_EQ(util::error::OUT_OF_RANGE, status.error_code());
     EXPECT_EQ("EOF", status.error_message());
     EXPECT_EQ(file_contents, stream_contents);
-    EXPECT_EQ(stream_size, ra_stream->size());
+    EXPECT_EQ(stream_size, ra_stream->size().ValueOrDie());
   }
 }
 
diff --git a/cc/util/test_util.h b/cc/util/test_util.h
index 13ec54f34..a48fc8134 100644
--- a/cc/util/test_util.h
+++ b/cc/util/test_util.h
@@ -20,9 +20,11 @@
 #include <limits>
 #include <string>
 
+#include "absl/base/thread_annotations.h"
 #include "absl/strings/match.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
+#include "absl/synchronization/mutex.h"
 #include "tink/aead.h"
 #include "tink/deterministic_aead.h"
 #include "tink/hybrid_decrypt.h"
@@ -389,52 +391,61 @@ class DummyStreamingAead : public StreamingAead {
         std::unique_ptr<crypto::tink::RandomAccessStream> ct_source,
         absl::string_view expected_header)
         : ct_source_(std::move(ct_source)), exp_header_(expected_header),
-          after_init_(false), status_(util::OkStatus()) {}
+          status_(util::Status(util::error::UNAVAILABLE, "not initialized")) {}
 
     crypto::tink::util::Status PRead(
         int64_t position, int count,
         crypto::tink::util::Buffer* dest_buffer) override {
-      if (!after_init_) {  // Try to initialize.
-        after_init_ = true;
-        status_ = dest_buffer->set_size(0);
+      {  // Initialize, if not initialized yet.
+        absl::MutexLock lock(&status_mutex_);
+        if (status_.error_code() == util::error::UNAVAILABLE) Initialize();
         if (!status_.ok()) return status_;
-        auto buf = std::move(
-            util::Buffer::New(exp_header_.size()).ValueOrDie());
-        status_ = ct_source_->PRead(0, exp_header_.size(), buf.get());
-        if (!status_.ok() &&
-            status_.error_code() != util::error::OUT_OF_RANGE) return status_;
-        if (buf->size() < exp_header_.size()) {
-          status_ = util::Status(
-              util::error::INVALID_ARGUMENT, "Could not read header");
-        } else if (memcmp(buf->get_mem_block(), exp_header_.data(),
-                          static_cast<int>(exp_header_.size()))) {
-          status_ = util::Status(
-              util::error::INVALID_ARGUMENT, "Corrupted header");
-        }
       }
-      if (!status_.ok()) return status_;
+      auto status = dest_buffer->set_size(0);
+      if (!status.ok()) return status;
       return ct_source_->PRead(
           position + exp_header_.size(), count, dest_buffer);
     }
 
-    int64_t size() const override {
-      if (after_init_ && status_.ok()) {
-        auto pt_size = ct_source_->size() - exp_header_.size();
-        if (pt_size >= 0) return pt_size;
+    util::StatusOr<int64_t> size() override {
+      {  // Initialize, if not initialized yet.
+        absl::MutexLock lock(&status_mutex_);
+        if (status_.error_code() == util::error::UNAVAILABLE) Initialize();
+        if (!status_.ok()) return status_;
       }
-      return -1;
+      auto ct_size_result = ct_source_->size();
+      if (!ct_size_result.ok()) return ct_size_result.status();
+      auto pt_size = ct_size_result.ValueOrDie() - exp_header_.size();
+      if (pt_size >= 0) return pt_size;
+      return util::Status(util::error::UNAVAILABLE, "size not available");
     }
 
    private:
+    void Initialize() EXCLUSIVE_LOCKS_REQUIRED(status_mutex_) {
+      auto buf = std::move(
+          util::Buffer::New(exp_header_.size()).ValueOrDie());
+      status_ = ct_source_->PRead(0, exp_header_.size(), buf.get());
+      if (!status_.ok() &&
+          status_.error_code() != util::error::OUT_OF_RANGE) return;
+      if (buf->size() < exp_header_.size()) {
+        status_ = util::Status(
+            util::error::INVALID_ARGUMENT, "Could not read header");
+      } else if (memcmp(buf->get_mem_block(), exp_header_.data(),
+                        static_cast<int>(exp_header_.size()))) {
+        status_ = util::Status(
+            util::error::INVALID_ARGUMENT, "Corrupted header");
+      }
+    }
+
     std::unique_ptr<crypto::tink::RandomAccessStream> ct_source_;
     std::string exp_header_;
-    bool after_init_;
-    util::Status status_;
+    mutable absl::Mutex status_mutex_;
+    util::Status status_ GUARDED_BY(status_mutex_);
   };  // class DummyDecryptingRandomAccessStream
 
  private:
   std::string streaming_aead_name_;
-};
+};  // class DummyStreamingAead
 
 // A dummy implementation of HybridEncrypt-interface.
 // An instance of DummyHybridEncrypt can be identified by a name specified
-- 
GitLab