mongo-charms-single-kernel 1.8.7__py3-none-any.whl → 1.8.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mongo-charms-single-kernel might be problematic. Click here for more details.

Files changed (31) hide show
  1. {mongo_charms_single_kernel-1.8.7.dist-info → mongo_charms_single_kernel-1.8.9.dist-info}/METADATA +1 -1
  2. {mongo_charms_single_kernel-1.8.7.dist-info → mongo_charms_single_kernel-1.8.9.dist-info}/RECORD +30 -28
  3. single_kernel_mongo/config/literals.py +8 -3
  4. single_kernel_mongo/config/models.py +12 -0
  5. single_kernel_mongo/config/relations.py +2 -1
  6. single_kernel_mongo/config/statuses.py +127 -20
  7. single_kernel_mongo/core/operator.py +68 -1
  8. single_kernel_mongo/core/structured_config.py +2 -0
  9. single_kernel_mongo/core/workload.py +10 -4
  10. single_kernel_mongo/events/cluster.py +5 -0
  11. single_kernel_mongo/events/sharding.py +3 -1
  12. single_kernel_mongo/events/tls.py +183 -157
  13. single_kernel_mongo/exceptions.py +0 -8
  14. single_kernel_mongo/lib/charms/operator_libs_linux/v1/systemd.py +288 -0
  15. single_kernel_mongo/lib/charms/tls_certificates_interface/v4/tls_certificates.py +1995 -0
  16. single_kernel_mongo/managers/cluster.py +70 -28
  17. single_kernel_mongo/managers/config.py +14 -8
  18. single_kernel_mongo/managers/mongo.py +1 -1
  19. single_kernel_mongo/managers/mongodb_operator.py +53 -56
  20. single_kernel_mongo/managers/mongos_operator.py +18 -20
  21. single_kernel_mongo/managers/sharding.py +154 -127
  22. single_kernel_mongo/managers/tls.py +223 -206
  23. single_kernel_mongo/state/charm_state.py +39 -16
  24. single_kernel_mongo/state/cluster_state.py +8 -0
  25. single_kernel_mongo/state/config_server_state.py +9 -0
  26. single_kernel_mongo/state/tls_state.py +39 -12
  27. single_kernel_mongo/templates/enable-transparent-huge-pages.service.j2 +14 -0
  28. single_kernel_mongo/utils/helpers.py +4 -19
  29. single_kernel_mongo/lib/charms/tls_certificates_interface/v3/tls_certificates.py +0 -2123
  30. {mongo_charms_single_kernel-1.8.7.dist-info → mongo_charms_single_kernel-1.8.9.dist-info}/WHEEL +0 -0
  31. {mongo_charms_single_kernel-1.8.7.dist-info → mongo_charms_single_kernel-1.8.9.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1995 @@
1
+ # Copyright 2024 Canonical Ltd.
2
+ # See LICENSE file for licensing details.
3
+
4
+ """Charm library for managing TLS certificates (V4).
5
+
6
+ This library contains the Requires and Provides classes for handling the tls-certificates
7
+ interface.
8
+
9
+ Pre-requisites:
10
+ - Juju >= 3.0
11
+ - cryptography >= 43.0.0
12
+ - pydantic >= 1.0
13
+
14
+ Learn more on how-to use the TLS Certificates interface library by reading the documentation:
15
+ - https://charmhub.io/tls-certificates-interface/
16
+
17
+ """ # noqa: D214, D405, D411, D416
18
+
19
+ import copy
20
+ import ipaddress
21
+ import json
22
+ import logging
23
+ import uuid
24
+ from contextlib import suppress
25
+ from dataclasses import dataclass
26
+ from datetime import datetime, timedelta, timezone
27
+ from enum import Enum
28
+ from typing import FrozenSet, List, MutableMapping, Optional, Tuple, Union
29
+
30
+ import pydantic
31
+ from cryptography import x509
32
+ from cryptography.exceptions import InvalidSignature
33
+ from cryptography.hazmat.primitives import hashes, serialization
34
+ from cryptography.hazmat.primitives.asymmetric import rsa
35
+ from cryptography.x509.oid import ExtensionOID, NameOID
36
+ from ops import BoundEvent, CharmBase, CharmEvents, Secret, SecretExpiredEvent, SecretRemoveEvent
37
+ from ops.framework import EventBase, EventSource, Handle, Object
38
+ from ops.jujuversion import JujuVersion
39
+ from ops.model import Application, ModelError, Relation, SecretNotFoundError, Unit
40
+
41
+ # The unique Charmhub library identifier, never change it
42
+ LIBID = "afd8c2bccf834997afce12c2706d2ede"
43
+
44
+ # Increment this major API version when introducing breaking changes
45
+ LIBAPI = 4
46
+
47
+ # Increment this PATCH version before using `charmcraft publish-lib` or reset
48
+ # to 0 if you are raising the major API version
49
+ LIBPATCH = 22
50
+
51
+ PYDEPS = [
52
+ "cryptography>=43.0.0",
53
+ "pydantic",
54
+ ]
55
+ IS_PYDANTIC_V1 = int(pydantic.version.VERSION.split(".")[0]) < 2
56
+
57
+ logger = logging.getLogger(__name__)
58
+
59
+
60
+ class TLSCertificatesError(Exception):
61
+ """Base class for custom errors raised by this library."""
62
+
63
+
64
+ class DataValidationError(TLSCertificatesError):
65
+ """Raised when data validation fails."""
66
+
67
+
68
+ class _DatabagModel(pydantic.BaseModel):
69
+ """Base databag model.
70
+
71
+ Supports both pydantic v1 and v2.
72
+ """
73
+
74
+ if IS_PYDANTIC_V1:
75
+
76
+ class Config:
77
+ """Pydantic config."""
78
+
79
+ # ignore any extra fields in the databag
80
+ extra = "ignore"
81
+ """Ignore any extra fields in the databag."""
82
+ allow_population_by_field_name = True
83
+ """Allow instantiating this class by field name (instead of forcing alias)."""
84
+
85
+ _NEST_UNDER = None
86
+
87
+ model_config = pydantic.ConfigDict(
88
+ # tolerate additional keys in databag
89
+ extra="ignore",
90
+ # Allow instantiating this class by field name (instead of forcing alias).
91
+ populate_by_name=True,
92
+ # Custom config key: whether to nest the whole datastructure (as json)
93
+ # under a field or spread it out at the toplevel.
94
+ _NEST_UNDER=None,
95
+ ) # type: ignore
96
+ """Pydantic config."""
97
+
98
+ @classmethod
99
+ def load(cls, databag: MutableMapping):
100
+ """Load this model from a Juju databag."""
101
+ if IS_PYDANTIC_V1:
102
+ return cls._load_v1(databag)
103
+ nest_under = cls.model_config.get("_NEST_UNDER")
104
+ if nest_under:
105
+ return cls.model_validate(json.loads(databag[nest_under]))
106
+
107
+ try:
108
+ data = {
109
+ k: json.loads(v)
110
+ for k, v in databag.items()
111
+ # Don't attempt to parse model-external values
112
+ if k in {(f.alias or n) for n, f in cls.model_fields.items()}
113
+ }
114
+ except json.JSONDecodeError as e:
115
+ msg = f"invalid databag contents: expecting json. {databag}"
116
+ logger.error(msg)
117
+ raise DataValidationError(msg) from e
118
+
119
+ try:
120
+ return cls.model_validate_json(json.dumps(data))
121
+ except pydantic.ValidationError as e:
122
+ msg = f"failed to validate databag: {databag}"
123
+ logger.debug(msg, exc_info=True)
124
+ raise DataValidationError(msg) from e
125
+
126
+ @classmethod
127
+ def _load_v1(cls, databag: MutableMapping):
128
+ """Load implementation for pydantic v1."""
129
+ if cls._NEST_UNDER:
130
+ return cls.parse_obj(json.loads(databag[cls._NEST_UNDER]))
131
+
132
+ try:
133
+ data = {
134
+ k: json.loads(v)
135
+ for k, v in databag.items()
136
+ # Don't attempt to parse model-external values
137
+ if k in {f.alias for f in cls.__fields__.values()}
138
+ }
139
+ except json.JSONDecodeError as e:
140
+ msg = f"invalid databag contents: expecting json. {databag}"
141
+ logger.error(msg)
142
+ raise DataValidationError(msg) from e
143
+
144
+ try:
145
+ return cls.parse_raw(json.dumps(data)) # type: ignore
146
+ except pydantic.ValidationError as e:
147
+ msg = f"failed to validate databag: {databag}"
148
+ logger.debug(msg, exc_info=True)
149
+ raise DataValidationError(msg) from e
150
+
151
+ def dump(self, databag: Optional[MutableMapping] = None, clear: bool = True):
152
+ """Write the contents of this model to Juju databag.
153
+
154
+ Args:
155
+ databag: The databag to write to.
156
+ clear: Whether to clear the databag before writing.
157
+
158
+ Returns:
159
+ MutableMapping: The databag.
160
+ """
161
+ if IS_PYDANTIC_V1:
162
+ return self._dump_v1(databag, clear)
163
+ if clear and databag:
164
+ databag.clear()
165
+
166
+ if databag is None:
167
+ databag = {}
168
+ nest_under = self.model_config.get("_NEST_UNDER")
169
+ if nest_under:
170
+ databag[nest_under] = self.model_dump_json(
171
+ by_alias=True,
172
+ # skip keys whose values are default
173
+ exclude_defaults=True,
174
+ )
175
+ return databag
176
+
177
+ dct = self.model_dump(mode="json", by_alias=True, exclude_defaults=True)
178
+ databag.update({k: json.dumps(v) for k, v in dct.items()})
179
+ return databag
180
+
181
+ def _dump_v1(self, databag: Optional[MutableMapping] = None, clear: bool = True):
182
+ """Dump implementation for pydantic v1."""
183
+ if clear and databag:
184
+ databag.clear()
185
+
186
+ if databag is None:
187
+ databag = {}
188
+
189
+ if self._NEST_UNDER:
190
+ databag[self._NEST_UNDER] = self.json(by_alias=True, exclude_defaults=True)
191
+ return databag
192
+
193
+ dct = json.loads(self.json(by_alias=True, exclude_defaults=True))
194
+ databag.update({k: json.dumps(v) for k, v in dct.items()})
195
+
196
+ return databag
197
+
198
+
199
+ class _Certificate(pydantic.BaseModel):
200
+ """Certificate model."""
201
+
202
+ ca: str
203
+ certificate_signing_request: str
204
+ certificate: str
205
+ chain: Optional[List[str]] = None
206
+ revoked: Optional[bool] = None
207
+
208
+ def to_provider_certificate(self, relation_id: int) -> "ProviderCertificate":
209
+ """Convert to a ProviderCertificate."""
210
+ return ProviderCertificate(
211
+ relation_id=relation_id,
212
+ certificate=Certificate.from_string(self.certificate),
213
+ certificate_signing_request=CertificateSigningRequest.from_string(
214
+ self.certificate_signing_request
215
+ ),
216
+ ca=Certificate.from_string(self.ca),
217
+ chain=[Certificate.from_string(certificate) for certificate in self.chain]
218
+ if self.chain
219
+ else [],
220
+ revoked=self.revoked,
221
+ )
222
+
223
+
224
+ class _CertificateSigningRequest(pydantic.BaseModel):
225
+ """Certificate signing request model."""
226
+
227
+ certificate_signing_request: str
228
+ ca: Optional[bool]
229
+
230
+
231
+ class _ProviderApplicationData(_DatabagModel):
232
+ """Provider application data model."""
233
+
234
+ certificates: List[_Certificate] = []
235
+
236
+
237
+ class _RequirerData(_DatabagModel):
238
+ """Requirer data model.
239
+
240
+ The same model is used for the unit and application data.
241
+ """
242
+
243
+ certificate_signing_requests: List[_CertificateSigningRequest] = []
244
+
245
+
246
+ class Mode(Enum):
247
+ """Enum representing the mode of the certificate request.
248
+
249
+ UNIT (default): Request a certificate for the unit.
250
+ Each unit will manage its private key,
251
+ certificate signing request and certificate.
252
+ APP: Request a certificate for the application.
253
+ Only the leader unit will manage the private key, certificate signing request
254
+ and certificate.
255
+ """
256
+
257
+ UNIT = 1
258
+ APP = 2
259
+
260
+
261
+ @dataclass(frozen=True)
262
+ class PrivateKey:
263
+ """This class represents a private key."""
264
+
265
+ raw: str
266
+
267
+ def __str__(self):
268
+ """Return the private key as a string."""
269
+ return self.raw
270
+
271
+ @classmethod
272
+ def from_string(cls, private_key: str) -> "PrivateKey":
273
+ """Create a PrivateKey object from a private key."""
274
+ return cls(raw=private_key.strip())
275
+
276
+ def is_valid(self) -> bool:
277
+ """Validate that the private key is PEM-formatted, RSA, and at least 2048 bits."""
278
+ try:
279
+ key = serialization.load_pem_private_key(
280
+ self.raw.encode(),
281
+ password=None,
282
+ )
283
+
284
+ if not isinstance(key, rsa.RSAPrivateKey):
285
+ logger.warning("Private key is not an RSA key")
286
+ return False
287
+
288
+ if key.key_size < 2048:
289
+ logger.warning("RSA key size is less than 2048 bits")
290
+ return False
291
+
292
+ return True
293
+ except ValueError:
294
+ logger.warning("Invalid private key format")
295
+ return False
296
+
297
+
298
+ @dataclass(frozen=True)
299
+ class Certificate:
300
+ """This class represents a certificate."""
301
+
302
+ raw: str
303
+ common_name: str
304
+ expiry_time: datetime
305
+ validity_start_time: datetime
306
+ is_ca: bool = False
307
+ sans_dns: Optional[FrozenSet[str]] = frozenset()
308
+ sans_ip: Optional[FrozenSet[str]] = frozenset()
309
+ sans_oid: Optional[FrozenSet[str]] = frozenset()
310
+ email_address: Optional[str] = None
311
+ organization: Optional[str] = None
312
+ organizational_unit: Optional[str] = None
313
+ country_name: Optional[str] = None
314
+ state_or_province_name: Optional[str] = None
315
+ locality_name: Optional[str] = None
316
+
317
+ def __str__(self) -> str:
318
+ """Return the certificate as a string."""
319
+ return self.raw
320
+
321
+ @classmethod
322
+ def from_string(cls, certificate: str) -> "Certificate":
323
+ """Create a Certificate object from a certificate."""
324
+ try:
325
+ certificate_object = x509.load_pem_x509_certificate(data=certificate.encode())
326
+ except ValueError as e:
327
+ logger.error("Could not load certificate: %s", e)
328
+ raise TLSCertificatesError("Could not load certificate")
329
+
330
+ common_name = certificate_object.subject.get_attributes_for_oid(NameOID.COMMON_NAME)
331
+ country_name = certificate_object.subject.get_attributes_for_oid(NameOID.COUNTRY_NAME)
332
+ state_or_province_name = certificate_object.subject.get_attributes_for_oid(
333
+ NameOID.STATE_OR_PROVINCE_NAME
334
+ )
335
+ locality_name = certificate_object.subject.get_attributes_for_oid(NameOID.LOCALITY_NAME)
336
+ organization_name = certificate_object.subject.get_attributes_for_oid(
337
+ NameOID.ORGANIZATION_NAME
338
+ )
339
+ organizational_unit = certificate_object.subject.get_attributes_for_oid(
340
+ NameOID.ORGANIZATIONAL_UNIT_NAME
341
+ )
342
+ email_address = certificate_object.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS)
343
+ sans_dns: List[str] = []
344
+ sans_ip: List[str] = []
345
+ sans_oid: List[str] = []
346
+ try:
347
+ sans = certificate_object.extensions.get_extension_for_class(
348
+ x509.SubjectAlternativeName
349
+ ).value
350
+ for san in sans:
351
+ if isinstance(san, x509.DNSName):
352
+ sans_dns.append(san.value)
353
+ if isinstance(san, x509.IPAddress):
354
+ sans_ip.append(str(san.value))
355
+ if isinstance(san, x509.RegisteredID):
356
+ sans_oid.append(str(san.value))
357
+ except x509.ExtensionNotFound:
358
+ logger.debug("No SANs found in certificate")
359
+ sans_dns = []
360
+ sans_ip = []
361
+ sans_oid = []
362
+ expiry_time = certificate_object.not_valid_after_utc
363
+ validity_start_time = certificate_object.not_valid_before_utc
364
+ is_ca = False
365
+ try:
366
+ is_ca = certificate_object.extensions.get_extension_for_oid(
367
+ ExtensionOID.BASIC_CONSTRAINTS
368
+ ).value.ca # type: ignore[reportAttributeAccessIssue]
369
+ except x509.ExtensionNotFound:
370
+ pass
371
+
372
+ return cls(
373
+ raw=certificate.strip(),
374
+ common_name=str(common_name[0].value),
375
+ is_ca=is_ca,
376
+ country_name=str(country_name[0].value) if country_name else None,
377
+ state_or_province_name=str(state_or_province_name[0].value)
378
+ if state_or_province_name
379
+ else None,
380
+ locality_name=str(locality_name[0].value) if locality_name else None,
381
+ organization=str(organization_name[0].value) if organization_name else None,
382
+ organizational_unit=str(organizational_unit[0].value) if organizational_unit else None,
383
+ email_address=str(email_address[0].value) if email_address else None,
384
+ sans_dns=frozenset(sans_dns),
385
+ sans_ip=frozenset(sans_ip),
386
+ sans_oid=frozenset(sans_oid),
387
+ expiry_time=expiry_time,
388
+ validity_start_time=validity_start_time,
389
+ )
390
+
391
+ def matches_private_key(self, private_key: PrivateKey) -> bool:
392
+ """Check if this certificate matches a given private key.
393
+
394
+ Args:
395
+ private_key (PrivateKey): The private key to validate against.
396
+
397
+ Returns:
398
+ bool: True if the certificate matches the private key, False otherwise.
399
+ """
400
+ try:
401
+ cert_object = x509.load_pem_x509_certificate(self.raw.encode())
402
+ key_object = serialization.load_pem_private_key(
403
+ private_key.raw.encode(), password=None
404
+ )
405
+
406
+ cert_public_key = cert_object.public_key()
407
+ key_public_key = key_object.public_key()
408
+
409
+ if not isinstance(cert_public_key, rsa.RSAPublicKey):
410
+ logger.warning("Certificate does not use RSA public key")
411
+ return False
412
+
413
+ if not isinstance(key_public_key, rsa.RSAPublicKey):
414
+ logger.warning("Private key is not an RSA key")
415
+ return False
416
+
417
+ return cert_public_key.public_numbers() == key_public_key.public_numbers()
418
+ except Exception as e:
419
+ logger.warning("Failed to validate certificate and private key match: %s", e)
420
+ return False
421
+
422
+
423
+ @dataclass(frozen=True)
424
+ class CertificateSigningRequest:
425
+ """This class represents a certificate signing request."""
426
+
427
+ raw: str
428
+ common_name: str
429
+ sans_dns: Optional[FrozenSet[str]] = None
430
+ sans_ip: Optional[FrozenSet[str]] = None
431
+ sans_oid: Optional[FrozenSet[str]] = None
432
+ email_address: Optional[str] = None
433
+ organization: Optional[str] = None
434
+ organizational_unit: Optional[str] = None
435
+ country_name: Optional[str] = None
436
+ state_or_province_name: Optional[str] = None
437
+ locality_name: Optional[str] = None
438
+ has_unique_identifier: bool = False
439
+
440
+ def __eq__(self, other: object) -> bool:
441
+ """Check if two CertificateSigningRequest objects are equal."""
442
+ if not isinstance(other, CertificateSigningRequest):
443
+ return NotImplemented
444
+ return self.raw.strip() == other.raw.strip()
445
+
446
+ def __str__(self) -> str:
447
+ """Return the CSR as a string."""
448
+ return self.raw
449
+
450
+ @classmethod
451
+ def from_string(cls, csr: str) -> "CertificateSigningRequest":
452
+ """Create a CertificateSigningRequest object from a CSR."""
453
+ try:
454
+ csr_object = x509.load_pem_x509_csr(csr.encode())
455
+ except ValueError as e:
456
+ logger.error("Could not load CSR: %s", e)
457
+ raise TLSCertificatesError("Could not load CSR")
458
+ common_name = csr_object.subject.get_attributes_for_oid(NameOID.COMMON_NAME)
459
+ country_name = csr_object.subject.get_attributes_for_oid(NameOID.COUNTRY_NAME)
460
+ state_or_province_name = csr_object.subject.get_attributes_for_oid(
461
+ NameOID.STATE_OR_PROVINCE_NAME
462
+ )
463
+ locality_name = csr_object.subject.get_attributes_for_oid(NameOID.LOCALITY_NAME)
464
+ organization_name = csr_object.subject.get_attributes_for_oid(NameOID.ORGANIZATION_NAME)
465
+ organizational_unit = csr_object.subject.get_attributes_for_oid(
466
+ NameOID.ORGANIZATIONAL_UNIT_NAME
467
+ )
468
+ email_address = csr_object.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS)
469
+ unique_identifier = csr_object.subject.get_attributes_for_oid(
470
+ NameOID.X500_UNIQUE_IDENTIFIER
471
+ )
472
+ try:
473
+ sans = csr_object.extensions.get_extension_for_class(x509.SubjectAlternativeName).value
474
+ sans_dns = frozenset(sans.get_values_for_type(x509.DNSName))
475
+ sans_ip = frozenset([str(san) for san in sans.get_values_for_type(x509.IPAddress)])
476
+ sans_oid = frozenset(
477
+ [san.dotted_string for san in sans.get_values_for_type(x509.RegisteredID)]
478
+ )
479
+ except x509.ExtensionNotFound:
480
+ sans = frozenset()
481
+ sans_dns = frozenset()
482
+ sans_ip = frozenset()
483
+ sans_oid = frozenset()
484
+ return cls(
485
+ raw=csr.strip(),
486
+ common_name=str(common_name[0].value),
487
+ country_name=str(country_name[0].value) if country_name else None,
488
+ state_or_province_name=str(state_or_province_name[0].value)
489
+ if state_or_province_name
490
+ else None,
491
+ locality_name=str(locality_name[0].value) if locality_name else None,
492
+ organization=str(organization_name[0].value) if organization_name else None,
493
+ organizational_unit=str(organizational_unit[0].value) if organizational_unit else None,
494
+ email_address=str(email_address[0].value) if email_address else None,
495
+ sans_dns=sans_dns,
496
+ sans_ip=sans_ip,
497
+ sans_oid=sans_oid,
498
+ has_unique_identifier=bool(unique_identifier),
499
+ )
500
+
501
+ def matches_private_key(self, key: PrivateKey) -> bool:
502
+ """Check if a CSR matches a private key.
503
+
504
+ This function only works with RSA keys.
505
+
506
+ Args:
507
+ key (PrivateKey): Private key
508
+ Returns:
509
+ bool: True/False depending on whether the CSR matches the private key.
510
+ """
511
+ try:
512
+ csr_object = x509.load_pem_x509_csr(self.raw.encode("utf-8"))
513
+ key_object = serialization.load_pem_private_key(
514
+ data=key.raw.encode("utf-8"), password=None
515
+ )
516
+ key_object_public_key = key_object.public_key()
517
+ csr_object_public_key = csr_object.public_key()
518
+ if not isinstance(key_object_public_key, rsa.RSAPublicKey):
519
+ logger.warning("Key is not an RSA key")
520
+ return False
521
+ if not isinstance(csr_object_public_key, rsa.RSAPublicKey):
522
+ logger.warning("CSR is not an RSA key")
523
+ return False
524
+ if (
525
+ csr_object_public_key.public_numbers().n
526
+ != key_object_public_key.public_numbers().n
527
+ ):
528
+ logger.warning("Public key numbers between CSR and key do not match")
529
+ return False
530
+ except ValueError:
531
+ logger.warning("Could not load certificate or CSR.")
532
+ return False
533
+ return True
534
+
535
+ def matches_certificate(self, certificate: Certificate) -> bool:
536
+ """Check if a CSR matches a certificate.
537
+
538
+ Args:
539
+ certificate (Certificate): Certificate
540
+ Returns:
541
+ bool: True/False depending on whether the CSR matches the certificate.
542
+ """
543
+ csr_object = x509.load_pem_x509_csr(self.raw.encode("utf-8"))
544
+ cert_object = x509.load_pem_x509_certificate(certificate.raw.encode("utf-8"))
545
+ return csr_object.public_key() == cert_object.public_key()
546
+
547
+ def get_sha256_hex(self) -> str:
548
+ """Calculate the hash of the provided data and return the hexadecimal representation."""
549
+ digest = hashes.Hash(hashes.SHA256())
550
+ digest.update(self.raw.encode())
551
+ return digest.finalize().hex()
552
+
553
+
554
+ @dataclass(frozen=True)
555
+ class CertificateRequestAttributes:
556
+ """A representation of the certificate request attributes.
557
+
558
+ This class should be used inside the requirer charm to specify the requested
559
+ attributes for the certificate.
560
+ """
561
+
562
+ common_name: str
563
+ sans_dns: Optional[FrozenSet[str]] = frozenset()
564
+ sans_ip: Optional[FrozenSet[str]] = frozenset()
565
+ sans_oid: Optional[FrozenSet[str]] = frozenset()
566
+ email_address: Optional[str] = None
567
+ organization: Optional[str] = None
568
+ organizational_unit: Optional[str] = None
569
+ country_name: Optional[str] = None
570
+ state_or_province_name: Optional[str] = None
571
+ locality_name: Optional[str] = None
572
+ is_ca: bool = False
573
+ add_unique_id_to_subject_name: bool = True
574
+
575
+ def is_valid(self) -> bool:
576
+ """Check whether the certificate request is valid."""
577
+ if not self.common_name:
578
+ return False
579
+ return True
580
+
581
+ def generate_csr(
582
+ self,
583
+ private_key: PrivateKey,
584
+ ) -> CertificateSigningRequest:
585
+ """Generate a CSR using private key and subject.
586
+
587
+ Args:
588
+ private_key (PrivateKey): Private key
589
+
590
+ Returns:
591
+ CertificateSigningRequest: CSR
592
+ """
593
+ return generate_csr(
594
+ private_key=private_key,
595
+ common_name=self.common_name,
596
+ sans_dns=self.sans_dns,
597
+ sans_ip=self.sans_ip,
598
+ sans_oid=self.sans_oid,
599
+ email_address=self.email_address,
600
+ organization=self.organization,
601
+ organizational_unit=self.organizational_unit,
602
+ country_name=self.country_name,
603
+ state_or_province_name=self.state_or_province_name,
604
+ locality_name=self.locality_name,
605
+ add_unique_id_to_subject_name=self.add_unique_id_to_subject_name,
606
+ )
607
+
608
+ @classmethod
609
+ def from_csr(cls, csr: CertificateSigningRequest, is_ca: bool):
610
+ """Create a CertificateRequestAttributes object from a CSR."""
611
+ return cls(
612
+ common_name=csr.common_name,
613
+ sans_dns=csr.sans_dns,
614
+ sans_ip=csr.sans_ip,
615
+ sans_oid=csr.sans_oid,
616
+ email_address=csr.email_address,
617
+ organization=csr.organization,
618
+ organizational_unit=csr.organizational_unit,
619
+ country_name=csr.country_name,
620
+ state_or_province_name=csr.state_or_province_name,
621
+ locality_name=csr.locality_name,
622
+ is_ca=is_ca,
623
+ add_unique_id_to_subject_name=csr.has_unique_identifier,
624
+ )
625
+
626
+
627
+ @dataclass(frozen=True)
628
+ class ProviderCertificate:
629
+ """This class represents a certificate provided by the TLS provider."""
630
+
631
+ relation_id: int
632
+ certificate: Certificate
633
+ certificate_signing_request: CertificateSigningRequest
634
+ ca: Certificate
635
+ chain: List[Certificate]
636
+ revoked: Optional[bool] = None
637
+
638
+ def to_json(self) -> str:
639
+ """Return the object as a JSON string.
640
+
641
+ Returns:
642
+ str: JSON representation of the object
643
+ """
644
+ return json.dumps(
645
+ {
646
+ "csr": str(self.certificate_signing_request),
647
+ "certificate": str(self.certificate),
648
+ "ca": str(self.ca),
649
+ "chain": [str(cert) for cert in self.chain],
650
+ "revoked": self.revoked,
651
+ }
652
+ )
653
+
654
+
655
+ @dataclass(frozen=True)
656
+ class RequirerCertificateRequest:
657
+ """This class represents a certificate signing request requested by a specific TLS requirer."""
658
+
659
+ relation_id: int
660
+ certificate_signing_request: CertificateSigningRequest
661
+ is_ca: bool
662
+
663
+
664
+ class CertificateAvailableEvent(EventBase):
665
+ """Charm Event triggered when a TLS certificate is available."""
666
+
667
+ def __init__(
668
+ self,
669
+ handle: Handle,
670
+ certificate: Certificate,
671
+ certificate_signing_request: CertificateSigningRequest,
672
+ ca: Certificate,
673
+ chain: List[Certificate],
674
+ ):
675
+ super().__init__(handle)
676
+ self.certificate = certificate
677
+ self.certificate_signing_request = certificate_signing_request
678
+ self.ca = ca
679
+ self.chain = chain
680
+
681
+ def snapshot(self) -> dict:
682
+ """Return snapshot."""
683
+ return {
684
+ "certificate": str(self.certificate),
685
+ "certificate_signing_request": str(self.certificate_signing_request),
686
+ "ca": str(self.ca),
687
+ "chain": json.dumps([str(certificate) for certificate in self.chain]),
688
+ }
689
+
690
+ def restore(self, snapshot: dict):
691
+ """Restore snapshot."""
692
+ self.certificate = Certificate.from_string(snapshot["certificate"])
693
+ self.certificate_signing_request = CertificateSigningRequest.from_string(
694
+ snapshot["certificate_signing_request"]
695
+ )
696
+ self.ca = Certificate.from_string(snapshot["ca"])
697
+ chain_strs = json.loads(snapshot["chain"])
698
+ self.chain = [Certificate.from_string(chain_str) for chain_str in chain_strs]
699
+
700
+ def chain_as_pem(self) -> str:
701
+ """Return full certificate chain as a PEM string."""
702
+ return "\n\n".join([str(cert) for cert in self.chain])
703
+
704
+
705
+ def generate_private_key(
706
+ key_size: int = 2048,
707
+ public_exponent: int = 65537,
708
+ ) -> PrivateKey:
709
+ """Generate a private key with the RSA algorithm.
710
+
711
+ Args:
712
+ key_size (int): Key size in bits, must be at least 2048 bits
713
+ public_exponent: Public exponent.
714
+
715
+ Returns:
716
+ PrivateKey: Private Key
717
+ """
718
+ if key_size < 2048:
719
+ raise ValueError("Key size must be at least 2048 bits for RSA security")
720
+ private_key = rsa.generate_private_key(
721
+ public_exponent=public_exponent,
722
+ key_size=key_size,
723
+ )
724
+ key_bytes = private_key.private_bytes(
725
+ encoding=serialization.Encoding.PEM,
726
+ format=serialization.PrivateFormat.TraditionalOpenSSL,
727
+ encryption_algorithm=serialization.NoEncryption(),
728
+ )
729
+ return PrivateKey.from_string(key_bytes.decode())
730
+
731
+
732
+ def calculate_relative_datetime(target_time: datetime, fraction: float) -> datetime:
733
+ """Calculate a datetime that is a given percentage from now to a target time.
734
+
735
+ Args:
736
+ target_time (datetime): The future datetime to interpolate towards.
737
+ fraction (float): Fraction of the interval from now to target_time (0.0-1.0).
738
+ 1.0 means return target_time,
739
+ 0.9 means return the time after 90% of the interval has passed,
740
+ and 0.0 means return now.
741
+ """
742
+ if fraction <= 0.0 or fraction > 1.0:
743
+ raise ValueError("Invalid fraction. Must be between 0.0 and 1.0")
744
+ now = datetime.now(timezone.utc)
745
+ time_until_target = target_time - now
746
+ return now + time_until_target * fraction
747
+
748
+
749
+ def chain_has_valid_order(chain: List[str]) -> bool:
750
+ """Check if the chain has a valid order.
751
+
752
+ Validates that each certificate in the chain is properly signed by the next certificate.
753
+ The chain should be ordered from leaf to root, where each certificate is signed by
754
+ the next one in the chain.
755
+
756
+ Args:
757
+ chain (List[str]): List of certificates in PEM format, ordered from leaf to root
758
+
759
+ Returns:
760
+ bool: True if the chain has a valid order, False otherwise.
761
+ """
762
+ if len(chain) < 2:
763
+ return True
764
+
765
+ try:
766
+ for i in range(len(chain) - 1):
767
+ cert = x509.load_pem_x509_certificate(chain[i].encode())
768
+ issuer = x509.load_pem_x509_certificate(chain[i + 1].encode())
769
+ cert.verify_directly_issued_by(issuer)
770
+ return True
771
+ except (ValueError, TypeError, InvalidSignature):
772
+ return False
773
+
774
+
775
+ def generate_csr( # noqa: C901
776
+ private_key: PrivateKey,
777
+ common_name: str,
778
+ sans_dns: Optional[FrozenSet[str]] = frozenset(),
779
+ sans_ip: Optional[FrozenSet[str]] = frozenset(),
780
+ sans_oid: Optional[FrozenSet[str]] = frozenset(),
781
+ organization: Optional[str] = None,
782
+ organizational_unit: Optional[str] = None,
783
+ email_address: Optional[str] = None,
784
+ country_name: Optional[str] = None,
785
+ locality_name: Optional[str] = None,
786
+ state_or_province_name: Optional[str] = None,
787
+ add_unique_id_to_subject_name: bool = True,
788
+ ) -> CertificateSigningRequest:
789
+ """Generate a CSR using private key and subject.
790
+
791
+ Args:
792
+ private_key (PrivateKey): Private key
793
+ common_name (str): Common name
794
+ sans_dns (FrozenSet[str]): DNS Subject Alternative Names
795
+ sans_ip (FrozenSet[str]): IP Subject Alternative Names
796
+ sans_oid (FrozenSet[str]): OID Subject Alternative Names
797
+ organization (Optional[str]): Organization name
798
+ organizational_unit (Optional[str]): Organizational unit name
799
+ email_address (Optional[str]): Email address
800
+ country_name (Optional[str]): Country name
801
+ state_or_province_name (Optional[str]): State or province name
802
+ locality_name (Optional[str]): Locality name
803
+ add_unique_id_to_subject_name (bool): Whether a unique ID must be added to the CSR's
804
+ subject name. Always leave to "True" when the CSR is used to request certificates
805
+ using the tls-certificates relation.
806
+
807
+ Returns:
808
+ CertificateSigningRequest: CSR
809
+ """
810
+ signing_key = serialization.load_pem_private_key(str(private_key).encode(), password=None)
811
+ subject_name = [x509.NameAttribute(x509.NameOID.COMMON_NAME, common_name)]
812
+ if add_unique_id_to_subject_name:
813
+ unique_identifier = uuid.uuid4()
814
+ subject_name.append(
815
+ x509.NameAttribute(x509.NameOID.X500_UNIQUE_IDENTIFIER, str(unique_identifier))
816
+ )
817
+ if organization:
818
+ subject_name.append(x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, organization))
819
+ if organizational_unit:
820
+ subject_name.append(
821
+ x509.NameAttribute(x509.NameOID.ORGANIZATIONAL_UNIT_NAME, organizational_unit)
822
+ )
823
+ if email_address:
824
+ subject_name.append(x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, email_address))
825
+ if country_name:
826
+ subject_name.append(x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country_name))
827
+ if state_or_province_name:
828
+ subject_name.append(
829
+ x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, state_or_province_name)
830
+ )
831
+ if locality_name:
832
+ subject_name.append(x509.NameAttribute(x509.NameOID.LOCALITY_NAME, locality_name))
833
+ csr = x509.CertificateSigningRequestBuilder(subject_name=x509.Name(subject_name))
834
+
835
+ _sans: List[x509.GeneralName] = []
836
+ if sans_oid:
837
+ _sans.extend([x509.RegisteredID(x509.ObjectIdentifier(san)) for san in sans_oid])
838
+ if sans_ip:
839
+ _sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in sans_ip])
840
+ if sans_dns:
841
+ _sans.extend([x509.DNSName(san) for san in sans_dns])
842
+ if _sans:
843
+ csr = csr.add_extension(x509.SubjectAlternativeName(set(_sans)), critical=False)
844
+ signed_certificate = csr.sign(signing_key, hashes.SHA256()) # type: ignore[arg-type]
845
+ csr_str = signed_certificate.public_bytes(serialization.Encoding.PEM).decode()
846
+ return CertificateSigningRequest.from_string(csr_str)
847
+
848
+
849
+ def generate_ca(
850
+ private_key: PrivateKey,
851
+ validity: timedelta,
852
+ common_name: str,
853
+ sans_dns: Optional[FrozenSet[str]] = frozenset(),
854
+ sans_ip: Optional[FrozenSet[str]] = frozenset(),
855
+ sans_oid: Optional[FrozenSet[str]] = frozenset(),
856
+ organization: Optional[str] = None,
857
+ organizational_unit: Optional[str] = None,
858
+ email_address: Optional[str] = None,
859
+ country_name: Optional[str] = None,
860
+ state_or_province_name: Optional[str] = None,
861
+ locality_name: Optional[str] = None,
862
+ ) -> Certificate:
863
+ """Generate a self signed CA Certificate.
864
+
865
+ Args:
866
+ private_key (PrivateKey): Private key
867
+ validity (timedelta): Certificate validity time
868
+ common_name (str): Common Name that can be an IP or a Full Qualified Domain Name (FQDN).
869
+ sans_dns (FrozenSet[str]): DNS Subject Alternative Names
870
+ sans_ip (FrozenSet[str]): IP Subject Alternative Names
871
+ sans_oid (FrozenSet[str]): OID Subject Alternative Names
872
+ organization (Optional[str]): Organization name
873
+ organizational_unit (Optional[str]): Organizational unit name
874
+ email_address (Optional[str]): Email address
875
+ country_name (str): Certificate Issuing country
876
+ state_or_province_name (str): Certificate Issuing state or province
877
+ locality_name (str): Certificate Issuing locality
878
+
879
+ Returns:
880
+ Certificate: CA Certificate.
881
+ """
882
+ private_key_object = serialization.load_pem_private_key(
883
+ str(private_key).encode(), password=None
884
+ )
885
+ assert isinstance(private_key_object, rsa.RSAPrivateKey)
886
+ subject_name = [x509.NameAttribute(x509.NameOID.COMMON_NAME, common_name)]
887
+ if organization:
888
+ subject_name.append(x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, organization))
889
+ if organizational_unit:
890
+ subject_name.append(
891
+ x509.NameAttribute(x509.NameOID.ORGANIZATIONAL_UNIT_NAME, organizational_unit)
892
+ )
893
+ if email_address:
894
+ subject_name.append(x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, email_address))
895
+ if country_name:
896
+ subject_name.append(x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country_name))
897
+ if state_or_province_name:
898
+ subject_name.append(
899
+ x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, state_or_province_name)
900
+ )
901
+ if locality_name:
902
+ subject_name.append(x509.NameAttribute(x509.NameOID.LOCALITY_NAME, locality_name))
903
+
904
+ subject_identifier_object = x509.SubjectKeyIdentifier.from_public_key(
905
+ private_key_object.public_key()
906
+ )
907
+ subject_identifier = key_identifier = subject_identifier_object.public_bytes()
908
+ key_usage = x509.KeyUsage(
909
+ digital_signature=True,
910
+ key_encipherment=True,
911
+ key_cert_sign=True,
912
+ key_agreement=False,
913
+ content_commitment=False,
914
+ data_encipherment=False,
915
+ crl_sign=False,
916
+ encipher_only=False,
917
+ decipher_only=False,
918
+ )
919
+
920
+ builder = (
921
+ x509.CertificateBuilder()
922
+ .subject_name(x509.Name(subject_name))
923
+ .issuer_name(x509.Name(subject_name))
924
+ .public_key(private_key_object.public_key())
925
+ .serial_number(x509.random_serial_number())
926
+ .not_valid_before(datetime.now(timezone.utc))
927
+ .not_valid_after(datetime.now(timezone.utc) + validity)
928
+ .add_extension(x509.SubjectKeyIdentifier(digest=subject_identifier), critical=False)
929
+ .add_extension(
930
+ x509.AuthorityKeyIdentifier(
931
+ key_identifier=key_identifier,
932
+ authority_cert_issuer=None,
933
+ authority_cert_serial_number=None,
934
+ ),
935
+ critical=False,
936
+ )
937
+ .add_extension(key_usage, critical=True)
938
+ .add_extension(
939
+ x509.BasicConstraints(ca=True, path_length=None),
940
+ critical=True,
941
+ )
942
+ )
943
+ san_extension = _san_extension(
944
+ email_address=email_address,
945
+ sans_dns=sans_dns,
946
+ sans_ip=sans_ip,
947
+ sans_oid=sans_oid,
948
+ )
949
+ if san_extension:
950
+ builder = builder.add_extension(san_extension, critical=False)
951
+ cert = builder.sign(private_key_object, hashes.SHA256()) # type: ignore[arg-type]
952
+ ca_cert_str = cert.public_bytes(serialization.Encoding.PEM).decode().strip()
953
+ return Certificate.from_string(ca_cert_str)
954
+
955
+
956
+ def _san_extension(
957
+ email_address: Optional[str] = None,
958
+ sans_dns: Optional[FrozenSet[str]] = frozenset(),
959
+ sans_ip: Optional[FrozenSet[str]] = frozenset(),
960
+ sans_oid: Optional[FrozenSet[str]] = frozenset(),
961
+ ) -> Optional[x509.SubjectAlternativeName]:
962
+ sans: List[x509.GeneralName] = []
963
+ if email_address:
964
+ # If an e-mail address was provided, it should always be in the SAN
965
+ sans.append(x509.RFC822Name(email_address))
966
+ if sans_dns:
967
+ sans.extend([x509.DNSName(san) for san in sans_dns])
968
+ if sans_ip:
969
+ sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in sans_ip])
970
+ if sans_oid:
971
+ sans.extend([x509.RegisteredID(x509.ObjectIdentifier(san)) for san in sans_oid])
972
+ if not sans:
973
+ return None
974
+ return x509.SubjectAlternativeName(sans)
975
+
976
+
977
+ def generate_certificate(
978
+ csr: CertificateSigningRequest,
979
+ ca: Certificate,
980
+ ca_private_key: PrivateKey,
981
+ validity: timedelta,
982
+ is_ca: bool = False,
983
+ ) -> Certificate:
984
+ """Generate a TLS certificate based on a CSR.
985
+
986
+ Args:
987
+ csr (CertificateSigningRequest): CSR
988
+ ca (Certificate): CA Certificate
989
+ ca_private_key (PrivateKey): CA private key
990
+ validity (timedelta): Certificate validity time
991
+ is_ca (bool): Whether the certificate is a CA certificate
992
+
993
+ Returns:
994
+ Certificate: Certificate
995
+ """
996
+ csr_object = x509.load_pem_x509_csr(str(csr).encode())
997
+ subject = csr_object.subject
998
+ ca_pem = x509.load_pem_x509_certificate(str(ca).encode())
999
+ issuer = ca_pem.issuer
1000
+ private_key = serialization.load_pem_private_key(str(ca_private_key).encode(), password=None)
1001
+
1002
+ certificate_builder = (
1003
+ x509.CertificateBuilder()
1004
+ .subject_name(subject)
1005
+ .issuer_name(issuer)
1006
+ .public_key(csr_object.public_key())
1007
+ .serial_number(x509.random_serial_number())
1008
+ .not_valid_before(datetime.now(timezone.utc))
1009
+ .not_valid_after(datetime.now(timezone.utc) + validity)
1010
+ )
1011
+ extensions = _generate_certificate_request_extensions(
1012
+ authority_key_identifier=ca_pem.extensions.get_extension_for_class(
1013
+ x509.SubjectKeyIdentifier
1014
+ ).value.key_identifier,
1015
+ csr=csr_object,
1016
+ is_ca=is_ca,
1017
+ )
1018
+ for extension in extensions:
1019
+ try:
1020
+ certificate_builder = certificate_builder.add_extension(
1021
+ extval=extension.value,
1022
+ critical=extension.critical,
1023
+ )
1024
+ except ValueError as e:
1025
+ logger.warning("Failed to add extension %s: %s", extension.oid, e)
1026
+
1027
+ cert = certificate_builder.sign(private_key, hashes.SHA256()) # type: ignore[arg-type]
1028
+ cert_bytes = cert.public_bytes(serialization.Encoding.PEM)
1029
+ return Certificate.from_string(cert_bytes.decode().strip())
1030
+
1031
+
1032
+ def _generate_certificate_request_extensions(
1033
+ authority_key_identifier: bytes,
1034
+ csr: x509.CertificateSigningRequest,
1035
+ is_ca: bool,
1036
+ ) -> List[x509.Extension]:
1037
+ """Generate a list of certificate extensions from a CSR and other known information.
1038
+
1039
+ Args:
1040
+ authority_key_identifier (bytes): Authority key identifier
1041
+ csr (x509.CertificateSigningRequest): CSR
1042
+ is_ca (bool): Whether the certificate is a CA certificate
1043
+
1044
+ Returns:
1045
+ List[x509.Extension]: List of extensions
1046
+ """
1047
+ cert_extensions_list: List[x509.Extension] = [
1048
+ x509.Extension(
1049
+ oid=ExtensionOID.AUTHORITY_KEY_IDENTIFIER,
1050
+ value=x509.AuthorityKeyIdentifier(
1051
+ key_identifier=authority_key_identifier,
1052
+ authority_cert_issuer=None,
1053
+ authority_cert_serial_number=None,
1054
+ ),
1055
+ critical=False,
1056
+ ),
1057
+ x509.Extension(
1058
+ oid=ExtensionOID.SUBJECT_KEY_IDENTIFIER,
1059
+ value=x509.SubjectKeyIdentifier.from_public_key(csr.public_key()),
1060
+ critical=False,
1061
+ ),
1062
+ x509.Extension(
1063
+ oid=ExtensionOID.BASIC_CONSTRAINTS,
1064
+ critical=True,
1065
+ value=x509.BasicConstraints(ca=is_ca, path_length=None),
1066
+ ),
1067
+ ]
1068
+ if sans := _generate_subject_alternative_name_extension(csr):
1069
+ cert_extensions_list.append(sans)
1070
+
1071
+ if is_ca:
1072
+ cert_extensions_list.append(
1073
+ x509.Extension(
1074
+ ExtensionOID.KEY_USAGE,
1075
+ critical=True,
1076
+ value=x509.KeyUsage(
1077
+ digital_signature=False,
1078
+ content_commitment=False,
1079
+ key_encipherment=False,
1080
+ data_encipherment=False,
1081
+ key_agreement=False,
1082
+ key_cert_sign=True,
1083
+ crl_sign=True,
1084
+ encipher_only=False,
1085
+ decipher_only=False,
1086
+ ),
1087
+ )
1088
+ )
1089
+
1090
+ existing_oids = {ext.oid for ext in cert_extensions_list}
1091
+ for extension in csr.extensions:
1092
+ if extension.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME:
1093
+ continue
1094
+ if extension.oid in existing_oids:
1095
+ logger.warning("Extension %s is managed by the TLS provider, ignoring.", extension.oid)
1096
+ continue
1097
+ cert_extensions_list.append(extension)
1098
+
1099
+ return cert_extensions_list
1100
+
1101
+
1102
+ def _generate_subject_alternative_name_extension(
1103
+ csr: x509.CertificateSigningRequest,
1104
+ ) -> Optional[x509.Extension]:
1105
+ sans: List[x509.GeneralName] = []
1106
+ try:
1107
+ loaded_san_ext = csr.extensions.get_extension_for_class(x509.SubjectAlternativeName)
1108
+ sans.extend(
1109
+ [x509.DNSName(name) for name in loaded_san_ext.value.get_values_for_type(x509.DNSName)]
1110
+ )
1111
+ sans.extend(
1112
+ [x509.IPAddress(ip) for ip in loaded_san_ext.value.get_values_for_type(x509.IPAddress)]
1113
+ )
1114
+ sans.extend(
1115
+ [
1116
+ x509.RegisteredID(oid)
1117
+ for oid in loaded_san_ext.value.get_values_for_type(x509.RegisteredID)
1118
+ ]
1119
+ )
1120
+ sans.extend(
1121
+ [
1122
+ x509.RFC822Name(name)
1123
+ for name in loaded_san_ext.value.get_values_for_type(x509.RFC822Name)
1124
+ ]
1125
+ )
1126
+ except x509.ExtensionNotFound:
1127
+ pass
1128
+ # If email is present in the CSR Subject, make sure it is also in the SANS
1129
+ # to conform to RFC 5280.
1130
+ email = csr.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS)
1131
+ if email:
1132
+ email_rfc822 = x509.RFC822Name(str(email[0].value))
1133
+ if email_rfc822 not in sans:
1134
+ sans.append(email_rfc822)
1135
+
1136
+ return (
1137
+ x509.Extension(
1138
+ oid=ExtensionOID.SUBJECT_ALTERNATIVE_NAME,
1139
+ critical=False,
1140
+ value=x509.SubjectAlternativeName(sans),
1141
+ )
1142
+ if sans
1143
+ else None
1144
+ )
1145
+
1146
+
1147
+ class CertificatesRequirerCharmEvents(CharmEvents):
1148
+ """List of events that the TLS Certificates requirer charm can leverage."""
1149
+
1150
+ certificate_available = EventSource(CertificateAvailableEvent)
1151
+
1152
+
1153
+ class TLSCertificatesRequiresV4(Object):
1154
+ """A class to manage the TLS certificates interface for a unit or app."""
1155
+
1156
+ on = CertificatesRequirerCharmEvents() # type: ignore[reportAssignmentType]
1157
+
1158
+ def __init__(
1159
+ self,
1160
+ charm: CharmBase,
1161
+ relationship_name: str,
1162
+ certificate_requests: List[CertificateRequestAttributes],
1163
+ mode: Mode = Mode.UNIT,
1164
+ refresh_events: List[BoundEvent] = [],
1165
+ private_key: Optional[PrivateKey] = None,
1166
+ renewal_relative_time: float = 0.9,
1167
+ ):
1168
+ """Create a new instance of the TLSCertificatesRequiresV4 class.
1169
+
1170
+ Args:
1171
+ charm (CharmBase): The charm instance to relate to.
1172
+ relationship_name (str): The name of the relation that provides the certificates.
1173
+ certificate_requests (List[CertificateRequestAttributes]):
1174
+ A list with the attributes of the certificate requests.
1175
+ mode (Mode): Whether to use unit or app certificates mode. Default is Mode.UNIT.
1176
+ In UNIT mode the requirer will place the csr in the unit relation data.
1177
+ Each unit will manage its private key,
1178
+ certificate signing request and certificate.
1179
+ UNIT mode is for use cases where each unit has its own identity.
1180
+ If you don't know which mode to use, you likely need UNIT.
1181
+ In APP mode the leader unit will place the csr in the app relation databag.
1182
+ APP mode is for use cases where the underlying application needs the certificate
1183
+ for example using it as an intermediate CA to sign other certificates.
1184
+ The certificate can only be accessed by the leader unit.
1185
+ refresh_events (List[BoundEvent]): A list of events to trigger a refresh of
1186
+ the certificates.
1187
+ private_key (Optional[PrivateKey]): The private key to use for the certificates.
1188
+ If provided, it will be used instead of generating a new one.
1189
+ If the key is not valid an exception will be raised.
1190
+ Using this parameter is discouraged,
1191
+ having to pass around private keys manually can be a security concern.
1192
+ Allowing the library to generate and manage the key is the more secure approach.
1193
+ renewal_relative_time (float): The time to renew the certificate relative to its
1194
+ expiry.
1195
+ Default is 0.9, meaning 90% of the validity period.
1196
+ The minimum value is 0.5, meaning 50% of the validity period.
1197
+ If an invalid value is provided, an exception will be raised.
1198
+ """
1199
+ super().__init__(charm, relationship_name)
1200
+ if not JujuVersion.from_environ().has_secrets:
1201
+ logger.warning("This version of the TLS library requires Juju secrets (Juju >= 3.0)")
1202
+ if not self._mode_is_valid(mode):
1203
+ raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP")
1204
+ for certificate_request in certificate_requests:
1205
+ if not certificate_request.is_valid():
1206
+ raise TLSCertificatesError("Invalid certificate request")
1207
+ self.charm = charm
1208
+ self.relationship_name = relationship_name
1209
+ self.certificate_requests = certificate_requests
1210
+ self.mode = mode
1211
+ if private_key and not private_key.is_valid():
1212
+ raise TLSCertificatesError("Invalid private key")
1213
+ if renewal_relative_time <= 0.5 or renewal_relative_time > 1.0:
1214
+ raise TLSCertificatesError(
1215
+ "Invalid renewal relative time. Must be between 0.0 and 1.0"
1216
+ )
1217
+ self._private_key = private_key
1218
+ self.renewal_relative_time = renewal_relative_time
1219
+ self.framework.observe(charm.on[relationship_name].relation_created, self._configure)
1220
+ self.framework.observe(charm.on[relationship_name].relation_changed, self._configure)
1221
+ self.framework.observe(charm.on.secret_expired, self._on_secret_expired)
1222
+ self.framework.observe(charm.on.secret_remove, self._on_secret_remove)
1223
+ for event in refresh_events:
1224
+ self.framework.observe(event, self._configure)
1225
+
1226
+ def _configure(self, _: Optional[EventBase] = None):
1227
+ """Handle TLS Certificates Relation Data.
1228
+
1229
+ This method is called during any TLS relation event.
1230
+ It will generate a private key if it doesn't exist yet.
1231
+ It will send certificate requests if they haven't been sent yet.
1232
+ It will find available certificates and emit events.
1233
+ """
1234
+ if not self._tls_relation_created():
1235
+ logger.debug("TLS relation not created yet.")
1236
+ return
1237
+ self._ensure_private_key()
1238
+ self._cleanup_certificate_requests()
1239
+ self._send_certificate_requests()
1240
+ self._find_available_certificates()
1241
+
1242
+ def _mode_is_valid(self, mode: Mode) -> bool:
1243
+ return mode in [Mode.UNIT, Mode.APP]
1244
+
1245
+ def _validate_secret_exists(self, secret: Secret) -> None:
1246
+ secret.get_info() # Will raise `SecretNotFoundError` if the secret does not exist
1247
+
1248
+ def _on_secret_remove(self, event: SecretRemoveEvent) -> None:
1249
+ """Handle Secret Removed Event."""
1250
+ try:
1251
+ # Ensure the secret exists before trying to remove it, otherwise
1252
+ # the unit could be stuck in an error state. See the docstring of
1253
+ # `remove_revision` and the below issue for more information.
1254
+ # https://github.com/juju/juju/issues/19036
1255
+ self._validate_secret_exists(event.secret)
1256
+ event.secret.remove_revision(event.revision)
1257
+ except SecretNotFoundError:
1258
+ logger.warning(
1259
+ "No such secret %s, nothing to remove",
1260
+ event.secret.label or event.secret.id,
1261
+ )
1262
+ return
1263
+
1264
+ def _on_secret_expired(self, event: SecretExpiredEvent) -> None:
1265
+ """Handle Secret Expired Event.
1266
+
1267
+ Renews certificate requests and removes the expired secret.
1268
+ """
1269
+ if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-certificate"):
1270
+ return
1271
+ try:
1272
+ csr_str = event.secret.get_content(refresh=True)["csr"]
1273
+ except ModelError:
1274
+ logger.error("Failed to get CSR from secret - Skipping")
1275
+ return
1276
+ csr = CertificateSigningRequest.from_string(csr_str)
1277
+ self._renew_certificate_request(csr)
1278
+ event.secret.remove_all_revisions()
1279
+
1280
+ def sync(self) -> None:
1281
+ """Sync TLS Certificates Relation Data.
1282
+
1283
+ This method allows the requirer to sync the TLS certificates relation data
1284
+ without waiting for the refresh events to be triggered.
1285
+ """
1286
+ self._configure()
1287
+
1288
+ def renew_certificate(self, certificate: ProviderCertificate) -> None:
1289
+ """Request the renewal of the provided certificate."""
1290
+ certificate_signing_request = certificate.certificate_signing_request
1291
+ secret_label = self._get_csr_secret_label(certificate_signing_request)
1292
+ try:
1293
+ secret = self.model.get_secret(label=secret_label)
1294
+ except SecretNotFoundError:
1295
+ logger.warning("No matching secret found - Skipping renewal")
1296
+ return
1297
+ current_csr = secret.get_content(refresh=True).get("csr", "")
1298
+ if current_csr != str(certificate_signing_request):
1299
+ logger.warning("No matching CSR found - Skipping renewal")
1300
+ return
1301
+ self._renew_certificate_request(certificate_signing_request)
1302
+ secret.remove_all_revisions()
1303
+
1304
+ def _renew_certificate_request(self, csr: CertificateSigningRequest):
1305
+ """Remove existing CSR from relation data and create a new one."""
1306
+ self._remove_requirer_csr_from_relation_data(csr)
1307
+ self._send_certificate_requests()
1308
+ logger.info("Renewed certificate request")
1309
+
1310
+ def _remove_requirer_csr_from_relation_data(self, csr: CertificateSigningRequest) -> None:
1311
+ relation = self.model.get_relation(self.relationship_name)
1312
+ if not relation:
1313
+ logger.debug("No relation: %s", self.relationship_name)
1314
+ return
1315
+ if not self.get_csrs_from_requirer_relation_data():
1316
+ logger.info("No CSRs in relation data - Doing nothing")
1317
+ return
1318
+ app_or_unit = self._get_app_or_unit()
1319
+ try:
1320
+ requirer_relation_data = _RequirerData.load(relation.data[app_or_unit])
1321
+ except DataValidationError:
1322
+ logger.warning("Invalid relation data - Skipping removal of CSR")
1323
+ return
1324
+ new_relation_data = copy.deepcopy(requirer_relation_data.certificate_signing_requests)
1325
+ for requirer_csr in new_relation_data:
1326
+ if requirer_csr.certificate_signing_request.strip() == str(csr).strip():
1327
+ new_relation_data.remove(requirer_csr)
1328
+ try:
1329
+ _RequirerData(certificate_signing_requests=new_relation_data).dump(
1330
+ relation.data[app_or_unit]
1331
+ )
1332
+ logger.info("Removed CSR from relation data")
1333
+ except ModelError:
1334
+ logger.warning("Failed to update relation data")
1335
+
1336
+ def _get_app_or_unit(self) -> Union[Application, Unit]:
1337
+ """Return the unit or app object based on the mode."""
1338
+ if self.mode == Mode.UNIT:
1339
+ return self.model.unit
1340
+ elif self.mode == Mode.APP:
1341
+ return self.model.app
1342
+ raise TLSCertificatesError("Invalid mode")
1343
+
1344
+ @property
1345
+ def private_key(self) -> Optional[PrivateKey]:
1346
+ """Return the private key."""
1347
+ if self._private_key:
1348
+ return self._private_key
1349
+ if not self._private_key_generated():
1350
+ return None
1351
+ secret = self.charm.model.get_secret(label=self._get_private_key_secret_label())
1352
+ private_key = secret.get_content(refresh=True)["private-key"]
1353
+ return PrivateKey.from_string(private_key)
1354
+
1355
+ def _ensure_private_key(self) -> None:
1356
+ """Make sure there is a private key to be used.
1357
+
1358
+ It will make sure there is a private key passed by the charm using the private_key
1359
+ parameter or generate a new one otherwise.
1360
+ """
1361
+ # Remove the generated private key
1362
+ # if one has been passed by the charm using the private_key parameter
1363
+ if self._private_key:
1364
+ self._remove_private_key_secret()
1365
+ return
1366
+ if self._private_key_generated():
1367
+ logger.debug("Private key already generated")
1368
+ return
1369
+ self._generate_private_key()
1370
+
1371
+ def regenerate_private_key(self) -> None:
1372
+ """Regenerate the private key.
1373
+
1374
+ Generate a new private key, remove old certificate requests and send new ones.
1375
+
1376
+ Raises:
1377
+ TLSCertificatesError: If the private key is passed by the charm using the
1378
+ private_key parameter.
1379
+ """
1380
+ if self._private_key:
1381
+ raise TLSCertificatesError(
1382
+ "Private key is passed by the charm through the private_key parameter, "
1383
+ "this function can't be used"
1384
+ )
1385
+ if not self._private_key_generated():
1386
+ logger.warning("No private key to regenerate")
1387
+ return
1388
+ self._generate_private_key()
1389
+ self._cleanup_certificate_requests()
1390
+ self._send_certificate_requests()
1391
+
1392
+ def _generate_private_key(self) -> None:
1393
+ """Generate a new private key and store it in a secret.
1394
+
1395
+ This is the case when the private key used is generated by the library.
1396
+ and not passed by the charm using the private_key parameter.
1397
+ """
1398
+ self._store_private_key_in_secret(generate_private_key())
1399
+ logger.info("Private key generated")
1400
+
1401
+ def _private_key_generated(self) -> bool:
1402
+ """Check if a private key is stored in a secret.
1403
+
1404
+ This is the case when the private key used is generated by the library.
1405
+ This should not exist when the private key used
1406
+ is passed by the charm using the private_key parameter.
1407
+ """
1408
+ try:
1409
+ secret = self.charm.model.get_secret(label=self._get_private_key_secret_label())
1410
+ secret.get_content(refresh=True)
1411
+ return True
1412
+ except SecretNotFoundError:
1413
+ return False
1414
+
1415
+ def _store_private_key_in_secret(self, private_key: PrivateKey) -> None:
1416
+ try:
1417
+ secret = self.charm.model.get_secret(label=self._get_private_key_secret_label())
1418
+ secret.set_content({"private-key": str(private_key)})
1419
+ secret.get_content(refresh=True)
1420
+ except SecretNotFoundError:
1421
+ self.charm.unit.add_secret(
1422
+ content={"private-key": str(private_key)},
1423
+ label=self._get_private_key_secret_label(),
1424
+ )
1425
+
1426
+ def _remove_private_key_secret(self) -> None:
1427
+ """Remove the private key secret."""
1428
+ try:
1429
+ secret = self.charm.model.get_secret(label=self._get_private_key_secret_label())
1430
+ secret.remove_all_revisions()
1431
+ except SecretNotFoundError:
1432
+ logger.warning("Private key secret not found, nothing to remove")
1433
+
1434
+ def _csr_matches_certificate_request(
1435
+ self, certificate_signing_request: CertificateSigningRequest, is_ca: bool
1436
+ ) -> bool:
1437
+ for certificate_request in self.certificate_requests:
1438
+ if certificate_request == CertificateRequestAttributes.from_csr(
1439
+ certificate_signing_request,
1440
+ is_ca,
1441
+ ):
1442
+ return True
1443
+ return False
1444
+
1445
+ def _certificate_requested(self, certificate_request: CertificateRequestAttributes) -> bool:
1446
+ if not self.private_key:
1447
+ return False
1448
+ csr = self._certificate_requested_for_attributes(certificate_request)
1449
+ if not csr:
1450
+ return False
1451
+ if not csr.certificate_signing_request.matches_private_key(key=self.private_key):
1452
+ return False
1453
+ return True
1454
+
1455
+ def _certificate_requested_for_attributes(
1456
+ self,
1457
+ certificate_request: CertificateRequestAttributes,
1458
+ ) -> Optional[RequirerCertificateRequest]:
1459
+ for requirer_csr in self.get_csrs_from_requirer_relation_data():
1460
+ if certificate_request == CertificateRequestAttributes.from_csr(
1461
+ requirer_csr.certificate_signing_request,
1462
+ requirer_csr.is_ca,
1463
+ ):
1464
+ return requirer_csr
1465
+ return None
1466
+
1467
+ def get_csrs_from_requirer_relation_data(self) -> List[RequirerCertificateRequest]:
1468
+ """Return list of requirer's CSRs from relation data."""
1469
+ if self.mode == Mode.APP and not self.model.unit.is_leader():
1470
+ logger.debug("Not a leader unit - Skipping")
1471
+ return []
1472
+ relation = self.model.get_relation(self.relationship_name)
1473
+ if not relation:
1474
+ logger.debug("No relation: %s", self.relationship_name)
1475
+ return []
1476
+ app_or_unit = self._get_app_or_unit()
1477
+ try:
1478
+ requirer_relation_data = _RequirerData.load(relation.data[app_or_unit])
1479
+ except DataValidationError:
1480
+ logger.warning("Invalid relation data")
1481
+ return []
1482
+ requirer_csrs = []
1483
+ for csr in requirer_relation_data.certificate_signing_requests:
1484
+ requirer_csrs.append(
1485
+ RequirerCertificateRequest(
1486
+ relation_id=relation.id,
1487
+ certificate_signing_request=CertificateSigningRequest.from_string(
1488
+ csr.certificate_signing_request
1489
+ ),
1490
+ is_ca=csr.ca if csr.ca else False,
1491
+ )
1492
+ )
1493
+ return requirer_csrs
1494
+
1495
+ def get_provider_certificates(self) -> List[ProviderCertificate]:
1496
+ """Return list of certificates from the provider's relation data."""
1497
+ return self._load_provider_certificates()
1498
+
1499
+ def _load_provider_certificates(self) -> List[ProviderCertificate]:
1500
+ relation = self.model.get_relation(self.relationship_name)
1501
+ if not relation:
1502
+ logger.debug("No relation: %s", self.relationship_name)
1503
+ return []
1504
+ if not relation.app:
1505
+ logger.debug("No remote app in relation: %s", self.relationship_name)
1506
+ return []
1507
+ try:
1508
+ provider_relation_data = _ProviderApplicationData.load(relation.data[relation.app])
1509
+ except DataValidationError:
1510
+ logger.warning("Invalid relation data")
1511
+ return []
1512
+ return [
1513
+ certificate.to_provider_certificate(relation_id=relation.id)
1514
+ for certificate in provider_relation_data.certificates
1515
+ ]
1516
+
1517
+ def _request_certificate(self, csr: CertificateSigningRequest, is_ca: bool) -> None:
1518
+ """Add CSR to relation data."""
1519
+ if self.mode == Mode.APP and not self.model.unit.is_leader():
1520
+ logger.debug("Not a leader unit - Skipping")
1521
+ return
1522
+ relation = self.model.get_relation(self.relationship_name)
1523
+ if not relation:
1524
+ logger.debug("No relation: %s", self.relationship_name)
1525
+ return
1526
+ new_csr = _CertificateSigningRequest(
1527
+ certificate_signing_request=str(csr).strip(), ca=is_ca
1528
+ )
1529
+ app_or_unit = self._get_app_or_unit()
1530
+ try:
1531
+ requirer_relation_data = _RequirerData.load(relation.data[app_or_unit])
1532
+ except DataValidationError:
1533
+ requirer_relation_data = _RequirerData(
1534
+ certificate_signing_requests=[],
1535
+ )
1536
+ new_relation_data = copy.deepcopy(requirer_relation_data.certificate_signing_requests)
1537
+ new_relation_data.append(new_csr)
1538
+ try:
1539
+ _RequirerData(certificate_signing_requests=new_relation_data).dump(
1540
+ relation.data[app_or_unit]
1541
+ )
1542
+ logger.info("Certificate signing request added to relation data.")
1543
+ except ModelError:
1544
+ logger.warning("Failed to update relation data")
1545
+
1546
+ def _send_certificate_requests(self):
1547
+ if not self.private_key:
1548
+ logger.debug("Private key not generated yet.")
1549
+ return
1550
+ for certificate_request in self.certificate_requests:
1551
+ if not self._certificate_requested(certificate_request):
1552
+ csr = certificate_request.generate_csr(
1553
+ private_key=self.private_key,
1554
+ )
1555
+ if not csr:
1556
+ logger.warning("Failed to generate CSR")
1557
+ continue
1558
+ self._request_certificate(csr=csr, is_ca=certificate_request.is_ca)
1559
+
1560
+ def get_assigned_certificate(
1561
+ self, certificate_request: CertificateRequestAttributes
1562
+ ) -> Tuple[Optional[ProviderCertificate], Optional[PrivateKey]]:
1563
+ """Get the certificate that was assigned to the given certificate request."""
1564
+ for requirer_csr in self.get_csrs_from_requirer_relation_data():
1565
+ if certificate_request == CertificateRequestAttributes.from_csr(
1566
+ requirer_csr.certificate_signing_request,
1567
+ requirer_csr.is_ca,
1568
+ ):
1569
+ return self._find_certificate_in_relation_data(requirer_csr), self.private_key
1570
+ return None, None
1571
+
1572
+ def get_assigned_certificates(
1573
+ self,
1574
+ ) -> Tuple[List[ProviderCertificate], Optional[PrivateKey]]:
1575
+ """Get a list of certificates that were assigned to this or app."""
1576
+ assigned_certificates = []
1577
+ for requirer_csr in self.get_csrs_from_requirer_relation_data():
1578
+ if cert := self._find_certificate_in_relation_data(requirer_csr):
1579
+ assigned_certificates.append(cert)
1580
+ return assigned_certificates, self.private_key
1581
+
1582
+ def _find_certificate_in_relation_data(
1583
+ self, csr: RequirerCertificateRequest
1584
+ ) -> Optional[ProviderCertificate]:
1585
+ """Return the certificate that matches the given CSR, validated against the private key."""
1586
+ if not self.private_key:
1587
+ return None
1588
+ for provider_certificate in self.get_provider_certificates():
1589
+ if provider_certificate.certificate_signing_request == csr.certificate_signing_request:
1590
+ if provider_certificate.certificate.is_ca and not csr.is_ca:
1591
+ logger.warning("Non CA certificate requested, got a CA certificate, ignoring")
1592
+ continue
1593
+ elif not provider_certificate.certificate.is_ca and csr.is_ca:
1594
+ logger.warning("CA certificate requested, got a non CA certificate, ignoring")
1595
+ continue
1596
+ if not provider_certificate.certificate.matches_private_key(self.private_key):
1597
+ logger.warning(
1598
+ "Certificate does not match the private key. Ignoring invalid certificate."
1599
+ )
1600
+ continue
1601
+ return provider_certificate
1602
+ return None
1603
+
1604
+ def _find_available_certificates(self):
1605
+ """Find available certificates and emit events.
1606
+
1607
+ This method will find certificates that are available for the requirer's CSRs.
1608
+ If a certificate is found, it will be set as a secret and an event will be emitted.
1609
+ If a certificate is revoked, the secret will be removed and an event will be emitted.
1610
+ """
1611
+ requirer_csrs = self.get_csrs_from_requirer_relation_data()
1612
+ csrs = [csr.certificate_signing_request for csr in requirer_csrs]
1613
+ provider_certificates = self.get_provider_certificates()
1614
+ for provider_certificate in provider_certificates:
1615
+ if provider_certificate.certificate_signing_request in csrs:
1616
+ secret_label = self._get_csr_secret_label(
1617
+ provider_certificate.certificate_signing_request
1618
+ )
1619
+ if provider_certificate.revoked:
1620
+ with suppress(SecretNotFoundError):
1621
+ logger.debug(
1622
+ "Removing secret with label %s",
1623
+ secret_label,
1624
+ )
1625
+ secret = self.model.get_secret(label=secret_label)
1626
+ secret.remove_all_revisions()
1627
+ else:
1628
+ if not self._csr_matches_certificate_request(
1629
+ certificate_signing_request=provider_certificate.certificate_signing_request,
1630
+ is_ca=provider_certificate.certificate.is_ca,
1631
+ ):
1632
+ logger.debug("Certificate requested for different attributes - Skipping")
1633
+ continue
1634
+ try:
1635
+ secret = self.model.get_secret(label=secret_label)
1636
+ logger.debug("Setting secret with label %s", secret_label)
1637
+ # Juju < 3.6 will create a new revision even if the content is the same
1638
+ if secret.get_content(refresh=True).get("certificate", "") == str(
1639
+ provider_certificate.certificate
1640
+ ):
1641
+ logger.debug(
1642
+ "Secret %s with correct certificate already exists", secret_label
1643
+ )
1644
+ continue
1645
+ secret.set_content(
1646
+ content={
1647
+ "certificate": str(provider_certificate.certificate),
1648
+ "csr": str(provider_certificate.certificate_signing_request),
1649
+ }
1650
+ )
1651
+ secret.set_info(
1652
+ expire=calculate_relative_datetime(
1653
+ target_time=provider_certificate.certificate.expiry_time,
1654
+ fraction=self.renewal_relative_time,
1655
+ ),
1656
+ )
1657
+ secret.get_content(refresh=True)
1658
+ except SecretNotFoundError:
1659
+ logger.debug("Creating new secret with label %s", secret_label)
1660
+ secret = self.charm.unit.add_secret(
1661
+ content={
1662
+ "certificate": str(provider_certificate.certificate),
1663
+ "csr": str(provider_certificate.certificate_signing_request),
1664
+ },
1665
+ label=secret_label,
1666
+ expire=calculate_relative_datetime(
1667
+ target_time=provider_certificate.certificate.expiry_time,
1668
+ fraction=self.renewal_relative_time,
1669
+ ),
1670
+ )
1671
+ self.on.certificate_available.emit(
1672
+ certificate_signing_request=provider_certificate.certificate_signing_request,
1673
+ certificate=provider_certificate.certificate,
1674
+ ca=provider_certificate.ca,
1675
+ chain=provider_certificate.chain,
1676
+ )
1677
+
1678
+ def _cleanup_certificate_requests(self):
1679
+ """Clean up certificate requests.
1680
+
1681
+ Remove any certificate requests that falls into one of the following categories:
1682
+ - The CSR attributes do not match any of the certificate requests defined in
1683
+ the charm's certificate_requests attribute.
1684
+ - The CSR public key does not match the private key.
1685
+ """
1686
+ for requirer_csr in self.get_csrs_from_requirer_relation_data():
1687
+ if not self._csr_matches_certificate_request(
1688
+ certificate_signing_request=requirer_csr.certificate_signing_request,
1689
+ is_ca=requirer_csr.is_ca,
1690
+ ):
1691
+ self._remove_requirer_csr_from_relation_data(
1692
+ requirer_csr.certificate_signing_request
1693
+ )
1694
+ logger.info(
1695
+ "Removed CSR from relation data because it did not match any certificate request" # noqa: E501
1696
+ )
1697
+ elif (
1698
+ self.private_key
1699
+ and not requirer_csr.certificate_signing_request.matches_private_key(
1700
+ self.private_key
1701
+ )
1702
+ ):
1703
+ self._remove_requirer_csr_from_relation_data(
1704
+ requirer_csr.certificate_signing_request
1705
+ )
1706
+ logger.info(
1707
+ "Removed CSR from relation data because it did not match the private key" # noqa: E501
1708
+ )
1709
+
1710
+ def _tls_relation_created(self) -> bool:
1711
+ relation = self.model.get_relation(self.relationship_name)
1712
+ if not relation:
1713
+ return False
1714
+ return True
1715
+
1716
+ def _get_private_key_secret_label(self) -> str:
1717
+ if self.mode == Mode.UNIT:
1718
+ return f"{LIBID}-private-key-{self._get_unit_number()}-{self.relationship_name}"
1719
+ elif self.mode == Mode.APP:
1720
+ return f"{LIBID}-private-key-{self.relationship_name}"
1721
+ else:
1722
+ raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP.")
1723
+
1724
+ def _get_csr_secret_label(self, csr: CertificateSigningRequest) -> str:
1725
+ csr_in_sha256_hex = csr.get_sha256_hex()
1726
+ if self.mode == Mode.UNIT:
1727
+ return f"{LIBID}-certificate-{self._get_unit_number()}-{csr_in_sha256_hex}"
1728
+ elif self.mode == Mode.APP:
1729
+ return f"{LIBID}-certificate-{csr_in_sha256_hex}"
1730
+ else:
1731
+ raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP.")
1732
+
1733
+ def _get_unit_number(self) -> str:
1734
+ return self.model.unit.name.split("/")[1]
1735
+
1736
+
1737
+ class TLSCertificatesProvidesV4(Object):
1738
+ """TLS certificates provider class to be instantiated by TLS certificates providers."""
1739
+
1740
+ def __init__(self, charm: CharmBase, relationship_name: str):
1741
+ super().__init__(charm, relationship_name)
1742
+ self.framework.observe(charm.on[relationship_name].relation_joined, self._configure)
1743
+ self.framework.observe(charm.on[relationship_name].relation_changed, self._configure)
1744
+ self.framework.observe(charm.on.update_status, self._configure)
1745
+ self.charm = charm
1746
+ self.relationship_name = relationship_name
1747
+
1748
+ def _configure(self, _: EventBase) -> None:
1749
+ """Handle update status and tls relation changed events.
1750
+
1751
+ This is a common hook triggered on a regular basis.
1752
+
1753
+ Revoke certificates for which no csr exists
1754
+ """
1755
+ if not self.model.unit.is_leader():
1756
+ return
1757
+ self._remove_certificates_for_which_no_csr_exists()
1758
+
1759
+ def _remove_certificates_for_which_no_csr_exists(self) -> None:
1760
+ provider_certificates = self.get_provider_certificates()
1761
+ requirer_csrs = [
1762
+ request.certificate_signing_request for request in self.get_certificate_requests()
1763
+ ]
1764
+ for provider_certificate in provider_certificates:
1765
+ if provider_certificate.certificate_signing_request not in requirer_csrs:
1766
+ tls_relation = self._get_tls_relations(
1767
+ relation_id=provider_certificate.relation_id
1768
+ )
1769
+ self._remove_provider_certificate(
1770
+ certificate=provider_certificate.certificate,
1771
+ relation=tls_relation[0],
1772
+ )
1773
+
1774
+ def _get_tls_relations(self, relation_id: Optional[int] = None) -> List[Relation]:
1775
+ return (
1776
+ [
1777
+ relation
1778
+ for relation in self.model.relations[self.relationship_name]
1779
+ if relation.id == relation_id
1780
+ ]
1781
+ if relation_id is not None
1782
+ else self.model.relations.get(self.relationship_name, [])
1783
+ )
1784
+
1785
+ def get_certificate_requests(
1786
+ self, relation_id: Optional[int] = None
1787
+ ) -> List[RequirerCertificateRequest]:
1788
+ """Load certificate requests from the relation data."""
1789
+ relations = self._get_tls_relations(relation_id)
1790
+ requirer_csrs: List[RequirerCertificateRequest] = []
1791
+ for relation in relations:
1792
+ for unit in relation.units:
1793
+ requirer_csrs.extend(self._load_requirer_databag(relation, unit))
1794
+ requirer_csrs.extend(self._load_requirer_databag(relation, relation.app))
1795
+ return requirer_csrs
1796
+
1797
+ def _load_requirer_databag(
1798
+ self, relation: Relation, unit_or_app: Union[Application, Unit]
1799
+ ) -> List[RequirerCertificateRequest]:
1800
+ try:
1801
+ requirer_relation_data = _RequirerData.load(relation.data.get(unit_or_app, {}))
1802
+ except DataValidationError:
1803
+ logger.debug("Invalid requirer relation data for %s", unit_or_app.name)
1804
+ return []
1805
+ return [
1806
+ RequirerCertificateRequest(
1807
+ relation_id=relation.id,
1808
+ certificate_signing_request=CertificateSigningRequest.from_string(
1809
+ csr.certificate_signing_request
1810
+ ),
1811
+ is_ca=csr.ca if csr.ca else False,
1812
+ )
1813
+ for csr in requirer_relation_data.certificate_signing_requests
1814
+ ]
1815
+
1816
+ def _add_provider_certificate(
1817
+ self,
1818
+ relation: Relation,
1819
+ provider_certificate: ProviderCertificate,
1820
+ ) -> None:
1821
+ chain = [str(certificate) for certificate in provider_certificate.chain]
1822
+ if chain[0] != str(provider_certificate.certificate):
1823
+ logger.warning(
1824
+ "The order of the chain from the TLS Certificates Provider is incorrect. "
1825
+ "The leaf certificate should be the first element of the chain."
1826
+ )
1827
+ elif not chain_has_valid_order(chain):
1828
+ logger.warning(
1829
+ "The order of the chain from the TLS Certificates Provider is partially incorrect."
1830
+ )
1831
+ new_certificate = _Certificate(
1832
+ certificate=str(provider_certificate.certificate),
1833
+ certificate_signing_request=str(provider_certificate.certificate_signing_request),
1834
+ ca=str(provider_certificate.ca),
1835
+ chain=chain,
1836
+ )
1837
+ provider_certificates = self._load_provider_certificates(relation)
1838
+ if new_certificate in provider_certificates:
1839
+ logger.info("Certificate already in relation data - Doing nothing")
1840
+ return
1841
+ provider_certificates.append(new_certificate)
1842
+ self._dump_provider_certificates(relation=relation, certificates=provider_certificates)
1843
+
1844
+ def _load_provider_certificates(self, relation: Relation) -> List[_Certificate]:
1845
+ try:
1846
+ provider_relation_data = _ProviderApplicationData.load(relation.data[self.charm.app])
1847
+ except DataValidationError:
1848
+ logger.debug("Invalid provider relation data")
1849
+ return []
1850
+ return copy.deepcopy(provider_relation_data.certificates)
1851
+
1852
+ def _dump_provider_certificates(self, relation: Relation, certificates: List[_Certificate]):
1853
+ try:
1854
+ _ProviderApplicationData(certificates=certificates).dump(relation.data[self.model.app])
1855
+ logger.info("Certificate relation data updated")
1856
+ except ModelError:
1857
+ logger.warning("Failed to update relation data")
1858
+
1859
+ def _remove_provider_certificate(
1860
+ self,
1861
+ relation: Relation,
1862
+ certificate: Optional[Certificate] = None,
1863
+ certificate_signing_request: Optional[CertificateSigningRequest] = None,
1864
+ ) -> None:
1865
+ """Remove certificate based on certificate or certificate signing request."""
1866
+ provider_certificates = self._load_provider_certificates(relation)
1867
+ for provider_certificate in provider_certificates:
1868
+ if certificate and provider_certificate.certificate == str(certificate):
1869
+ provider_certificates.remove(provider_certificate)
1870
+ if (
1871
+ certificate_signing_request
1872
+ and provider_certificate.certificate_signing_request
1873
+ == str(certificate_signing_request)
1874
+ ):
1875
+ provider_certificates.remove(provider_certificate)
1876
+ self._dump_provider_certificates(relation=relation, certificates=provider_certificates)
1877
+
1878
+ def revoke_all_certificates(self) -> None:
1879
+ """Revoke all certificates of this provider.
1880
+
1881
+ This method is meant to be used when the Root CA has changed.
1882
+ """
1883
+ if not self.model.unit.is_leader():
1884
+ logger.warning("Unit is not a leader - will not set relation data")
1885
+ return
1886
+ relations = self._get_tls_relations()
1887
+ for relation in relations:
1888
+ provider_certificates = self._load_provider_certificates(relation)
1889
+ for certificate in provider_certificates:
1890
+ certificate.revoked = True
1891
+ self._dump_provider_certificates(relation=relation, certificates=provider_certificates)
1892
+
1893
+ def set_relation_certificate(
1894
+ self,
1895
+ provider_certificate: ProviderCertificate,
1896
+ ) -> None:
1897
+ """Add certificates to relation data.
1898
+
1899
+ Args:
1900
+ provider_certificate (ProviderCertificate): ProviderCertificate object
1901
+
1902
+ Returns:
1903
+ None
1904
+ """
1905
+ if not self.model.unit.is_leader():
1906
+ logger.warning("Unit is not a leader - will not set relation data")
1907
+ return
1908
+ certificates_relation = self.model.get_relation(
1909
+ relation_name=self.relationship_name, relation_id=provider_certificate.relation_id
1910
+ )
1911
+ if not certificates_relation:
1912
+ raise TLSCertificatesError(f"Relation {self.relationship_name} does not exist")
1913
+ self._remove_provider_certificate(
1914
+ relation=certificates_relation,
1915
+ certificate_signing_request=provider_certificate.certificate_signing_request,
1916
+ )
1917
+ self._add_provider_certificate(
1918
+ relation=certificates_relation,
1919
+ provider_certificate=provider_certificate,
1920
+ )
1921
+
1922
+ def get_issued_certificates(
1923
+ self, relation_id: Optional[int] = None
1924
+ ) -> List[ProviderCertificate]:
1925
+ """Return a List of issued (non revoked) certificates.
1926
+
1927
+ Returns:
1928
+ List: List of ProviderCertificate objects
1929
+ """
1930
+ if not self.model.unit.is_leader():
1931
+ logger.warning("Unit is not a leader - will not read relation data")
1932
+ return []
1933
+ provider_certificates = self.get_provider_certificates(relation_id=relation_id)
1934
+ return [certificate for certificate in provider_certificates if not certificate.revoked]
1935
+
1936
+ def get_provider_certificates(
1937
+ self, relation_id: Optional[int] = None
1938
+ ) -> List[ProviderCertificate]:
1939
+ """Return a List of issued certificates."""
1940
+ certificates: List[ProviderCertificate] = []
1941
+ relations = self._get_tls_relations(relation_id)
1942
+ for relation in relations:
1943
+ if not relation.app:
1944
+ logger.warning("Relation %s does not have an application", relation.id)
1945
+ continue
1946
+ for certificate in self._load_provider_certificates(relation):
1947
+ certificates.append(certificate.to_provider_certificate(relation_id=relation.id))
1948
+ return certificates
1949
+
1950
+ def get_unsolicited_certificates(
1951
+ self, relation_id: Optional[int] = None
1952
+ ) -> List[ProviderCertificate]:
1953
+ """Return provider certificates for which no certificate requests exists.
1954
+
1955
+ Those certificates should be revoked.
1956
+ """
1957
+ unsolicited_certificates: List[ProviderCertificate] = []
1958
+ provider_certificates = self.get_provider_certificates(relation_id=relation_id)
1959
+ requirer_csrs = self.get_certificate_requests(relation_id=relation_id)
1960
+ list_of_csrs = [csr.certificate_signing_request for csr in requirer_csrs]
1961
+ for certificate in provider_certificates:
1962
+ if certificate.certificate_signing_request not in list_of_csrs:
1963
+ unsolicited_certificates.append(certificate)
1964
+ return unsolicited_certificates
1965
+
1966
+ def get_outstanding_certificate_requests(
1967
+ self, relation_id: Optional[int] = None
1968
+ ) -> List[RequirerCertificateRequest]:
1969
+ """Return CSR's for which no certificate has been issued.
1970
+
1971
+ Args:
1972
+ relation_id (int): Relation id
1973
+
1974
+ Returns:
1975
+ list: List of RequirerCertificateRequest objects.
1976
+ """
1977
+ requirer_csrs = self.get_certificate_requests(relation_id=relation_id)
1978
+ outstanding_csrs: List[RequirerCertificateRequest] = []
1979
+ for relation_csr in requirer_csrs:
1980
+ if not self._certificate_issued_for_csr(
1981
+ csr=relation_csr.certificate_signing_request,
1982
+ relation_id=relation_id,
1983
+ ):
1984
+ outstanding_csrs.append(relation_csr)
1985
+ return outstanding_csrs
1986
+
1987
+ def _certificate_issued_for_csr(
1988
+ self, csr: CertificateSigningRequest, relation_id: Optional[int]
1989
+ ) -> bool:
1990
+ """Check whether a certificate has been issued for a given CSR."""
1991
+ issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id)
1992
+ for issued_certificate in issued_certificates_per_csr:
1993
+ if issued_certificate.certificate_signing_request == csr:
1994
+ return csr.matches_certificate(issued_certificate.certificate)
1995
+ return False