charmlibs-interfaces-tls-certificates 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of charmlibs-interfaces-tls-certificates might be problematic. Click here for more details.

@@ -0,0 +1,1981 @@
1
+ # Copyright 2024 Canonical Ltd.
2
+ # See LICENSE file for licensing details.
3
+
4
+ """Source code of ``tls_certificates_interface.tls_certificates`` v4.22."""
5
+
6
+ from __future__ import annotations
7
+
8
+ import copy
9
+ import ipaddress
10
+ import json
11
+ import logging
12
+ import uuid
13
+ from contextlib import suppress
14
+ from dataclasses import dataclass
15
+ from datetime import datetime, timedelta, timezone
16
+ from enum import Enum
17
+ from typing import TYPE_CHECKING
18
+
19
+ import pydantic
20
+ from cryptography import x509
21
+ from cryptography.exceptions import InvalidSignature
22
+ from cryptography.hazmat.primitives import hashes, serialization
23
+ from cryptography.hazmat.primitives.asymmetric import rsa
24
+ from cryptography.x509.oid import ExtensionOID, NameOID
25
+ from ops import BoundEvent, CharmBase, CharmEvents, Secret, SecretExpiredEvent, SecretRemoveEvent
26
+ from ops.framework import EventBase, EventSource, Handle, Object
27
+ from ops.jujuversion import JujuVersion
28
+ from ops.model import Application, ModelError, Relation, SecretNotFoundError, Unit
29
+
30
+ if TYPE_CHECKING:
31
+ from collections.abc import MutableMapping
32
+ from typing import Any
33
+
34
+ # legacy Charmhub-hosted lib ID, used at runtime in this lib for labels
35
+ LIBID = "afd8c2bccf834997afce12c2706d2ede"
36
+
37
+ IS_PYDANTIC_V1 = int(pydantic.version.VERSION.split(".")[0]) < 2
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ class TLSCertificatesError(Exception):
43
+ """Base class for custom errors raised by this library."""
44
+
45
+
46
+ class DataValidationError(TLSCertificatesError):
47
+ """Raised when data validation fails."""
48
+
49
+
50
+ class _DatabagModel(pydantic.BaseModel):
51
+ """Base databag model.
52
+
53
+ Supports both pydantic v1 and v2.
54
+ """
55
+
56
+ if IS_PYDANTIC_V1:
57
+
58
+ class Config:
59
+ """Pydantic config."""
60
+
61
+ # ignore any extra fields in the databag
62
+ extra = "ignore"
63
+ """Ignore any extra fields in the databag."""
64
+ allow_population_by_field_name = True
65
+ """Allow instantiating this class by field name (instead of forcing alias)."""
66
+
67
+ _NEST_UNDER = None
68
+
69
+ model_config = pydantic.ConfigDict(
70
+ # tolerate additional keys in databag
71
+ extra="ignore",
72
+ # Allow instantiating this class by field name (instead of forcing alias).
73
+ populate_by_name=True,
74
+ # Custom config key: whether to nest the whole datastructure (as json)
75
+ # under a field or spread it out at the toplevel.
76
+ _NEST_UNDER=None,
77
+ ) # type: ignore
78
+ """Pydantic config."""
79
+
80
+ @classmethod
81
+ def load(cls, databag: MutableMapping[str, str]):
82
+ """Load this model from a Juju databag."""
83
+ if IS_PYDANTIC_V1:
84
+ return cls._load_v1(databag)
85
+ nest_under = cls.model_config.get("_NEST_UNDER")
86
+ if nest_under:
87
+ return cls.model_validate(json.loads(databag[nest_under]))
88
+
89
+ try:
90
+ data = {
91
+ k: json.loads(v)
92
+ for k, v in databag.items()
93
+ # Don't attempt to parse model-external values
94
+ if k in {(f.alias or n) for n, f in cls.model_fields.items()}
95
+ }
96
+ except json.JSONDecodeError as e:
97
+ msg = f"invalid databag contents: expecting json. {databag}"
98
+ logger.error(msg)
99
+ raise DataValidationError(msg) from e
100
+
101
+ try:
102
+ return cls.model_validate_json(json.dumps(data))
103
+ except pydantic.ValidationError as e:
104
+ msg = f"failed to validate databag: {databag}"
105
+ logger.debug(msg, exc_info=True)
106
+ raise DataValidationError(msg) from e
107
+
108
+ @classmethod
109
+ def _load_v1(cls, databag: MutableMapping[str, str]):
110
+ """Load implementation for pydantic v1."""
111
+ if cls._NEST_UNDER:
112
+ return cls.parse_obj(json.loads(databag[cls._NEST_UNDER])) # pyright: ignore[reportDeprecated]
113
+
114
+ try:
115
+ data = {
116
+ k: json.loads(v)
117
+ for k, v in databag.items()
118
+ # Don't attempt to parse model-external values
119
+ if k in {f.alias for f in cls.__fields__.values()} # pyright: ignore[reportDeprecated]
120
+ }
121
+ except json.JSONDecodeError as e:
122
+ msg = f"invalid databag contents: expecting json. {databag}"
123
+ logger.error(msg)
124
+ raise DataValidationError(msg) from e
125
+
126
+ try:
127
+ return cls.parse_raw(json.dumps(data)) # type: ignore
128
+ except pydantic.ValidationError as e:
129
+ msg = f"failed to validate databag: {databag}"
130
+ logger.debug(msg, exc_info=True)
131
+ raise DataValidationError(msg) from e
132
+
133
+ def dump(self, databag: MutableMapping[str, str] | None = None, clear: bool = True):
134
+ """Write the contents of this model to Juju databag.
135
+
136
+ Args:
137
+ databag: The databag to write to.
138
+ clear: Whether to clear the databag before writing.
139
+
140
+ Returns:
141
+ MutableMapping: The databag.
142
+ """
143
+ if IS_PYDANTIC_V1:
144
+ return self._dump_v1(databag, clear)
145
+ if clear and databag:
146
+ databag.clear()
147
+
148
+ if databag is None:
149
+ databag = {}
150
+ nest_under = self.model_config.get("_NEST_UNDER")
151
+ if nest_under:
152
+ databag[nest_under] = self.model_dump_json(
153
+ by_alias=True,
154
+ # skip keys whose values are default
155
+ exclude_defaults=True,
156
+ )
157
+ return databag
158
+
159
+ dct = self.model_dump(mode="json", by_alias=True, exclude_defaults=True)
160
+ databag.update({k: json.dumps(v) for k, v in dct.items()})
161
+ return databag
162
+
163
+ def _dump_v1(self, databag: MutableMapping[str, str] | None = None, clear: bool = True):
164
+ """Dump implementation for pydantic v1."""
165
+ if clear and databag:
166
+ databag.clear()
167
+
168
+ if databag is None:
169
+ databag = {}
170
+
171
+ if self._NEST_UNDER:
172
+ databag[self._NEST_UNDER] = self.json(by_alias=True, exclude_defaults=True) # pyright: ignore[reportDeprecated]
173
+ return databag
174
+
175
+ dct = json.loads(self.json(by_alias=True, exclude_defaults=True)) # pyright: ignore[reportDeprecated]
176
+ databag.update({k: json.dumps(v) for k, v in dct.items()})
177
+
178
+ return databag
179
+
180
+
181
+ class _Certificate(pydantic.BaseModel):
182
+ """Certificate model."""
183
+
184
+ ca: str
185
+ certificate_signing_request: str
186
+ certificate: str
187
+ chain: list[str] | None = None
188
+ revoked: bool | None = None
189
+
190
+ def to_provider_certificate(self, relation_id: int) -> ProviderCertificate:
191
+ """Convert to a ProviderCertificate."""
192
+ return ProviderCertificate(
193
+ relation_id=relation_id,
194
+ certificate=Certificate.from_string(self.certificate),
195
+ certificate_signing_request=CertificateSigningRequest.from_string(
196
+ self.certificate_signing_request
197
+ ),
198
+ ca=Certificate.from_string(self.ca),
199
+ chain=[Certificate.from_string(certificate) for certificate in self.chain]
200
+ if self.chain
201
+ else [],
202
+ revoked=self.revoked,
203
+ )
204
+
205
+
206
+ class _CertificateSigningRequest(pydantic.BaseModel):
207
+ """Certificate signing request model."""
208
+
209
+ certificate_signing_request: str
210
+ ca: bool | None
211
+
212
+
213
+ class _ProviderApplicationData(_DatabagModel):
214
+ """Provider application data model."""
215
+
216
+ # FIXME: mutable default? should it be Field(default_factory=list)?
217
+ certificates: list[_Certificate] = []
218
+
219
+
220
+ class _RequirerData(_DatabagModel):
221
+ """Requirer data model.
222
+
223
+ The same model is used for the unit and application data.
224
+ """
225
+
226
+ # FIXME: mutable default? should it be Field(default_factory=list)?
227
+ certificate_signing_requests: list[_CertificateSigningRequest] = []
228
+
229
+
230
+ class Mode(Enum):
231
+ """Enum representing the mode of the certificate request.
232
+
233
+ UNIT (default): Request a certificate for the unit.
234
+ Each unit will manage its private key,
235
+ certificate signing request and certificate.
236
+ APP: Request a certificate for the application.
237
+ Only the leader unit will manage the private key, certificate signing request
238
+ and certificate.
239
+ """
240
+
241
+ UNIT = 1
242
+ APP = 2
243
+
244
+
245
+ @dataclass(frozen=True)
246
+ class PrivateKey:
247
+ """This class represents a private key."""
248
+
249
+ raw: str
250
+
251
+ def __str__(self):
252
+ """Return the private key as a string."""
253
+ return self.raw
254
+
255
+ @classmethod
256
+ def from_string(cls, private_key: str) -> PrivateKey:
257
+ """Create a PrivateKey object from a private key."""
258
+ return cls(raw=private_key.strip())
259
+
260
+ def is_valid(self) -> bool:
261
+ """Validate that the private key is PEM-formatted, RSA, and at least 2048 bits."""
262
+ try:
263
+ key = serialization.load_pem_private_key(
264
+ self.raw.encode(),
265
+ password=None,
266
+ )
267
+
268
+ if not isinstance(key, rsa.RSAPrivateKey):
269
+ logger.warning("Private key is not an RSA key")
270
+ return False
271
+
272
+ if key.key_size < 2048:
273
+ logger.warning("RSA key size is less than 2048 bits")
274
+ return False
275
+
276
+ return True
277
+ except ValueError:
278
+ logger.warning("Invalid private key format")
279
+ return False
280
+
281
+
282
+ @dataclass(frozen=True)
283
+ class Certificate:
284
+ """This class represents a certificate."""
285
+
286
+ raw: str
287
+ common_name: str
288
+ expiry_time: datetime
289
+ validity_start_time: datetime
290
+ is_ca: bool = False
291
+ sans_dns: frozenset[str] | None = frozenset()
292
+ sans_ip: frozenset[str] | None = frozenset()
293
+ sans_oid: frozenset[str] | None = frozenset()
294
+ email_address: str | None = None
295
+ organization: str | None = None
296
+ organizational_unit: str | None = None
297
+ country_name: str | None = None
298
+ state_or_province_name: str | None = None
299
+ locality_name: str | None = None
300
+
301
+ def __str__(self) -> str:
302
+ """Return the certificate as a string."""
303
+ return self.raw
304
+
305
+ @classmethod
306
+ def from_string(cls, certificate: str) -> Certificate:
307
+ """Create a Certificate object from a certificate."""
308
+ try:
309
+ certificate_object = x509.load_pem_x509_certificate(data=certificate.encode())
310
+ except ValueError as e:
311
+ logger.error("Could not load certificate: %s", e)
312
+ raise TLSCertificatesError("Could not load certificate")
313
+
314
+ common_name = certificate_object.subject.get_attributes_for_oid(NameOID.COMMON_NAME)
315
+ country_name = certificate_object.subject.get_attributes_for_oid(NameOID.COUNTRY_NAME)
316
+ state_or_province_name = certificate_object.subject.get_attributes_for_oid(
317
+ NameOID.STATE_OR_PROVINCE_NAME
318
+ )
319
+ locality_name = certificate_object.subject.get_attributes_for_oid(NameOID.LOCALITY_NAME)
320
+ organization_name = certificate_object.subject.get_attributes_for_oid(
321
+ NameOID.ORGANIZATION_NAME
322
+ )
323
+ organizational_unit = certificate_object.subject.get_attributes_for_oid(
324
+ NameOID.ORGANIZATIONAL_UNIT_NAME
325
+ )
326
+ email_address = certificate_object.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS)
327
+ sans_dns: list[str] = []
328
+ sans_ip: list[str] = []
329
+ sans_oid: list[str] = []
330
+ try:
331
+ sans = certificate_object.extensions.get_extension_for_class(
332
+ x509.SubjectAlternativeName
333
+ ).value
334
+ for san in sans:
335
+ if isinstance(san, x509.DNSName):
336
+ sans_dns.append(san.value)
337
+ if isinstance(san, x509.IPAddress):
338
+ sans_ip.append(str(san.value))
339
+ if isinstance(san, x509.RegisteredID):
340
+ sans_oid.append(str(san.value))
341
+ except x509.ExtensionNotFound:
342
+ logger.debug("No SANs found in certificate")
343
+ sans_dns = []
344
+ sans_ip = []
345
+ sans_oid = []
346
+ expiry_time = certificate_object.not_valid_after_utc
347
+ validity_start_time = certificate_object.not_valid_before_utc
348
+ is_ca = False
349
+ try:
350
+ is_ca = certificate_object.extensions.get_extension_for_oid( # pyright: ignore[reportUnknownVariableType]
351
+ ExtensionOID.BASIC_CONSTRAINTS
352
+ ).value.ca # type: ignore[reportAttributeAccessIssue]
353
+ except x509.ExtensionNotFound:
354
+ pass
355
+
356
+ return cls(
357
+ raw=certificate.strip(),
358
+ common_name=str(common_name[0].value),
359
+ is_ca=is_ca, # pyright: ignore[reportUnknownArgumentType]
360
+ country_name=str(country_name[0].value) if country_name else None,
361
+ state_or_province_name=str(state_or_province_name[0].value)
362
+ if state_or_province_name
363
+ else None,
364
+ locality_name=str(locality_name[0].value) if locality_name else None,
365
+ organization=str(organization_name[0].value) if organization_name else None,
366
+ organizational_unit=str(organizational_unit[0].value) if organizational_unit else None,
367
+ email_address=str(email_address[0].value) if email_address else None,
368
+ sans_dns=frozenset(sans_dns),
369
+ sans_ip=frozenset(sans_ip),
370
+ sans_oid=frozenset(sans_oid),
371
+ expiry_time=expiry_time,
372
+ validity_start_time=validity_start_time,
373
+ )
374
+
375
+ def matches_private_key(self, private_key: PrivateKey) -> bool:
376
+ """Check if this certificate matches a given private key.
377
+
378
+ Args:
379
+ private_key (PrivateKey): The private key to validate against.
380
+
381
+ Returns:
382
+ bool: True if the certificate matches the private key, False otherwise.
383
+ """
384
+ try:
385
+ cert_object = x509.load_pem_x509_certificate(self.raw.encode())
386
+ key_object = serialization.load_pem_private_key(
387
+ private_key.raw.encode(), password=None
388
+ )
389
+
390
+ cert_public_key = cert_object.public_key()
391
+ key_public_key = key_object.public_key()
392
+
393
+ if not isinstance(cert_public_key, rsa.RSAPublicKey):
394
+ logger.warning("Certificate does not use RSA public key")
395
+ return False
396
+
397
+ if not isinstance(key_public_key, rsa.RSAPublicKey):
398
+ logger.warning("Private key is not an RSA key")
399
+ return False
400
+
401
+ return cert_public_key.public_numbers() == key_public_key.public_numbers()
402
+ except Exception as e:
403
+ logger.warning("Failed to validate certificate and private key match: %s", e)
404
+ return False
405
+
406
+
407
+ @dataclass(frozen=True)
408
+ class CertificateSigningRequest:
409
+ """This class represents a certificate signing request."""
410
+
411
+ raw: str
412
+ common_name: str
413
+ sans_dns: frozenset[str] | None = None
414
+ sans_ip: frozenset[str] | None = None
415
+ sans_oid: frozenset[str] | None = None
416
+ email_address: str | None = None
417
+ organization: str | None = None
418
+ organizational_unit: str | None = None
419
+ country_name: str | None = None
420
+ state_or_province_name: str | None = None
421
+ locality_name: str | None = None
422
+ has_unique_identifier: bool = False
423
+
424
+ def __eq__(self, other: object) -> bool:
425
+ """Check if two CertificateSigningRequest objects are equal."""
426
+ if not isinstance(other, CertificateSigningRequest):
427
+ return NotImplemented
428
+ return self.raw.strip() == other.raw.strip()
429
+
430
+ def __str__(self) -> str:
431
+ """Return the CSR as a string."""
432
+ return self.raw
433
+
434
+ @classmethod
435
+ def from_string(cls, csr: str) -> CertificateSigningRequest:
436
+ """Create a CertificateSigningRequest object from a CSR."""
437
+ try:
438
+ csr_object = x509.load_pem_x509_csr(csr.encode())
439
+ except ValueError as e:
440
+ logger.error("Could not load CSR: %s", e)
441
+ raise TLSCertificatesError("Could not load CSR")
442
+ common_name = csr_object.subject.get_attributes_for_oid(NameOID.COMMON_NAME)
443
+ country_name = csr_object.subject.get_attributes_for_oid(NameOID.COUNTRY_NAME)
444
+ state_or_province_name = csr_object.subject.get_attributes_for_oid(
445
+ NameOID.STATE_OR_PROVINCE_NAME
446
+ )
447
+ locality_name = csr_object.subject.get_attributes_for_oid(NameOID.LOCALITY_NAME)
448
+ organization_name = csr_object.subject.get_attributes_for_oid(NameOID.ORGANIZATION_NAME)
449
+ organizational_unit = csr_object.subject.get_attributes_for_oid(
450
+ NameOID.ORGANIZATIONAL_UNIT_NAME
451
+ )
452
+ email_address = csr_object.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS)
453
+ unique_identifier = csr_object.subject.get_attributes_for_oid(
454
+ NameOID.X500_UNIQUE_IDENTIFIER
455
+ )
456
+ sans: x509.SubjectAlternativeName | frozenset[str]
457
+ sans_dns: frozenset[str]
458
+ sans_ip: frozenset[str]
459
+ sans_oid: frozenset[str]
460
+ try:
461
+ sans = csr_object.extensions.get_extension_for_class(x509.SubjectAlternativeName).value
462
+ sans_dns = frozenset(sans.get_values_for_type(x509.DNSName))
463
+ sans_ip = frozenset([str(san) for san in sans.get_values_for_type(x509.IPAddress)])
464
+ sans_oid = frozenset([
465
+ san.dotted_string for san in sans.get_values_for_type(x509.RegisteredID)
466
+ ])
467
+ except x509.ExtensionNotFound:
468
+ sans = frozenset()
469
+ sans_dns = frozenset()
470
+ sans_ip = frozenset()
471
+ sans_oid = frozenset()
472
+ return cls(
473
+ raw=csr.strip(),
474
+ common_name=str(common_name[0].value),
475
+ country_name=str(country_name[0].value) if country_name else None,
476
+ state_or_province_name=str(state_or_province_name[0].value)
477
+ if state_or_province_name
478
+ else None,
479
+ locality_name=str(locality_name[0].value) if locality_name else None,
480
+ organization=str(organization_name[0].value) if organization_name else None,
481
+ organizational_unit=str(organizational_unit[0].value) if organizational_unit else None,
482
+ email_address=str(email_address[0].value) if email_address else None,
483
+ sans_dns=sans_dns,
484
+ sans_ip=sans_ip,
485
+ sans_oid=sans_oid,
486
+ has_unique_identifier=bool(unique_identifier),
487
+ )
488
+
489
+ def matches_private_key(self, key: PrivateKey) -> bool:
490
+ """Check if a CSR matches a private key.
491
+
492
+ This function only works with RSA keys.
493
+
494
+ Args:
495
+ key (PrivateKey): Private key
496
+ Returns:
497
+ bool: True/False depending on whether the CSR matches the private key.
498
+ """
499
+ try:
500
+ csr_object = x509.load_pem_x509_csr(self.raw.encode("utf-8"))
501
+ key_object = serialization.load_pem_private_key(
502
+ data=key.raw.encode("utf-8"), password=None
503
+ )
504
+ key_object_public_key = key_object.public_key()
505
+ csr_object_public_key = csr_object.public_key()
506
+ if not isinstance(key_object_public_key, rsa.RSAPublicKey):
507
+ logger.warning("Key is not an RSA key")
508
+ return False
509
+ if not isinstance(csr_object_public_key, rsa.RSAPublicKey):
510
+ logger.warning("CSR is not an RSA key")
511
+ return False
512
+ if (
513
+ csr_object_public_key.public_numbers().n
514
+ != key_object_public_key.public_numbers().n
515
+ ):
516
+ logger.warning("Public key numbers between CSR and key do not match")
517
+ return False
518
+ except ValueError:
519
+ logger.warning("Could not load certificate or CSR.")
520
+ return False
521
+ return True
522
+
523
+ def matches_certificate(self, certificate: Certificate) -> bool:
524
+ """Check if a CSR matches a certificate.
525
+
526
+ Args:
527
+ certificate (Certificate): Certificate
528
+ Returns:
529
+ bool: True/False depending on whether the CSR matches the certificate.
530
+ """
531
+ csr_object = x509.load_pem_x509_csr(self.raw.encode("utf-8"))
532
+ cert_object = x509.load_pem_x509_certificate(certificate.raw.encode("utf-8"))
533
+ return csr_object.public_key() == cert_object.public_key()
534
+
535
+ def get_sha256_hex(self) -> str:
536
+ """Calculate the hash of the provided data and return the hexadecimal representation."""
537
+ digest = hashes.Hash(hashes.SHA256())
538
+ digest.update(self.raw.encode())
539
+ return digest.finalize().hex()
540
+
541
+
542
+ @dataclass(frozen=True)
543
+ class CertificateRequestAttributes:
544
+ """A representation of the certificate request attributes.
545
+
546
+ This class should be used inside the requirer charm to specify the requested
547
+ attributes for the certificate.
548
+ """
549
+
550
+ common_name: str
551
+ sans_dns: frozenset[str] | None = frozenset()
552
+ sans_ip: frozenset[str] | None = frozenset()
553
+ sans_oid: frozenset[str] | None = frozenset()
554
+ email_address: str | None = None
555
+ organization: str | None = None
556
+ organizational_unit: str | None = None
557
+ country_name: str | None = None
558
+ state_or_province_name: str | None = None
559
+ locality_name: str | None = None
560
+ is_ca: bool = False
561
+ add_unique_id_to_subject_name: bool = True
562
+
563
+ def is_valid(self) -> bool:
564
+ """Check whether the certificate request is valid."""
565
+ if not self.common_name:
566
+ return False
567
+ return True
568
+
569
+ def generate_csr(
570
+ self,
571
+ private_key: PrivateKey,
572
+ ) -> CertificateSigningRequest:
573
+ """Generate a CSR using private key and subject.
574
+
575
+ Args:
576
+ private_key (PrivateKey): Private key
577
+
578
+ Returns:
579
+ CertificateSigningRequest: CSR
580
+ """
581
+ return generate_csr(
582
+ private_key=private_key,
583
+ common_name=self.common_name,
584
+ sans_dns=self.sans_dns,
585
+ sans_ip=self.sans_ip,
586
+ sans_oid=self.sans_oid,
587
+ email_address=self.email_address,
588
+ organization=self.organization,
589
+ organizational_unit=self.organizational_unit,
590
+ country_name=self.country_name,
591
+ state_or_province_name=self.state_or_province_name,
592
+ locality_name=self.locality_name,
593
+ add_unique_id_to_subject_name=self.add_unique_id_to_subject_name,
594
+ )
595
+
596
+ @classmethod
597
+ def from_csr(cls, csr: CertificateSigningRequest, is_ca: bool):
598
+ """Create a ``CertificateRequestAttributes`` object from a CSR."""
599
+ return cls(
600
+ common_name=csr.common_name,
601
+ sans_dns=csr.sans_dns,
602
+ sans_ip=csr.sans_ip,
603
+ sans_oid=csr.sans_oid,
604
+ email_address=csr.email_address,
605
+ organization=csr.organization,
606
+ organizational_unit=csr.organizational_unit,
607
+ country_name=csr.country_name,
608
+ state_or_province_name=csr.state_or_province_name,
609
+ locality_name=csr.locality_name,
610
+ is_ca=is_ca,
611
+ add_unique_id_to_subject_name=csr.has_unique_identifier,
612
+ )
613
+
614
+
615
+ @dataclass(frozen=True)
616
+ class ProviderCertificate:
617
+ """This class represents a certificate provided by the TLS provider."""
618
+
619
+ relation_id: int
620
+ certificate: Certificate
621
+ certificate_signing_request: CertificateSigningRequest
622
+ ca: Certificate
623
+ chain: list[Certificate]
624
+ revoked: bool | None = None
625
+
626
+ def to_json(self) -> str:
627
+ """Return the object as a JSON string.
628
+
629
+ Returns:
630
+ str: JSON representation of the object
631
+ """
632
+ return json.dumps({
633
+ "csr": str(self.certificate_signing_request),
634
+ "certificate": str(self.certificate),
635
+ "ca": str(self.ca),
636
+ "chain": [str(cert) for cert in self.chain],
637
+ "revoked": self.revoked,
638
+ })
639
+
640
+
641
+ @dataclass(frozen=True)
642
+ class RequirerCertificateRequest:
643
+ """This class represents a certificate signing request requested by a specific TLS requirer."""
644
+
645
+ relation_id: int
646
+ certificate_signing_request: CertificateSigningRequest
647
+ is_ca: bool
648
+
649
+
650
+ class CertificateAvailableEvent(EventBase):
651
+ """Charm Event triggered when a TLS certificate is available."""
652
+
653
+ def __init__(
654
+ self,
655
+ handle: Handle,
656
+ certificate: Certificate,
657
+ certificate_signing_request: CertificateSigningRequest,
658
+ ca: Certificate,
659
+ chain: list[Certificate],
660
+ ):
661
+ super().__init__(handle)
662
+ self.certificate = certificate
663
+ self.certificate_signing_request = certificate_signing_request
664
+ self.ca = ca
665
+ self.chain = chain
666
+
667
+ def snapshot(self) -> dict[str, str]:
668
+ """Return snapshot."""
669
+ return {
670
+ "certificate": str(self.certificate),
671
+ "certificate_signing_request": str(self.certificate_signing_request),
672
+ "ca": str(self.ca),
673
+ "chain": json.dumps([str(certificate) for certificate in self.chain]),
674
+ }
675
+
676
+ def restore(self, snapshot: dict[str, str]):
677
+ """Restore snapshot."""
678
+ self.certificate = Certificate.from_string(snapshot["certificate"])
679
+ self.certificate_signing_request = CertificateSigningRequest.from_string(
680
+ snapshot["certificate_signing_request"]
681
+ )
682
+ self.ca = Certificate.from_string(snapshot["ca"])
683
+ chain_strs = json.loads(snapshot["chain"])
684
+ self.chain = [Certificate.from_string(chain_str) for chain_str in chain_strs]
685
+
686
+ def chain_as_pem(self) -> str:
687
+ """Return full certificate chain as a PEM string."""
688
+ return "\n\n".join([str(cert) for cert in self.chain])
689
+
690
+
691
+ def generate_private_key(
692
+ key_size: int = 2048,
693
+ public_exponent: int = 65537,
694
+ ) -> PrivateKey:
695
+ """Generate a private key with the RSA algorithm.
696
+
697
+ Args:
698
+ key_size (int): Key size in bits, must be at least 2048 bits
699
+ public_exponent: Public exponent.
700
+
701
+ Returns:
702
+ PrivateKey: Private Key
703
+ """
704
+ if key_size < 2048:
705
+ raise ValueError("Key size must be at least 2048 bits for RSA security")
706
+ private_key = rsa.generate_private_key(
707
+ public_exponent=public_exponent,
708
+ key_size=key_size,
709
+ )
710
+ key_bytes = private_key.private_bytes(
711
+ encoding=serialization.Encoding.PEM,
712
+ format=serialization.PrivateFormat.TraditionalOpenSSL,
713
+ encryption_algorithm=serialization.NoEncryption(),
714
+ )
715
+ return PrivateKey.from_string(key_bytes.decode())
716
+
717
+
718
+ def calculate_relative_datetime(target_time: datetime, fraction: float) -> datetime:
719
+ """Calculate a datetime that is a given percentage from now to a target time.
720
+
721
+ Args:
722
+ target_time (datetime): The future datetime to interpolate towards.
723
+ fraction (float): Fraction of the interval from now to target_time (0.0-1.0).
724
+ 1.0 means return target_time,
725
+ 0.9 means return the time after 90% of the interval has passed,
726
+ and 0.0 means return now.
727
+ """
728
+ if fraction <= 0.0 or fraction > 1.0:
729
+ raise ValueError("Invalid fraction. Must be between 0.0 and 1.0")
730
+ now = datetime.now(timezone.utc)
731
+ time_until_target = target_time - now
732
+ return now + time_until_target * fraction
733
+
734
+
735
+ def chain_has_valid_order(chain: list[str]) -> bool:
736
+ """Check if the chain has a valid order.
737
+
738
+ Validates that each certificate in the chain is properly signed by the next certificate.
739
+ The chain should be ordered from leaf to root, where each certificate is signed by
740
+ the next one in the chain.
741
+
742
+ Args:
743
+ chain (List[str]): List of certificates in PEM format, ordered from leaf to root
744
+
745
+ Returns:
746
+ bool: True if the chain has a valid order, False otherwise.
747
+ """
748
+ if len(chain) < 2:
749
+ return True
750
+
751
+ try:
752
+ for i in range(len(chain) - 1):
753
+ cert = x509.load_pem_x509_certificate(chain[i].encode())
754
+ issuer = x509.load_pem_x509_certificate(chain[i + 1].encode())
755
+ cert.verify_directly_issued_by(issuer)
756
+ return True
757
+ except (ValueError, TypeError, InvalidSignature):
758
+ return False
759
+
760
+
761
+ def generate_csr(
762
+ private_key: PrivateKey,
763
+ common_name: str,
764
+ sans_dns: frozenset[str] | None = frozenset(),
765
+ sans_ip: frozenset[str] | None = frozenset(),
766
+ sans_oid: frozenset[str] | None = frozenset(),
767
+ organization: str | None = None,
768
+ organizational_unit: str | None = None,
769
+ email_address: str | None = None,
770
+ country_name: str | None = None,
771
+ locality_name: str | None = None,
772
+ state_or_province_name: str | None = None,
773
+ add_unique_id_to_subject_name: bool = True,
774
+ ) -> CertificateSigningRequest:
775
+ """Generate a CSR using private key and subject.
776
+
777
+ Args:
778
+ private_key (PrivateKey): Private key
779
+ common_name (str): Common name
780
+ sans_dns (FrozenSet[str]): DNS Subject Alternative Names
781
+ sans_ip (FrozenSet[str]): IP Subject Alternative Names
782
+ sans_oid (FrozenSet[str]): OID Subject Alternative Names
783
+ organization (Optional[str]): Organization name
784
+ organizational_unit (Optional[str]): Organizational unit name
785
+ email_address (Optional[str]): Email address
786
+ country_name (Optional[str]): Country name
787
+ state_or_province_name (Optional[str]): State or province name
788
+ locality_name (Optional[str]): Locality name
789
+ add_unique_id_to_subject_name (bool): Whether a unique ID must be added to the CSR's
790
+ subject name. Always leave to "True" when the CSR is used to request certificates
791
+ using the tls-certificates relation.
792
+
793
+ Returns:
794
+ CertificateSigningRequest: CSR
795
+ """
796
+ signing_key = serialization.load_pem_private_key(str(private_key).encode(), password=None)
797
+ subject_name = [x509.NameAttribute(x509.NameOID.COMMON_NAME, common_name)]
798
+ if add_unique_id_to_subject_name:
799
+ unique_identifier = uuid.uuid4()
800
+ subject_name.append(
801
+ x509.NameAttribute(x509.NameOID.X500_UNIQUE_IDENTIFIER, str(unique_identifier))
802
+ )
803
+ if organization:
804
+ subject_name.append(x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, organization))
805
+ if organizational_unit:
806
+ subject_name.append(
807
+ x509.NameAttribute(x509.NameOID.ORGANIZATIONAL_UNIT_NAME, organizational_unit)
808
+ )
809
+ if email_address:
810
+ subject_name.append(x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, email_address))
811
+ if country_name:
812
+ subject_name.append(x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country_name))
813
+ if state_or_province_name:
814
+ subject_name.append(
815
+ x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, state_or_province_name)
816
+ )
817
+ if locality_name:
818
+ subject_name.append(x509.NameAttribute(x509.NameOID.LOCALITY_NAME, locality_name))
819
+ csr = x509.CertificateSigningRequestBuilder(subject_name=x509.Name(subject_name))
820
+
821
+ sans: list[x509.GeneralName] = []
822
+ if sans_oid:
823
+ sans.extend([x509.RegisteredID(x509.ObjectIdentifier(san)) for san in sans_oid])
824
+ if sans_ip:
825
+ sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in sans_ip])
826
+ if sans_dns:
827
+ sans.extend([x509.DNSName(san) for san in sans_dns])
828
+ if sans:
829
+ csr = csr.add_extension(x509.SubjectAlternativeName(set(sans)), critical=False)
830
+ signed_certificate = csr.sign(signing_key, hashes.SHA256()) # type: ignore[arg-type]
831
+ csr_str = signed_certificate.public_bytes(serialization.Encoding.PEM).decode()
832
+ return CertificateSigningRequest.from_string(csr_str)
833
+
834
+
835
+ def generate_ca(
836
+ private_key: PrivateKey,
837
+ validity: timedelta,
838
+ common_name: str,
839
+ sans_dns: frozenset[str] | None = frozenset(),
840
+ sans_ip: frozenset[str] | None = frozenset(),
841
+ sans_oid: frozenset[str] | None = frozenset(),
842
+ organization: str | None = None,
843
+ organizational_unit: str | None = None,
844
+ email_address: str | None = None,
845
+ country_name: str | None = None,
846
+ state_or_province_name: str | None = None,
847
+ locality_name: str | None = None,
848
+ ) -> Certificate:
849
+ """Generate a self signed CA Certificate.
850
+
851
+ Args:
852
+ private_key (PrivateKey): Private key
853
+ validity (timedelta): Certificate validity time
854
+ common_name (str): Common Name that can be an IP or a Full Qualified Domain Name (FQDN).
855
+ sans_dns (FrozenSet[str]): DNS Subject Alternative Names
856
+ sans_ip (FrozenSet[str]): IP Subject Alternative Names
857
+ sans_oid (FrozenSet[str]): OID Subject Alternative Names
858
+ organization (Optional[str]): Organization name
859
+ organizational_unit (Optional[str]): Organizational unit name
860
+ email_address (Optional[str]): Email address
861
+ country_name (str): Certificate Issuing country
862
+ state_or_province_name (str): Certificate Issuing state or province
863
+ locality_name (str): Certificate Issuing locality
864
+
865
+ Returns:
866
+ Certificate: CA Certificate.
867
+ """
868
+ private_key_object = serialization.load_pem_private_key(
869
+ str(private_key).encode(), password=None
870
+ )
871
+ assert isinstance(private_key_object, rsa.RSAPrivateKey)
872
+ subject_name = [x509.NameAttribute(x509.NameOID.COMMON_NAME, common_name)]
873
+ if organization:
874
+ subject_name.append(x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, organization))
875
+ if organizational_unit:
876
+ subject_name.append(
877
+ x509.NameAttribute(x509.NameOID.ORGANIZATIONAL_UNIT_NAME, organizational_unit)
878
+ )
879
+ if email_address:
880
+ subject_name.append(x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, email_address))
881
+ if country_name:
882
+ subject_name.append(x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country_name))
883
+ if state_or_province_name:
884
+ subject_name.append(
885
+ x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, state_or_province_name)
886
+ )
887
+ if locality_name:
888
+ subject_name.append(x509.NameAttribute(x509.NameOID.LOCALITY_NAME, locality_name))
889
+
890
+ subject_identifier_object = x509.SubjectKeyIdentifier.from_public_key(
891
+ private_key_object.public_key()
892
+ )
893
+ subject_identifier = key_identifier = subject_identifier_object.public_bytes()
894
+ key_usage = x509.KeyUsage(
895
+ digital_signature=True,
896
+ key_encipherment=True,
897
+ key_cert_sign=True,
898
+ key_agreement=False,
899
+ content_commitment=False,
900
+ data_encipherment=False,
901
+ crl_sign=False,
902
+ encipher_only=False,
903
+ decipher_only=False,
904
+ )
905
+
906
+ builder = (
907
+ x509.CertificateBuilder()
908
+ .subject_name(x509.Name(subject_name))
909
+ .issuer_name(x509.Name(subject_name))
910
+ .public_key(private_key_object.public_key())
911
+ .serial_number(x509.random_serial_number())
912
+ .not_valid_before(datetime.now(timezone.utc))
913
+ .not_valid_after(datetime.now(timezone.utc) + validity)
914
+ .add_extension(x509.SubjectKeyIdentifier(digest=subject_identifier), critical=False)
915
+ .add_extension(
916
+ x509.AuthorityKeyIdentifier(
917
+ key_identifier=key_identifier,
918
+ authority_cert_issuer=None,
919
+ authority_cert_serial_number=None,
920
+ ),
921
+ critical=False,
922
+ )
923
+ .add_extension(key_usage, critical=True)
924
+ .add_extension(
925
+ x509.BasicConstraints(ca=True, path_length=None),
926
+ critical=True,
927
+ )
928
+ )
929
+ san_extension = _san_extension(
930
+ email_address=email_address,
931
+ sans_dns=sans_dns,
932
+ sans_ip=sans_ip,
933
+ sans_oid=sans_oid,
934
+ )
935
+ if san_extension:
936
+ builder = builder.add_extension(san_extension, critical=False)
937
+ cert = builder.sign(private_key_object, hashes.SHA256())
938
+ ca_cert_str = cert.public_bytes(serialization.Encoding.PEM).decode().strip()
939
+ return Certificate.from_string(ca_cert_str)
940
+
941
+
942
+ def _san_extension(
943
+ email_address: str | None = None,
944
+ sans_dns: frozenset[str] | None = frozenset(),
945
+ sans_ip: frozenset[str] | None = frozenset(),
946
+ sans_oid: frozenset[str] | None = frozenset(),
947
+ ) -> x509.SubjectAlternativeName | None:
948
+ sans: list[x509.GeneralName] = []
949
+ if email_address:
950
+ # If an e-mail address was provided, it should always be in the SAN
951
+ sans.append(x509.RFC822Name(email_address))
952
+ if sans_dns:
953
+ sans.extend([x509.DNSName(san) for san in sans_dns])
954
+ if sans_ip:
955
+ sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in sans_ip])
956
+ if sans_oid:
957
+ sans.extend([x509.RegisteredID(x509.ObjectIdentifier(san)) for san in sans_oid])
958
+ if not sans:
959
+ return None
960
+ return x509.SubjectAlternativeName(sans)
961
+
962
+
963
+ def generate_certificate(
964
+ csr: CertificateSigningRequest,
965
+ ca: Certificate,
966
+ ca_private_key: PrivateKey,
967
+ validity: timedelta,
968
+ is_ca: bool = False,
969
+ ) -> Certificate:
970
+ """Generate a TLS certificate based on a CSR.
971
+
972
+ Args:
973
+ csr (CertificateSigningRequest): CSR
974
+ ca (Certificate): CA Certificate
975
+ ca_private_key (PrivateKey): CA private key
976
+ validity (timedelta): Certificate validity time
977
+ is_ca (bool): Whether the certificate is a CA certificate
978
+
979
+ Returns:
980
+ Certificate: Certificate
981
+ """
982
+ csr_object = x509.load_pem_x509_csr(str(csr).encode())
983
+ subject = csr_object.subject
984
+ ca_pem = x509.load_pem_x509_certificate(str(ca).encode())
985
+ issuer = ca_pem.issuer
986
+ private_key = serialization.load_pem_private_key(str(ca_private_key).encode(), password=None)
987
+
988
+ certificate_builder = (
989
+ x509.CertificateBuilder()
990
+ .subject_name(subject)
991
+ .issuer_name(issuer)
992
+ .public_key(csr_object.public_key())
993
+ .serial_number(x509.random_serial_number())
994
+ .not_valid_before(datetime.now(timezone.utc))
995
+ .not_valid_after(datetime.now(timezone.utc) + validity)
996
+ )
997
+ extensions = _generate_certificate_request_extensions(
998
+ authority_key_identifier=ca_pem.extensions.get_extension_for_class(
999
+ x509.SubjectKeyIdentifier
1000
+ ).value.key_identifier,
1001
+ csr=csr_object,
1002
+ is_ca=is_ca,
1003
+ )
1004
+ for extension in extensions:
1005
+ try:
1006
+ certificate_builder = certificate_builder.add_extension(
1007
+ extval=extension.value,
1008
+ critical=extension.critical,
1009
+ )
1010
+ except ValueError as e:
1011
+ logger.warning("Failed to add extension %s: %s", extension.oid, e)
1012
+
1013
+ cert = certificate_builder.sign(private_key, hashes.SHA256()) # type: ignore[arg-type]
1014
+ cert_bytes = cert.public_bytes(serialization.Encoding.PEM)
1015
+ return Certificate.from_string(cert_bytes.decode().strip())
1016
+
1017
+
1018
+ def _generate_certificate_request_extensions(
1019
+ authority_key_identifier: bytes,
1020
+ csr: x509.CertificateSigningRequest,
1021
+ is_ca: bool,
1022
+ ) -> list[x509.Extension[Any]]:
1023
+ """Generate a list of certificate extensions from a CSR and other known information.
1024
+
1025
+ Args:
1026
+ authority_key_identifier (bytes): Authority key identifier
1027
+ csr (x509.CertificateSigningRequest): CSR
1028
+ is_ca (bool): Whether the certificate is a CA certificate
1029
+
1030
+ Returns:
1031
+ List[x509.Extension]: List of extensions
1032
+ """
1033
+ cert_extensions_list: list[x509.Extension[Any]] = [
1034
+ x509.Extension(
1035
+ oid=ExtensionOID.AUTHORITY_KEY_IDENTIFIER,
1036
+ value=x509.AuthorityKeyIdentifier(
1037
+ key_identifier=authority_key_identifier,
1038
+ authority_cert_issuer=None,
1039
+ authority_cert_serial_number=None,
1040
+ ),
1041
+ critical=False,
1042
+ ),
1043
+ x509.Extension(
1044
+ oid=ExtensionOID.SUBJECT_KEY_IDENTIFIER,
1045
+ value=x509.SubjectKeyIdentifier.from_public_key(csr.public_key()),
1046
+ critical=False,
1047
+ ),
1048
+ x509.Extension(
1049
+ oid=ExtensionOID.BASIC_CONSTRAINTS,
1050
+ critical=True,
1051
+ value=x509.BasicConstraints(ca=is_ca, path_length=None),
1052
+ ),
1053
+ ]
1054
+ if sans := _generate_subject_alternative_name_extension(csr):
1055
+ cert_extensions_list.append(sans)
1056
+
1057
+ if is_ca:
1058
+ cert_extensions_list.append(
1059
+ x509.Extension(
1060
+ ExtensionOID.KEY_USAGE,
1061
+ critical=True,
1062
+ value=x509.KeyUsage(
1063
+ digital_signature=False,
1064
+ content_commitment=False,
1065
+ key_encipherment=False,
1066
+ data_encipherment=False,
1067
+ key_agreement=False,
1068
+ key_cert_sign=True,
1069
+ crl_sign=True,
1070
+ encipher_only=False,
1071
+ decipher_only=False,
1072
+ ),
1073
+ )
1074
+ )
1075
+
1076
+ existing_oids = {ext.oid for ext in cert_extensions_list}
1077
+ for extension in csr.extensions:
1078
+ if extension.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME:
1079
+ continue
1080
+ if extension.oid in existing_oids:
1081
+ logger.warning("Extension %s is managed by the TLS provider, ignoring.", extension.oid)
1082
+ continue
1083
+ cert_extensions_list.append(extension)
1084
+
1085
+ return cert_extensions_list
1086
+
1087
+
1088
+ def _generate_subject_alternative_name_extension(
1089
+ csr: x509.CertificateSigningRequest,
1090
+ ) -> x509.Extension[Any] | None:
1091
+ sans: list[x509.GeneralName] = []
1092
+ try:
1093
+ loaded_san_ext = csr.extensions.get_extension_for_class(x509.SubjectAlternativeName)
1094
+ sans.extend([
1095
+ x509.DNSName(name) for name in loaded_san_ext.value.get_values_for_type(x509.DNSName)
1096
+ ])
1097
+ sans.extend([
1098
+ x509.IPAddress(ip) for ip in loaded_san_ext.value.get_values_for_type(x509.IPAddress)
1099
+ ])
1100
+ sans.extend([
1101
+ x509.RegisteredID(oid)
1102
+ for oid in loaded_san_ext.value.get_values_for_type(x509.RegisteredID)
1103
+ ])
1104
+ sans.extend([
1105
+ x509.RFC822Name(name)
1106
+ for name in loaded_san_ext.value.get_values_for_type(x509.RFC822Name)
1107
+ ])
1108
+ except x509.ExtensionNotFound:
1109
+ pass
1110
+ # If email is present in the CSR Subject, make sure it is also in the SANS
1111
+ # to conform to RFC 5280.
1112
+ email = csr.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS)
1113
+ if email:
1114
+ email_rfc822 = x509.RFC822Name(str(email[0].value))
1115
+ if email_rfc822 not in sans:
1116
+ sans.append(email_rfc822)
1117
+
1118
+ return (
1119
+ x509.Extension(
1120
+ oid=ExtensionOID.SUBJECT_ALTERNATIVE_NAME,
1121
+ critical=False,
1122
+ value=x509.SubjectAlternativeName(sans),
1123
+ )
1124
+ if sans
1125
+ else None
1126
+ )
1127
+
1128
+
1129
+ class CertificatesRequirerCharmEvents(CharmEvents):
1130
+ """List of events that the TLS Certificates requirer charm can leverage."""
1131
+
1132
+ certificate_available = EventSource(CertificateAvailableEvent)
1133
+
1134
+
1135
+ class TLSCertificatesRequiresV4(Object):
1136
+ """A class to manage the TLS certificates interface for a unit or app."""
1137
+
1138
+ on = CertificatesRequirerCharmEvents() # type: ignore[reportAssignmentType]
1139
+
1140
+ def __init__(
1141
+ self,
1142
+ charm: CharmBase,
1143
+ relationship_name: str,
1144
+ certificate_requests: list[CertificateRequestAttributes],
1145
+ mode: Mode = Mode.UNIT,
1146
+ refresh_events: list[BoundEvent] | None = None,
1147
+ private_key: PrivateKey | None = None,
1148
+ renewal_relative_time: float = 0.9,
1149
+ ):
1150
+ """Create a new instance of the TLSCertificatesRequiresV4 class.
1151
+
1152
+ Args:
1153
+ charm (CharmBase): The charm instance to relate to.
1154
+ relationship_name (str): The name of the relation that provides the certificates.
1155
+ certificate_requests (List[CertificateRequestAttributes]):
1156
+ A list with the attributes of the certificate requests.
1157
+ mode (Mode): Whether to use unit or app certificates mode. Default is Mode.UNIT.
1158
+ In UNIT mode the requirer will place the csr in the unit relation data.
1159
+ Each unit will manage its private key,
1160
+ certificate signing request and certificate.
1161
+ UNIT mode is for use cases where each unit has its own identity.
1162
+ If you don't know which mode to use, you likely need UNIT.
1163
+ In APP mode the leader unit will place the csr in the app relation databag.
1164
+ APP mode is for use cases where the underlying application needs the certificate
1165
+ for example using it as an intermediate CA to sign other certificates.
1166
+ The certificate can only be accessed by the leader unit.
1167
+ refresh_events (List[BoundEvent]): A list of events to trigger a refresh of
1168
+ the certificates.
1169
+ private_key (Optional[PrivateKey]): The private key to use for the certificates.
1170
+ If provided, it will be used instead of generating a new one.
1171
+ If the key is not valid an exception will be raised.
1172
+ Using this parameter is discouraged,
1173
+ having to pass around private keys manually can be a security concern.
1174
+ Allowing the library to generate and manage the key is the more secure approach.
1175
+ renewal_relative_time (float): The time to renew the certificate relative to its
1176
+ expiry.
1177
+ Default is 0.9, meaning 90% of the validity period.
1178
+ The minimum value is 0.5, meaning 50% of the validity period.
1179
+ If an invalid value is provided, an exception will be raised.
1180
+ """
1181
+ if refresh_events is None:
1182
+ refresh_events = []
1183
+ super().__init__(charm, relationship_name)
1184
+ if not JujuVersion.from_environ().has_secrets:
1185
+ logger.warning("This version of the TLS library requires Juju secrets (Juju >= 3.0)")
1186
+ if not self._mode_is_valid(mode):
1187
+ raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP")
1188
+ for certificate_request in certificate_requests:
1189
+ if not certificate_request.is_valid():
1190
+ raise TLSCertificatesError("Invalid certificate request")
1191
+ self.charm = charm
1192
+ self.relationship_name = relationship_name
1193
+ self.certificate_requests = certificate_requests
1194
+ self.mode = mode
1195
+ if private_key and not private_key.is_valid():
1196
+ raise TLSCertificatesError("Invalid private key")
1197
+ if renewal_relative_time <= 0.5 or renewal_relative_time > 1.0:
1198
+ raise TLSCertificatesError(
1199
+ "Invalid renewal relative time. Must be between 0.0 and 1.0"
1200
+ )
1201
+ self._private_key = private_key
1202
+ self.renewal_relative_time = renewal_relative_time
1203
+ self.framework.observe(charm.on[relationship_name].relation_created, self._configure)
1204
+ self.framework.observe(charm.on[relationship_name].relation_changed, self._configure)
1205
+ self.framework.observe(charm.on.secret_expired, self._on_secret_expired)
1206
+ self.framework.observe(charm.on.secret_remove, self._on_secret_remove)
1207
+ for event in refresh_events:
1208
+ self.framework.observe(event, self._configure)
1209
+
1210
+ def _configure(self, _: EventBase | None = None):
1211
+ """Handle TLS Certificates Relation Data.
1212
+
1213
+ This method is called during any TLS relation event.
1214
+ It will generate a private key if it doesn't exist yet.
1215
+ It will send certificate requests if they haven't been sent yet.
1216
+ It will find available certificates and emit events.
1217
+ """
1218
+ if not self._tls_relation_created():
1219
+ logger.debug("TLS relation not created yet.")
1220
+ return
1221
+ self._ensure_private_key()
1222
+ self._cleanup_certificate_requests()
1223
+ self._send_certificate_requests()
1224
+ self._find_available_certificates()
1225
+
1226
+ def _mode_is_valid(self, mode: Mode) -> bool:
1227
+ return mode in [Mode.UNIT, Mode.APP]
1228
+
1229
+ def _validate_secret_exists(self, secret: Secret) -> None:
1230
+ secret.get_info() # Will raise `SecretNotFoundError` if the secret does not exist
1231
+
1232
+ def _on_secret_remove(self, event: SecretRemoveEvent) -> None:
1233
+ """Handle Secret Removed Event."""
1234
+ try:
1235
+ # Ensure the secret exists before trying to remove it, otherwise
1236
+ # the unit could be stuck in an error state. See the docstring of
1237
+ # `remove_revision` and the below issue for more information.
1238
+ # https://github.com/juju/juju/issues/19036
1239
+ self._validate_secret_exists(event.secret)
1240
+ event.secret.remove_revision(event.revision)
1241
+ except SecretNotFoundError:
1242
+ logger.warning(
1243
+ "No such secret %s, nothing to remove",
1244
+ event.secret.label or event.secret.id,
1245
+ )
1246
+ return
1247
+
1248
+ def _on_secret_expired(self, event: SecretExpiredEvent) -> None:
1249
+ """Handle Secret Expired Event.
1250
+
1251
+ Renews certificate requests and removes the expired secret.
1252
+ """
1253
+ if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-certificate"):
1254
+ return
1255
+ try:
1256
+ csr_str = event.secret.get_content(refresh=True)["csr"]
1257
+ except ModelError:
1258
+ logger.error("Failed to get CSR from secret - Skipping")
1259
+ return
1260
+ csr = CertificateSigningRequest.from_string(csr_str)
1261
+ self._renew_certificate_request(csr)
1262
+ event.secret.remove_all_revisions()
1263
+
1264
+ def sync(self) -> None:
1265
+ """Sync TLS Certificates Relation Data.
1266
+
1267
+ This method allows the requirer to sync the TLS certificates relation data
1268
+ without waiting for the refresh events to be triggered.
1269
+ """
1270
+ self._configure()
1271
+
1272
+ def renew_certificate(self, certificate: ProviderCertificate) -> None:
1273
+ """Request the renewal of the provided certificate."""
1274
+ certificate_signing_request = certificate.certificate_signing_request
1275
+ secret_label = self._get_csr_secret_label(certificate_signing_request)
1276
+ try:
1277
+ secret = self.model.get_secret(label=secret_label)
1278
+ except SecretNotFoundError:
1279
+ logger.warning("No matching secret found - Skipping renewal")
1280
+ return
1281
+ current_csr = secret.get_content(refresh=True).get("csr", "")
1282
+ if current_csr != str(certificate_signing_request):
1283
+ logger.warning("No matching CSR found - Skipping renewal")
1284
+ return
1285
+ self._renew_certificate_request(certificate_signing_request)
1286
+ secret.remove_all_revisions()
1287
+
1288
+ def _renew_certificate_request(self, csr: CertificateSigningRequest):
1289
+ """Remove existing CSR from relation data and create a new one."""
1290
+ self._remove_requirer_csr_from_relation_data(csr)
1291
+ self._send_certificate_requests()
1292
+ logger.info("Renewed certificate request")
1293
+
1294
+ def _remove_requirer_csr_from_relation_data(self, csr: CertificateSigningRequest) -> None:
1295
+ relation = self.model.get_relation(self.relationship_name)
1296
+ if not relation:
1297
+ logger.debug("No relation: %s", self.relationship_name)
1298
+ return
1299
+ if not self.get_csrs_from_requirer_relation_data():
1300
+ logger.info("No CSRs in relation data - Doing nothing")
1301
+ return
1302
+ app_or_unit = self._get_app_or_unit()
1303
+ try:
1304
+ requirer_relation_data = _RequirerData.load(relation.data[app_or_unit])
1305
+ except DataValidationError:
1306
+ logger.warning("Invalid relation data - Skipping removal of CSR")
1307
+ return
1308
+ new_relation_data = copy.deepcopy(requirer_relation_data.certificate_signing_requests)
1309
+ for requirer_csr in new_relation_data:
1310
+ if requirer_csr.certificate_signing_request.strip() == str(csr).strip():
1311
+ new_relation_data.remove(requirer_csr)
1312
+ try:
1313
+ _RequirerData(certificate_signing_requests=new_relation_data).dump(
1314
+ relation.data[app_or_unit]
1315
+ )
1316
+ logger.info("Removed CSR from relation data")
1317
+ except ModelError:
1318
+ logger.warning("Failed to update relation data")
1319
+
1320
+ def _get_app_or_unit(self) -> Application | Unit:
1321
+ """Return the unit or app object based on the mode."""
1322
+ if self.mode == Mode.UNIT:
1323
+ return self.model.unit
1324
+ elif self.mode == Mode.APP:
1325
+ return self.model.app
1326
+ raise TLSCertificatesError("Invalid mode")
1327
+
1328
+ @property
1329
+ def private_key(self) -> PrivateKey | None:
1330
+ """Return the private key."""
1331
+ if self._private_key:
1332
+ return self._private_key
1333
+ if not self._private_key_generated():
1334
+ return None
1335
+ secret = self.charm.model.get_secret(label=self._get_private_key_secret_label())
1336
+ private_key = secret.get_content(refresh=True)["private-key"]
1337
+ return PrivateKey.from_string(private_key)
1338
+
1339
+ def _ensure_private_key(self) -> None:
1340
+ """Make sure there is a private key to be used.
1341
+
1342
+ It will make sure there is a private key passed by the charm using the private_key
1343
+ parameter or generate a new one otherwise.
1344
+ """
1345
+ # Remove the generated private key
1346
+ # if one has been passed by the charm using the private_key parameter
1347
+ if self._private_key:
1348
+ self._remove_private_key_secret()
1349
+ return
1350
+ if self._private_key_generated():
1351
+ logger.debug("Private key already generated")
1352
+ return
1353
+ self._generate_private_key()
1354
+
1355
+ def regenerate_private_key(self) -> None:
1356
+ """Regenerate the private key.
1357
+
1358
+ Generate a new private key, remove old certificate requests and send new ones.
1359
+
1360
+ Raises:
1361
+ TLSCertificatesError: If the private key is passed by the charm using the
1362
+ private_key parameter.
1363
+ """
1364
+ if self._private_key:
1365
+ raise TLSCertificatesError(
1366
+ "Private key is passed by the charm through the private_key parameter, "
1367
+ "this function can't be used"
1368
+ )
1369
+ if not self._private_key_generated():
1370
+ logger.warning("No private key to regenerate")
1371
+ return
1372
+ self._generate_private_key()
1373
+ self._cleanup_certificate_requests()
1374
+ self._send_certificate_requests()
1375
+
1376
+ def _generate_private_key(self) -> None:
1377
+ """Generate a new private key and store it in a secret.
1378
+
1379
+ This is the case when the private key used is generated by the library.
1380
+ and not passed by the charm using the private_key parameter.
1381
+ """
1382
+ self._store_private_key_in_secret(generate_private_key())
1383
+ logger.info("Private key generated")
1384
+
1385
+ def _private_key_generated(self) -> bool:
1386
+ """Check if a private key is stored in a secret.
1387
+
1388
+ This is the case when the private key used is generated by the library.
1389
+ This should not exist when the private key used
1390
+ is passed by the charm using the private_key parameter.
1391
+ """
1392
+ try:
1393
+ secret = self.charm.model.get_secret(label=self._get_private_key_secret_label())
1394
+ secret.get_content(refresh=True)
1395
+ return True
1396
+ except SecretNotFoundError:
1397
+ return False
1398
+
1399
+ def _store_private_key_in_secret(self, private_key: PrivateKey) -> None:
1400
+ try:
1401
+ secret = self.charm.model.get_secret(label=self._get_private_key_secret_label())
1402
+ secret.set_content({"private-key": str(private_key)})
1403
+ secret.get_content(refresh=True)
1404
+ except SecretNotFoundError:
1405
+ self.charm.unit.add_secret(
1406
+ content={"private-key": str(private_key)},
1407
+ label=self._get_private_key_secret_label(),
1408
+ )
1409
+
1410
+ def _remove_private_key_secret(self) -> None:
1411
+ """Remove the private key secret."""
1412
+ try:
1413
+ secret = self.charm.model.get_secret(label=self._get_private_key_secret_label())
1414
+ secret.remove_all_revisions()
1415
+ except SecretNotFoundError:
1416
+ logger.warning("Private key secret not found, nothing to remove")
1417
+
1418
+ def _csr_matches_certificate_request(
1419
+ self, certificate_signing_request: CertificateSigningRequest, is_ca: bool
1420
+ ) -> bool:
1421
+ for certificate_request in self.certificate_requests:
1422
+ if certificate_request == CertificateRequestAttributes.from_csr(
1423
+ certificate_signing_request,
1424
+ is_ca,
1425
+ ):
1426
+ return True
1427
+ return False
1428
+
1429
+ def _certificate_requested(self, certificate_request: CertificateRequestAttributes) -> bool:
1430
+ if not self.private_key:
1431
+ return False
1432
+ csr = self._certificate_requested_for_attributes(certificate_request)
1433
+ if not csr:
1434
+ return False
1435
+ if not csr.certificate_signing_request.matches_private_key(key=self.private_key):
1436
+ return False
1437
+ return True
1438
+
1439
+ def _certificate_requested_for_attributes(
1440
+ self,
1441
+ certificate_request: CertificateRequestAttributes,
1442
+ ) -> RequirerCertificateRequest | None:
1443
+ for requirer_csr in self.get_csrs_from_requirer_relation_data():
1444
+ if certificate_request == CertificateRequestAttributes.from_csr(
1445
+ requirer_csr.certificate_signing_request,
1446
+ requirer_csr.is_ca,
1447
+ ):
1448
+ return requirer_csr
1449
+ return None
1450
+
1451
+ def get_csrs_from_requirer_relation_data(self) -> list[RequirerCertificateRequest]:
1452
+ """Return list of requirer's CSRs from relation data."""
1453
+ if self.mode == Mode.APP and not self.model.unit.is_leader():
1454
+ logger.debug("Not a leader unit - Skipping")
1455
+ return []
1456
+ relation = self.model.get_relation(self.relationship_name)
1457
+ if not relation:
1458
+ logger.debug("No relation: %s", self.relationship_name)
1459
+ return []
1460
+ app_or_unit = self._get_app_or_unit()
1461
+ try:
1462
+ requirer_relation_data = _RequirerData.load(relation.data[app_or_unit])
1463
+ except DataValidationError:
1464
+ logger.warning("Invalid relation data")
1465
+ return []
1466
+ requirer_csrs = [
1467
+ RequirerCertificateRequest(
1468
+ relation_id=relation.id,
1469
+ certificate_signing_request=CertificateSigningRequest.from_string(
1470
+ csr.certificate_signing_request
1471
+ ),
1472
+ is_ca=csr.ca if csr.ca else False,
1473
+ )
1474
+ for csr in requirer_relation_data.certificate_signing_requests
1475
+ ]
1476
+ return requirer_csrs
1477
+
1478
+ def get_provider_certificates(self) -> list[ProviderCertificate]:
1479
+ """Return list of certificates from the provider's relation data."""
1480
+ return self._load_provider_certificates()
1481
+
1482
+ def _load_provider_certificates(self) -> list[ProviderCertificate]:
1483
+ relation = self.model.get_relation(self.relationship_name)
1484
+ if not relation:
1485
+ logger.debug("No relation: %s", self.relationship_name)
1486
+ return []
1487
+ if not relation.app:
1488
+ logger.debug("No remote app in relation: %s", self.relationship_name)
1489
+ return []
1490
+ try:
1491
+ provider_relation_data = _ProviderApplicationData.load(relation.data[relation.app])
1492
+ except DataValidationError:
1493
+ logger.warning("Invalid relation data")
1494
+ return []
1495
+ return [
1496
+ certificate.to_provider_certificate(relation_id=relation.id)
1497
+ for certificate in provider_relation_data.certificates
1498
+ ]
1499
+
1500
+ def _request_certificate(self, csr: CertificateSigningRequest, is_ca: bool) -> None:
1501
+ """Add CSR to relation data."""
1502
+ if self.mode == Mode.APP and not self.model.unit.is_leader():
1503
+ logger.debug("Not a leader unit - Skipping")
1504
+ return
1505
+ relation = self.model.get_relation(self.relationship_name)
1506
+ if not relation:
1507
+ logger.debug("No relation: %s", self.relationship_name)
1508
+ return
1509
+ new_csr = _CertificateSigningRequest(
1510
+ certificate_signing_request=str(csr).strip(), ca=is_ca
1511
+ )
1512
+ app_or_unit = self._get_app_or_unit()
1513
+ try:
1514
+ requirer_relation_data = _RequirerData.load(relation.data[app_or_unit])
1515
+ except DataValidationError:
1516
+ requirer_relation_data = _RequirerData(
1517
+ certificate_signing_requests=[],
1518
+ )
1519
+ new_relation_data = copy.deepcopy(requirer_relation_data.certificate_signing_requests)
1520
+ new_relation_data.append(new_csr)
1521
+ try:
1522
+ _RequirerData(certificate_signing_requests=new_relation_data).dump(
1523
+ relation.data[app_or_unit]
1524
+ )
1525
+ logger.info("Certificate signing request added to relation data.")
1526
+ except ModelError:
1527
+ logger.warning("Failed to update relation data")
1528
+
1529
+ def _send_certificate_requests(self):
1530
+ if not self.private_key:
1531
+ logger.debug("Private key not generated yet.")
1532
+ return
1533
+ for certificate_request in self.certificate_requests:
1534
+ if not self._certificate_requested(certificate_request):
1535
+ csr = certificate_request.generate_csr(
1536
+ private_key=self.private_key,
1537
+ )
1538
+ if not csr:
1539
+ logger.warning("Failed to generate CSR")
1540
+ continue
1541
+ self._request_certificate(csr=csr, is_ca=certificate_request.is_ca)
1542
+
1543
+ def get_assigned_certificate(
1544
+ self, certificate_request: CertificateRequestAttributes
1545
+ ) -> tuple[ProviderCertificate | None, PrivateKey | None]:
1546
+ """Get the certificate that was assigned to the given certificate request."""
1547
+ for requirer_csr in self.get_csrs_from_requirer_relation_data():
1548
+ if certificate_request == CertificateRequestAttributes.from_csr(
1549
+ requirer_csr.certificate_signing_request,
1550
+ requirer_csr.is_ca,
1551
+ ):
1552
+ return self._find_certificate_in_relation_data(requirer_csr), self.private_key
1553
+ return None, None
1554
+
1555
+ def get_assigned_certificates(
1556
+ self,
1557
+ ) -> tuple[list[ProviderCertificate], PrivateKey | None]:
1558
+ """Get a list of certificates that were assigned to this or app."""
1559
+ assigned_certificates = [
1560
+ cert
1561
+ for requirer_csr in self.get_csrs_from_requirer_relation_data()
1562
+ if (cert := self._find_certificate_in_relation_data(requirer_csr))
1563
+ ]
1564
+ return assigned_certificates, self.private_key
1565
+
1566
+ def _find_certificate_in_relation_data(
1567
+ self, csr: RequirerCertificateRequest
1568
+ ) -> ProviderCertificate | None:
1569
+ """Return the certificate that matches the given CSR, validated against the private key."""
1570
+ if not self.private_key:
1571
+ return None
1572
+ for provider_certificate in self.get_provider_certificates():
1573
+ if provider_certificate.certificate_signing_request == csr.certificate_signing_request:
1574
+ if provider_certificate.certificate.is_ca and not csr.is_ca:
1575
+ logger.warning("Non CA certificate requested, got a CA certificate, ignoring")
1576
+ continue
1577
+ elif not provider_certificate.certificate.is_ca and csr.is_ca:
1578
+ logger.warning("CA certificate requested, got a non CA certificate, ignoring")
1579
+ continue
1580
+ if not provider_certificate.certificate.matches_private_key(self.private_key):
1581
+ logger.warning(
1582
+ "Certificate does not match the private key. Ignoring invalid certificate."
1583
+ )
1584
+ continue
1585
+ return provider_certificate
1586
+ return None
1587
+
1588
+ def _find_available_certificates(self):
1589
+ """Find available certificates and emit events.
1590
+
1591
+ This method will find certificates that are available for the requirer's CSRs.
1592
+ If a certificate is found, it will be set as a secret and an event will be emitted.
1593
+ If a certificate is revoked, the secret will be removed and an event will be emitted.
1594
+ """
1595
+ requirer_csrs = self.get_csrs_from_requirer_relation_data()
1596
+ csrs = [csr.certificate_signing_request for csr in requirer_csrs]
1597
+ provider_certificates = self.get_provider_certificates()
1598
+ for provider_certificate in provider_certificates:
1599
+ if provider_certificate.certificate_signing_request in csrs:
1600
+ secret_label = self._get_csr_secret_label(
1601
+ provider_certificate.certificate_signing_request
1602
+ )
1603
+ if provider_certificate.revoked:
1604
+ with suppress(SecretNotFoundError):
1605
+ logger.debug(
1606
+ "Removing secret with label %s",
1607
+ secret_label,
1608
+ )
1609
+ secret = self.model.get_secret(label=secret_label)
1610
+ secret.remove_all_revisions()
1611
+ else:
1612
+ if not self._csr_matches_certificate_request(
1613
+ certificate_signing_request=provider_certificate.certificate_signing_request,
1614
+ is_ca=provider_certificate.certificate.is_ca,
1615
+ ):
1616
+ logger.debug("Certificate requested for different attributes - Skipping")
1617
+ continue
1618
+ try:
1619
+ secret = self.model.get_secret(label=secret_label)
1620
+ logger.debug("Setting secret with label %s", secret_label)
1621
+ # Juju < 3.6 will create a new revision even if the content is the same
1622
+ if secret.get_content(refresh=True).get("certificate", "") == str(
1623
+ provider_certificate.certificate
1624
+ ):
1625
+ logger.debug(
1626
+ "Secret %s with correct certificate already exists", secret_label
1627
+ )
1628
+ continue
1629
+ secret.set_content(
1630
+ content={
1631
+ "certificate": str(provider_certificate.certificate),
1632
+ "csr": str(provider_certificate.certificate_signing_request),
1633
+ }
1634
+ )
1635
+ secret.set_info(
1636
+ expire=calculate_relative_datetime(
1637
+ target_time=provider_certificate.certificate.expiry_time,
1638
+ fraction=self.renewal_relative_time,
1639
+ ),
1640
+ )
1641
+ secret.get_content(refresh=True)
1642
+ except SecretNotFoundError:
1643
+ logger.debug("Creating new secret with label %s", secret_label)
1644
+ secret = self.charm.unit.add_secret(
1645
+ content={
1646
+ "certificate": str(provider_certificate.certificate),
1647
+ "csr": str(provider_certificate.certificate_signing_request),
1648
+ },
1649
+ label=secret_label,
1650
+ expire=calculate_relative_datetime(
1651
+ target_time=provider_certificate.certificate.expiry_time,
1652
+ fraction=self.renewal_relative_time,
1653
+ ),
1654
+ )
1655
+ self.on.certificate_available.emit(
1656
+ certificate_signing_request=provider_certificate.certificate_signing_request,
1657
+ certificate=provider_certificate.certificate,
1658
+ ca=provider_certificate.ca,
1659
+ chain=provider_certificate.chain,
1660
+ )
1661
+
1662
+ def _cleanup_certificate_requests(self):
1663
+ """Clean up certificate requests.
1664
+
1665
+ Remove any certificate requests that falls into one of the following categories:
1666
+ - The CSR attributes do not match any of the certificate requests defined in
1667
+ the charm's certificate_requests attribute.
1668
+ - The CSR public key does not match the private key.
1669
+ """
1670
+ for requirer_csr in self.get_csrs_from_requirer_relation_data():
1671
+ if not self._csr_matches_certificate_request(
1672
+ certificate_signing_request=requirer_csr.certificate_signing_request,
1673
+ is_ca=requirer_csr.is_ca,
1674
+ ):
1675
+ self._remove_requirer_csr_from_relation_data(
1676
+ requirer_csr.certificate_signing_request
1677
+ )
1678
+ logger.info(
1679
+ "Removed CSR from relation data because it did not match any certificate request" # noqa: E501
1680
+ )
1681
+ elif (
1682
+ self.private_key
1683
+ and not requirer_csr.certificate_signing_request.matches_private_key(
1684
+ self.private_key
1685
+ )
1686
+ ):
1687
+ self._remove_requirer_csr_from_relation_data(
1688
+ requirer_csr.certificate_signing_request
1689
+ )
1690
+ logger.info(
1691
+ "Removed CSR from relation data because it did not match the private key"
1692
+ )
1693
+
1694
+ def _tls_relation_created(self) -> bool:
1695
+ relation = self.model.get_relation(self.relationship_name)
1696
+ if not relation:
1697
+ return False
1698
+ return True
1699
+
1700
+ def _get_private_key_secret_label(self) -> str:
1701
+ if self.mode == Mode.UNIT:
1702
+ return f"{LIBID}-private-key-{self._get_unit_number()}-{self.relationship_name}"
1703
+ elif self.mode == Mode.APP:
1704
+ return f"{LIBID}-private-key-{self.relationship_name}"
1705
+ else:
1706
+ raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP.")
1707
+
1708
+ def _get_csr_secret_label(self, csr: CertificateSigningRequest) -> str:
1709
+ csr_in_sha256_hex = csr.get_sha256_hex()
1710
+ if self.mode == Mode.UNIT:
1711
+ return f"{LIBID}-certificate-{self._get_unit_number()}-{csr_in_sha256_hex}"
1712
+ elif self.mode == Mode.APP:
1713
+ return f"{LIBID}-certificate-{csr_in_sha256_hex}"
1714
+ else:
1715
+ raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP.")
1716
+
1717
+ def _get_unit_number(self) -> str:
1718
+ return self.model.unit.name.split("/")[1]
1719
+
1720
+
1721
+ class TLSCertificatesProvidesV4(Object):
1722
+ """TLS certificates provider class to be instantiated by TLS certificates providers."""
1723
+
1724
+ def __init__(self, charm: CharmBase, relationship_name: str):
1725
+ super().__init__(charm, relationship_name)
1726
+ self.framework.observe(charm.on[relationship_name].relation_joined, self._configure)
1727
+ self.framework.observe(charm.on[relationship_name].relation_changed, self._configure)
1728
+ self.framework.observe(charm.on.update_status, self._configure)
1729
+ self.charm = charm
1730
+ self.relationship_name = relationship_name
1731
+
1732
+ def _configure(self, _: EventBase) -> None:
1733
+ """Handle update status and tls relation changed events.
1734
+
1735
+ This is a common hook triggered on a regular basis.
1736
+
1737
+ Revoke certificates for which no csr exists
1738
+ """
1739
+ if not self.model.unit.is_leader():
1740
+ return
1741
+ self._remove_certificates_for_which_no_csr_exists()
1742
+
1743
+ def _remove_certificates_for_which_no_csr_exists(self) -> None:
1744
+ provider_certificates = self.get_provider_certificates()
1745
+ requirer_csrs = [
1746
+ request.certificate_signing_request for request in self.get_certificate_requests()
1747
+ ]
1748
+ for provider_certificate in provider_certificates:
1749
+ if provider_certificate.certificate_signing_request not in requirer_csrs:
1750
+ tls_relation = self._get_tls_relations(
1751
+ relation_id=provider_certificate.relation_id
1752
+ )
1753
+ self._remove_provider_certificate(
1754
+ certificate=provider_certificate.certificate,
1755
+ relation=tls_relation[0],
1756
+ )
1757
+
1758
+ def _get_tls_relations(self, relation_id: int | None = None) -> list[Relation]:
1759
+ return (
1760
+ [
1761
+ relation
1762
+ for relation in self.model.relations[self.relationship_name]
1763
+ if relation.id == relation_id
1764
+ ]
1765
+ if relation_id is not None
1766
+ else self.model.relations.get(self.relationship_name, [])
1767
+ )
1768
+
1769
+ def get_certificate_requests(
1770
+ self, relation_id: int | None = None
1771
+ ) -> list[RequirerCertificateRequest]:
1772
+ """Load certificate requests from the relation data."""
1773
+ relations = self._get_tls_relations(relation_id)
1774
+ requirer_csrs: list[RequirerCertificateRequest] = []
1775
+ for relation in relations:
1776
+ for unit in relation.units:
1777
+ requirer_csrs.extend(self._load_requirer_databag(relation, unit))
1778
+ requirer_csrs.extend(self._load_requirer_databag(relation, relation.app))
1779
+ return requirer_csrs
1780
+
1781
+ def _load_requirer_databag(
1782
+ self, relation: Relation, unit_or_app: Application | Unit
1783
+ ) -> list[RequirerCertificateRequest]:
1784
+ try:
1785
+ requirer_relation_data = _RequirerData.load(relation.data.get(unit_or_app, {}))
1786
+ except DataValidationError:
1787
+ logger.debug("Invalid requirer relation data for %s", unit_or_app.name)
1788
+ return []
1789
+ return [
1790
+ RequirerCertificateRequest(
1791
+ relation_id=relation.id,
1792
+ certificate_signing_request=CertificateSigningRequest.from_string(
1793
+ csr.certificate_signing_request
1794
+ ),
1795
+ is_ca=csr.ca if csr.ca else False,
1796
+ )
1797
+ for csr in requirer_relation_data.certificate_signing_requests
1798
+ ]
1799
+
1800
+ def _add_provider_certificate(
1801
+ self,
1802
+ relation: Relation,
1803
+ provider_certificate: ProviderCertificate,
1804
+ ) -> None:
1805
+ chain = [str(certificate) for certificate in provider_certificate.chain]
1806
+ if chain[0] != str(provider_certificate.certificate):
1807
+ logger.warning(
1808
+ "The order of the chain from the TLS Certificates Provider is incorrect. "
1809
+ "The leaf certificate should be the first element of the chain."
1810
+ )
1811
+ elif not chain_has_valid_order(chain):
1812
+ logger.warning(
1813
+ "The order of the chain from the TLS Certificates Provider is partially incorrect."
1814
+ )
1815
+ new_certificate = _Certificate(
1816
+ certificate=str(provider_certificate.certificate),
1817
+ certificate_signing_request=str(provider_certificate.certificate_signing_request),
1818
+ ca=str(provider_certificate.ca),
1819
+ chain=chain,
1820
+ )
1821
+ provider_certificates = self._load_provider_certificates(relation)
1822
+ if new_certificate in provider_certificates:
1823
+ logger.info("Certificate already in relation data - Doing nothing")
1824
+ return
1825
+ provider_certificates.append(new_certificate)
1826
+ self._dump_provider_certificates(relation=relation, certificates=provider_certificates)
1827
+
1828
+ def _load_provider_certificates(self, relation: Relation) -> list[_Certificate]:
1829
+ try:
1830
+ provider_relation_data = _ProviderApplicationData.load(relation.data[self.charm.app])
1831
+ except DataValidationError:
1832
+ logger.debug("Invalid provider relation data")
1833
+ return []
1834
+ return copy.deepcopy(provider_relation_data.certificates)
1835
+
1836
+ def _dump_provider_certificates(self, relation: Relation, certificates: list[_Certificate]):
1837
+ try:
1838
+ _ProviderApplicationData(certificates=certificates).dump(relation.data[self.model.app])
1839
+ logger.info("Certificate relation data updated")
1840
+ except ModelError:
1841
+ logger.warning("Failed to update relation data")
1842
+
1843
+ def _remove_provider_certificate(
1844
+ self,
1845
+ relation: Relation,
1846
+ certificate: Certificate | None = None,
1847
+ certificate_signing_request: CertificateSigningRequest | None = None,
1848
+ ) -> None:
1849
+ """Remove certificate based on certificate or certificate signing request."""
1850
+ provider_certificates = self._load_provider_certificates(relation)
1851
+ for provider_certificate in provider_certificates:
1852
+ if certificate and provider_certificate.certificate == str(certificate):
1853
+ provider_certificates.remove(provider_certificate)
1854
+ if (
1855
+ certificate_signing_request
1856
+ and provider_certificate.certificate_signing_request
1857
+ == str(certificate_signing_request)
1858
+ ):
1859
+ provider_certificates.remove(provider_certificate)
1860
+ self._dump_provider_certificates(relation=relation, certificates=provider_certificates)
1861
+
1862
+ def revoke_all_certificates(self) -> None:
1863
+ """Revoke all certificates of this provider.
1864
+
1865
+ This method is meant to be used when the Root CA has changed.
1866
+ """
1867
+ if not self.model.unit.is_leader():
1868
+ logger.warning("Unit is not a leader - will not set relation data")
1869
+ return
1870
+ relations = self._get_tls_relations()
1871
+ for relation in relations:
1872
+ provider_certificates = self._load_provider_certificates(relation)
1873
+ for certificate in provider_certificates:
1874
+ certificate.revoked = True
1875
+ self._dump_provider_certificates(relation=relation, certificates=provider_certificates)
1876
+
1877
+ def set_relation_certificate(
1878
+ self,
1879
+ provider_certificate: ProviderCertificate,
1880
+ ) -> None:
1881
+ """Add certificates to relation data.
1882
+
1883
+ Args:
1884
+ provider_certificate (ProviderCertificate): ProviderCertificate object
1885
+
1886
+ Returns:
1887
+ None
1888
+ """
1889
+ if not self.model.unit.is_leader():
1890
+ logger.warning("Unit is not a leader - will not set relation data")
1891
+ return
1892
+ certificates_relation = self.model.get_relation(
1893
+ relation_name=self.relationship_name, relation_id=provider_certificate.relation_id
1894
+ )
1895
+ if not certificates_relation:
1896
+ raise TLSCertificatesError(f"Relation {self.relationship_name} does not exist")
1897
+ self._remove_provider_certificate(
1898
+ relation=certificates_relation,
1899
+ certificate_signing_request=provider_certificate.certificate_signing_request,
1900
+ )
1901
+ self._add_provider_certificate(
1902
+ relation=certificates_relation,
1903
+ provider_certificate=provider_certificate,
1904
+ )
1905
+
1906
+ def get_issued_certificates(self, relation_id: int | None = None) -> list[ProviderCertificate]:
1907
+ """Return a List of issued (non revoked) certificates.
1908
+
1909
+ Returns:
1910
+ List: List of ProviderCertificate objects
1911
+ """
1912
+ if not self.model.unit.is_leader():
1913
+ logger.warning("Unit is not a leader - will not read relation data")
1914
+ return []
1915
+ provider_certificates = self.get_provider_certificates(relation_id=relation_id)
1916
+ return [certificate for certificate in provider_certificates if not certificate.revoked]
1917
+
1918
+ def get_provider_certificates(
1919
+ self, relation_id: int | None = None
1920
+ ) -> list[ProviderCertificate]:
1921
+ """Return a List of issued certificates."""
1922
+ certificates: list[ProviderCertificate] = []
1923
+ relations = self._get_tls_relations(relation_id)
1924
+ for relation in relations:
1925
+ if not relation.app:
1926
+ logger.warning("Relation %s does not have an application", relation.id)
1927
+ continue
1928
+ certificates.extend(
1929
+ certificate.to_provider_certificate(relation_id=relation.id)
1930
+ for certificate in self._load_provider_certificates(relation)
1931
+ )
1932
+ return certificates
1933
+
1934
+ def get_unsolicited_certificates(
1935
+ self, relation_id: int | None = None
1936
+ ) -> list[ProviderCertificate]:
1937
+ """Return provider certificates for which no certificate requests exists.
1938
+
1939
+ Those certificates should be revoked.
1940
+ """
1941
+ provider_certificates = self.get_provider_certificates(relation_id=relation_id)
1942
+ requirer_csrs = self.get_certificate_requests(relation_id=relation_id)
1943
+ list_of_csrs = [csr.certificate_signing_request for csr in requirer_csrs]
1944
+ unsolicited_certificates: list[ProviderCertificate] = [
1945
+ certificate
1946
+ for certificate in provider_certificates
1947
+ if certificate.certificate_signing_request not in list_of_csrs
1948
+ ]
1949
+ return unsolicited_certificates
1950
+
1951
+ def get_outstanding_certificate_requests(
1952
+ self, relation_id: int | None = None
1953
+ ) -> list[RequirerCertificateRequest]:
1954
+ """Return CSR's for which no certificate has been issued.
1955
+
1956
+ Args:
1957
+ relation_id (int): Relation id
1958
+
1959
+ Returns:
1960
+ list: List of RequirerCertificateRequest objects.
1961
+ """
1962
+ requirer_csrs = self.get_certificate_requests(relation_id=relation_id)
1963
+ outstanding_csrs: list[RequirerCertificateRequest] = [
1964
+ relation_csr
1965
+ for relation_csr in requirer_csrs
1966
+ if not self._certificate_issued_for_csr(
1967
+ csr=relation_csr.certificate_signing_request,
1968
+ relation_id=relation_id,
1969
+ )
1970
+ ]
1971
+ return outstanding_csrs
1972
+
1973
+ def _certificate_issued_for_csr(
1974
+ self, csr: CertificateSigningRequest, relation_id: int | None
1975
+ ) -> bool:
1976
+ """Check whether a certificate has been issued for a given CSR."""
1977
+ issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id)
1978
+ for issued_certificate in issued_certificates_per_csr:
1979
+ if issued_certificate.certificate_signing_request == csr:
1980
+ return csr.matches_certificate(issued_certificate.certificate)
1981
+ return False