charmlibs-interfaces-tls-certificates 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2479 @@
1
+ # Copyright 2024 Canonical Ltd.
2
+ # See LICENSE file for licensing details.
3
+
4
+ """Source code of ``tls_certificates_interface.tls_certificates`` v4.22."""
5
+
6
+ from __future__ import annotations
7
+
8
+ import copy
9
+ import ipaddress
10
+ import json
11
+ import logging
12
+ import uuid
13
+ import warnings
14
+ from contextlib import suppress
15
+ from dataclasses import asdict, dataclass, field
16
+ from datetime import datetime, timedelta, timezone
17
+ from enum import Enum
18
+ from typing import (
19
+ TYPE_CHECKING,
20
+ TypeAlias,
21
+ )
22
+
23
+ import pydantic
24
+ from cryptography import x509
25
+ from cryptography.exceptions import InvalidSignature
26
+ from cryptography.hazmat.primitives import hashes, serialization
27
+ from cryptography.hazmat.primitives.asymmetric import rsa
28
+ from cryptography.x509.oid import ExtensionOID, NameOID
29
+ from ops import BoundEvent, CharmBase, CharmEvents, Secret, SecretExpiredEvent, SecretRemoveEvent
30
+ from ops.framework import EventBase, EventSource, Handle, Object
31
+ from ops.jujuversion import JujuVersion
32
+ from ops.model import Application, ModelError, Relation, SecretNotFoundError, Unit
33
+
34
+ if TYPE_CHECKING:
35
+ from collections.abc import Collection, MutableMapping
36
+
37
+ # legacy Charmhub-hosted lib ID, used at runtime in this lib for labels
38
+ LIBID = "afd8c2bccf834997afce12c2706d2ede"
39
+
40
+
41
+ IS_PYDANTIC_V1 = int(pydantic.version.VERSION.split(".")[0]) < 2
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+ NESTED_JSON_KEY = "owasp_event"
46
+
47
+ CertificateIssuerPrivateKeyTypes: TypeAlias = rsa.RSAPrivateKey
48
+
49
+
50
+ @dataclass
51
+ class _OWASPLogEvent:
52
+ """OWASP-compliant log event."""
53
+
54
+ datetime: str
55
+ event: str
56
+ level: str
57
+ description: str
58
+ type: str = "security"
59
+ labels: dict[str, str] = field(default_factory=dict)
60
+
61
+ def to_json(self) -> str:
62
+ return json.dumps(self.to_dict(), ensure_ascii=False)
63
+
64
+ def to_dict(self) -> dict:
65
+ log_event = dict(asdict(self), **self.labels)
66
+ log_event.pop("labels", None)
67
+ return {k: v for k, v in log_event.items() if v is not None}
68
+
69
+
70
+ class _OWASPLogger:
71
+ """OWASP-compliant logger for security events."""
72
+
73
+ def __init__(self, application: str | None = None):
74
+ self.application = application
75
+ self._logger = logging.getLogger(__name__)
76
+
77
+ def log_event(self, event: str, level: int, description: str, **labels: str):
78
+ if self.application and "application" not in labels:
79
+ labels["application"] = self.application
80
+ log = _OWASPLogEvent(
81
+ datetime=datetime.now(timezone.utc).astimezone().isoformat(),
82
+ event=event,
83
+ level=logging.getLevelName(level),
84
+ description=description,
85
+ labels=labels,
86
+ )
87
+ self._logger.log(level, log.to_json(), extra={NESTED_JSON_KEY: log.to_dict()})
88
+
89
+
90
+ class TLSCertificatesError(Exception):
91
+ """Base class for custom errors raised by this library."""
92
+
93
+
94
+ class DataValidationError(TLSCertificatesError):
95
+ """Raised when data validation fails."""
96
+
97
+
98
+ class _DatabagModel(pydantic.BaseModel):
99
+ """Base databag model.
100
+
101
+ Supports both pydantic v1 and v2.
102
+ """
103
+
104
+ if IS_PYDANTIC_V1:
105
+
106
+ class Config:
107
+ """Pydantic config."""
108
+
109
+ # ignore any extra fields in the databag
110
+ extra = "ignore"
111
+ """Ignore any extra fields in the databag."""
112
+ allow_population_by_field_name = True
113
+ """Allow instantiating this class by field name (instead of forcing alias)."""
114
+
115
+ _NEST_UNDER = None
116
+
117
+ model_config = pydantic.ConfigDict(
118
+ # tolerate additional keys in databag
119
+ extra="ignore",
120
+ # Allow instantiating this class by field name (instead of forcing alias).
121
+ populate_by_name=True,
122
+ # Custom config key: whether to nest the whole datastructure (as json)
123
+ # under a field or spread it out at the toplevel.
124
+ _NEST_UNDER=None,
125
+ ) # type: ignore
126
+ """Pydantic config."""
127
+
128
+ @classmethod
129
+ def load(cls, databag: MutableMapping):
130
+ """Load this model from a Juju databag."""
131
+ if IS_PYDANTIC_V1:
132
+ return cls._load_v1(databag)
133
+ nest_under = cls.model_config.get("_NEST_UNDER")
134
+ if nest_under:
135
+ return cls.model_validate(json.loads(databag[nest_under]))
136
+
137
+ try:
138
+ data = {
139
+ k: json.loads(v)
140
+ for k, v in databag.items()
141
+ # Don't attempt to parse model-external values
142
+ if k in {(f.alias or n) for n, f in cls.model_fields.items()}
143
+ }
144
+ except json.JSONDecodeError as e:
145
+ msg = f"invalid databag contents: expecting json. {databag}"
146
+ logger.error(msg)
147
+ raise DataValidationError(msg) from e
148
+
149
+ try:
150
+ return cls.model_validate_json(json.dumps(data))
151
+ except pydantic.ValidationError as e:
152
+ msg = f"failed to validate databag: {databag}"
153
+ logger.debug(msg, exc_info=True)
154
+ raise DataValidationError(msg) from e
155
+
156
+ @classmethod
157
+ def _load_v1(cls, databag: MutableMapping):
158
+ """Load implementation for pydantic v1."""
159
+ if cls._NEST_UNDER:
160
+ return cls.parse_obj(json.loads(databag[cls._NEST_UNDER]))
161
+
162
+ try:
163
+ data = {
164
+ k: json.loads(v)
165
+ for k, v in databag.items()
166
+ # Don't attempt to parse model-external values
167
+ if k in {f.alias for f in cls.__fields__.values()}
168
+ }
169
+ except json.JSONDecodeError as e:
170
+ msg = f"invalid databag contents: expecting json. {databag}"
171
+ logger.error(msg)
172
+ raise DataValidationError(msg) from e
173
+
174
+ try:
175
+ return cls.parse_raw(json.dumps(data)) # type: ignore
176
+ except pydantic.ValidationError as e:
177
+ msg = f"failed to validate databag: {databag}"
178
+ logger.debug(msg, exc_info=True)
179
+ raise DataValidationError(msg) from e
180
+
181
+ def dump(self, databag: MutableMapping | None = None, clear: bool = True):
182
+ """Write the contents of this model to Juju databag.
183
+
184
+ Args:
185
+ databag: The databag to write to.
186
+ clear: Whether to clear the databag before writing.
187
+
188
+ Returns:
189
+ MutableMapping: The databag.
190
+ """
191
+ if IS_PYDANTIC_V1:
192
+ return self._dump_v1(databag, clear)
193
+ if clear and databag:
194
+ databag.clear()
195
+
196
+ if databag is None:
197
+ databag = {}
198
+ nest_under = self.model_config.get("_NEST_UNDER")
199
+ if nest_under:
200
+ databag[nest_under] = self.model_dump_json(
201
+ by_alias=True,
202
+ # skip keys whose values are default
203
+ exclude_defaults=True,
204
+ )
205
+ return databag
206
+
207
+ dct = self.model_dump(mode="json", by_alias=True, exclude_defaults=True)
208
+ databag.update({k: json.dumps(v) for k, v in dct.items()})
209
+ return databag
210
+
211
+ def _dump_v1(self, databag: MutableMapping | None = None, clear: bool = True):
212
+ """Dump implementation for pydantic v1."""
213
+ if clear and databag:
214
+ databag.clear()
215
+
216
+ if databag is None:
217
+ databag = {}
218
+
219
+ if self._NEST_UNDER:
220
+ databag[self._NEST_UNDER] = self.json(by_alias=True, exclude_defaults=True)
221
+ return databag
222
+
223
+ dct = json.loads(self.json(by_alias=True, exclude_defaults=True))
224
+ databag.update({k: json.dumps(v) for k, v in dct.items()})
225
+
226
+ return databag
227
+
228
+
229
+ class _Certificate(pydantic.BaseModel):
230
+ """Certificate model."""
231
+
232
+ ca: str
233
+ certificate_signing_request: str
234
+ certificate: str
235
+ chain: list[str] | None = None
236
+ revoked: bool | None = None
237
+
238
+ def to_provider_certificate(self, relation_id: int) -> ProviderCertificate:
239
+ """Convert to a ProviderCertificate."""
240
+ return ProviderCertificate(
241
+ relation_id=relation_id,
242
+ certificate=Certificate.from_string(self.certificate),
243
+ certificate_signing_request=CertificateSigningRequest.from_string(
244
+ self.certificate_signing_request
245
+ ),
246
+ ca=Certificate.from_string(self.ca),
247
+ chain=[Certificate.from_string(certificate) for certificate in self.chain]
248
+ if self.chain
249
+ else [],
250
+ revoked=self.revoked,
251
+ )
252
+
253
+
254
+ class _CertificateSigningRequest(pydantic.BaseModel):
255
+ """Certificate signing request model."""
256
+
257
+ certificate_signing_request: str
258
+ ca: bool | None
259
+
260
+
261
+ class _ProviderApplicationData(_DatabagModel):
262
+ """Provider application data model."""
263
+
264
+ certificates: list[_Certificate] = []
265
+
266
+
267
+ class _RequirerData(_DatabagModel):
268
+ """Requirer data model.
269
+
270
+ The same model is used for the unit and application data.
271
+ """
272
+
273
+ certificate_signing_requests: list[_CertificateSigningRequest] = []
274
+
275
+
276
+ class Mode(Enum):
277
+ """Enum representing the mode of the certificate request.
278
+
279
+ UNIT (default): Request a certificate for the unit.
280
+ Each unit will manage its private key,
281
+ certificate signing request and certificate.
282
+ APP: Request a certificate for the application.
283
+ Only the leader unit will manage the private key, certificate signing request
284
+ and certificate.
285
+ """
286
+
287
+ UNIT = 1
288
+ APP = 2
289
+
290
+
291
+ class PrivateKey:
292
+ """This class represents a private key."""
293
+
294
+ def __init__(
295
+ self, raw: str | None = None, x509_object: rsa.RSAPrivateKey | None = None
296
+ ) -> None:
297
+ """Initialize the PrivateKey object.
298
+
299
+ If both raw and x509_object are provided, x509_object takes precedence.
300
+ """
301
+ if x509_object:
302
+ self._private_key = x509_object
303
+ elif raw:
304
+ self._private_key = serialization.load_pem_private_key(
305
+ raw.encode(),
306
+ password=None,
307
+ )
308
+ else:
309
+ raise ValueError("Either raw private key string or x509_object must be provided")
310
+
311
+ @property
312
+ def raw(self) -> str:
313
+ """Return the PEM-formatted string representation of the private key."""
314
+ return str(self)
315
+
316
+ def __str__(self):
317
+ """Return the private key as a string in PEM format."""
318
+ return (
319
+ self._private_key.private_bytes(
320
+ encoding=serialization.Encoding.PEM,
321
+ format=serialization.PrivateFormat.TraditionalOpenSSL,
322
+ encryption_algorithm=serialization.NoEncryption(),
323
+ )
324
+ .decode()
325
+ .strip()
326
+ )
327
+
328
+ @classmethod
329
+ def from_string(cls, private_key: str) -> PrivateKey:
330
+ """Create a PrivateKey object from a private key."""
331
+ return cls(raw=private_key)
332
+
333
+ def is_valid(self) -> bool:
334
+ """Validate that the private key is PEM-formatted, RSA, and at least 2048 bits."""
335
+ try:
336
+ if not isinstance(self._private_key, rsa.RSAPrivateKey):
337
+ logger.warning("Private key is not an RSA key")
338
+ return False
339
+
340
+ if self._private_key.key_size < 2048:
341
+ logger.warning("RSA key size is less than 2048 bits")
342
+ return False
343
+
344
+ return True
345
+ except ValueError:
346
+ logger.warning("Invalid private key format")
347
+ return False
348
+
349
+ @classmethod
350
+ def generate(cls, key_size: int = 2048, public_exponent: int = 65537) -> PrivateKey:
351
+ """Generate a new RSA private key.
352
+
353
+ Args:
354
+ key_size: The size of the key in bits.
355
+ public_exponent: The public exponent of the key.
356
+
357
+ Returns:
358
+ PrivateKey: The generated private key.
359
+ """
360
+ private_key = rsa.generate_private_key(
361
+ public_exponent=public_exponent,
362
+ key_size=key_size,
363
+ )
364
+ _OWASPLogger().log_event(
365
+ event="private_key_generated",
366
+ level=logging.INFO,
367
+ description="Private key generated",
368
+ key_size=str(key_size),
369
+ )
370
+ return PrivateKey(x509_object=private_key)
371
+
372
+ def __eq__(self, other: object) -> bool:
373
+ """Check if two PrivateKey objects are equal."""
374
+ if not isinstance(other, PrivateKey):
375
+ return NotImplemented
376
+ return self.raw == other.raw
377
+
378
+
379
+ class Certificate:
380
+ """This class represents a certificate."""
381
+
382
+ _cert: x509.Certificate
383
+
384
+ def __init__(
385
+ self,
386
+ raw: str | None = None, # Must remain first argument for backwards compatibility
387
+ # Old Interface fields (ignored)
388
+ common_name: str | None = None,
389
+ expiry_time: datetime | None = None,
390
+ validity_start_time: datetime | None = None,
391
+ is_ca: bool | None = None,
392
+ sans_dns: set[str] | None = None,
393
+ sans_ip: set[str] | None = None,
394
+ sans_oid: set[str] | None = None,
395
+ email_address: str | None = None,
396
+ organization: str | None = None,
397
+ organizational_unit: str | None = None,
398
+ country_name: str | None = None,
399
+ state_or_province_name: str | None = None,
400
+ locality_name: str | None = None,
401
+ # End Old Interface fields
402
+ x509_object: x509.Certificate | None = None,
403
+ ) -> None:
404
+ """Initialize the Certificate object.
405
+
406
+ This initializer must maintain the old interface while also allowing
407
+ instantiation from an existing x509_object. It ignores all fields
408
+ other than raw and x509_object, preferring x509_object.
409
+ """
410
+ if x509_object:
411
+ self._cert = x509_object
412
+ elif raw:
413
+ self._cert = x509.load_pem_x509_certificate(data=raw.encode())
414
+ else:
415
+ raise ValueError("Either raw certificate string or x509_object must be provided")
416
+
417
+ @property
418
+ def raw(self) -> str:
419
+ """Return the PEM-formatted string representation of the certificate."""
420
+ return str(self)
421
+
422
+ @property
423
+ def common_name(self) -> str:
424
+ """Return the common name of the certificate."""
425
+ # We maintain compatibility with the old interface by returning
426
+ # an empty string if no common name is set.
427
+ common_name = self._cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME)
428
+ return str(common_name[0].value) if common_name else ""
429
+
430
+ @property
431
+ def expiry_time(self) -> datetime:
432
+ """Return the expiry time of the certificate."""
433
+ return self._cert.not_valid_after_utc
434
+
435
+ @property
436
+ def validity_start_time(self) -> datetime:
437
+ """Return the validity start time of the certificate."""
438
+ return self._cert.not_valid_before_utc
439
+
440
+ @property
441
+ def is_ca(self) -> bool:
442
+ """Return whether the certificate is a CA certificate."""
443
+ try:
444
+ return self._cert.extensions.get_extension_for_oid(
445
+ ExtensionOID.BASIC_CONSTRAINTS
446
+ ).value.ca # type: ignore[reportAttributeAccessIssue]
447
+ except x509.ExtensionNotFound:
448
+ return False
449
+
450
+ @property
451
+ def sans_dns(self) -> set[str] | None:
452
+ """Return the DNS Subject Alternative Names of the certificate."""
453
+ with suppress(x509.ExtensionNotFound):
454
+ sans = self._cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value
455
+ return {str(san) for san in sans.get_values_for_type(x509.DNSName)}
456
+ return None
457
+
458
+ @property
459
+ def sans_ip(self) -> set[str] | None:
460
+ """Return the IP Subject Alternative Names of the certificate."""
461
+ with suppress(x509.ExtensionNotFound):
462
+ sans = self._cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value
463
+ return {str(san) for san in sans.get_values_for_type(x509.IPAddress)}
464
+ return None
465
+
466
+ @property
467
+ def sans_oid(self) -> set[str] | None:
468
+ """Return the OID Subject Alternative Names of the certificate."""
469
+ with suppress(x509.ExtensionNotFound):
470
+ sans = self._cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value
471
+ return {str(san.dotted_string) for san in sans.get_values_for_type(x509.RegisteredID)}
472
+ return None
473
+
474
+ @property
475
+ def email_address(self) -> str | None:
476
+ """Return the email address of the certificate."""
477
+ email_address = self._cert.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS)
478
+ return str(email_address[0].value) if email_address else None
479
+
480
+ @property
481
+ def organization(self) -> str | None:
482
+ """Return the organization name of the certificate."""
483
+ organization = self._cert.subject.get_attributes_for_oid(NameOID.ORGANIZATION_NAME)
484
+ return str(organization[0].value) if organization else None
485
+
486
+ @property
487
+ def organizational_unit(self) -> str | None:
488
+ """Return the organizational unit name of the certificate."""
489
+ organizational_unit = self._cert.subject.get_attributes_for_oid(
490
+ NameOID.ORGANIZATIONAL_UNIT_NAME
491
+ )
492
+ return str(organizational_unit[0].value) if organizational_unit else None
493
+
494
+ @property
495
+ def country_name(self) -> str | None:
496
+ """Return the country name of the certificate."""
497
+ country_name = self._cert.subject.get_attributes_for_oid(NameOID.COUNTRY_NAME)
498
+ return str(country_name[0].value) if country_name else None
499
+
500
+ @property
501
+ def state_or_province_name(self) -> str | None:
502
+ """Return the state or province name of the certificate."""
503
+ state_or_province_name = self._cert.subject.get_attributes_for_oid(
504
+ NameOID.STATE_OR_PROVINCE_NAME
505
+ )
506
+ return str(state_or_province_name[0].value) if state_or_province_name else None
507
+
508
+ @property
509
+ def locality_name(self) -> str | None:
510
+ """Return the locality name of the certificate."""
511
+ locality_name = self._cert.subject.get_attributes_for_oid(NameOID.LOCALITY_NAME)
512
+ return str(locality_name[0].value) if locality_name else None
513
+
514
+ def __str__(self) -> str:
515
+ """Return the certificate as a string."""
516
+ return self._cert.public_bytes(serialization.Encoding.PEM).decode().strip()
517
+
518
+ def __eq__(self, other: object) -> bool:
519
+ """Check if two Certificate objects are equal."""
520
+ if not isinstance(other, Certificate):
521
+ return NotImplemented
522
+ return self.raw == other.raw
523
+
524
+ @classmethod
525
+ def from_string(cls, certificate: str) -> Certificate:
526
+ """Create a Certificate object from a certificate."""
527
+ try:
528
+ certificate_object = x509.load_pem_x509_certificate(data=certificate.encode())
529
+ except ValueError as e:
530
+ logger.error("Could not load certificate: %s", e)
531
+ raise TLSCertificatesError("Could not load certificate")
532
+
533
+ return cls(x509_object=certificate_object)
534
+
535
+ def matches_private_key(self, private_key: PrivateKey) -> bool:
536
+ """Check if this certificate matches a given private key.
537
+
538
+ Args:
539
+ private_key (PrivateKey): The private key to validate against.
540
+
541
+ Returns:
542
+ bool: True if the certificate matches the private key, False otherwise.
543
+ """
544
+ try:
545
+ cert_public_key = self._cert.public_key()
546
+ key_public_key = private_key._private_key.public_key()
547
+
548
+ if not isinstance(cert_public_key, rsa.RSAPublicKey):
549
+ logger.warning("Certificate does not use RSA public key")
550
+ return False
551
+
552
+ if not isinstance(key_public_key, rsa.RSAPublicKey):
553
+ logger.warning("Private key is not an RSA key")
554
+ return False
555
+
556
+ return cert_public_key.public_numbers() == key_public_key.public_numbers()
557
+ except Exception as e:
558
+ logger.warning("Failed to validate certificate and private key match: %s", e)
559
+ return False
560
+
561
+ @classmethod
562
+ def generate(
563
+ cls,
564
+ csr: CertificateSigningRequest,
565
+ ca: Certificate,
566
+ ca_private_key: PrivateKey,
567
+ validity: timedelta,
568
+ is_ca: bool = False,
569
+ ) -> Certificate:
570
+ """Generate a certificate from a CSR signed by the given CA and CA private key.
571
+
572
+ Args:
573
+ csr: The certificate signing request.
574
+ ca: The CA certificate.
575
+ ca_private_key: The CA private key.
576
+ validity: The validity period of the certificate.
577
+ is_ca: Whether the generated certificate is a CA certificate.
578
+
579
+ Returns:
580
+ Certificate: The generated certificate.
581
+ """
582
+ # Ideally, this would be the constructor, but we can't add new
583
+ # required parameters to the constructor without breaking backwards
584
+ # compatibility.
585
+ private_key = serialization.load_pem_private_key(
586
+ str(ca_private_key).encode(), password=None
587
+ )
588
+ assert isinstance(private_key, CertificateIssuerPrivateKeyTypes)
589
+
590
+ # Create a certificate builder
591
+ cert_builder = x509.CertificateBuilder(
592
+ subject_name=csr._csr.subject,
593
+ # issuer_name=ca._cert.subject,
594
+ # TODO: Validate this is correct, the old code used `issuer`
595
+ issuer_name=ca._cert.issuer,
596
+ public_key=csr._csr.public_key(),
597
+ serial_number=x509.random_serial_number(),
598
+ not_valid_before=datetime.now(timezone.utc),
599
+ not_valid_after=datetime.now(timezone.utc) + validity,
600
+ )
601
+ extensions = _generate_certificate_request_extensions(
602
+ authority_key_identifier=ca._cert.extensions.get_extension_for_class(
603
+ x509.SubjectKeyIdentifier
604
+ ).value.key_identifier,
605
+ csr=csr._csr,
606
+ is_ca=is_ca,
607
+ )
608
+ for extension in extensions:
609
+ try:
610
+ cert_builder = cert_builder.add_extension(extension.value, extension.critical)
611
+ except ValueError as e:
612
+ logger.error("Could not add extension to certificate: %s", e)
613
+ raise TLSCertificatesError("Could not add extension to certificate") from e
614
+
615
+ # Sign the certificate with the CA's private key
616
+ cert = cert_builder.sign(private_key=private_key, algorithm=hashes.SHA256())
617
+ _OWASPLogger().log_event(
618
+ event="certificate_generated",
619
+ level=logging.INFO,
620
+ description="Certificate generated from CSR",
621
+ common_name=csr.common_name,
622
+ is_ca=str(is_ca),
623
+ validity_days=str(validity.days),
624
+ )
625
+
626
+ return cls(x509_object=cert)
627
+
628
+ @classmethod
629
+ def generate_self_signed_ca(
630
+ cls,
631
+ attributes: CertificateRequestAttributes,
632
+ private_key: PrivateKey,
633
+ validity: timedelta,
634
+ ) -> Certificate:
635
+ """Generate a self-signed CA certificate.
636
+
637
+ Args:
638
+ attributes: The certificate request attributes.
639
+ private_key: The private key to sign the CA certificate.
640
+ validity: The validity period of the CA certificate.
641
+
642
+ Returns:
643
+ Certificate: The generated CA certificate.
644
+ """
645
+ assert isinstance(private_key._private_key, rsa.RSAPrivateKey)
646
+
647
+ public_key = private_key._private_key.public_key()
648
+
649
+ builder = x509.CertificateBuilder(
650
+ public_key=public_key,
651
+ serial_number=x509.random_serial_number(),
652
+ not_valid_before=datetime.now(timezone.utc),
653
+ not_valid_after=datetime.now(timezone.utc) + validity,
654
+ )
655
+
656
+ if subject_name := _extract_subject_name_attributes(attributes):
657
+ builder = builder.subject_name(subject_name).issuer_name(subject_name)
658
+
659
+ builder = (
660
+ builder.add_extension(
661
+ x509.SubjectKeyIdentifier.from_public_key(public_key), critical=False
662
+ )
663
+ .add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True)
664
+ .add_extension(
665
+ x509.KeyUsage(
666
+ digital_signature=True,
667
+ key_encipherment=True,
668
+ key_cert_sign=True,
669
+ key_agreement=False,
670
+ content_commitment=False,
671
+ data_encipherment=False,
672
+ crl_sign=False,
673
+ encipher_only=False,
674
+ decipher_only=False,
675
+ ),
676
+ critical=True,
677
+ )
678
+ )
679
+
680
+ if san_extension := _san_extension(
681
+ email_address=attributes.email_address,
682
+ sans_dns=attributes.sans_dns,
683
+ sans_ip=attributes.sans_ip,
684
+ sans_oid=attributes.sans_oid,
685
+ ):
686
+ builder = builder.add_extension(san_extension, critical=False)
687
+
688
+ cert = cls(x509_object=builder.sign(private_key._private_key, algorithm=hashes.SHA256()))
689
+
690
+ _OWASPLogger().log_event(
691
+ event="ca_certificate_generated",
692
+ level=logging.INFO,
693
+ description="CA certificate generated",
694
+ common_name=cert.common_name,
695
+ validity_days=str(validity.days),
696
+ )
697
+
698
+ return cert
699
+
700
+
701
+ class CertificateSigningRequest:
702
+ """A representation of the certificate signing request."""
703
+
704
+ _csr: x509.CertificateSigningRequest
705
+
706
+ def __init__(
707
+ self,
708
+ raw: str | None = None, # Must remain first argument for backwards compatibility
709
+ # Old Interface fields (ignored)
710
+ common_name: str | None = None,
711
+ sans_dns: set[str] | None = None,
712
+ sans_ip: set[str] | None = None,
713
+ sans_oid: set[str] | None = None,
714
+ email_address: str | None = None,
715
+ organization: str | None = None,
716
+ organizational_unit: str | None = None,
717
+ country_name: str | None = None,
718
+ state_or_province_name: str | None = None,
719
+ locality_name: str | None = None,
720
+ has_unique_identifier: bool | None = None,
721
+ # End Old Interface fields
722
+ x509_object: x509.CertificateSigningRequest | None = None,
723
+ ):
724
+ """Initialize the CertificateSigningRequest object.
725
+
726
+ This initializer must maintain the old interface while also allowing
727
+ instantiation from an existing x509_object. It ignores all fields
728
+ other than raw and x509_object, preferring x509_object.
729
+ """
730
+ if x509_object:
731
+ self._csr = x509_object
732
+ return
733
+ elif raw:
734
+ try:
735
+ self._csr = x509.load_pem_x509_csr(raw.encode())
736
+ except ValueError as e:
737
+ logger.error("Could not load CSR: %s", e)
738
+ raise TLSCertificatesError("Could not load CSR")
739
+ return
740
+ raise ValueError("Either raw CSR string or x509_object must be provided")
741
+
742
+ @property
743
+ def common_name(self) -> str:
744
+ """Return the common name of the CSR."""
745
+ common_name = self._csr.subject.get_attributes_for_oid(NameOID.COMMON_NAME)
746
+ return str(common_name[0].value) if common_name else ""
747
+
748
+ @property
749
+ def sans_dns(self) -> set[str]:
750
+ """Return the DNS Subject Alternative Names of the CSR."""
751
+ with suppress(x509.ExtensionNotFound):
752
+ sans = self._csr.extensions.get_extension_for_class(x509.SubjectAlternativeName).value
753
+ return {str(san) for san in sans.get_values_for_type(x509.DNSName)}
754
+ return set()
755
+
756
+ @property
757
+ def sans_ip(self) -> set[str]:
758
+ """Return the IP Subject Alternative Names of the CSR."""
759
+ with suppress(x509.ExtensionNotFound):
760
+ sans = self._csr.extensions.get_extension_for_class(x509.SubjectAlternativeName).value
761
+ return {str(san) for san in sans.get_values_for_type(x509.IPAddress)}
762
+ return set()
763
+
764
+ @property
765
+ def sans_oid(self) -> set[str]:
766
+ """Return the OID Subject Alternative Names of the CSR."""
767
+ with suppress(x509.ExtensionNotFound):
768
+ sans = self._csr.extensions.get_extension_for_class(x509.SubjectAlternativeName).value
769
+ return {str(san.dotted_string) for san in sans.get_values_for_type(x509.RegisteredID)}
770
+ return set()
771
+
772
+ @property
773
+ def email_address(self) -> str | None:
774
+ """Return the email address of the CSR."""
775
+ email_address = self._csr.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS)
776
+ return str(email_address[0].value) if email_address else None
777
+
778
+ @property
779
+ def organization(self) -> str | None:
780
+ """Return the organization name of the CSR."""
781
+ organization = self._csr.subject.get_attributes_for_oid(NameOID.ORGANIZATION_NAME)
782
+ return str(organization[0].value) if organization else None
783
+
784
+ @property
785
+ def organizational_unit(self) -> str | None:
786
+ """Return the organizational unit name of the CSR."""
787
+ organizational_unit = self._csr.subject.get_attributes_for_oid(
788
+ NameOID.ORGANIZATIONAL_UNIT_NAME
789
+ )
790
+ return str(organizational_unit[0].value) if organizational_unit else None
791
+
792
+ @property
793
+ def country_name(self) -> str | None:
794
+ """Return the country name of the CSR."""
795
+ country_name = self._csr.subject.get_attributes_for_oid(NameOID.COUNTRY_NAME)
796
+ return str(country_name[0].value) if country_name else None
797
+
798
+ @property
799
+ def state_or_province_name(self) -> str | None:
800
+ """Return the state or province name of the CSR."""
801
+ state_or_province_name = self._csr.subject.get_attributes_for_oid(
802
+ NameOID.STATE_OR_PROVINCE_NAME
803
+ )
804
+ return str(state_or_province_name[0].value) if state_or_province_name else None
805
+
806
+ @property
807
+ def locality_name(self) -> str | None:
808
+ """Return the locality name of the CSR."""
809
+ locality_name = self._csr.subject.get_attributes_for_oid(NameOID.LOCALITY_NAME)
810
+ return str(locality_name[0].value) if locality_name else None
811
+
812
+ @property
813
+ def has_unique_identifier(self) -> bool:
814
+ """Return whether the CSR has a unique identifier."""
815
+ unique_identifier = self._csr.subject.get_attributes_for_oid(
816
+ NameOID.X500_UNIQUE_IDENTIFIER
817
+ )
818
+ return bool(unique_identifier)
819
+
820
+ @property
821
+ def raw(self) -> str:
822
+ """Return the PEM-formatted string representation of the CSR."""
823
+ return self.__str__()
824
+
825
+ def __str__(self) -> str:
826
+ """Return the CSR as a string."""
827
+ return self._csr.public_bytes(serialization.Encoding.PEM).decode().strip()
828
+
829
+ @property
830
+ def additional_critical_extensions(self) -> list[x509.ExtensionType]:
831
+ """Return additional critical extensions present on the CSR (excluding SAN)."""
832
+ return [
833
+ extension.value
834
+ for extension in self._csr.extensions
835
+ if extension.critical and extension.oid != ExtensionOID.SUBJECT_ALTERNATIVE_NAME
836
+ ]
837
+
838
+ @classmethod
839
+ def from_string(cls, csr: str) -> CertificateSigningRequest:
840
+ """Create a CertificateSigningRequest object from a CSR."""
841
+ return cls(raw=csr)
842
+
843
+ @classmethod
844
+ def from_csr(cls, csr: x509.CertificateSigningRequest) -> CertificateSigningRequest:
845
+ """Create a CertificateSigningRequest object from a CSR."""
846
+ return cls(x509_object=csr)
847
+
848
+ def __eq__(self, other: object) -> bool:
849
+ """Check if two CertificateSigningRequest objects are equal."""
850
+ if not isinstance(other, CertificateSigningRequest):
851
+ return NotImplemented
852
+ return self.raw == other.raw
853
+
854
+ def matches_certificate(self, certificate: Certificate) -> bool:
855
+ """Check if this CSR matches a given certificate.
856
+
857
+ Args:
858
+ certificate (Certificate): The certificate to validate against.
859
+
860
+ Returns:
861
+ bool: True if the CSR matches the certificate, False otherwise.
862
+ """
863
+ return self._csr.public_key() == certificate._cert.public_key()
864
+
865
+ def matches_private_key(self, key: PrivateKey) -> bool:
866
+ """Check if a CSR matches a private key.
867
+
868
+ This function only works with RSA keys.
869
+
870
+ Args:
871
+ key (PrivateKey): Private key
872
+ Returns:
873
+ bool: True/False depending on whether the CSR matches the private key.
874
+ """
875
+ try:
876
+ key_object_public_key = key._private_key.public_key()
877
+ csr_object_public_key = self._csr.public_key()
878
+ if not isinstance(key_object_public_key, rsa.RSAPublicKey):
879
+ logger.warning("Key is not an RSA key")
880
+ return False
881
+ if not isinstance(csr_object_public_key, rsa.RSAPublicKey):
882
+ logger.warning("CSR is not an RSA key")
883
+ return False
884
+ if (
885
+ csr_object_public_key.public_numbers().n
886
+ != key_object_public_key.public_numbers().n
887
+ ):
888
+ logger.warning("Public key numbers between CSR and key do not match")
889
+ return False
890
+ except ValueError:
891
+ logger.warning("Could not load certificate or CSR.")
892
+ return False
893
+ return True
894
+
895
+ def get_sha256_hex(self) -> str:
896
+ """Calculate the hash of the provided data and return the hexadecimal representation."""
897
+ digest = hashes.Hash(hashes.SHA256())
898
+ digest.update(self.raw.encode())
899
+ return digest.finalize().hex()
900
+
901
+ def sign(
902
+ self, ca: Certificate, ca_private_key: PrivateKey, validity: timedelta, is_ca: bool = False
903
+ ) -> Certificate:
904
+ """Sign this CSR with the given CA and CA private key.
905
+
906
+ Args:
907
+ ca: The CA certificate.
908
+ ca_private_key: The CA private key.
909
+ validity: The validity period of the certificate.
910
+ is_ca: Whether the generated certificate is a CA certificate.
911
+
912
+ Returns:
913
+ Certificate: The signed certificate.
914
+ """
915
+ return Certificate.generate(
916
+ csr=self,
917
+ ca=ca,
918
+ ca_private_key=ca_private_key,
919
+ validity=validity,
920
+ is_ca=is_ca,
921
+ )
922
+
923
+ @classmethod
924
+ def generate(
925
+ cls,
926
+ attributes: CertificateRequestAttributes,
927
+ private_key: PrivateKey,
928
+ ) -> CertificateSigningRequest:
929
+ """Generate a CSR using the supplied attributes and private key.
930
+
931
+ Args:
932
+ attributes (CertificateRequestAttributes): Certificate request attributes
933
+ private_key (PrivateKey): Private key
934
+ Returns:
935
+ CertificateSigningRequest: CSR
936
+ """
937
+ signing_key = private_key._private_key
938
+ assert isinstance(signing_key, CertificateIssuerPrivateKeyTypes)
939
+
940
+ csr_builder = x509.CertificateSigningRequestBuilder()
941
+ if subject_name := _extract_subject_name_attributes(attributes):
942
+ csr_builder = csr_builder.subject_name(subject_name)
943
+
944
+ sans: list[x509.GeneralName] = []
945
+ if attributes.sans_oid:
946
+ sans.extend([
947
+ x509.RegisteredID(x509.ObjectIdentifier(san)) for san in attributes.sans_oid
948
+ ])
949
+ if attributes.sans_ip:
950
+ sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in attributes.sans_ip])
951
+ if attributes.sans_dns:
952
+ sans.extend([x509.DNSName(san) for san in attributes.sans_dns])
953
+ if sans:
954
+ csr_builder = csr_builder.add_extension(
955
+ x509.SubjectAlternativeName(set(sans)), critical=False
956
+ )
957
+ if attributes.additional_critical_extensions:
958
+ for extension in attributes.additional_critical_extensions:
959
+ csr_builder = csr_builder.add_extension(extension, critical=True)
960
+ signed_certificate_request = csr_builder.sign(signing_key, hashes.SHA256())
961
+ return cls(x509_object=signed_certificate_request)
962
+
963
+
964
+ class CertificateRequestAttributes:
965
+ """A representation of the certificate request attributes."""
966
+
967
+ def __init__(
968
+ self,
969
+ common_name: str | None = None,
970
+ sans_dns: Collection[str] | None = None,
971
+ sans_ip: Collection[str] | None = None,
972
+ sans_oid: Collection[str] | None = None,
973
+ email_address: str | None = None,
974
+ organization: str | None = None,
975
+ organizational_unit: str | None = None,
976
+ country_name: str | None = None,
977
+ state_or_province_name: str | None = None,
978
+ locality_name: str | None = None,
979
+ is_ca: bool = False,
980
+ add_unique_id_to_subject_name: bool = True,
981
+ additional_critical_extensions: Collection[x509.ExtensionType] | None = None,
982
+ ):
983
+ if not common_name and not sans_dns and not sans_ip and not sans_oid:
984
+ raise ValueError(
985
+ "At least one of common_name, sans_dns, sans_ip, or sans_oid must be provided"
986
+ )
987
+ self._common_name = common_name
988
+ self._sans_dns = set(sans_dns) if sans_dns else None
989
+ self._sans_ip = set(sans_ip) if sans_ip else None
990
+ self._sans_oid = set(sans_oid) if sans_oid else None
991
+ self._email_address = email_address
992
+ self._organization = organization
993
+ self._organizational_unit = organizational_unit
994
+ self._country_name = country_name
995
+ self._state_or_province_name = state_or_province_name
996
+ self._locality_name = locality_name
997
+ self._is_ca = is_ca
998
+ self._add_unique_id_to_subject_name = add_unique_id_to_subject_name
999
+ self._additional_critical_extensions = list(additional_critical_extensions or [])
1000
+
1001
+ @property
1002
+ def common_name(self) -> str:
1003
+ """Return the common name."""
1004
+ # For legacy interface compatibility, return empty string if not set
1005
+ return self._common_name if self._common_name else ""
1006
+
1007
+ @property
1008
+ def sans_dns(self) -> set[str] | None:
1009
+ """Return the DNS Subject Alternative Names."""
1010
+ return self._sans_dns
1011
+
1012
+ @property
1013
+ def sans_ip(self) -> set[str] | None:
1014
+ """Return the IP Subject Alternative Names."""
1015
+ return self._sans_ip
1016
+
1017
+ @property
1018
+ def sans_oid(self) -> set[str] | None:
1019
+ """Return the OID Subject Alternative Names."""
1020
+ return self._sans_oid
1021
+
1022
+ @property
1023
+ def email_address(self) -> str | None:
1024
+ """Return the email address."""
1025
+ return self._email_address
1026
+
1027
+ @property
1028
+ def organization(self) -> str | None:
1029
+ """Return the organization name."""
1030
+ return self._organization
1031
+
1032
+ @property
1033
+ def organizational_unit(self) -> str | None:
1034
+ """Return the organizational unit name."""
1035
+ return self._organizational_unit
1036
+
1037
+ @property
1038
+ def country_name(self) -> str | None:
1039
+ """Return the country name."""
1040
+ return self._country_name
1041
+
1042
+ @property
1043
+ def state_or_province_name(self) -> str | None:
1044
+ """Return the state or province name."""
1045
+ return self._state_or_province_name
1046
+
1047
+ @property
1048
+ def locality_name(self) -> str | None:
1049
+ """Return the locality name."""
1050
+ return self._locality_name
1051
+
1052
+ @property
1053
+ def is_ca(self) -> bool:
1054
+ """Return whether the certificate is a CA certificate."""
1055
+ return self._is_ca
1056
+
1057
+ @property
1058
+ def add_unique_id_to_subject_name(self) -> bool:
1059
+ """Return whether to add a unique identifier to the subject name."""
1060
+ return self._add_unique_id_to_subject_name
1061
+
1062
+ @property
1063
+ def additional_critical_extensions(self) -> list[x509.ExtensionType]:
1064
+ """Return additional critical extensions to be added to the CSR."""
1065
+ return self._additional_critical_extensions
1066
+
1067
+ @classmethod
1068
+ def from_csr(cls, csr: CertificateSigningRequest, is_ca: bool) -> CertificateRequestAttributes:
1069
+ """Create CertificateRequestAttributes from a CertificateSigningRequest.
1070
+
1071
+ Args:
1072
+ csr: The CSR to extract attributes from.
1073
+ is_ca: Whether a CA certificate is being requested.
1074
+
1075
+ Returns:
1076
+ CertificateRequestAttributes: The extracted attributes.
1077
+ """
1078
+ return cls(
1079
+ common_name=csr.common_name,
1080
+ sans_dns=csr.sans_dns,
1081
+ sans_ip=csr.sans_ip,
1082
+ sans_oid=csr.sans_oid,
1083
+ email_address=csr.email_address,
1084
+ organization=csr.organization,
1085
+ organizational_unit=csr.organizational_unit,
1086
+ country_name=csr.country_name,
1087
+ state_or_province_name=csr.state_or_province_name,
1088
+ locality_name=csr.locality_name,
1089
+ is_ca=is_ca,
1090
+ add_unique_id_to_subject_name=csr.has_unique_identifier,
1091
+ additional_critical_extensions=csr.additional_critical_extensions,
1092
+ )
1093
+
1094
+ def __eq__(self, other: object) -> bool:
1095
+ """Check if two CertificateRequestAttributes objects are equal."""
1096
+ if not isinstance(other, CertificateRequestAttributes):
1097
+ return NotImplemented
1098
+ return (
1099
+ self.common_name == other.common_name
1100
+ and self.sans_dns == other.sans_dns
1101
+ and self.sans_ip == other.sans_ip
1102
+ and self.sans_oid == other.sans_oid
1103
+ and self.email_address == other.email_address
1104
+ and self.organization == other.organization
1105
+ and self.organizational_unit == other.organizational_unit
1106
+ and self.country_name == other.country_name
1107
+ and self.state_or_province_name == other.state_or_province_name
1108
+ and self.locality_name == other.locality_name
1109
+ and self.is_ca == other.is_ca
1110
+ and self.add_unique_id_to_subject_name == other.add_unique_id_to_subject_name
1111
+ and self.additional_critical_extensions == other.additional_critical_extensions
1112
+ )
1113
+
1114
+ def is_valid(self) -> bool:
1115
+ """Validate the attributes of the certificate request.
1116
+
1117
+ Returns:
1118
+ bool: True if the attributes are valid, False otherwise.
1119
+ """
1120
+ if not self.common_name and not self.sans_dns and not self.sans_ip and not self.sans_oid:
1121
+ logger.warning(
1122
+ "At least one of common_name, sans_dns, sans_ip, or sans_oid must be provided"
1123
+ )
1124
+ return False
1125
+ return True
1126
+
1127
+ def generate_csr(
1128
+ self,
1129
+ private_key: PrivateKey,
1130
+ ) -> CertificateSigningRequest:
1131
+ """Generate a CSR using the current attributes and a private key.
1132
+
1133
+ Args:
1134
+ private_key (PrivateKey): Private key to sign the CSR.
1135
+
1136
+ Returns:
1137
+ CertificateSigningRequest: The generated CSR.
1138
+ """
1139
+ return CertificateSigningRequest.generate(self, private_key)
1140
+
1141
+
1142
+ @dataclass(frozen=True)
1143
+ class ProviderCertificate:
1144
+ """This class represents a certificate provided by the TLS provider."""
1145
+
1146
+ relation_id: int
1147
+ certificate: Certificate
1148
+ certificate_signing_request: CertificateSigningRequest
1149
+ ca: Certificate
1150
+ chain: list[Certificate]
1151
+ revoked: bool | None = None
1152
+
1153
+ def to_json(self) -> str:
1154
+ """Return the object as a JSON string.
1155
+
1156
+ Returns:
1157
+ str: JSON representation of the object
1158
+ """
1159
+ return json.dumps({
1160
+ "csr": str(self.certificate_signing_request),
1161
+ "certificate": str(self.certificate),
1162
+ "ca": str(self.ca),
1163
+ "chain": [str(cert) for cert in self.chain],
1164
+ "revoked": self.revoked,
1165
+ })
1166
+
1167
+
1168
+ @dataclass(frozen=True)
1169
+ class RequirerCertificateRequest:
1170
+ """This class represents a certificate signing request requested by a specific TLS requirer."""
1171
+
1172
+ relation_id: int
1173
+ certificate_signing_request: CertificateSigningRequest
1174
+ is_ca: bool
1175
+
1176
+
1177
+ class CertificateAvailableEvent(EventBase):
1178
+ """Charm Event triggered when a TLS certificate is available."""
1179
+
1180
+ def __init__(
1181
+ self,
1182
+ handle: Handle,
1183
+ certificate: Certificate,
1184
+ certificate_signing_request: CertificateSigningRequest,
1185
+ ca: Certificate,
1186
+ chain: list[Certificate],
1187
+ ):
1188
+ super().__init__(handle)
1189
+ self.certificate = certificate
1190
+ self.certificate_signing_request = certificate_signing_request
1191
+ self.ca = ca
1192
+ self.chain = chain
1193
+
1194
+ def snapshot(self) -> dict:
1195
+ """Return snapshot."""
1196
+ return {
1197
+ "certificate": str(self.certificate),
1198
+ "certificate_signing_request": str(self.certificate_signing_request),
1199
+ "ca": str(self.ca),
1200
+ "chain": json.dumps([str(certificate) for certificate in self.chain]),
1201
+ }
1202
+
1203
+ def restore(self, snapshot: dict):
1204
+ """Restore snapshot."""
1205
+ self.certificate = Certificate.from_string(snapshot["certificate"])
1206
+ self.certificate_signing_request = CertificateSigningRequest.from_string(
1207
+ snapshot["certificate_signing_request"]
1208
+ )
1209
+ self.ca = Certificate.from_string(snapshot["ca"])
1210
+ chain_strs = json.loads(snapshot["chain"])
1211
+ self.chain = [Certificate.from_string(chain_str) for chain_str in chain_strs]
1212
+
1213
+ def chain_as_pem(self) -> str:
1214
+ """Return full certificate chain as a PEM string."""
1215
+ return "\n\n".join([str(cert) for cert in self.chain])
1216
+
1217
+
1218
+ def generate_private_key(
1219
+ key_size: int = 2048,
1220
+ public_exponent: int = 65537,
1221
+ ) -> PrivateKey:
1222
+ """Generate a private key with the RSA algorithm.
1223
+
1224
+ Args:
1225
+ key_size (int): Key size in bits, must be at least 2048 bits
1226
+ public_exponent: Public exponent.
1227
+
1228
+ Returns:
1229
+ PrivateKey: Private Key
1230
+ """
1231
+ warnings.warn( # noqa: B028
1232
+ "generate_private_key() is deprecated. Use PrivateKey.generate() instead.",
1233
+ DeprecationWarning,
1234
+ )
1235
+ return PrivateKey.generate(key_size=key_size, public_exponent=public_exponent)
1236
+
1237
+
1238
+ def calculate_relative_datetime(target_time: datetime, fraction: float) -> datetime:
1239
+ """Calculate a datetime that is a given percentage from now to a target time.
1240
+
1241
+ Args:
1242
+ target_time (datetime): The future datetime to interpolate towards.
1243
+ fraction (float): Fraction of the interval from now to target_time (0.0-1.0).
1244
+ 1.0 means return target_time,
1245
+ 0.9 means return the time after 90% of the interval has passed,
1246
+ and 0.0 means return now.
1247
+ """
1248
+ if fraction <= 0.0 or fraction > 1.0:
1249
+ raise ValueError("Invalid fraction. Must be between 0.0 and 1.0")
1250
+ now = datetime.now(timezone.utc)
1251
+ time_until_target = target_time - now
1252
+ return now + time_until_target * fraction
1253
+
1254
+
1255
+ def chain_has_valid_order(chain: list[str]) -> bool:
1256
+ """Check if the chain has a valid order.
1257
+
1258
+ Validates that each certificate in the chain is properly signed by the next certificate.
1259
+ The chain should be ordered from leaf to root, where each certificate is signed by
1260
+ the next one in the chain.
1261
+
1262
+ Args:
1263
+ chain (List[str]): List of certificates in PEM format, ordered from leaf to root
1264
+
1265
+ Returns:
1266
+ bool: True if the chain has a valid order, False otherwise.
1267
+ """
1268
+ if len(chain) < 2:
1269
+ return True
1270
+
1271
+ try:
1272
+ for i in range(len(chain) - 1):
1273
+ cert = x509.load_pem_x509_certificate(chain[i].encode())
1274
+ issuer = x509.load_pem_x509_certificate(chain[i + 1].encode())
1275
+ cert.verify_directly_issued_by(issuer)
1276
+ return True
1277
+ except (ValueError, TypeError, InvalidSignature):
1278
+ return False
1279
+
1280
+
1281
+ def generate_csr(
1282
+ private_key: PrivateKey,
1283
+ common_name: str,
1284
+ sans_dns: frozenset[str] | None = frozenset(),
1285
+ sans_ip: frozenset[str] | None = frozenset(),
1286
+ sans_oid: frozenset[str] | None = frozenset(),
1287
+ organization: str | None = None,
1288
+ organizational_unit: str | None = None,
1289
+ email_address: str | None = None,
1290
+ country_name: str | None = None,
1291
+ locality_name: str | None = None,
1292
+ state_or_province_name: str | None = None,
1293
+ add_unique_id_to_subject_name: bool = True,
1294
+ ) -> CertificateSigningRequest:
1295
+ """Generate a CSR using private key and subject.
1296
+
1297
+ Args:
1298
+ private_key (PrivateKey): Private key
1299
+ common_name (str): Common name
1300
+ sans_dns (FrozenSet[str]): DNS Subject Alternative Names
1301
+ sans_ip (FrozenSet[str]): IP Subject Alternative Names
1302
+ sans_oid (FrozenSet[str]): OID Subject Alternative Names
1303
+ organization (Optional[str]): Organization name
1304
+ organizational_unit (Optional[str]): Organizational unit name
1305
+ email_address (Optional[str]): Email address
1306
+ country_name (Optional[str]): Country name
1307
+ state_or_province_name (Optional[str]): State or province name
1308
+ locality_name (Optional[str]): Locality name
1309
+ add_unique_id_to_subject_name (bool): Whether a unique ID must be added to the CSR's
1310
+ subject name. Always leave to "True" when the CSR is used to request certificates
1311
+ using the tls-certificates relation.
1312
+
1313
+ Returns:
1314
+ CertificateSigningRequest: CSR
1315
+ """
1316
+ warnings.warn( # noqa: B028
1317
+ (
1318
+ "generate_csr() is deprecated. Use "
1319
+ "CertificateRequestAttributes.generate_csr() or "
1320
+ "CertificateSigningRequest.generate() instead."
1321
+ ),
1322
+ DeprecationWarning,
1323
+ )
1324
+ return CertificateRequestAttributes(
1325
+ common_name=common_name,
1326
+ sans_dns=sans_dns,
1327
+ sans_ip=sans_ip,
1328
+ sans_oid=sans_oid,
1329
+ organization=organization,
1330
+ organizational_unit=organizational_unit,
1331
+ email_address=email_address,
1332
+ country_name=country_name,
1333
+ state_or_province_name=state_or_province_name,
1334
+ locality_name=locality_name,
1335
+ add_unique_id_to_subject_name=add_unique_id_to_subject_name,
1336
+ ).generate_csr(private_key=private_key)
1337
+
1338
+
1339
+ def generate_ca(
1340
+ private_key: PrivateKey,
1341
+ validity: timedelta,
1342
+ common_name: str,
1343
+ sans_dns: frozenset[str] | None = frozenset(),
1344
+ sans_ip: frozenset[str] | None = frozenset(),
1345
+ sans_oid: frozenset[str] | None = frozenset(),
1346
+ organization: str | None = None,
1347
+ organizational_unit: str | None = None,
1348
+ email_address: str | None = None,
1349
+ country_name: str | None = None,
1350
+ state_or_province_name: str | None = None,
1351
+ locality_name: str | None = None,
1352
+ ) -> Certificate:
1353
+ """Generate a self signed CA Certificate.
1354
+
1355
+ Args:
1356
+ private_key: Private key
1357
+ validity: Certificate validity time
1358
+ common_name: Common Name that can be an IP or a Full Qualified Domain Name (FQDN).
1359
+ sans_dns: DNS Subject Alternative Names
1360
+ sans_ip: IP Subject Alternative Names
1361
+ sans_oid: OID Subject Alternative Names
1362
+ organization: Organization name
1363
+ organizational_unit: Organizational unit name
1364
+ email_address: Email address
1365
+ country_name: Certificate Issuing country
1366
+ state_or_province_name: Certificate Issuing state or province
1367
+ locality_name: Certificate Issuing locality
1368
+
1369
+ Returns:
1370
+ CA Certificate.
1371
+ """
1372
+ warnings.warn( # noqa: B028
1373
+ "generate_ca() is deprecated. Use Certificate.generate_self_signed_ca() instead.",
1374
+ DeprecationWarning,
1375
+ )
1376
+ attributes = CertificateRequestAttributes(
1377
+ common_name=common_name,
1378
+ sans_dns=sans_dns,
1379
+ sans_ip=sans_ip,
1380
+ sans_oid=sans_oid,
1381
+ organization=organization,
1382
+ organizational_unit=organizational_unit,
1383
+ email_address=email_address,
1384
+ country_name=country_name,
1385
+ state_or_province_name=state_or_province_name,
1386
+ locality_name=locality_name,
1387
+ is_ca=True,
1388
+ )
1389
+ return Certificate.generate_self_signed_ca(attributes, private_key, validity)
1390
+
1391
+
1392
+ def _san_extension(
1393
+ email_address: str | None = None,
1394
+ sans_dns: Collection[str] | None = frozenset(),
1395
+ sans_ip: Collection[str] | None = frozenset(),
1396
+ sans_oid: Collection[str] | None = frozenset(),
1397
+ ) -> x509.SubjectAlternativeName | None:
1398
+ sans: list[x509.GeneralName] = []
1399
+ if email_address:
1400
+ # If an e-mail address was provided, it should always be in the SAN
1401
+ sans.append(x509.RFC822Name(email_address))
1402
+ if sans_dns:
1403
+ sans.extend([x509.DNSName(san) for san in sans_dns])
1404
+ if sans_ip:
1405
+ sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in sans_ip])
1406
+ if sans_oid:
1407
+ sans.extend([x509.RegisteredID(x509.ObjectIdentifier(san)) for san in sans_oid])
1408
+ if not sans:
1409
+ return None
1410
+ return x509.SubjectAlternativeName(sans)
1411
+
1412
+
1413
+ def generate_certificate(
1414
+ csr: CertificateSigningRequest,
1415
+ ca: Certificate,
1416
+ ca_private_key: PrivateKey,
1417
+ validity: timedelta,
1418
+ is_ca: bool = False,
1419
+ ) -> Certificate:
1420
+ """Generate a TLS certificate based on a CSR.
1421
+
1422
+ Args:
1423
+ csr (CertificateSigningRequest): CSR
1424
+ ca (Certificate): CA Certificate
1425
+ ca_private_key (PrivateKey): CA private key
1426
+ validity (timedelta): Certificate validity time
1427
+ is_ca (bool): Whether the certificate is a CA certificate
1428
+
1429
+ Returns:
1430
+ Certificate: Certificate
1431
+ """
1432
+ warnings.warn( # noqa: B028
1433
+ "generate_certificate() is deprecated. Use Certificate.generate() instead.",
1434
+ DeprecationWarning,
1435
+ )
1436
+ return Certificate.generate(
1437
+ csr=csr,
1438
+ ca=ca,
1439
+ ca_private_key=ca_private_key,
1440
+ validity=validity,
1441
+ is_ca=is_ca,
1442
+ )
1443
+
1444
+
1445
+ def _extract_subject_name_attributes(
1446
+ attributes: CertificateRequestAttributes,
1447
+ ) -> x509.Name | None:
1448
+ subject_name_attributes = []
1449
+ if attributes.common_name:
1450
+ subject_name_attributes.append(
1451
+ x509.NameAttribute(x509.NameOID.COMMON_NAME, attributes.common_name)
1452
+ )
1453
+ if attributes.add_unique_id_to_subject_name:
1454
+ unique_identifier = uuid.uuid4()
1455
+ subject_name_attributes.append(
1456
+ x509.NameAttribute(x509.NameOID.X500_UNIQUE_IDENTIFIER, str(unique_identifier))
1457
+ )
1458
+ if attributes.organization:
1459
+ subject_name_attributes.append(
1460
+ x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, attributes.organization)
1461
+ )
1462
+ if attributes.organizational_unit:
1463
+ subject_name_attributes.append(
1464
+ x509.NameAttribute(
1465
+ x509.NameOID.ORGANIZATIONAL_UNIT_NAME,
1466
+ attributes.organizational_unit,
1467
+ )
1468
+ )
1469
+ if attributes.email_address:
1470
+ subject_name_attributes.append(
1471
+ x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, attributes.email_address)
1472
+ )
1473
+ if attributes.country_name:
1474
+ subject_name_attributes.append(
1475
+ x509.NameAttribute(x509.NameOID.COUNTRY_NAME, attributes.country_name)
1476
+ )
1477
+ if attributes.state_or_province_name:
1478
+ subject_name_attributes.append(
1479
+ x509.NameAttribute(
1480
+ x509.NameOID.STATE_OR_PROVINCE_NAME,
1481
+ attributes.state_or_province_name,
1482
+ )
1483
+ )
1484
+ if attributes.locality_name:
1485
+ subject_name_attributes.append(
1486
+ x509.NameAttribute(x509.NameOID.LOCALITY_NAME, attributes.locality_name)
1487
+ )
1488
+
1489
+ if subject_name_attributes:
1490
+ return x509.Name(subject_name_attributes)
1491
+
1492
+ return None
1493
+
1494
+
1495
+ def _generate_certificate_request_extensions(
1496
+ authority_key_identifier: bytes,
1497
+ csr: x509.CertificateSigningRequest,
1498
+ is_ca: bool,
1499
+ ) -> list[x509.Extension]:
1500
+ """Generate a list of certificate extensions from a CSR and other known information.
1501
+
1502
+ Args:
1503
+ authority_key_identifier (bytes): Authority key identifier
1504
+ csr (x509.CertificateSigningRequest): CSR
1505
+ is_ca (bool): Whether the certificate is a CA certificate
1506
+
1507
+ Returns:
1508
+ List[x509.Extension]: List of extensions
1509
+ """
1510
+ cert_extensions_list: list[x509.Extension] = [
1511
+ x509.Extension(
1512
+ oid=ExtensionOID.AUTHORITY_KEY_IDENTIFIER,
1513
+ value=x509.AuthorityKeyIdentifier(
1514
+ key_identifier=authority_key_identifier,
1515
+ authority_cert_issuer=None,
1516
+ authority_cert_serial_number=None,
1517
+ ),
1518
+ critical=False,
1519
+ ),
1520
+ x509.Extension(
1521
+ oid=ExtensionOID.SUBJECT_KEY_IDENTIFIER,
1522
+ value=x509.SubjectKeyIdentifier.from_public_key(csr.public_key()),
1523
+ critical=False,
1524
+ ),
1525
+ x509.Extension(
1526
+ oid=ExtensionOID.BASIC_CONSTRAINTS,
1527
+ critical=True,
1528
+ value=x509.BasicConstraints(ca=is_ca, path_length=None),
1529
+ ),
1530
+ ]
1531
+ if sans := _generate_subject_alternative_name_extension(csr):
1532
+ cert_extensions_list.append(sans)
1533
+
1534
+ if is_ca:
1535
+ cert_extensions_list.append(
1536
+ x509.Extension(
1537
+ ExtensionOID.KEY_USAGE,
1538
+ critical=True,
1539
+ value=x509.KeyUsage(
1540
+ digital_signature=False,
1541
+ content_commitment=False,
1542
+ key_encipherment=False,
1543
+ data_encipherment=False,
1544
+ key_agreement=False,
1545
+ key_cert_sign=True,
1546
+ crl_sign=True,
1547
+ encipher_only=False,
1548
+ decipher_only=False,
1549
+ ),
1550
+ )
1551
+ )
1552
+
1553
+ existing_oids = {ext.oid for ext in cert_extensions_list}
1554
+ for extension in csr.extensions:
1555
+ if extension.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME:
1556
+ continue
1557
+ if extension.oid in existing_oids:
1558
+ logger.warning("Extension %s is managed by the TLS provider, ignoring.", extension.oid)
1559
+ continue
1560
+ cert_extensions_list.append(extension)
1561
+
1562
+ return cert_extensions_list
1563
+
1564
+
1565
+ def _generate_subject_alternative_name_extension(
1566
+ csr: x509.CertificateSigningRequest,
1567
+ ) -> x509.Extension | None:
1568
+ sans: list[x509.GeneralName] = []
1569
+ try:
1570
+ loaded_san_ext = csr.extensions.get_extension_for_class(x509.SubjectAlternativeName)
1571
+ sans.extend([
1572
+ x509.DNSName(name) for name in loaded_san_ext.value.get_values_for_type(x509.DNSName)
1573
+ ])
1574
+ sans.extend([
1575
+ x509.IPAddress(ip) for ip in loaded_san_ext.value.get_values_for_type(x509.IPAddress)
1576
+ ])
1577
+ sans.extend([
1578
+ x509.RegisteredID(oid)
1579
+ for oid in loaded_san_ext.value.get_values_for_type(x509.RegisteredID)
1580
+ ])
1581
+ sans.extend([
1582
+ x509.RFC822Name(name)
1583
+ for name in loaded_san_ext.value.get_values_for_type(x509.RFC822Name)
1584
+ ])
1585
+ except x509.ExtensionNotFound:
1586
+ pass
1587
+ # If email is present in the CSR Subject, make sure it is also in the SANS
1588
+ # to conform to RFC 5280.
1589
+ email = csr.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS)
1590
+ if email:
1591
+ email_rfc822 = x509.RFC822Name(str(email[0].value))
1592
+ if email_rfc822 not in sans:
1593
+ sans.append(email_rfc822)
1594
+
1595
+ return (
1596
+ x509.Extension(
1597
+ oid=ExtensionOID.SUBJECT_ALTERNATIVE_NAME,
1598
+ critical=False,
1599
+ value=x509.SubjectAlternativeName(sans),
1600
+ )
1601
+ if sans
1602
+ else None
1603
+ )
1604
+
1605
+
1606
+ class CertificatesRequirerCharmEvents(CharmEvents):
1607
+ """List of events that the TLS Certificates requirer charm can leverage."""
1608
+
1609
+ certificate_available = EventSource(CertificateAvailableEvent)
1610
+
1611
+
1612
+ class TLSCertificatesRequiresV4(Object):
1613
+ """A class to manage the TLS certificates interface for a unit or app."""
1614
+
1615
+ on = CertificatesRequirerCharmEvents() # type: ignore[reportAssignmentType]
1616
+
1617
+ def __init__(
1618
+ self,
1619
+ charm: CharmBase,
1620
+ relationship_name: str,
1621
+ certificate_requests: list[CertificateRequestAttributes],
1622
+ mode: Mode = Mode.UNIT,
1623
+ refresh_events: list[BoundEvent] | None = None,
1624
+ private_key: PrivateKey | None = None,
1625
+ renewal_relative_time: float = 0.9,
1626
+ ):
1627
+ """Create a new instance of the TLSCertificatesRequiresV4 class.
1628
+
1629
+ Args:
1630
+ charm (CharmBase): The charm instance to relate to.
1631
+ relationship_name (str): The name of the relation that provides the certificates.
1632
+ certificate_requests (List[CertificateRequestAttributes]):
1633
+ A list with the attributes of the certificate requests.
1634
+ mode (Mode): Whether to use unit or app certificates mode. Default is Mode.UNIT.
1635
+ In UNIT mode the requirer will place the csr in the unit relation data.
1636
+ Each unit will manage its private key,
1637
+ certificate signing request and certificate.
1638
+ UNIT mode is for use cases where each unit has its own identity.
1639
+ If you don't know which mode to use, you likely need UNIT.
1640
+ In APP mode the leader unit will place the csr in the app relation databag.
1641
+ APP mode is for use cases where the underlying application needs the certificate
1642
+ for example using it as an intermediate CA to sign other certificates.
1643
+ The certificate can only be accessed by the leader unit.
1644
+ refresh_events (List[BoundEvent]): A list of events to trigger a refresh of
1645
+ the certificates.
1646
+ private_key (Optional[PrivateKey]): The private key to use for the certificates.
1647
+ If provided, it will be used instead of generating a new one.
1648
+ If the key is not valid an exception will be raised.
1649
+ Using this parameter is discouraged,
1650
+ having to pass around private keys manually can be a security concern.
1651
+ Allowing the library to generate and manage the key is the more secure approach.
1652
+ renewal_relative_time (float): The time to renew the certificate relative to its
1653
+ expiry.
1654
+ Default is 0.9, meaning 90% of the validity period.
1655
+ The minimum value is 0.5, meaning 50% of the validity period.
1656
+ If an invalid value is provided, an exception will be raised.
1657
+ """
1658
+ if refresh_events is None:
1659
+ refresh_events = []
1660
+ super().__init__(charm, relationship_name)
1661
+ if not JujuVersion.from_environ().has_secrets:
1662
+ logger.warning("This version of the TLS library requires Juju secrets (Juju >= 3.0)")
1663
+ if not self._mode_is_valid(mode):
1664
+ raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP")
1665
+ for certificate_request in certificate_requests:
1666
+ if not certificate_request.is_valid():
1667
+ raise TLSCertificatesError("Invalid certificate request")
1668
+ self.charm = charm
1669
+ self.relationship_name = relationship_name
1670
+ self.certificate_requests = certificate_requests
1671
+ self.mode = mode
1672
+ if private_key and not private_key.is_valid():
1673
+ raise TLSCertificatesError("Invalid private key")
1674
+ if renewal_relative_time <= 0.5 or renewal_relative_time > 1.0:
1675
+ raise TLSCertificatesError(
1676
+ "Invalid renewal relative time. Must be between 0.5 and 1.0"
1677
+ )
1678
+ self._private_key = private_key
1679
+ self.renewal_relative_time = renewal_relative_time
1680
+ self.framework.observe(charm.on[relationship_name].relation_created, self._configure)
1681
+ self.framework.observe(charm.on[relationship_name].relation_changed, self._configure)
1682
+ self.framework.observe(charm.on.secret_expired, self._on_secret_expired)
1683
+ self.framework.observe(charm.on.secret_remove, self._on_secret_remove)
1684
+ for event in refresh_events:
1685
+ self.framework.observe(event, self._configure)
1686
+ self._security_logger = _OWASPLogger(application=f"tls-certificates-{charm.app.name}")
1687
+
1688
+ def _configure(self, _: EventBase | None = None):
1689
+ """Handle TLS Certificates Relation Data.
1690
+
1691
+ This method is called during any TLS relation event.
1692
+ It will generate a private key if it doesn't exist yet.
1693
+ It will send certificate requests if they haven't been sent yet.
1694
+ It will find available certificates and emit events.
1695
+ """
1696
+ if not self._tls_relation_created():
1697
+ logger.debug("TLS relation not created yet.")
1698
+ return
1699
+ self._ensure_private_key()
1700
+ self._cleanup_certificate_requests()
1701
+ self._send_certificate_requests()
1702
+ self._find_available_certificates()
1703
+
1704
+ def _mode_is_valid(self, mode: Mode) -> bool:
1705
+ return mode in [Mode.UNIT, Mode.APP]
1706
+
1707
+ def _validate_secret_exists(self, secret: Secret) -> None:
1708
+ secret.get_info() # Will raise `SecretNotFoundError` if the secret does not exist
1709
+
1710
+ def _on_secret_remove(self, event: SecretRemoveEvent) -> None:
1711
+ """Handle Secret Removed Event."""
1712
+ try:
1713
+ # Ensure the secret exists before trying to remove it, otherwise
1714
+ # the unit could be stuck in an error state. See the docstring of
1715
+ # `remove_revision` and the below issue for more information.
1716
+ # https://github.com/juju/juju/issues/19036
1717
+ self._validate_secret_exists(event.secret)
1718
+ event.secret.remove_revision(event.revision)
1719
+ except SecretNotFoundError:
1720
+ logger.warning(
1721
+ "No such secret %s, nothing to remove",
1722
+ event.secret.label or event.secret.id,
1723
+ )
1724
+ return
1725
+
1726
+ def _on_secret_expired(self, event: SecretExpiredEvent) -> None:
1727
+ """Handle Secret Expired Event.
1728
+
1729
+ Renews certificate requests and removes the expired secret.
1730
+ """
1731
+ if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-certificate"):
1732
+ return
1733
+ try:
1734
+ csr_str = event.secret.get_content(refresh=True)["csr"]
1735
+ except ModelError:
1736
+ logger.error("Failed to get CSR from secret - Skipping")
1737
+ return
1738
+ csr = CertificateSigningRequest.from_string(csr_str)
1739
+ self._renew_certificate_request(csr)
1740
+ event.secret.remove_all_revisions()
1741
+
1742
+ def sync(self) -> None:
1743
+ """Sync TLS Certificates Relation Data.
1744
+
1745
+ This method allows the requirer to sync the TLS certificates relation data
1746
+ without waiting for the refresh events to be triggered.
1747
+ """
1748
+ self._configure()
1749
+
1750
+ def renew_certificate(self, certificate: ProviderCertificate) -> None:
1751
+ """Request the renewal of the provided certificate."""
1752
+ certificate_signing_request = certificate.certificate_signing_request
1753
+ secret_label = self._get_csr_secret_label(certificate_signing_request)
1754
+ try:
1755
+ secret = self.model.get_secret(label=secret_label)
1756
+ except SecretNotFoundError:
1757
+ logger.warning("No matching secret found - Skipping renewal")
1758
+ return
1759
+ current_csr = secret.get_content(refresh=True).get("csr", "")
1760
+ if current_csr != str(certificate_signing_request):
1761
+ logger.warning("No matching CSR found - Skipping renewal")
1762
+ return
1763
+ self._renew_certificate_request(certificate_signing_request)
1764
+ secret.remove_all_revisions()
1765
+
1766
+ def _renew_certificate_request(self, csr: CertificateSigningRequest):
1767
+ """Remove existing CSR from relation data and create a new one."""
1768
+ self._remove_requirer_csr_from_relation_data(csr)
1769
+ self._send_certificate_requests()
1770
+ logger.info("Renewed certificate request")
1771
+
1772
+ def _remove_requirer_csr_from_relation_data(self, csr: CertificateSigningRequest) -> None:
1773
+ relation = self.model.get_relation(self.relationship_name)
1774
+ if not relation:
1775
+ logger.debug("No relation: %s", self.relationship_name)
1776
+ return
1777
+ if not self.get_csrs_from_requirer_relation_data():
1778
+ logger.info("No CSRs in relation data - Doing nothing")
1779
+ return
1780
+ app_or_unit = self._get_app_or_unit()
1781
+ try:
1782
+ requirer_relation_data = _RequirerData.load(relation.data[app_or_unit])
1783
+ except DataValidationError:
1784
+ logger.warning("Invalid relation data - Skipping removal of CSR")
1785
+ return
1786
+ new_relation_data = copy.deepcopy(requirer_relation_data.certificate_signing_requests)
1787
+ for requirer_csr in new_relation_data:
1788
+ if requirer_csr.certificate_signing_request.strip() == str(csr).strip():
1789
+ new_relation_data.remove(requirer_csr)
1790
+ try:
1791
+ _RequirerData(certificate_signing_requests=new_relation_data).dump(
1792
+ relation.data[app_or_unit]
1793
+ )
1794
+ logger.info("Removed CSR from relation data")
1795
+ except ModelError:
1796
+ logger.warning("Failed to update relation data")
1797
+
1798
+ def _get_app_or_unit(self) -> Application | Unit:
1799
+ """Return the unit or app object based on the mode."""
1800
+ if self.mode == Mode.UNIT:
1801
+ return self.model.unit
1802
+ elif self.mode == Mode.APP:
1803
+ return self.model.app
1804
+ raise TLSCertificatesError("Invalid mode")
1805
+
1806
+ @property
1807
+ def private_key(self) -> PrivateKey | None:
1808
+ """Return the private key."""
1809
+ if self._private_key:
1810
+ return self._private_key
1811
+ if not self._private_key_generated():
1812
+ return None
1813
+ secret = self.charm.model.get_secret(label=self._get_private_key_secret_label())
1814
+ private_key = secret.get_content(refresh=True)["private-key"]
1815
+ return PrivateKey.from_string(private_key)
1816
+
1817
+ def _ensure_private_key(self) -> None:
1818
+ """Make sure there is a private key to be used.
1819
+
1820
+ It will make sure there is a private key passed by the charm using the private_key
1821
+ parameter or generate a new one otherwise.
1822
+ """
1823
+ # Remove the generated private key
1824
+ # if one has been passed by the charm using the private_key parameter
1825
+ if self._private_key:
1826
+ self._remove_private_key_secret()
1827
+ return
1828
+ if self._private_key_generated():
1829
+ logger.debug("Private key already generated")
1830
+ return
1831
+ if self.mode == Mode.APP and not self.model.unit.is_leader():
1832
+ logger.debug("Not leader, skipping private key generation in APP mode")
1833
+ return
1834
+ self._generate_private_key()
1835
+
1836
+ def regenerate_private_key(self) -> None:
1837
+ """Regenerate the private key.
1838
+
1839
+ Generate a new private key, remove old certificate requests and send new ones.
1840
+
1841
+ Raises:
1842
+ TLSCertificatesError: If the private key is passed by the charm using the
1843
+ private_key parameter.
1844
+ """
1845
+ if self._private_key:
1846
+ raise TLSCertificatesError(
1847
+ "Private key is passed by the charm through the private_key parameter, "
1848
+ "this function can't be used"
1849
+ )
1850
+ if self.mode == Mode.APP and not self.model.unit.is_leader():
1851
+ logger.warning("Only the leader can regenerate the private key in APP mode")
1852
+ return
1853
+ if not self._private_key_generated():
1854
+ logger.warning("No private key to regenerate")
1855
+ return
1856
+ self._generate_private_key()
1857
+ self._cleanup_certificate_requests()
1858
+ self._send_certificate_requests()
1859
+
1860
+ def _generate_private_key(self) -> None:
1861
+ """Generate a new private key and store it in a secret.
1862
+
1863
+ This is the case when the private key used is generated by the library.
1864
+ and not passed by the charm using the private_key parameter.
1865
+ """
1866
+ self._store_private_key_in_secret(generate_private_key())
1867
+ logger.info("Private key generated")
1868
+
1869
+ def _private_key_generated(self) -> bool:
1870
+ """Check if a private key is stored in a secret.
1871
+
1872
+ This is the case when the private key used is generated by the library.
1873
+ This should not exist when the private key used
1874
+ is passed by the charm using the private_key parameter.
1875
+ """
1876
+ try:
1877
+ secret = self.charm.model.get_secret(label=self._get_private_key_secret_label())
1878
+ secret.get_content(refresh=True)
1879
+ return True
1880
+ except SecretNotFoundError:
1881
+ return False
1882
+
1883
+ def _store_private_key_in_secret(self, private_key: PrivateKey) -> None:
1884
+ try:
1885
+ secret = self.charm.model.get_secret(label=self._get_private_key_secret_label())
1886
+ secret.set_content({"private-key": str(private_key)})
1887
+ secret.get_content(refresh=True)
1888
+ except SecretNotFoundError:
1889
+ app_or_unit = self._get_app_or_unit()
1890
+ app_or_unit.add_secret(
1891
+ content={"private-key": str(private_key)},
1892
+ label=self._get_private_key_secret_label(),
1893
+ )
1894
+
1895
+ def _remove_private_key_secret(self) -> None:
1896
+ """Remove the private key secret."""
1897
+ if self.mode == Mode.APP and not self.model.unit.is_leader():
1898
+ logger.debug("Not leader, cannot remove app owned private key secret")
1899
+ return
1900
+ try:
1901
+ secret = self.charm.model.get_secret(label=self._get_private_key_secret_label())
1902
+ secret.remove_all_revisions()
1903
+ except SecretNotFoundError:
1904
+ logger.warning("Private key secret not found, nothing to remove")
1905
+
1906
+ def _csr_matches_certificate_request(
1907
+ self, certificate_signing_request: CertificateSigningRequest, is_ca: bool
1908
+ ) -> bool:
1909
+ for certificate_request in self.certificate_requests:
1910
+ if certificate_request == CertificateRequestAttributes.from_csr(
1911
+ certificate_signing_request,
1912
+ is_ca,
1913
+ ):
1914
+ return True
1915
+ return False
1916
+
1917
+ def _certificate_requested(self, certificate_request: CertificateRequestAttributes) -> bool:
1918
+ if not self.private_key:
1919
+ return False
1920
+ csr = self._certificate_requested_for_attributes(certificate_request)
1921
+ if not csr:
1922
+ return False
1923
+ if not csr.certificate_signing_request.matches_private_key(key=self.private_key):
1924
+ return False
1925
+ return True
1926
+
1927
+ def _certificate_requested_for_attributes(
1928
+ self,
1929
+ certificate_request: CertificateRequestAttributes,
1930
+ ) -> RequirerCertificateRequest | None:
1931
+ for requirer_csr in self.get_csrs_from_requirer_relation_data():
1932
+ if certificate_request == CertificateRequestAttributes.from_csr(
1933
+ requirer_csr.certificate_signing_request,
1934
+ requirer_csr.is_ca,
1935
+ ):
1936
+ return requirer_csr
1937
+ return None
1938
+
1939
+ def get_csrs_from_requirer_relation_data(self) -> list[RequirerCertificateRequest]:
1940
+ """Return list of requirer's CSRs from relation data."""
1941
+ if self.mode == Mode.APP and not self.model.unit.is_leader():
1942
+ logger.debug("Not a leader unit - Skipping")
1943
+ return []
1944
+ relation = self.model.get_relation(self.relationship_name)
1945
+ if not relation:
1946
+ logger.debug("No relation: %s", self.relationship_name)
1947
+ return []
1948
+ app_or_unit = self._get_app_or_unit()
1949
+ try:
1950
+ requirer_relation_data = _RequirerData.load(relation.data[app_or_unit])
1951
+ except DataValidationError:
1952
+ logger.warning("Invalid relation data")
1953
+ return []
1954
+ return [
1955
+ RequirerCertificateRequest(
1956
+ relation_id=relation.id,
1957
+ certificate_signing_request=CertificateSigningRequest.from_string(
1958
+ csr.certificate_signing_request
1959
+ ),
1960
+ is_ca=csr.ca if csr.ca else False,
1961
+ )
1962
+ for csr in requirer_relation_data.certificate_signing_requests
1963
+ ]
1964
+
1965
+ def get_provider_certificates(self) -> list[ProviderCertificate]:
1966
+ """Return list of certificates from the provider's relation data."""
1967
+ return self._load_provider_certificates()
1968
+
1969
+ def _load_provider_certificates(self) -> list[ProviderCertificate]:
1970
+ relation = self.model.get_relation(self.relationship_name)
1971
+ if not relation:
1972
+ logger.debug("No relation: %s", self.relationship_name)
1973
+ return []
1974
+ if not relation.app:
1975
+ logger.debug("No remote app in relation: %s", self.relationship_name)
1976
+ return []
1977
+ try:
1978
+ provider_relation_data = _ProviderApplicationData.load(relation.data[relation.app])
1979
+ except DataValidationError:
1980
+ logger.warning("Invalid relation data")
1981
+ return []
1982
+ return [
1983
+ certificate.to_provider_certificate(relation_id=relation.id)
1984
+ for certificate in provider_relation_data.certificates
1985
+ ]
1986
+
1987
+ def _request_certificate(self, csr: CertificateSigningRequest, is_ca: bool) -> None:
1988
+ """Add CSR to relation data."""
1989
+ if self.mode == Mode.APP and not self.model.unit.is_leader():
1990
+ logger.debug("Not a leader unit - Skipping")
1991
+ return
1992
+ relation = self.model.get_relation(self.relationship_name)
1993
+ if not relation:
1994
+ logger.debug("No relation: %s", self.relationship_name)
1995
+ return
1996
+ new_csr = _CertificateSigningRequest(
1997
+ certificate_signing_request=str(csr).strip(), ca=is_ca
1998
+ )
1999
+ app_or_unit = self._get_app_or_unit()
2000
+ try:
2001
+ requirer_relation_data = _RequirerData.load(relation.data[app_or_unit])
2002
+ except DataValidationError:
2003
+ requirer_relation_data = _RequirerData(
2004
+ certificate_signing_requests=[],
2005
+ )
2006
+ new_relation_data = copy.deepcopy(requirer_relation_data.certificate_signing_requests)
2007
+ new_relation_data.append(new_csr)
2008
+ try:
2009
+ _RequirerData(certificate_signing_requests=new_relation_data).dump(
2010
+ relation.data[app_or_unit]
2011
+ )
2012
+ logger.info("Certificate signing request added to relation data.")
2013
+ except ModelError:
2014
+ logger.warning("Failed to update relation data")
2015
+
2016
+ def _send_certificate_requests(self):
2017
+ if not self.private_key:
2018
+ logger.debug("Private key not generated yet.")
2019
+ return
2020
+ for certificate_request in self.certificate_requests:
2021
+ if not self._certificate_requested(certificate_request):
2022
+ csr = certificate_request.generate_csr(
2023
+ private_key=self.private_key,
2024
+ )
2025
+ if not csr:
2026
+ logger.warning("Failed to generate CSR")
2027
+ continue
2028
+ self._request_certificate(csr=csr, is_ca=certificate_request.is_ca)
2029
+
2030
+ def get_assigned_certificate(
2031
+ self, certificate_request: CertificateRequestAttributes
2032
+ ) -> tuple[ProviderCertificate | None, PrivateKey | None]:
2033
+ """Get the certificate that was assigned to the given certificate request."""
2034
+ for requirer_csr in self.get_csrs_from_requirer_relation_data():
2035
+ if certificate_request == CertificateRequestAttributes.from_csr(
2036
+ requirer_csr.certificate_signing_request,
2037
+ requirer_csr.is_ca,
2038
+ ):
2039
+ return self._find_certificate_in_relation_data(requirer_csr), self.private_key
2040
+ return None, None
2041
+
2042
+ def get_assigned_certificates(
2043
+ self,
2044
+ ) -> tuple[list[ProviderCertificate], PrivateKey | None]:
2045
+ """Get a list of certificates that were assigned to this or app."""
2046
+ assigned_certificates = [
2047
+ cert
2048
+ for requirer_csr in self.get_csrs_from_requirer_relation_data()
2049
+ if (cert := self._find_certificate_in_relation_data(requirer_csr))
2050
+ ]
2051
+ return assigned_certificates, self.private_key
2052
+
2053
+ def _find_certificate_in_relation_data(
2054
+ self, csr: RequirerCertificateRequest
2055
+ ) -> ProviderCertificate | None:
2056
+ """Return the certificate that matches the given CSR, validated against the private key."""
2057
+ if not self.private_key:
2058
+ return None
2059
+ for provider_certificate in self.get_provider_certificates():
2060
+ if provider_certificate.certificate_signing_request == csr.certificate_signing_request:
2061
+ if provider_certificate.certificate.is_ca and not csr.is_ca:
2062
+ logger.warning("Non CA certificate requested, got a CA certificate, ignoring")
2063
+ continue
2064
+ elif not provider_certificate.certificate.is_ca and csr.is_ca:
2065
+ logger.warning("CA certificate requested, got a non CA certificate, ignoring")
2066
+ continue
2067
+ if not provider_certificate.certificate.matches_private_key(self.private_key):
2068
+ logger.warning(
2069
+ "Certificate does not match the private key. Ignoring invalid certificate."
2070
+ )
2071
+ continue
2072
+ return provider_certificate
2073
+ return None
2074
+
2075
+ def _find_available_certificates(self):
2076
+ """Find available certificates and emit events.
2077
+
2078
+ This method will find certificates that are available for the requirer's CSRs.
2079
+ If a certificate is found, it will be set as a secret and an event will be emitted.
2080
+ If a certificate is revoked, the secret will be removed and an event will be emitted.
2081
+ """
2082
+ requirer_csrs = self.get_csrs_from_requirer_relation_data()
2083
+ csrs = [csr.certificate_signing_request for csr in requirer_csrs]
2084
+ provider_certificates = self.get_provider_certificates()
2085
+ for provider_certificate in provider_certificates:
2086
+ if provider_certificate.certificate_signing_request in csrs:
2087
+ secret_label = self._get_csr_secret_label(
2088
+ provider_certificate.certificate_signing_request
2089
+ )
2090
+ if provider_certificate.revoked:
2091
+ with suppress(SecretNotFoundError):
2092
+ logger.debug(
2093
+ "Removing secret with label %s",
2094
+ secret_label,
2095
+ )
2096
+ secret = self.model.get_secret(label=secret_label)
2097
+ secret.remove_all_revisions()
2098
+ else:
2099
+ if not self._csr_matches_certificate_request(
2100
+ certificate_signing_request=provider_certificate.certificate_signing_request,
2101
+ is_ca=provider_certificate.certificate.is_ca,
2102
+ ):
2103
+ logger.debug("Certificate requested for different attributes - Skipping")
2104
+ continue
2105
+ try:
2106
+ secret = self.model.get_secret(label=secret_label)
2107
+ logger.debug("Setting secret with label %s", secret_label)
2108
+ # Juju < 3.6 will create a new revision even if the content is the same
2109
+ if secret.get_content(refresh=True).get("certificate", "") == str(
2110
+ provider_certificate.certificate
2111
+ ):
2112
+ logger.debug(
2113
+ "Secret %s with correct certificate already exists", secret_label
2114
+ )
2115
+ continue
2116
+ secret.set_content(
2117
+ content={
2118
+ "certificate": str(provider_certificate.certificate),
2119
+ "csr": str(provider_certificate.certificate_signing_request),
2120
+ }
2121
+ )
2122
+ secret.set_info(
2123
+ expire=calculate_relative_datetime(
2124
+ target_time=provider_certificate.certificate.expiry_time,
2125
+ fraction=self.renewal_relative_time,
2126
+ ),
2127
+ )
2128
+ secret.get_content(refresh=True)
2129
+ except SecretNotFoundError:
2130
+ logger.debug("Creating new secret with label %s", secret_label)
2131
+ secret = self.charm.unit.add_secret(
2132
+ content={
2133
+ "certificate": str(provider_certificate.certificate),
2134
+ "csr": str(provider_certificate.certificate_signing_request),
2135
+ },
2136
+ label=secret_label,
2137
+ expire=calculate_relative_datetime(
2138
+ target_time=provider_certificate.certificate.expiry_time,
2139
+ fraction=self.renewal_relative_time,
2140
+ ),
2141
+ )
2142
+ self.on.certificate_available.emit(
2143
+ certificate_signing_request=provider_certificate.certificate_signing_request,
2144
+ certificate=provider_certificate.certificate,
2145
+ ca=provider_certificate.ca,
2146
+ chain=provider_certificate.chain,
2147
+ )
2148
+
2149
+ def _cleanup_certificate_requests(self):
2150
+ """Clean up certificate requests.
2151
+
2152
+ Remove any certificate requests that falls into one of the following categories:
2153
+ - The CSR attributes do not match any of the certificate requests defined in
2154
+ the charm's certificate_requests attribute.
2155
+ - The CSR public key does not match the private key.
2156
+ """
2157
+ for requirer_csr in self.get_csrs_from_requirer_relation_data():
2158
+ if not self._csr_matches_certificate_request(
2159
+ certificate_signing_request=requirer_csr.certificate_signing_request,
2160
+ is_ca=requirer_csr.is_ca,
2161
+ ):
2162
+ self._remove_requirer_csr_from_relation_data(
2163
+ requirer_csr.certificate_signing_request
2164
+ )
2165
+ logger.info(
2166
+ "Removed CSR from relation data because it did not match any certificate request" # noqa: E501
2167
+ )
2168
+ elif (
2169
+ self.private_key
2170
+ and not requirer_csr.certificate_signing_request.matches_private_key(
2171
+ self.private_key
2172
+ )
2173
+ ):
2174
+ self._remove_requirer_csr_from_relation_data(
2175
+ requirer_csr.certificate_signing_request
2176
+ )
2177
+ logger.info(
2178
+ "Removed CSR from relation data because it did not match the private key"
2179
+ )
2180
+
2181
+ def _tls_relation_created(self) -> bool:
2182
+ relation = self.model.get_relation(self.relationship_name)
2183
+ if not relation:
2184
+ return False
2185
+ return True
2186
+
2187
+ def _get_private_key_secret_label(self) -> str:
2188
+ if self.mode == Mode.UNIT:
2189
+ return f"{LIBID}-private-key-{self._get_unit_number()}-{self.relationship_name}"
2190
+ elif self.mode == Mode.APP:
2191
+ return f"{LIBID}-private-key-{self.relationship_name}"
2192
+ else:
2193
+ raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP.")
2194
+
2195
+ def _get_csr_secret_label(self, csr: CertificateSigningRequest) -> str:
2196
+ csr_in_sha256_hex = csr.get_sha256_hex()
2197
+ if self.mode == Mode.UNIT:
2198
+ return f"{LIBID}-certificate-{self._get_unit_number()}-{csr_in_sha256_hex}"
2199
+ elif self.mode == Mode.APP:
2200
+ return f"{LIBID}-certificate-{csr_in_sha256_hex}"
2201
+ else:
2202
+ raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP.")
2203
+
2204
+ def _get_unit_number(self) -> str:
2205
+ return self.model.unit.name.split("/")[1]
2206
+
2207
+
2208
+ class TLSCertificatesProvidesV4(Object):
2209
+ """TLS certificates provider class to be instantiated by TLS certificates providers."""
2210
+
2211
+ def __init__(self, charm: CharmBase, relationship_name: str):
2212
+ super().__init__(charm, relationship_name)
2213
+ self.framework.observe(charm.on[relationship_name].relation_joined, self._configure)
2214
+ self.framework.observe(charm.on[relationship_name].relation_changed, self._configure)
2215
+ self.framework.observe(charm.on.update_status, self._configure)
2216
+ self.charm = charm
2217
+ self.relationship_name = relationship_name
2218
+ self._security_logger = _OWASPLogger(application=f"tls-certificates-{charm.app.name}")
2219
+
2220
+ def _configure(self, _: EventBase) -> None:
2221
+ """Handle update status and tls relation changed events.
2222
+
2223
+ This is a common hook triggered on a regular basis.
2224
+
2225
+ Revoke certificates for which no csr exists
2226
+ """
2227
+ if not self.model.unit.is_leader():
2228
+ return
2229
+ self._remove_certificates_for_which_no_csr_exists()
2230
+
2231
+ def _remove_certificates_for_which_no_csr_exists(self) -> None:
2232
+ provider_certificates = self.get_provider_certificates()
2233
+ requirer_csrs = [
2234
+ request.certificate_signing_request for request in self.get_certificate_requests()
2235
+ ]
2236
+ for provider_certificate in provider_certificates:
2237
+ if provider_certificate.certificate_signing_request not in requirer_csrs:
2238
+ tls_relation = self._get_tls_relations(
2239
+ relation_id=provider_certificate.relation_id
2240
+ )
2241
+ self._remove_provider_certificate(
2242
+ certificate=provider_certificate.certificate,
2243
+ relation=tls_relation[0],
2244
+ )
2245
+
2246
+ def _get_tls_relations(self, relation_id: int | None = None) -> list[Relation]:
2247
+ return (
2248
+ [
2249
+ relation
2250
+ for relation in self.model.relations[self.relationship_name]
2251
+ if relation.id == relation_id
2252
+ ]
2253
+ if relation_id is not None
2254
+ else self.model.relations.get(self.relationship_name, [])
2255
+ )
2256
+
2257
+ def get_certificate_requests(
2258
+ self, relation_id: int | None = None
2259
+ ) -> list[RequirerCertificateRequest]:
2260
+ """Load certificate requests from the relation data."""
2261
+ relations = self._get_tls_relations(relation_id)
2262
+ requirer_csrs: list[RequirerCertificateRequest] = []
2263
+ for relation in relations:
2264
+ for unit in relation.units:
2265
+ requirer_csrs.extend(self._load_requirer_databag(relation, unit))
2266
+ requirer_csrs.extend(self._load_requirer_databag(relation, relation.app))
2267
+ return requirer_csrs
2268
+
2269
+ def _load_requirer_databag(
2270
+ self, relation: Relation, unit_or_app: Application | Unit
2271
+ ) -> list[RequirerCertificateRequest]:
2272
+ try:
2273
+ requirer_relation_data = _RequirerData.load(relation.data.get(unit_or_app, {}))
2274
+ except DataValidationError:
2275
+ logger.debug("Invalid requirer relation data for %s", unit_or_app.name)
2276
+ return []
2277
+ return [
2278
+ RequirerCertificateRequest(
2279
+ relation_id=relation.id,
2280
+ certificate_signing_request=CertificateSigningRequest.from_string(
2281
+ csr.certificate_signing_request
2282
+ ),
2283
+ is_ca=csr.ca if csr.ca else False,
2284
+ )
2285
+ for csr in requirer_relation_data.certificate_signing_requests
2286
+ ]
2287
+
2288
+ def _add_provider_certificate(
2289
+ self,
2290
+ relation: Relation,
2291
+ provider_certificate: ProviderCertificate,
2292
+ ) -> None:
2293
+ chain = [str(certificate) for certificate in provider_certificate.chain]
2294
+ if chain[0] != str(provider_certificate.certificate):
2295
+ logger.warning(
2296
+ "The order of the chain from the TLS Certificates Provider is incorrect. "
2297
+ "The leaf certificate should be the first element of the chain."
2298
+ )
2299
+ elif not chain_has_valid_order(chain):
2300
+ logger.warning(
2301
+ "The order of the chain from the TLS Certificates Provider is partially incorrect."
2302
+ )
2303
+ new_certificate = _Certificate(
2304
+ certificate=str(provider_certificate.certificate),
2305
+ certificate_signing_request=str(provider_certificate.certificate_signing_request),
2306
+ ca=str(provider_certificate.ca),
2307
+ chain=chain,
2308
+ )
2309
+ provider_certificates = self._load_provider_certificates(relation)
2310
+ if new_certificate in provider_certificates:
2311
+ logger.info("Certificate already in relation data - Doing nothing")
2312
+ return
2313
+ provider_certificates.append(new_certificate)
2314
+ self._dump_provider_certificates(relation=relation, certificates=provider_certificates)
2315
+
2316
+ def _load_provider_certificates(self, relation: Relation) -> list[_Certificate]:
2317
+ try:
2318
+ provider_relation_data = _ProviderApplicationData.load(relation.data[self.charm.app])
2319
+ except DataValidationError:
2320
+ logger.debug("Invalid provider relation data")
2321
+ return []
2322
+ return copy.deepcopy(provider_relation_data.certificates)
2323
+
2324
+ def _dump_provider_certificates(self, relation: Relation, certificates: list[_Certificate]):
2325
+ try:
2326
+ _ProviderApplicationData(certificates=certificates).dump(relation.data[self.model.app])
2327
+ logger.info("Certificate relation data updated")
2328
+ except ModelError:
2329
+ logger.warning("Failed to update relation data")
2330
+
2331
+ def _remove_provider_certificate(
2332
+ self,
2333
+ relation: Relation,
2334
+ certificate: Certificate | None = None,
2335
+ certificate_signing_request: CertificateSigningRequest | None = None,
2336
+ ) -> None:
2337
+ """Remove certificate based on certificate or certificate signing request."""
2338
+ provider_certificates = self._load_provider_certificates(relation)
2339
+ for provider_certificate in provider_certificates:
2340
+ if certificate and provider_certificate.certificate == str(certificate):
2341
+ provider_certificates.remove(provider_certificate)
2342
+ if (
2343
+ certificate_signing_request
2344
+ and provider_certificate.certificate_signing_request
2345
+ == str(certificate_signing_request)
2346
+ ):
2347
+ provider_certificates.remove(provider_certificate)
2348
+ self._dump_provider_certificates(relation=relation, certificates=provider_certificates)
2349
+
2350
+ def revoke_all_certificates(self) -> None:
2351
+ """Revoke all certificates of this provider.
2352
+
2353
+ This method is meant to be used when the Root CA has changed.
2354
+ """
2355
+ if not self.model.unit.is_leader():
2356
+ logger.warning("Unit is not a leader - will not set relation data")
2357
+ return
2358
+ relations = self._get_tls_relations()
2359
+ for relation in relations:
2360
+ provider_certificates = self._load_provider_certificates(relation)
2361
+ for certificate in provider_certificates:
2362
+ certificate.revoked = True
2363
+ self._dump_provider_certificates(relation=relation, certificates=provider_certificates)
2364
+ self._security_logger.log_event(
2365
+ event="all_certificates_revoked",
2366
+ level=logging.WARNING,
2367
+ description="All certificates revoked",
2368
+ )
2369
+
2370
+ def set_relation_certificate(
2371
+ self,
2372
+ provider_certificate: ProviderCertificate,
2373
+ ) -> None:
2374
+ """Add certificates to relation data.
2375
+
2376
+ Args:
2377
+ provider_certificate (ProviderCertificate): ProviderCertificate object
2378
+
2379
+ Returns:
2380
+ None
2381
+ """
2382
+ if not self.model.unit.is_leader():
2383
+ logger.warning("Unit is not a leader - will not set relation data")
2384
+ return
2385
+ certificates_relation = self.model.get_relation(
2386
+ relation_name=self.relationship_name, relation_id=provider_certificate.relation_id
2387
+ )
2388
+ if not certificates_relation:
2389
+ raise TLSCertificatesError(f"Relation {self.relationship_name} does not exist")
2390
+ self._remove_provider_certificate(
2391
+ relation=certificates_relation,
2392
+ certificate_signing_request=provider_certificate.certificate_signing_request,
2393
+ )
2394
+ self._add_provider_certificate(
2395
+ relation=certificates_relation,
2396
+ provider_certificate=provider_certificate,
2397
+ )
2398
+ self._security_logger.log_event(
2399
+ event="certificate_provided",
2400
+ level=logging.INFO,
2401
+ description="Certificate provided to requirer",
2402
+ relation_id=str(provider_certificate.relation_id),
2403
+ common_name=provider_certificate.certificate.common_name,
2404
+ )
2405
+
2406
+ def get_issued_certificates(self, relation_id: int | None = None) -> list[ProviderCertificate]:
2407
+ """Return a List of issued (non revoked) certificates.
2408
+
2409
+ Returns:
2410
+ List: List of ProviderCertificate objects
2411
+ """
2412
+ if not self.model.unit.is_leader():
2413
+ logger.warning("Unit is not a leader - will not read relation data")
2414
+ return []
2415
+ provider_certificates = self.get_provider_certificates(relation_id=relation_id)
2416
+ return [certificate for certificate in provider_certificates if not certificate.revoked]
2417
+
2418
+ def get_provider_certificates(
2419
+ self, relation_id: int | None = None
2420
+ ) -> list[ProviderCertificate]:
2421
+ """Return a List of issued certificates."""
2422
+ certificates: list[ProviderCertificate] = []
2423
+ relations = self._get_tls_relations(relation_id)
2424
+ for relation in relations:
2425
+ if not relation.app:
2426
+ logger.warning("Relation %s does not have an application", relation.id)
2427
+ continue
2428
+ certificates.extend([
2429
+ certificate.to_provider_certificate(relation_id=relation.id)
2430
+ for certificate in self._load_provider_certificates(relation)
2431
+ ])
2432
+ return certificates
2433
+
2434
+ def get_unsolicited_certificates(
2435
+ self, relation_id: int | None = None
2436
+ ) -> list[ProviderCertificate]:
2437
+ """Return provider certificates for which no certificate requests exists.
2438
+
2439
+ Those certificates should be revoked.
2440
+ """
2441
+ provider_certificates = self.get_provider_certificates(relation_id=relation_id)
2442
+ requirer_csrs = self.get_certificate_requests(relation_id=relation_id)
2443
+ list_of_csrs = [csr.certificate_signing_request for csr in requirer_csrs]
2444
+ return [
2445
+ certificate
2446
+ for certificate in provider_certificates
2447
+ if certificate.certificate_signing_request not in list_of_csrs
2448
+ ]
2449
+
2450
+ def get_outstanding_certificate_requests(
2451
+ self, relation_id: int | None = None
2452
+ ) -> list[RequirerCertificateRequest]:
2453
+ """Return CSR's for which no certificate has been issued.
2454
+
2455
+ Args:
2456
+ relation_id (int): Relation id
2457
+
2458
+ Returns:
2459
+ list: List of RequirerCertificateRequest objects.
2460
+ """
2461
+ requirer_csrs = self.get_certificate_requests(relation_id=relation_id)
2462
+ return [
2463
+ relation_csr
2464
+ for relation_csr in requirer_csrs
2465
+ if not self._certificate_issued_for_csr(
2466
+ csr=relation_csr.certificate_signing_request,
2467
+ relation_id=relation_id,
2468
+ )
2469
+ ]
2470
+
2471
+ def _certificate_issued_for_csr(
2472
+ self, csr: CertificateSigningRequest, relation_id: int | None
2473
+ ) -> bool:
2474
+ """Check whether a certificate has been issued for a given CSR."""
2475
+ issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id)
2476
+ for issued_certificate in issued_certificates_per_csr:
2477
+ if issued_certificate.certificate_signing_request == csr:
2478
+ return csr.matches_certificate(issued_certificate.certificate)
2479
+ return False