scdataloader 1.1.3__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scdataloader/VERSION CHANGED
@@ -1 +1 @@
1
- 1.1.3
1
+ 1.2.1
scdataloader/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
+ from .collator import Collator
1
2
  from .data import Dataset, SimpleAnnDataset
2
3
  from .datamodule import DataModule
3
4
  from .preprocess import Preprocessor
4
- from .collator import Collator
scdataloader/__main__.py CHANGED
@@ -1,11 +1,13 @@
1
1
  import argparse
2
+ from typing import Optional, Union
3
+
4
+ import lamindb as ln
5
+
2
6
  from scdataloader.preprocess import (
3
7
  LaminPreprocessor,
4
- additional_preprocess,
5
8
  additional_postprocess,
9
+ additional_preprocess,
6
10
  )
7
- import lamindb as ln
8
- from typing import Optional, Union
9
11
 
10
12
 
11
13
  # scdataloader --instance="laminlabs/cellxgene" --name="cellxgene-census" --version="2023-12-15" --description="preprocessed for scprint" --new_name="scprint main" --start_at=39
scdataloader/collator.py CHANGED
@@ -1,7 +1,9 @@
1
+ from typing import Optional
2
+
1
3
  import numpy as np
2
- from .utils import load_genes, downsample_profile
3
4
  from torch import Tensor, long
4
- from typing import Optional
5
+
6
+ from .utils import downsample_profile, load_genes
5
7
 
6
8
 
7
9
  class Collator:
scdataloader/data.py CHANGED
@@ -1,18 +1,20 @@
1
+ import warnings
2
+ from collections import Counter
1
3
  from dataclasses import dataclass, field
2
-
3
- import lamindb as ln
4
+ from functools import reduce
5
+ from typing import Literal, Optional, Union
4
6
 
5
7
  # ln.connect("scprint")
6
-
7
8
  import bionty as bt
9
+ import lamindb as ln
10
+ import numpy as np
8
11
  import pandas as pd
9
- from torch.utils.data import Dataset as torchDataset
10
- from typing import Union, Optional, Literal
11
- from scdataloader.mapped import MappedCollection
12
- import warnings
13
-
14
12
  from anndata import AnnData
13
+ from lamindb.core import MappedCollection
14
+ from lamindb.core._mapped_collection import _Connect
15
+ from lamindb.core.storage._anndata_accessor import _safer_read_index
15
16
  from scipy.sparse import issparse
17
+ from torch.utils.data import Dataset as torchDataset
16
18
 
17
19
  from scdataloader.utils import get_ancestry_mapping, load_genes
18
20
 
@@ -110,7 +112,16 @@ class Dataset(torchDataset):
110
112
  self.genedf = load_genes(self.organisms)
111
113
 
112
114
  self.genedf.columns = self.genedf.columns.astype(str)
113
- self.mapped_dataset._check_aligned_vars(self.genedf.index.tolist())
115
+ self.check_aligned_vars()
116
+
117
+ def check_aligned_vars(self):
118
+ vars = self.genedf.index.tolist()
119
+ i = 0
120
+ for storage in self.mapped_dataset.storages:
121
+ with _Connect(storage) as store:
122
+ if len(set(_safer_read_index(store["var"]).tolist()) - set(vars)) == 0:
123
+ i += 1
124
+ print("{}% are aligned".format(i * 100 / len(self.mapped_dataset.storages)))
114
125
 
115
126
  def __len__(self, **kwargs):
116
127
  return self.mapped_dataset.__len__(**kwargs)
@@ -145,14 +156,27 @@ class Dataset(torchDataset):
145
156
  )
146
157
  )
147
158
 
