drisk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. drisk/__init__.py +89 -0
  2. drisk/_style.py +5 -0
  3. drisk/arithmetic.py +111 -0
  4. drisk/copulas/__init__.py +7 -0
  5. drisk/copulas/base.py +95 -0
  6. drisk/copulas/gaussian.py +38 -0
  7. drisk/copulas/registry.py +26 -0
  8. drisk/copulas/student_t.py +51 -0
  9. drisk/correlations/__init__.py +5 -0
  10. drisk/correlations/matrix.py +151 -0
  11. drisk/decision/__init__.py +23 -0
  12. drisk/decision/dtree/__init__.py +17 -0
  13. drisk/decision/dtree/_coercion.py +48 -0
  14. drisk/decision/dtree/_plotting.py +259 -0
  15. drisk/decision/dtree/_sampling.py +66 -0
  16. drisk/decision/dtree/_types.py +7 -0
  17. drisk/decision/dtree/branches.py +6 -0
  18. drisk/decision/dtree/chance_branch.py +32 -0
  19. drisk/decision/dtree/decision_branch.py +31 -0
  20. drisk/decision/dtree/nodes/__init__.py +29 -0
  21. drisk/decision/dtree/nodes/base.py +50 -0
  22. drisk/decision/dtree/nodes/chance.py +104 -0
  23. drisk/decision/dtree/nodes/decision.py +105 -0
  24. drisk/decision/dtree/nodes/factory.py +27 -0
  25. drisk/decision/dtree/nodes/outcome.py +65 -0
  26. drisk/decision/dtree/tree.py +190 -0
  27. drisk/distributions/__init__.py +63 -0
  28. drisk/distributions/base.py +127 -0
  29. drisk/distributions/mixture.py +213 -0
  30. drisk/distributions/registry.py +57 -0
  31. drisk/distributions/types.py +18 -0
  32. drisk/distributions/univariate/__init__.py +53 -0
  33. drisk/distributions/univariate/base.py +52 -0
  34. drisk/distributions/univariate/continuous/__init__.py +32 -0
  35. drisk/distributions/univariate/continuous/base.py +111 -0
  36. drisk/distributions/univariate/continuous/beta.py +148 -0
  37. drisk/distributions/univariate/continuous/exponential.py +103 -0
  38. drisk/distributions/univariate/continuous/gamma.py +126 -0
  39. drisk/distributions/univariate/continuous/logitnormal.py +164 -0
  40. drisk/distributions/univariate/continuous/lognormal.py +137 -0
  41. drisk/distributions/univariate/continuous/normal.py +112 -0
  42. drisk/distributions/univariate/continuous/stretched_beta.py +216 -0
  43. drisk/distributions/univariate/discrete/__init__.py +19 -0
  44. drisk/distributions/univariate/discrete/base.py +108 -0
  45. drisk/distributions/univariate/discrete/bernoulli.py +98 -0
  46. drisk/distributions/univariate/discrete/binomial.py +131 -0
  47. drisk/distributions/univariate/discrete/geometric.py +116 -0
  48. drisk/distributions/univariate/discrete/negative_binomial.py +145 -0
  49. drisk/distributions/univariate/discrete/poisson.py +103 -0
  50. drisk/models/__init__.py +6 -0
  51. drisk/models/base.py +551 -0
  52. drisk/models/functions.py +10 -0
  53. drisk/models/py.typed +0 -0
  54. drisk/py.typed +0 -0
  55. drisk/random.py +34 -0
  56. drisk/sensitivity/__init__.py +5 -0
  57. drisk/sensitivity/_evaluate.py +52 -0
  58. drisk/sensitivity/_inputs.py +47 -0
  59. drisk/sensitivity/one_at_a_time.py +367 -0
  60. drisk/summary.py +101 -0
  61. drisk-0.1.0.dist-info/METADATA +171 -0
  62. drisk-0.1.0.dist-info/RECORD +64 -0
  63. drisk-0.1.0.dist-info/WHEEL +4 -0
  64. drisk-0.1.0.dist-info/licenses/LICENSE +21 -0
