skfolio 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- skfolio/__init__.py +7 -7
- skfolio/cluster/__init__.py +2 -2
- skfolio/cluster/_hierarchical.py +2 -2
- skfolio/datasets/__init__.py +3 -3
- skfolio/datasets/_base.py +2 -2
- skfolio/datasets/data/__init__.py +1 -0
- skfolio/distance/__init__.py +4 -4
- skfolio/distance/_base.py +2 -2
- skfolio/distance/_distance.py +11 -10
- skfolio/distribution/__init__.py +56 -0
- skfolio/distribution/_base.py +203 -0
- skfolio/distribution/copula/__init__.py +35 -0
- skfolio/distribution/copula/_base.py +456 -0
- skfolio/distribution/copula/_clayton.py +539 -0
- skfolio/distribution/copula/_gaussian.py +407 -0
- skfolio/distribution/copula/_gumbel.py +560 -0
- skfolio/distribution/copula/_independent.py +196 -0
- skfolio/distribution/copula/_joe.py +609 -0
- skfolio/distribution/copula/_selection.py +111 -0
- skfolio/distribution/copula/_student_t.py +486 -0
- skfolio/distribution/copula/_utils.py +509 -0
- skfolio/distribution/multivariate/__init__.py +11 -0
- skfolio/distribution/multivariate/_base.py +241 -0
- skfolio/distribution/multivariate/_utils.py +632 -0
- skfolio/distribution/multivariate/_vine_copula.py +1254 -0
- skfolio/distribution/univariate/__init__.py +19 -0
- skfolio/distribution/univariate/_base.py +308 -0
- skfolio/distribution/univariate/_gaussian.py +136 -0
- skfolio/distribution/univariate/_johnson_su.py +152 -0
- skfolio/distribution/univariate/_normal_inverse_gaussian.py +153 -0
- skfolio/distribution/univariate/_selection.py +85 -0
- skfolio/distribution/univariate/_student_t.py +144 -0
- skfolio/exceptions.py +8 -8
- skfolio/measures/__init__.py +24 -24
- skfolio/measures/_enums.py +7 -7
- skfolio/measures/_measures.py +4 -7
- skfolio/metrics/__init__.py +2 -0
- skfolio/metrics/_scorer.py +4 -4
- skfolio/model_selection/__init__.py +4 -4
- skfolio/model_selection/_combinatorial.py +15 -12
- skfolio/model_selection/_validation.py +2 -2
- skfolio/model_selection/_walk_forward.py +3 -3
- skfolio/moments/__init__.py +11 -11
- skfolio/moments/covariance/__init__.py +6 -6
- skfolio/moments/covariance/_base.py +1 -1
- skfolio/moments/covariance/_denoise_covariance.py +3 -2
- skfolio/moments/covariance/_detone_covariance.py +3 -2
- skfolio/moments/covariance/_empirical_covariance.py +3 -2
- skfolio/moments/covariance/_ew_covariance.py +3 -2
- skfolio/moments/covariance/_gerber_covariance.py +3 -2
- skfolio/moments/covariance/_graphical_lasso_cv.py +1 -1
- skfolio/moments/covariance/_implied_covariance.py +3 -8
- skfolio/moments/covariance/_ledoit_wolf.py +1 -1
- skfolio/moments/covariance/_oas.py +1 -1
- skfolio/moments/covariance/_shrunk_covariance.py +1 -1
- skfolio/moments/expected_returns/__init__.py +2 -2
- skfolio/moments/expected_returns/_base.py +1 -1
- skfolio/moments/expected_returns/_empirical_mu.py +3 -2
- skfolio/moments/expected_returns/_equilibrium_mu.py +3 -2
- skfolio/moments/expected_returns/_ew_mu.py +3 -2
- skfolio/moments/expected_returns/_shrunk_mu.py +4 -3
- skfolio/optimization/__init__.py +12 -10
- skfolio/optimization/_base.py +2 -2
- skfolio/optimization/cluster/__init__.py +3 -1
- skfolio/optimization/cluster/_nco.py +10 -9
- skfolio/optimization/cluster/hierarchical/__init__.py +3 -1
- skfolio/optimization/cluster/hierarchical/_base.py +1 -2
- skfolio/optimization/cluster/hierarchical/_herc.py +4 -3
- skfolio/optimization/cluster/hierarchical/_hrp.py +4 -3
- skfolio/optimization/convex/__init__.py +5 -3
- skfolio/optimization/convex/_base.py +10 -9
- skfolio/optimization/convex/_distributionally_robust.py +8 -5
- skfolio/optimization/convex/_maximum_diversification.py +8 -6
- skfolio/optimization/convex/_mean_risk.py +10 -8
- skfolio/optimization/convex/_risk_budgeting.py +6 -4
- skfolio/optimization/ensemble/__init__.py +2 -0
- skfolio/optimization/ensemble/_base.py +2 -2
- skfolio/optimization/ensemble/_stacking.py +3 -3
- skfolio/optimization/naive/__init__.py +3 -1
- skfolio/optimization/naive/_naive.py +4 -3
- skfolio/population/__init__.py +2 -0
- skfolio/population/_population.py +34 -7
- skfolio/portfolio/__init__.py +1 -1
- skfolio/portfolio/_base.py +43 -8
- skfolio/portfolio/_multi_period_portfolio.py +3 -2
- skfolio/portfolio/_portfolio.py +5 -4
- skfolio/pre_selection/__init__.py +3 -1
- skfolio/pre_selection/_drop_correlated.py +3 -3
- skfolio/pre_selection/_select_complete.py +31 -30
- skfolio/pre_selection/_select_k_extremes.py +3 -3
- skfolio/pre_selection/_select_non_dominated.py +3 -3
- skfolio/pre_selection/_select_non_expiring.py +8 -6
- skfolio/preprocessing/__init__.py +2 -0
- skfolio/preprocessing/_returns.py +2 -2
- skfolio/prior/__init__.py +7 -3
- skfolio/prior/_base.py +2 -2
- skfolio/prior/_black_litterman.py +7 -4
- skfolio/prior/_empirical.py +5 -2
- skfolio/prior/_factor_model.py +10 -5
- skfolio/prior/_synthetic_data.py +239 -0
- skfolio/synthetic_returns/__init__.py +1 -0
- skfolio/typing.py +7 -7
- skfolio/uncertainty_set/__init__.py +7 -5
- skfolio/uncertainty_set/_base.py +5 -4
- skfolio/uncertainty_set/_bootstrap.py +1 -1
- skfolio/uncertainty_set/_empirical.py +1 -1
- skfolio/utils/__init__.py +1 -0
- skfolio/utils/bootstrap.py +2 -2
- skfolio/utils/equations.py +13 -10
- skfolio/utils/sorting.py +2 -2
- skfolio/utils/stats.py +15 -15
- skfolio/utils/tools.py +86 -22
- {skfolio-0.6.0.dist-info → skfolio-0.8.0.dist-info}/METADATA +122 -46
- skfolio-0.8.0.dist-info/RECORD +120 -0
- {skfolio-0.6.0.dist-info → skfolio-0.8.0.dist-info}/WHEEL +1 -1
- skfolio-0.6.0.dist-info/RECORD +0 -95
- {skfolio-0.6.0.dist-info → skfolio-0.8.0.dist-info/licenses}/LICENSE +0 -0
- {skfolio-0.6.0.dist-info → skfolio-0.8.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,632 @@
|
|
1
|
+
"""Utils module for multivariate distribution."""
|
2
|
+
|
3
|
+
# Copyright (c) 2025
|
4
|
+
# Author: Hugo Delatte <delatte.hugo@gmail.com>
|
5
|
+
# Credits: Matteo Manzi, Vincent Maladière, Carlo Nicolini
|
6
|
+
# SPDX-License-Identifier: BSD-3-Clause
|
7
|
+
|
8
|
+
from abc import ABC, abstractmethod
|
9
|
+
from dataclasses import dataclass
|
10
|
+
from enum import auto
|
11
|
+
from functools import cached_property
|
12
|
+
from itertools import combinations
|
13
|
+
from typing import Union
|
14
|
+
|
15
|
+
import numpy as np
|
16
|
+
import scipy.sparse.csgraph as ssc
|
17
|
+
import scipy.stats as st
|
18
|
+
import sklearn.feature_selection as sf
|
19
|
+
|
20
|
+
from skfolio.utils.tools import AutoEnum
|
21
|
+
|
22
|
+
|
23
|
+
class DependenceMethod(AutoEnum):
|
24
|
+
"""
|
25
|
+
Enumeration of methods to measure bivariate dependence.
|
26
|
+
|
27
|
+
Attributes
|
28
|
+
----------
|
29
|
+
KENDALL_TAU
|
30
|
+
Use Kendall's tau correlation coefficient.
|
31
|
+
|
32
|
+
MUTUAL_INFORMATION
|
33
|
+
Use mutual information estimated via a k-nearest neighbors method.
|
34
|
+
|
35
|
+
WASSERSTEIN_DISTANCE
|
36
|
+
Use the Wasserstein (Earth Mover's) distance.
|
37
|
+
"""
|
38
|
+
|
39
|
+
KENDALL_TAU = auto()
|
40
|
+
MUTUAL_INFORMATION = auto()
|
41
|
+
WASSERSTEIN_DISTANCE = auto()
|
42
|
+
|
43
|
+
|
44
|
+
@dataclass
|
45
|
+
class EdgeCondSets:
|
46
|
+
"""
|
47
|
+
Container for conditioning sets associated with an edge in an R-vine.
|
48
|
+
|
49
|
+
Attributes
|
50
|
+
----------
|
51
|
+
conditioned : tuple[int, int]
|
52
|
+
A tuple of conditioned variable indices.
|
53
|
+
|
54
|
+
conditioning : set[int]
|
55
|
+
A set of conditioning variable indices.
|
56
|
+
"""
|
57
|
+
|
58
|
+
conditioned: tuple[int, int]
|
59
|
+
conditioning: set[int]
|
60
|
+
|
61
|
+
def to_set(self) -> set[int]:
|
62
|
+
"""Union of conditioned and conditioning sets."""
|
63
|
+
return set(self.conditioned) | self.conditioning
|
64
|
+
|
65
|
+
def __add__(self, other: "EdgeCondSets") -> "EdgeCondSets":
|
66
|
+
"""Combine two EdgeCondSets, merging conditioned and conditioning sets."""
|
67
|
+
if not isinstance(other, self.__class__):
|
68
|
+
raise TypeError(
|
69
|
+
f"Cannot add a EdgeCondSets with an object of type {type(other)}"
|
70
|
+
)
|
71
|
+
s1 = self.to_set()
|
72
|
+
s2 = other.to_set()
|
73
|
+
conditioning = s1 & s2
|
74
|
+
conditioned = tuple(s1 ^ s2)
|
75
|
+
# maintain order
|
76
|
+
if conditioned[0] in other.conditioned:
|
77
|
+
conditioned = conditioned[::-1]
|
78
|
+
return self.__class__(conditioned=conditioned, conditioning=conditioning)
|
79
|
+
|
80
|
+
def __repr__(self) -> str:
|
81
|
+
"""String representation of the EdgeCondSets."""
|
82
|
+
if self.conditioning:
|
83
|
+
return f"{self.conditioned} | {self.conditioning}"
|
84
|
+
return str(self.conditioned)
|
85
|
+
|
86
|
+
|
87
|
+
class BaseNode(ABC):
|
88
|
+
"""Base class for Nodes of the R-vine tree.
|
89
|
+
|
90
|
+
Parameters
|
91
|
+
----------
|
92
|
+
ref : int or Edge
|
93
|
+
For RootNode: reference of the variable index.
|
94
|
+
For ChildNode: reference of the edge in the previous tree.
|
95
|
+
|
96
|
+
Attributes
|
97
|
+
----------
|
98
|
+
edges : set[Edge]
|
99
|
+
The set of edges attached to this node.
|
100
|
+
|
101
|
+
tree : Tree
|
102
|
+
The Tree containing this Node.
|
103
|
+
"""
|
104
|
+
|
105
|
+
def __init__(self, ref: Union[int, "Edge"]):
|
106
|
+
self._ref = ref
|
107
|
+
self.edges: set[Edge] = set()
|
108
|
+
self.tree: Tree | None = None # Reference to the Tree containing this Node
|
109
|
+
|
110
|
+
@property
|
111
|
+
def ref(self) -> Union[int, "Edge"]:
|
112
|
+
"""Return the reference of this node (read-only)."""
|
113
|
+
return self._ref
|
114
|
+
|
115
|
+
@abstractmethod
|
116
|
+
def clear_cache(self, **kwargs):
|
117
|
+
"""Clear the cached pseudo-values and margin values (u and v)."""
|
118
|
+
pass
|
119
|
+
|
120
|
+
def __repr__(self) -> str:
|
121
|
+
"""String representation of the node."""
|
122
|
+
return f"Node({self.ref})"
|
123
|
+
|
124
|
+
|
125
|
+
class RootNode(BaseNode):
|
126
|
+
"""Root Node of the R-vine tree.
|
127
|
+
|
128
|
+
Parameters
|
129
|
+
----------
|
130
|
+
ref : int
|
131
|
+
The reference variable index.
|
132
|
+
|
133
|
+
central : bool
|
134
|
+
True if the node is central; otherwise, False.
|
135
|
+
|
136
|
+
pseudo_values : ndarray, optional
|
137
|
+
The pseudo-values of the Root Node.
|
138
|
+
|
139
|
+
Attributes
|
140
|
+
----------
|
141
|
+
edges : set[Edge]
|
142
|
+
The set of edges attached to this node.
|
143
|
+
|
144
|
+
tree : Tree
|
145
|
+
The Tree containing this Node.
|
146
|
+
"""
|
147
|
+
|
148
|
+
def __init__(
|
149
|
+
self, ref: int, central: bool, pseudo_values: np.ndarray | None = None
|
150
|
+
):
|
151
|
+
super().__init__(ref=ref)
|
152
|
+
self.central = central
|
153
|
+
self.pseudo_values = pseudo_values
|
154
|
+
|
155
|
+
def clear_cache(self, **kwargs):
|
156
|
+
"""Clear the cached margin values (u and v)."""
|
157
|
+
self.pseudo_values = None
|
158
|
+
|
159
|
+
|
160
|
+
class ChildNode(BaseNode):
|
161
|
+
"""Child Node of the R-vine tree.
|
162
|
+
A child node is an edge from the previous tree.
|
163
|
+
|
164
|
+
Parameters
|
165
|
+
----------
|
166
|
+
ref : Edge
|
167
|
+
The reference edge in the previous tree.
|
168
|
+
|
169
|
+
Attributes
|
170
|
+
----------
|
171
|
+
edges : set[Edge]
|
172
|
+
The set of edges attached to this node.
|
173
|
+
|
174
|
+
tree : Tree
|
175
|
+
The Tree containing this Node.
|
176
|
+
"""
|
177
|
+
|
178
|
+
def __init__(self, ref: "Edge"):
|
179
|
+
super().__init__(ref=ref)
|
180
|
+
# pointer from Edge to Node
|
181
|
+
ref.ref_node = self
|
182
|
+
self._central: bool | None = None
|
183
|
+
self._u: np.ndarray | None = None
|
184
|
+
self._v: np.ndarray | None = None
|
185
|
+
self._u_count: int = 0
|
186
|
+
self._v_count: int = 0
|
187
|
+
self._u_count_total: int = 0
|
188
|
+
self._v_count_total: int = 0
|
189
|
+
|
190
|
+
@property
|
191
|
+
def central(self) -> bool:
|
192
|
+
"""Determine whether this node is considered central.
|
193
|
+
It is inherited from the associated edge's centrality.
|
194
|
+
|
195
|
+
Returns
|
196
|
+
-------
|
197
|
+
central: bool
|
198
|
+
True if the node is central; otherwise, False.
|
199
|
+
"""
|
200
|
+
if self._central is None:
|
201
|
+
self._central = self.ref.strongly_central
|
202
|
+
return self._central
|
203
|
+
|
204
|
+
@property
|
205
|
+
def u(self) -> np.ndarray:
|
206
|
+
"""Get the first margin value (u) for the node.
|
207
|
+
|
208
|
+
It is obtained by computing the partial derivative of the copula with respect
|
209
|
+
to v.
|
210
|
+
|
211
|
+
Returns
|
212
|
+
-------
|
213
|
+
u : ndarray
|
214
|
+
The u values for this node.
|
215
|
+
"""
|
216
|
+
is_count = self.tree is not None and self.tree.is_count_visits
|
217
|
+
|
218
|
+
if is_count:
|
219
|
+
self._u_count_total += 1
|
220
|
+
else:
|
221
|
+
self._u_count += 1
|
222
|
+
|
223
|
+
if self._u is None:
|
224
|
+
X = self.ref.get_X()
|
225
|
+
if is_count:
|
226
|
+
self._u = np.array([np.nan])
|
227
|
+
else:
|
228
|
+
self._u = self.ref.copula.partial_derivative(X, first_margin=False)
|
229
|
+
|
230
|
+
value = self._u
|
231
|
+
|
232
|
+
# Clear cache
|
233
|
+
if (
|
234
|
+
not is_count
|
235
|
+
and self._u_count_total != 0
|
236
|
+
and self._u_count == self._u_count_total
|
237
|
+
):
|
238
|
+
self._u = None
|
239
|
+
self._u_count = 0
|
240
|
+
|
241
|
+
return value
|
242
|
+
|
243
|
+
@u.setter
|
244
|
+
def u(self, value: np.ndarray) -> None:
|
245
|
+
self._u = value
|
246
|
+
|
247
|
+
@property
|
248
|
+
def v(self) -> np.ndarray:
|
249
|
+
"""Get the second margin value (v) for the node.
|
250
|
+
|
251
|
+
It is obtained by computing the partial derivative of the copula with respect
|
252
|
+
to u.
|
253
|
+
|
254
|
+
Returns
|
255
|
+
-------
|
256
|
+
v : ndarray
|
257
|
+
The v values for this node.
|
258
|
+
"""
|
259
|
+
is_count = self.tree is not None and self.tree.is_count_visits
|
260
|
+
|
261
|
+
if is_count:
|
262
|
+
self._v_count_total += 1
|
263
|
+
else:
|
264
|
+
self._v_count += 1
|
265
|
+
|
266
|
+
if self._v is None:
|
267
|
+
X = self.ref.get_X()
|
268
|
+
if is_count:
|
269
|
+
self._v = np.array([np.nan])
|
270
|
+
else:
|
271
|
+
self._v = self.ref.copula.partial_derivative(X, first_margin=True)
|
272
|
+
|
273
|
+
value = self._v
|
274
|
+
|
275
|
+
# Clear cache
|
276
|
+
if (
|
277
|
+
not is_count
|
278
|
+
and self._v_count_total != 0
|
279
|
+
and self._v_count == self._v_count_total
|
280
|
+
):
|
281
|
+
self._v = None
|
282
|
+
self._v_count = 0
|
283
|
+
|
284
|
+
return value
|
285
|
+
|
286
|
+
@v.setter
|
287
|
+
def v(self, value: np.ndarray):
|
288
|
+
self._v = value
|
289
|
+
|
290
|
+
def get_var(self, is_left: bool) -> int:
|
291
|
+
"""Return the variable index associated with this node.
|
292
|
+
|
293
|
+
The variable is determined by the conditioned set of the edge.
|
294
|
+
|
295
|
+
Parameters
|
296
|
+
----------
|
297
|
+
is_left : bool
|
298
|
+
Indicates whether to select the left or right node.
|
299
|
+
|
300
|
+
Returns
|
301
|
+
-------
|
302
|
+
var : int
|
303
|
+
The variable index corresponding to this node.
|
304
|
+
"""
|
305
|
+
if is_left is None:
|
306
|
+
raise ValueError("is_left cannot be None for Child Nodes")
|
307
|
+
var = self.ref.cond_sets.conditioned[0 if is_left else 1]
|
308
|
+
return var
|
309
|
+
|
310
|
+
def clear_cache(self, clear_count: bool):
|
311
|
+
"""Clear the cached margin values (u and v) and counts.
|
312
|
+
|
313
|
+
Parameters
|
314
|
+
----------
|
315
|
+
clear_count : bool
|
316
|
+
If True, the visit counts are also reset.
|
317
|
+
"""
|
318
|
+
self._u = None
|
319
|
+
self._v = None
|
320
|
+
if clear_count:
|
321
|
+
self._u_count = 0
|
322
|
+
self._v_count = 0
|
323
|
+
self._u_count_total = 0
|
324
|
+
self._v_count_total = 0
|
325
|
+
|
326
|
+
|
327
|
+
class Edge:
|
328
|
+
"""
|
329
|
+
Represents an edge in an R-vine tree connecting two nodes.
|
330
|
+
|
331
|
+
This class encapsulates the information for an edge between two nodes in an R-vine,
|
332
|
+
including the associated copula, the dependence measure, and the conditioning sets.
|
333
|
+
|
334
|
+
Attributes
|
335
|
+
----------
|
336
|
+
node1 : RootNode | ChildNode
|
337
|
+
The first node in the edge.
|
338
|
+
|
339
|
+
node2 : RootNode | ChildNode
|
340
|
+
The second node in the edge.
|
341
|
+
|
342
|
+
dependence_method : DependenceMethod
|
343
|
+
The method used to measure dependence between the two nodes.
|
344
|
+
|
345
|
+
copula : object or None
|
346
|
+
The fitted copula for this edge (if available).
|
347
|
+
|
348
|
+
ref_node : Node or None
|
349
|
+
A pointer to the node in the next tree constructed from this edge.
|
350
|
+
"""
|
351
|
+
|
352
|
+
def __init__(
|
353
|
+
self,
|
354
|
+
node1: RootNode | ChildNode,
|
355
|
+
node2: RootNode | ChildNode,
|
356
|
+
dependence_method: DependenceMethod = DependenceMethod.KENDALL_TAU,
|
357
|
+
):
|
358
|
+
self.node1 = node1
|
359
|
+
self.node2 = node2
|
360
|
+
self.dependence_method = dependence_method
|
361
|
+
self.copula = None
|
362
|
+
self.ref_node = None # Pointer to the next tree Node
|
363
|
+
|
364
|
+
@cached_property
|
365
|
+
def weakly_central(self) -> bool:
|
366
|
+
"""Determine if the edge is weakly central.
|
367
|
+
An edge is weakly central if at least one of its two nodes is central.
|
368
|
+
"""
|
369
|
+
return self.node1.central or self.node2.central
|
370
|
+
|
371
|
+
@cached_property
|
372
|
+
def strongly_central(self) -> bool:
|
373
|
+
"""Determine if the edge is strongly central.
|
374
|
+
An edge is strongly central if both of its nodes are central.
|
375
|
+
"""
|
376
|
+
return self.node1.central and self.node2.central
|
377
|
+
|
378
|
+
@cached_property
|
379
|
+
def dependence(self) -> float:
|
380
|
+
"""Dependence measure between the two nodes.
|
381
|
+
This is computed on the data from the edge using the specified dependence
|
382
|
+
method.
|
383
|
+
"""
|
384
|
+
X = self.get_X()
|
385
|
+
dep = _dependence(X, dependence_method=self.dependence_method)
|
386
|
+
return dep
|
387
|
+
|
388
|
+
@cached_property
|
389
|
+
def cond_sets(self) -> EdgeCondSets:
|
390
|
+
"""Compute the conditioning sets for the edge.
|
391
|
+
For a root node edge, the conditioned set consists of the two variable indices.
|
392
|
+
For non-root nodes, the conditioning sets are obtained by combining the
|
393
|
+
conditioning sets of the two edges from the previous tree.
|
394
|
+
"""
|
395
|
+
if isinstance(self.node1, RootNode):
|
396
|
+
return EdgeCondSets(
|
397
|
+
conditioned=(self.node1.ref, self.node2.ref), conditioning=set()
|
398
|
+
)
|
399
|
+
return self.node1.ref.cond_sets + self.node2.ref.cond_sets
|
400
|
+
|
401
|
+
def ref_to_nodes(self):
|
402
|
+
"""Connect this edge to its two nodes."""
|
403
|
+
self.node1.edges.add(self)
|
404
|
+
self.node2.edges.add(self)
|
405
|
+
|
406
|
+
def get_X(self) -> np.ndarray:
|
407
|
+
"""Retrieve the bivariate pseudo-observation data associated with the edge.
|
408
|
+
|
409
|
+
For a root edge, this returns the pseudo-values from node1 and node2.
|
410
|
+
For non-root edges, the appropriate margins (u or v) are selected
|
411
|
+
based on the shared node order.
|
412
|
+
|
413
|
+
Returns
|
414
|
+
-------
|
415
|
+
X : ndarray of shape (n_observations, 2)
|
416
|
+
The bivariate pseudo-observation data corresponding to this edge.
|
417
|
+
"""
|
418
|
+
if isinstance(self.node1, RootNode):
|
419
|
+
u = self.node1.pseudo_values
|
420
|
+
v = self.node2.pseudo_values
|
421
|
+
else:
|
422
|
+
is_left1, is_left2 = self.node1.ref.shared_node_is_left(self.node2.ref)
|
423
|
+
u = self.node1.v if is_left1 else self.node1.u
|
424
|
+
v = self.node2.v if is_left2 else self.node2.u
|
425
|
+
X = np.stack([u, v]).T
|
426
|
+
return X
|
427
|
+
|
428
|
+
def shared_node_is_left(self, other: "Edge") -> tuple[bool, bool]:
|
429
|
+
"""Determine the ordering of shared nodes between this edge and another edge.
|
430
|
+
|
431
|
+
If the two edges share one node, this method indicates for each edge whether the
|
432
|
+
shared node is the left node.
|
433
|
+
|
434
|
+
Parameters
|
435
|
+
----------
|
436
|
+
other : Edge
|
437
|
+
Another edge to compare with.
|
438
|
+
|
439
|
+
Returns
|
440
|
+
-------
|
441
|
+
is_left1, is_left2 : tuple[bool, bool]
|
442
|
+
A tuple (is_left1, is_left2) where is_left1 is True if the shared node is
|
443
|
+
the left node of self and is_left2 is True if the shared node is the left
|
444
|
+
node of other.
|
445
|
+
|
446
|
+
Raises
|
447
|
+
------
|
448
|
+
ValueError
|
449
|
+
If the edges do not share exactly one node.
|
450
|
+
"""
|
451
|
+
if self.node1 == other.node1:
|
452
|
+
return True, True
|
453
|
+
if self.node2 == other.node1:
|
454
|
+
return False, True
|
455
|
+
if self.node2 == other.node2:
|
456
|
+
return False, False
|
457
|
+
# self.node1 == other.node2
|
458
|
+
raise ValueError("Edges are not correctly ordered")
|
459
|
+
|
460
|
+
def share_one_node(self, other: "Edge") -> bool:
|
461
|
+
"""Check whether two edges share exactly one node.
|
462
|
+
|
463
|
+
Parameters
|
464
|
+
----------
|
465
|
+
other : Edge
|
466
|
+
Another edge to compare with.
|
467
|
+
|
468
|
+
Returns
|
469
|
+
-------
|
470
|
+
bool
|
471
|
+
True if the two edges share exactly one node; otherwise, False.
|
472
|
+
"""
|
473
|
+
return len({self.node1, self.node2} & {other.node1, other.node2}) == 1
|
474
|
+
|
475
|
+
def __repr__(self) -> str:
|
476
|
+
"""String representation of the edge."""
|
477
|
+
if self.copula is None:
|
478
|
+
return f"Edge({self.cond_sets})"
|
479
|
+
return f"Edge({self.cond_sets}, {self.copula.fitted_repr})"
|
480
|
+
|
481
|
+
|
482
|
+
class Tree:
|
483
|
+
"""
|
484
|
+
Represents an R-vine tree at level k.
|
485
|
+
|
486
|
+
A Tree consists of a set of nodes and the edges connecting them. It represents one
|
487
|
+
level (k) in the R-vine structure.
|
488
|
+
|
489
|
+
Parameters
|
490
|
+
----------
|
491
|
+
level : int
|
492
|
+
The tree level (k) in the R-vine.
|
493
|
+
|
494
|
+
nodes : list[Node]
|
495
|
+
A list of Node objects representing the nodes in this tree.
|
496
|
+
|
497
|
+
Attributes
|
498
|
+
----------
|
499
|
+
edges : list[Edge]
|
500
|
+
The list of edges in the Tree.
|
501
|
+
|
502
|
+
is_count_visits : bool
|
503
|
+
Whether to count the number of visit of each Node during sampling.
|
504
|
+
"""
|
505
|
+
|
506
|
+
def __init__(self, level: int, nodes: list[RootNode | ChildNode]):
|
507
|
+
self.level = level
|
508
|
+
self._nodes = nodes
|
509
|
+
for node in nodes:
|
510
|
+
# pointer from Node to Tree
|
511
|
+
node.tree = self
|
512
|
+
self.edges = None
|
513
|
+
self.is_count_visits: bool = False
|
514
|
+
|
515
|
+
@property
|
516
|
+
def nodes(self) -> list[RootNode | ChildNode]:
|
517
|
+
"""Return the tree nodes (read-only)."""
|
518
|
+
return self._nodes
|
519
|
+
|
520
|
+
def set_edges_from_mst(self, dependence_method: DependenceMethod) -> None:
|
521
|
+
"""Construct the Maximum Spanning Tree (MST) from the current nodes using
|
522
|
+
the specified dependence method.
|
523
|
+
|
524
|
+
The MST is built based on pairwise dependence measures computed between nodes.
|
525
|
+
If any edge is (weakly) central, a central factor is added to the dependence
|
526
|
+
measure to favor edges connected to central nodes.
|
527
|
+
|
528
|
+
Parameters
|
529
|
+
----------
|
530
|
+
dependence_method : DependenceMethod
|
531
|
+
The method used to compute the dependence measure between nodes (e.g.,
|
532
|
+
Kendall's tau).
|
533
|
+
|
534
|
+
Returns
|
535
|
+
-------
|
536
|
+
None
|
537
|
+
"""
|
538
|
+
n = len(self.nodes)
|
539
|
+
dependence_matrix = np.zeros((n, n))
|
540
|
+
eligible_edges = {}
|
541
|
+
central = False
|
542
|
+
for i, j in combinations(range(n), 2):
|
543
|
+
node1 = self.nodes[i]
|
544
|
+
node2 = self.nodes[j]
|
545
|
+
if self.level == 0 or node1.ref.share_one_node(node2.ref):
|
546
|
+
edge = Edge(
|
547
|
+
node1=node1, node2=node2, dependence_method=dependence_method
|
548
|
+
)
|
549
|
+
if not central and edge.weakly_central:
|
550
|
+
central = True
|
551
|
+
# Negate the matrix to use minimum_spanning_tree for maximum spanning
|
552
|
+
# Add a cst to ensure that even if dep is 0, we still build a valid MST
|
553
|
+
dep = abs(edge.dependence) + 1e-5
|
554
|
+
dependence_matrix[i, j] = dep
|
555
|
+
eligible_edges[(i, j)] = edge
|
556
|
+
|
557
|
+
if np.any(np.isnan(dependence_matrix)):
|
558
|
+
raise RuntimeError("dependence_matrix contains NaNs")
|
559
|
+
|
560
|
+
if central:
|
561
|
+
max_dep = np.max(dependence_matrix)
|
562
|
+
for (i, j), edge in eligible_edges.items():
|
563
|
+
if edge.weakly_central:
|
564
|
+
if edge.strongly_central:
|
565
|
+
central_factor = 3 * max_dep
|
566
|
+
else:
|
567
|
+
central_factor = 2 * max_dep
|
568
|
+
dep = dependence_matrix[i, j] + central_factor
|
569
|
+
dependence_matrix[i, j] = dep
|
570
|
+
|
571
|
+
# Compute the minimum spanning tree
|
572
|
+
mst = ssc.minimum_spanning_tree(-dependence_matrix, overwrite=True)
|
573
|
+
|
574
|
+
edges = []
|
575
|
+
# Extract the indices of the non-zero entries (edges)
|
576
|
+
for i, j in zip(*mst.nonzero(), strict=True):
|
577
|
+
edge = eligible_edges[(i, j)]
|
578
|
+
# connect Nodes to Edges
|
579
|
+
edge.ref_to_nodes()
|
580
|
+
edges.append(edge)
|
581
|
+
|
582
|
+
self.edges = edges
|
583
|
+
|
584
|
+
def clear_cache(self, clear_count: bool = True):
|
585
|
+
"""Clear cached values for all nodes in the tree."""
|
586
|
+
for node in self.nodes:
|
587
|
+
node.clear_cache(clear_count=clear_count)
|
588
|
+
|
589
|
+
def __repr__(self):
|
590
|
+
"""String representation of the tree."""
|
591
|
+
return f"Tree(level {self.level})"
|
592
|
+
|
593
|
+
|
594
|
+
def _dependence(X, dependence_method: DependenceMethod) -> float:
|
595
|
+
"""Compute the dependence between two variables in X using the specified method.
|
596
|
+
|
597
|
+
Parameters
|
598
|
+
----------
|
599
|
+
X : array-like of shape (n_observations, 2)
|
600
|
+
A 2D array of bivariate inputs (u, v), where u and v are assumed to lie in
|
601
|
+
[0, 1].
|
602
|
+
|
603
|
+
dependence_method : DependenceMethod
|
604
|
+
The method to use for measuring dependence. Options are:
|
605
|
+
- DependenceMethod.KENDALL_TAU
|
606
|
+
- DependenceMethod.MUTUAL_INFORMATION
|
607
|
+
- DependenceMethod.WASSERSTEIN_DISTANCE
|
608
|
+
|
609
|
+
Returns
|
610
|
+
-------
|
611
|
+
dependence : float
|
612
|
+
The computed dependence measure.
|
613
|
+
|
614
|
+
Raises
|
615
|
+
------
|
616
|
+
ValueError
|
617
|
+
If X does not have exactly 2 columns or if an unsupported dependence method is
|
618
|
+
provided.
|
619
|
+
"""
|
620
|
+
X = np.asarray(X)
|
621
|
+
if X.ndim != 2 or X.shape[1] != 2:
|
622
|
+
raise ValueError("X must be a 2D array with exactly 2 columns.")
|
623
|
+
match dependence_method:
|
624
|
+
case DependenceMethod.KENDALL_TAU:
|
625
|
+
dep = st.kendalltau(X[:, 0], X[:, 1]).statistic
|
626
|
+
case DependenceMethod.MUTUAL_INFORMATION:
|
627
|
+
dep = sf.mutual_info_regression(X[:, 0].reshape(-1, 1), X[:, 1])[0]
|
628
|
+
case DependenceMethod.WASSERSTEIN_DISTANCE:
|
629
|
+
dep = st.wasserstein_distance(X[:, 0], X[:, 1])
|
630
|
+
case _:
|
631
|
+
raise ValueError(f"Dependence method {dependence_method} not valid")
|
632
|
+
return dep
|