148
- def get_label_weights(self, *args, **kwargs):
149
- """
150
- get_label_weights is a wrapper around mappedDataset.get_label_weights
159
+ def get_label_weights(self, obs_keys: str | list[str], scaler: int = 10):
160
+ """Get all weights for the given label keys."""
161
+ if isinstance(obs_keys, str):
162
+ obs_keys = [obs_keys]
163
+ labels_list = []
164
+ for label_key in obs_keys:
165
+ labels_to_str = (
166
+ self.mapped_dataset.get_merged_labels(label_key).astype(str).astype("O")
167
+ )
168
+ labels_list.append(labels_to_str)
169
+ if len(labels_list) > 1:
170
+ labels = reduce(lambda a, b: a + b, labels_list)
171
+ else:
172
+ labels = labels_list[0]
151
173
 
152
- Returns:
153
- dict: dictionary of weights for each label
154
- """
155
- return self.mapped_dataset.get_label_weights(*args, **kwargs)
174
+ counter = Counter(labels) # type: ignore
175
+ rn = {n: i for i, n in enumerate(counter.keys())}
176
+ labels = np.array([rn[label] for label in labels])
177
+ counter = np.array(list(counter.values()))
178
+ weights = scaler / (counter + scaler)
179
+ return weights, labels
156
180
 
157
181
  def get_unseen_mapped_dataset_elements(self, idx: int):
