sciv 0.0.94__py3-none-any.whl → 0.0.96__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sciv/model/_core_.py CHANGED
@@ -1,15 +1,12 @@
1
1
  # -*- coding: UTF-8 -*-
2
2
 
3
3
  import os.path
4
- import shutil
5
4
  import time
6
- from typing import Optional, Union, Literal, Tuple
5
+ from typing import Optional, Union, Literal
7
6
 
8
7
  import numpy as np
9
8
 
10
- import anndata as ad
11
9
  import pandas as pd
12
- from tqdm import tqdm
13
10
  from anndata import AnnData
14
11
  from pandas import DataFrame
15
12
 
@@ -24,6 +21,8 @@ __name__: str = "model_core"
24
21
 
25
22
 
26
23
  def _run_random_walk_(random_walk: RandomWalk, is_ablation: bool, is_simple: bool) -> AnnData:
24
+ start_time = time.time()
25
+
27
26
  if not random_walk.is_run_core:
28
27
  random_walk.run_core()
29
28
 
@@ -58,11 +57,15 @@ def _run_random_walk_(random_walk: RandomWalk, is_ablation: bool, is_simple: boo
58
57
  if not random_walk.is_run_en_ablation_m_knn:
59
58
  random_walk.run_en_ablation_m_knn()
60
59
 
60
+ random_walk.elapsed_time += time.time() - start_time
61
+
61
62
  return random_walk.trs_adata
62
63
 
63
64
 
64
- def _check_and_run_two_step_(
65
+ def core(
65
66
  adata: AnnData,
67
+ variants: dict,
68
+ trait_info: DataFrame,
66
69
  cell_rate: Optional[float] = None,
67
70
  peak_rate: Optional[float] = None,
68
71
  max_epochs: int = 500,
@@ -76,7 +79,9 @@ def _check_and_run_two_step_(
76
79
  k: int = 30,
77
80
  or_k: int = 1,
78
81
  weight: float = 0.1,
79
- laplacian_gamma: Optional[float] = None,
82
+ kernel: Literal["laplacian", "gaussian"] = "gaussian",
83
+ local_k: int = 10,
84
+ kernel_gamma: Optional[float, collection] = None,
80
85
  epsilon: float = 1e-05,
81
86
  gamma: float = 0.05,
82
87
  enrichment_gamma: float = 0.05,
@@ -94,9 +99,105 @@ def _check_and_run_two_step_(
94
99
  is_save_random_walk_model: bool = False,
95
100
  is_file_exist_loading: bool = False,
96
101
  filename_dict: Optional[dict] = None,
97
- single_chunk_size: int = 500,
98
102
  block_size: int = -1
99
- ) -> Tuple[dict, dict, str, bool, str, str, str, str, AnnData, AnnData, AnnData]:
103
+ ) -> AnnData:
104
+ """
105
+ The core algorithm of sciv includes the flow of all algorithms, as well as drawing and saving data.
106
+ In the entire algorithm, the samples are in the row position, and the traits or diseases are in the column position,
107
+ while ensuring that there is no interaction between the traits or diseases,
108
+ ensuring the stability of the results;
109
+ Meaning of main variables:
110
+ 1. `overlap_adata`, (obs: peaks, var: traits/diseases) Peaks-traits/diseases data obtained by overlaying variant
111
+ data with peaks.
112
+ 2. `da_peaks`, (obs: clusters (Leiden), var: peaks) Differential peak data of cell clustering, used for weight
113
+ correction of cells.
114
+ 3. `init_score`, (obs: cells, var: traits/diseases) This is the initial TRS data.
115
+ 4. `cc_data`, (obs: cells, var: cells) Cell similarity data.
116
+ 5. `random_walk`, RandomWalk class.
117
+ 6. `trs`, (obs: cells, var: traits/diseases) This is the final TRS data.
118
+ :param adata: scATAC-seq data;
119
+ :param variants: variant data; This data is recommended to be obtained by executing the `fl.read_variants` method.
120
+ :param trait_info: variant annotation file information;
121
+ :param cell_rate: Removing the percentage of cell count in total cell count only takes effect when the min_cells
122
+ parameter is None;
123
+ :param peak_rate: Removing the percentage of peak count in total peak count only takes effect when the min_peaks
124
+ parameter is None;
125
+ :param max_epochs: The maximum number of epochs for PoissonVI training;
126
+ :param lr: Learning rate for optimization;
127
+ :param batch_size: Minibatch size to use during training;
128
+ :param eps: Optimizer eps;
129
+ :param early_stopping: Whether to perform early stopping with respect to the validation set;
130
+ :param early_stopping_patience: How many epochs to wait for improvement before early stopping;
131
+ :param batch_key: Batch information in scATAC-seq data;
132
+ :param resolution: Resolution of the Leiden Cluster. The recommended values are any one of 0.4, 0.9, 1.3, 1.5;
133
+ :param k: When building an mKNN network, the number of nodes connected by each node (and operation);
134
+ :param or_k: When building an mKNN network, the number of nodes connected by each node (or operation);
135
+ :param weight: The weight of interactions or operations;
136
+ :param local_k: Determining the number of neighbors for the adaptive kernel;
137
+ :param kernel: Determine the kernel function to be used;
138
+ :param kernel_gamma: If None, it defaults to the adaptive value obtained through the local information of
139
+ parameter `local_k`. Otherwise, it should be strictly positive;
140
+ :param epsilon: conditions for stopping in random walk;
141
+ :param gamma: reset weight for random walk;
142
+ :param enrichment_gamma: reset weight for random walk for enrichment;
143
+ :param p: Distance used for loss {1: Manhattan distance, 2: Euclidean distance};
144
+ :param n_jobs: The maximum number of concurrently running jobs;
145
+ :param min_seed_cell_rate: The minimum percentage of seed cells in all cells;
146
+ :param max_seed_cell_rate: The maximum percentage of seed cells in all cells;
147
+ :param credible_threshold: The threshold for determining the credibility of enriched cells in the context of
148
+ enrichment, i.e. the threshold for judging enriched cells;
149
+ :param diff_peak_value: Specify the correction value in peak correction of clustering type differences.
150
+ {'emp_effect', 'bayes_factor', 'emp_prob1'}
151
+ :param enrichment_threshold: Only by setting a threshold for the standardized output TRS can a portion of the
152
+ enrichment results be obtained. Parameters support string types {'golden', 'half', 'e', 'pi', 'none'},
153
+ or valid floating-point types within the range of (0, log1p(1)).
154
+ :param is_ablation: True represents obtaining the results of the ablation experiment. This parameter is limited by
155
+ the `is_simple` parameter, and its effectiveness requires setting `is_simple` to `False`;
156
+ :param model_dir: The folder name saved by the training module;
157
+ It is worth noting that if the training model file (`model.pt`) exists in this path, it will be automatically
158
+ read and skip the training of `PoissonVI` model.
159
+ :param save_path: Save path for process files and result files;
160
+ :param is_simple: True represents not adding unnecessary intermediate variables, only adding the final result.
161
+ It is worth noting that when set to `True`, the `is_ablation` parameter will become invalid, and when set to
162
+ `False`, `is_ablation` will only take effect;
163
+ :param is_save_random_walk_model: Default to `False`, do not save random walk model. When setting `True`, please
164
+ ensure sufficient storage as the saved `pkl` file is relatively large.
165
+ :param is_file_exist_loading: By default, the file will be overwritten. When set to `True`, if the file exists, the
166
+ process will be skipped and the file will be directly read as the result;
167
+ :param filename_dict: The name of the file that exists.
168
+ default: {
169
+ "sc_atac": "sc_atac.h5ad",
170
+ "da_peaks": "da_peaks.h5ad",
171
+ "atac_overlap": "atac_overlap.h5ad",
172
+ "init_score": "init_score.h5ad",
173
+ "cc_data": "cc_data.h5ad",
174
+ "random_walk": "random_walk.h5ad",
175
+ "trs": "trs.h5ad"
176
+ }
177
+ :param block_size: The size of the segmentation stored in block wise matrix multiplication.
178
+ By sacrificing time and space to reduce memory consumption to a certain extent.
179
+ If the value is less than or equal to zero, no block operation will be performed.
180
+ :return: `trs`, (obs: cells, var: traits/diseases) This is the final TRS data.
181
+ """
182
+
183
+ # start time
184
+ start_time = time.time()
185
+
186
+ if len(variants.keys()) == 0:
187
+ ul.log(__name__).error("The number of mutations is empty.")
188
+ raise ValueError("The number of mutations is empty.")
189
+
190
+ _trait_count_ = trait_info.shape[0]
191
+
192
+ if len(variants.keys()) != _trait_count_:
193
+ ul.log(__name__).error(
194
+ "The parameters `variants` and `trait_info` are inconsistent. "
195
+ "These two parameters can be obtained using method `fl.read_variants`."
196
+ )
197
+ raise ValueError(
198
+ "The parameters `variants` and `trait_info` are inconsistent. "
199
+ "These two parameters can be obtained using method `fl.read_variants`."
200
+ )
100
201
 
101
202
  if adata.shape[0] == 0:
102
203
  ul.log(__name__).error("The scATAC-seq data is empty.")
@@ -115,8 +216,12 @@ def _check_and_run_two_step_(
115
216
  )
116
217
 
117
218
  if batch_key is not None and batch_key not in adata.obs.columns:
118
- ul.log(__name__).error(f"The cells information {adata.obs.columns} in data `adata` must include the {batch_key} column.")
119
- raise ValueError(f"The cells information {adata.obs.columns} in data `adata` must include the {batch_key} column.")
219
+ ul.log(__name__).error(
220
+ f"The cells information {adata.obs.columns} in data `adata` must include the {batch_key} column."
221
+ )
222
+ raise ValueError(
223
+ f"The cells information {adata.obs.columns} in data `adata` must include the {batch_key} column."
224
+ )
120
225
 
121
226
  if cell_rate is not None:
122
227
 
@@ -130,10 +235,6 @@ def _check_and_run_two_step_(
130
235
  ul.log(__name__).error("The parameter of `peak_rate` should be between 0 and 1.")
131
236
  raise ValueError("The parameter of `peak_rate` should be between 0 and 1.")
132
237
 
133
- if single_chunk_size <= 0:
134
- ul.log(__name__).error("The parameter `single_chunk_size` must be greater than zero.")
135
- raise ValueError("The parameter `single_chunk_size` must be greater than zero.")
136
-
137
238
  if resolution <= 0:
138
239
  ul.log(__name__).error("The parameter `resolution` must be greater than zero.")
139
240
  raise ValueError("The parameter `resolution` must be greater than zero.")
@@ -152,6 +253,14 @@ def _check_and_run_two_step_(
152
253
  "which is highly likely to result in poor performance."
153
254
  )
154
255
 
256
+ if local_k <= 0:
257
+ ul.log(__name__).error("The `local_k` parameter must be a natural number greater than 0.")
258
+ raise ValueError("The `local_k` parameter must be a natural number greater than 0.")
259
+
260
+ if kernel not in ["laplacian", "gaussian"]:
261
+ ul.log(__name__).error("Parameter `kernel` only supports two values, `laplacian` and `gaussian`.")
262
+ raise ValueError("Parameter `kernel` only supports two values, `laplacian` and `gaussian`.")
263
+
155
264
  if weight < 0 or weight > 1:
156
265
  ul.log(__name__).error("The parameter of `weight` should be between 0 and 1.")
157
266
  raise ValueError("The parameter of `weight` should be between 0 and 1.")
@@ -190,36 +299,40 @@ def _check_and_run_two_step_(
190
299
  if isinstance(enrichment_threshold, float):
191
300
 
192
301
  if enrichment_threshold <= 0 or enrichment_threshold >= np.log1p(1):
193
- ul.log(__name__).warning("The `enrichment_threshold` parameter is not set within the range of (0, log1p(1)), this parameter will become invalid.")
194
- ul.log(__name__).warning("It is recommended to set the `enrichment_threshold` parameter to the 'golden' value.")
302
+ ul.log(__name__).warning(
303
+ "The `enrichment_threshold` parameter is not set within the range of (0, log1p(1)), this parameter "
304
+ "will become invalid."
305
+ )
306
+ ul.log(__name__).warning(
307
+ "It is recommended to set the `enrichment_threshold` parameter to the 'golden' value."
308
+ )
195
309
 
196
310
  elif enrichment_threshold not in ["golden", "half", "e", "pi", "none"]:
311
+ ul.log(__name__).error(
312
+ "Invalid enrichment settings. The string type in the `enrichment_threshold` parameter only supports the "
313
+ "following parameter 'golden', 'half', 'e', 'pi', 'none', Alternatively, input a floating-point type "
314
+ "value within the range of (0, log1p(1))"
315
+ )
197
316
  raise ValueError(
198
- f"Invalid enrichment settings. The string type in the `enrichment_threshold` parameter only supports the following parameter "
199
- f"'golden', 'half', 'e', 'pi', 'none', Alternatively, input a floating-point type value within the range of (0, log1p(1))"
317
+ "Invalid enrichment settings. The string type in the `enrichment_threshold` parameter only supports the "
318
+ "following parameter 'golden', 'half', 'e', 'pi', 'none', Alternatively, input a floating-point type "
319
+ "value within the range of (0, log1p(1))"
200
320
  )
201
321
 
202
322
  if diff_peak_value not in ['emp_effect', 'bayes_factor', 'emp_prob1', 'all']:
203
- ul.log(__name__).error("The `diff_peak_value` parameter only supports one of the {'emp_effect', 'bayes_factor', 'emp_prob1', 'all'} values.")
204
- raise ValueError("The `diff_peak_value` parameter only supports one of the {'emp_effect', 'bayes_factor', 'emp_prob1', 'all'} values.")
205
-
206
- # get cache path
207
- cache_path = str(ul.project_cache_path)
208
- ul.file_method(__name__).makedirs(cache_path)
209
-
210
- # Assign a name to the formed document
211
- cache_path_dict: dict = {
212
- "atac_overlap": os.path.join(cache_path, "atac_overlap"),
213
- "init_score": os.path.join(cache_path, "init_score"),
214
- "random_walk": os.path.join(cache_path, "random_walk"),
215
- "trs": os.path.join(cache_path, "trs")
216
- }
323
+ ul.log(__name__).error(
324
+ "The `diff_peak_value` parameter only supports one of the "
325
+ "{'emp_effect', 'bayes_factor', 'emp_prob1', 'all'} values."
326
+ )
327
+ raise ValueError(
328
+ "The `diff_peak_value` parameter only supports one of the "
329
+ "{'emp_effect', 'bayes_factor', 'emp_prob1', 'all'} values."
330
+ )
217
331
 
218
332
  # parameter information
219
333
  params: dict = {
220
334
  "cell_rate": cell_rate,
221
335
  "peak_rate": peak_rate,
222
- "single_chunk_size": single_chunk_size,
223
336
  "max_epochs": int(max_epochs),
224
337
  "lr": lr,
225
338
  "batch_size": batch_size,
@@ -231,7 +344,9 @@ def _check_and_run_two_step_(
231
344
  "k": k,
232
345
  "or_k": or_k,
233
346
  "weight": weight,
234
- "laplacian_gamma": laplacian_gamma,
347
+ "kernel": kernel,
348
+ "local_k": local_k,
349
+ "kernel_gamma": kernel_gamma,
235
350
  "epsilon": epsilon,
236
351
  "gamma": gamma,
237
352
  "enrichment_gamma": enrichment_gamma,
@@ -245,7 +360,6 @@ def _check_and_run_two_step_(
245
360
  "is_ablation": is_ablation,
246
361
  "model_dir": str(model_dir),
247
362
  "save_path": str(save_path),
248
- "cache_path": str(cache_path),
249
363
  "is_simple": is_simple,
250
364
  "is_save_random_walk_model": is_save_random_walk_model,
251
365
  "is_file_exist_loading": is_file_exist_loading,
@@ -274,14 +388,17 @@ def _check_and_run_two_step_(
274
388
  "trs": "trs.h5ad"
275
389
  }
276
390
 
391
+ def _get_file_(_id_: str) -> str:
392
+ return os.path.join(save_path, str(filename_dict.get(_id_, f"{_id_}.h5ad"))) if save_path else None
393
+
277
394
  # Assign a name to the formed document
278
- adata_save_file = os.path.join(save_path, "sc_atac.h5ad" if "sc_atac" not in filename_dict else str(filename_dict["sc_atac"])) if save_path is not None else None
279
- da_peaks_save_file = os.path.join(save_path, "da_peaks.h5ad" if "da_peaks" not in filename_dict else str(filename_dict["da_peaks"])) if save_path is not None else None
280
- atac_overlap_save_file = os.path.join(save_path, "atac_overlap.h5ad" if "atac_overlap" not in filename_dict else str(filename_dict["atac_overlap"])) if save_path is not None else None
281
- init_score_save_file = os.path.join(save_path, "init_score.h5ad" if "init_score" not in filename_dict else str(filename_dict["init_score"])) if save_path is not None else None
282
- cc_data_save_file = os.path.join(save_path, "cc_data.h5ad" if "cc_data" not in filename_dict else str(filename_dict["cc_data"])) if save_path is not None else None
283
- random_walk_save_file = os.path.join(save_path, "random_walk.pkl" if "random_walk" not in filename_dict else str(filename_dict["random_walk"])) if save_path is not None else None
284
- trs_save_file = os.path.join(save_path, "trs.h5ad" if "trs" not in filename_dict else str(filename_dict["trs"])) if save_path is not None else None
395
+ adata_save_file = _get_file_("sc_atac")
396
+ da_peaks_save_file = _get_file_("da_peaks")
397
+ atac_overlap_save_file = _get_file_("atac_overlap")
398
+ init_score_save_file = _get_file_("init_score")
399
+ cc_data_save_file = _get_file_("cc_data")
400
+ random_walk_save_file = _get_file_("random_walk")
401
+ trs_save_file = _get_file_("trs")
285
402
 
286
403
  """
287
404
  1. Filter scATAC-seq data, PoissonVI
@@ -291,6 +408,7 @@ def _check_and_run_two_step_(
291
408
  da_peaks_is_read: bool = False
292
409
 
293
410
  if is_file_exist_loading:
411
+
294
412
  if os.path.exists(adata_save_file):
295
413
  adata = read_h5ad(adata_save_file)
296
414
  adata_is_read = True
@@ -321,6 +439,7 @@ def _check_and_run_two_step_(
321
439
 
322
440
  else:
323
441
  filter_data(adata, cell_rate=cell_rate, peak_rate=peak_rate)
442
+
324
443
  da_peaks = poisson_vi(
325
444
  adata,
326
445
  max_epochs=max_epochs,
@@ -334,7 +453,10 @@ def _check_and_run_two_step_(
334
453
  model_dir=model_dir
335
454
  )
336
455
 
456
+ step1_time = adata.uns["elapsed_time"] + da_peaks.uns["elapsed_time"]
457
+
337
458
  if save_path is not None:
459
+
338
460
  if not adata_is_read:
339
461
  save_h5ad(adata, file=adata_save_file)
340
462
 
@@ -342,599 +464,140 @@ def _check_and_run_two_step_(
342
464
  save_h5ad(da_peaks, file=da_peaks_save_file)
343
465
 
344
466
  """
345
- 2. Calculate cell-cell correlation. Building a network between cells.
467
+ 2. Overlap regional data and mutation data and sum the PP values of all mutations
468
+ in a region as the values for that region
346
469
  """
347
470
 
348
- cc_data_is_read: bool = is_file_exist_loading and os.path.exists(cc_data_save_file)
349
-
350
- if cc_data_is_read:
351
- cc_data: AnnData = read_h5ad(cc_data_save_file)
352
- else:
353
- # cell-cell network
354
- cc_data = obtain_cell_cell_network(adata=adata, k=k, or_k=or_k, weight=weight, gamma=laplacian_gamma, is_simple=is_simple)
355
-
356
- if save_path is not None and not cc_data_is_read:
357
- save_h5ad(cc_data, file=cc_data_save_file)
358
-
359
- return (
360
- cache_path_dict, params, save_path, is_file_exist_loading,
361
- atac_overlap_save_file, init_score_save_file, random_walk_save_file, trs_save_file,
362
- adata, da_peaks, cc_data
363
- )
364
-
471
+ # Determine whether it is necessary to read the file
472
+ overlap_is_read: bool = is_file_exist_loading and os.path.exists(atac_overlap_save_file)
365
473
 
366
- def core(
367
- adata: AnnData,
368
- variants: dict,
369
- trait_info: DataFrame,
370
- cell_rate: Optional[float] = None,
371
- peak_rate: Optional[float] = None,
372
- max_epochs: int = 500,
373
- lr: float = 1e-4,
374
- batch_size: int = 128,
375
- eps: float = 1e-08,
376
- early_stopping: bool = True,
377
- early_stopping_patience: int = 50,
378
- batch_key: Optional[str] = None,
379
- resolution: float = 0.5,
380
- k: int = 30,
381
- or_k: int = 1,
382
- weight: float = 0.1,
383
- laplacian_gamma: Optional[float] = None,
384
- epsilon: float = 1e-05,
385
- gamma: float = 0.05,
386
- enrichment_gamma: float = 0.05,
387
- p: int = 2,
388
- n_jobs: int = -1,
389
- min_seed_cell_rate: float = 0.01,
390
- max_seed_cell_rate: float = 0.05,
391
- credible_threshold: float = 0,
392
- diff_peak_value: difference_peak_optional = 'emp_effect',
393
- enrichment_threshold: Union[enrichment_optional, float] = 'golden',
394
- is_ablation: bool = False,
395
- model_dir: Optional[path] = None,
396
- save_path: Optional[path] = None,
397
- is_simple: bool = True,
398
- is_save_random_walk_model: bool = False,
399
- is_file_exist_loading: bool = False,
400
- filename_dict: Optional[dict] = None,
401
- single_chunk_size: int = 500,
402
- block_size: int = -1
403
- ) -> AnnData:
404
- """
405
- The core algorithm of sciv includes the flow of all algorithms, as well as drawing and saving data.
406
- In the entire algorithm, the samples are in the row position, and the traits or diseases are in the column position,
407
- while ensuring that there is no interaction between the traits or diseases,
408
- ensuring the stability of the results;
409
- Meaning of main variables:
410
- 1. `overlap_adata`, (obs: peaks, var: traits/diseases) Peaks-traits/diseases data obtained by overlaying variant
411
- data with peaks.
412
- 2. `da_peaks`, (obs: clusters (Leiden), var: peaks) Differential peak data of cell clustering, used for weight
413
- correction of cells.
414
- 3. `init_score`, (obs: cells, var: traits/diseases) This is the initial TRS data.
415
- 4. `cc_data`, (obs: cells, var: cells) Cell similarity data.
416
- 5. `random_walk`, RandomWalk class.
417
- 6. `trs`, (obs: cells, var: traits/diseases) This is the final TRS data.
418
- :param adata: scATAC-seq data;
419
- :param variants: variant data; This data is recommended to be obtained by executing the `fl.read_variants` method.
420
- :param trait_info: variant annotation file information;
421
- :param cell_rate: Removing the percentage of cell count in total cell count only takes effect when the min_cells
422
- parameter is None;
423
- :param peak_rate: Removing the percentage of peak count in total peak count only takes effect when the min_peaks
424
- parameter is None;
425
- :param max_epochs: The maximum number of epochs for PoissonVI training;
426
- :param lr: Learning rate for optimization;
427
- :param batch_size: Minibatch size to use during training;
428
- :param eps: Optimizer eps;
429
- :param early_stopping: Whether to perform early stopping with respect to the validation set;
430
- :param early_stopping_patience: How many epochs to wait for improvement before early stopping;
431
- :param batch_key: Batch information in scATAC-seq data;
432
- :param resolution: Resolution of the Leiden Cluster. The recommended values are any one of 0.4, 0.9, 1.3, 1.5;
433
- :param k: When building an mKNN network, the number of nodes connected by each node (and operation);
434
- :param or_k: When building an mKNN network, the number of nodes connected by each node (or operation);
435
- :param weight: The weight of interactions or operations;
436
- :param laplacian_gamma: If None, defaults to 1.0 / n_features. Otherwise, it should be strictly positive;
437
- :param epsilon: conditions for stopping in random walk;
438
- :param gamma: reset weight for random walk;
439
- :param enrichment_gamma: reset weight for random walk for enrichment;
440
- :param p: Distance used for loss {1: Manhattan distance, 2: Euclidean distance};
441
- :param n_jobs: The maximum number of concurrently running jobs;
442
- :param min_seed_cell_rate: The minimum percentage of seed cells in all cells;
443
- :param max_seed_cell_rate: The maximum percentage of seed cells in all cells;
444
- :param credible_threshold: The threshold for determining the credibility of enriched cells in the context of
445
- enrichment, i.e. the threshold for judging enriched cells;
446
- :param diff_peak_value: Specify the correction value in peak correction of clustering type differences.
447
- {'emp_effect', 'bayes_factor', 'emp_prob1'}
448
- :param enrichment_threshold: Only by setting a threshold for the standardized output TRS can a portion of the enrichment
449
- results be obtained. Parameters support string types {'golden', 'half', 'e', 'pi', 'none'}, or valid floating-point types
450
- within the range of (0, log1p(1)).
451
- :param is_ablation: True represents obtaining the results of the ablation experiment. This parameter is limited by
452
- the `is_simple` parameter, and its effectiveness requires setting `is_simple` to `False`;
453
- :param model_dir: The folder name saved by the training module;
454
- It is worth noting that if the training model file (`model.pt`) exists in this path, it will be automatically read and skip
455
- the training of `PoissonVI` model.
456
- :param save_path: Save path for process files and result files;
457
- :param is_simple: True represents not adding unnecessary intermediate variables, only adding the final result.
458
- It is worth noting that when set to `True`, the `is_ablation` parameter will become invalid, and when set to
459
- `False`, `is_ablation` will only take effect;
460
- :param is_save_random_walk_model: Default to `False`, do not save random walk model. When setting `True`, please
461
- ensure sufficient storage as the saved `pkl` file is relatively large.
462
- :param is_file_exist_loading: By default, the file will be overwritten. When set to `True`, if the file exists, the
463
- process will be skipped and the file will be directly read as the result;
464
- :param single_chunk_size: The size of a single chunk;
465
- :param filename_dict: The name of the file that exists.
466
- default: {
467
- "sc_atac": "sc_atac.h5ad",
468
- "da_peaks": "da_peaks.h5ad",
469
- "atac_overlap": "atac_overlap.h5ad",
470
- "init_score": "init_score.h5ad",
471
- "cc_data": "cc_data.h5ad",
472
- "random_walk": "random_walk.h5ad",
473
- "trs": "trs.h5ad"
474
- }
475
- :param block_size: The size of the segmentation stored in block wise matrix multiplication.
476
- By sacrificing time and space to reduce memory consumption to a certain extent.
477
- If the value is less than or equal to zero, no block operation will be performed.
478
- :return: `trs`, (obs: cells, var: traits/diseases) This is the final TRS data.
479
- """
474
+ if overlap_is_read:
475
+ overlap_adata: AnnData = read_h5ad(atac_overlap_save_file)
480
476
 
481
- # start time
482
- start_time = time.time()
477
+ if overlap_adata.var.shape[0] != _trait_count_:
478
+ ul.log(__name__).warning(
479
+ f"The number of diseases read from file `atac_overlap.h5ad` are inconsistent with the input ({overlap_adata.var.shape[0]} != {_trait_count_}). "
480
+ f"Please check and verify. If the verification is not as expected, file `atac_overlap.h5ad` needs to be moved or deleted."
481
+ )
483
482
 
484
- if len(variants.keys()) == 0:
485
- ul.log(__name__).error("The number of mutations is empty.")
486
- raise ValueError("The number of mutations is empty.")
483
+ else:
484
+ overlap_adata: AnnData = overlap_sum(adata, variants, trait_info)
487
485
 
488
- _trait_count_ = trait_info.shape[0]
486
+ del variants, trait_info
489
487
 
490
- if len(variants.keys()) != _trait_count_:
491
- ul.log(__name__).error(
492
- "The parameters `variants` and `trait_info` are inconsistent. "
493
- "These two parameters can be obtained using method `fl.read_variants`."
494
- )
495
- raise ValueError(
496
- "The parameters `variants` and `trait_info` are inconsistent. "
497
- "These two parameters can be obtained using method `fl.read_variants`."
498
- )
488
+ step2_time = overlap_adata.uns["elapsed_time"]
499
489
 
500
- (
501
- cache_path_dict, params, save_path, is_file_exist_loading,
502
- atac_overlap_save_file, init_score_save_file, random_walk_save_file, trs_save_file,
503
- adata, da_peaks, cc_data
504
- ) = _check_and_run_two_step_(
505
- adata=adata,
506
- cell_rate=cell_rate,
507
- peak_rate=peak_rate,
508
- max_epochs=max_epochs,
509
- lr=lr,
510
- batch_size=batch_size,
511
- eps=eps,
512
- early_stopping=early_stopping,
513
- early_stopping_patience=early_stopping_patience,
514
- batch_key=batch_key,
515
- resolution=resolution,
516
- k=k,
517
- or_k=or_k,
518
- weight=weight,
519
- laplacian_gamma=laplacian_gamma,
520
- epsilon=epsilon,
521
- gamma=gamma,
522
- p=p,
523
- n_jobs=n_jobs,
524
- min_seed_cell_rate=min_seed_cell_rate,
525
- max_seed_cell_rate=max_seed_cell_rate,
526
- credible_threshold=max_seed_cell_rate,
527
- enrichment_threshold=enrichment_threshold,
528
- diff_peak_value=diff_peak_value,
529
- is_ablation=is_ablation,
530
- model_dir=model_dir,
531
- save_path=save_path,
532
- is_simple=is_simple,
533
- is_save_random_walk_model=is_save_random_walk_model,
534
- is_file_exist_loading=is_file_exist_loading,
535
- filename_dict=filename_dict,
536
- single_chunk_size=single_chunk_size,
537
- block_size=block_size
538
- )
490
+ if save_path is not None and not overlap_is_read:
491
+ save_h5ad(overlap_adata, file=atac_overlap_save_file)
539
492
 
540
493
  """
541
- 3, 4, 5 steps
494
+ 3. Calculate the initial trait relevance scores for each cell
542
495
  """
543
- variants_key_list: list = list(trait_info["id"])
544
- # Quantity of traits
545
- trait_size: int = trait_info.shape[0]
546
496
 
547
- # Determine whether it is necessary to read the file
548
- overlap_is_read: bool = is_file_exist_loading and os.path.exists(atac_overlap_save_file)
549
497
  init_score_is_read: bool = is_file_exist_loading and os.path.exists(init_score_save_file)
550
498
 
551
- # Number of Blocks
552
- chunk_size: int = int(np.ceil(trait_size / single_chunk_size))
553
-
554
- overlap_adata: Union[AnnData, None] = None
555
- init_score: Union[AnnData, None] = None
499
+ if init_score_is_read:
500
+ init_score: AnnData = read_h5ad(init_score_save_file)
556
501
 
557
- # overlap
558
- if overlap_is_read:
559
- overlap_adata: AnnData = read_h5ad(atac_overlap_save_file)
502
+ if init_score.var.shape[0] != _trait_count_:
503
+ ul.log(__name__).warning(
504
+ f"The number of diseases read from file `init_score.h5ad` are inconsistent with the input ({init_score.var.shape[0]} != {_trait_count_}). "
505
+ f"Please check and verify. If the verification is not as expected, file `init_score.h5ad` needs to be moved or deleted."
506
+ )
560
507
 
561
- if chunk_size > 1:
508
+ else:
509
+ # intermediate score data, integration data
510
+ init_score: AnnData = calculate_init_score_weight(
511
+ adata=adata,
512
+ da_peaks_adata=da_peaks,
513
+ overlap_adata=overlap_adata,
514
+ diff_peak_value=diff_peak_value,
515
+ is_simple=is_simple,
516
+ block_size=block_size
517
+ )
562
518
 
563
- if overlap_adata.var.shape[0] != _trait_count_:
564
- ul.log(__name__).error(
565
- f"The number of diseases read from file `atac_overlap.h5ad` are inconsistent with the input ({overlap_adata.var.shape[0]} != {_trait_count_}). "
566
- f"Please check and verify. If the verification is not as expected, file `atac_overlap.h5ad` needs to be moved or deleted."
567
- )
568
- raise ValueError(
569
- f"The number of diseases read from file `atac_overlap.h5ad` are inconsistent with the input ({overlap_adata.var.shape[0]} != {_trait_count_}). "
570
- f"Please check and verify. If the verification is not as expected, file `atac_overlap.h5ad` needs to be moved or deleted."
571
- )
519
+ del da_peaks, overlap_adata
572
520
 
573
- else:
521
+ step3_time = init_score.uns["elapsed_time"]
574
522
 
575
- if overlap_adata.var.shape[0] != _trait_count_:
576
- ul.log(__name__).warning(
577
- f"The number of diseases read from file `atac_overlap.h5ad` are inconsistent with the input ({overlap_adata.var.shape[0]} != {_trait_count_}). "
578
- f"Please check and verify. If the verification is not as expected, file `atac_overlap.h5ad` needs to be moved or deleted."
579
- )
523
+ if save_path is not None and not init_score_is_read:
524
+ save_h5ad(init_score, file=init_score_save_file)
580
525
 
581
- if init_score_is_read:
582
- init_score: AnnData = read_h5ad(init_score_save_file)
526
+ """
527
+ 4. Calculate cell-cell correlation. Building a network between cells.
528
+ """
583
529
 
584
- if chunk_size > 1:
530
+ cc_data_is_read: bool = is_file_exist_loading and os.path.exists(cc_data_save_file)
585
531
 
586
- if init_score.var.shape[0] != _trait_count_:
587
- ul.log(__name__).warning(
588
- f"The number of diseases read from file `init_score.h5ad` are inconsistent with the input ({init_score.var.shape[0]} != {_trait_count_}). "
589
- f"Please check and verify. If the verification is not as expected, file `init_score.h5ad` needs to be moved or deleted."
590
- )
591
- raise ValueError(
592
- f"The number of diseases read from file `init_score.h5ad` are inconsistent with the input ({init_score.var.shape[0]} != {_trait_count_}). "
593
- f"Please check and verify. If the verification is not as expected, file `init_score.h5ad` needs to be moved or deleted."
594
- )
532
+ if cc_data_is_read:
533
+ cc_data: AnnData = read_h5ad(cc_data_save_file)
534
+ else:
535
+ # cell-cell network
536
+ cc_data = obtain_cell_cell_network(
537
+ adata=adata,
538
+ k=k,
539
+ or_k=or_k,
540
+ weight=weight,
541
+ kernel=kernel,
542
+ local_k=local_k,
543
+ gamma=kernel_gamma,
544
+ is_simple=is_simple
545
+ )
595
546
 
596
- else:
547
+ del adata
597
548
 
598
- if init_score.var.shape[0] != _trait_count_:
599
- ul.log(__name__).warning(
600
- f"The number of diseases read from file `init_score.h5ad` are inconsistent with the input ({init_score.var.shape[0]} != {_trait_count_}). "
601
- f"Please check and verify. If the verification is not as expected, file `init_score.h5ad` needs to be moved or deleted."
602
- )
603
-
604
- if chunk_size > 1:
605
-
606
- # Create cache container folder
607
- for _path_ in cache_path_dict.values():
608
- ul.file_method(__name__).makedirs(_path_)
609
-
610
- ul.log(__name__).info(f"Due to excessive traits/diseases, divide and conquer. A total of {chunk_size} blocks need to be processed, with {single_chunk_size} elements per block.")
611
- # Separate execution
612
- for chunk in range(chunk_size):
613
- # Index of the start and end of the traits obtained
614
- _start_ = chunk * single_chunk_size
615
- _end_ = _start_ + single_chunk_size if trait_size > _start_ + single_chunk_size else trait_size
616
- ul.log(__name__).info(f"Processing blocks from {_start_ + 1} to {_end_}")
617
-
618
- # chunk cache file
619
- _chunk_atac_overlap_save_file_ = os.path.join(cache_path_dict["atac_overlap"], f"atac_overlap_{chunk}.h5ad")
620
- _chunk_init_score_save_file_ = os.path.join(cache_path_dict["init_score"], f"init_score_{chunk}.h5ad")
621
- _chunk_random_walk_save_file_ = os.path.join(cache_path_dict["random_walk"], f"random_walk_{chunk}.pkl")
622
- _chunk_trs_save_file_ = os.path.join(cache_path_dict["trs"], f"trs_{chunk}.h5ad")
623
-
624
- # get variant info
625
- _chunk_variants_key_list_ = variants_key_list[_start_:_end_]
626
- _chunk_variants_: dict = {key: variants[key] for key in _chunk_variants_key_list_}
627
- _chunk_trait_info_: DataFrame = trait_info[trait_info["id"].isin(_chunk_variants_key_list_)]
628
- del _chunk_variants_key_list_
629
-
630
- # Determine whether the final result has been generated, and if it has, skip all intermediate calculation processes
631
- _chunk_overlap_is_read_: bool = is_file_exist_loading and os.path.exists(_chunk_atac_overlap_save_file_)
632
- _chunk_init_score_is_read_: bool = is_file_exist_loading and os.path.exists(_chunk_init_score_save_file_)
633
- _chunk_random_walk_is_read_: bool = is_file_exist_loading and os.path.exists(_chunk_random_walk_save_file_) and is_save_random_walk_model
634
- _chunk_trs_is_read_: bool = is_file_exist_loading and os.path.exists(_chunk_trs_save_file_)
635
-
636
- if _chunk_trs_is_read_:
637
- ul.log(__name__).warning(f"{_chunk_trs_save_file_} result file already exists, so skip this calculation process.")
638
- continue
639
-
640
- """
641
- 3. Overlap regional data and mutation data and sum the PP values of all mutations in a region
642
- as the values for that region
643
- """
644
-
645
- # overlap
646
- if overlap_is_read:
647
- _chunk_overlap_adata_: AnnData = overlap_adata[:, _start_:_end_]
648
- elif _chunk_overlap_is_read_:
649
- _chunk_overlap_adata_: AnnData = read_h5ad(_chunk_atac_overlap_save_file_)
650
-
651
- if _chunk_overlap_adata_.var.shape[0] != (_end_ - _start_):
652
- ul.log(__name__).warning(
653
- f"The number of diseases read from file `{_chunk_atac_overlap_save_file_}` are inconsistent with the input ({_chunk_overlap_adata_.var.shape[0]} != {_end_ - _start_}) (chunk: {chunk}). "
654
- f"Please check and verify. If the verification is not as expected, file `{_chunk_atac_overlap_save_file_}` needs to be moved or deleted."
655
- )
656
-
657
- else:
658
- _chunk_overlap_adata_: AnnData = overlap_sum(adata, _chunk_variants_, _chunk_trait_info_)
659
- save_h5ad(_chunk_overlap_adata_, file=_chunk_atac_overlap_save_file_)
660
-
661
- del _chunk_overlap_is_read_, _chunk_atac_overlap_save_file_
662
-
663
- """
664
- 4. Calculate the initial trait- or disease-related cell score with weight
665
- """
666
-
667
- # overlap
668
- if init_score_is_read:
669
- _chunk_init_score_: AnnData = init_score[:, _start_:_end_]
670
- elif _chunk_init_score_is_read_:
671
- _chunk_init_score_: AnnData = read_h5ad(_chunk_init_score_save_file_)
672
-
673
- if _chunk_init_score_.var.shape[0] != (_end_ - _start_):
674
- ul.log(__name__).warning(
675
- f"The number of diseases read from file `{_chunk_init_score_save_file_}` are inconsistent with the input ({_chunk_init_score_.var.shape[0]} != {_end_ - _start_}) (chunk: {chunk}). "
676
- f"Please check and verify. If the verification is not as expected, file `{_chunk_init_score_save_file_}` needs to be moved or deleted."
677
- )
678
-
679
- else:
680
- # intermediate score data, integration data
681
- _chunk_init_score_: AnnData = calculate_init_score_weight(
682
- adata=adata,
683
- da_peaks_adata=da_peaks,
684
- overlap_adata=_chunk_overlap_adata_,
685
- diff_peak_value=diff_peak_value,
686
- is_simple=is_simple,
687
- block_size=block_size
688
- )
689
- save_h5ad(_chunk_init_score_, file=_chunk_init_score_save_file_)
690
-
691
- del _chunk_overlap_adata_, _chunk_init_score_is_read_, _chunk_init_score_save_file_
692
-
693
- """
694
- 5. Random walk
695
- """
696
-
697
- if _chunk_random_walk_is_read_:
698
- _chunk_random_walk_: RandomWalk = read_pkl(_chunk_random_walk_save_file_)
699
- else:
700
- # random walk
701
- # noinspection DuplicatedCode
702
- _chunk_random_walk_: RandomWalk = RandomWalk(
703
- cc_adata=cc_data,
704
- init_status=_chunk_init_score_,
705
- epsilon=epsilon,
706
- gamma=gamma,
707
- enrichment_gamma=enrichment_gamma,
708
- p=p,
709
- n_jobs=n_jobs,
710
- min_seed_cell_rate=min_seed_cell_rate,
711
- max_seed_cell_rate=max_seed_cell_rate,
712
- credible_threshold=credible_threshold,
713
- enrichment_threshold=enrichment_threshold,
714
- is_ablation=is_ablation,
715
- is_simple=is_simple
716
- )
717
-
718
- if is_save_random_walk_model:
719
- save_pkl(_chunk_random_walk_, save_file=_chunk_random_walk_save_file_)
720
-
721
- del _chunk_init_score_, _chunk_random_walk_is_read_, _chunk_random_walk_save_file_
722
-
723
- if not _chunk_trs_is_read_:
724
- _chunk_trs_: AnnData = _run_random_walk_(_chunk_random_walk_, is_ablation, is_simple)
725
- _chunk_params_: dict = params.copy()
726
- _chunk_params_.update({"_start_": _start_})
727
- _chunk_params_.update({"_end_": _end_})
728
- # Save parameters
729
- _chunk_trs_.uns["params"] = _chunk_params_
730
- del _chunk_params_
731
- # save result
732
- save_h5ad(_chunk_trs_, file=_chunk_trs_save_file_)
733
- del _chunk_trs_
734
-
735
- del _chunk_trs_is_read_, _chunk_random_walk_, _chunk_trs_save_file_
736
-
737
- if save_path is not None:
738
-
739
- """
740
- (Merge) 3. Overlap regional data and mutation data and sum the PP values of all mutations in a region
741
- as the values for that region
742
- """
743
-
744
- _chunk_atac_overlap_adata_list_: list[AnnData] = []
745
- ul.log(__name__).info(f"Merge peak-trait/disease files.")
746
- for chunk in tqdm(range(chunk_size)):
747
- # chunk cache file
748
- _chunk_atac_overlap_save_file_ = os.path.join(cache_path_dict["atac_overlap"], f"atac_overlap_{chunk}.h5ad")
749
- _chunk_atac_overlap_adata_ = read_h5ad(_chunk_atac_overlap_save_file_, is_verbose=False)
750
- _chunk_atac_overlap_adata_list_.append(_chunk_atac_overlap_adata_)
751
- del _chunk_atac_overlap_save_file_, _chunk_atac_overlap_adata_
752
-
753
- # save atac_overlap
754
- _chunk_atac_overlap_adata_all_: AnnData = ad.concat(_chunk_atac_overlap_adata_list_, axis=1)
755
- del _chunk_atac_overlap_adata_list_
756
- _chunk_atac_overlap_adata_all_.var = trait_info.copy()
757
- _chunk_atac_overlap_adata_all_.uns["is_overlap"] = True
758
- save_h5ad(_chunk_atac_overlap_adata_all_, atac_overlap_save_file)
759
- del _chunk_atac_overlap_adata_all_
760
-
761
- # delete cache data
762
- ul.log(__name__).info(f"Clear cache file information: {cache_path_dict['atac_overlap']}")
763
- shutil.rmtree(cache_path_dict["atac_overlap"])
764
-
765
- """
766
- (Merge) 4. Calculate the initial trait- or disease-related cell score with weight
767
- """
768
-
769
- # merge init_score
770
- _chunk_init_score_adata_list_: list[AnnData] = []
771
- ul.log(__name__).info(f"Merge iTRS files.")
772
- for chunk in tqdm(range(chunk_size)):
773
- # chunk cache file
774
- _chunk_init_score_save_file_ = os.path.join(cache_path_dict["init_score"], f"init_score_{chunk}.h5ad")
775
- _chunk_init_score_adata_ = read_h5ad(_chunk_init_score_save_file_, is_verbose=False)
776
- _chunk_init_score_adata_list_.append(_chunk_init_score_adata_)
777
- del _chunk_init_score_save_file_, _chunk_init_score_adata_
778
-
779
- # save init_score
780
- _chunk_init_score_adata_all_: AnnData = ad.concat(_chunk_init_score_adata_list_, axis=1)
781
- del _chunk_init_score_adata_list_
782
- _chunk_init_score_adata_all_.obs = adata.obs.copy()
783
- _chunk_init_score_adata_all_.var = trait_info.copy()
784
- save_h5ad(_chunk_init_score_adata_all_, init_score_save_file)
785
- del _chunk_init_score_adata_all_
786
-
787
- # delete cache data
788
- ul.log(__name__).info(f"Clear cache file information: {cache_path_dict['init_score']}")
789
- shutil.rmtree(cache_path_dict["init_score"])
790
-
791
- """
792
- (Merge) 5. Random walk and result files
793
- """
794
-
795
- # merge trs
796
- _chunk_trs_adata_list_: list[AnnData] = []
797
- # Separate execution
798
- ul.log(__name__).info(f"Merge TRS files.")
799
- for chunk in tqdm(range(chunk_size)):
800
- # chunk cache file
801
- _chunk_trs_save_file_ = os.path.join(cache_path_dict["trs"], f"trs_{chunk}.h5ad")
802
- _chunk_trs_adata_ = read_h5ad(_chunk_trs_save_file_, is_verbose=False)
803
- _chunk_trs_adata_list_.append(_chunk_trs_adata_)
804
- del _chunk_trs_save_file_, _chunk_trs_adata_
805
-
806
- # save trs
807
- trs: AnnData = ad.concat(_chunk_trs_adata_list_, axis=1)
808
- del _chunk_trs_adata_list_
809
- trs.obs = adata.obs.copy()
810
- trs.var = trait_info.copy()
811
-
812
- # start time
813
- elapsed_time = time.time() - start_time
814
-
815
- params.update({"chunk_size": chunk_size})
816
- params.update({"elapsed_time": elapsed_time})
817
- # Save parameters
818
- trs.uns["params"] = params
819
- trs.uns["variants"] = variants
820
- trs.uns["trait_info"] = trait_info
821
- del params, variants, trait_info
822
-
823
- # delete cache data
824
- ul.log(__name__).info(f"Clear cache file information: {cache_path_dict['trs']}")
825
- shutil.rmtree(cache_path_dict["trs"])
826
-
827
- if save_path is not None:
828
- save_h5ad(trs, file=trs_save_file)
829
-
830
- if is_save_random_walk_model:
831
- _chunk_random_walk_dict_: dict = {}
832
- ul.log(__name__).info(f"Merge random walk model files.")
833
- for chunk in tqdm(range(chunk_size)):
834
- _start_ = chunk * single_chunk_size
835
- _end_ = min(_start_ + single_chunk_size, trait_size)
836
-
837
- # chunk cache file
838
- _chunk_random_walk_save_file_ = os.path.join(cache_path_dict["random_walk"], f"random_walk_{chunk}.pkl")
839
- _chunk_random_walk_data_ = read_pkl(_chunk_random_walk_save_file_, is_verbose=False)
840
- _chunk_random_walk_dict_.update({f"{_start_}_{_end_}": _chunk_random_walk_data_})
841
- del _chunk_random_walk_save_file_, _chunk_random_walk_data_
842
-
843
- save_pkl(_chunk_random_walk_dict_, save_file=random_walk_save_file)
844
-
845
- # delete cache data
846
- ul.log(__name__).info(f"Clear cache file information: {cache_path_dict['random_walk']}")
847
- shutil.rmtree(cache_path_dict["random_walk"])
848
-
849
- # Delete cache files
850
- for _path_ in cache_path_dict.values():
851
- if os.path.exists(_path_):
852
- shutil.rmtree(_path_)
549
+ step4_time = cc_data.uns["elapsed_time"]
853
550
 
854
- else:
855
-
856
- """
857
- 3. Overlap regional data and mutation data and sum the PP values of all mutations
858
- in a region as the values for that region
859
- """
860
-
861
- # overlap
862
- if not overlap_is_read:
863
- overlap_adata: AnnData = overlap_sum(adata, variants, trait_info)
864
-
865
- if save_path is not None and not overlap_is_read:
866
- save_h5ad(overlap_adata, file=atac_overlap_save_file)
867
-
868
- del overlap_is_read
869
-
870
- """
871
- 4. Calculate the initial trait relevance scores for each cell
872
- """
873
-
874
- if not init_score_is_read:
875
- # intermediate score data, integration data
876
- init_score: AnnData = calculate_init_score_weight(
877
- adata=adata,
878
- da_peaks_adata=da_peaks,
879
- overlap_adata=overlap_adata,
880
- diff_peak_value=diff_peak_value,
881
- is_simple=is_simple,
882
- block_size=block_size
883
- )
551
+ if save_path is not None and not cc_data_is_read:
552
+ save_h5ad(cc_data, file=cc_data_save_file)
884
553
 
885
- if save_path is not None and not init_score_is_read:
886
- save_h5ad(init_score, file=init_score_save_file)
554
+ """
555
+ 5. Random walk
556
+ """
887
557
 
888
- del init_score_is_read, da_peaks, overlap_adata
558
+ random_walk_is_read: bool = is_file_exist_loading and os.path.exists(random_walk_save_file) and is_save_random_walk_model
889
559
 
890
- """
891
- 5. Random walk
892
- """
560
+ if random_walk_is_read:
561
+ random_walk: RandomWalk = read_pkl(random_walk_save_file)
562
+ else:
563
+ # random walk
564
+ # noinspection DuplicatedCode
565
+ random_walk: RandomWalk = RandomWalk(
566
+ cc_adata=cc_data,
567
+ init_status=init_score,
568
+ epsilon=epsilon,
569
+ gamma=gamma,
570
+ enrichment_gamma=enrichment_gamma,
571
+ p=p,
572
+ n_jobs=n_jobs,
573
+ min_seed_cell_rate=min_seed_cell_rate,
574
+ max_seed_cell_rate=max_seed_cell_rate,
575
+ credible_threshold=credible_threshold,
576
+ enrichment_threshold=enrichment_threshold,
577
+ is_ablation=is_ablation,
578
+ is_simple=is_simple
579
+ )
893
580
 
894
- random_walk_is_read: bool = is_file_exist_loading and os.path.exists(random_walk_save_file) and is_save_random_walk_model
581
+ if save_path is not None and random_walk_is_read:
582
+ save_pkl(random_walk, save_file=random_walk_save_file)
895
583
 
896
- if random_walk_is_read:
897
- random_walk: RandomWalk = read_pkl(random_walk_save_file)
898
- else:
899
- # random walk
900
- # noinspection DuplicatedCode
901
- random_walk: RandomWalk = RandomWalk(
902
- cc_adata=cc_data,
903
- init_status=init_score,
904
- epsilon=epsilon,
905
- gamma=gamma,
906
- enrichment_gamma=enrichment_gamma,
907
- p=p,
908
- n_jobs=n_jobs,
909
- min_seed_cell_rate=min_seed_cell_rate,
910
- max_seed_cell_rate=max_seed_cell_rate,
911
- credible_threshold=credible_threshold,
912
- enrichment_threshold=enrichment_threshold,
913
- is_ablation=is_ablation,
914
- is_simple=is_simple
915
- )
584
+ del random_walk_is_read, init_score, cc_data
916
585
 
917
- if save_path is not None and random_walk_is_read:
918
- save_pkl(random_walk, save_file=random_walk_save_file)
586
+ trs = _run_random_walk_(random_walk, is_ablation, is_simple)
919
587
 
920
- del random_walk_is_read, init_score, cc_data
588
+ step5_time = random_walk.elapsed_time
921
589
 
922
- trs = _run_random_walk_(random_walk, is_ablation, is_simple)
590
+ # end time
591
+ elapsed_time = time.time() - start_time
592
+ step_time = step1_time + step2_time + step3_time + step4_time + step5_time
923
593
 
924
- # start time
925
- elapsed_time = time.time() - start_time
594
+ params.update({"elapsed_time": elapsed_time if elapsed_time > step_time else step_time})
595
+ trs.uns["params"] = params
926
596
 
927
- params.update({"chunk_size": chunk_size})
928
- params.update({"elapsed_time": elapsed_time})
929
- # Save parameters
930
- trs.uns["params"] = params
931
- trs.uns["variants"] = variants
932
- trs.uns["trait_info"] = trait_info
933
- del params, variants, trait_info
597
+ del params
934
598
 
935
- if save_path is not None:
936
- # save result
937
- save_h5ad(trs, file=trs_save_file)
599
+ if save_path is not None:
600
+ save_h5ad(trs, file=trs_save_file)
938
601
 
939
602
  return trs
940
603
 
@@ -981,6 +644,8 @@ def association_score(
981
644
  def knock(
982
645
  trs: AnnData,
983
646
  sc_atac: AnnData,
647
+ da_peaks: AnnData,
648
+ cc_data: AnnData,
984
649
  knock_trait: str,
985
650
  knock_info: dict[str, Union[str, collection]],
986
651
  knock_value: float = 0,
@@ -1058,49 +723,12 @@ def knock(
1058
723
 
1059
724
  knock_trait_info.index = knock_trait_info["id"].astype(str)
1060
725
 
1061
- (
1062
- cache_path_dict, params, save_path, is_file_exist_loading, _, _, _, _, adata, da_peaks, cc_data
1063
- ) = _check_and_run_two_step_(
1064
- adata=sc_atac,
1065
- cell_rate=params["cell_rate"] if "cell_rate" in params else None,
1066
- peak_rate=params["peak_rate"] if "peak_rate" in params else None,
1067
- max_epochs=params["max_epochs"],
1068
- lr=params["lr"] if "lr" in params else None,
1069
- batch_size=params["batch_size"] if "batch_size" in params else None,
1070
- eps=params["eps"] if "eps" in params else None,
1071
- early_stopping=params["early_stopping"] if "early_stopping" in params else None,
1072
- early_stopping_patience=params["early_stopping_patience"] if "early_stopping_patience" in params else None,
1073
- batch_key=params["batch_key"] if "batch_key" in params else None,
1074
- resolution=params["resolution"],
1075
- k=params["k"],
1076
- or_k=params["or_k"],
1077
- weight=params["weight"],
1078
- laplacian_gamma=params["laplacian_gamma"] if "laplacian_gamma" in params else None,
1079
- epsilon=params["epsilon"],
1080
- gamma=params["gamma"],
1081
- p=params["p"],
1082
- n_jobs=params["n_jobs"] if "n_jobs" in params else -1,
1083
- min_seed_cell_rate=params["min_seed_cell_rate"],
1084
- max_seed_cell_rate=params["max_seed_cell_rate"],
1085
- credible_threshold=params["credible_threshold"],
1086
- enrichment_threshold=params["enrichment_threshold"],
1087
- diff_peak_value=params["diff_peak_value"],
1088
- is_ablation=False,
1089
- model_dir=params["model_dir"] if "model_dir" in params else None,
1090
- save_path=params["save_path"] if "save_path" in params else None,
1091
- is_simple=True,
1092
- is_save_random_walk_model=False,
1093
- is_file_exist_loading=True,
1094
- single_chunk_size=params["single_chunk_size"],
1095
- block_size=params["block_size"]
1096
- )
1097
-
1098
726
  ul.log(__name__).info(f"Run the process after knocking down or knocking out.")
1099
- knock_overlap_adata: AnnData = overlap_sum(adata, knock_variants, knock_trait_info)
727
+ knock_overlap_adata: AnnData = overlap_sum(sc_atac, knock_variants, knock_trait_info)
1100
728
 
1101
729
  # intermediate score data, integration data
1102
730
  knock_init_score: AnnData = calculate_init_score_weight(
1103
- adata=adata,
731
+ adata=sc_atac,
1104
732
  da_peaks_adata=da_peaks,
1105
733
  overlap_adata=knock_overlap_adata,
1106
734
  diff_peak_value=params["diff_peak_value"],
@@ -1155,7 +783,7 @@ def knock(
1155
783
  knock_trs.var["count"] = knock_trs.var["count"].astype(int)
1156
784
 
1157
785
  # save result
1158
- if save_path is not None:
1159
- save_h5ad(knock_trs, file=os.path.join(save_path, f"knock_trs_{knock_trait}.h5ad"))
786
+ if params["save_path"] is not None:
787
+ save_h5ad(knock_trs, file=os.path.join(params["save_path"], f"knock_trs_{knock_trait}.h5ad"))
1160
788
 
1161
789
  return knock_trs