stLENS 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stlens-0.1.0/PKG-INFO +12 -0
- stlens-0.1.0/README.md +1 -0
- stlens-0.1.0/pyproject.toml +19 -0
- stlens-0.1.0/setup.cfg +4 -0
- stlens-0.1.0/src/PCA.py +309 -0
- stlens-0.1.0/src/__init__.py +1 -0
- stlens-0.1.0/src/calc.py +182 -0
- stlens-0.1.0/src/find_optimal_pc.py +875 -0
- stlens-0.1.0/src/scLENS_py.py +927 -0
- stlens-0.1.0/src/stLENS.egg-info/PKG-INFO +12 -0
- stlens-0.1.0/src/stLENS.egg-info/SOURCES.txt +12 -0
- stlens-0.1.0/src/stLENS.egg-info/dependency_links.txt +1 -0
- stlens-0.1.0/src/stLENS.egg-info/requires.txt +9 -0
- stlens-0.1.0/src/stLENS.egg-info/top_level.txt +5 -0
stlens-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: stLENS
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Requires-Dist: python==3.10
|
|
5
|
+
Requires-Dist: ipykernel
|
|
6
|
+
Requires-Dist: torch
|
|
7
|
+
Requires-Dist: scanpy
|
|
8
|
+
Requires-Dist: dask
|
|
9
|
+
Requires-Dist: cupy
|
|
10
|
+
Requires-Dist: zarr
|
|
11
|
+
Requires-Dist: python-igraph==0.9.9
|
|
12
|
+
Requires-Dist: leidenalg
|
stlens-0.1.0/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# scLENS
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "stLENS"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
|
|
5
|
+
dependencies = [
|
|
6
|
+
"python==3.10",
|
|
7
|
+
"ipykernel",
|
|
8
|
+
"torch",
|
|
9
|
+
"scanpy",
|
|
10
|
+
"dask",
|
|
11
|
+
"cupy",
|
|
12
|
+
"zarr",
|
|
13
|
+
"python-igraph==0.9.9",
|
|
14
|
+
"leidenalg"
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
[build-system]
|
|
18
|
+
requires = ["setuptools>=61.0"]
|
|
19
|
+
build-backend = "setuptools.build_meta"
|
stlens-0.1.0/setup.cfg
ADDED
stlens-0.1.0/src/PCA.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
1
|
+
from .calc import Calc
|
|
2
|
+
|
|
3
|
+
import dask.array as da
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import scanpy as sc
|
|
6
|
+
import torch
|
|
7
|
+
import numpy as np
|
|
8
|
+
import cupy as cp
|
|
9
|
+
from tqdm.auto import tqdm
|
|
10
|
+
import matplotlib.pyplot as plt
|
|
11
|
+
import seaborn as sns
|
|
12
|
+
from scipy import stats
|
|
13
|
+
import matplotlib.patches as mpatches
|
|
14
|
+
import matplotlib.lines as mlines
|
|
15
|
+
import gc
|
|
16
|
+
|
|
17
|
+
class PCA():
|
|
18
|
+
def __init__(self, device = None, data = None):
|
|
19
|
+
self.device = device
|
|
20
|
+
self.data = data
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def fit(self, X=None, eigen_solver = 'wishart'):
|
|
24
|
+
calc = Calc()
|
|
25
|
+
self.n_cells, self.n_genes = X.shape
|
|
26
|
+
self.rmt_device = None
|
|
27
|
+
|
|
28
|
+
if eigen_solver == 'wishart':
|
|
29
|
+
self.L, self.V = self._get_eigen(X)
|
|
30
|
+
Xr = self._random_matrix(X)
|
|
31
|
+
|
|
32
|
+
if isinstance(self.L, np.ndarray):
|
|
33
|
+
self.device = 'cpu'
|
|
34
|
+
self.rmt_device = 'cpu'
|
|
35
|
+
self.Lr, self.Vr = self._get_eigen(Xr)
|
|
36
|
+
|
|
37
|
+
if self.rmt_device != 'cpu' and isinstance(self.Lr, np.ndarray):
|
|
38
|
+
self.L = self.L.get()
|
|
39
|
+
self.V = self.V.get()
|
|
40
|
+
self.rmt_device = 'cpu'
|
|
41
|
+
elif self.rmt_device == 'cpu':
|
|
42
|
+
self.rmt_device = 'cpu'
|
|
43
|
+
else:
|
|
44
|
+
self.rmt_device = 'gpu'
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
del Xr
|
|
48
|
+
cp.get_default_memory_pool().free_all_blocks()
|
|
49
|
+
cp.get_default_pinned_memory_pool().free_all_blocks()
|
|
50
|
+
cp._default_memory_pool.free_all_blocks()
|
|
51
|
+
gc.collect()
|
|
52
|
+
|
|
53
|
+
self.explained_variance_ = (self.L**2) / self.n_cells
|
|
54
|
+
self.total_variance_ = self.explained_variance_.sum()
|
|
55
|
+
|
|
56
|
+
calc.L = self.L
|
|
57
|
+
# calc.rmt_device = self.rmt_device
|
|
58
|
+
self.L_mp = calc._mp_calculation(self.L, self.Lr, self.rmt_device)
|
|
59
|
+
calc.L_mp = self.L_mp
|
|
60
|
+
self.lambda_c = calc._tw(self.rmt_device)
|
|
61
|
+
print("lambda_c:",self.lambda_c)
|
|
62
|
+
self.peak = calc._mp_parameters(self.L_mp, self.rmt_device)['peak']
|
|
63
|
+
|
|
64
|
+
else:
|
|
65
|
+
raise ValueError("Invalid eigen_solver. Use 'wishart'.")
|
|
66
|
+
|
|
67
|
+
self.Ls = self.L[self.L > self.lambda_c]
|
|
68
|
+
self.Vs = self.V[:, self.L > self.lambda_c]
|
|
69
|
+
|
|
70
|
+
noise_boolean = ((self.L < self.lambda_c) & (self.L > calc.b_minus))
|
|
71
|
+
self.Vn = self.V[:, noise_boolean]
|
|
72
|
+
self.Ln = self.L[noise_boolean]
|
|
73
|
+
|
|
74
|
+
self.n_components = len(self.Ls)
|
|
75
|
+
print(f"Number of signal components: {self.n_components}")
|
|
76
|
+
|
|
77
|
+
cp.get_default_memory_pool().free_all_blocks()
|
|
78
|
+
cp.get_default_pinned_memory_pool().free_all_blocks()
|
|
79
|
+
gc.collect()
|
|
80
|
+
|
|
81
|
+
def get_signal_components(self, n_components=0):
|
|
82
|
+
if n_components == 0:
|
|
83
|
+
comp = self.Ls, self.Vs
|
|
84
|
+
return comp
|
|
85
|
+
elif n_components >= 1:
|
|
86
|
+
comp = self.Ls[:n_components], self.Vs[:n_components]
|
|
87
|
+
return comp
|
|
88
|
+
raise ValueError('n_components must be positive')
|
|
89
|
+
|
|
90
|
+
def _wishart_matrix(self, X):
|
|
91
|
+
if X.shape[0] <= X.shape[1]:
|
|
92
|
+
Y = (X @ X.T)
|
|
93
|
+
else:
|
|
94
|
+
Y = (X.T @ X)
|
|
95
|
+
Y /= X.shape[1]
|
|
96
|
+
return Y
|
|
97
|
+
|
|
98
|
+
def to_gpu(self, Y):
|
|
99
|
+
chunk_size = (10000, Y.shape[1])
|
|
100
|
+
if isinstance(Y, da.core.Array):
|
|
101
|
+
Y_dask = Y
|
|
102
|
+
else :
|
|
103
|
+
Y_dask = da.from_array(Y, chunks=chunk_size)
|
|
104
|
+
|
|
105
|
+
Y_gpu = cp.asarray(Y_dask.blocks[0])
|
|
106
|
+
|
|
107
|
+
chunk = len(Y_dask.chunks[0])
|
|
108
|
+
for i in range(1, chunk):
|
|
109
|
+
block = cp.asarray(Y_dask.blocks[i])
|
|
110
|
+
Y_gpu = cp.concatenate((Y_gpu, block), axis=0)
|
|
111
|
+
|
|
112
|
+
del block
|
|
113
|
+
gc.collect()
|
|
114
|
+
cp.get_default_memory_pool().free_all_blocks()
|
|
115
|
+
cp.get_default_pinned_memory_pool().free_all_blocks()
|
|
116
|
+
cp._default_memory_pool.free_all_blocks()
|
|
117
|
+
|
|
118
|
+
del Y_dask, chunk_size, Y, chunk
|
|
119
|
+
gc.collect()
|
|
120
|
+
cp.get_default_memory_pool().free_all_blocks()
|
|
121
|
+
cp.get_default_pinned_memory_pool().free_all_blocks()
|
|
122
|
+
cp._default_memory_pool.free_all_blocks()
|
|
123
|
+
|
|
124
|
+
return Y_gpu
|
|
125
|
+
|
|
126
|
+
def _get_eigen(self, X):
|
|
127
|
+
Y = self._wishart_matrix(X)
|
|
128
|
+
if self.device=='gpu':
|
|
129
|
+
try:
|
|
130
|
+
Y = self.to_gpu(Y)
|
|
131
|
+
L, V = cp.linalg.eigh(Y)
|
|
132
|
+
except cp.cuda.memory.OutOfMemoryError:
|
|
133
|
+
print('[Warning] GPU memory insufficient. Falling back to CPU computation.')
|
|
134
|
+
if isinstance(Y, cp.ndarray):
|
|
135
|
+
Y = Y.get()
|
|
136
|
+
cp.get_default_memory_pool().free_all_blocks()
|
|
137
|
+
cp.get_default_pinned_memory_pool().free_all_blocks()
|
|
138
|
+
cp._default_memory_pool.free_all_blocks()
|
|
139
|
+
L, V = np.linalg.eigh(Y)
|
|
140
|
+
|
|
141
|
+
elif self.device=='cpu':
|
|
142
|
+
L, V = np.linalg.eigh(Y)
|
|
143
|
+
else:
|
|
144
|
+
raise ValueError("The device must be either 'cpu' or 'gpu'.")
|
|
145
|
+
|
|
146
|
+
del Y
|
|
147
|
+
cp.get_default_memory_pool().free_all_blocks()
|
|
148
|
+
cp.get_default_pinned_memory_pool().free_all_blocks()
|
|
149
|
+
cp._default_memory_pool.free_all_blocks()
|
|
150
|
+
gc.collect()
|
|
151
|
+
return L, V
|
|
152
|
+
|
|
153
|
+
def _random_matrix(self, X):
|
|
154
|
+
|
|
155
|
+
if isinstance(X, da.core.Array):
|
|
156
|
+
X_dask = X
|
|
157
|
+
else:
|
|
158
|
+
X_dask = da.from_array(X, chunks=(10000, X.shape[1]))
|
|
159
|
+
def shuffle_block(block):
|
|
160
|
+
for row in block:
|
|
161
|
+
np.random.shuffle(row)
|
|
162
|
+
return block
|
|
163
|
+
|
|
164
|
+
Xr_dask = X_dask.map_blocks(shuffle_block, dtype=X_dask.dtype)
|
|
165
|
+
Xr = Xr_dask.compute()
|
|
166
|
+
|
|
167
|
+
del X_dask, Xr_dask
|
|
168
|
+
gc.collect()
|
|
169
|
+
cp.get_default_memory_pool().free_all_blocks()
|
|
170
|
+
cp.get_default_pinned_memory_pool().free_all_blocks()
|
|
171
|
+
cp._default_memory_pool.free_all_blocks()
|
|
172
|
+
return Xr
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def plot_mp(self, comparison=False, path=False, info=True, bins=None, title=None):
|
|
176
|
+
calc = Calc()
|
|
177
|
+
calc.style_mp_stat()
|
|
178
|
+
|
|
179
|
+
if bins is None:
|
|
180
|
+
bins = 300
|
|
181
|
+
|
|
182
|
+
if self.device == 'gpu':
|
|
183
|
+
x = np.linspace(0, int(cp.round(cp.max(self.L_mp) + 0.5)), 2000)
|
|
184
|
+
y = calc._mp_pdf(x, self.L_mp, self.rmt_device).get()
|
|
185
|
+
if comparison and self.Lr is not None:
|
|
186
|
+
yr = calc._mp_pdf(x, self.Lr, self.rmt_device).get()
|
|
187
|
+
elif self.device == 'cpu':
|
|
188
|
+
x = np.linspace(0, int(np.round(np.max(self.L_mp) + 0.5)), 2000)
|
|
189
|
+
y = calc._mp_pdf(x, self.L_mp, self.rmt_device)
|
|
190
|
+
if comparison and self.Lr is not None:
|
|
191
|
+
yr = calc._mp_pdf(x, self.Lr, self.rmt_device)
|
|
192
|
+
else:
|
|
193
|
+
raise ValueError("The device must be either 'cpu' or 'gpu'.")
|
|
194
|
+
|
|
195
|
+
# info 부분 합침
|
|
196
|
+
if info:
|
|
197
|
+
fig = plt.figure(dpi=100)
|
|
198
|
+
fig.set_layout_engine()
|
|
199
|
+
|
|
200
|
+
ax = fig.add_subplot(111)
|
|
201
|
+
|
|
202
|
+
dic = calc._mp_parameters(self.L_mp, self.rmt_device)
|
|
203
|
+
info1 = (r'$\bf{Data Parameters}$' + '\n{0} cells\n{1} genes'
|
|
204
|
+
.format(self.n_cells, self.n_genes))
|
|
205
|
+
info2 = ('\n' + r'$\bf{MP\ distribution\ in\ data}$'
|
|
206
|
+
+ '\n$\gamma={:0.2f}$ \n$\sigma^2={:1.2f}$\
|
|
207
|
+
\n$b_-={:2.2f}$\n$b_+={:3.2f}$'
|
|
208
|
+
.format(dic['gamma'], dic['s'], dic['b_minus'],
|
|
209
|
+
dic['b_plus']))
|
|
210
|
+
|
|
211
|
+
n_components = self.n_components if self.n_components is not None else 0
|
|
212
|
+
info3 = ('\n' + r'$\bf{Analysis}$' +
|
|
213
|
+
'\n{0} eigenvalues > $\lambda_c (3 \sigma)$\
|
|
214
|
+
\n{1} noise eigenvalues'
|
|
215
|
+
.format(n_components, self.n_cells - n_components))
|
|
216
|
+
|
|
217
|
+
# print("L_mp type:", type(self.L_mp))
|
|
218
|
+
# print("L_mp shape:", self.L_mp.shape if hasattr(self.L_mp, "shape") else "No shape attribute")
|
|
219
|
+
|
|
220
|
+
if isinstance(self.L_mp, np.ndarray):
|
|
221
|
+
cdf_func = calc._call_mp_cdf(self.L_mp, dic)
|
|
222
|
+
ks = stats.kstest(self.L_mp, cdf_func)
|
|
223
|
+
else:
|
|
224
|
+
cdf_func = calc._call_mp_cdf(self.L_mp.get(), dic)
|
|
225
|
+
ks = stats.kstest(self.L_mp.get(), cdf_func)
|
|
226
|
+
|
|
227
|
+
info4 = '\n'+r'$\bf{Statistics}$'+'\nKS distance ={0}'\
|
|
228
|
+
.format(round(ks[0], 4))\
|
|
229
|
+
+ '\nKS test p-value={0}'\
|
|
230
|
+
.format(round(ks[1], 2))
|
|
231
|
+
|
|
232
|
+
infoT = info1 + info2 + info4 + info3
|
|
233
|
+
|
|
234
|
+
ax.text(1.05, 1.02, infoT, transform=ax.transAxes, fontsize=10,
|
|
235
|
+
verticalalignment='top', horizontalalignment='left',
|
|
236
|
+
bbox=dict(facecolor='wheat', alpha=0.8, boxstyle='round,pad=0.5'))
|
|
237
|
+
|
|
238
|
+
# props = dict(boxstyle='round', facecolor='wheat', alpha=0.8)
|
|
239
|
+
|
|
240
|
+
# at = AnchoredText(infoT, loc=2, prop=dict(size=10),
|
|
241
|
+
# frameon=True,
|
|
242
|
+
# bbox_to_anchor=(1., 1.024),
|
|
243
|
+
# bbox_transform=ax.transAxes)
|
|
244
|
+
# at.patch.set_boxstyle("round,pad=0.,rounding_size=0.2")
|
|
245
|
+
# lgd = ax.add_artist(at)
|
|
246
|
+
else:
|
|
247
|
+
plt.figure(dip=100)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
# distplot이 deprecated -> histplot으로 변경
|
|
251
|
+
if not isinstance(self.L, np.ndarray):
|
|
252
|
+
self.L = self.L.get()
|
|
253
|
+
if not isinstance(self.Lr, np.ndarray):
|
|
254
|
+
self.Lr = self.Lr.get()
|
|
255
|
+
|
|
256
|
+
plot = sns.histplot(self.L, bins=bins, stat="density",
|
|
257
|
+
kde=False, color=sns.xkcd_rgb["cornflower blue"], alpha=0.85)
|
|
258
|
+
|
|
259
|
+
# MP 분포 선 (랜덤 데이터)
|
|
260
|
+
plt.plot(x, y, color=sns.xkcd_rgb["pale red"], lw=2, label="MP for random part in data")
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
# 범례 설정
|
|
264
|
+
MP_data = mlines.Line2D([], [], color=sns.xkcd_rgb["pale red"], label="MP for random part in data", linewidth=2)
|
|
265
|
+
MP_rand = mlines.Line2D([], [], color=sns.xkcd_rgb["sap green"], label="MP for randomized data", linewidth=1.5, linestyle='--')
|
|
266
|
+
randomized = mpatches.Patch(color=sns.xkcd_rgb["apple green"], label="Randomized data", alpha=0.75, linewidth=3, fill=False)
|
|
267
|
+
data_real = mpatches.Patch(color=sns.xkcd_rgb["cornflower blue"], label="Real data", alpha=0.85)
|
|
268
|
+
|
|
269
|
+
# 비교가 필요한 경우
|
|
270
|
+
if comparison:
|
|
271
|
+
sns.histplot(self.Lr, bins=30, kde=False,
|
|
272
|
+
stat="density", color=sns.xkcd_rgb["apple green"], alpha=0.75, linewidth=3)
|
|
273
|
+
|
|
274
|
+
ax.plot(x, yr, sns.xkcd_rgb["sap green"], lw=1.5, ls='--')
|
|
275
|
+
|
|
276
|
+
ax.legend(handles=[data_real, MP_data, randomized, MP_rand], loc="upper right", frameon=True)
|
|
277
|
+
else:
|
|
278
|
+
ax.legend(handles=[data_real, MP_data], loc="upper right", frameon=True)
|
|
279
|
+
|
|
280
|
+
# x축 범위 설정
|
|
281
|
+
if self.device == 'cpu':
|
|
282
|
+
max_Lr = np.max(self.Lr) if self.Lr is not None else 0
|
|
283
|
+
max_L_mp = np.max(self.L_mp) if self.L_mp is not None else 0
|
|
284
|
+
|
|
285
|
+
elif self.device == 'gpu':
|
|
286
|
+
max_Lr = cp.max(self.Lr) if self.Lr is not None else 0
|
|
287
|
+
max_L_mp = cp.max(self.L_mp) if self.L_mp is not None else 0
|
|
288
|
+
|
|
289
|
+
else:
|
|
290
|
+
raise ValueError("The device must be either 'cpu' or 'gpu'.")
|
|
291
|
+
|
|
292
|
+
ax.set_xlim([0, int(np.round(max(max_Lr, max_L_mp) + 1.5))])
|
|
293
|
+
|
|
294
|
+
# 격자 스타일 설정
|
|
295
|
+
ax.grid(linestyle='--', lw=0.3)
|
|
296
|
+
|
|
297
|
+
# 제목 설정
|
|
298
|
+
if title:
|
|
299
|
+
ax.set_title(title)
|
|
300
|
+
|
|
301
|
+
# x축 레이블 설정
|
|
302
|
+
ax.set_xlabel('First cell eigenvalues normalized distribution')
|
|
303
|
+
|
|
304
|
+
if self.data is not None and isinstance(self.data, sc.AnnData):
|
|
305
|
+
self.data.uns['mp_plot'] = fig
|
|
306
|
+
|
|
307
|
+
# if path:
|
|
308
|
+
# plt.savefig(path, bbox_inches="tight")
|
|
309
|
+
return fig
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from . import scLENS_py
|
stlens-0.1.0/src/calc.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
import scanpy as sc
|
|
2
|
+
import cupy as cp
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import numpy as np
|
|
5
|
+
from scipy import stats, linalg
|
|
6
|
+
import scipy
|
|
7
|
+
|
|
8
|
+
import matplotlib.pyplot as plt
|
|
9
|
+
from matplotlib import rcParams
|
|
10
|
+
import matplotlib.lines as mlines
|
|
11
|
+
import matplotlib.gridspec as gridspec
|
|
12
|
+
from matplotlib.offsetbox import AnchoredText
|
|
13
|
+
import matplotlib.patches as mpatches
|
|
14
|
+
import matplotlib.lines as mlines
|
|
15
|
+
|
|
16
|
+
import seaborn as sns
|
|
17
|
+
|
|
18
|
+
import psutil
|
|
19
|
+
import os
|
|
20
|
+
|
|
21
|
+
class Calc():
|
|
22
|
+
def __init__(self,device = None, data = None, L=None, L_mp=None):
|
|
23
|
+
|
|
24
|
+
self.L = []
|
|
25
|
+
self.V = None
|
|
26
|
+
self.L_mp = None
|
|
27
|
+
self.explained_variance_ = []
|
|
28
|
+
self.total_variance_ = []
|
|
29
|
+
self.device = device
|
|
30
|
+
self.data = data
|
|
31
|
+
|
|
32
|
+
def style_mp_stat(self):
|
|
33
|
+
plt.style.use("ggplot")
|
|
34
|
+
# np.seterr(invalid='ßignore')
|
|
35
|
+
np.seterr(invalid='ignore')
|
|
36
|
+
sns.set_style("white")
|
|
37
|
+
sns.set_context("paper")
|
|
38
|
+
sns.set_palette("deep")
|
|
39
|
+
plt.rcParams['axes.linewidth'] =0.5
|
|
40
|
+
plt.rcParams['figure.dpi'] = 100
|
|
41
|
+
|
|
42
|
+
def _tw(self, rmt_device):
|
|
43
|
+
'''Tracy-Widom critical eigenvalue'''
|
|
44
|
+
if self.L is None or len(self.L) == 0:
|
|
45
|
+
raise ValueError("self.L must not be empty or None.")
|
|
46
|
+
if self.L_mp is None or len(self.L_mp) == 0:
|
|
47
|
+
raise ValueError("self.L_mp must not be empty or None.")
|
|
48
|
+
|
|
49
|
+
gamma = self._mp_parameters(self.L_mp, rmt_device)['gamma']
|
|
50
|
+
p = len(self.L) / gamma
|
|
51
|
+
|
|
52
|
+
if rmt_device == 'cpu':
|
|
53
|
+
sigma = 1 / np.power(p, 2/3) * np.power(gamma, 5/6) * \
|
|
54
|
+
np.power((1 + np.sqrt(gamma)), 4/3)
|
|
55
|
+
lambda_c = np.mean(self.L_mp) * (1 + np.sqrt(gamma)) ** 2 + sigma
|
|
56
|
+
|
|
57
|
+
else:
|
|
58
|
+
sigma = 1 / cp.power(p, 2/3) * cp.power(gamma, 5/6) * \
|
|
59
|
+
cp.power((1 + cp.sqrt(gamma)), 4/3)
|
|
60
|
+
lambda_c = cp.mean(self.L_mp) * (1 + cp.sqrt(gamma)) ** 2 + sigma
|
|
61
|
+
|
|
62
|
+
return lambda_c
|
|
63
|
+
|
|
64
|
+
def _mp_parameters(self, L, rmt_device):
|
|
65
|
+
if rmt_device == 'gpu':
|
|
66
|
+
moment_1 = cp.mean(L)
|
|
67
|
+
moment_2 = cp.mean(cp.power(L, 2))
|
|
68
|
+
gamma = moment_2 / float(moment_1**2) - 1
|
|
69
|
+
s = moment_1
|
|
70
|
+
sigma = moment_2
|
|
71
|
+
b_plus = s * (1 + cp.sqrt(gamma))**2
|
|
72
|
+
b_minus = s * (1 - cp.sqrt(gamma))**2
|
|
73
|
+
x_peak = s * (1.0 - gamma)**2.0 / (1.0 + gamma)
|
|
74
|
+
return {
|
|
75
|
+
'moment_1': moment_1,
|
|
76
|
+
'moment_2': moment_2,
|
|
77
|
+
'gamma': gamma,
|
|
78
|
+
'b_plus': b_plus,
|
|
79
|
+
'b_minus': b_minus,
|
|
80
|
+
's': s,
|
|
81
|
+
'peak': x_peak,
|
|
82
|
+
'sigma': sigma
|
|
83
|
+
}
|
|
84
|
+
else:
|
|
85
|
+
moment_1 = np.mean(L)
|
|
86
|
+
moment_2 = np.mean(np.power(L, 2))
|
|
87
|
+
gamma = moment_2 / float(moment_1**2) - 1
|
|
88
|
+
s = moment_1
|
|
89
|
+
sigma = moment_2
|
|
90
|
+
b_plus = s * (1 + np.sqrt(gamma))**2
|
|
91
|
+
b_minus = s * (1 - np.sqrt(gamma))**2
|
|
92
|
+
x_peak = s * (1.0 - gamma)**2.0 / (1.0 + gamma)
|
|
93
|
+
return {
|
|
94
|
+
'moment_1': moment_1,
|
|
95
|
+
'moment_2': moment_2,
|
|
96
|
+
'gamma': gamma,
|
|
97
|
+
'b_plus': b_plus,
|
|
98
|
+
'b_minus': b_minus,
|
|
99
|
+
's': s,
|
|
100
|
+
'peak': x_peak,
|
|
101
|
+
'sigma': sigma
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _marchenko_pastur(self, x, dic):
|
|
107
|
+
'''Distribution of eigenvalues'''
|
|
108
|
+
pdf = np.sqrt((dic['b_plus'] - x) * (x-dic['b_minus']))\
|
|
109
|
+
/ float(2 * dic['s'] * np.pi * dic['gamma'] * x)
|
|
110
|
+
return pdf
|
|
111
|
+
|
|
112
|
+
def _mp_pdf(self, x, L, rmt_device):
|
|
113
|
+
'''Marchnko-Pastur PDF'''
|
|
114
|
+
dic = self._mp_parameters(L, rmt_device)
|
|
115
|
+
if rmt_device == 'cpu':
|
|
116
|
+
y = np.empty_like(x)
|
|
117
|
+
else:
|
|
118
|
+
y = cp.empty_like(x)
|
|
119
|
+
for i, xi in enumerate(x):
|
|
120
|
+
y[i] = self._marchenko_pastur(xi, dic)
|
|
121
|
+
return y
|
|
122
|
+
|
|
123
|
+
def _mp_calculation(self, L, Lr, rmt_device, eta=1, eps=10**-6, max_iter=1000):
|
|
124
|
+
converged = False
|
|
125
|
+
iter = 0
|
|
126
|
+
loss_history = []
|
|
127
|
+
|
|
128
|
+
b_plus = self._mp_parameters(Lr, rmt_device)['b_plus']
|
|
129
|
+
b_minus = self._mp_parameters(Lr, rmt_device)['b_minus']
|
|
130
|
+
|
|
131
|
+
L_updated = L[(L > b_minus) & (L < b_plus)]
|
|
132
|
+
new_b_plus = self._mp_parameters(L_updated, rmt_device)['b_plus']
|
|
133
|
+
new_b_minus = self._mp_parameters(L_updated, rmt_device)['b_minus']
|
|
134
|
+
|
|
135
|
+
while not converged:
|
|
136
|
+
loss = (1 - float(new_b_plus) / float(b_plus))**2
|
|
137
|
+
loss_history.append(loss)
|
|
138
|
+
iter += 1
|
|
139
|
+
|
|
140
|
+
if loss <= eps:
|
|
141
|
+
converged = True
|
|
142
|
+
elif iter == max_iter:
|
|
143
|
+
print('Max interactions exceeded!')
|
|
144
|
+
converged = True
|
|
145
|
+
else:
|
|
146
|
+
gradient = new_b_plus - b_plus
|
|
147
|
+
new_b_plus = b_plus + eta * gradient
|
|
148
|
+
|
|
149
|
+
L_updated = L[(L > new_b_minus) & (L < new_b_plus)]
|
|
150
|
+
self.b_plus = new_b_plus
|
|
151
|
+
self.b_minus = new_b_minus
|
|
152
|
+
|
|
153
|
+
new_b_plus = self._mp_parameters(L_updated, rmt_device)['b_plus']
|
|
154
|
+
new_b_minus = self._mp_parameters(L_updated, rmt_device)['b_minus']
|
|
155
|
+
|
|
156
|
+
if rmt_device == 'cpu':
|
|
157
|
+
indices = np.where((L > new_b_minus) & (L < new_b_plus))
|
|
158
|
+
L_mp = L[indices]
|
|
159
|
+
return np.array(L_mp)
|
|
160
|
+
else:
|
|
161
|
+
indices = cp.where((L > new_b_minus) & (L < new_b_plus))
|
|
162
|
+
L_mp = L[indices]
|
|
163
|
+
return cp.array(L_mp)
|
|
164
|
+
|
|
165
|
+
def _cdf_marchenko(self,x,dic):
|
|
166
|
+
if x < dic['b_minus']:
|
|
167
|
+
return 0.0
|
|
168
|
+
elif x>dic['b_minus'] and x<dic['b_plus']:
|
|
169
|
+
return 1/float(2*dic['s']*np.pi*dic['gamma'])*\
|
|
170
|
+
float(np.sqrt((dic['b_plus']-x)*(x-dic['b_minus']))+\
|
|
171
|
+
(dic['b_plus']+dic['b_minus'])/2*np.arcsin((2*x-dic['b_plus']-\
|
|
172
|
+
dic['b_minus'])/(dic['b_plus']-dic['b_minus']))-\
|
|
173
|
+
np.sqrt(dic['b_plus']*dic['b_minus'])*np.arcsin(((dic['b_plus']+\
|
|
174
|
+
dic['b_minus'])*x-2*dic['b_plus']*dic['b_minus'])/\
|
|
175
|
+
((dic['b_plus']-dic['b_minus'])*x)) )+np.arcsin(1)/np.pi
|
|
176
|
+
else:
|
|
177
|
+
return 1.0
|
|
178
|
+
|
|
179
|
+
def _call_mp_cdf(self,L,dic):
|
|
180
|
+
"CDF of Marchenko Pastur"
|
|
181
|
+
func= lambda y: list(map(lambda x: self._cdf_marchenko(x,dic), y))
|
|
182
|
+
return func
|