mongo-charms-single-kernel 1.8.7__py3-none-any.whl → 1.8.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mongo-charms-single-kernel might be problematic. Click here for more details.
- {mongo_charms_single_kernel-1.8.7.dist-info → mongo_charms_single_kernel-1.8.9.dist-info}/METADATA +1 -1
- {mongo_charms_single_kernel-1.8.7.dist-info → mongo_charms_single_kernel-1.8.9.dist-info}/RECORD +30 -28
- single_kernel_mongo/config/literals.py +8 -3
- single_kernel_mongo/config/models.py +12 -0
- single_kernel_mongo/config/relations.py +2 -1
- single_kernel_mongo/config/statuses.py +127 -20
- single_kernel_mongo/core/operator.py +68 -1
- single_kernel_mongo/core/structured_config.py +2 -0
- single_kernel_mongo/core/workload.py +10 -4
- single_kernel_mongo/events/cluster.py +5 -0
- single_kernel_mongo/events/sharding.py +3 -1
- single_kernel_mongo/events/tls.py +183 -157
- single_kernel_mongo/exceptions.py +0 -8
- single_kernel_mongo/lib/charms/operator_libs_linux/v1/systemd.py +288 -0
- single_kernel_mongo/lib/charms/tls_certificates_interface/v4/tls_certificates.py +1995 -0
- single_kernel_mongo/managers/cluster.py +70 -28
- single_kernel_mongo/managers/config.py +14 -8
- single_kernel_mongo/managers/mongo.py +1 -1
- single_kernel_mongo/managers/mongodb_operator.py +53 -56
- single_kernel_mongo/managers/mongos_operator.py +18 -20
- single_kernel_mongo/managers/sharding.py +154 -127
- single_kernel_mongo/managers/tls.py +223 -206
- single_kernel_mongo/state/charm_state.py +39 -16
- single_kernel_mongo/state/cluster_state.py +8 -0
- single_kernel_mongo/state/config_server_state.py +9 -0
- single_kernel_mongo/state/tls_state.py +39 -12
- single_kernel_mongo/templates/enable-transparent-huge-pages.service.j2 +14 -0
- single_kernel_mongo/utils/helpers.py +4 -19
- single_kernel_mongo/lib/charms/tls_certificates_interface/v3/tls_certificates.py +0 -2123
- {mongo_charms_single_kernel-1.8.7.dist-info → mongo_charms_single_kernel-1.8.9.dist-info}/WHEEL +0 -0
- {mongo_charms_single_kernel-1.8.7.dist-info → mongo_charms_single_kernel-1.8.9.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1995 @@
|
|
|
1
|
+
# Copyright 2024 Canonical Ltd.
|
|
2
|
+
# See LICENSE file for licensing details.
|
|
3
|
+
|
|
4
|
+
"""Charm library for managing TLS certificates (V4).
|
|
5
|
+
|
|
6
|
+
This library contains the Requires and Provides classes for handling the tls-certificates
|
|
7
|
+
interface.
|
|
8
|
+
|
|
9
|
+
Pre-requisites:
|
|
10
|
+
- Juju >= 3.0
|
|
11
|
+
- cryptography >= 43.0.0
|
|
12
|
+
- pydantic >= 1.0
|
|
13
|
+
|
|
14
|
+
Learn more on how-to use the TLS Certificates interface library by reading the documentation:
|
|
15
|
+
- https://charmhub.io/tls-certificates-interface/
|
|
16
|
+
|
|
17
|
+
""" # noqa: D214, D405, D411, D416
|
|
18
|
+
|
|
19
|
+
import copy
|
|
20
|
+
import ipaddress
|
|
21
|
+
import json
|
|
22
|
+
import logging
|
|
23
|
+
import uuid
|
|
24
|
+
from contextlib import suppress
|
|
25
|
+
from dataclasses import dataclass
|
|
26
|
+
from datetime import datetime, timedelta, timezone
|
|
27
|
+
from enum import Enum
|
|
28
|
+
from typing import FrozenSet, List, MutableMapping, Optional, Tuple, Union
|
|
29
|
+
|
|
30
|
+
import pydantic
|
|
31
|
+
from cryptography import x509
|
|
32
|
+
from cryptography.exceptions import InvalidSignature
|
|
33
|
+
from cryptography.hazmat.primitives import hashes, serialization
|
|
34
|
+
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
35
|
+
from cryptography.x509.oid import ExtensionOID, NameOID
|
|
36
|
+
from ops import BoundEvent, CharmBase, CharmEvents, Secret, SecretExpiredEvent, SecretRemoveEvent
|
|
37
|
+
from ops.framework import EventBase, EventSource, Handle, Object
|
|
38
|
+
from ops.jujuversion import JujuVersion
|
|
39
|
+
from ops.model import Application, ModelError, Relation, SecretNotFoundError, Unit
|
|
40
|
+
|
|
41
|
+
# The unique Charmhub library identifier, never change it
|
|
42
|
+
LIBID = "afd8c2bccf834997afce12c2706d2ede"
|
|
43
|
+
|
|
44
|
+
# Increment this major API version when introducing breaking changes
|
|
45
|
+
LIBAPI = 4
|
|
46
|
+
|
|
47
|
+
# Increment this PATCH version before using `charmcraft publish-lib` or reset
|
|
48
|
+
# to 0 if you are raising the major API version
|
|
49
|
+
LIBPATCH = 22
|
|
50
|
+
|
|
51
|
+
PYDEPS = [
|
|
52
|
+
"cryptography>=43.0.0",
|
|
53
|
+
"pydantic",
|
|
54
|
+
]
|
|
55
|
+
IS_PYDANTIC_V1 = int(pydantic.version.VERSION.split(".")[0]) < 2
|
|
56
|
+
|
|
57
|
+
logger = logging.getLogger(__name__)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class TLSCertificatesError(Exception):
|
|
61
|
+
"""Base class for custom errors raised by this library."""
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class DataValidationError(TLSCertificatesError):
|
|
65
|
+
"""Raised when data validation fails."""
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class _DatabagModel(pydantic.BaseModel):
|
|
69
|
+
"""Base databag model.
|
|
70
|
+
|
|
71
|
+
Supports both pydantic v1 and v2.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
if IS_PYDANTIC_V1:
|
|
75
|
+
|
|
76
|
+
class Config:
|
|
77
|
+
"""Pydantic config."""
|
|
78
|
+
|
|
79
|
+
# ignore any extra fields in the databag
|
|
80
|
+
extra = "ignore"
|
|
81
|
+
"""Ignore any extra fields in the databag."""
|
|
82
|
+
allow_population_by_field_name = True
|
|
83
|
+
"""Allow instantiating this class by field name (instead of forcing alias)."""
|
|
84
|
+
|
|
85
|
+
_NEST_UNDER = None
|
|
86
|
+
|
|
87
|
+
model_config = pydantic.ConfigDict(
|
|
88
|
+
# tolerate additional keys in databag
|
|
89
|
+
extra="ignore",
|
|
90
|
+
# Allow instantiating this class by field name (instead of forcing alias).
|
|
91
|
+
populate_by_name=True,
|
|
92
|
+
# Custom config key: whether to nest the whole datastructure (as json)
|
|
93
|
+
# under a field or spread it out at the toplevel.
|
|
94
|
+
_NEST_UNDER=None,
|
|
95
|
+
) # type: ignore
|
|
96
|
+
"""Pydantic config."""
|
|
97
|
+
|
|
98
|
+
@classmethod
|
|
99
|
+
def load(cls, databag: MutableMapping):
|
|
100
|
+
"""Load this model from a Juju databag."""
|
|
101
|
+
if IS_PYDANTIC_V1:
|
|
102
|
+
return cls._load_v1(databag)
|
|
103
|
+
nest_under = cls.model_config.get("_NEST_UNDER")
|
|
104
|
+
if nest_under:
|
|
105
|
+
return cls.model_validate(json.loads(databag[nest_under]))
|
|
106
|
+
|
|
107
|
+
try:
|
|
108
|
+
data = {
|
|
109
|
+
k: json.loads(v)
|
|
110
|
+
for k, v in databag.items()
|
|
111
|
+
# Don't attempt to parse model-external values
|
|
112
|
+
if k in {(f.alias or n) for n, f in cls.model_fields.items()}
|
|
113
|
+
}
|
|
114
|
+
except json.JSONDecodeError as e:
|
|
115
|
+
msg = f"invalid databag contents: expecting json. {databag}"
|
|
116
|
+
logger.error(msg)
|
|
117
|
+
raise DataValidationError(msg) from e
|
|
118
|
+
|
|
119
|
+
try:
|
|
120
|
+
return cls.model_validate_json(json.dumps(data))
|
|
121
|
+
except pydantic.ValidationError as e:
|
|
122
|
+
msg = f"failed to validate databag: {databag}"
|
|
123
|
+
logger.debug(msg, exc_info=True)
|
|
124
|
+
raise DataValidationError(msg) from e
|
|
125
|
+
|
|
126
|
+
@classmethod
|
|
127
|
+
def _load_v1(cls, databag: MutableMapping):
|
|
128
|
+
"""Load implementation for pydantic v1."""
|
|
129
|
+
if cls._NEST_UNDER:
|
|
130
|
+
return cls.parse_obj(json.loads(databag[cls._NEST_UNDER]))
|
|
131
|
+
|
|
132
|
+
try:
|
|
133
|
+
data = {
|
|
134
|
+
k: json.loads(v)
|
|
135
|
+
for k, v in databag.items()
|
|
136
|
+
# Don't attempt to parse model-external values
|
|
137
|
+
if k in {f.alias for f in cls.__fields__.values()}
|
|
138
|
+
}
|
|
139
|
+
except json.JSONDecodeError as e:
|
|
140
|
+
msg = f"invalid databag contents: expecting json. {databag}"
|
|
141
|
+
logger.error(msg)
|
|
142
|
+
raise DataValidationError(msg) from e
|
|
143
|
+
|
|
144
|
+
try:
|
|
145
|
+
return cls.parse_raw(json.dumps(data)) # type: ignore
|
|
146
|
+
except pydantic.ValidationError as e:
|
|
147
|
+
msg = f"failed to validate databag: {databag}"
|
|
148
|
+
logger.debug(msg, exc_info=True)
|
|
149
|
+
raise DataValidationError(msg) from e
|
|
150
|
+
|
|
151
|
+
def dump(self, databag: Optional[MutableMapping] = None, clear: bool = True):
|
|
152
|
+
"""Write the contents of this model to Juju databag.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
databag: The databag to write to.
|
|
156
|
+
clear: Whether to clear the databag before writing.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
MutableMapping: The databag.
|
|
160
|
+
"""
|
|
161
|
+
if IS_PYDANTIC_V1:
|
|
162
|
+
return self._dump_v1(databag, clear)
|
|
163
|
+
if clear and databag:
|
|
164
|
+
databag.clear()
|
|
165
|
+
|
|
166
|
+
if databag is None:
|
|
167
|
+
databag = {}
|
|
168
|
+
nest_under = self.model_config.get("_NEST_UNDER")
|
|
169
|
+
if nest_under:
|
|
170
|
+
databag[nest_under] = self.model_dump_json(
|
|
171
|
+
by_alias=True,
|
|
172
|
+
# skip keys whose values are default
|
|
173
|
+
exclude_defaults=True,
|
|
174
|
+
)
|
|
175
|
+
return databag
|
|
176
|
+
|
|
177
|
+
dct = self.model_dump(mode="json", by_alias=True, exclude_defaults=True)
|
|
178
|
+
databag.update({k: json.dumps(v) for k, v in dct.items()})
|
|
179
|
+
return databag
|
|
180
|
+
|
|
181
|
+
def _dump_v1(self, databag: Optional[MutableMapping] = None, clear: bool = True):
|
|
182
|
+
"""Dump implementation for pydantic v1."""
|
|
183
|
+
if clear and databag:
|
|
184
|
+
databag.clear()
|
|
185
|
+
|
|
186
|
+
if databag is None:
|
|
187
|
+
databag = {}
|
|
188
|
+
|
|
189
|
+
if self._NEST_UNDER:
|
|
190
|
+
databag[self._NEST_UNDER] = self.json(by_alias=True, exclude_defaults=True)
|
|
191
|
+
return databag
|
|
192
|
+
|
|
193
|
+
dct = json.loads(self.json(by_alias=True, exclude_defaults=True))
|
|
194
|
+
databag.update({k: json.dumps(v) for k, v in dct.items()})
|
|
195
|
+
|
|
196
|
+
return databag
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class _Certificate(pydantic.BaseModel):
|
|
200
|
+
"""Certificate model."""
|
|
201
|
+
|
|
202
|
+
ca: str
|
|
203
|
+
certificate_signing_request: str
|
|
204
|
+
certificate: str
|
|
205
|
+
chain: Optional[List[str]] = None
|
|
206
|
+
revoked: Optional[bool] = None
|
|
207
|
+
|
|
208
|
+
def to_provider_certificate(self, relation_id: int) -> "ProviderCertificate":
|
|
209
|
+
"""Convert to a ProviderCertificate."""
|
|
210
|
+
return ProviderCertificate(
|
|
211
|
+
relation_id=relation_id,
|
|
212
|
+
certificate=Certificate.from_string(self.certificate),
|
|
213
|
+
certificate_signing_request=CertificateSigningRequest.from_string(
|
|
214
|
+
self.certificate_signing_request
|
|
215
|
+
),
|
|
216
|
+
ca=Certificate.from_string(self.ca),
|
|
217
|
+
chain=[Certificate.from_string(certificate) for certificate in self.chain]
|
|
218
|
+
if self.chain
|
|
219
|
+
else [],
|
|
220
|
+
revoked=self.revoked,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class _CertificateSigningRequest(pydantic.BaseModel):
|
|
225
|
+
"""Certificate signing request model."""
|
|
226
|
+
|
|
227
|
+
certificate_signing_request: str
|
|
228
|
+
ca: Optional[bool]
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class _ProviderApplicationData(_DatabagModel):
|
|
232
|
+
"""Provider application data model."""
|
|
233
|
+
|
|
234
|
+
certificates: List[_Certificate] = []
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
class _RequirerData(_DatabagModel):
|
|
238
|
+
"""Requirer data model.
|
|
239
|
+
|
|
240
|
+
The same model is used for the unit and application data.
|
|
241
|
+
"""
|
|
242
|
+
|
|
243
|
+
certificate_signing_requests: List[_CertificateSigningRequest] = []
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
class Mode(Enum):
|
|
247
|
+
"""Enum representing the mode of the certificate request.
|
|
248
|
+
|
|
249
|
+
UNIT (default): Request a certificate for the unit.
|
|
250
|
+
Each unit will manage its private key,
|
|
251
|
+
certificate signing request and certificate.
|
|
252
|
+
APP: Request a certificate for the application.
|
|
253
|
+
Only the leader unit will manage the private key, certificate signing request
|
|
254
|
+
and certificate.
|
|
255
|
+
"""
|
|
256
|
+
|
|
257
|
+
UNIT = 1
|
|
258
|
+
APP = 2
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
@dataclass(frozen=True)
|
|
262
|
+
class PrivateKey:
|
|
263
|
+
"""This class represents a private key."""
|
|
264
|
+
|
|
265
|
+
raw: str
|
|
266
|
+
|
|
267
|
+
def __str__(self):
|
|
268
|
+
"""Return the private key as a string."""
|
|
269
|
+
return self.raw
|
|
270
|
+
|
|
271
|
+
@classmethod
|
|
272
|
+
def from_string(cls, private_key: str) -> "PrivateKey":
|
|
273
|
+
"""Create a PrivateKey object from a private key."""
|
|
274
|
+
return cls(raw=private_key.strip())
|
|
275
|
+
|
|
276
|
+
def is_valid(self) -> bool:
|
|
277
|
+
"""Validate that the private key is PEM-formatted, RSA, and at least 2048 bits."""
|
|
278
|
+
try:
|
|
279
|
+
key = serialization.load_pem_private_key(
|
|
280
|
+
self.raw.encode(),
|
|
281
|
+
password=None,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
if not isinstance(key, rsa.RSAPrivateKey):
|
|
285
|
+
logger.warning("Private key is not an RSA key")
|
|
286
|
+
return False
|
|
287
|
+
|
|
288
|
+
if key.key_size < 2048:
|
|
289
|
+
logger.warning("RSA key size is less than 2048 bits")
|
|
290
|
+
return False
|
|
291
|
+
|
|
292
|
+
return True
|
|
293
|
+
except ValueError:
|
|
294
|
+
logger.warning("Invalid private key format")
|
|
295
|
+
return False
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
@dataclass(frozen=True)
|
|
299
|
+
class Certificate:
|
|
300
|
+
"""This class represents a certificate."""
|
|
301
|
+
|
|
302
|
+
raw: str
|
|
303
|
+
common_name: str
|
|
304
|
+
expiry_time: datetime
|
|
305
|
+
validity_start_time: datetime
|
|
306
|
+
is_ca: bool = False
|
|
307
|
+
sans_dns: Optional[FrozenSet[str]] = frozenset()
|
|
308
|
+
sans_ip: Optional[FrozenSet[str]] = frozenset()
|
|
309
|
+
sans_oid: Optional[FrozenSet[str]] = frozenset()
|
|
310
|
+
email_address: Optional[str] = None
|
|
311
|
+
organization: Optional[str] = None
|
|
312
|
+
organizational_unit: Optional[str] = None
|
|
313
|
+
country_name: Optional[str] = None
|
|
314
|
+
state_or_province_name: Optional[str] = None
|
|
315
|
+
locality_name: Optional[str] = None
|
|
316
|
+
|
|
317
|
+
def __str__(self) -> str:
|
|
318
|
+
"""Return the certificate as a string."""
|
|
319
|
+
return self.raw
|
|
320
|
+
|
|
321
|
+
@classmethod
|
|
322
|
+
def from_string(cls, certificate: str) -> "Certificate":
|
|
323
|
+
"""Create a Certificate object from a certificate."""
|
|
324
|
+
try:
|
|
325
|
+
certificate_object = x509.load_pem_x509_certificate(data=certificate.encode())
|
|
326
|
+
except ValueError as e:
|
|
327
|
+
logger.error("Could not load certificate: %s", e)
|
|
328
|
+
raise TLSCertificatesError("Could not load certificate")
|
|
329
|
+
|
|
330
|
+
common_name = certificate_object.subject.get_attributes_for_oid(NameOID.COMMON_NAME)
|
|
331
|
+
country_name = certificate_object.subject.get_attributes_for_oid(NameOID.COUNTRY_NAME)
|
|
332
|
+
state_or_province_name = certificate_object.subject.get_attributes_for_oid(
|
|
333
|
+
NameOID.STATE_OR_PROVINCE_NAME
|
|
334
|
+
)
|
|
335
|
+
locality_name = certificate_object.subject.get_attributes_for_oid(NameOID.LOCALITY_NAME)
|
|
336
|
+
organization_name = certificate_object.subject.get_attributes_for_oid(
|
|
337
|
+
NameOID.ORGANIZATION_NAME
|
|
338
|
+
)
|
|
339
|
+
organizational_unit = certificate_object.subject.get_attributes_for_oid(
|
|
340
|
+
NameOID.ORGANIZATIONAL_UNIT_NAME
|
|
341
|
+
)
|
|
342
|
+
email_address = certificate_object.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS)
|
|
343
|
+
sans_dns: List[str] = []
|
|
344
|
+
sans_ip: List[str] = []
|
|
345
|
+
sans_oid: List[str] = []
|
|
346
|
+
try:
|
|
347
|
+
sans = certificate_object.extensions.get_extension_for_class(
|
|
348
|
+
x509.SubjectAlternativeName
|
|
349
|
+
).value
|
|
350
|
+
for san in sans:
|
|
351
|
+
if isinstance(san, x509.DNSName):
|
|
352
|
+
sans_dns.append(san.value)
|
|
353
|
+
if isinstance(san, x509.IPAddress):
|
|
354
|
+
sans_ip.append(str(san.value))
|
|
355
|
+
if isinstance(san, x509.RegisteredID):
|
|
356
|
+
sans_oid.append(str(san.value))
|
|
357
|
+
except x509.ExtensionNotFound:
|
|
358
|
+
logger.debug("No SANs found in certificate")
|
|
359
|
+
sans_dns = []
|
|
360
|
+
sans_ip = []
|
|
361
|
+
sans_oid = []
|
|
362
|
+
expiry_time = certificate_object.not_valid_after_utc
|
|
363
|
+
validity_start_time = certificate_object.not_valid_before_utc
|
|
364
|
+
is_ca = False
|
|
365
|
+
try:
|
|
366
|
+
is_ca = certificate_object.extensions.get_extension_for_oid(
|
|
367
|
+
ExtensionOID.BASIC_CONSTRAINTS
|
|
368
|
+
).value.ca # type: ignore[reportAttributeAccessIssue]
|
|
369
|
+
except x509.ExtensionNotFound:
|
|
370
|
+
pass
|
|
371
|
+
|
|
372
|
+
return cls(
|
|
373
|
+
raw=certificate.strip(),
|
|
374
|
+
common_name=str(common_name[0].value),
|
|
375
|
+
is_ca=is_ca,
|
|
376
|
+
country_name=str(country_name[0].value) if country_name else None,
|
|
377
|
+
state_or_province_name=str(state_or_province_name[0].value)
|
|
378
|
+
if state_or_province_name
|
|
379
|
+
else None,
|
|
380
|
+
locality_name=str(locality_name[0].value) if locality_name else None,
|
|
381
|
+
organization=str(organization_name[0].value) if organization_name else None,
|
|
382
|
+
organizational_unit=str(organizational_unit[0].value) if organizational_unit else None,
|
|
383
|
+
email_address=str(email_address[0].value) if email_address else None,
|
|
384
|
+
sans_dns=frozenset(sans_dns),
|
|
385
|
+
sans_ip=frozenset(sans_ip),
|
|
386
|
+
sans_oid=frozenset(sans_oid),
|
|
387
|
+
expiry_time=expiry_time,
|
|
388
|
+
validity_start_time=validity_start_time,
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
def matches_private_key(self, private_key: PrivateKey) -> bool:
|
|
392
|
+
"""Check if this certificate matches a given private key.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
private_key (PrivateKey): The private key to validate against.
|
|
396
|
+
|
|
397
|
+
Returns:
|
|
398
|
+
bool: True if the certificate matches the private key, False otherwise.
|
|
399
|
+
"""
|
|
400
|
+
try:
|
|
401
|
+
cert_object = x509.load_pem_x509_certificate(self.raw.encode())
|
|
402
|
+
key_object = serialization.load_pem_private_key(
|
|
403
|
+
private_key.raw.encode(), password=None
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
cert_public_key = cert_object.public_key()
|
|
407
|
+
key_public_key = key_object.public_key()
|
|
408
|
+
|
|
409
|
+
if not isinstance(cert_public_key, rsa.RSAPublicKey):
|
|
410
|
+
logger.warning("Certificate does not use RSA public key")
|
|
411
|
+
return False
|
|
412
|
+
|
|
413
|
+
if not isinstance(key_public_key, rsa.RSAPublicKey):
|
|
414
|
+
logger.warning("Private key is not an RSA key")
|
|
415
|
+
return False
|
|
416
|
+
|
|
417
|
+
return cert_public_key.public_numbers() == key_public_key.public_numbers()
|
|
418
|
+
except Exception as e:
|
|
419
|
+
logger.warning("Failed to validate certificate and private key match: %s", e)
|
|
420
|
+
return False
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
@dataclass(frozen=True)
|
|
424
|
+
class CertificateSigningRequest:
|
|
425
|
+
"""This class represents a certificate signing request."""
|
|
426
|
+
|
|
427
|
+
raw: str
|
|
428
|
+
common_name: str
|
|
429
|
+
sans_dns: Optional[FrozenSet[str]] = None
|
|
430
|
+
sans_ip: Optional[FrozenSet[str]] = None
|
|
431
|
+
sans_oid: Optional[FrozenSet[str]] = None
|
|
432
|
+
email_address: Optional[str] = None
|
|
433
|
+
organization: Optional[str] = None
|
|
434
|
+
organizational_unit: Optional[str] = None
|
|
435
|
+
country_name: Optional[str] = None
|
|
436
|
+
state_or_province_name: Optional[str] = None
|
|
437
|
+
locality_name: Optional[str] = None
|
|
438
|
+
has_unique_identifier: bool = False
|
|
439
|
+
|
|
440
|
+
def __eq__(self, other: object) -> bool:
|
|
441
|
+
"""Check if two CertificateSigningRequest objects are equal."""
|
|
442
|
+
if not isinstance(other, CertificateSigningRequest):
|
|
443
|
+
return NotImplemented
|
|
444
|
+
return self.raw.strip() == other.raw.strip()
|
|
445
|
+
|
|
446
|
+
def __str__(self) -> str:
|
|
447
|
+
"""Return the CSR as a string."""
|
|
448
|
+
return self.raw
|
|
449
|
+
|
|
450
|
+
@classmethod
|
|
451
|
+
def from_string(cls, csr: str) -> "CertificateSigningRequest":
|
|
452
|
+
"""Create a CertificateSigningRequest object from a CSR."""
|
|
453
|
+
try:
|
|
454
|
+
csr_object = x509.load_pem_x509_csr(csr.encode())
|
|
455
|
+
except ValueError as e:
|
|
456
|
+
logger.error("Could not load CSR: %s", e)
|
|
457
|
+
raise TLSCertificatesError("Could not load CSR")
|
|
458
|
+
common_name = csr_object.subject.get_attributes_for_oid(NameOID.COMMON_NAME)
|
|
459
|
+
country_name = csr_object.subject.get_attributes_for_oid(NameOID.COUNTRY_NAME)
|
|
460
|
+
state_or_province_name = csr_object.subject.get_attributes_for_oid(
|
|
461
|
+
NameOID.STATE_OR_PROVINCE_NAME
|
|
462
|
+
)
|
|
463
|
+
locality_name = csr_object.subject.get_attributes_for_oid(NameOID.LOCALITY_NAME)
|
|
464
|
+
organization_name = csr_object.subject.get_attributes_for_oid(NameOID.ORGANIZATION_NAME)
|
|
465
|
+
organizational_unit = csr_object.subject.get_attributes_for_oid(
|
|
466
|
+
NameOID.ORGANIZATIONAL_UNIT_NAME
|
|
467
|
+
)
|
|
468
|
+
email_address = csr_object.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS)
|
|
469
|
+
unique_identifier = csr_object.subject.get_attributes_for_oid(
|
|
470
|
+
NameOID.X500_UNIQUE_IDENTIFIER
|
|
471
|
+
)
|
|
472
|
+
try:
|
|
473
|
+
sans = csr_object.extensions.get_extension_for_class(x509.SubjectAlternativeName).value
|
|
474
|
+
sans_dns = frozenset(sans.get_values_for_type(x509.DNSName))
|
|
475
|
+
sans_ip = frozenset([str(san) for san in sans.get_values_for_type(x509.IPAddress)])
|
|
476
|
+
sans_oid = frozenset(
|
|
477
|
+
[san.dotted_string for san in sans.get_values_for_type(x509.RegisteredID)]
|
|
478
|
+
)
|
|
479
|
+
except x509.ExtensionNotFound:
|
|
480
|
+
sans = frozenset()
|
|
481
|
+
sans_dns = frozenset()
|
|
482
|
+
sans_ip = frozenset()
|
|
483
|
+
sans_oid = frozenset()
|
|
484
|
+
return cls(
|
|
485
|
+
raw=csr.strip(),
|
|
486
|
+
common_name=str(common_name[0].value),
|
|
487
|
+
country_name=str(country_name[0].value) if country_name else None,
|
|
488
|
+
state_or_province_name=str(state_or_province_name[0].value)
|
|
489
|
+
if state_or_province_name
|
|
490
|
+
else None,
|
|
491
|
+
locality_name=str(locality_name[0].value) if locality_name else None,
|
|
492
|
+
organization=str(organization_name[0].value) if organization_name else None,
|
|
493
|
+
organizational_unit=str(organizational_unit[0].value) if organizational_unit else None,
|
|
494
|
+
email_address=str(email_address[0].value) if email_address else None,
|
|
495
|
+
sans_dns=sans_dns,
|
|
496
|
+
sans_ip=sans_ip,
|
|
497
|
+
sans_oid=sans_oid,
|
|
498
|
+
has_unique_identifier=bool(unique_identifier),
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
def matches_private_key(self, key: PrivateKey) -> bool:
|
|
502
|
+
"""Check if a CSR matches a private key.
|
|
503
|
+
|
|
504
|
+
This function only works with RSA keys.
|
|
505
|
+
|
|
506
|
+
Args:
|
|
507
|
+
key (PrivateKey): Private key
|
|
508
|
+
Returns:
|
|
509
|
+
bool: True/False depending on whether the CSR matches the private key.
|
|
510
|
+
"""
|
|
511
|
+
try:
|
|
512
|
+
csr_object = x509.load_pem_x509_csr(self.raw.encode("utf-8"))
|
|
513
|
+
key_object = serialization.load_pem_private_key(
|
|
514
|
+
data=key.raw.encode("utf-8"), password=None
|
|
515
|
+
)
|
|
516
|
+
key_object_public_key = key_object.public_key()
|
|
517
|
+
csr_object_public_key = csr_object.public_key()
|
|
518
|
+
if not isinstance(key_object_public_key, rsa.RSAPublicKey):
|
|
519
|
+
logger.warning("Key is not an RSA key")
|
|
520
|
+
return False
|
|
521
|
+
if not isinstance(csr_object_public_key, rsa.RSAPublicKey):
|
|
522
|
+
logger.warning("CSR is not an RSA key")
|
|
523
|
+
return False
|
|
524
|
+
if (
|
|
525
|
+
csr_object_public_key.public_numbers().n
|
|
526
|
+
!= key_object_public_key.public_numbers().n
|
|
527
|
+
):
|
|
528
|
+
logger.warning("Public key numbers between CSR and key do not match")
|
|
529
|
+
return False
|
|
530
|
+
except ValueError:
|
|
531
|
+
logger.warning("Could not load certificate or CSR.")
|
|
532
|
+
return False
|
|
533
|
+
return True
|
|
534
|
+
|
|
535
|
+
def matches_certificate(self, certificate: Certificate) -> bool:
|
|
536
|
+
"""Check if a CSR matches a certificate.
|
|
537
|
+
|
|
538
|
+
Args:
|
|
539
|
+
certificate (Certificate): Certificate
|
|
540
|
+
Returns:
|
|
541
|
+
bool: True/False depending on whether the CSR matches the certificate.
|
|
542
|
+
"""
|
|
543
|
+
csr_object = x509.load_pem_x509_csr(self.raw.encode("utf-8"))
|
|
544
|
+
cert_object = x509.load_pem_x509_certificate(certificate.raw.encode("utf-8"))
|
|
545
|
+
return csr_object.public_key() == cert_object.public_key()
|
|
546
|
+
|
|
547
|
+
def get_sha256_hex(self) -> str:
|
|
548
|
+
"""Calculate the hash of the provided data and return the hexadecimal representation."""
|
|
549
|
+
digest = hashes.Hash(hashes.SHA256())
|
|
550
|
+
digest.update(self.raw.encode())
|
|
551
|
+
return digest.finalize().hex()
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
@dataclass(frozen=True)
|
|
555
|
+
class CertificateRequestAttributes:
|
|
556
|
+
"""A representation of the certificate request attributes.
|
|
557
|
+
|
|
558
|
+
This class should be used inside the requirer charm to specify the requested
|
|
559
|
+
attributes for the certificate.
|
|
560
|
+
"""
|
|
561
|
+
|
|
562
|
+
common_name: str
|
|
563
|
+
sans_dns: Optional[FrozenSet[str]] = frozenset()
|
|
564
|
+
sans_ip: Optional[FrozenSet[str]] = frozenset()
|
|
565
|
+
sans_oid: Optional[FrozenSet[str]] = frozenset()
|
|
566
|
+
email_address: Optional[str] = None
|
|
567
|
+
organization: Optional[str] = None
|
|
568
|
+
organizational_unit: Optional[str] = None
|
|
569
|
+
country_name: Optional[str] = None
|
|
570
|
+
state_or_province_name: Optional[str] = None
|
|
571
|
+
locality_name: Optional[str] = None
|
|
572
|
+
is_ca: bool = False
|
|
573
|
+
add_unique_id_to_subject_name: bool = True
|
|
574
|
+
|
|
575
|
+
def is_valid(self) -> bool:
|
|
576
|
+
"""Check whether the certificate request is valid."""
|
|
577
|
+
if not self.common_name:
|
|
578
|
+
return False
|
|
579
|
+
return True
|
|
580
|
+
|
|
581
|
+
def generate_csr(
|
|
582
|
+
self,
|
|
583
|
+
private_key: PrivateKey,
|
|
584
|
+
) -> CertificateSigningRequest:
|
|
585
|
+
"""Generate a CSR using private key and subject.
|
|
586
|
+
|
|
587
|
+
Args:
|
|
588
|
+
private_key (PrivateKey): Private key
|
|
589
|
+
|
|
590
|
+
Returns:
|
|
591
|
+
CertificateSigningRequest: CSR
|
|
592
|
+
"""
|
|
593
|
+
return generate_csr(
|
|
594
|
+
private_key=private_key,
|
|
595
|
+
common_name=self.common_name,
|
|
596
|
+
sans_dns=self.sans_dns,
|
|
597
|
+
sans_ip=self.sans_ip,
|
|
598
|
+
sans_oid=self.sans_oid,
|
|
599
|
+
email_address=self.email_address,
|
|
600
|
+
organization=self.organization,
|
|
601
|
+
organizational_unit=self.organizational_unit,
|
|
602
|
+
country_name=self.country_name,
|
|
603
|
+
state_or_province_name=self.state_or_province_name,
|
|
604
|
+
locality_name=self.locality_name,
|
|
605
|
+
add_unique_id_to_subject_name=self.add_unique_id_to_subject_name,
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
@classmethod
|
|
609
|
+
def from_csr(cls, csr: CertificateSigningRequest, is_ca: bool):
|
|
610
|
+
"""Create a CertificateRequestAttributes object from a CSR."""
|
|
611
|
+
return cls(
|
|
612
|
+
common_name=csr.common_name,
|
|
613
|
+
sans_dns=csr.sans_dns,
|
|
614
|
+
sans_ip=csr.sans_ip,
|
|
615
|
+
sans_oid=csr.sans_oid,
|
|
616
|
+
email_address=csr.email_address,
|
|
617
|
+
organization=csr.organization,
|
|
618
|
+
organizational_unit=csr.organizational_unit,
|
|
619
|
+
country_name=csr.country_name,
|
|
620
|
+
state_or_province_name=csr.state_or_province_name,
|
|
621
|
+
locality_name=csr.locality_name,
|
|
622
|
+
is_ca=is_ca,
|
|
623
|
+
add_unique_id_to_subject_name=csr.has_unique_identifier,
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
@dataclass(frozen=True)
|
|
628
|
+
class ProviderCertificate:
|
|
629
|
+
"""This class represents a certificate provided by the TLS provider."""
|
|
630
|
+
|
|
631
|
+
relation_id: int
|
|
632
|
+
certificate: Certificate
|
|
633
|
+
certificate_signing_request: CertificateSigningRequest
|
|
634
|
+
ca: Certificate
|
|
635
|
+
chain: List[Certificate]
|
|
636
|
+
revoked: Optional[bool] = None
|
|
637
|
+
|
|
638
|
+
def to_json(self) -> str:
|
|
639
|
+
"""Return the object as a JSON string.
|
|
640
|
+
|
|
641
|
+
Returns:
|
|
642
|
+
str: JSON representation of the object
|
|
643
|
+
"""
|
|
644
|
+
return json.dumps(
|
|
645
|
+
{
|
|
646
|
+
"csr": str(self.certificate_signing_request),
|
|
647
|
+
"certificate": str(self.certificate),
|
|
648
|
+
"ca": str(self.ca),
|
|
649
|
+
"chain": [str(cert) for cert in self.chain],
|
|
650
|
+
"revoked": self.revoked,
|
|
651
|
+
}
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
|
|
655
|
+
@dataclass(frozen=True)
|
|
656
|
+
class RequirerCertificateRequest:
|
|
657
|
+
"""This class represents a certificate signing request requested by a specific TLS requirer."""
|
|
658
|
+
|
|
659
|
+
relation_id: int
|
|
660
|
+
certificate_signing_request: CertificateSigningRequest
|
|
661
|
+
is_ca: bool
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
class CertificateAvailableEvent(EventBase):
|
|
665
|
+
"""Charm Event triggered when a TLS certificate is available."""
|
|
666
|
+
|
|
667
|
+
def __init__(
|
|
668
|
+
self,
|
|
669
|
+
handle: Handle,
|
|
670
|
+
certificate: Certificate,
|
|
671
|
+
certificate_signing_request: CertificateSigningRequest,
|
|
672
|
+
ca: Certificate,
|
|
673
|
+
chain: List[Certificate],
|
|
674
|
+
):
|
|
675
|
+
super().__init__(handle)
|
|
676
|
+
self.certificate = certificate
|
|
677
|
+
self.certificate_signing_request = certificate_signing_request
|
|
678
|
+
self.ca = ca
|
|
679
|
+
self.chain = chain
|
|
680
|
+
|
|
681
|
+
def snapshot(self) -> dict:
|
|
682
|
+
"""Return snapshot."""
|
|
683
|
+
return {
|
|
684
|
+
"certificate": str(self.certificate),
|
|
685
|
+
"certificate_signing_request": str(self.certificate_signing_request),
|
|
686
|
+
"ca": str(self.ca),
|
|
687
|
+
"chain": json.dumps([str(certificate) for certificate in self.chain]),
|
|
688
|
+
}
|
|
689
|
+
|
|
690
|
+
def restore(self, snapshot: dict):
|
|
691
|
+
"""Restore snapshot."""
|
|
692
|
+
self.certificate = Certificate.from_string(snapshot["certificate"])
|
|
693
|
+
self.certificate_signing_request = CertificateSigningRequest.from_string(
|
|
694
|
+
snapshot["certificate_signing_request"]
|
|
695
|
+
)
|
|
696
|
+
self.ca = Certificate.from_string(snapshot["ca"])
|
|
697
|
+
chain_strs = json.loads(snapshot["chain"])
|
|
698
|
+
self.chain = [Certificate.from_string(chain_str) for chain_str in chain_strs]
|
|
699
|
+
|
|
700
|
+
def chain_as_pem(self) -> str:
|
|
701
|
+
"""Return full certificate chain as a PEM string."""
|
|
702
|
+
return "\n\n".join([str(cert) for cert in self.chain])
|
|
703
|
+
|
|
704
|
+
|
|
705
|
+
def generate_private_key(
|
|
706
|
+
key_size: int = 2048,
|
|
707
|
+
public_exponent: int = 65537,
|
|
708
|
+
) -> PrivateKey:
|
|
709
|
+
"""Generate a private key with the RSA algorithm.
|
|
710
|
+
|
|
711
|
+
Args:
|
|
712
|
+
key_size (int): Key size in bits, must be at least 2048 bits
|
|
713
|
+
public_exponent: Public exponent.
|
|
714
|
+
|
|
715
|
+
Returns:
|
|
716
|
+
PrivateKey: Private Key
|
|
717
|
+
"""
|
|
718
|
+
if key_size < 2048:
|
|
719
|
+
raise ValueError("Key size must be at least 2048 bits for RSA security")
|
|
720
|
+
private_key = rsa.generate_private_key(
|
|
721
|
+
public_exponent=public_exponent,
|
|
722
|
+
key_size=key_size,
|
|
723
|
+
)
|
|
724
|
+
key_bytes = private_key.private_bytes(
|
|
725
|
+
encoding=serialization.Encoding.PEM,
|
|
726
|
+
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
|
727
|
+
encryption_algorithm=serialization.NoEncryption(),
|
|
728
|
+
)
|
|
729
|
+
return PrivateKey.from_string(key_bytes.decode())
|
|
730
|
+
|
|
731
|
+
|
|
732
|
+
def calculate_relative_datetime(target_time: datetime, fraction: float) -> datetime:
|
|
733
|
+
"""Calculate a datetime that is a given percentage from now to a target time.
|
|
734
|
+
|
|
735
|
+
Args:
|
|
736
|
+
target_time (datetime): The future datetime to interpolate towards.
|
|
737
|
+
fraction (float): Fraction of the interval from now to target_time (0.0-1.0).
|
|
738
|
+
1.0 means return target_time,
|
|
739
|
+
0.9 means return the time after 90% of the interval has passed,
|
|
740
|
+
and 0.0 means return now.
|
|
741
|
+
"""
|
|
742
|
+
if fraction <= 0.0 or fraction > 1.0:
|
|
743
|
+
raise ValueError("Invalid fraction. Must be between 0.0 and 1.0")
|
|
744
|
+
now = datetime.now(timezone.utc)
|
|
745
|
+
time_until_target = target_time - now
|
|
746
|
+
return now + time_until_target * fraction
|
|
747
|
+
|
|
748
|
+
|
|
749
|
+
def chain_has_valid_order(chain: List[str]) -> bool:
|
|
750
|
+
"""Check if the chain has a valid order.
|
|
751
|
+
|
|
752
|
+
Validates that each certificate in the chain is properly signed by the next certificate.
|
|
753
|
+
The chain should be ordered from leaf to root, where each certificate is signed by
|
|
754
|
+
the next one in the chain.
|
|
755
|
+
|
|
756
|
+
Args:
|
|
757
|
+
chain (List[str]): List of certificates in PEM format, ordered from leaf to root
|
|
758
|
+
|
|
759
|
+
Returns:
|
|
760
|
+
bool: True if the chain has a valid order, False otherwise.
|
|
761
|
+
"""
|
|
762
|
+
if len(chain) < 2:
|
|
763
|
+
return True
|
|
764
|
+
|
|
765
|
+
try:
|
|
766
|
+
for i in range(len(chain) - 1):
|
|
767
|
+
cert = x509.load_pem_x509_certificate(chain[i].encode())
|
|
768
|
+
issuer = x509.load_pem_x509_certificate(chain[i + 1].encode())
|
|
769
|
+
cert.verify_directly_issued_by(issuer)
|
|
770
|
+
return True
|
|
771
|
+
except (ValueError, TypeError, InvalidSignature):
|
|
772
|
+
return False
|
|
773
|
+
|
|
774
|
+
|
|
775
|
+
def generate_csr( # noqa: C901
|
|
776
|
+
private_key: PrivateKey,
|
|
777
|
+
common_name: str,
|
|
778
|
+
sans_dns: Optional[FrozenSet[str]] = frozenset(),
|
|
779
|
+
sans_ip: Optional[FrozenSet[str]] = frozenset(),
|
|
780
|
+
sans_oid: Optional[FrozenSet[str]] = frozenset(),
|
|
781
|
+
organization: Optional[str] = None,
|
|
782
|
+
organizational_unit: Optional[str] = None,
|
|
783
|
+
email_address: Optional[str] = None,
|
|
784
|
+
country_name: Optional[str] = None,
|
|
785
|
+
locality_name: Optional[str] = None,
|
|
786
|
+
state_or_province_name: Optional[str] = None,
|
|
787
|
+
add_unique_id_to_subject_name: bool = True,
|
|
788
|
+
) -> CertificateSigningRequest:
|
|
789
|
+
"""Generate a CSR using private key and subject.
|
|
790
|
+
|
|
791
|
+
Args:
|
|
792
|
+
private_key (PrivateKey): Private key
|
|
793
|
+
common_name (str): Common name
|
|
794
|
+
sans_dns (FrozenSet[str]): DNS Subject Alternative Names
|
|
795
|
+
sans_ip (FrozenSet[str]): IP Subject Alternative Names
|
|
796
|
+
sans_oid (FrozenSet[str]): OID Subject Alternative Names
|
|
797
|
+
organization (Optional[str]): Organization name
|
|
798
|
+
organizational_unit (Optional[str]): Organizational unit name
|
|
799
|
+
email_address (Optional[str]): Email address
|
|
800
|
+
country_name (Optional[str]): Country name
|
|
801
|
+
state_or_province_name (Optional[str]): State or province name
|
|
802
|
+
locality_name (Optional[str]): Locality name
|
|
803
|
+
add_unique_id_to_subject_name (bool): Whether a unique ID must be added to the CSR's
|
|
804
|
+
subject name. Always leave to "True" when the CSR is used to request certificates
|
|
805
|
+
using the tls-certificates relation.
|
|
806
|
+
|
|
807
|
+
Returns:
|
|
808
|
+
CertificateSigningRequest: CSR
|
|
809
|
+
"""
|
|
810
|
+
signing_key = serialization.load_pem_private_key(str(private_key).encode(), password=None)
|
|
811
|
+
subject_name = [x509.NameAttribute(x509.NameOID.COMMON_NAME, common_name)]
|
|
812
|
+
if add_unique_id_to_subject_name:
|
|
813
|
+
unique_identifier = uuid.uuid4()
|
|
814
|
+
subject_name.append(
|
|
815
|
+
x509.NameAttribute(x509.NameOID.X500_UNIQUE_IDENTIFIER, str(unique_identifier))
|
|
816
|
+
)
|
|
817
|
+
if organization:
|
|
818
|
+
subject_name.append(x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, organization))
|
|
819
|
+
if organizational_unit:
|
|
820
|
+
subject_name.append(
|
|
821
|
+
x509.NameAttribute(x509.NameOID.ORGANIZATIONAL_UNIT_NAME, organizational_unit)
|
|
822
|
+
)
|
|
823
|
+
if email_address:
|
|
824
|
+
subject_name.append(x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, email_address))
|
|
825
|
+
if country_name:
|
|
826
|
+
subject_name.append(x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country_name))
|
|
827
|
+
if state_or_province_name:
|
|
828
|
+
subject_name.append(
|
|
829
|
+
x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, state_or_province_name)
|
|
830
|
+
)
|
|
831
|
+
if locality_name:
|
|
832
|
+
subject_name.append(x509.NameAttribute(x509.NameOID.LOCALITY_NAME, locality_name))
|
|
833
|
+
csr = x509.CertificateSigningRequestBuilder(subject_name=x509.Name(subject_name))
|
|
834
|
+
|
|
835
|
+
_sans: List[x509.GeneralName] = []
|
|
836
|
+
if sans_oid:
|
|
837
|
+
_sans.extend([x509.RegisteredID(x509.ObjectIdentifier(san)) for san in sans_oid])
|
|
838
|
+
if sans_ip:
|
|
839
|
+
_sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in sans_ip])
|
|
840
|
+
if sans_dns:
|
|
841
|
+
_sans.extend([x509.DNSName(san) for san in sans_dns])
|
|
842
|
+
if _sans:
|
|
843
|
+
csr = csr.add_extension(x509.SubjectAlternativeName(set(_sans)), critical=False)
|
|
844
|
+
signed_certificate = csr.sign(signing_key, hashes.SHA256()) # type: ignore[arg-type]
|
|
845
|
+
csr_str = signed_certificate.public_bytes(serialization.Encoding.PEM).decode()
|
|
846
|
+
return CertificateSigningRequest.from_string(csr_str)
|
|
847
|
+
|
|
848
|
+
|
|
849
|
+
def generate_ca(
|
|
850
|
+
private_key: PrivateKey,
|
|
851
|
+
validity: timedelta,
|
|
852
|
+
common_name: str,
|
|
853
|
+
sans_dns: Optional[FrozenSet[str]] = frozenset(),
|
|
854
|
+
sans_ip: Optional[FrozenSet[str]] = frozenset(),
|
|
855
|
+
sans_oid: Optional[FrozenSet[str]] = frozenset(),
|
|
856
|
+
organization: Optional[str] = None,
|
|
857
|
+
organizational_unit: Optional[str] = None,
|
|
858
|
+
email_address: Optional[str] = None,
|
|
859
|
+
country_name: Optional[str] = None,
|
|
860
|
+
state_or_province_name: Optional[str] = None,
|
|
861
|
+
locality_name: Optional[str] = None,
|
|
862
|
+
) -> Certificate:
|
|
863
|
+
"""Generate a self signed CA Certificate.
|
|
864
|
+
|
|
865
|
+
Args:
|
|
866
|
+
private_key (PrivateKey): Private key
|
|
867
|
+
validity (timedelta): Certificate validity time
|
|
868
|
+
common_name (str): Common Name that can be an IP or a Full Qualified Domain Name (FQDN).
|
|
869
|
+
sans_dns (FrozenSet[str]): DNS Subject Alternative Names
|
|
870
|
+
sans_ip (FrozenSet[str]): IP Subject Alternative Names
|
|
871
|
+
sans_oid (FrozenSet[str]): OID Subject Alternative Names
|
|
872
|
+
organization (Optional[str]): Organization name
|
|
873
|
+
organizational_unit (Optional[str]): Organizational unit name
|
|
874
|
+
email_address (Optional[str]): Email address
|
|
875
|
+
country_name (str): Certificate Issuing country
|
|
876
|
+
state_or_province_name (str): Certificate Issuing state or province
|
|
877
|
+
locality_name (str): Certificate Issuing locality
|
|
878
|
+
|
|
879
|
+
Returns:
|
|
880
|
+
Certificate: CA Certificate.
|
|
881
|
+
"""
|
|
882
|
+
private_key_object = serialization.load_pem_private_key(
|
|
883
|
+
str(private_key).encode(), password=None
|
|
884
|
+
)
|
|
885
|
+
assert isinstance(private_key_object, rsa.RSAPrivateKey)
|
|
886
|
+
subject_name = [x509.NameAttribute(x509.NameOID.COMMON_NAME, common_name)]
|
|
887
|
+
if organization:
|
|
888
|
+
subject_name.append(x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, organization))
|
|
889
|
+
if organizational_unit:
|
|
890
|
+
subject_name.append(
|
|
891
|
+
x509.NameAttribute(x509.NameOID.ORGANIZATIONAL_UNIT_NAME, organizational_unit)
|
|
892
|
+
)
|
|
893
|
+
if email_address:
|
|
894
|
+
subject_name.append(x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, email_address))
|
|
895
|
+
if country_name:
|
|
896
|
+
subject_name.append(x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country_name))
|
|
897
|
+
if state_or_province_name:
|
|
898
|
+
subject_name.append(
|
|
899
|
+
x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, state_or_province_name)
|
|
900
|
+
)
|
|
901
|
+
if locality_name:
|
|
902
|
+
subject_name.append(x509.NameAttribute(x509.NameOID.LOCALITY_NAME, locality_name))
|
|
903
|
+
|
|
904
|
+
subject_identifier_object = x509.SubjectKeyIdentifier.from_public_key(
|
|
905
|
+
private_key_object.public_key()
|
|
906
|
+
)
|
|
907
|
+
subject_identifier = key_identifier = subject_identifier_object.public_bytes()
|
|
908
|
+
key_usage = x509.KeyUsage(
|
|
909
|
+
digital_signature=True,
|
|
910
|
+
key_encipherment=True,
|
|
911
|
+
key_cert_sign=True,
|
|
912
|
+
key_agreement=False,
|
|
913
|
+
content_commitment=False,
|
|
914
|
+
data_encipherment=False,
|
|
915
|
+
crl_sign=False,
|
|
916
|
+
encipher_only=False,
|
|
917
|
+
decipher_only=False,
|
|
918
|
+
)
|
|
919
|
+
|
|
920
|
+
builder = (
|
|
921
|
+
x509.CertificateBuilder()
|
|
922
|
+
.subject_name(x509.Name(subject_name))
|
|
923
|
+
.issuer_name(x509.Name(subject_name))
|
|
924
|
+
.public_key(private_key_object.public_key())
|
|
925
|
+
.serial_number(x509.random_serial_number())
|
|
926
|
+
.not_valid_before(datetime.now(timezone.utc))
|
|
927
|
+
.not_valid_after(datetime.now(timezone.utc) + validity)
|
|
928
|
+
.add_extension(x509.SubjectKeyIdentifier(digest=subject_identifier), critical=False)
|
|
929
|
+
.add_extension(
|
|
930
|
+
x509.AuthorityKeyIdentifier(
|
|
931
|
+
key_identifier=key_identifier,
|
|
932
|
+
authority_cert_issuer=None,
|
|
933
|
+
authority_cert_serial_number=None,
|
|
934
|
+
),
|
|
935
|
+
critical=False,
|
|
936
|
+
)
|
|
937
|
+
.add_extension(key_usage, critical=True)
|
|
938
|
+
.add_extension(
|
|
939
|
+
x509.BasicConstraints(ca=True, path_length=None),
|
|
940
|
+
critical=True,
|
|
941
|
+
)
|
|
942
|
+
)
|
|
943
|
+
san_extension = _san_extension(
|
|
944
|
+
email_address=email_address,
|
|
945
|
+
sans_dns=sans_dns,
|
|
946
|
+
sans_ip=sans_ip,
|
|
947
|
+
sans_oid=sans_oid,
|
|
948
|
+
)
|
|
949
|
+
if san_extension:
|
|
950
|
+
builder = builder.add_extension(san_extension, critical=False)
|
|
951
|
+
cert = builder.sign(private_key_object, hashes.SHA256()) # type: ignore[arg-type]
|
|
952
|
+
ca_cert_str = cert.public_bytes(serialization.Encoding.PEM).decode().strip()
|
|
953
|
+
return Certificate.from_string(ca_cert_str)
|
|
954
|
+
|
|
955
|
+
|
|
956
|
+
def _san_extension(
|
|
957
|
+
email_address: Optional[str] = None,
|
|
958
|
+
sans_dns: Optional[FrozenSet[str]] = frozenset(),
|
|
959
|
+
sans_ip: Optional[FrozenSet[str]] = frozenset(),
|
|
960
|
+
sans_oid: Optional[FrozenSet[str]] = frozenset(),
|
|
961
|
+
) -> Optional[x509.SubjectAlternativeName]:
|
|
962
|
+
sans: List[x509.GeneralName] = []
|
|
963
|
+
if email_address:
|
|
964
|
+
# If an e-mail address was provided, it should always be in the SAN
|
|
965
|
+
sans.append(x509.RFC822Name(email_address))
|
|
966
|
+
if sans_dns:
|
|
967
|
+
sans.extend([x509.DNSName(san) for san in sans_dns])
|
|
968
|
+
if sans_ip:
|
|
969
|
+
sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in sans_ip])
|
|
970
|
+
if sans_oid:
|
|
971
|
+
sans.extend([x509.RegisteredID(x509.ObjectIdentifier(san)) for san in sans_oid])
|
|
972
|
+
if not sans:
|
|
973
|
+
return None
|
|
974
|
+
return x509.SubjectAlternativeName(sans)
|
|
975
|
+
|
|
976
|
+
|
|
977
|
+
def generate_certificate(
|
|
978
|
+
csr: CertificateSigningRequest,
|
|
979
|
+
ca: Certificate,
|
|
980
|
+
ca_private_key: PrivateKey,
|
|
981
|
+
validity: timedelta,
|
|
982
|
+
is_ca: bool = False,
|
|
983
|
+
) -> Certificate:
|
|
984
|
+
"""Generate a TLS certificate based on a CSR.
|
|
985
|
+
|
|
986
|
+
Args:
|
|
987
|
+
csr (CertificateSigningRequest): CSR
|
|
988
|
+
ca (Certificate): CA Certificate
|
|
989
|
+
ca_private_key (PrivateKey): CA private key
|
|
990
|
+
validity (timedelta): Certificate validity time
|
|
991
|
+
is_ca (bool): Whether the certificate is a CA certificate
|
|
992
|
+
|
|
993
|
+
Returns:
|
|
994
|
+
Certificate: Certificate
|
|
995
|
+
"""
|
|
996
|
+
csr_object = x509.load_pem_x509_csr(str(csr).encode())
|
|
997
|
+
subject = csr_object.subject
|
|
998
|
+
ca_pem = x509.load_pem_x509_certificate(str(ca).encode())
|
|
999
|
+
issuer = ca_pem.issuer
|
|
1000
|
+
private_key = serialization.load_pem_private_key(str(ca_private_key).encode(), password=None)
|
|
1001
|
+
|
|
1002
|
+
certificate_builder = (
|
|
1003
|
+
x509.CertificateBuilder()
|
|
1004
|
+
.subject_name(subject)
|
|
1005
|
+
.issuer_name(issuer)
|
|
1006
|
+
.public_key(csr_object.public_key())
|
|
1007
|
+
.serial_number(x509.random_serial_number())
|
|
1008
|
+
.not_valid_before(datetime.now(timezone.utc))
|
|
1009
|
+
.not_valid_after(datetime.now(timezone.utc) + validity)
|
|
1010
|
+
)
|
|
1011
|
+
extensions = _generate_certificate_request_extensions(
|
|
1012
|
+
authority_key_identifier=ca_pem.extensions.get_extension_for_class(
|
|
1013
|
+
x509.SubjectKeyIdentifier
|
|
1014
|
+
).value.key_identifier,
|
|
1015
|
+
csr=csr_object,
|
|
1016
|
+
is_ca=is_ca,
|
|
1017
|
+
)
|
|
1018
|
+
for extension in extensions:
|
|
1019
|
+
try:
|
|
1020
|
+
certificate_builder = certificate_builder.add_extension(
|
|
1021
|
+
extval=extension.value,
|
|
1022
|
+
critical=extension.critical,
|
|
1023
|
+
)
|
|
1024
|
+
except ValueError as e:
|
|
1025
|
+
logger.warning("Failed to add extension %s: %s", extension.oid, e)
|
|
1026
|
+
|
|
1027
|
+
cert = certificate_builder.sign(private_key, hashes.SHA256()) # type: ignore[arg-type]
|
|
1028
|
+
cert_bytes = cert.public_bytes(serialization.Encoding.PEM)
|
|
1029
|
+
return Certificate.from_string(cert_bytes.decode().strip())
|
|
1030
|
+
|
|
1031
|
+
|
|
1032
|
+
def _generate_certificate_request_extensions(
|
|
1033
|
+
authority_key_identifier: bytes,
|
|
1034
|
+
csr: x509.CertificateSigningRequest,
|
|
1035
|
+
is_ca: bool,
|
|
1036
|
+
) -> List[x509.Extension]:
|
|
1037
|
+
"""Generate a list of certificate extensions from a CSR and other known information.
|
|
1038
|
+
|
|
1039
|
+
Args:
|
|
1040
|
+
authority_key_identifier (bytes): Authority key identifier
|
|
1041
|
+
csr (x509.CertificateSigningRequest): CSR
|
|
1042
|
+
is_ca (bool): Whether the certificate is a CA certificate
|
|
1043
|
+
|
|
1044
|
+
Returns:
|
|
1045
|
+
List[x509.Extension]: List of extensions
|
|
1046
|
+
"""
|
|
1047
|
+
cert_extensions_list: List[x509.Extension] = [
|
|
1048
|
+
x509.Extension(
|
|
1049
|
+
oid=ExtensionOID.AUTHORITY_KEY_IDENTIFIER,
|
|
1050
|
+
value=x509.AuthorityKeyIdentifier(
|
|
1051
|
+
key_identifier=authority_key_identifier,
|
|
1052
|
+
authority_cert_issuer=None,
|
|
1053
|
+
authority_cert_serial_number=None,
|
|
1054
|
+
),
|
|
1055
|
+
critical=False,
|
|
1056
|
+
),
|
|
1057
|
+
x509.Extension(
|
|
1058
|
+
oid=ExtensionOID.SUBJECT_KEY_IDENTIFIER,
|
|
1059
|
+
value=x509.SubjectKeyIdentifier.from_public_key(csr.public_key()),
|
|
1060
|
+
critical=False,
|
|
1061
|
+
),
|
|
1062
|
+
x509.Extension(
|
|
1063
|
+
oid=ExtensionOID.BASIC_CONSTRAINTS,
|
|
1064
|
+
critical=True,
|
|
1065
|
+
value=x509.BasicConstraints(ca=is_ca, path_length=None),
|
|
1066
|
+
),
|
|
1067
|
+
]
|
|
1068
|
+
if sans := _generate_subject_alternative_name_extension(csr):
|
|
1069
|
+
cert_extensions_list.append(sans)
|
|
1070
|
+
|
|
1071
|
+
if is_ca:
|
|
1072
|
+
cert_extensions_list.append(
|
|
1073
|
+
x509.Extension(
|
|
1074
|
+
ExtensionOID.KEY_USAGE,
|
|
1075
|
+
critical=True,
|
|
1076
|
+
value=x509.KeyUsage(
|
|
1077
|
+
digital_signature=False,
|
|
1078
|
+
content_commitment=False,
|
|
1079
|
+
key_encipherment=False,
|
|
1080
|
+
data_encipherment=False,
|
|
1081
|
+
key_agreement=False,
|
|
1082
|
+
key_cert_sign=True,
|
|
1083
|
+
crl_sign=True,
|
|
1084
|
+
encipher_only=False,
|
|
1085
|
+
decipher_only=False,
|
|
1086
|
+
),
|
|
1087
|
+
)
|
|
1088
|
+
)
|
|
1089
|
+
|
|
1090
|
+
existing_oids = {ext.oid for ext in cert_extensions_list}
|
|
1091
|
+
for extension in csr.extensions:
|
|
1092
|
+
if extension.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME:
|
|
1093
|
+
continue
|
|
1094
|
+
if extension.oid in existing_oids:
|
|
1095
|
+
logger.warning("Extension %s is managed by the TLS provider, ignoring.", extension.oid)
|
|
1096
|
+
continue
|
|
1097
|
+
cert_extensions_list.append(extension)
|
|
1098
|
+
|
|
1099
|
+
return cert_extensions_list
|
|
1100
|
+
|
|
1101
|
+
|
|
1102
|
+
def _generate_subject_alternative_name_extension(
|
|
1103
|
+
csr: x509.CertificateSigningRequest,
|
|
1104
|
+
) -> Optional[x509.Extension]:
|
|
1105
|
+
sans: List[x509.GeneralName] = []
|
|
1106
|
+
try:
|
|
1107
|
+
loaded_san_ext = csr.extensions.get_extension_for_class(x509.SubjectAlternativeName)
|
|
1108
|
+
sans.extend(
|
|
1109
|
+
[x509.DNSName(name) for name in loaded_san_ext.value.get_values_for_type(x509.DNSName)]
|
|
1110
|
+
)
|
|
1111
|
+
sans.extend(
|
|
1112
|
+
[x509.IPAddress(ip) for ip in loaded_san_ext.value.get_values_for_type(x509.IPAddress)]
|
|
1113
|
+
)
|
|
1114
|
+
sans.extend(
|
|
1115
|
+
[
|
|
1116
|
+
x509.RegisteredID(oid)
|
|
1117
|
+
for oid in loaded_san_ext.value.get_values_for_type(x509.RegisteredID)
|
|
1118
|
+
]
|
|
1119
|
+
)
|
|
1120
|
+
sans.extend(
|
|
1121
|
+
[
|
|
1122
|
+
x509.RFC822Name(name)
|
|
1123
|
+
for name in loaded_san_ext.value.get_values_for_type(x509.RFC822Name)
|
|
1124
|
+
]
|
|
1125
|
+
)
|
|
1126
|
+
except x509.ExtensionNotFound:
|
|
1127
|
+
pass
|
|
1128
|
+
# If email is present in the CSR Subject, make sure it is also in the SANS
|
|
1129
|
+
# to conform to RFC 5280.
|
|
1130
|
+
email = csr.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS)
|
|
1131
|
+
if email:
|
|
1132
|
+
email_rfc822 = x509.RFC822Name(str(email[0].value))
|
|
1133
|
+
if email_rfc822 not in sans:
|
|
1134
|
+
sans.append(email_rfc822)
|
|
1135
|
+
|
|
1136
|
+
return (
|
|
1137
|
+
x509.Extension(
|
|
1138
|
+
oid=ExtensionOID.SUBJECT_ALTERNATIVE_NAME,
|
|
1139
|
+
critical=False,
|
|
1140
|
+
value=x509.SubjectAlternativeName(sans),
|
|
1141
|
+
)
|
|
1142
|
+
if sans
|
|
1143
|
+
else None
|
|
1144
|
+
)
|
|
1145
|
+
|
|
1146
|
+
|
|
1147
|
+
class CertificatesRequirerCharmEvents(CharmEvents):
|
|
1148
|
+
"""List of events that the TLS Certificates requirer charm can leverage."""
|
|
1149
|
+
|
|
1150
|
+
certificate_available = EventSource(CertificateAvailableEvent)
|
|
1151
|
+
|
|
1152
|
+
|
|
1153
|
+
class TLSCertificatesRequiresV4(Object):
|
|
1154
|
+
"""A class to manage the TLS certificates interface for a unit or app."""
|
|
1155
|
+
|
|
1156
|
+
on = CertificatesRequirerCharmEvents() # type: ignore[reportAssignmentType]
|
|
1157
|
+
|
|
1158
|
+
def __init__(
|
|
1159
|
+
self,
|
|
1160
|
+
charm: CharmBase,
|
|
1161
|
+
relationship_name: str,
|
|
1162
|
+
certificate_requests: List[CertificateRequestAttributes],
|
|
1163
|
+
mode: Mode = Mode.UNIT,
|
|
1164
|
+
refresh_events: List[BoundEvent] = [],
|
|
1165
|
+
private_key: Optional[PrivateKey] = None,
|
|
1166
|
+
renewal_relative_time: float = 0.9,
|
|
1167
|
+
):
|
|
1168
|
+
"""Create a new instance of the TLSCertificatesRequiresV4 class.
|
|
1169
|
+
|
|
1170
|
+
Args:
|
|
1171
|
+
charm (CharmBase): The charm instance to relate to.
|
|
1172
|
+
relationship_name (str): The name of the relation that provides the certificates.
|
|
1173
|
+
certificate_requests (List[CertificateRequestAttributes]):
|
|
1174
|
+
A list with the attributes of the certificate requests.
|
|
1175
|
+
mode (Mode): Whether to use unit or app certificates mode. Default is Mode.UNIT.
|
|
1176
|
+
In UNIT mode the requirer will place the csr in the unit relation data.
|
|
1177
|
+
Each unit will manage its private key,
|
|
1178
|
+
certificate signing request and certificate.
|
|
1179
|
+
UNIT mode is for use cases where each unit has its own identity.
|
|
1180
|
+
If you don't know which mode to use, you likely need UNIT.
|
|
1181
|
+
In APP mode the leader unit will place the csr in the app relation databag.
|
|
1182
|
+
APP mode is for use cases where the underlying application needs the certificate
|
|
1183
|
+
for example using it as an intermediate CA to sign other certificates.
|
|
1184
|
+
The certificate can only be accessed by the leader unit.
|
|
1185
|
+
refresh_events (List[BoundEvent]): A list of events to trigger a refresh of
|
|
1186
|
+
the certificates.
|
|
1187
|
+
private_key (Optional[PrivateKey]): The private key to use for the certificates.
|
|
1188
|
+
If provided, it will be used instead of generating a new one.
|
|
1189
|
+
If the key is not valid an exception will be raised.
|
|
1190
|
+
Using this parameter is discouraged,
|
|
1191
|
+
having to pass around private keys manually can be a security concern.
|
|
1192
|
+
Allowing the library to generate and manage the key is the more secure approach.
|
|
1193
|
+
renewal_relative_time (float): The time to renew the certificate relative to its
|
|
1194
|
+
expiry.
|
|
1195
|
+
Default is 0.9, meaning 90% of the validity period.
|
|
1196
|
+
The minimum value is 0.5, meaning 50% of the validity period.
|
|
1197
|
+
If an invalid value is provided, an exception will be raised.
|
|
1198
|
+
"""
|
|
1199
|
+
super().__init__(charm, relationship_name)
|
|
1200
|
+
if not JujuVersion.from_environ().has_secrets:
|
|
1201
|
+
logger.warning("This version of the TLS library requires Juju secrets (Juju >= 3.0)")
|
|
1202
|
+
if not self._mode_is_valid(mode):
|
|
1203
|
+
raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP")
|
|
1204
|
+
for certificate_request in certificate_requests:
|
|
1205
|
+
if not certificate_request.is_valid():
|
|
1206
|
+
raise TLSCertificatesError("Invalid certificate request")
|
|
1207
|
+
self.charm = charm
|
|
1208
|
+
self.relationship_name = relationship_name
|
|
1209
|
+
self.certificate_requests = certificate_requests
|
|
1210
|
+
self.mode = mode
|
|
1211
|
+
if private_key and not private_key.is_valid():
|
|
1212
|
+
raise TLSCertificatesError("Invalid private key")
|
|
1213
|
+
if renewal_relative_time <= 0.5 or renewal_relative_time > 1.0:
|
|
1214
|
+
raise TLSCertificatesError(
|
|
1215
|
+
"Invalid renewal relative time. Must be between 0.0 and 1.0"
|
|
1216
|
+
)
|
|
1217
|
+
self._private_key = private_key
|
|
1218
|
+
self.renewal_relative_time = renewal_relative_time
|
|
1219
|
+
self.framework.observe(charm.on[relationship_name].relation_created, self._configure)
|
|
1220
|
+
self.framework.observe(charm.on[relationship_name].relation_changed, self._configure)
|
|
1221
|
+
self.framework.observe(charm.on.secret_expired, self._on_secret_expired)
|
|
1222
|
+
self.framework.observe(charm.on.secret_remove, self._on_secret_remove)
|
|
1223
|
+
for event in refresh_events:
|
|
1224
|
+
self.framework.observe(event, self._configure)
|
|
1225
|
+
|
|
1226
|
+
def _configure(self, _: Optional[EventBase] = None):
|
|
1227
|
+
"""Handle TLS Certificates Relation Data.
|
|
1228
|
+
|
|
1229
|
+
This method is called during any TLS relation event.
|
|
1230
|
+
It will generate a private key if it doesn't exist yet.
|
|
1231
|
+
It will send certificate requests if they haven't been sent yet.
|
|
1232
|
+
It will find available certificates and emit events.
|
|
1233
|
+
"""
|
|
1234
|
+
if not self._tls_relation_created():
|
|
1235
|
+
logger.debug("TLS relation not created yet.")
|
|
1236
|
+
return
|
|
1237
|
+
self._ensure_private_key()
|
|
1238
|
+
self._cleanup_certificate_requests()
|
|
1239
|
+
self._send_certificate_requests()
|
|
1240
|
+
self._find_available_certificates()
|
|
1241
|
+
|
|
1242
|
+
def _mode_is_valid(self, mode: Mode) -> bool:
|
|
1243
|
+
return mode in [Mode.UNIT, Mode.APP]
|
|
1244
|
+
|
|
1245
|
+
def _validate_secret_exists(self, secret: Secret) -> None:
|
|
1246
|
+
secret.get_info() # Will raise `SecretNotFoundError` if the secret does not exist
|
|
1247
|
+
|
|
1248
|
+
def _on_secret_remove(self, event: SecretRemoveEvent) -> None:
|
|
1249
|
+
"""Handle Secret Removed Event."""
|
|
1250
|
+
try:
|
|
1251
|
+
# Ensure the secret exists before trying to remove it, otherwise
|
|
1252
|
+
# the unit could be stuck in an error state. See the docstring of
|
|
1253
|
+
# `remove_revision` and the below issue for more information.
|
|
1254
|
+
# https://github.com/juju/juju/issues/19036
|
|
1255
|
+
self._validate_secret_exists(event.secret)
|
|
1256
|
+
event.secret.remove_revision(event.revision)
|
|
1257
|
+
except SecretNotFoundError:
|
|
1258
|
+
logger.warning(
|
|
1259
|
+
"No such secret %s, nothing to remove",
|
|
1260
|
+
event.secret.label or event.secret.id,
|
|
1261
|
+
)
|
|
1262
|
+
return
|
|
1263
|
+
|
|
1264
|
+
def _on_secret_expired(self, event: SecretExpiredEvent) -> None:
|
|
1265
|
+
"""Handle Secret Expired Event.
|
|
1266
|
+
|
|
1267
|
+
Renews certificate requests and removes the expired secret.
|
|
1268
|
+
"""
|
|
1269
|
+
if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-certificate"):
|
|
1270
|
+
return
|
|
1271
|
+
try:
|
|
1272
|
+
csr_str = event.secret.get_content(refresh=True)["csr"]
|
|
1273
|
+
except ModelError:
|
|
1274
|
+
logger.error("Failed to get CSR from secret - Skipping")
|
|
1275
|
+
return
|
|
1276
|
+
csr = CertificateSigningRequest.from_string(csr_str)
|
|
1277
|
+
self._renew_certificate_request(csr)
|
|
1278
|
+
event.secret.remove_all_revisions()
|
|
1279
|
+
|
|
1280
|
+
def sync(self) -> None:
|
|
1281
|
+
"""Sync TLS Certificates Relation Data.
|
|
1282
|
+
|
|
1283
|
+
This method allows the requirer to sync the TLS certificates relation data
|
|
1284
|
+
without waiting for the refresh events to be triggered.
|
|
1285
|
+
"""
|
|
1286
|
+
self._configure()
|
|
1287
|
+
|
|
1288
|
+
def renew_certificate(self, certificate: ProviderCertificate) -> None:
|
|
1289
|
+
"""Request the renewal of the provided certificate."""
|
|
1290
|
+
certificate_signing_request = certificate.certificate_signing_request
|
|
1291
|
+
secret_label = self._get_csr_secret_label(certificate_signing_request)
|
|
1292
|
+
try:
|
|
1293
|
+
secret = self.model.get_secret(label=secret_label)
|
|
1294
|
+
except SecretNotFoundError:
|
|
1295
|
+
logger.warning("No matching secret found - Skipping renewal")
|
|
1296
|
+
return
|
|
1297
|
+
current_csr = secret.get_content(refresh=True).get("csr", "")
|
|
1298
|
+
if current_csr != str(certificate_signing_request):
|
|
1299
|
+
logger.warning("No matching CSR found - Skipping renewal")
|
|
1300
|
+
return
|
|
1301
|
+
self._renew_certificate_request(certificate_signing_request)
|
|
1302
|
+
secret.remove_all_revisions()
|
|
1303
|
+
|
|
1304
|
+
def _renew_certificate_request(self, csr: CertificateSigningRequest):
|
|
1305
|
+
"""Remove existing CSR from relation data and create a new one."""
|
|
1306
|
+
self._remove_requirer_csr_from_relation_data(csr)
|
|
1307
|
+
self._send_certificate_requests()
|
|
1308
|
+
logger.info("Renewed certificate request")
|
|
1309
|
+
|
|
1310
|
+
def _remove_requirer_csr_from_relation_data(self, csr: CertificateSigningRequest) -> None:
|
|
1311
|
+
relation = self.model.get_relation(self.relationship_name)
|
|
1312
|
+
if not relation:
|
|
1313
|
+
logger.debug("No relation: %s", self.relationship_name)
|
|
1314
|
+
return
|
|
1315
|
+
if not self.get_csrs_from_requirer_relation_data():
|
|
1316
|
+
logger.info("No CSRs in relation data - Doing nothing")
|
|
1317
|
+
return
|
|
1318
|
+
app_or_unit = self._get_app_or_unit()
|
|
1319
|
+
try:
|
|
1320
|
+
requirer_relation_data = _RequirerData.load(relation.data[app_or_unit])
|
|
1321
|
+
except DataValidationError:
|
|
1322
|
+
logger.warning("Invalid relation data - Skipping removal of CSR")
|
|
1323
|
+
return
|
|
1324
|
+
new_relation_data = copy.deepcopy(requirer_relation_data.certificate_signing_requests)
|
|
1325
|
+
for requirer_csr in new_relation_data:
|
|
1326
|
+
if requirer_csr.certificate_signing_request.strip() == str(csr).strip():
|
|
1327
|
+
new_relation_data.remove(requirer_csr)
|
|
1328
|
+
try:
|
|
1329
|
+
_RequirerData(certificate_signing_requests=new_relation_data).dump(
|
|
1330
|
+
relation.data[app_or_unit]
|
|
1331
|
+
)
|
|
1332
|
+
logger.info("Removed CSR from relation data")
|
|
1333
|
+
except ModelError:
|
|
1334
|
+
logger.warning("Failed to update relation data")
|
|
1335
|
+
|
|
1336
|
+
def _get_app_or_unit(self) -> Union[Application, Unit]:
|
|
1337
|
+
"""Return the unit or app object based on the mode."""
|
|
1338
|
+
if self.mode == Mode.UNIT:
|
|
1339
|
+
return self.model.unit
|
|
1340
|
+
elif self.mode == Mode.APP:
|
|
1341
|
+
return self.model.app
|
|
1342
|
+
raise TLSCertificatesError("Invalid mode")
|
|
1343
|
+
|
|
1344
|
+
@property
|
|
1345
|
+
def private_key(self) -> Optional[PrivateKey]:
|
|
1346
|
+
"""Return the private key."""
|
|
1347
|
+
if self._private_key:
|
|
1348
|
+
return self._private_key
|
|
1349
|
+
if not self._private_key_generated():
|
|
1350
|
+
return None
|
|
1351
|
+
secret = self.charm.model.get_secret(label=self._get_private_key_secret_label())
|
|
1352
|
+
private_key = secret.get_content(refresh=True)["private-key"]
|
|
1353
|
+
return PrivateKey.from_string(private_key)
|
|
1354
|
+
|
|
1355
|
+
def _ensure_private_key(self) -> None:
|
|
1356
|
+
"""Make sure there is a private key to be used.
|
|
1357
|
+
|
|
1358
|
+
It will make sure there is a private key passed by the charm using the private_key
|
|
1359
|
+
parameter or generate a new one otherwise.
|
|
1360
|
+
"""
|
|
1361
|
+
# Remove the generated private key
|
|
1362
|
+
# if one has been passed by the charm using the private_key parameter
|
|
1363
|
+
if self._private_key:
|
|
1364
|
+
self._remove_private_key_secret()
|
|
1365
|
+
return
|
|
1366
|
+
if self._private_key_generated():
|
|
1367
|
+
logger.debug("Private key already generated")
|
|
1368
|
+
return
|
|
1369
|
+
self._generate_private_key()
|
|
1370
|
+
|
|
1371
|
+
def regenerate_private_key(self) -> None:
|
|
1372
|
+
"""Regenerate the private key.
|
|
1373
|
+
|
|
1374
|
+
Generate a new private key, remove old certificate requests and send new ones.
|
|
1375
|
+
|
|
1376
|
+
Raises:
|
|
1377
|
+
TLSCertificatesError: If the private key is passed by the charm using the
|
|
1378
|
+
private_key parameter.
|
|
1379
|
+
"""
|
|
1380
|
+
if self._private_key:
|
|
1381
|
+
raise TLSCertificatesError(
|
|
1382
|
+
"Private key is passed by the charm through the private_key parameter, "
|
|
1383
|
+
"this function can't be used"
|
|
1384
|
+
)
|
|
1385
|
+
if not self._private_key_generated():
|
|
1386
|
+
logger.warning("No private key to regenerate")
|
|
1387
|
+
return
|
|
1388
|
+
self._generate_private_key()
|
|
1389
|
+
self._cleanup_certificate_requests()
|
|
1390
|
+
self._send_certificate_requests()
|
|
1391
|
+
|
|
1392
|
+
def _generate_private_key(self) -> None:
|
|
1393
|
+
"""Generate a new private key and store it in a secret.
|
|
1394
|
+
|
|
1395
|
+
This is the case when the private key used is generated by the library.
|
|
1396
|
+
and not passed by the charm using the private_key parameter.
|
|
1397
|
+
"""
|
|
1398
|
+
self._store_private_key_in_secret(generate_private_key())
|
|
1399
|
+
logger.info("Private key generated")
|
|
1400
|
+
|
|
1401
|
+
def _private_key_generated(self) -> bool:
|
|
1402
|
+
"""Check if a private key is stored in a secret.
|
|
1403
|
+
|
|
1404
|
+
This is the case when the private key used is generated by the library.
|
|
1405
|
+
This should not exist when the private key used
|
|
1406
|
+
is passed by the charm using the private_key parameter.
|
|
1407
|
+
"""
|
|
1408
|
+
try:
|
|
1409
|
+
secret = self.charm.model.get_secret(label=self._get_private_key_secret_label())
|
|
1410
|
+
secret.get_content(refresh=True)
|
|
1411
|
+
return True
|
|
1412
|
+
except SecretNotFoundError:
|
|
1413
|
+
return False
|
|
1414
|
+
|
|
1415
|
+
def _store_private_key_in_secret(self, private_key: PrivateKey) -> None:
|
|
1416
|
+
try:
|
|
1417
|
+
secret = self.charm.model.get_secret(label=self._get_private_key_secret_label())
|
|
1418
|
+
secret.set_content({"private-key": str(private_key)})
|
|
1419
|
+
secret.get_content(refresh=True)
|
|
1420
|
+
except SecretNotFoundError:
|
|
1421
|
+
self.charm.unit.add_secret(
|
|
1422
|
+
content={"private-key": str(private_key)},
|
|
1423
|
+
label=self._get_private_key_secret_label(),
|
|
1424
|
+
)
|
|
1425
|
+
|
|
1426
|
+
def _remove_private_key_secret(self) -> None:
|
|
1427
|
+
"""Remove the private key secret."""
|
|
1428
|
+
try:
|
|
1429
|
+
secret = self.charm.model.get_secret(label=self._get_private_key_secret_label())
|
|
1430
|
+
secret.remove_all_revisions()
|
|
1431
|
+
except SecretNotFoundError:
|
|
1432
|
+
logger.warning("Private key secret not found, nothing to remove")
|
|
1433
|
+
|
|
1434
|
+
def _csr_matches_certificate_request(
|
|
1435
|
+
self, certificate_signing_request: CertificateSigningRequest, is_ca: bool
|
|
1436
|
+
) -> bool:
|
|
1437
|
+
for certificate_request in self.certificate_requests:
|
|
1438
|
+
if certificate_request == CertificateRequestAttributes.from_csr(
|
|
1439
|
+
certificate_signing_request,
|
|
1440
|
+
is_ca,
|
|
1441
|
+
):
|
|
1442
|
+
return True
|
|
1443
|
+
return False
|
|
1444
|
+
|
|
1445
|
+
def _certificate_requested(self, certificate_request: CertificateRequestAttributes) -> bool:
|
|
1446
|
+
if not self.private_key:
|
|
1447
|
+
return False
|
|
1448
|
+
csr = self._certificate_requested_for_attributes(certificate_request)
|
|
1449
|
+
if not csr:
|
|
1450
|
+
return False
|
|
1451
|
+
if not csr.certificate_signing_request.matches_private_key(key=self.private_key):
|
|
1452
|
+
return False
|
|
1453
|
+
return True
|
|
1454
|
+
|
|
1455
|
+
def _certificate_requested_for_attributes(
|
|
1456
|
+
self,
|
|
1457
|
+
certificate_request: CertificateRequestAttributes,
|
|
1458
|
+
) -> Optional[RequirerCertificateRequest]:
|
|
1459
|
+
for requirer_csr in self.get_csrs_from_requirer_relation_data():
|
|
1460
|
+
if certificate_request == CertificateRequestAttributes.from_csr(
|
|
1461
|
+
requirer_csr.certificate_signing_request,
|
|
1462
|
+
requirer_csr.is_ca,
|
|
1463
|
+
):
|
|
1464
|
+
return requirer_csr
|
|
1465
|
+
return None
|
|
1466
|
+
|
|
1467
|
+
def get_csrs_from_requirer_relation_data(self) -> List[RequirerCertificateRequest]:
|
|
1468
|
+
"""Return list of requirer's CSRs from relation data."""
|
|
1469
|
+
if self.mode == Mode.APP and not self.model.unit.is_leader():
|
|
1470
|
+
logger.debug("Not a leader unit - Skipping")
|
|
1471
|
+
return []
|
|
1472
|
+
relation = self.model.get_relation(self.relationship_name)
|
|
1473
|
+
if not relation:
|
|
1474
|
+
logger.debug("No relation: %s", self.relationship_name)
|
|
1475
|
+
return []
|
|
1476
|
+
app_or_unit = self._get_app_or_unit()
|
|
1477
|
+
try:
|
|
1478
|
+
requirer_relation_data = _RequirerData.load(relation.data[app_or_unit])
|
|
1479
|
+
except DataValidationError:
|
|
1480
|
+
logger.warning("Invalid relation data")
|
|
1481
|
+
return []
|
|
1482
|
+
requirer_csrs = []
|
|
1483
|
+
for csr in requirer_relation_data.certificate_signing_requests:
|
|
1484
|
+
requirer_csrs.append(
|
|
1485
|
+
RequirerCertificateRequest(
|
|
1486
|
+
relation_id=relation.id,
|
|
1487
|
+
certificate_signing_request=CertificateSigningRequest.from_string(
|
|
1488
|
+
csr.certificate_signing_request
|
|
1489
|
+
),
|
|
1490
|
+
is_ca=csr.ca if csr.ca else False,
|
|
1491
|
+
)
|
|
1492
|
+
)
|
|
1493
|
+
return requirer_csrs
|
|
1494
|
+
|
|
1495
|
+
def get_provider_certificates(self) -> List[ProviderCertificate]:
|
|
1496
|
+
"""Return list of certificates from the provider's relation data."""
|
|
1497
|
+
return self._load_provider_certificates()
|
|
1498
|
+
|
|
1499
|
+
def _load_provider_certificates(self) -> List[ProviderCertificate]:
|
|
1500
|
+
relation = self.model.get_relation(self.relationship_name)
|
|
1501
|
+
if not relation:
|
|
1502
|
+
logger.debug("No relation: %s", self.relationship_name)
|
|
1503
|
+
return []
|
|
1504
|
+
if not relation.app:
|
|
1505
|
+
logger.debug("No remote app in relation: %s", self.relationship_name)
|
|
1506
|
+
return []
|
|
1507
|
+
try:
|
|
1508
|
+
provider_relation_data = _ProviderApplicationData.load(relation.data[relation.app])
|
|
1509
|
+
except DataValidationError:
|
|
1510
|
+
logger.warning("Invalid relation data")
|
|
1511
|
+
return []
|
|
1512
|
+
return [
|
|
1513
|
+
certificate.to_provider_certificate(relation_id=relation.id)
|
|
1514
|
+
for certificate in provider_relation_data.certificates
|
|
1515
|
+
]
|
|
1516
|
+
|
|
1517
|
+
def _request_certificate(self, csr: CertificateSigningRequest, is_ca: bool) -> None:
|
|
1518
|
+
"""Add CSR to relation data."""
|
|
1519
|
+
if self.mode == Mode.APP and not self.model.unit.is_leader():
|
|
1520
|
+
logger.debug("Not a leader unit - Skipping")
|
|
1521
|
+
return
|
|
1522
|
+
relation = self.model.get_relation(self.relationship_name)
|
|
1523
|
+
if not relation:
|
|
1524
|
+
logger.debug("No relation: %s", self.relationship_name)
|
|
1525
|
+
return
|
|
1526
|
+
new_csr = _CertificateSigningRequest(
|
|
1527
|
+
certificate_signing_request=str(csr).strip(), ca=is_ca
|
|
1528
|
+
)
|
|
1529
|
+
app_or_unit = self._get_app_or_unit()
|
|
1530
|
+
try:
|
|
1531
|
+
requirer_relation_data = _RequirerData.load(relation.data[app_or_unit])
|
|
1532
|
+
except DataValidationError:
|
|
1533
|
+
requirer_relation_data = _RequirerData(
|
|
1534
|
+
certificate_signing_requests=[],
|
|
1535
|
+
)
|
|
1536
|
+
new_relation_data = copy.deepcopy(requirer_relation_data.certificate_signing_requests)
|
|
1537
|
+
new_relation_data.append(new_csr)
|
|
1538
|
+
try:
|
|
1539
|
+
_RequirerData(certificate_signing_requests=new_relation_data).dump(
|
|
1540
|
+
relation.data[app_or_unit]
|
|
1541
|
+
)
|
|
1542
|
+
logger.info("Certificate signing request added to relation data.")
|
|
1543
|
+
except ModelError:
|
|
1544
|
+
logger.warning("Failed to update relation data")
|
|
1545
|
+
|
|
1546
|
+
def _send_certificate_requests(self):
|
|
1547
|
+
if not self.private_key:
|
|
1548
|
+
logger.debug("Private key not generated yet.")
|
|
1549
|
+
return
|
|
1550
|
+
for certificate_request in self.certificate_requests:
|
|
1551
|
+
if not self._certificate_requested(certificate_request):
|
|
1552
|
+
csr = certificate_request.generate_csr(
|
|
1553
|
+
private_key=self.private_key,
|
|
1554
|
+
)
|
|
1555
|
+
if not csr:
|
|
1556
|
+
logger.warning("Failed to generate CSR")
|
|
1557
|
+
continue
|
|
1558
|
+
self._request_certificate(csr=csr, is_ca=certificate_request.is_ca)
|
|
1559
|
+
|
|
1560
|
+
def get_assigned_certificate(
|
|
1561
|
+
self, certificate_request: CertificateRequestAttributes
|
|
1562
|
+
) -> Tuple[Optional[ProviderCertificate], Optional[PrivateKey]]:
|
|
1563
|
+
"""Get the certificate that was assigned to the given certificate request."""
|
|
1564
|
+
for requirer_csr in self.get_csrs_from_requirer_relation_data():
|
|
1565
|
+
if certificate_request == CertificateRequestAttributes.from_csr(
|
|
1566
|
+
requirer_csr.certificate_signing_request,
|
|
1567
|
+
requirer_csr.is_ca,
|
|
1568
|
+
):
|
|
1569
|
+
return self._find_certificate_in_relation_data(requirer_csr), self.private_key
|
|
1570
|
+
return None, None
|
|
1571
|
+
|
|
1572
|
+
def get_assigned_certificates(
|
|
1573
|
+
self,
|
|
1574
|
+
) -> Tuple[List[ProviderCertificate], Optional[PrivateKey]]:
|
|
1575
|
+
"""Get a list of certificates that were assigned to this or app."""
|
|
1576
|
+
assigned_certificates = []
|
|
1577
|
+
for requirer_csr in self.get_csrs_from_requirer_relation_data():
|
|
1578
|
+
if cert := self._find_certificate_in_relation_data(requirer_csr):
|
|
1579
|
+
assigned_certificates.append(cert)
|
|
1580
|
+
return assigned_certificates, self.private_key
|
|
1581
|
+
|
|
1582
|
+
def _find_certificate_in_relation_data(
|
|
1583
|
+
self, csr: RequirerCertificateRequest
|
|
1584
|
+
) -> Optional[ProviderCertificate]:
|
|
1585
|
+
"""Return the certificate that matches the given CSR, validated against the private key."""
|
|
1586
|
+
if not self.private_key:
|
|
1587
|
+
return None
|
|
1588
|
+
for provider_certificate in self.get_provider_certificates():
|
|
1589
|
+
if provider_certificate.certificate_signing_request == csr.certificate_signing_request:
|
|
1590
|
+
if provider_certificate.certificate.is_ca and not csr.is_ca:
|
|
1591
|
+
logger.warning("Non CA certificate requested, got a CA certificate, ignoring")
|
|
1592
|
+
continue
|
|
1593
|
+
elif not provider_certificate.certificate.is_ca and csr.is_ca:
|
|
1594
|
+
logger.warning("CA certificate requested, got a non CA certificate, ignoring")
|
|
1595
|
+
continue
|
|
1596
|
+
if not provider_certificate.certificate.matches_private_key(self.private_key):
|
|
1597
|
+
logger.warning(
|
|
1598
|
+
"Certificate does not match the private key. Ignoring invalid certificate."
|
|
1599
|
+
)
|
|
1600
|
+
continue
|
|
1601
|
+
return provider_certificate
|
|
1602
|
+
return None
|
|
1603
|
+
|
|
1604
|
+
def _find_available_certificates(self):
|
|
1605
|
+
"""Find available certificates and emit events.
|
|
1606
|
+
|
|
1607
|
+
This method will find certificates that are available for the requirer's CSRs.
|
|
1608
|
+
If a certificate is found, it will be set as a secret and an event will be emitted.
|
|
1609
|
+
If a certificate is revoked, the secret will be removed and an event will be emitted.
|
|
1610
|
+
"""
|
|
1611
|
+
requirer_csrs = self.get_csrs_from_requirer_relation_data()
|
|
1612
|
+
csrs = [csr.certificate_signing_request for csr in requirer_csrs]
|
|
1613
|
+
provider_certificates = self.get_provider_certificates()
|
|
1614
|
+
for provider_certificate in provider_certificates:
|
|
1615
|
+
if provider_certificate.certificate_signing_request in csrs:
|
|
1616
|
+
secret_label = self._get_csr_secret_label(
|
|
1617
|
+
provider_certificate.certificate_signing_request
|
|
1618
|
+
)
|
|
1619
|
+
if provider_certificate.revoked:
|
|
1620
|
+
with suppress(SecretNotFoundError):
|
|
1621
|
+
logger.debug(
|
|
1622
|
+
"Removing secret with label %s",
|
|
1623
|
+
secret_label,
|
|
1624
|
+
)
|
|
1625
|
+
secret = self.model.get_secret(label=secret_label)
|
|
1626
|
+
secret.remove_all_revisions()
|
|
1627
|
+
else:
|
|
1628
|
+
if not self._csr_matches_certificate_request(
|
|
1629
|
+
certificate_signing_request=provider_certificate.certificate_signing_request,
|
|
1630
|
+
is_ca=provider_certificate.certificate.is_ca,
|
|
1631
|
+
):
|
|
1632
|
+
logger.debug("Certificate requested for different attributes - Skipping")
|
|
1633
|
+
continue
|
|
1634
|
+
try:
|
|
1635
|
+
secret = self.model.get_secret(label=secret_label)
|
|
1636
|
+
logger.debug("Setting secret with label %s", secret_label)
|
|
1637
|
+
# Juju < 3.6 will create a new revision even if the content is the same
|
|
1638
|
+
if secret.get_content(refresh=True).get("certificate", "") == str(
|
|
1639
|
+
provider_certificate.certificate
|
|
1640
|
+
):
|
|
1641
|
+
logger.debug(
|
|
1642
|
+
"Secret %s with correct certificate already exists", secret_label
|
|
1643
|
+
)
|
|
1644
|
+
continue
|
|
1645
|
+
secret.set_content(
|
|
1646
|
+
content={
|
|
1647
|
+
"certificate": str(provider_certificate.certificate),
|
|
1648
|
+
"csr": str(provider_certificate.certificate_signing_request),
|
|
1649
|
+
}
|
|
1650
|
+
)
|
|
1651
|
+
secret.set_info(
|
|
1652
|
+
expire=calculate_relative_datetime(
|
|
1653
|
+
target_time=provider_certificate.certificate.expiry_time,
|
|
1654
|
+
fraction=self.renewal_relative_time,
|
|
1655
|
+
),
|
|
1656
|
+
)
|
|
1657
|
+
secret.get_content(refresh=True)
|
|
1658
|
+
except SecretNotFoundError:
|
|
1659
|
+
logger.debug("Creating new secret with label %s", secret_label)
|
|
1660
|
+
secret = self.charm.unit.add_secret(
|
|
1661
|
+
content={
|
|
1662
|
+
"certificate": str(provider_certificate.certificate),
|
|
1663
|
+
"csr": str(provider_certificate.certificate_signing_request),
|
|
1664
|
+
},
|
|
1665
|
+
label=secret_label,
|
|
1666
|
+
expire=calculate_relative_datetime(
|
|
1667
|
+
target_time=provider_certificate.certificate.expiry_time,
|
|
1668
|
+
fraction=self.renewal_relative_time,
|
|
1669
|
+
),
|
|
1670
|
+
)
|
|
1671
|
+
self.on.certificate_available.emit(
|
|
1672
|
+
certificate_signing_request=provider_certificate.certificate_signing_request,
|
|
1673
|
+
certificate=provider_certificate.certificate,
|
|
1674
|
+
ca=provider_certificate.ca,
|
|
1675
|
+
chain=provider_certificate.chain,
|
|
1676
|
+
)
|
|
1677
|
+
|
|
1678
|
+
def _cleanup_certificate_requests(self):
|
|
1679
|
+
"""Clean up certificate requests.
|
|
1680
|
+
|
|
1681
|
+
Remove any certificate requests that falls into one of the following categories:
|
|
1682
|
+
- The CSR attributes do not match any of the certificate requests defined in
|
|
1683
|
+
the charm's certificate_requests attribute.
|
|
1684
|
+
- The CSR public key does not match the private key.
|
|
1685
|
+
"""
|
|
1686
|
+
for requirer_csr in self.get_csrs_from_requirer_relation_data():
|
|
1687
|
+
if not self._csr_matches_certificate_request(
|
|
1688
|
+
certificate_signing_request=requirer_csr.certificate_signing_request,
|
|
1689
|
+
is_ca=requirer_csr.is_ca,
|
|
1690
|
+
):
|
|
1691
|
+
self._remove_requirer_csr_from_relation_data(
|
|
1692
|
+
requirer_csr.certificate_signing_request
|
|
1693
|
+
)
|
|
1694
|
+
logger.info(
|
|
1695
|
+
"Removed CSR from relation data because it did not match any certificate request" # noqa: E501
|
|
1696
|
+
)
|
|
1697
|
+
elif (
|
|
1698
|
+
self.private_key
|
|
1699
|
+
and not requirer_csr.certificate_signing_request.matches_private_key(
|
|
1700
|
+
self.private_key
|
|
1701
|
+
)
|
|
1702
|
+
):
|
|
1703
|
+
self._remove_requirer_csr_from_relation_data(
|
|
1704
|
+
requirer_csr.certificate_signing_request
|
|
1705
|
+
)
|
|
1706
|
+
logger.info(
|
|
1707
|
+
"Removed CSR from relation data because it did not match the private key" # noqa: E501
|
|
1708
|
+
)
|
|
1709
|
+
|
|
1710
|
+
def _tls_relation_created(self) -> bool:
|
|
1711
|
+
relation = self.model.get_relation(self.relationship_name)
|
|
1712
|
+
if not relation:
|
|
1713
|
+
return False
|
|
1714
|
+
return True
|
|
1715
|
+
|
|
1716
|
+
def _get_private_key_secret_label(self) -> str:
|
|
1717
|
+
if self.mode == Mode.UNIT:
|
|
1718
|
+
return f"{LIBID}-private-key-{self._get_unit_number()}-{self.relationship_name}"
|
|
1719
|
+
elif self.mode == Mode.APP:
|
|
1720
|
+
return f"{LIBID}-private-key-{self.relationship_name}"
|
|
1721
|
+
else:
|
|
1722
|
+
raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP.")
|
|
1723
|
+
|
|
1724
|
+
def _get_csr_secret_label(self, csr: CertificateSigningRequest) -> str:
|
|
1725
|
+
csr_in_sha256_hex = csr.get_sha256_hex()
|
|
1726
|
+
if self.mode == Mode.UNIT:
|
|
1727
|
+
return f"{LIBID}-certificate-{self._get_unit_number()}-{csr_in_sha256_hex}"
|
|
1728
|
+
elif self.mode == Mode.APP:
|
|
1729
|
+
return f"{LIBID}-certificate-{csr_in_sha256_hex}"
|
|
1730
|
+
else:
|
|
1731
|
+
raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP.")
|
|
1732
|
+
|
|
1733
|
+
def _get_unit_number(self) -> str:
|
|
1734
|
+
return self.model.unit.name.split("/")[1]
|
|
1735
|
+
|
|
1736
|
+
|
|
1737
|
+
class TLSCertificatesProvidesV4(Object):
|
|
1738
|
+
"""TLS certificates provider class to be instantiated by TLS certificates providers."""
|
|
1739
|
+
|
|
1740
|
+
def __init__(self, charm: CharmBase, relationship_name: str):
|
|
1741
|
+
super().__init__(charm, relationship_name)
|
|
1742
|
+
self.framework.observe(charm.on[relationship_name].relation_joined, self._configure)
|
|
1743
|
+
self.framework.observe(charm.on[relationship_name].relation_changed, self._configure)
|
|
1744
|
+
self.framework.observe(charm.on.update_status, self._configure)
|
|
1745
|
+
self.charm = charm
|
|
1746
|
+
self.relationship_name = relationship_name
|
|
1747
|
+
|
|
1748
|
+
def _configure(self, _: EventBase) -> None:
|
|
1749
|
+
"""Handle update status and tls relation changed events.
|
|
1750
|
+
|
|
1751
|
+
This is a common hook triggered on a regular basis.
|
|
1752
|
+
|
|
1753
|
+
Revoke certificates for which no csr exists
|
|
1754
|
+
"""
|
|
1755
|
+
if not self.model.unit.is_leader():
|
|
1756
|
+
return
|
|
1757
|
+
self._remove_certificates_for_which_no_csr_exists()
|
|
1758
|
+
|
|
1759
|
+
def _remove_certificates_for_which_no_csr_exists(self) -> None:
|
|
1760
|
+
provider_certificates = self.get_provider_certificates()
|
|
1761
|
+
requirer_csrs = [
|
|
1762
|
+
request.certificate_signing_request for request in self.get_certificate_requests()
|
|
1763
|
+
]
|
|
1764
|
+
for provider_certificate in provider_certificates:
|
|
1765
|
+
if provider_certificate.certificate_signing_request not in requirer_csrs:
|
|
1766
|
+
tls_relation = self._get_tls_relations(
|
|
1767
|
+
relation_id=provider_certificate.relation_id
|
|
1768
|
+
)
|
|
1769
|
+
self._remove_provider_certificate(
|
|
1770
|
+
certificate=provider_certificate.certificate,
|
|
1771
|
+
relation=tls_relation[0],
|
|
1772
|
+
)
|
|
1773
|
+
|
|
1774
|
+
def _get_tls_relations(self, relation_id: Optional[int] = None) -> List[Relation]:
|
|
1775
|
+
return (
|
|
1776
|
+
[
|
|
1777
|
+
relation
|
|
1778
|
+
for relation in self.model.relations[self.relationship_name]
|
|
1779
|
+
if relation.id == relation_id
|
|
1780
|
+
]
|
|
1781
|
+
if relation_id is not None
|
|
1782
|
+
else self.model.relations.get(self.relationship_name, [])
|
|
1783
|
+
)
|
|
1784
|
+
|
|
1785
|
+
def get_certificate_requests(
|
|
1786
|
+
self, relation_id: Optional[int] = None
|
|
1787
|
+
) -> List[RequirerCertificateRequest]:
|
|
1788
|
+
"""Load certificate requests from the relation data."""
|
|
1789
|
+
relations = self._get_tls_relations(relation_id)
|
|
1790
|
+
requirer_csrs: List[RequirerCertificateRequest] = []
|
|
1791
|
+
for relation in relations:
|
|
1792
|
+
for unit in relation.units:
|
|
1793
|
+
requirer_csrs.extend(self._load_requirer_databag(relation, unit))
|
|
1794
|
+
requirer_csrs.extend(self._load_requirer_databag(relation, relation.app))
|
|
1795
|
+
return requirer_csrs
|
|
1796
|
+
|
|
1797
|
+
def _load_requirer_databag(
|
|
1798
|
+
self, relation: Relation, unit_or_app: Union[Application, Unit]
|
|
1799
|
+
) -> List[RequirerCertificateRequest]:
|
|
1800
|
+
try:
|
|
1801
|
+
requirer_relation_data = _RequirerData.load(relation.data.get(unit_or_app, {}))
|
|
1802
|
+
except DataValidationError:
|
|
1803
|
+
logger.debug("Invalid requirer relation data for %s", unit_or_app.name)
|
|
1804
|
+
return []
|
|
1805
|
+
return [
|
|
1806
|
+
RequirerCertificateRequest(
|
|
1807
|
+
relation_id=relation.id,
|
|
1808
|
+
certificate_signing_request=CertificateSigningRequest.from_string(
|
|
1809
|
+
csr.certificate_signing_request
|
|
1810
|
+
),
|
|
1811
|
+
is_ca=csr.ca if csr.ca else False,
|
|
1812
|
+
)
|
|
1813
|
+
for csr in requirer_relation_data.certificate_signing_requests
|
|
1814
|
+
]
|
|
1815
|
+
|
|
1816
|
+
def _add_provider_certificate(
|
|
1817
|
+
self,
|
|
1818
|
+
relation: Relation,
|
|
1819
|
+
provider_certificate: ProviderCertificate,
|
|
1820
|
+
) -> None:
|
|
1821
|
+
chain = [str(certificate) for certificate in provider_certificate.chain]
|
|
1822
|
+
if chain[0] != str(provider_certificate.certificate):
|
|
1823
|
+
logger.warning(
|
|
1824
|
+
"The order of the chain from the TLS Certificates Provider is incorrect. "
|
|
1825
|
+
"The leaf certificate should be the first element of the chain."
|
|
1826
|
+
)
|
|
1827
|
+
elif not chain_has_valid_order(chain):
|
|
1828
|
+
logger.warning(
|
|
1829
|
+
"The order of the chain from the TLS Certificates Provider is partially incorrect."
|
|
1830
|
+
)
|
|
1831
|
+
new_certificate = _Certificate(
|
|
1832
|
+
certificate=str(provider_certificate.certificate),
|
|
1833
|
+
certificate_signing_request=str(provider_certificate.certificate_signing_request),
|
|
1834
|
+
ca=str(provider_certificate.ca),
|
|
1835
|
+
chain=chain,
|
|
1836
|
+
)
|
|
1837
|
+
provider_certificates = self._load_provider_certificates(relation)
|
|
1838
|
+
if new_certificate in provider_certificates:
|
|
1839
|
+
logger.info("Certificate already in relation data - Doing nothing")
|
|
1840
|
+
return
|
|
1841
|
+
provider_certificates.append(new_certificate)
|
|
1842
|
+
self._dump_provider_certificates(relation=relation, certificates=provider_certificates)
|
|
1843
|
+
|
|
1844
|
+
def _load_provider_certificates(self, relation: Relation) -> List[_Certificate]:
|
|
1845
|
+
try:
|
|
1846
|
+
provider_relation_data = _ProviderApplicationData.load(relation.data[self.charm.app])
|
|
1847
|
+
except DataValidationError:
|
|
1848
|
+
logger.debug("Invalid provider relation data")
|
|
1849
|
+
return []
|
|
1850
|
+
return copy.deepcopy(provider_relation_data.certificates)
|
|
1851
|
+
|
|
1852
|
+
def _dump_provider_certificates(self, relation: Relation, certificates: List[_Certificate]):
|
|
1853
|
+
try:
|
|
1854
|
+
_ProviderApplicationData(certificates=certificates).dump(relation.data[self.model.app])
|
|
1855
|
+
logger.info("Certificate relation data updated")
|
|
1856
|
+
except ModelError:
|
|
1857
|
+
logger.warning("Failed to update relation data")
|
|
1858
|
+
|
|
1859
|
+
def _remove_provider_certificate(
|
|
1860
|
+
self,
|
|
1861
|
+
relation: Relation,
|
|
1862
|
+
certificate: Optional[Certificate] = None,
|
|
1863
|
+
certificate_signing_request: Optional[CertificateSigningRequest] = None,
|
|
1864
|
+
) -> None:
|
|
1865
|
+
"""Remove certificate based on certificate or certificate signing request."""
|
|
1866
|
+
provider_certificates = self._load_provider_certificates(relation)
|
|
1867
|
+
for provider_certificate in provider_certificates:
|
|
1868
|
+
if certificate and provider_certificate.certificate == str(certificate):
|
|
1869
|
+
provider_certificates.remove(provider_certificate)
|
|
1870
|
+
if (
|
|
1871
|
+
certificate_signing_request
|
|
1872
|
+
and provider_certificate.certificate_signing_request
|
|
1873
|
+
== str(certificate_signing_request)
|
|
1874
|
+
):
|
|
1875
|
+
provider_certificates.remove(provider_certificate)
|
|
1876
|
+
self._dump_provider_certificates(relation=relation, certificates=provider_certificates)
|
|
1877
|
+
|
|
1878
|
+
def revoke_all_certificates(self) -> None:
|
|
1879
|
+
"""Revoke all certificates of this provider.
|
|
1880
|
+
|
|
1881
|
+
This method is meant to be used when the Root CA has changed.
|
|
1882
|
+
"""
|
|
1883
|
+
if not self.model.unit.is_leader():
|
|
1884
|
+
logger.warning("Unit is not a leader - will not set relation data")
|
|
1885
|
+
return
|
|
1886
|
+
relations = self._get_tls_relations()
|
|
1887
|
+
for relation in relations:
|
|
1888
|
+
provider_certificates = self._load_provider_certificates(relation)
|
|
1889
|
+
for certificate in provider_certificates:
|
|
1890
|
+
certificate.revoked = True
|
|
1891
|
+
self._dump_provider_certificates(relation=relation, certificates=provider_certificates)
|
|
1892
|
+
|
|
1893
|
+
def set_relation_certificate(
|
|
1894
|
+
self,
|
|
1895
|
+
provider_certificate: ProviderCertificate,
|
|
1896
|
+
) -> None:
|
|
1897
|
+
"""Add certificates to relation data.
|
|
1898
|
+
|
|
1899
|
+
Args:
|
|
1900
|
+
provider_certificate (ProviderCertificate): ProviderCertificate object
|
|
1901
|
+
|
|
1902
|
+
Returns:
|
|
1903
|
+
None
|
|
1904
|
+
"""
|
|
1905
|
+
if not self.model.unit.is_leader():
|
|
1906
|
+
logger.warning("Unit is not a leader - will not set relation data")
|
|
1907
|
+
return
|
|
1908
|
+
certificates_relation = self.model.get_relation(
|
|
1909
|
+
relation_name=self.relationship_name, relation_id=provider_certificate.relation_id
|
|
1910
|
+
)
|
|
1911
|
+
if not certificates_relation:
|
|
1912
|
+
raise TLSCertificatesError(f"Relation {self.relationship_name} does not exist")
|
|
1913
|
+
self._remove_provider_certificate(
|
|
1914
|
+
relation=certificates_relation,
|
|
1915
|
+
certificate_signing_request=provider_certificate.certificate_signing_request,
|
|
1916
|
+
)
|
|
1917
|
+
self._add_provider_certificate(
|
|
1918
|
+
relation=certificates_relation,
|
|
1919
|
+
provider_certificate=provider_certificate,
|
|
1920
|
+
)
|
|
1921
|
+
|
|
1922
|
+
def get_issued_certificates(
|
|
1923
|
+
self, relation_id: Optional[int] = None
|
|
1924
|
+
) -> List[ProviderCertificate]:
|
|
1925
|
+
"""Return a List of issued (non revoked) certificates.
|
|
1926
|
+
|
|
1927
|
+
Returns:
|
|
1928
|
+
List: List of ProviderCertificate objects
|
|
1929
|
+
"""
|
|
1930
|
+
if not self.model.unit.is_leader():
|
|
1931
|
+
logger.warning("Unit is not a leader - will not read relation data")
|
|
1932
|
+
return []
|
|
1933
|
+
provider_certificates = self.get_provider_certificates(relation_id=relation_id)
|
|
1934
|
+
return [certificate for certificate in provider_certificates if not certificate.revoked]
|
|
1935
|
+
|
|
1936
|
+
def get_provider_certificates(
|
|
1937
|
+
self, relation_id: Optional[int] = None
|
|
1938
|
+
) -> List[ProviderCertificate]:
|
|
1939
|
+
"""Return a List of issued certificates."""
|
|
1940
|
+
certificates: List[ProviderCertificate] = []
|
|
1941
|
+
relations = self._get_tls_relations(relation_id)
|
|
1942
|
+
for relation in relations:
|
|
1943
|
+
if not relation.app:
|
|
1944
|
+
logger.warning("Relation %s does not have an application", relation.id)
|
|
1945
|
+
continue
|
|
1946
|
+
for certificate in self._load_provider_certificates(relation):
|
|
1947
|
+
certificates.append(certificate.to_provider_certificate(relation_id=relation.id))
|
|
1948
|
+
return certificates
|
|
1949
|
+
|
|
1950
|
+
def get_unsolicited_certificates(
|
|
1951
|
+
self, relation_id: Optional[int] = None
|
|
1952
|
+
) -> List[ProviderCertificate]:
|
|
1953
|
+
"""Return provider certificates for which no certificate requests exists.
|
|
1954
|
+
|
|
1955
|
+
Those certificates should be revoked.
|
|
1956
|
+
"""
|
|
1957
|
+
unsolicited_certificates: List[ProviderCertificate] = []
|
|
1958
|
+
provider_certificates = self.get_provider_certificates(relation_id=relation_id)
|
|
1959
|
+
requirer_csrs = self.get_certificate_requests(relation_id=relation_id)
|
|
1960
|
+
list_of_csrs = [csr.certificate_signing_request for csr in requirer_csrs]
|
|
1961
|
+
for certificate in provider_certificates:
|
|
1962
|
+
if certificate.certificate_signing_request not in list_of_csrs:
|
|
1963
|
+
unsolicited_certificates.append(certificate)
|
|
1964
|
+
return unsolicited_certificates
|
|
1965
|
+
|
|
1966
|
+
def get_outstanding_certificate_requests(
|
|
1967
|
+
self, relation_id: Optional[int] = None
|
|
1968
|
+
) -> List[RequirerCertificateRequest]:
|
|
1969
|
+
"""Return CSR's for which no certificate has been issued.
|
|
1970
|
+
|
|
1971
|
+
Args:
|
|
1972
|
+
relation_id (int): Relation id
|
|
1973
|
+
|
|
1974
|
+
Returns:
|
|
1975
|
+
list: List of RequirerCertificateRequest objects.
|
|
1976
|
+
"""
|
|
1977
|
+
requirer_csrs = self.get_certificate_requests(relation_id=relation_id)
|
|
1978
|
+
outstanding_csrs: List[RequirerCertificateRequest] = []
|
|
1979
|
+
for relation_csr in requirer_csrs:
|
|
1980
|
+
if not self._certificate_issued_for_csr(
|
|
1981
|
+
csr=relation_csr.certificate_signing_request,
|
|
1982
|
+
relation_id=relation_id,
|
|
1983
|
+
):
|
|
1984
|
+
outstanding_csrs.append(relation_csr)
|
|
1985
|
+
return outstanding_csrs
|
|
1986
|
+
|
|
1987
|
+
def _certificate_issued_for_csr(
|
|
1988
|
+
self, csr: CertificateSigningRequest, relation_id: Optional[int]
|
|
1989
|
+
) -> bool:
|
|
1990
|
+
"""Check whether a certificate has been issued for a given CSR."""
|
|
1991
|
+
issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id)
|
|
1992
|
+
for issued_certificate in issued_certificates_per_csr:
|
|
1993
|
+
if issued_certificate.certificate_signing_request == csr:
|
|
1994
|
+
return csr.matches_certificate(issued_certificate.certificate)
|
|
1995
|
+
return False
|