skfolio 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. skfolio/__init__.py +7 -7
  2. skfolio/cluster/__init__.py +2 -2
  3. skfolio/cluster/_hierarchical.py +2 -2
  4. skfolio/datasets/__init__.py +3 -3
  5. skfolio/datasets/_base.py +2 -2
  6. skfolio/datasets/data/__init__.py +1 -0
  7. skfolio/distance/__init__.py +4 -4
  8. skfolio/distance/_base.py +2 -2
  9. skfolio/distance/_distance.py +11 -10
  10. skfolio/distribution/__init__.py +56 -0
  11. skfolio/distribution/_base.py +203 -0
  12. skfolio/distribution/copula/__init__.py +35 -0
  13. skfolio/distribution/copula/_base.py +456 -0
  14. skfolio/distribution/copula/_clayton.py +539 -0
  15. skfolio/distribution/copula/_gaussian.py +407 -0
  16. skfolio/distribution/copula/_gumbel.py +560 -0
  17. skfolio/distribution/copula/_independent.py +196 -0
  18. skfolio/distribution/copula/_joe.py +609 -0
  19. skfolio/distribution/copula/_selection.py +111 -0
  20. skfolio/distribution/copula/_student_t.py +486 -0
  21. skfolio/distribution/copula/_utils.py +509 -0
  22. skfolio/distribution/multivariate/__init__.py +11 -0
  23. skfolio/distribution/multivariate/_base.py +241 -0
  24. skfolio/distribution/multivariate/_utils.py +632 -0
  25. skfolio/distribution/multivariate/_vine_copula.py +1254 -0
  26. skfolio/distribution/univariate/__init__.py +19 -0
  27. skfolio/distribution/univariate/_base.py +308 -0
  28. skfolio/distribution/univariate/_gaussian.py +136 -0
  29. skfolio/distribution/univariate/_johnson_su.py +152 -0
  30. skfolio/distribution/univariate/_normal_inverse_gaussian.py +153 -0
  31. skfolio/distribution/univariate/_selection.py +85 -0
  32. skfolio/distribution/univariate/_student_t.py +144 -0
  33. skfolio/exceptions.py +8 -8
  34. skfolio/measures/__init__.py +24 -24
  35. skfolio/measures/_enums.py +7 -7
  36. skfolio/measures/_measures.py +4 -7
  37. skfolio/metrics/__init__.py +2 -0
  38. skfolio/metrics/_scorer.py +4 -4
  39. skfolio/model_selection/__init__.py +4 -4
  40. skfolio/model_selection/_combinatorial.py +15 -12
  41. skfolio/model_selection/_validation.py +2 -2
  42. skfolio/model_selection/_walk_forward.py +3 -3
  43. skfolio/moments/__init__.py +11 -11
  44. skfolio/moments/covariance/__init__.py +6 -6
  45. skfolio/moments/covariance/_base.py +1 -1
  46. skfolio/moments/covariance/_denoise_covariance.py +3 -2
  47. skfolio/moments/covariance/_detone_covariance.py +3 -2
  48. skfolio/moments/covariance/_empirical_covariance.py +3 -2
  49. skfolio/moments/covariance/_ew_covariance.py +3 -2
  50. skfolio/moments/covariance/_gerber_covariance.py +3 -2
  51. skfolio/moments/covariance/_graphical_lasso_cv.py +1 -1
  52. skfolio/moments/covariance/_implied_covariance.py +3 -8
  53. skfolio/moments/covariance/_ledoit_wolf.py +1 -1
  54. skfolio/moments/covariance/_oas.py +1 -1
  55. skfolio/moments/covariance/_shrunk_covariance.py +1 -1
  56. skfolio/moments/expected_returns/__init__.py +2 -2
  57. skfolio/moments/expected_returns/_base.py +1 -1
  58. skfolio/moments/expected_returns/_empirical_mu.py +3 -2
  59. skfolio/moments/expected_returns/_equilibrium_mu.py +3 -2
  60. skfolio/moments/expected_returns/_ew_mu.py +3 -2
  61. skfolio/moments/expected_returns/_shrunk_mu.py +4 -3
  62. skfolio/optimization/__init__.py +12 -10
  63. skfolio/optimization/_base.py +2 -2
  64. skfolio/optimization/cluster/__init__.py +3 -1
  65. skfolio/optimization/cluster/_nco.py +10 -9
  66. skfolio/optimization/cluster/hierarchical/__init__.py +3 -1
  67. skfolio/optimization/cluster/hierarchical/_base.py +1 -2
  68. skfolio/optimization/cluster/hierarchical/_herc.py +4 -3
  69. skfolio/optimization/cluster/hierarchical/_hrp.py +4 -3
  70. skfolio/optimization/convex/__init__.py +5 -3
  71. skfolio/optimization/convex/_base.py +10 -9
  72. skfolio/optimization/convex/_distributionally_robust.py +8 -5
  73. skfolio/optimization/convex/_maximum_diversification.py +8 -6
  74. skfolio/optimization/convex/_mean_risk.py +10 -8
  75. skfolio/optimization/convex/_risk_budgeting.py +6 -4
  76. skfolio/optimization/ensemble/__init__.py +2 -0
  77. skfolio/optimization/ensemble/_base.py +2 -2
  78. skfolio/optimization/ensemble/_stacking.py +3 -3
  79. skfolio/optimization/naive/__init__.py +3 -1
  80. skfolio/optimization/naive/_naive.py +4 -3
  81. skfolio/population/__init__.py +2 -0
  82. skfolio/population/_population.py +34 -7
  83. skfolio/portfolio/__init__.py +1 -1
  84. skfolio/portfolio/_base.py +43 -8
  85. skfolio/portfolio/_multi_period_portfolio.py +3 -2
  86. skfolio/portfolio/_portfolio.py +5 -4
  87. skfolio/pre_selection/__init__.py +3 -1
  88. skfolio/pre_selection/_drop_correlated.py +3 -3
  89. skfolio/pre_selection/_select_complete.py +31 -30
  90. skfolio/pre_selection/_select_k_extremes.py +3 -3
  91. skfolio/pre_selection/_select_non_dominated.py +3 -3
  92. skfolio/pre_selection/_select_non_expiring.py +8 -6
  93. skfolio/preprocessing/__init__.py +2 -0
  94. skfolio/preprocessing/_returns.py +2 -2
  95. skfolio/prior/__init__.py +7 -3
  96. skfolio/prior/_base.py +2 -2
  97. skfolio/prior/_black_litterman.py +7 -4
  98. skfolio/prior/_empirical.py +5 -2
  99. skfolio/prior/_factor_model.py +10 -5
  100. skfolio/prior/_synthetic_data.py +239 -0
  101. skfolio/synthetic_returns/__init__.py +1 -0
  102. skfolio/typing.py +7 -7
  103. skfolio/uncertainty_set/__init__.py +7 -5
  104. skfolio/uncertainty_set/_base.py +5 -4
  105. skfolio/uncertainty_set/_bootstrap.py +1 -1
  106. skfolio/uncertainty_set/_empirical.py +1 -1
  107. skfolio/utils/__init__.py +1 -0
  108. skfolio/utils/bootstrap.py +2 -2
  109. skfolio/utils/equations.py +13 -10
  110. skfolio/utils/sorting.py +2 -2
  111. skfolio/utils/stats.py +15 -15
  112. skfolio/utils/tools.py +86 -22
  113. {skfolio-0.6.0.dist-info → skfolio-0.8.0.dist-info}/METADATA +122 -46
  114. skfolio-0.8.0.dist-info/RECORD +120 -0
  115. {skfolio-0.6.0.dist-info → skfolio-0.8.0.dist-info}/WHEEL +1 -1
  116. skfolio-0.6.0.dist-info/RECORD +0 -95
  117. {skfolio-0.6.0.dist-info → skfolio-0.8.0.dist-info/licenses}/LICENSE +0 -0
  118. {skfolio-0.6.0.dist-info → skfolio-0.8.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,632 @@
1
+ """Utils module for multivariate distribution."""
2
+
3
+ # Copyright (c) 2025
4
+ # Author: Hugo Delatte <delatte.hugo@gmail.com>
5
+ # Credits: Matteo Manzi, Vincent Maladière, Carlo Nicolini
6
+ # SPDX-License-Identifier: BSD-3-Clause
7
+
8
+ from abc import ABC, abstractmethod
9
+ from dataclasses import dataclass
10
+ from enum import auto
11
+ from functools import cached_property
12
+ from itertools import combinations
13
+ from typing import Union
14
+
15
+ import numpy as np
16
+ import scipy.sparse.csgraph as ssc
17
+ import scipy.stats as st
18
+ import sklearn.feature_selection as sf
19
+
20
+ from skfolio.utils.tools import AutoEnum
21
+
22
+
23
+ class DependenceMethod(AutoEnum):
24
+ """
25
+ Enumeration of methods to measure bivariate dependence.
26
+
27
+ Attributes
28
+ ----------
29
+ KENDALL_TAU
30
+ Use Kendall's tau correlation coefficient.
31
+
32
+ MUTUAL_INFORMATION
33
+ Use mutual information estimated via a k-nearest neighbors method.
34
+
35
+ WASSERSTEIN_DISTANCE
36
+ Use the Wasserstein (Earth Mover's) distance.
37
+ """
38
+
39
+ KENDALL_TAU = auto()
40
+ MUTUAL_INFORMATION = auto()
41
+ WASSERSTEIN_DISTANCE = auto()
42
+
43
+
44
+ @dataclass
45
+ class EdgeCondSets:
46
+ """
47
+ Container for conditioning sets associated with an edge in an R-vine.
48
+
49
+ Attributes
50
+ ----------
51
+ conditioned : tuple[int, int]
52
+ A tuple of conditioned variable indices.
53
+
54
+ conditioning : set[int]
55
+ A set of conditioning variable indices.
56
+ """
57
+
58
+ conditioned: tuple[int, int]
59
+ conditioning: set[int]
60
+
61
+ def to_set(self) -> set[int]:
62
+ """Union of conditioned and conditioning sets."""
63
+ return set(self.conditioned) | self.conditioning
64
+
65
+ def __add__(self, other: "EdgeCondSets") -> "EdgeCondSets":
66
+ """Combine two EdgeCondSets, merging conditioned and conditioning sets."""
67
+ if not isinstance(other, self.__class__):
68
+ raise TypeError(
69
+ f"Cannot add a EdgeCondSets with an object of type {type(other)}"
70
+ )
71
+ s1 = self.to_set()
72
+ s2 = other.to_set()
73
+ conditioning = s1 & s2
74
+ conditioned = tuple(s1 ^ s2)
75
+ # maintain order
76
+ if conditioned[0] in other.conditioned:
77
+ conditioned = conditioned[::-1]
78
+ return self.__class__(conditioned=conditioned, conditioning=conditioning)
79
+
80
+ def __repr__(self) -> str:
81
+ """String representation of the EdgeCondSets."""
82
+ if self.conditioning:
83
+ return f"{self.conditioned} | {self.conditioning}"
84
+ return str(self.conditioned)
85
+
86
+
87
+ class BaseNode(ABC):
88
+ """Base class for Nodes of the R-vine tree.
89
+
90
+ Parameters
91
+ ----------
92
+ ref : int or Edge
93
+ For RootNode: reference of the variable index.
94
+ For ChildNode: reference of the edge in the previous tree.
95
+
96
+ Attributes
97
+ ----------
98
+ edges : set[Edge]
99
+ The set of edges attached to this node.
100
+
101
+ tree : Tree
102
+ The Tree containing this Node.
103
+ """
104
+
105
+ def __init__(self, ref: Union[int, "Edge"]):
106
+ self._ref = ref
107
+ self.edges: set[Edge] = set()
108
+ self.tree: Tree | None = None # Reference to the Tree containing this Node
109
+
110
+ @property
111
+ def ref(self) -> Union[int, "Edge"]:
112
+ """Return the reference of this node (read-only)."""
113
+ return self._ref
114
+
115
+ @abstractmethod
116
+ def clear_cache(self, **kwargs):
117
+ """Clear the cached pseudo-values and margin values (u and v)."""
118
+ pass
119
+
120
+ def __repr__(self) -> str:
121
+ """String representation of the node."""
122
+ return f"Node({self.ref})"
123
+
124
+
125
+ class RootNode(BaseNode):
126
+ """Root Node of the R-vine tree.
127
+
128
+ Parameters
129
+ ----------
130
+ ref : int
131
+ The reference variable index.
132
+
133
+ central : bool
134
+ True if the node is central; otherwise, False.
135
+
136
+ pseudo_values : ndarray, optional
137
+ The pseudo-values of the Root Node.
138
+
139
+ Attributes
140
+ ----------
141
+ edges : set[Edge]
142
+ The set of edges attached to this node.
143
+
144
+ tree : Tree
145
+ The Tree containing this Node.
146
+ """
147
+
148
+ def __init__(
149
+ self, ref: int, central: bool, pseudo_values: np.ndarray | None = None
150
+ ):
151
+ super().__init__(ref=ref)
152
+ self.central = central
153
+ self.pseudo_values = pseudo_values
154
+
155
+ def clear_cache(self, **kwargs):
156
+ """Clear the cached margin values (u and v)."""
157
+ self.pseudo_values = None
158
+
159
+
160
+ class ChildNode(BaseNode):
161
+ """Child Node of the R-vine tree.
162
+ A child node is an edge from the previous tree.
163
+
164
+ Parameters
165
+ ----------
166
+ ref : Edge
167
+ The reference edge in the previous tree.
168
+
169
+ Attributes
170
+ ----------
171
+ edges : set[Edge]
172
+ The set of edges attached to this node.
173
+
174
+ tree : Tree
175
+ The Tree containing this Node.
176
+ """
177
+
178
+ def __init__(self, ref: "Edge"):
179
+ super().__init__(ref=ref)
180
+ # pointer from Edge to Node
181
+ ref.ref_node = self
182
+ self._central: bool | None = None
183
+ self._u: np.ndarray | None = None
184
+ self._v: np.ndarray | None = None
185
+ self._u_count: int = 0
186
+ self._v_count: int = 0
187
+ self._u_count_total: int = 0
188
+ self._v_count_total: int = 0
189
+
190
+ @property
191
+ def central(self) -> bool:
192
+ """Determine whether this node is considered central.
193
+ It is inherited from the associated edge's centrality.
194
+
195
+ Returns
196
+ -------
197
+ central: bool
198
+ True if the node is central; otherwise, False.
199
+ """
200
+ if self._central is None:
201
+ self._central = self.ref.strongly_central
202
+ return self._central
203
+
204
+ @property
205
+ def u(self) -> np.ndarray:
206
+ """Get the first margin value (u) for the node.
207
+
208
+ It is obtained by computing the partial derivative of the copula with respect
209
+ to v.
210
+
211
+ Returns
212
+ -------
213
+ u : ndarray
214
+ The u values for this node.
215
+ """
216
+ is_count = self.tree is not None and self.tree.is_count_visits
217
+
218
+ if is_count:
219
+ self._u_count_total += 1
220
+ else:
221
+ self._u_count += 1
222
+
223
+ if self._u is None:
224
+ X = self.ref.get_X()
225
+ if is_count:
226
+ self._u = np.array([np.nan])
227
+ else:
228
+ self._u = self.ref.copula.partial_derivative(X, first_margin=False)
229
+
230
+ value = self._u
231
+
232
+ # Clear cache
233
+ if (
234
+ not is_count
235
+ and self._u_count_total != 0
236
+ and self._u_count == self._u_count_total
237
+ ):
238
+ self._u = None
239
+ self._u_count = 0
240
+
241
+ return value
242
+
243
+ @u.setter
244
+ def u(self, value: np.ndarray) -> None:
245
+ self._u = value
246
+
247
+ @property
248
+ def v(self) -> np.ndarray:
249
+ """Get the second margin value (v) for the node.
250
+
251
+ It is obtained by computing the partial derivative of the copula with respect
252
+ to u.
253
+
254
+ Returns
255
+ -------
256
+ v : ndarray
257
+ The v values for this node.
258
+ """
259
+ is_count = self.tree is not None and self.tree.is_count_visits
260
+
261
+ if is_count:
262
+ self._v_count_total += 1
263
+ else:
264
+ self._v_count += 1
265
+
266
+ if self._v is None:
267
+ X = self.ref.get_X()
268
+ if is_count:
269
+ self._v = np.array([np.nan])
270
+ else:
271
+ self._v = self.ref.copula.partial_derivative(X, first_margin=True)
272
+
273
+ value = self._v
274
+
275
+ # Clear cache
276
+ if (
277
+ not is_count
278
+ and self._v_count_total != 0
279
+ and self._v_count == self._v_count_total
280
+ ):
281
+ self._v = None
282
+ self._v_count = 0
283
+
284
+ return value
285
+
286
+ @v.setter
287
+ def v(self, value: np.ndarray):
288
+ self._v = value
289
+
290
+ def get_var(self, is_left: bool) -> int:
291
+ """Return the variable index associated with this node.
292
+
293
+ The variable is determined by the conditioned set of the edge.
294
+
295
+ Parameters
296
+ ----------
297
+ is_left : bool
298
+ Indicates whether to select the left or right node.
299
+
300
+ Returns
301
+ -------
302
+ var : int
303
+ The variable index corresponding to this node.
304
+ """
305
+ if is_left is None:
306
+ raise ValueError("is_left cannot be None for Child Nodes")
307
+ var = self.ref.cond_sets.conditioned[0 if is_left else 1]
308
+ return var
309
+
310
+ def clear_cache(self, clear_count: bool):
311
+ """Clear the cached margin values (u and v) and counts.
312
+
313
+ Parameters
314
+ ----------
315
+ clear_count : bool
316
+ If True, the visit counts are also reset.
317
+ """
318
+ self._u = None
319
+ self._v = None
320
+ if clear_count:
321
+ self._u_count = 0
322
+ self._v_count = 0
323
+ self._u_count_total = 0
324
+ self._v_count_total = 0
325
+
326
+
327
+ class Edge:
328
+ """
329
+ Represents an edge in an R-vine tree connecting two nodes.
330
+
331
+ This class encapsulates the information for an edge between two nodes in an R-vine,
332
+ including the associated copula, the dependence measure, and the conditioning sets.
333
+
334
+ Attributes
335
+ ----------
336
+ node1 : RootNode | ChildNode
337
+ The first node in the edge.
338
+
339
+ node2 : RootNode | ChildNode
340
+ The second node in the edge.
341
+
342
+ dependence_method : DependenceMethod
343
+ The method used to measure dependence between the two nodes.
344
+
345
+ copula : object or None
346
+ The fitted copula for this edge (if available).
347
+
348
+ ref_node : Node or None
349
+ A pointer to the node in the next tree constructed from this edge.
350
+ """
351
+
352
+ def __init__(
353
+ self,
354
+ node1: RootNode | ChildNode,
355
+ node2: RootNode | ChildNode,
356
+ dependence_method: DependenceMethod = DependenceMethod.KENDALL_TAU,
357
+ ):
358
+ self.node1 = node1
359
+ self.node2 = node2
360
+ self.dependence_method = dependence_method
361
+ self.copula = None
362
+ self.ref_node = None # Pointer to the next tree Node
363
+
364
+ @cached_property
365
+ def weakly_central(self) -> bool:
366
+ """Determine if the edge is weakly central.
367
+ An edge is weakly central if at least one of its two nodes is central.
368
+ """
369
+ return self.node1.central or self.node2.central
370
+
371
+ @cached_property
372
+ def strongly_central(self) -> bool:
373
+ """Determine if the edge is strongly central.
374
+ An edge is strongly central if both of its nodes are central.
375
+ """
376
+ return self.node1.central and self.node2.central
377
+
378
+ @cached_property
379
+ def dependence(self) -> float:
380
+ """Dependence measure between the two nodes.
381
+ This is computed on the data from the edge using the specified dependence
382
+ method.
383
+ """
384
+ X = self.get_X()
385
+ dep = _dependence(X, dependence_method=self.dependence_method)
386
+ return dep
387
+
388
+ @cached_property
389
+ def cond_sets(self) -> EdgeCondSets:
390
+ """Compute the conditioning sets for the edge.
391
+ For a root node edge, the conditioned set consists of the two variable indices.
392
+ For non-root nodes, the conditioning sets are obtained by combining the
393
+ conditioning sets of the two edges from the previous tree.
394
+ """
395
+ if isinstance(self.node1, RootNode):
396
+ return EdgeCondSets(
397
+ conditioned=(self.node1.ref, self.node2.ref), conditioning=set()
398
+ )
399
+ return self.node1.ref.cond_sets + self.node2.ref.cond_sets
400
+
401
+ def ref_to_nodes(self):
402
+ """Connect this edge to its two nodes."""
403
+ self.node1.edges.add(self)
404
+ self.node2.edges.add(self)
405
+
406
+ def get_X(self) -> np.ndarray:
407
+ """Retrieve the bivariate pseudo-observation data associated with the edge.
408
+
409
+ For a root edge, this returns the pseudo-values from node1 and node2.
410
+ For non-root edges, the appropriate margins (u or v) are selected
411
+ based on the shared node order.
412
+
413
+ Returns
414
+ -------
415
+ X : ndarray of shape (n_observations, 2)
416
+ The bivariate pseudo-observation data corresponding to this edge.
417
+ """
418
+ if isinstance(self.node1, RootNode):
419
+ u = self.node1.pseudo_values
420
+ v = self.node2.pseudo_values
421
+ else:
422
+ is_left1, is_left2 = self.node1.ref.shared_node_is_left(self.node2.ref)
423
+ u = self.node1.v if is_left1 else self.node1.u
424
+ v = self.node2.v if is_left2 else self.node2.u
425
+ X = np.stack([u, v]).T
426
+ return X
427
+
428
+ def shared_node_is_left(self, other: "Edge") -> tuple[bool, bool]:
429
+ """Determine the ordering of shared nodes between this edge and another edge.
430
+
431
+ If the two edges share one node, this method indicates for each edge whether the
432
+ shared node is the left node.
433
+
434
+ Parameters
435
+ ----------
436
+ other : Edge
437
+ Another edge to compare with.
438
+
439
+ Returns
440
+ -------
441
+ is_left1, is_left2 : tuple[bool, bool]
442
+ A tuple (is_left1, is_left2) where is_left1 is True if the shared node is
443
+ the left node of self and is_left2 is True if the shared node is the left
444
+ node of other.
445
+
446
+ Raises
447
+ ------
448
+ ValueError
449
+ If the edges do not share exactly one node.
450
+ """
451
+ if self.node1 == other.node1:
452
+ return True, True
453
+ if self.node2 == other.node1:
454
+ return False, True
455
+ if self.node2 == other.node2:
456
+ return False, False
457
+ # self.node1 == other.node2
458
+ raise ValueError("Edges are not correctly ordered")
459
+
460
+ def share_one_node(self, other: "Edge") -> bool:
461
+ """Check whether two edges share exactly one node.
462
+
463
+ Parameters
464
+ ----------
465
+ other : Edge
466
+ Another edge to compare with.
467
+
468
+ Returns
469
+ -------
470
+ bool
471
+ True if the two edges share exactly one node; otherwise, False.
472
+ """
473
+ return len({self.node1, self.node2} & {other.node1, other.node2}) == 1
474
+
475
+ def __repr__(self) -> str:
476
+ """String representation of the edge."""
477
+ if self.copula is None:
478
+ return f"Edge({self.cond_sets})"
479
+ return f"Edge({self.cond_sets}, {self.copula.fitted_repr})"
480
+
481
+
482
+ class Tree:
483
+ """
484
+ Represents an R-vine tree at level k.
485
+
486
+ A Tree consists of a set of nodes and the edges connecting them. It represents one
487
+ level (k) in the R-vine structure.
488
+
489
+ Parameters
490
+ ----------
491
+ level : int
492
+ The tree level (k) in the R-vine.
493
+
494
+ nodes : list[Node]
495
+ A list of Node objects representing the nodes in this tree.
496
+
497
+ Attributes
498
+ ----------
499
+ edges : list[Edge]
500
+ The list of edges in the Tree.
501
+
502
+ is_count_visits : bool
503
+ Whether to count the number of visit of each Node during sampling.
504
+ """
505
+
506
+ def __init__(self, level: int, nodes: list[RootNode | ChildNode]):
507
+ self.level = level
508
+ self._nodes = nodes
509
+ for node in nodes:
510
+ # pointer from Node to Tree
511
+ node.tree = self
512
+ self.edges = None
513
+ self.is_count_visits: bool = False
514
+
515
+ @property
516
+ def nodes(self) -> list[RootNode | ChildNode]:
517
+ """Return the tree nodes (read-only)."""
518
+ return self._nodes
519
+
520
+ def set_edges_from_mst(self, dependence_method: DependenceMethod) -> None:
521
+ """Construct the Maximum Spanning Tree (MST) from the current nodes using
522
+ the specified dependence method.
523
+
524
+ The MST is built based on pairwise dependence measures computed between nodes.
525
+ If any edge is (weakly) central, a central factor is added to the dependence
526
+ measure to favor edges connected to central nodes.
527
+
528
+ Parameters
529
+ ----------
530
+ dependence_method : DependenceMethod
531
+ The method used to compute the dependence measure between nodes (e.g.,
532
+ Kendall's tau).
533
+
534
+ Returns
535
+ -------
536
+ None
537
+ """
538
+ n = len(self.nodes)
539
+ dependence_matrix = np.zeros((n, n))
540
+ eligible_edges = {}
541
+ central = False
542
+ for i, j in combinations(range(n), 2):
543
+ node1 = self.nodes[i]
544
+ node2 = self.nodes[j]
545
+ if self.level == 0 or node1.ref.share_one_node(node2.ref):
546
+ edge = Edge(
547
+ node1=node1, node2=node2, dependence_method=dependence_method
548
+ )
549
+ if not central and edge.weakly_central:
550
+ central = True
551
+ # Negate the matrix to use minimum_spanning_tree for maximum spanning
552
+ # Add a cst to ensure that even if dep is 0, we still build a valid MST
553
+ dep = abs(edge.dependence) + 1e-5
554
+ dependence_matrix[i, j] = dep
555
+ eligible_edges[(i, j)] = edge
556
+
557
+ if np.any(np.isnan(dependence_matrix)):
558
+ raise RuntimeError("dependence_matrix contains NaNs")
559
+
560
+ if central:
561
+ max_dep = np.max(dependence_matrix)
562
+ for (i, j), edge in eligible_edges.items():
563
+ if edge.weakly_central:
564
+ if edge.strongly_central:
565
+ central_factor = 3 * max_dep
566
+ else:
567
+ central_factor = 2 * max_dep
568
+ dep = dependence_matrix[i, j] + central_factor
569
+ dependence_matrix[i, j] = dep
570
+
571
+ # Compute the minimum spanning tree
572
+ mst = ssc.minimum_spanning_tree(-dependence_matrix, overwrite=True)
573
+
574
+ edges = []
575
+ # Extract the indices of the non-zero entries (edges)
576
+ for i, j in zip(*mst.nonzero(), strict=True):
577
+ edge = eligible_edges[(i, j)]
578
+ # connect Nodes to Edges
579
+ edge.ref_to_nodes()
580
+ edges.append(edge)
581
+
582
+ self.edges = edges
583
+
584
+ def clear_cache(self, clear_count: bool = True):
585
+ """Clear cached values for all nodes in the tree."""
586
+ for node in self.nodes:
587
+ node.clear_cache(clear_count=clear_count)
588
+
589
+ def __repr__(self):
590
+ """String representation of the tree."""
591
+ return f"Tree(level {self.level})"
592
+
593
+
594
+ def _dependence(X, dependence_method: DependenceMethod) -> float:
595
+ """Compute the dependence between two variables in X using the specified method.
596
+
597
+ Parameters
598
+ ----------
599
+ X : array-like of shape (n_observations, 2)
600
+ A 2D array of bivariate inputs (u, v), where u and v are assumed to lie in
601
+ [0, 1].
602
+
603
+ dependence_method : DependenceMethod
604
+ The method to use for measuring dependence. Options are:
605
+ - DependenceMethod.KENDALL_TAU
606
+ - DependenceMethod.MUTUAL_INFORMATION
607
+ - DependenceMethod.WASSERSTEIN_DISTANCE
608
+
609
+ Returns
610
+ -------
611
+ dependence : float
612
+ The computed dependence measure.
613
+
614
+ Raises
615
+ ------
616
+ ValueError
617
+ If X does not have exactly 2 columns or if an unsupported dependence method is
618
+ provided.
619
+ """
620
+ X = np.asarray(X)
621
+ if X.ndim != 2 or X.shape[1] != 2:
622
+ raise ValueError("X must be a 2D array with exactly 2 columns.")
623
+ match dependence_method:
624
+ case DependenceMethod.KENDALL_TAU:
625
+ dep = st.kendalltau(X[:, 0], X[:, 1]).statistic
626
+ case DependenceMethod.MUTUAL_INFORMATION:
627
+ dep = sf.mutual_info_regression(X[:, 0].reshape(-1, 1), X[:, 1])[0]
628
+ case DependenceMethod.WASSERSTEIN_DISTANCE:
629
+ dep = st.wasserstein_distance(X[:, 0], X[:, 1])
630
+ case _:
631
+ raise ValueError(f"Dependence method {dependence_method} not valid")
632
+ return dep