stLENS 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
stlens-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,12 @@
1
+ Metadata-Version: 2.4
2
+ Name: stLENS
3
+ Version: 0.1.0
4
+ Requires-Dist: python==3.10
5
+ Requires-Dist: ipykernel
6
+ Requires-Dist: torch
7
+ Requires-Dist: scanpy
8
+ Requires-Dist: dask
9
+ Requires-Dist: cupy
10
+ Requires-Dist: zarr
11
+ Requires-Dist: python-igraph==0.9.9
12
+ Requires-Dist: leidenalg
stlens-0.1.0/README.md ADDED
@@ -0,0 +1 @@
1
+ # scLENS
@@ -0,0 +1,19 @@
1
+ [project]
2
+ name = "stLENS"
3
+ version = "0.1.0"
4
+
5
+ dependencies = [
6
+ "python==3.10",
7
+ "ipykernel",
8
+ "torch",
9
+ "scanpy",
10
+ "dask",
11
+ "cupy",
12
+ "zarr",
13
+ "python-igraph==0.9.9",
14
+ "leidenalg"
15
+ ]
16
+
17
+ [build-system]
18
+ requires = ["setuptools>=61.0"]
19
+ build-backend = "setuptools.build_meta"
stlens-0.1.0/setup.cfg ADDED
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,309 @@
1
+ from .calc import Calc
2
+
3
+ import dask.array as da
4
+ import pandas as pd
5
+ import scanpy as sc
6
+ import torch
7
+ import numpy as np
8
+ import cupy as cp
9
+ from tqdm.auto import tqdm
10
+ import matplotlib.pyplot as plt
11
+ import seaborn as sns
12
+ from scipy import stats
13
+ import matplotlib.patches as mpatches
14
+ import matplotlib.lines as mlines
15
+ import gc
16
+
17
+ class PCA():
18
+ def __init__(self, device = None, data = None):
19
+ self.device = device
20
+ self.data = data
21
+
22
+
23
+ def fit(self, X=None, eigen_solver = 'wishart'):
24
+ calc = Calc()
25
+ self.n_cells, self.n_genes = X.shape
26
+ self.rmt_device = None
27
+
28
+ if eigen_solver == 'wishart':
29
+ self.L, self.V = self._get_eigen(X)
30
+ Xr = self._random_matrix(X)
31
+
32
+ if isinstance(self.L, np.ndarray):
33
+ self.device = 'cpu'
34
+ self.rmt_device = 'cpu'
35
+ self.Lr, self.Vr = self._get_eigen(Xr)
36
+
37
+ if self.rmt_device != 'cpu' and isinstance(self.Lr, np.ndarray):
38
+ self.L = self.L.get()
39
+ self.V = self.V.get()
40
+ self.rmt_device = 'cpu'
41
+ elif self.rmt_device == 'cpu':
42
+ self.rmt_device = 'cpu'
43
+ else:
44
+ self.rmt_device = 'gpu'
45
+
46
+
47
+ del Xr
48
+ cp.get_default_memory_pool().free_all_blocks()
49
+ cp.get_default_pinned_memory_pool().free_all_blocks()
50
+ cp._default_memory_pool.free_all_blocks()
51
+ gc.collect()
52
+
53
+ self.explained_variance_ = (self.L**2) / self.n_cells
54
+ self.total_variance_ = self.explained_variance_.sum()
55
+
56
+ calc.L = self.L
57
+ # calc.rmt_device = self.rmt_device
58
+ self.L_mp = calc._mp_calculation(self.L, self.Lr, self.rmt_device)
59
+ calc.L_mp = self.L_mp
60
+ self.lambda_c = calc._tw(self.rmt_device)
61
+ print("lambda_c:",self.lambda_c)
62
+ self.peak = calc._mp_parameters(self.L_mp, self.rmt_device)['peak']
63
+
64
+ else:
65
+ raise ValueError("Invalid eigen_solver. Use 'wishart'.")
66
+
67
+ self.Ls = self.L[self.L > self.lambda_c]
68
+ self.Vs = self.V[:, self.L > self.lambda_c]
69
+
70
+ noise_boolean = ((self.L < self.lambda_c) & (self.L > calc.b_minus))
71
+ self.Vn = self.V[:, noise_boolean]
72
+ self.Ln = self.L[noise_boolean]
73
+
74
+ self.n_components = len(self.Ls)
75
+ print(f"Number of signal components: {self.n_components}")
76
+
77
+ cp.get_default_memory_pool().free_all_blocks()
78
+ cp.get_default_pinned_memory_pool().free_all_blocks()
79
+ gc.collect()
80
+
81
+ def get_signal_components(self, n_components=0):
82
+ if n_components == 0:
83
+ comp = self.Ls, self.Vs
84
+ return comp
85
+ elif n_components >= 1:
86
+ comp = self.Ls[:n_components], self.Vs[:n_components]
87
+ return comp
88
+ raise ValueError('n_components must be positive')
89
+
90
+ def _wishart_matrix(self, X):
91
+ if X.shape[0] <= X.shape[1]:
92
+ Y = (X @ X.T)
93
+ else:
94
+ Y = (X.T @ X)
95
+ Y /= X.shape[1]
96
+ return Y
97
+
98
+ def to_gpu(self, Y):
99
+ chunk_size = (10000, Y.shape[1])
100
+ if isinstance(Y, da.core.Array):
101
+ Y_dask = Y
102
+ else :
103
+ Y_dask = da.from_array(Y, chunks=chunk_size)
104
+
105
+ Y_gpu = cp.asarray(Y_dask.blocks[0])
106
+
107
+ chunk = len(Y_dask.chunks[0])
108
+ for i in range(1, chunk):
109
+ block = cp.asarray(Y_dask.blocks[i])
110
+ Y_gpu = cp.concatenate((Y_gpu, block), axis=0)
111
+
112
+ del block
113
+ gc.collect()
114
+ cp.get_default_memory_pool().free_all_blocks()
115
+ cp.get_default_pinned_memory_pool().free_all_blocks()
116
+ cp._default_memory_pool.free_all_blocks()
117
+
118
+ del Y_dask, chunk_size, Y, chunk
119
+ gc.collect()
120
+ cp.get_default_memory_pool().free_all_blocks()
121
+ cp.get_default_pinned_memory_pool().free_all_blocks()
122
+ cp._default_memory_pool.free_all_blocks()
123
+
124
+ return Y_gpu
125
+
126
+ def _get_eigen(self, X):
127
+ Y = self._wishart_matrix(X)
128
+ if self.device=='gpu':
129
+ try:
130
+ Y = self.to_gpu(Y)
131
+ L, V = cp.linalg.eigh(Y)
132
+ except cp.cuda.memory.OutOfMemoryError:
133
+ print('[Warning] GPU memory insufficient. Falling back to CPU computation.')
134
+ if isinstance(Y, cp.ndarray):
135
+ Y = Y.get()
136
+ cp.get_default_memory_pool().free_all_blocks()
137
+ cp.get_default_pinned_memory_pool().free_all_blocks()
138
+ cp._default_memory_pool.free_all_blocks()
139
+ L, V = np.linalg.eigh(Y)
140
+
141
+ elif self.device=='cpu':
142
+ L, V = np.linalg.eigh(Y)
143
+ else:
144
+ raise ValueError("The device must be either 'cpu' or 'gpu'.")
145
+
146
+ del Y
147
+ cp.get_default_memory_pool().free_all_blocks()
148
+ cp.get_default_pinned_memory_pool().free_all_blocks()
149
+ cp._default_memory_pool.free_all_blocks()
150
+ gc.collect()
151
+ return L, V
152
+
153
+ def _random_matrix(self, X):
154
+
155
+ if isinstance(X, da.core.Array):
156
+ X_dask = X
157
+ else:
158
+ X_dask = da.from_array(X, chunks=(10000, X.shape[1]))
159
+ def shuffle_block(block):
160
+ for row in block:
161
+ np.random.shuffle(row)
162
+ return block
163
+
164
+ Xr_dask = X_dask.map_blocks(shuffle_block, dtype=X_dask.dtype)
165
+ Xr = Xr_dask.compute()
166
+
167
+ del X_dask, Xr_dask
168
+ gc.collect()
169
+ cp.get_default_memory_pool().free_all_blocks()
170
+ cp.get_default_pinned_memory_pool().free_all_blocks()
171
+ cp._default_memory_pool.free_all_blocks()
172
+ return Xr
173
+
174
+
175
+ def plot_mp(self, comparison=False, path=False, info=True, bins=None, title=None):
176
+ calc = Calc()
177
+ calc.style_mp_stat()
178
+
179
+ if bins is None:
180
+ bins = 300
181
+
182
+ if self.device == 'gpu':
183
+ x = np.linspace(0, int(cp.round(cp.max(self.L_mp) + 0.5)), 2000)
184
+ y = calc._mp_pdf(x, self.L_mp, self.rmt_device).get()
185
+ if comparison and self.Lr is not None:
186
+ yr = calc._mp_pdf(x, self.Lr, self.rmt_device).get()
187
+ elif self.device == 'cpu':
188
+ x = np.linspace(0, int(np.round(np.max(self.L_mp) + 0.5)), 2000)
189
+ y = calc._mp_pdf(x, self.L_mp, self.rmt_device)
190
+ if comparison and self.Lr is not None:
191
+ yr = calc._mp_pdf(x, self.Lr, self.rmt_device)
192
+ else:
193
+ raise ValueError("The device must be either 'cpu' or 'gpu'.")
194
+
195
+ # info 부분 합침
196
+ if info:
197
+ fig = plt.figure(dpi=100)
198
+ fig.set_layout_engine()
199
+
200
+ ax = fig.add_subplot(111)
201
+
202
+ dic = calc._mp_parameters(self.L_mp, self.rmt_device)
203
+ info1 = (r'$\bf{Data Parameters}$' + '\n{0} cells\n{1} genes'
204
+ .format(self.n_cells, self.n_genes))
205
+ info2 = ('\n' + r'$\bf{MP\ distribution\ in\ data}$'
206
+ + '\n$\gamma={:0.2f}$ \n$\sigma^2={:1.2f}$\
207
+ \n$b_-={:2.2f}$\n$b_+={:3.2f}$'
208
+ .format(dic['gamma'], dic['s'], dic['b_minus'],
209
+ dic['b_plus']))
210
+
211
+ n_components = self.n_components if self.n_components is not None else 0
212
+ info3 = ('\n' + r'$\bf{Analysis}$' +
213
+ '\n{0} eigenvalues > $\lambda_c (3 \sigma)$\
214
+ \n{1} noise eigenvalues'
215
+ .format(n_components, self.n_cells - n_components))
216
+
217
+ # print("L_mp type:", type(self.L_mp))
218
+ # print("L_mp shape:", self.L_mp.shape if hasattr(self.L_mp, "shape") else "No shape attribute")
219
+
220
+ if isinstance(self.L_mp, np.ndarray):
221
+ cdf_func = calc._call_mp_cdf(self.L_mp, dic)
222
+ ks = stats.kstest(self.L_mp, cdf_func)
223
+ else:
224
+ cdf_func = calc._call_mp_cdf(self.L_mp.get(), dic)
225
+ ks = stats.kstest(self.L_mp.get(), cdf_func)
226
+
227
+ info4 = '\n'+r'$\bf{Statistics}$'+'\nKS distance ={0}'\
228
+ .format(round(ks[0], 4))\
229
+ + '\nKS test p-value={0}'\
230
+ .format(round(ks[1], 2))
231
+
232
+ infoT = info1 + info2 + info4 + info3
233
+
234
+ ax.text(1.05, 1.02, infoT, transform=ax.transAxes, fontsize=10,
235
+ verticalalignment='top', horizontalalignment='left',
236
+ bbox=dict(facecolor='wheat', alpha=0.8, boxstyle='round,pad=0.5'))
237
+
238
+ # props = dict(boxstyle='round', facecolor='wheat', alpha=0.8)
239
+
240
+ # at = AnchoredText(infoT, loc=2, prop=dict(size=10),
241
+ # frameon=True,
242
+ # bbox_to_anchor=(1., 1.024),
243
+ # bbox_transform=ax.transAxes)
244
+ # at.patch.set_boxstyle("round,pad=0.,rounding_size=0.2")
245
+ # lgd = ax.add_artist(at)
246
+ else:
247
+ plt.figure(dip=100)
248
+
249
+
250
+ # distplot이 deprecated -> histplot으로 변경
251
+ if not isinstance(self.L, np.ndarray):
252
+ self.L = self.L.get()
253
+ if not isinstance(self.Lr, np.ndarray):
254
+ self.Lr = self.Lr.get()
255
+
256
+ plot = sns.histplot(self.L, bins=bins, stat="density",
257
+ kde=False, color=sns.xkcd_rgb["cornflower blue"], alpha=0.85)
258
+
259
+ # MP 분포 선 (랜덤 데이터)
260
+ plt.plot(x, y, color=sns.xkcd_rgb["pale red"], lw=2, label="MP for random part in data")
261
+
262
+
263
+ # 범례 설정
264
+ MP_data = mlines.Line2D([], [], color=sns.xkcd_rgb["pale red"], label="MP for random part in data", linewidth=2)
265
+ MP_rand = mlines.Line2D([], [], color=sns.xkcd_rgb["sap green"], label="MP for randomized data", linewidth=1.5, linestyle='--')
266
+ randomized = mpatches.Patch(color=sns.xkcd_rgb["apple green"], label="Randomized data", alpha=0.75, linewidth=3, fill=False)
267
+ data_real = mpatches.Patch(color=sns.xkcd_rgb["cornflower blue"], label="Real data", alpha=0.85)
268
+
269
+ # 비교가 필요한 경우
270
+ if comparison:
271
+ sns.histplot(self.Lr, bins=30, kde=False,
272
+ stat="density", color=sns.xkcd_rgb["apple green"], alpha=0.75, linewidth=3)
273
+
274
+ ax.plot(x, yr, sns.xkcd_rgb["sap green"], lw=1.5, ls='--')
275
+
276
+ ax.legend(handles=[data_real, MP_data, randomized, MP_rand], loc="upper right", frameon=True)
277
+ else:
278
+ ax.legend(handles=[data_real, MP_data], loc="upper right", frameon=True)
279
+
280
+ # x축 범위 설정
281
+ if self.device == 'cpu':
282
+ max_Lr = np.max(self.Lr) if self.Lr is not None else 0
283
+ max_L_mp = np.max(self.L_mp) if self.L_mp is not None else 0
284
+
285
+ elif self.device == 'gpu':
286
+ max_Lr = cp.max(self.Lr) if self.Lr is not None else 0
287
+ max_L_mp = cp.max(self.L_mp) if self.L_mp is not None else 0
288
+
289
+ else:
290
+ raise ValueError("The device must be either 'cpu' or 'gpu'.")
291
+
292
+ ax.set_xlim([0, int(np.round(max(max_Lr, max_L_mp) + 1.5))])
293
+
294
+ # 격자 스타일 설정
295
+ ax.grid(linestyle='--', lw=0.3)
296
+
297
+ # 제목 설정
298
+ if title:
299
+ ax.set_title(title)
300
+
301
+ # x축 레이블 설정
302
+ ax.set_xlabel('First cell eigenvalues normalized distribution')
303
+
304
+ if self.data is not None and isinstance(self.data, sc.AnnData):
305
+ self.data.uns['mp_plot'] = fig
306
+
307
+ # if path:
308
+ # plt.savefig(path, bbox_inches="tight")
309
+ return fig
@@ -0,0 +1 @@
1
+ from . import scLENS_py
@@ -0,0 +1,182 @@
1
+ import scanpy as sc
2
+ import cupy as cp
3
+ import pandas as pd
4
+ import numpy as np
5
+ from scipy import stats, linalg
6
+ import scipy
7
+
8
+ import matplotlib.pyplot as plt
9
+ from matplotlib import rcParams
10
+ import matplotlib.lines as mlines
11
+ import matplotlib.gridspec as gridspec
12
+ from matplotlib.offsetbox import AnchoredText
13
+ import matplotlib.patches as mpatches
14
+ import matplotlib.lines as mlines
15
+
16
+ import seaborn as sns
17
+
18
+ import psutil
19
+ import os
20
+
21
+ class Calc():
22
+ def __init__(self,device = None, data = None, L=None, L_mp=None):
23
+
24
+ self.L = []
25
+ self.V = None
26
+ self.L_mp = None
27
+ self.explained_variance_ = []
28
+ self.total_variance_ = []
29
+ self.device = device
30
+ self.data = data
31
+
32
+ def style_mp_stat(self):
33
+ plt.style.use("ggplot")
34
+ # np.seterr(invalid='ßignore')
35
+ np.seterr(invalid='ignore')
36
+ sns.set_style("white")
37
+ sns.set_context("paper")
38
+ sns.set_palette("deep")
39
+ plt.rcParams['axes.linewidth'] =0.5
40
+ plt.rcParams['figure.dpi'] = 100
41
+
42
+ def _tw(self, rmt_device):
43
+ '''Tracy-Widom critical eigenvalue'''
44
+ if self.L is None or len(self.L) == 0:
45
+ raise ValueError("self.L must not be empty or None.")
46
+ if self.L_mp is None or len(self.L_mp) == 0:
47
+ raise ValueError("self.L_mp must not be empty or None.")
48
+
49
+ gamma = self._mp_parameters(self.L_mp, rmt_device)['gamma']
50
+ p = len(self.L) / gamma
51
+
52
+ if rmt_device == 'cpu':
53
+ sigma = 1 / np.power(p, 2/3) * np.power(gamma, 5/6) * \
54
+ np.power((1 + np.sqrt(gamma)), 4/3)
55
+ lambda_c = np.mean(self.L_mp) * (1 + np.sqrt(gamma)) ** 2 + sigma
56
+
57
+ else:
58
+ sigma = 1 / cp.power(p, 2/3) * cp.power(gamma, 5/6) * \
59
+ cp.power((1 + cp.sqrt(gamma)), 4/3)
60
+ lambda_c = cp.mean(self.L_mp) * (1 + cp.sqrt(gamma)) ** 2 + sigma
61
+
62
+ return lambda_c
63
+
64
+ def _mp_parameters(self, L, rmt_device):
65
+ if rmt_device == 'gpu':
66
+ moment_1 = cp.mean(L)
67
+ moment_2 = cp.mean(cp.power(L, 2))
68
+ gamma = moment_2 / float(moment_1**2) - 1
69
+ s = moment_1
70
+ sigma = moment_2
71
+ b_plus = s * (1 + cp.sqrt(gamma))**2
72
+ b_minus = s * (1 - cp.sqrt(gamma))**2
73
+ x_peak = s * (1.0 - gamma)**2.0 / (1.0 + gamma)
74
+ return {
75
+ 'moment_1': moment_1,
76
+ 'moment_2': moment_2,
77
+ 'gamma': gamma,
78
+ 'b_plus': b_plus,
79
+ 'b_minus': b_minus,
80
+ 's': s,
81
+ 'peak': x_peak,
82
+ 'sigma': sigma
83
+ }
84
+ else:
85
+ moment_1 = np.mean(L)
86
+ moment_2 = np.mean(np.power(L, 2))
87
+ gamma = moment_2 / float(moment_1**2) - 1
88
+ s = moment_1
89
+ sigma = moment_2
90
+ b_plus = s * (1 + np.sqrt(gamma))**2
91
+ b_minus = s * (1 - np.sqrt(gamma))**2
92
+ x_peak = s * (1.0 - gamma)**2.0 / (1.0 + gamma)
93
+ return {
94
+ 'moment_1': moment_1,
95
+ 'moment_2': moment_2,
96
+ 'gamma': gamma,
97
+ 'b_plus': b_plus,
98
+ 'b_minus': b_minus,
99
+ 's': s,
100
+ 'peak': x_peak,
101
+ 'sigma': sigma
102
+ }
103
+
104
+
105
+
106
+ def _marchenko_pastur(self, x, dic):
107
+ '''Distribution of eigenvalues'''
108
+ pdf = np.sqrt((dic['b_plus'] - x) * (x-dic['b_minus']))\
109
+ / float(2 * dic['s'] * np.pi * dic['gamma'] * x)
110
+ return pdf
111
+
112
+ def _mp_pdf(self, x, L, rmt_device):
113
+ '''Marchnko-Pastur PDF'''
114
+ dic = self._mp_parameters(L, rmt_device)
115
+ if rmt_device == 'cpu':
116
+ y = np.empty_like(x)
117
+ else:
118
+ y = cp.empty_like(x)
119
+ for i, xi in enumerate(x):
120
+ y[i] = self._marchenko_pastur(xi, dic)
121
+ return y
122
+
123
+ def _mp_calculation(self, L, Lr, rmt_device, eta=1, eps=10**-6, max_iter=1000):
124
+ converged = False
125
+ iter = 0
126
+ loss_history = []
127
+
128
+ b_plus = self._mp_parameters(Lr, rmt_device)['b_plus']
129
+ b_minus = self._mp_parameters(Lr, rmt_device)['b_minus']
130
+
131
+ L_updated = L[(L > b_minus) & (L < b_plus)]
132
+ new_b_plus = self._mp_parameters(L_updated, rmt_device)['b_plus']
133
+ new_b_minus = self._mp_parameters(L_updated, rmt_device)['b_minus']
134
+
135
+ while not converged:
136
+ loss = (1 - float(new_b_plus) / float(b_plus))**2
137
+ loss_history.append(loss)
138
+ iter += 1
139
+
140
+ if loss <= eps:
141
+ converged = True
142
+ elif iter == max_iter:
143
+ print('Max interactions exceeded!')
144
+ converged = True
145
+ else:
146
+ gradient = new_b_plus - b_plus
147
+ new_b_plus = b_plus + eta * gradient
148
+
149
+ L_updated = L[(L > new_b_minus) & (L < new_b_plus)]
150
+ self.b_plus = new_b_plus
151
+ self.b_minus = new_b_minus
152
+
153
+ new_b_plus = self._mp_parameters(L_updated, rmt_device)['b_plus']
154
+ new_b_minus = self._mp_parameters(L_updated, rmt_device)['b_minus']
155
+
156
+ if rmt_device == 'cpu':
157
+ indices = np.where((L > new_b_minus) & (L < new_b_plus))
158
+ L_mp = L[indices]
159
+ return np.array(L_mp)
160
+ else:
161
+ indices = cp.where((L > new_b_minus) & (L < new_b_plus))
162
+ L_mp = L[indices]
163
+ return cp.array(L_mp)
164
+
165
+ def _cdf_marchenko(self,x,dic):
166
+ if x < dic['b_minus']:
167
+ return 0.0
168
+ elif x>dic['b_minus'] and x<dic['b_plus']:
169
+ return 1/float(2*dic['s']*np.pi*dic['gamma'])*\
170
+ float(np.sqrt((dic['b_plus']-x)*(x-dic['b_minus']))+\
171
+ (dic['b_plus']+dic['b_minus'])/2*np.arcsin((2*x-dic['b_plus']-\
172
+ dic['b_minus'])/(dic['b_plus']-dic['b_minus']))-\
173
+ np.sqrt(dic['b_plus']*dic['b_minus'])*np.arcsin(((dic['b_plus']+\
174
+ dic['b_minus'])*x-2*dic['b_plus']*dic['b_minus'])/\
175
+ ((dic['b_plus']-dic['b_minus'])*x)) )+np.arcsin(1)/np.pi
176
+ else:
177
+ return 1.0
178
+
179
+ def _call_mp_cdf(self,L,dic):
180
+ "CDF of Marchenko Pastur"
181
+ func= lambda y: list(map(lambda x: self._cdf_marchenko(x,dic), y))
182
+ return func