Newer
Older
# Copyright 2019 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module defines KeysetHandle."""
from __future__ import absolute_import
from __future__ import division
from __future__ import google_type_annotations
from __future__ import print_function
import random
from typing import Type, TypeVar
from tink.proto import tink_pb2
from tink.python.aead import aead
from tink.python.core import keyset_reader as reader
from tink.python.core import keyset_writer as writer
from tink.python.core import primitive_set
from tink.python.core import registry
from tink.python.core import tink_error
P = TypeVar('P')
MAX_INT32 = 2147483647 # = 2^31 - 1
class KeysetHandle(object):
"""A KeysetHandle provides abstracted access to Keyset.
KeysetHandle limits the exposure of actual protocol buffers that hold
sensitive key material. This class allows reading and writing encrypted
keysets.
"""
def __new__(cls):
raise tink_error.TinkError(
('KeysetHandle cannot be instantiated directly.'))
def __init__(self, keyset: tink_pb2.Keyset):
self._keyset = keyset
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
@classmethod
def generate_new(cls, key_template: tink_pb2.KeyTemplate) -> 'KeysetHandle':
"""Return a new KeysetHandle.
It contains a single fresh key generated according to key_template.
Args:
key_template: A tink_pb2.KeyTemplate object.
Returns:
A new KeysetHandle.
"""
keyset = tink_pb2.Keyset()
key_data = registry.Registry.new_key_data(key_template)
key_id = _generate_unused_key_id(keyset)
key = keyset.key.add()
key.key_data.CopyFrom(key_data)
key.status = tink_pb2.ENABLED
key.key_id = key_id
key.output_prefix_type = key_template.output_prefix_type
keyset.primary_key_id = key_id
return cls._create(keyset)
@classmethod
def read(cls, keyset_reader: reader.KeysetReader,
master_key_aead: aead.Aead) -> 'KeysetHandle':
"""Tries to create a KeysetHandle from an encrypted keyset."""
encrypted_keyset = keyset_reader.read_encrypted()
_assert_enough_encrypted_key_material(encrypted_keyset)
return cls._create(_decrypt(encrypted_keyset, master_key_aead))
@classmethod
def _create(cls, keyset: tink_pb2.Keyset):
o = object.__new__(cls)
o.__init__(keyset)
return o
def keyset_info(self) -> tink_pb2.KeysetInfo:
"""Returns the KeysetInfo that doesn't contain actual key material."""
return _keyset_info(self._keyset)
def write(self, keyset_writer: writer.KeysetWriter,
master_key_primitive: aead.Aead) -> None:
"""Serializes, encrypts with master_key_primitive and writes the keyset."""
encrypted_keyset = _encrypt(self._keyset, master_key_primitive)
keyset_writer.write_encrypted(encrypted_keyset)
def public_keyset_handle(self) -> 'KeysetHandle':
"""Returns a new KeysetHandle for the corresponding public keys."""
public_keyset = tink_pb2.Keyset()
for key in self._keyset.key:
public_key = public_keyset.key.add()
public_key.CopyFrom(key)
public_key.key_data.CopyFrom(
registry.Registry.public_key_data(key.key_data))
_validate_key(public_key)
public_keyset.primary_key_id = self._keyset.primary_key_id
def primitive(self, primitive_class: Type[P]) -> P:
"""Returns a wrapped primitive from this KeysetHandle.
Uses the KeyManager and the PrimitiveWrapper objects in the global
registry.Registry
to create the primitive. This function is the most common way of creating a
primitive.
Args:
primitive_class: The class of the primitive.
Returns:
The primitive.
Raises:
tink.TinkError if creation of the primitive fails, for example if
primitive_class cannot be used with this KeysetHandle.
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
"""
_validate_keyset(self._keyset)
pset = primitive_set.PrimitiveSet(primitive_class)
for key in self._keyset.key:
if key.status == tink_pb2.ENABLED:
primitive = registry.Registry.primitive(key.key_data, primitive_class)
entry = pset.add_primitive(primitive, key)
if key.key_id == self._keyset.primary_key_id:
pset.set_primary(entry)
return registry.Registry.wrap(pset)
def _keyset_info(keyset: tink_pb2.Keyset) -> tink_pb2.KeysetInfo:
keyset_info = tink_pb2.KeysetInfo(primary_key_id=keyset.primary_key_id)
for key in keyset.key:
key_info = keyset_info.key_info.add()
key_info.type_url = key.key_data.type_url
key_info.status = key.status
key_info.output_prefix_type = key.output_prefix_type
key_info.key_id = key.key_id
return keyset_info
def _encrypt(keyset: tink_pb2.Keyset,
master_key_primitive: aead.Aead) -> tink_pb2.EncryptedKeyset:
"""Encrypts a Keyset and returns an EncryptedKeyset."""
encrypted_keyset = master_key_primitive.encrypt(keyset.SerializeToString(),
b'')
# Check if we can decrypt, to detect errors
try:
keyset2 = tink_pb2.Keyset.FromString(
master_key_primitive.decrypt(encrypted_keyset, b''))
if keyset != keyset2:
raise tink_error.TinkError('cannot encrypt keyset: %s != %s' %
(keyset, keyset2))
except message.DecodeError:
raise tink_error.TinkError('invalid keyset, corrupted key material')
return tink_pb2.EncryptedKeyset(
encrypted_keyset=encrypted_keyset, keyset_info=_keyset_info(keyset))
def _decrypt(encrypted_keyset: tink_pb2.EncryptedKeyset,
master_key_aead: aead.Aead) -> tink_pb2.Keyset:
"""Decrypts an EncryptedKeyset and returns a Keyset."""
try:
keyset = tink_pb2.Keyset.FromString(
master_key_aead.decrypt(encrypted_keyset.encrypted_keyset, b''))
# Check emptiness here too, in case the encrypted keys unwrapped to nothing?
_assert_enough_key_material(keyset)
return keyset
except message.DecodeError:
raise tink_error.TinkError('invalid keyset, corrupted key material')
def _validate_keyset(keyset: tink_pb2.Keyset):
"""Raises tink_error.TinkError if keyset is not valid."""
for key in keyset.key:
if key.status != tink_pb2.DESTROYED:
_validate_key(key)
num_non_destroyed_keys = sum(
1 for key in keyset.key if key.status != tink_pb2.DESTROYED)
num_non_public_key_material = sum(
1 for key in keyset.key
if key.key_data.key_material_type != tink_pb2.KeyData.ASYMMETRIC_PUBLIC)
num_primary_keys = sum(
1 for key in keyset.key
if key.status == tink_pb2.ENABLED and key.key_id == keyset.primary_key_id)
if num_non_destroyed_keys == 0:
raise tink_error.TinkError('empty keyset')
if num_primary_keys > 1:
raise tink_error.TinkError('keyset contains multiple primary keys')
if num_primary_keys == 0 and num_non_public_key_material > 0:
raise tink_error.TinkError('keyset does not contain a valid primary key')
def _validate_key(key: tink_pb2.Keyset.Key):
"""Raises tink_error.TinkError if key is not valid."""
if not key.HasField('key_data'):
raise tink_error.TinkError('key {} has no key data'.format(key.key_id))
if key.output_prefix_type == tink_pb2.UNKNOWN_PREFIX:
raise tink_error.TinkError('key {} has unknown prefix'.format(key.key_id))
if key.status == tink_pb2.UNKNOWN_STATUS:
raise tink_error.TinkError('key {} has unknown status'.format(key.key_id))
def _assert_no_secret_key_material(keyset: tink_pb2.Keyset):
for key in keyset.key:
if (key.key_data.key_material_type == tink_pb2.KeyData.UNKNOWN_KEYMATERIAL
or key.key_data.key_material_type == tink_pb2.KeyData.SYMMETRIC or
key.key_data.key_material_type == tink_pb2.KeyData.ASYMMETRIC_PRIVATE):
raise tink_error.TinkError('keyset contains secret key material')
def _assert_enough_key_material(keyset: tink_pb2.Keyset):
if not keyset or not keyset.key:
raise tink_error.TinkError('empty keyset')
def _assert_enough_encrypted_key_material(
encrypted_keyset: tink_pb2.EncryptedKeyset):
if not encrypted_keyset or not encrypted_keyset.encrypted_keyset:
raise tink_error.TinkError('empty keyset')
def _generate_unused_key_id(keyset: tink_pb2.Keyset) -> int:
while True:
key_id = random.randint(1, MAX_INT32)
if key_id not in {key.key_id for key in keyset.key}:
return key_id