158
182
  """
@@ -236,7 +260,7 @@ class Dataset(torchDataset):
236
260
  clss
237
261
  )
238
262
  )
239
- cats = self.mapped_dataset.get_merged_categories(clss)
263
+ cats = set(self.mapped_dataset.get_merged_categories(clss))
240
264
  addition = set(LABELS_TOADD.get(clss, {}).values())
241
265
  cats |= addition
242
266
  groupings, _, leaf_labels = get_ancestry_mapping(cats, parentdf)
@@ -1,21 +1,20 @@
1
+ from typing import Optional, Sequence, Union
2
+
3
+ import lamindb as ln
4
+ import lightning as L
1
5
  import numpy as np
2
6
  import pandas as pd
3
- import lamindb as ln
4
-
7
+ import torch
8
+ from torch.utils.data import DataLoader, Sampler
5
9
  from torch.utils.data.sampler import (
6
- WeightedRandomSampler,
7
- SubsetRandomSampler,
8
- SequentialSampler,
9
10
  RandomSampler,
11
+ SequentialSampler,
12
+ SubsetRandomSampler,
13
+ WeightedRandomSampler,
10
14
  )
11
- import torch
12
- from torch.utils.data import DataLoader, Sampler
13
- import lightning as L
14
-
15
- from typing import Optional, Union, Sequence
16
15
 
17
- from .data import Dataset
18
16
  from .collator import Collator
17
+ from .data import Dataset
19
18
  from .utils import getBiomartTable
20
19
 
21
20
 
@@ -110,7 +109,8 @@ class DataModule(L.LightningDataModule):
110
109
  "need to provide your own table as this automated function only works for humans for now"
111
110
  )
112
111
  biomart = getBiomartTable(
113
- attributes=["start_position", "chromosome_name"]
112
+ attributes=["start_position", "chromosome_name"],
113
+ useCache=True,
114
114
  ).set_index("ensembl_gene_id")
115
115
  biomart = biomart.loc[~biomart.index.duplicated(keep="first")]
116
116
  biomart = biomart.sort_values(by=["chromosome_name", "start_position"])
@@ -129,7 +129,7 @@ class DataModule(L.LightningDataModule):
129
129
  prev_chromosome = r["chromosome_name"]
130
130
  print(f"reduced the size to {len(set(c))/len(biomart)}")
131
131
  biomart["pos"] = c
132
- mdataset.genedf = biomart.loc[mdataset.genedf.index]
132
+ mdataset.genedf = mdataset.genedf.join(biomart, how="inner")
133
133
  self.gene_pos = mdataset.genedf["pos"].astype(int).tolist()
134
134
 
135
135
  if gene_embeddings != "":
@@ -177,11 +177,18 @@ class Preprocessor:
177
177
  # # cleanup and dropping low expressed genes and unexpressed cells
178
178
  prevsize = adata.shape[0]
179
179
  adata.obs["nnz"] = np.array(np.sum(adata.X != 0, axis=1).flatten())[0]
180
- adata = adata[(adata.obs["nnz"] > self.min_nnz_genes)]
181
180
  if self.filter_gene_by_counts:
182
181
  sc.pp.filter_genes(adata, min_counts=self.filter_gene_by_counts)
183
182
  if self.filter_cell_by_counts:
184
- sc.pp.filter_cells(adata, min_counts=self.filter_cell_by_counts)
183
+ sc.pp.filter_cells(
184
+ adata,
185
+ min_counts=self.filter_cell_by_counts,
186
+ )
187
+ if self.min_nnz_genes:
188
+ sc.pp.filter_cells(
189
+ adata,
190
+ min_genes=self.min_nnz_genes,
191
+ )
185
192
  # if lost > 50% of the dataset, drop dataset
186
193
  # load the genes
187
194
  genesdf = data_utils.load_genes(adata.obs.organism_ontology_term_id.iloc[0])
@@ -297,7 +304,7 @@ class Preprocessor:
297
304
  # https://rapids-singlecell.readthedocs.io/en/latest/api/generated/rapids_singlecell.pp.pca.html#rapids_singlecell.pp.pca
298
305
  if self.do_postp:
299
306
  print("normalize")
300
- adata.layers["clean"] = sc.pp.log1p(
307
+ adata.layers["norm"] = sc.pp.log1p(
301
308
  sc.pp.normalize_total(
302
309
  adata, target_sum=self.normalize_sum, inplace=False
303
310
  )["X"]
@@ -306,20 +313,34 @@ class Preprocessor:
306
313
  if self.subset_hvg:
307
314
  sc.pp.highly_variable_genes(
308
315
  adata,
309
- layer="clean",
310
316
  n_top_genes=self.subset_hvg,
311
317
  batch_key=self.batch_key,
312
318
  flavor=self.hvg_flavor,
313
319
  subset=False,
314
320
  )
315
- adata.obsm["clean_pca"] = sc.pp.pca(
316
- adata.layers["clean"],
317
- n_comps=300 if adata.shape[0] > 300 else adata.shape[0] - 2,
321
+ sc.pp.log1p(adata, layer="norm")
322
+ sc.pp.pca(
323
+ adata,
324
+ layer="norm",
325
+ n_comps=200 if adata.shape[0] > 200 else adata.shape[0] - 2,
318
326
  )
319
- sc.pp.neighbors(adata, use_rep="clean_pca")
320
- sc.tl.leiden(adata, key_added="leiden_3", resolution=3.0)
327
+ sc.pp.neighbors(adata, use_rep="X_pca")
321
328
  sc.tl.leiden(adata, key_added="leiden_2", resolution=2.0)
322
329
  sc.tl.leiden(adata, key_added="leiden_1", resolution=1.0)
330
+ sc.tl.leiden(adata, key_added="leiden_0.5", resolution=0.5)
331
+ batches = [
332
+ "assay_ontology_term_id",
333
+ "self_reported_ethnicity_ontology_term_id",
334
+ "sex_ontology_term_id",
335
+ "development_stage_ontology_term_id",
336
+ ]
337
+ if "donor_id" in adata.obs.columns:
338
+ batches.append("donor_id")
339
+ if "suspension_type" in adata.obs.columns:
340
+ batches.append("suspension_type")
341
+ adata.obs["batches"] = adata.obs[batches].apply(
342
+ lambda x: ",".join(x.dropna().astype(str)), axis=1
343
+ )
323
344
  sc.tl.umap(adata)
324
345
  # additional
325
346
  if self.additional_postprocess is not None:
@@ -379,14 +400,12 @@ class LaminPreprocessor(Preprocessor):
379
400
  def __init__(
380
401
  self,
381
402
  *args,
382
- erase_prev_dataset: bool = False,
383
403
  cache: bool = True,
384
404
  stream: bool = False,
385
405
  keep_files: bool = True,
386
406
  **kwargs,
387
407
  ):
388
408
  super().__init__(*args, **kwargs)
389
- self.erase_prev_dataset = erase_prev_dataset
390
409
  self.cache = cache
391
410
  self.stream = stream
392
411
  self.keep_files = keep_files
@@ -418,14 +437,17 @@ class LaminPreprocessor(Preprocessor):
418
437
  elif isinstance(data, ln.Collection):
419
438
  for i, file in enumerate(data.artifacts.all()[start_at:]):
420
439
  # use the counts matrix
421
- print(i)
440
+ print(i + start_at)
422
441
  if file.stem_uid in all_ready_processed_keys:
423
442
  print(f"{file.stem_uid} is already processed... not preprocessing")
424
443
  continue
425
444
  print(file)
426
- backed = file.backed()
445
+ backed = file.open()
427
446
  if backed.obs.is_primary_data.sum() == 0:
428
447
  print(f"{file.key} only contains non primary cells.. dropping")
448
+ # Save the stem_uid to a file to avoid loading it again
449
+ with open("nonprimary.txt", "a") as f:
450
+ f.write(f"{file.stem_uid}\n")
429
451
  continue
430
452
  if backed.shape[1] < 1000:
431
453
  print(
@@ -449,17 +471,17 @@ class LaminPreprocessor(Preprocessor):
449
471
  (np.ceil(badata.shape[0] / 30_000) * 30_000) // num_blocks
450
472
  )
451
473
  print("num blocks ", num_blocks)
452
- for i in range(num_blocks):
453
- start_index = i * block_size
454
- end_index = min((i + 1) * block_size, badata.shape[0])
474
+ for j in range(num_blocks):
475
+ start_index = j * block_size
476
+ end_index = min((j + 1) * block_size, badata.shape[0])
455
477
  block = badata[start_index:end_index].to_memory()
456
478
  print(block)
457
479
  block = super().__call__(block)
458
- myfile = ln.Artifact(
480
+ myfile = ln.from_anndata(
459
481
  block,
460
- is_new_version_of=file,
482
+ revises=file,
461
483
  description=description,
462
- version=str(version) + "_s" + str(i),
484
+ version=str(version) + "_s" + str(j),
463
485
  )
464
486
  myfile.save()
465
487
  if self.keep_files:
@@ -470,9 +492,13 @@ class LaminPreprocessor(Preprocessor):
470
492
 
471
493
  else:
472
494
  adata = super().__call__(adata)
473
- myfile = ln.Artifact(
495
+ try:
496
+ sc.pl.umap(adata, color=["cell_type"])
497
+ except Exception:
498
+ sc.pl.umap(adata, color=["cell_type_ontology_term_id"])
499
+ myfile = ln.from_anndata(
474
500
  adata,
475
- is_new_version_of=file,
501
+ revises=file,
476
502
  description=description,
477
503
  version=str(version),
478
504
  )
@@ -646,46 +672,35 @@ def additional_preprocess(adata):
646
672
 
647
673
 
648
674
  def additional_postprocess(adata):
675
+ import palantir
676
+
649
677
  # define the "up to" 10 neighbors for each cells and add to obs
650
678
  # compute neighbors
651
679
  # need to be connectivities and same labels [cell type, assay, dataset, disease]
652
680
  # define the "neighbor" up to 10(N) cells and add to obs
653
681
  # define the "next time point" up to 5(M) cells and add to obs # step 1: filter genes
682
+ del adata.obsp["connectivities"]
683
+ del adata.obsp["distances"]
684
+ sc.external.pp.harmony_integrate(adata, key="batches")
685
+ sc.pp.neighbors(adata, use_rep="X_pca_harmony")
686
+ sc.tl.umap(adata)
687
+ sc.pl.umap(
688
+ adata,
689
+ color=["cell_type", "batches"],
690
+ )
691
+ palantir.utils.run_diffusion_maps(adata, n_components=20)
692
+ palantir.utils.determine_multiscale_space(adata)
693
+ terminal_states = palantir.utils.find_terminal_states(
694
+ adata,
695
+ celltypes=adata.obs.cell_type_ontology_term_id.unique(),
696
+ celltype_column="cell_type_ontology_term_id",
697
+ )
654
698
  sc.tl.diffmap(adata)
655
- # create a meta group
656
- adata.obs["dpt_group"] = (
657
- adata.obs["leiden_1"].astype(str)
658
- + "_"
659
- + adata.obs["disease_ontology_term_id"].astype(str)
660
- + "_"
661
- + adata.obs["cell_type_ontology_term_id"].astype(str)
662
- + "_"
663
- + adata.obs["tissue_ontology_term_id"].astype(str)
664
- ) # + "_" + adata.obs['dataset_id'].astype(str)
665
-
666
- # if group is too small
667
- okgroup = [i for i, j in adata.obs["dpt_group"].value_counts().items() if j >= 10]
668
- not_okgroup = [i for i, j in adata.obs["dpt_group"].value_counts().items() if j < 3]
669
- # set the group to empty
670
- adata.obs.loc[adata.obs["dpt_group"].isin(not_okgroup), "dpt_group"] = ""
671
- adata.obs["heat_diff"] = np.nan
672
- # for each group
673
- for val in set(okgroup):
674
- if val == "":
675
- continue
676
- # get the best root cell
677
- eq = adata.obs.dpt_group == val
678
- loc = np.where(eq)[0]
679
-
680
- root_ixs = loc[adata.obsm["X_diffmap"][eq, 0].argmin()]
681
- adata.uns["iroot"] = root_ixs
682
- # compute the diffusion pseudo time from it
699
+ adata.obs["heat_diff"] = 1
700
+ for terminal_state in terminal_states.index.tolist():
701
+ adata.uns["iroot"] = np.where(adata.obs.index == terminal_state)[0][0]
683
702
  sc.tl.dpt(adata)
684
- adata.obs.loc[eq, "heat_diff"] = adata.obs.loc[eq, "dpt_pseudotime"]
685
- adata.obs.drop(columns=["dpt_pseudotime"], inplace=True)
686
-
687
- # sort so that the next time points are aligned for all groups
688
- adata = adata[adata.obs.sort_values(["dpt_group", "heat_diff"]).index]
689
- # to query N next time points we just get the N elements below and check they are in the group
690
- # to query the N nearest neighbors we just get the N elements above and N below and check they are in the group
703
+ adata.obs["heat_diff"] = np.minimum(
704
+ adata.obs["heat_diff"], adata.obs["dpt_pseudotime"]
705
+ )
691
706
  return adata
scdataloader/utils.py CHANGED
@@ -1,23 +1,21 @@
1
1
  import io
2
2
  import os
3
3
  import urllib
4
+ from collections import Counter
5
+ from functools import lru_cache
6
+ from typing import List, Optional, Union
4
7
 
5
8
  import bionty as bt
6
9
  import lamindb as ln
7
10
  import numpy as np
8
11
  import pandas as pd
12
+ import torch
13
+ from anndata import AnnData
9
14
  from biomart import BiomartServer
10
15
  from django.db import IntegrityError
11
16
  from scipy.sparse import csr_matrix
12
17
  from scipy.stats import median_abs_deviation
13
- from functools import lru_cache
14
- from collections import Counter
15
18
  from torch import Tensor
16
- import torch
17
-
18
- from typing import Union, List, Optional
19
-
20
- from anndata import AnnData
21
19
 
22
20
 
23
21
  def downsample_profile(mat: Tensor, dropout: float):
@@ -92,7 +90,7 @@ def _fetchFromServer(
92
90
 
93
91
 
94
92
  def getBiomartTable(
95
- ensemble_server: str = "http://jul2023.archive.ensembl.org/biomart",
93
+ ensemble_server: str = "http://may2024.archive.ensembl.org/biomart",
96
94
  useCache: bool = False,
97
95
  cache_folder: str = "/tmp/biomart/",
98
96
  attributes: List[str] = [],
@@ -102,7 +100,7 @@ def getBiomartTable(
102
100
  """generate a genelist dataframe from ensembl's biomart
