skfolio 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- skfolio/cluster/_hierarchical.py +1 -1
- skfolio/population/_population.py +1 -3
- skfolio/portfolio/_base.py +3 -7
- skfolio/portfolio/_multi_period_portfolio.py +0 -2
- skfolio/portfolio/_portfolio.py +0 -2
- skfolio/preprocessing/_returns.py +17 -6
- skfolio/utils/stats.py +1 -1
- {skfolio-0.2.1.dist-info → skfolio-0.2.2.dist-info}/METADATA +2 -2
- {skfolio-0.2.1.dist-info → skfolio-0.2.2.dist-info}/RECORD +12 -14
- skfolio/utils/fixes/__init__.py +0 -3
- skfolio/utils/fixes/_dendrogram.py +0 -391
- {skfolio-0.2.1.dist-info → skfolio-0.2.2.dist-info}/LICENSE +0 -0
- {skfolio-0.2.1.dist-info → skfolio-0.2.2.dist-info}/WHEEL +0 -0
- {skfolio-0.2.1.dist-info → skfolio-0.2.2.dist-info}/top_level.txt +0 -0
skfolio/cluster/_hierarchical.py
CHANGED
@@ -13,8 +13,8 @@ import scipy.cluster.hierarchy as sch
|
|
13
13
|
import scipy.spatial.distance as scd
|
14
14
|
import sklearn.base as skb
|
15
15
|
import sklearn.utils.validation as skv
|
16
|
+
from plotly.figure_factory import create_dendrogram
|
16
17
|
|
17
|
-
from skfolio.utils.fixes import create_dendrogram
|
18
18
|
from skfolio.utils.stats import assert_is_distance, compute_optimal_n_clusters
|
19
19
|
from skfolio.utils.tools import AutoEnum, default_asset_names
|
20
20
|
|
@@ -20,8 +20,6 @@ from skfolio.portfolio import BasePortfolio, MultiPeriodPortfolio
|
|
20
20
|
from skfolio.utils.sorting import non_denominated_sort
|
21
21
|
from skfolio.utils.tools import deduplicate_names
|
22
22
|
|
23
|
-
pd.options.plotting.backend = "plotly"
|
24
|
-
|
25
23
|
|
26
24
|
class Population(list):
|
27
25
|
"""Population Class.
|
@@ -616,7 +614,7 @@ class Population(list):
|
|
616
614
|
df = pd.concat(cumulative_returns, axis=1).iloc[:, idx]
|
617
615
|
df.columns = deduplicate_names(names)
|
618
616
|
|
619
|
-
fig = df.plot()
|
617
|
+
fig = df.plot(backend="plotly")
|
620
618
|
fig.update_layout(
|
621
619
|
title=title,
|
622
620
|
xaxis_title="Observations",
|
skfolio/portfolio/_base.py
CHANGED
@@ -61,10 +61,6 @@ from skfolio.utils.tools import (
|
|
61
61
|
format_measure,
|
62
62
|
)
|
63
63
|
|
64
|
-
# TODO: remove and use plotly express
|
65
|
-
pd.options.plotting.backend = "plotly"
|
66
|
-
|
67
|
-
|
68
64
|
_ZERO_THRESHOLD = 1e-5
|
69
65
|
_MEASURES = {
|
70
66
|
e for enu in [PerfMeasure, RiskMeasure, ExtraRiskMeasure, RatioMeasure] for e in enu
|
@@ -988,7 +984,7 @@ class BasePortfolio:
|
|
988
984
|
yaxis_title = title
|
989
985
|
title = f"{title} (non-compounded)"
|
990
986
|
|
991
|
-
fig = df.plot()
|
987
|
+
fig = df.plot(backend="plotly")
|
992
988
|
fig.update_layout(
|
993
989
|
title=title,
|
994
990
|
xaxis_title="Observations",
|
@@ -1019,7 +1015,7 @@ class BasePortfolio:
|
|
1019
1015
|
"""
|
1020
1016
|
if idx is None:
|
1021
1017
|
idx = slice(None)
|
1022
|
-
fig = self.returns_df.iloc[idx].plot()
|
1018
|
+
fig = self.returns_df.iloc[idx].plot(backend="plotly")
|
1023
1019
|
fig.update_layout(
|
1024
1020
|
title="Returns",
|
1025
1021
|
xaxis_title="Observations",
|
@@ -1050,7 +1046,7 @@ class BasePortfolio:
|
|
1050
1046
|
"""
|
1051
1047
|
rolling = self.rolling_measure(measure=measure, window=window)
|
1052
1048
|
rolling.name = f"{measure} {window} observations"
|
1053
|
-
fig = rolling.plot()
|
1049
|
+
fig = rolling.plot(backend="plotly")
|
1054
1050
|
fig.add_hline(
|
1055
1051
|
y=getattr(self, measure.value),
|
1056
1052
|
line_width=1,
|
@@ -18,8 +18,6 @@ from skfolio.portfolio._base import BasePortfolio
|
|
18
18
|
from skfolio.portfolio._portfolio import Portfolio
|
19
19
|
from skfolio.utils.tools import deduplicate_names
|
20
20
|
|
21
|
-
pd.options.plotting.backend = "plotly"
|
22
|
-
|
23
21
|
|
24
22
|
class MultiPeriodPortfolio(BasePortfolio):
|
25
23
|
r"""Multi-Period Portfolio class.
|
skfolio/portfolio/_portfolio.py
CHANGED
@@ -4,6 +4,8 @@
|
|
4
4
|
# Author: Hugo Delatte <delatte.hugo@gmail.com>
|
5
5
|
# License: BSD 3 clause
|
6
6
|
|
7
|
+
from typing import Literal
|
8
|
+
|
7
9
|
import numpy as np
|
8
10
|
import pandas as pd
|
9
11
|
|
@@ -13,7 +15,8 @@ def prices_to_returns(
|
|
13
15
|
y: pd.DataFrame | None = None,
|
14
16
|
log_returns: bool = False,
|
15
17
|
nan_threshold: float = 1,
|
16
|
-
join:
|
18
|
+
join: Literal["left", "right", "inner", "outer", "cross"] = "outer",
|
19
|
+
drop_inceptions_nan: bool = True,
|
17
20
|
) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
|
18
21
|
r"""Transforms a DataFrame of prices to linear or logarithmic returns.
|
19
22
|
|
@@ -53,13 +56,19 @@ def prices_to_returns(
|
|
53
56
|
log_returns : bool, default=True
|
54
57
|
If this is set to True, logarithmic returns are used instead of simple returns.
|
55
58
|
|
56
|
-
join : str, default=
|
59
|
+
join : str, default="outer"
|
57
60
|
The join method between `X` and `y` when `y` is provided.
|
58
61
|
|
59
62
|
nan_threshold : float, default=1.0
|
60
63
|
Drop observations (rows) that have a percentage of missing assets prices above
|
61
64
|
this threshold. The default (`1.0`) is to keep all the observations.
|
62
65
|
|
66
|
+
drop_inceptions_nan : bool, default=True
|
67
|
+
If this is set to True, observations at the beginning are dropped if any of
|
68
|
+
the asset values are missing, otherwise we keep the NaNs. This is useful when
|
69
|
+
you work with a large universe of assets with different inception dates coupled
|
70
|
+
with a pre-selection Transformer.
|
71
|
+
|
63
72
|
Returns
|
64
73
|
-------
|
65
74
|
X : DataFrame
|
@@ -82,7 +91,7 @@ def prices_to_returns(
|
|
82
91
|
else:
|
83
92
|
if not isinstance(y, pd.DataFrame):
|
84
93
|
raise TypeError("`y` must be a DataFrame")
|
85
|
-
df = X.join(y, how=
|
94
|
+
df = X.join(y, how=join)
|
86
95
|
|
87
96
|
n_observations, n_assets = X.shape
|
88
97
|
|
@@ -98,18 +107,20 @@ def prices_to_returns(
|
|
98
107
|
|
99
108
|
# Forward fill missing values
|
100
109
|
df.ffill(inplace=True)
|
101
|
-
# Drop rows
|
102
|
-
|
110
|
+
# Drop rows according to drop_inceptions_nan
|
111
|
+
# noinspection PyTypeChecker
|
112
|
+
df.dropna(how="any" if drop_inceptions_nan else "all", inplace=True)
|
103
113
|
# Drop column if all its values are missing
|
104
114
|
df.dropna(axis=1, how="all", inplace=True)
|
105
115
|
|
106
116
|
# returns
|
107
|
-
all_returns = df.pct_change().
|
117
|
+
all_returns = df.pct_change().iloc[1:]
|
108
118
|
if log_returns:
|
109
119
|
all_returns = np.log1p(all_returns)
|
110
120
|
|
111
121
|
if y is None:
|
112
122
|
return all_returns
|
123
|
+
|
113
124
|
returns = all_returns[[x for x in X.columns if x in df.columns]]
|
114
125
|
factor_returns = all_returns[[x for x in y.columns if x in df.columns]]
|
115
126
|
return returns, factor_returns
|
skfolio/utils/stats.py
CHANGED
@@ -448,7 +448,7 @@ def compute_optimal_n_clusters(distance: np.ndarray, linkage_matrix: np.ndarray)
|
|
448
448
|
"""
|
449
449
|
cut_tree = sch.cut_tree(linkage_matrix)
|
450
450
|
n = cut_tree.shape[1]
|
451
|
-
max_clusters = max(8, round(np.sqrt(n)))
|
451
|
+
max_clusters = min(n, max(8, round(np.sqrt(n))))
|
452
452
|
dispersion = []
|
453
453
|
for k in range(max_clusters):
|
454
454
|
level = cut_tree[:, n - k - 1]
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: skfolio
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.2
|
4
4
|
Summary: Portfolio optimization built on top of scikit-learn
|
5
5
|
Author-email: Hugo Delatte <delatte.hugo@gmail.com>
|
6
6
|
Maintainer-email: Hugo Delatte <delatte.hugo@gmail.com>
|
@@ -62,7 +62,7 @@ Requires-Dist: pandas >=1.4.1
|
|
62
62
|
Requires-Dist: cvxpy >=1.4.1
|
63
63
|
Requires-Dist: scikit-learn >=1.3.2
|
64
64
|
Requires-Dist: joblib >=1.3.2
|
65
|
-
Requires-Dist: plotly >=5.
|
65
|
+
Requires-Dist: plotly >=5.22.0
|
66
66
|
Provides-Extra: docs
|
67
67
|
Requires-Dist: Sphinx ; extra == 'docs'
|
68
68
|
Requires-Dist: sphinx-gallery ; extra == 'docs'
|
@@ -2,7 +2,7 @@ skfolio/__init__.py,sha256=5pn5LpTz6v2j2sxGkY97cVRrSPsN3Yav9b6Uw08boEI,618
|
|
2
2
|
skfolio/exceptions.py,sha256=-XniKql9QHgfitMgHsE9UXWVPdjWpNGO2dVk2SsdPWE,662
|
3
3
|
skfolio/typing.py,sha256=yEZiCZ6UIyfYUqtfj9Kf2KA9mrjUbmxyzpH9uqVboJs,1378
|
4
4
|
skfolio/cluster/__init__.py,sha256=4g-PFB_ld9BhiQ1ZPvvAorpFbRwd_p_DkeRlulDv2Hk,251
|
5
|
-
skfolio/cluster/_hierarchical.py,sha256=
|
5
|
+
skfolio/cluster/_hierarchical.py,sha256=16INBe5HB7ALODO3RNI8ZjOYALtMZa3U_7EP1aEIxp8,12819
|
6
6
|
skfolio/datasets/__init__.py,sha256=9Tpf0Uj8wgr-g7xqvqQP4S4TUYDUNmNmK8t6lqBw2Fs,407
|
7
7
|
skfolio/datasets/_base.py,sha256=laAj8vE8evEKv6NAAJdypLrsfmlhqOJ27aP_AcpcxVQ,13952
|
8
8
|
skfolio/datasets/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -48,15 +48,15 @@ skfolio/optimization/ensemble/_stacking.py,sha256=2PodXf_JahToShSAbwQRXosEer3ESO
|
|
48
48
|
skfolio/optimization/naive/__init__.py,sha256=Dkr55R48urC-jfYN007NTbei16N91Na_EDYLVqzhGgQ,147
|
49
49
|
skfolio/optimization/naive/_naive.py,sha256=lXBUK0-JZd_30xUqnwv3IN1-yvKNVrub2A4ZVdfjmnY,5570
|
50
50
|
skfolio/population/__init__.py,sha256=rsPPMUv95aTK7vmpPeQwF8NzFuBwk6RDo5g4HNaPzNM,80
|
51
|
-
skfolio/population/_population.py,sha256=
|
51
|
+
skfolio/population/_population.py,sha256=TqyzJbz82asy9OwJmFDJqGqn9eFOC4Xn4ompGicTyJI,29124
|
52
52
|
skfolio/portfolio/__init__.py,sha256=YYtcAPmA2zeCxFGTXegg2FXcA7py6CxOX7IMTdYuXl0,586
|
53
|
-
skfolio/portfolio/_base.py,sha256=
|
54
|
-
skfolio/portfolio/_multi_period_portfolio.py,sha256=
|
55
|
-
skfolio/portfolio/_portfolio.py,sha256=
|
53
|
+
skfolio/portfolio/_base.py,sha256=XbvCdlhwP-mgnbVppoZxv8gr_IgCv8csx71LigqZ0-M,38282
|
54
|
+
skfolio/portfolio/_multi_period_portfolio.py,sha256=Zt2khkaeVwZjUKvL0NAk5kLJtfO19hup3YxCGdgk5Mk,22719
|
55
|
+
skfolio/portfolio/_portfolio.py,sha256=sSDX2HzK2KgatCgQakEhENOLE-3jSfwI_xgbiVrCGUY,31609
|
56
56
|
skfolio/pre_selection/__init__.py,sha256=VtUtDn-U-Mn_xR2k7yfld0Yb0rPhLakEAiBwUyi-4Z8,189
|
57
57
|
skfolio/pre_selection/_pre_selection.py,sha256=w84T14nKmzkgzbw5CW_AIlci741lXYxKUwB5pBjhTTI,12163
|
58
58
|
skfolio/preprocessing/__init__.py,sha256=15A1bzfPsbfxxXgGP1gstf4R0E_347Wn18z5W5jH-hk,94
|
59
|
-
skfolio/preprocessing/_returns.py,sha256=
|
59
|
+
skfolio/preprocessing/_returns.py,sha256=oo1Mm-UCHwq4ECjfmsRxWzzK1EPsuv-EEtnimvv_nXo,4345
|
60
60
|
skfolio/prior/__init__.py,sha256=jql8NTiWlykPKJUXTOPdqm531mP8Pul1QAR6hXTXA6c,446
|
61
61
|
skfolio/prior/_base.py,sha256=Dx6rX0X6ymDiieFOI-ik3xMNNFhYEtwLSXOdajf5wZY,1927
|
62
62
|
skfolio/prior/_black_litterman.py,sha256=sVx8113xeP4B6LA4rICKp0cgw7w3F46aQzIQY_34QwQ,9400
|
@@ -70,12 +70,10 @@ skfolio/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
70
70
|
skfolio/utils/bootstrap.py,sha256=3zY2kO_GQURKEcQMCasJOSByde9Mt2IAi3KJH0_a4mk,3550
|
71
71
|
skfolio/utils/equations.py,sha256=w0HsYjA7cS0mHYsI9MpixHLkof3HN26nc14ZfqFrHlE,11047
|
72
72
|
skfolio/utils/sorting.py,sha256=lSjMvH2L-sSj-06B3MlwBrH1rtjCeGEe4hG894W7TE0,3504
|
73
|
-
skfolio/utils/stats.py,sha256=
|
73
|
+
skfolio/utils/stats.py,sha256=KFzYvaa_F-q7bHZxEewrUDdFSnLX0VR3rcVQ5J7t_fw,13140
|
74
74
|
skfolio/utils/tools.py,sha256=xa42f7U3Ki8-CJS6g8w7bKCLI_QMJ8D6LxLBjlEM7Ok,15374
|
75
|
-
skfolio/
|
76
|
-
skfolio/
|
77
|
-
skfolio-0.2.
|
78
|
-
skfolio-0.2.
|
79
|
-
skfolio-0.2.
|
80
|
-
skfolio-0.2.1.dist-info/top_level.txt,sha256=NXEaoS9Ms7t32gxkb867nV0OKlU0KmssL7IJBVo0fJs,8
|
81
|
-
skfolio-0.2.1.dist-info/RECORD,,
|
75
|
+
skfolio-0.2.2.dist-info/LICENSE,sha256=F6Gi-ZJX5BlVzYK8R9NcvAkAsKa7KO29xB1OScbrH6Q,1526
|
76
|
+
skfolio-0.2.2.dist-info/METADATA,sha256=wd7HlJM1U8tEiQMLgSInMLiDhlcZaQAmhjij_HuXdcc,19585
|
77
|
+
skfolio-0.2.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
78
|
+
skfolio-0.2.2.dist-info/top_level.txt,sha256=NXEaoS9Ms7t32gxkb867nV0OKlU0KmssL7IJBVo0fJs,8
|
79
|
+
skfolio-0.2.2.dist-info/RECORD,,
|
skfolio/utils/fixes/__init__.py
DELETED
@@ -1,391 +0,0 @@
|
|
1
|
-
"""Compatibility fixes for plotly"""
|
2
|
-
|
3
|
-
# Fixes of the create_dendrogram plotly function until
|
4
|
-
# https://github.com/plotly/plotly.py/pull/4487 is corrected
|
5
|
-
|
6
|
-
from collections import OrderedDict
|
7
|
-
|
8
|
-
import numpy as np
|
9
|
-
import scipy.cluster.hierarchy as sch
|
10
|
-
import scipy.spatial as scs
|
11
|
-
from plotly import exceptions
|
12
|
-
from plotly.graph_objs import graph_objs
|
13
|
-
|
14
|
-
|
15
|
-
def create_dendrogram(
|
16
|
-
X,
|
17
|
-
orientation="bottom",
|
18
|
-
labels=None,
|
19
|
-
colorscale=None,
|
20
|
-
distfun=None,
|
21
|
-
linkagefun=lambda x: sch.linkage(x, "complete"),
|
22
|
-
hovertext=None,
|
23
|
-
color_threshold=None,
|
24
|
-
):
|
25
|
-
"""
|
26
|
-
Function that returns a dendrogram Plotly figure object. This is a thin
|
27
|
-
wrapper around scipy.cluster.hierarchy.dendrogram.
|
28
|
-
|
29
|
-
See also https://dash.plot.ly/dash-bio/clustergram.
|
30
|
-
|
31
|
-
:param (ndarray) X: Matrix of observations as array of arrays
|
32
|
-
:param (str) orientation: 'top', 'right', 'bottom', or 'left'
|
33
|
-
:param (list) labels: List of axis category labels(observation labels)
|
34
|
-
:param (list) colorscale: Optional colorscale for the dendrogram tree.
|
35
|
-
Requires 8 colors to be specified, the 7th of
|
36
|
-
which is ignored. With scipy>=1.5.0, the 2nd, 3rd
|
37
|
-
and 6th are used twice as often as the others.
|
38
|
-
Given a shorter list, the missing values are
|
39
|
-
replaced with defaults and with a longer list the
|
40
|
-
extra values are ignored.
|
41
|
-
:param (function) distfun: Function to compute the pairwise distance from
|
42
|
-
the observations
|
43
|
-
:param (function) linkagefun: Function to compute the linkage matrix from
|
44
|
-
the pairwise distances
|
45
|
-
:param (list[list]) hovertext: List of hovertext for constituent traces of dendrogram
|
46
|
-
clusters
|
47
|
-
:param (double) color_threshold: Value at which the separation of clusters will be made
|
48
|
-
|
49
|
-
Example 1: Simple bottom oriented dendrogram
|
50
|
-
|
51
|
-
>>> from plotly.figure_factory import create_dendrogram
|
52
|
-
|
53
|
-
>>> import numpy as np
|
54
|
-
|
55
|
-
>>> X = np.random.rand(10,10)
|
56
|
-
>>> fig = create_dendrogram(X)
|
57
|
-
>>> fig.show()
|
58
|
-
|
59
|
-
Example 2: Dendrogram to put on the left of the heatmap
|
60
|
-
|
61
|
-
>>> from plotly.figure_factory import create_dendrogram
|
62
|
-
|
63
|
-
>>> import numpy as np
|
64
|
-
|
65
|
-
>>> X = np.random.rand(5,5)
|
66
|
-
>>> names = ['Jack', 'Oxana', 'John', 'Chelsea', 'Mark']
|
67
|
-
>>> dendro = create_dendrogram(X, orientation='right', labels=names)
|
68
|
-
>>> dendro.update_layout({'width':700, 'height':500}) # doctest: +SKIP
|
69
|
-
>>> dendro.show()
|
70
|
-
|
71
|
-
Example 3: Dendrogram with Pandas
|
72
|
-
|
73
|
-
>>> from plotly.figure_factory import create_dendrogram
|
74
|
-
|
75
|
-
>>> import numpy as np
|
76
|
-
>>> import pandas as pd
|
77
|
-
|
78
|
-
>>> Index= ['A','B','C','D','E','F','G','H','I','J']
|
79
|
-
>>> df = pd.DataFrame(abs(np.random.randn(10, 10)), index=Index)
|
80
|
-
>>> fig = create_dendrogram(df, labels=Index)
|
81
|
-
>>> fig.show()
|
82
|
-
"""
|
83
|
-
|
84
|
-
s = X.shape
|
85
|
-
if len(s) != 2:
|
86
|
-
exceptions.PlotlyError("X should be 2-dimensional array.")
|
87
|
-
|
88
|
-
if distfun is None:
|
89
|
-
distfun = scs.distance.pdist
|
90
|
-
|
91
|
-
dendrogram = _Dendrogram(
|
92
|
-
X,
|
93
|
-
orientation,
|
94
|
-
labels,
|
95
|
-
colorscale,
|
96
|
-
distfun=distfun,
|
97
|
-
linkagefun=linkagefun,
|
98
|
-
hovertext=hovertext,
|
99
|
-
color_threshold=color_threshold,
|
100
|
-
)
|
101
|
-
|
102
|
-
return graph_objs.Figure(data=dendrogram.data, layout=dendrogram.layout)
|
103
|
-
|
104
|
-
|
105
|
-
class _Dendrogram:
|
106
|
-
"""Refer to FigureFactory.create_dendrogram() for docstring."""
|
107
|
-
|
108
|
-
def __init__(
|
109
|
-
self,
|
110
|
-
X,
|
111
|
-
orientation="bottom",
|
112
|
-
labels=None,
|
113
|
-
colorscale=None,
|
114
|
-
width=np.inf,
|
115
|
-
height=np.inf,
|
116
|
-
xaxis="xaxis",
|
117
|
-
yaxis="yaxis",
|
118
|
-
distfun=None,
|
119
|
-
linkagefun=lambda x: sch.linkage(x, "complete"),
|
120
|
-
hovertext=None,
|
121
|
-
color_threshold=None,
|
122
|
-
):
|
123
|
-
self.orientation = orientation
|
124
|
-
self.labels = labels
|
125
|
-
self.xaxis = xaxis
|
126
|
-
self.yaxis = yaxis
|
127
|
-
self.data = []
|
128
|
-
self.leaves = []
|
129
|
-
self.sign = {self.xaxis: 1, self.yaxis: 1}
|
130
|
-
self.layout = {self.xaxis: {}, self.yaxis: {}}
|
131
|
-
|
132
|
-
if self.orientation in ["left", "bottom"]:
|
133
|
-
self.sign[self.xaxis] = 1
|
134
|
-
else:
|
135
|
-
self.sign[self.xaxis] = -1
|
136
|
-
|
137
|
-
if self.orientation in ["right", "bottom"]:
|
138
|
-
self.sign[self.yaxis] = 1
|
139
|
-
else:
|
140
|
-
self.sign[self.yaxis] = -1
|
141
|
-
|
142
|
-
if distfun is None:
|
143
|
-
distfun = scs.distance.pdist
|
144
|
-
|
145
|
-
(dd_traces, xvals, yvals, ordered_labels, leaves) = self.get_dendrogram_traces(
|
146
|
-
X, colorscale, distfun, linkagefun, hovertext, color_threshold
|
147
|
-
)
|
148
|
-
|
149
|
-
self.labels = ordered_labels
|
150
|
-
self.leaves = leaves
|
151
|
-
yvals_flat = yvals.flatten()
|
152
|
-
xvals_flat = xvals.flatten()
|
153
|
-
|
154
|
-
self.zero_vals = []
|
155
|
-
|
156
|
-
for i in range(len(yvals_flat)):
|
157
|
-
if yvals_flat[i] == 0.0 and xvals_flat[i] not in self.zero_vals:
|
158
|
-
self.zero_vals.append(xvals_flat[i])
|
159
|
-
|
160
|
-
if len(self.zero_vals) > len(yvals) + 1:
|
161
|
-
# If the length of zero_vals is larger than the length of yvals,
|
162
|
-
# it means that there are wrong vals because of the identicial samples.
|
163
|
-
# Three and more identicial samples will make the yvals of spliting
|
164
|
-
# center into 0 and it will accidentally take it as leaves.
|
165
|
-
l_border = int(min(self.zero_vals))
|
166
|
-
r_border = int(max(self.zero_vals))
|
167
|
-
correct_leaves_pos = range(
|
168
|
-
l_border, r_border + 1, int((r_border - l_border) / len(yvals))
|
169
|
-
)
|
170
|
-
# Regenerating the leaves pos from the self.zero_vals with equally intervals.
|
171
|
-
self.zero_vals = [v for v in correct_leaves_pos]
|
172
|
-
|
173
|
-
self.zero_vals.sort()
|
174
|
-
self.layout = self.set_figure_layout(width, height)
|
175
|
-
self.data = dd_traces
|
176
|
-
|
177
|
-
def get_color_dict(self, colorscale):
|
178
|
-
"""
|
179
|
-
Returns colorscale used for dendrogram tree clusters.
|
180
|
-
|
181
|
-
:param (list) colorscale: Colors to use for the plot in rgb format.
|
182
|
-
:rtype (dict): A dict of default colors mapped to the user colorscale.
|
183
|
-
|
184
|
-
"""
|
185
|
-
|
186
|
-
# These are the color codes returned for dendrograms
|
187
|
-
# We're replacing them with nicer colors
|
188
|
-
# This list is the colors that can be used by dendrogram, which were
|
189
|
-
# determined as the combination of the default above_threshold_color and
|
190
|
-
# the default color palette (see scipy/cluster/hierarchy.py)
|
191
|
-
d = {
|
192
|
-
"r": "red",
|
193
|
-
"g": "green",
|
194
|
-
"b": "blue",
|
195
|
-
"c": "cyan",
|
196
|
-
"m": "magenta",
|
197
|
-
"y": "yellow",
|
198
|
-
"k": "black",
|
199
|
-
# palette in scipy/cluster/hierarchy.py
|
200
|
-
"w": "white",
|
201
|
-
}
|
202
|
-
default_colors = OrderedDict(sorted(d.items(), key=lambda t: t[0]))
|
203
|
-
|
204
|
-
if colorscale is None:
|
205
|
-
rgb_colorscale = [
|
206
|
-
"rgb(0,116,217)", # blue
|
207
|
-
"rgb(35,205,205)", # cyan
|
208
|
-
"rgb(61,153,112)", # green
|
209
|
-
"rgb(40,35,35)", # black
|
210
|
-
"rgb(133,20,75)", # magenta
|
211
|
-
"rgb(255,65,54)", # red
|
212
|
-
"rgb(255,255,255)", # white
|
213
|
-
"rgb(255,220,0)", # yellow
|
214
|
-
]
|
215
|
-
else:
|
216
|
-
rgb_colorscale = colorscale
|
217
|
-
|
218
|
-
for i in range(len(default_colors.keys())):
|
219
|
-
k = list(default_colors.keys())[i] # PY3 won't index keys
|
220
|
-
if i < len(rgb_colorscale):
|
221
|
-
default_colors[k] = rgb_colorscale[i]
|
222
|
-
|
223
|
-
# add support for cyclic format colors as introduced in scipy===1.5.0
|
224
|
-
# before this, the colors were named 'r', 'b', 'y' etc., now they are
|
225
|
-
# named 'C0', 'C1', etc. To keep the colors consistent regardless of the
|
226
|
-
# scipy version, we try as much as possible to map the new colors to the
|
227
|
-
# old colors
|
228
|
-
# this mapping was found by inpecting scipy/cluster/hierarchy.py (see
|
229
|
-
# comment above).
|
230
|
-
new_old_color_map = [
|
231
|
-
("C0", "b"),
|
232
|
-
("C1", "g"),
|
233
|
-
("C2", "r"),
|
234
|
-
("C3", "c"),
|
235
|
-
("C4", "m"),
|
236
|
-
("C5", "y"),
|
237
|
-
("C6", "k"),
|
238
|
-
("C7", "g"),
|
239
|
-
("C8", "r"),
|
240
|
-
("C9", "c"),
|
241
|
-
]
|
242
|
-
for nc, oc in new_old_color_map:
|
243
|
-
try:
|
244
|
-
default_colors[nc] = default_colors[oc]
|
245
|
-
except KeyError:
|
246
|
-
# it could happen that the old color isn't found (if a custom
|
247
|
-
# colorscale was specified), in this case we set it to an
|
248
|
-
# arbitrary default.
|
249
|
-
default_colors[nc] = "rgb(0,116,217)"
|
250
|
-
|
251
|
-
return default_colors
|
252
|
-
|
253
|
-
def set_axis_layout(self, axis_key):
|
254
|
-
"""
|
255
|
-
Sets and returns default axis object for dendrogram figure.
|
256
|
-
|
257
|
-
:param (str) axis_key: E.g., 'xaxis', 'xaxis1', 'yaxis', yaxis1', etc.
|
258
|
-
:rtype (dict): An axis_key dictionary with set parameters.
|
259
|
-
|
260
|
-
"""
|
261
|
-
axis_defaults = {
|
262
|
-
"type": "linear",
|
263
|
-
"ticks": "outside",
|
264
|
-
"mirror": "allticks",
|
265
|
-
"rangemode": "tozero",
|
266
|
-
"showticklabels": True,
|
267
|
-
"zeroline": False,
|
268
|
-
"showgrid": False,
|
269
|
-
"showline": True,
|
270
|
-
}
|
271
|
-
|
272
|
-
if len(self.labels) != 0:
|
273
|
-
axis_key_labels = self.xaxis
|
274
|
-
if self.orientation in ["left", "right"]:
|
275
|
-
axis_key_labels = self.yaxis
|
276
|
-
if axis_key_labels not in self.layout:
|
277
|
-
self.layout[axis_key_labels] = {}
|
278
|
-
self.layout[axis_key_labels]["tickvals"] = [
|
279
|
-
zv * self.sign[axis_key] for zv in self.zero_vals
|
280
|
-
]
|
281
|
-
self.layout[axis_key_labels]["ticktext"] = self.labels
|
282
|
-
self.layout[axis_key_labels]["tickmode"] = "array"
|
283
|
-
|
284
|
-
self.layout[axis_key].update(axis_defaults)
|
285
|
-
|
286
|
-
return self.layout[axis_key]
|
287
|
-
|
288
|
-
def set_figure_layout(self, width, height):
|
289
|
-
"""
|
290
|
-
Sets and returns default layout object for dendrogram figure.
|
291
|
-
|
292
|
-
"""
|
293
|
-
self.layout.update(
|
294
|
-
{
|
295
|
-
"showlegend": False,
|
296
|
-
"autosize": False,
|
297
|
-
"hovermode": "closest",
|
298
|
-
"width": width,
|
299
|
-
"height": height,
|
300
|
-
}
|
301
|
-
)
|
302
|
-
|
303
|
-
self.set_axis_layout(self.xaxis)
|
304
|
-
self.set_axis_layout(self.yaxis)
|
305
|
-
|
306
|
-
return self.layout
|
307
|
-
|
308
|
-
def get_dendrogram_traces(
|
309
|
-
self, X, colorscale, distfun, linkagefun, hovertext, color_threshold
|
310
|
-
):
|
311
|
-
"""
|
312
|
-
Calculates all the elements needed for plotting a dendrogram.
|
313
|
-
|
314
|
-
:param (ndarray) X: Matrix of observations as array of arrays
|
315
|
-
:param (list) colorscale: Color scale for dendrogram tree clusters
|
316
|
-
:param (function) distfun: Function to compute the pairwise distance
|
317
|
-
from the observations
|
318
|
-
:param (function) linkagefun: Function to compute the linkage matrix
|
319
|
-
from the pairwise distances
|
320
|
-
:param (list) hovertext: List of hovertext for constituent traces of dendrogram
|
321
|
-
:rtype (tuple): Contains all the traces in the following order:
|
322
|
-
(a) trace_list: List of Plotly trace objects for dendrogram tree
|
323
|
-
(b) icoord: All X points of the dendrogram tree as array of arrays
|
324
|
-
with length 4
|
325
|
-
(c) dcoord: All Y points of the dendrogram tree as array of arrays
|
326
|
-
with length 4
|
327
|
-
(d) ordered_labels: leaf labels in the order they are going to
|
328
|
-
appear on the plot
|
329
|
-
(e) P['leaves']: left-to-right traversal of the leaves
|
330
|
-
|
331
|
-
"""
|
332
|
-
d = distfun(X)
|
333
|
-
Z = linkagefun(d)
|
334
|
-
P = sch.dendrogram(
|
335
|
-
Z,
|
336
|
-
orientation=self.orientation,
|
337
|
-
labels=self.labels,
|
338
|
-
no_plot=True,
|
339
|
-
color_threshold=color_threshold,
|
340
|
-
)
|
341
|
-
|
342
|
-
icoord = np.array(P["icoord"])
|
343
|
-
dcoord = np.array(P["dcoord"])
|
344
|
-
ordered_labels = np.array(P["ivl"])
|
345
|
-
color_list = np.array(P["color_list"])
|
346
|
-
colors = self.get_color_dict(colorscale)
|
347
|
-
|
348
|
-
trace_list = []
|
349
|
-
|
350
|
-
for i in range(len(icoord)):
|
351
|
-
# xs and ys are arrays of 4 points that make up the '∩' shapes
|
352
|
-
# of the dendrogram tree
|
353
|
-
if self.orientation in ["top", "bottom"]:
|
354
|
-
xs = icoord[i]
|
355
|
-
else:
|
356
|
-
xs = dcoord[i]
|
357
|
-
|
358
|
-
if self.orientation in ["top", "bottom"]:
|
359
|
-
ys = dcoord[i]
|
360
|
-
else:
|
361
|
-
ys = icoord[i]
|
362
|
-
color_key = color_list[i]
|
363
|
-
hovertext_label = None
|
364
|
-
if hovertext:
|
365
|
-
hovertext_label = hovertext[i]
|
366
|
-
trace = dict(
|
367
|
-
type="scatter",
|
368
|
-
x=np.multiply(self.sign[self.xaxis], xs),
|
369
|
-
y=np.multiply(self.sign[self.yaxis], ys),
|
370
|
-
mode="lines",
|
371
|
-
marker=dict(color=colors[color_key]),
|
372
|
-
text=hovertext_label,
|
373
|
-
hoverinfo="text",
|
374
|
-
)
|
375
|
-
|
376
|
-
try:
|
377
|
-
x_index = int(self.xaxis[-1])
|
378
|
-
except ValueError:
|
379
|
-
x_index = ""
|
380
|
-
|
381
|
-
try:
|
382
|
-
y_index = int(self.yaxis[-1])
|
383
|
-
except ValueError:
|
384
|
-
y_index = ""
|
385
|
-
|
386
|
-
trace["xaxis"] = "x" + x_index
|
387
|
-
trace["yaxis"] = "y" + y_index
|
388
|
-
|
389
|
-
trace_list.append(trace)
|
390
|
-
|
391
|
-
return trace_list, icoord, dcoord, ordered_labels, P["leaves"]
|
File without changes
|
File without changes
|
File without changes
|