drisk/__init__.py ADDED
@@ -0,0 +1,89 @@
1
+ """Convenient tools for quick Monte Carlo modelling."""
2
+
3
+ from . import _style as _style
4
+ from .copulas import Copula, GaussianCopula, StudentTCopula
5
+ from .correlations import CorrelationMatrix
6
+ from .decision import (
7
+ ChanceBranch,
8
+ ChanceNode,
9
+ DecisionBranch,
10
+ DecisionNode,
11
+ DTree,
12
+ DTreeNode,
13
+ OutcomeNode,
14
+ )
15
+ from .distributions import (
16
+ PERT,
17
+ ArrayLike,
18
+ Bernoulli,
19
+ Beta,
20
+ Binomial,
21
+ DataFrameLike,
22
+ Distribution,
23
+ Exponential,
24
+ Gamma,
25
+ Geometric,
26
+ LogitNormal,
27
+ LogNormal,
28
+ NegativeBinomial,
29
+ Normal,
30
+ Poisson,
31
+ StretchedBeta,
32
+ UvBoundedContinuous,
33
+ UvContinuous,
34
+ UvCountDiscrete,
35
+ UvDiscrete,
36
+ UvDistribution,
37
+ UvFiniteDiscrete,
38
+ UvMixture,
39
+ UvPositiveContinuous,
40
+ UvRealContinuous,
41
+ UvUnitBoundedContinuous,
42
+ )
43
+ from .models import MCModel, MCOperation, where
44
+ from .sensitivity import OneAtATimeSensitivity, one_at_a_time
45
+
46
+ __all__ = [
47
+ "ArrayLike",
48
+ "Bernoulli",
49
+ "Beta",
50
+ "Binomial",
51
+ "ChanceBranch",
52
+ "ChanceNode",
53
+ "Copula",
54
+ "CorrelationMatrix",
55
+ "DataFrameLike",
56
+ "DecisionBranch",
57
+ "DecisionNode",
58
+ "Distribution",
59
+ "DTree",
60
+ "DTreeNode",
61
+ "Exponential",
62
+ "GaussianCopula",
63
+ "Gamma",
64
+ "Geometric",
65
+ "LogitNormal",
66
+ "MCModel",
67
+ "MCOperation",
68
+ "LogNormal",
69
+ "UvMixture",
70
+ "NegativeBinomial",
71
+ "Normal",
72
+ "OneAtATimeSensitivity",
73
+ "OutcomeNode",
74
+ "one_at_a_time",
75
+ "PERT",
76
+ "Poisson",
77
+ "StudentTCopula",
78
+ "StretchedBeta",
79
+ "UvBoundedContinuous",
80
+ "UvContinuous",
81
+ "UvCountDiscrete",
82
+ "UvDiscrete",
83
+ "UvDistribution",
84
+ "UvFiniteDiscrete",
85
+ "UvPositiveContinuous",
86
+ "UvRealContinuous",
87
+ "UvUnitBoundedContinuous",
88
+ "where",
89
+ ]
drisk/_style.py ADDED
@@ -0,0 +1,5 @@
1
+ """Package-wide Matplotlib defaults."""
2
+
3
+ import matplotlib as mpl
4
+
5
+ mpl.rcParams.update({"axes.grid": True, "grid.alpha": 0.25})
drisk/arithmetic.py ADDED
@@ -0,0 +1,111 @@
1
+ """Arithmetic helpers for composable Monte Carlo expressions."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+
8
+ class ArithmeticMixin:
9
+ """Mixin that turns arithmetic into lazy Monte Carlo model expressions."""
10
+
11
+ def __add__(self, other: Any) -> Any:
12
+ """Create a lazy model for ``self + other``."""
13
+ from drisk.models import MCModel, MCOperation
14
+
15
+ return MCModel.from_operation(MCOperation.ADD, self, other)
16
+
17
+ def __radd__(self, other: Any) -> Any:
18
+ """Create a lazy model for ``other + self``."""
19
+ from drisk.models import MCModel, MCOperation
20
+
21
+ return MCModel.from_operation(MCOperation.ADD, other, self)
22
+
23
+ def __sub__(self, other: Any) -> Any:
24
+ """Create a lazy model for ``self - other``."""
25
+ from drisk.models import MCModel, MCOperation
26
+
27
+ return MCModel.from_operation(MCOperation.SUBTRACT, self, other)
28
+
29
+ def __rsub__(self, other: Any) -> Any:
30
+ """Create a lazy model for ``other - self``."""
31
+ from drisk.models import MCModel, MCOperation
32
+
33
+ return MCModel.from_operation(MCOperation.SUBTRACT, other, self)
34
+
35
+ def __mul__(self, other: Any) -> Any:
36
+ """Create a lazy model for ``self * other``."""
37
+ from drisk.models import MCModel, MCOperation
38
+
39
+ return MCModel.from_operation(MCOperation.MULTIPLY, self, other)
40
+
41
+ def __rmul__(self, other: Any) -> Any:
42
+ """Create a lazy model for ``other * self``."""
43
+ from drisk.models import MCModel, MCOperation
44
+
45
+ return MCModel.from_operation(MCOperation.MULTIPLY, other, self)
46
+
47
+ def __truediv__(self, other: Any) -> Any:
48
+ """Create a lazy model for ``self / other``."""
49
+ from drisk.models import MCModel, MCOperation
50
+
51
+ return MCModel.from_operation(MCOperation.DIVIDE, self, other)
52
+
53
+ def __rtruediv__(self, other: Any) -> Any:
54
+ """Create a lazy model for ``other / self``."""
55
+ from drisk.models import MCModel, MCOperation
56
+
57
+ return MCModel.from_operation(MCOperation.DIVIDE, other, self)
58
+
59
+ def __pow__(self, other: Any) -> Any:
60
+ """Create a lazy model for ``self ** other``."""
61
+ from drisk.models import MCModel, MCOperation
62
+
63
+ return MCModel.from_operation(MCOperation.POWER, self, other)
64
+
65
+ def __rpow__(self, other: Any) -> Any:
66
+ """Create a lazy model for ``other ** self``."""
67
+ from drisk.models import MCModel, MCOperation
68
+
69
+ return MCModel.from_operation(MCOperation.POWER, other, self)
70
+
71
+ def __lt__(self, other: Any) -> Any:
72
+ """Create a lazy model for ``self < other``."""
73
+ from drisk.models import MCModel, MCOperation
74
+
75
+ return MCModel.from_operation(MCOperation.LESS, self, other)
76
+
77
+ def __le__(self, other: Any) -> Any:
78
+ """Create a lazy model for ``self <= other``."""
79
+ from drisk.models import MCModel, MCOperation
80
+
81
+ return MCModel.from_operation(MCOperation.LESS_EQUAL, self, other)
82
+
83
+ def __gt__(self, other: Any) -> Any:
84
+ """Create a lazy model for ``self > other``."""
85
+ from drisk.models import MCModel, MCOperation
86
+
87
+ return MCModel.from_operation(MCOperation.GREATER, self, other)
88
+
89
+ def __ge__(self, other: Any) -> Any:
90
+ """Create a lazy model for ``self >= other``."""
91
+ from drisk.models import MCModel, MCOperation
92
+
93
+ return MCModel.from_operation(MCOperation.GREATER_EQUAL, self, other)
94
+
95
+ def __neg__(self) -> Any:
96
+ """Create a lazy model for ``-self``."""
97
+ from drisk.models import MCModel, MCOperation
98
+
99
+ return MCModel.from_operation(MCOperation.NEGATIVE, self)
100
+
101
+ def __pos__(self) -> Any:
102
+ """Create a lazy model for ``+self``."""
103
+ from drisk.models import MCModel, MCOperation
104
+
105
+ return MCModel.from_operation(MCOperation.POSITIVE, self)
106
+
107
+ def __abs__(self) -> Any:
108
+ """Create a lazy model for ``abs(self)``."""
109
+ from drisk.models import MCModel, MCOperation
110
+
111
+ return MCModel.from_operation(MCOperation.ABS, self)
@@ -0,0 +1,7 @@
1
+ """Copula models for jointly sampling marginal distributions."""
2
+
3
+ from .base import Copula
4
+ from .gaussian import GaussianCopula
5
+ from .student_t import StudentTCopula
6
+
7
+ __all__ = ["Copula", "GaussianCopula", "StudentTCopula"]
drisk/copulas/base.py ADDED
@@ -0,0 +1,95 @@
1
+ """Base interfaces for copula models."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import Sequence
5
+ from inspect import isabstract
6
+ from typing import Any, Self
7
+
8
+ import numpy as np
9
+ from pydantic import BaseModel, ConfigDict, GetCoreSchemaHandler, model_validator
10
+ from pydantic_core import CoreSchema, core_schema
11
+
12
+ from drisk.correlations import CorrelationMatrix
13
+ from drisk.distributions.univariate import UvDistribution
14
+ from drisk.random import SeedLike
15
+
16
+
17
+ class Copula(BaseModel, ABC):
18
+ """Base class for copulas that jointly sample marginal distributions."""
19
+
20
+ distributions: tuple[UvDistribution, ...]
21
+ corr_matrix: CorrelationMatrix
22
+
23
+ model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
24
+
25
+ @classmethod
26
+ def __get_pydantic_core_schema__(
27
+ cls, source_type: Any, handler: GetCoreSchemaHandler
28
+ ) -> CoreSchema:
29
+ """Use ``copula_type`` to validate abstract copula-typed fields."""
30
+ if not getattr(cls, "__pydantic_complete__", False) or not isabstract(cls):
31
+ return handler(source_type)
32
+
33
+ try:
34
+ from drisk.copulas.registry import concrete_copula_types_for
35
+ except ImportError:
36
+ return handler(source_type)
37
+
38
+ try:
39
+ choices = {
40
+ copula_cls.model_fields["copula_type"].default: handler.generate_schema(
41
+ copula_cls
42
+ )
43
+ for copula_cls in concrete_copula_types_for(cls)
44
+ }
45
+ except ImportError:
46
+ return handler(source_type)
47
+
48
+ if not choices:
49
+ return handler(source_type)
50
+
51
+ return core_schema.tagged_union_schema(
52
+ choices=choices,
53
+ discriminator="copula_type",
54
+ from_attributes=True,
55
+ )
56
+
57
+ @property
58
+ def dims(self) -> int:
59
+ """Number of marginal distributions."""
60
+ return len(self.distributions)
61
+
62
+ @model_validator(mode="after")
63
+ def validate_dimensions(self) -> Self:
64
+ """Ensure the correlation matrix dimension matches the marginals."""
65
+ n = len(self.distributions)
66
+ matrix_n = len(self.corr_matrix.matrix)
67
+ if matrix_n != n:
68
+ raise ValueError(
69
+ f"Correlation matrix size ({matrix_n}) does not match number of distributions ({n})."
70
+ )
71
+ return self
72
+
73
+ @classmethod
74
+ def from_distributions_and_correlation(
75
+ cls,
76
+ distributions: Sequence[UvDistribution],
77
+ correlation: float,
78
+ **kwargs: object,
79
+ ) -> Self:
80
+ """Create a copula from marginals and one shared pairwise correlation."""
81
+ corr_matrix = CorrelationMatrix.from_n_corr(len(distributions), correlation)
82
+ return cls(distributions=distributions, corr_matrix=corr_matrix, **kwargs)
83
+
84
+ @abstractmethod
85
+ def sample(
86
+ self, size: int | tuple[int, ...] = 1, *, seed: SeedLike = None
87
+ ) -> np.ndarray:
88
+ """Jointly sample marginals, returning an array shaped ``(dims, *size)``."""
89
+ pass
90
+
91
+ def rvs(
92
+ self, size: int | tuple[int, ...] = 1, *, seed: SeedLike = None
93
+ ) -> np.ndarray:
94
+ """Alias for :meth:`sample` for users familiar with SciPy naming."""
95
+ return self.sample(size=size, seed=seed)
@@ -0,0 +1,38 @@
1
+ """Gaussian copula."""
2
+
3
+ from typing import Literal
4
+
5
+ import numpy as np
6
+ from scipy import stats
7
+
8
+ from drisk.random import SeedLike, get_rng
9
+
10
+ from .base import Copula
11
+
12
+
13
+ class GaussianCopula(Copula):
14
+ """Sample marginal distributions with dependence induced by a Gaussian copula."""
15
+
16
+ copula_type: Literal["gaussian"] = "gaussian"
17
+
18
+ def sample(
19
+ self, size: int | tuple[int, ...] = 1, *, seed: SeedLike = None
20
+ ) -> np.ndarray:
21
+ """Jointly sample marginals, returning an array shaped ``(dims, *size)``."""
22
+ if isinstance(size, int):
23
+ size = (size,)
24
+
25
+ rng = get_rng(seed)
26
+ normal_samples = rng.multivariate_normal(
27
+ mean=np.zeros(self.dims),
28
+ cov=self.corr_matrix.to_numpy(),
29
+ size=size,
30
+ )
31
+ normal_samples = np.moveaxis(normal_samples, -1, 0)
32
+ uniform_samples = stats.norm.cdf(normal_samples)
33
+
34
+ samples = np.empty_like(uniform_samples)
35
+ for i, dist in enumerate(self.distributions):
36
+ samples[i, ...] = dist.ppf(uniform_samples[i, ...])
37
+
38
+ return samples
@@ -0,0 +1,26 @@
1
+ """Registry of concrete copula implementations for Pydantic polymorphism."""
2
+
3
+ from functools import cache
4
+ from typing import cast
5
+
6
+ from drisk.copulas.base import Copula
7
+
8
+
9
+ @cache
10
+ def concrete_copula_types() -> tuple[type[Copula], ...]:
11
+ """Return all concrete copula classes supported by Drisk."""
12
+ from drisk.copulas.gaussian import GaussianCopula
13
+ from drisk.copulas.student_t import StudentTCopula
14
+
15
+ return (GaussianCopula, StudentTCopula)
16
+
17
+
18
+ def concrete_copula_types_for[CopulaT: Copula](
19
+ base_cls: type[CopulaT],
20
+ ) -> tuple[type[CopulaT], ...]:
21
+ """Return concrete registered copulas that are subclasses of ``base_cls``."""
22
+ return tuple(
23
+ cast(type[CopulaT], copula_cls)
24
+ for copula_cls in concrete_copula_types()
25
+ if issubclass(copula_cls, base_cls)
26
+ )
@@ -0,0 +1,51 @@
1
+ """Student-t copula."""
2
+
3
+ from typing import Literal
4
+
5
+ import numpy as np
6
+ from pydantic import field_validator
7
+ from scipy import stats
8
+
9
+ from drisk.random import SeedLike, get_rng
10
+
11
+ from .base import Copula
12
+
13
+
14
+ class StudentTCopula(Copula):
15
+ """Sample marginals with dependence induced by a Student-t copula."""
16
+
17
+ copula_type: Literal["student_t"] = "student_t"
18
+ nu: float = 4.0
19
+
20
+ @field_validator("nu")
21
+ @classmethod
22
+ def validate_nu(cls, nu: float) -> float:
23
+ """Validate degrees of freedom."""
24
+ if nu <= 0:
25
+ raise ValueError(f"nu must be positive, got {nu}")
26
+ return nu
27
+
28
+ def sample(
29
+ self, size: int | tuple[int, ...] = 1, *, seed: SeedLike = None
30
+ ) -> np.ndarray:
31
+ """Jointly sample marginals, returning an array shaped ``(dims, *size)``."""
32
+ if isinstance(size, int):
33
+ size = (size,)
34
+
35
+ rng = get_rng(seed)
36
+ normal_samples = rng.multivariate_normal(
37
+ mean=np.zeros(self.dims),
38
+ cov=self.corr_matrix.to_numpy(),
39
+ size=size,
40
+ )
41
+ normal_samples = np.moveaxis(normal_samples, -1, 0)
42
+
43
+ chi_square_samples = rng.chisquare(df=self.nu, size=size) / self.nu
44
+ t_samples = normal_samples / np.sqrt(chi_square_samples)
45
+ uniform_samples = stats.t.cdf(t_samples, df=self.nu)
46
+
47
+ samples = np.empty_like(uniform_samples)
48
+ for i, dist in enumerate(self.distributions):
49
+ samples[i, ...] = dist.ppf(uniform_samples[i, ...])
50
+
51
+ return samples
@@ -0,0 +1,5 @@
1
+ """Correlation structures for Monte Carlo modelling."""
2
+
3
+ from .matrix import CorrelationMatrix
4
+
5
+ __all__ = ["CorrelationMatrix"]
@@ -0,0 +1,151 @@
1
+ """Correlation matrix validation and helpers."""
2
+
3
+ from typing import Any, Self
4
+
5
+ import numpy as np
6
+ from pydantic import BaseModel, ConfigDict, field_validator
7
+
8
+
9
+ class CorrelationMatrix(BaseModel):
10
+ """Represent and validate a numeric correlation matrix."""
11
+
12
+ matrix: list[list[float]]
13
+
14
+ model_config = ConfigDict(extra="forbid")
15
+
16
+ @field_validator("matrix")
17
+ @classmethod
18
+ def validate_matrix(cls, matrix: list[list[float]]) -> list[list[float]]:
19
+ """Validate shape, correlation bounds, symmetry, and PSD-ness."""
20
+ if not matrix:
21
+ raise ValueError("Matrix cannot be empty")
22
+
23
+ arr = np.asarray(matrix, dtype=float)
24
+
25
+ if arr.ndim != 2 or arr.shape[0] != arr.shape[1]:
26
+ raise ValueError(f"Matrix must be square, got shape {arr.shape}")
27
+
28
+ diagonal = np.diag(arr)
29
+ if not np.allclose(diagonal, 1.0):
30
+ bad_indices = np.where(~np.isclose(diagonal, 1.0))[0]
31
+ i = int(bad_indices[0])
32
+ raise ValueError(
33
+ f"Diagonal element at ({i}, {i}) must be 1.0, got {diagonal[i]}"
34
+ )
35
+
36
+ if not np.allclose(arr, arr.T):
37
+ diff = np.abs(arr - arr.T)
38
+ i, j = np.unravel_index(np.argmax(diff), diff.shape)
39
+ if i > j:
40
+ i, j = j, i
41
+ raise ValueError(
42
+ f"Matrix is not symmetric: ({i}, {j})={arr[i, j]} != ({j}, {i})={arr[j, i]}"
43
+ )
44
+
45
+ if not np.all((arr >= -1.0) & (arr <= 1.0)):
46
+ bad_mask = (arr < -1.0) | (arr > 1.0)
47
+ i, j = np.unravel_index(np.argmax(bad_mask), arr.shape)
48
+ raise ValueError(
49
+ f"Correlation value at ({i}, {j}) must be between -1 and 1, got {arr[i, j]}"
50
+ )
51
+
52
+ eigenvalues = np.linalg.eigvalsh(arr)
53
+ min_eigenvalue = float(np.min(eigenvalues))
54
+ if min_eigenvalue < -1e-10:
55
+ raise ValueError(
56
+ "Correlation matrix must be positive semidefinite; "
57
+ f"minimum eigenvalue is {min_eigenvalue}"
58
+ )
59
+
60
+ return matrix
61
+
62
+ @classmethod
63
+ def from_n_corr(cls, n: int, corr: float) -> Self:
64
+ """Create an ``n`` by ``n`` matrix with a shared off-diagonal correlation."""
65
+ if n <= 0:
66
+ raise ValueError(f"n must be positive, got {n}")
67
+ if not (-1.0 <= corr <= 1.0):
68
+ raise ValueError(f"Correlation value must be between -1 and 1, got {corr}")
69
+
70
+ matrix = [[1.0 if i == j else float(corr) for j in range(n)] for i in range(n)]
71
+ return cls(matrix=matrix)
72
+
73
+ @classmethod
74
+ def from_numpy(cls, arr: np.ndarray) -> Self:
75
+ """Create a correlation matrix from a NumPy array."""
76
+ if arr.ndim != 2:
77
+ raise ValueError(f"Array must be 2-dimensional, got {arr.ndim}")
78
+ return cls(matrix=arr.tolist())
79
+
80
+ def to_numpy(self) -> np.ndarray:
81
+ """Return the correlation matrix as a NumPy array."""
82
+ return np.asarray(self.matrix, dtype=float)
83
+
84
+ def plot(
85
+ self,
86
+ ax: Any = None,
87
+ *,
88
+ labels: list[str] | None = None,
89
+ cmap: str = "Spectral",
90
+ show: bool = False,
91
+ colorbar: bool = True,
92
+ **imshow_kwargs: Any,
93
+ ) -> Any:
94
+ """
95
+ Plot the correlation matrix as an annotated heatmap.
96
+
97
+ Returns the Matplotlib ``Axes`` object. Importing Matplotlib is deferred
98
+ so non-plotting use stays lightweight. Extra keyword arguments are
99
+ passed to ``imshow``.
100
+ """
101
+ if ax is None:
102
+ import matplotlib.pyplot as plt
103
+
104
+ _, ax = plt.subplots()
105
+
106
+ arr = self.to_numpy()
107
+ n = arr.shape[0]
108
+
109
+ if labels is not None and len(labels) != n:
110
+ raise ValueError(f"labels must have length {n}, got {len(labels)}")
111
+
112
+ image = ax.imshow(
113
+ arr,
114
+ cmap=cmap,
115
+ vmin=-1,
116
+ vmax=1,
117
+ **imshow_kwargs,
118
+ )
119
+
120
+ tick_labels = labels if labels is not None else [str(i) for i in range(n)]
121
+ ax.set_xticks(np.arange(n), labels=tick_labels, rotation=90)
122
+ ax.tick_params(
123
+ axis="x", bottom=True, labelbottom=True, top=False, labeltop=False
124
+ )
125
+ ax.set_yticks(np.arange(n), labels=tick_labels)
126
+ ax.grid(False)
127
+ ax.set_title("Correlation matrix")
128
+
129
+ for i in range(n):
130
+ for j in range(n):
131
+ text_color = "white" if abs(arr[i, j]) > 0.5 else "black"
132
+ ax.text(
133
+ j,
134
+ i,
135
+ f"{arr[i, j]:.2f}",
136
+ ha="center",
137
+ va="center",
138
+ color=text_color,
139
+ )
140
+
141
+ if colorbar:
142
+ ax.figure.colorbar(image, ax=ax, fraction=0.046, pad=0.04)
143
+
144
+ ax.figure.tight_layout()
145
+
146
+ if show:
147
+ import matplotlib.pyplot as plt
148
+
149
+ plt.show()
150
+
151
+ return ax
@@ -0,0 +1,23 @@
1
+ """Decision analysis support."""
2
+
3
+ from .dtree import (
4
+ ChanceBranch,
5
+ ChanceNode,
6
+ DecisionBranch,
7
+ DecisionNode,
8
+ DTree,
9
+ DTreeNode,
10
+ OutcomeNode,
11
+ as_node,
12
+ )
13
+
14
+ __all__ = [
15
+ "ChanceBranch",
16
+ "ChanceNode",
17
+ "DecisionBranch",
18
+ "DecisionNode",
19
+ "DTree",
20
+ "DTreeNode",
21
+ "OutcomeNode",
22
+ "as_node",
23
+ ]
@@ -0,0 +1,17 @@
1
+ """Decision tree support."""
2
+
3
+ from .chance_branch import ChanceBranch
4
+ from .decision_branch import DecisionBranch
5
+ from .nodes import ChanceNode, DecisionNode, DTreeNode, OutcomeNode, as_node
6
+ from .tree import DTree
7
+
8
+ __all__ = [
9
+ "ChanceBranch",
10
+ "ChanceNode",
11
+ "DecisionBranch",
12
+ "DecisionNode",
13
+ "DTree",
14
+ "DTreeNode",
15
+ "OutcomeNode",
16
+ "as_node",
17
+ ]
@@ -0,0 +1,48 @@
1
+ """Coercion helpers for decision tree inputs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ from .chance_branch import ChanceBranch
8
+ from .decision_branch import DecisionBranch
9
+
10
+
11
+ def coerce_decision_branches(branches: Any) -> list[DecisionBranch]:
12
+ from .nodes.factory import as_node
13
+
14
+ if isinstance(branches, dict):
15
+ return [
16
+ DecisionBranch(name=str(name), node=as_node(value))
17
+ for name, value in branches.items()
18
+ ]
19
+ return [
20
+ branch
21
+ if isinstance(branch, DecisionBranch)
22
+ else DecisionBranch.model_validate(branch)
23
+ for branch in branches
24
+ ]
25
+
26
+
27
+ def coerce_chance_branches(branches: Any) -> list[ChanceBranch]:
28
+ from .nodes.factory import as_node
29
+
30
+ if isinstance(branches, dict):
31
+ coerced = []
32
+ for name, spec in branches.items():
33
+ if isinstance(spec, tuple) and len(spec) == 2:
34
+ probability, value = spec
35
+ coerced.append(
36
+ ChanceBranch(
37
+ name=str(name), probability=probability, node=as_node(value)
38
+ )
39
+ )
40
+ else:
41
+ coerced.append(ChanceBranch.model_validate({"name": name, **spec}))
42
+ return coerced
43
+ return [
44
+ branch
45
+ if isinstance(branch, ChanceBranch)
46
+ else ChanceBranch.model_validate(branch)
47
+ for branch in branches
48
+ ]