103
101
 
104
102
  Args:
105
- ensemble_server (str, optional): the biomart server. Defaults to "http://jul2023.archive.ensembl.org/biomart".
103
+ ensemble_server (str, optional): the biomart server. Defaults to "http://may2023.archive.ensembl.org/biomart".
106
104
  useCache (bool, optional): whether to use the cache or not. Defaults to False.
107
105
  cache_folder (str, optional): the cache folder. Defaults to "/tmp/biomart/".
108
106
  attributes (List[str], optional): the attributes to fetch. Defaults to [].
@@ -143,7 +141,6 @@ def getBiomartTable(
143
141
  raise ValueError("should be a dataframe")
144
142
  res = res[~(res["ensembl_gene_id"].isna())]
145
143
  if "hgnc_symbol" in res.columns:
146
- res = res[res["hgnc_symbol"].isna()]
147
144
  res.loc[res[res.hgnc_symbol.isna()].index, "hgnc_symbol"] = res[
148
145
  res.hgnc_symbol.isna()
149
146
  ]["ensembl_gene_id"]
@@ -371,10 +368,9 @@ def load_genes(organisms: Union[str, list] = "NCBITaxon:9606"): # "NCBITaxon:10
371
368
  genesdf["organism"] = organism
372
369
  organismdf.append(genesdf)
373
370
  organismdf = pd.concat(organismdf)
374
- organismdf.drop(
375
- columns=["source_id", "run_id", "created_by_id", "updated_at", "stable_id"],
376
- inplace=True,
377
- )
371
+ for col in ["source_id", "run_id", "created_by_id", "updated_at", "stable_id", "created_at"]:
372
+ if col in organismdf.columns:
373
+ organismdf.drop(columns=[col], inplace=True)
378
374
  return organismdf
379
375
 
380
376
 
@@ -387,6 +383,7 @@ def populate_my_ontology(
387
383
  tissues: List[str] = [],
388
384
  diseases: List[str] = [],
389
385
  dev_stages: List[str] = [],
386
+ organism_clade: str = "vertebrates",
390
387
  ):
391
388
  """
392
389
  creates a local version of the lamin ontologies and add the required missing values in base ontologies
@@ -397,7 +394,7 @@ def populate_my_ontology(
397
394
 
398
395
  add whatever value you need afterward like it is done here with:
399
396
 
400
- `bt.$ontology(name="ddd", ontology_id="ddddd").save()`
397
+ `bt.$ontology(name="ddd", ontolbogy_id="ddddd").save()`
401
398
 
402
399
  `df["assay_ontology_term_id"].unique()`
403
400
 
@@ -414,89 +411,111 @@ def populate_my_ontology(
414
411
  """
415
412
  # cell type
416
413
  if celltypes is not None:
417
- names = bt.CellType.public().df().index if not celltypes else celltypes
418
- records = bt.CellType.from_values(names, field="ontology_id")
419
- ln.save(records)
414
+ if len(celltypes) == 0:
415
+ bt.CellType.import_from_source()
416
+ else:
417
+ names = bt.CellType.public().df().index if not celltypes else celltypes
418
+ records = bt.CellType.from_values(names, field="ontology_id")
419
+ ln.save(records)
420
420
  bt.CellType(name="unknown", ontology_id="unknown").save()
421
421
  # Organism
422
422
  if organisms is not None:
423
- names = bt.Organism.public().df().index if not organisms else organisms
423
+ names = (
424
+ bt.Organism.public(organism=organism_clade).df().index
425
+ if not organisms
426
+ else organisms
427
+ )
428
+ source = bt.PublicSource.filter(name="ensembl", organism=organism_clade).last()
424
429
  records = [
425
430
  i[0] if type(i) is list else i
426
- for i in [bt.Organism.from_source(ontology_id=i) for i in names]
431
+ for i in [
432
+ bt.Organism.from_source(ontology_id=i, source=source) for i in names
433
+ ]
427
434
  ]
428
435
  ln.save(records)
429
436
  bt.Organism(name="unknown", ontology_id="unknown").save()
430
- organism_names = names
431
437
  # Phenotype
432
438
  if sex is not None:
433
439
  names = bt.Phenotype.public().df().index if not sex else sex
440
+ source = bt.PublicSource.filter(name="pato").first()
434
441
  records = [
435
- bt.Phenotype.from_source(
436
- ontology_id=i, source=bt.PublicSource.filter(name="pato").first()
437
- )
438
- for i in names
442
+ bt.Phenotype.from_source(ontology_id=i, source=source) for i in names
439
443
  ]
440
444
  ln.save(records)
441
445
  bt.Phenotype(name="unknown", ontology_id="unknown").save()
442
446
  # ethnicity
443
447
  if ethnicities is not None:
444
- names = bt.Ethnicity.public().df().index if not ethnicities else ethnicities
445
- records = bt.Ethnicity.from_values(names, field="ontology_id")
446
- ln.save(records)
448
+ if len(ethnicities) == 0:
449
+ bt.Ethnicity.import_from_source()
450
+ else:
451
+ names = bt.Ethnicity.public().df().index if not ethnicities else ethnicities
452
+ records = bt.Ethnicity.from_values(names, field="ontology_id")
453
+ ln.save(records)
447
454
  bt.Ethnicity(
448
455
  name="unknown", ontology_id="unknown"
449
456
  ).save() # multi ethnic will have to get renamed
450
457
  # ExperimentalFactor
451
458
  if assays is not None:
452
- names = bt.ExperimentalFactor.public().df().index if not assays else assays
453
- records = bt.ExperimentalFactor.from_values(names, field="ontology_id")
454
- ln.save(records)
459
+ if len(assays) == 0:
460
+ bt.ExperimentalFactor.import_from_source()
461
+ else:
462
+ names = bt.ExperimentalFactor.public().df().index if not assays else assays
463
+ records = bt.ExperimentalFactor.from_values(names, field="ontology_id")
464
+ ln.save(records)
455
465
  bt.ExperimentalFactor(name="unknown", ontology_id="unknown").save()
456
466
  # lookup = bt.ExperimentalFactor.lookup()
457
467
  # lookup.smart_seq_v4.parents.add(lookup.smart_like)
458
468
  # Tissue
459
469
  if tissues is not None:
460
- names = bt.Tissue.public().df().index if not tissues else tissues
461
- records = bt.Tissue.from_values(names, field="ontology_id")
462
- ln.save(records)
470
+ if len(tissues) == 0:
471
+ bt.Tissue.import_from_source()
472
+ else:
473
+ names = bt.Tissue.public().df().index if not tissues else tissues
474
+ records = bt.Tissue.from_values(names, field="ontology_id")
475
+ ln.save(records)
463
476
  bt.Tissue(name="unknown", ontology_id="unknown").save()
464
477
  # DevelopmentalStage
465
478
  if dev_stages is not None:
466
- names = (
467
- bt.DevelopmentalStage.public().df().index if not dev_stages else dev_stages
468
- )
469
- records = bt.DevelopmentalStage.from_values(names, field="ontology_id")
470
- ln.save(records)
479
+ if len(dev_stages) == 0:
480
+ bt.DevelopmentalStage.import_from_source()
481
+ source = bt.PublicSource.filter(organism="mouse", name="mmusdv").last()
482
+ bt.DevelopmentalStage.import_from_source(source=source)
483
+ else:
484
+ names = (
485
+ bt.DevelopmentalStage.public().df().index
486
+ if not dev_stages
487
+ else dev_stages
488
+ )
489
+ records = bt.DevelopmentalStage.from_values(names, field="ontology_id")
490
+ ln.save(records)
471
491
  bt.DevelopmentalStage(name="unknown", ontology_id="unknown").save()
472
492
 
473
- names = bt.DevelopmentalStage.public(organism="mouse").df().index
474
- records = [
475
- bt.DevelopmentalStage.from_source(
476
- ontology_id=i,
477
- source=bt.PublicSource.filter(organism="mouse", name="mmusdv").first(),
478
- )
479
- for i in names.tolist()
480
- ]
481
- ln.save(records)
482
493
  # Disease
483
494
  if diseases is not None:
484
- names = bt.Disease.public().df().index if not diseases else diseases
485
- records = bt.Disease.from_values(names, field="ontology_id")
486
- ln.save(records)
495
+ if len(diseases) == 0:
496
+ bt.Disease.import_from_source()
497
+ else:
498
+ names = bt.Disease.public().df().index if not diseases else diseases
499
+ records = bt.Disease.from_values(names, field="ontology_id")
500
+ ln.save(records)
487
501
  bt.Disease(name="normal", ontology_id="PATO:0000461").save()
488
502
  bt.Disease(name="unknown", ontology_id="unknown").save()
489
503
  # genes
490
- for organism in organism_names:
504
+ for organism in ["NCBITaxon:10090", "NCBITaxon:9606"]:
491
505
  # convert onto to name
492
506
  organism = bt.Organism.filter(ontology_id=organism).one().name
493
507
  names = bt.Gene.public(organism=organism).df()["ensembl_gene_id"]
494
- records = bt.Gene.from_values(
495
- names,
496
- field="ensembl_gene_id",
497
- organism=organism,
498
- )
499
- ln.save(records)
508
+
509
+ # Process names in blocks of 10,000 elements
510
+ block_size = 10000
511
+ for i in range(0, len(names), block_size):
512
+ block = names[i : i + block_size]
513
+ records = bt.Gene.from_values(
514
+ block,
515
+ field="ensembl_gene_id",
516
+ organism=organism,
517
+ )
518
+ ln.save(records)
500
519
 
501
520
 
502
521
  def is_outlier(adata: AnnData, metric: str, nmads: int):