smftools 0.2.5__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. smftools/__init__.py +39 -7
  2. smftools/_settings.py +2 -0
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +2 -0
  7. smftools/cli/hmm_adata.py +7 -2
  8. smftools/cli/load_adata.py +130 -98
  9. smftools/cli/preprocess_adata.py +2 -0
  10. smftools/cli/spatial_adata.py +5 -1
  11. smftools/cli_entry.py +26 -1
  12. smftools/config/__init__.py +2 -0
  13. smftools/config/default.yaml +4 -1
  14. smftools/config/experiment_config.py +6 -0
  15. smftools/datasets/__init__.py +2 -0
  16. smftools/hmm/HMM.py +9 -3
  17. smftools/hmm/__init__.py +24 -13
  18. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  19. smftools/hmm/archived/calculate_distances.py +2 -0
  20. smftools/hmm/archived/call_hmm_peaks.py +2 -0
  21. smftools/hmm/archived/train_hmm.py +2 -0
  22. smftools/hmm/call_hmm_peaks.py +5 -2
  23. smftools/hmm/display_hmm.py +4 -1
  24. smftools/hmm/hmm_readwrite.py +7 -2
  25. smftools/hmm/nucleosome_hmm_refinement.py +2 -0
  26. smftools/informatics/__init__.py +53 -34
  27. smftools/informatics/archived/bam_conversion.py +2 -0
  28. smftools/informatics/archived/bam_direct.py +2 -0
  29. smftools/informatics/archived/basecall_pod5s.py +2 -0
  30. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  31. smftools/informatics/archived/conversion_smf.py +2 -0
  32. smftools/informatics/archived/deaminase_smf.py +1 -0
  33. smftools/informatics/archived/direct_smf.py +2 -0
  34. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  35. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  36. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +2 -0
  37. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  38. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  39. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  40. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  41. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  42. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  43. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  44. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  45. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  46. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  47. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  48. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  49. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  50. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  51. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  52. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  53. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  54. smftools/informatics/archived/helpers/archived/load_adata.py +2 -0
  55. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  56. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  57. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  58. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  59. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  60. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  61. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  62. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +2 -0
  63. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  64. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  65. smftools/informatics/archived/print_bam_query_seq.py +2 -0
  66. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  67. smftools/informatics/archived/subsample_pod5.py +2 -0
  68. smftools/informatics/bam_functions.py +737 -170
  69. smftools/informatics/basecalling.py +2 -0
  70. smftools/informatics/bed_functions.py +271 -61
  71. smftools/informatics/binarize_converted_base_identities.py +3 -0
  72. smftools/informatics/complement_base_list.py +2 -0
  73. smftools/informatics/converted_BAM_to_adata.py +66 -22
  74. smftools/informatics/fasta_functions.py +94 -10
  75. smftools/informatics/h5ad_functions.py +8 -2
  76. smftools/informatics/modkit_extract_to_adata.py +16 -6
  77. smftools/informatics/modkit_functions.py +2 -0
  78. smftools/informatics/ohe.py +2 -0
  79. smftools/informatics/pod5_functions.py +3 -2
  80. smftools/machine_learning/__init__.py +22 -6
  81. smftools/machine_learning/data/__init__.py +2 -0
  82. smftools/machine_learning/data/anndata_data_module.py +18 -4
  83. smftools/machine_learning/data/preprocessing.py +2 -0
  84. smftools/machine_learning/evaluation/__init__.py +2 -0
  85. smftools/machine_learning/evaluation/eval_utils.py +2 -0
  86. smftools/machine_learning/evaluation/evaluators.py +14 -9
  87. smftools/machine_learning/inference/__init__.py +2 -0
  88. smftools/machine_learning/inference/inference_utils.py +2 -0
  89. smftools/machine_learning/inference/lightning_inference.py +6 -1
  90. smftools/machine_learning/inference/sklearn_inference.py +2 -0
  91. smftools/machine_learning/inference/sliding_window_inference.py +2 -0
  92. smftools/machine_learning/models/__init__.py +2 -0
  93. smftools/machine_learning/models/base.py +7 -2
  94. smftools/machine_learning/models/cnn.py +7 -2
  95. smftools/machine_learning/models/lightning_base.py +16 -11
  96. smftools/machine_learning/models/mlp.py +5 -1
  97. smftools/machine_learning/models/positional.py +7 -2
  98. smftools/machine_learning/models/rnn.py +5 -1
  99. smftools/machine_learning/models/sklearn_models.py +14 -9
  100. smftools/machine_learning/models/transformer.py +7 -2
  101. smftools/machine_learning/models/wrappers.py +6 -2
  102. smftools/machine_learning/training/__init__.py +2 -0
  103. smftools/machine_learning/training/train_lightning_model.py +13 -3
  104. smftools/machine_learning/training/train_sklearn_model.py +2 -0
  105. smftools/machine_learning/utils/__init__.py +2 -0
  106. smftools/machine_learning/utils/device.py +5 -1
  107. smftools/machine_learning/utils/grl.py +5 -1
  108. smftools/optional_imports.py +31 -0
  109. smftools/plotting/__init__.py +32 -31
  110. smftools/plotting/autocorrelation_plotting.py +9 -5
  111. smftools/plotting/classifiers.py +16 -4
  112. smftools/plotting/general_plotting.py +6 -3
  113. smftools/plotting/hmm_plotting.py +12 -2
  114. smftools/plotting/position_stats.py +15 -7
  115. smftools/plotting/qc_plotting.py +6 -1
  116. smftools/preprocessing/__init__.py +35 -37
  117. smftools/preprocessing/archived/add_read_length_and_mapping_qc.py +2 -0
  118. smftools/preprocessing/archived/calculate_complexity.py +2 -0
  119. smftools/preprocessing/archived/mark_duplicates.py +2 -0
  120. smftools/preprocessing/archived/preprocessing.py +2 -0
  121. smftools/preprocessing/archived/remove_duplicates.py +2 -0
  122. smftools/preprocessing/binary_layers_to_ohe.py +2 -1
  123. smftools/preprocessing/calculate_complexity_II.py +4 -1
  124. smftools/preprocessing/calculate_pairwise_differences.py +2 -0
  125. smftools/preprocessing/calculate_pairwise_hamming_distances.py +3 -0
  126. smftools/preprocessing/calculate_position_Youden.py +9 -2
  127. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +2 -0
  128. smftools/preprocessing/filter_reads_on_modification_thresholds.py +2 -0
  129. smftools/preprocessing/flag_duplicate_reads.py +42 -54
  130. smftools/preprocessing/make_dirs.py +2 -1
  131. smftools/preprocessing/min_non_diagonal.py +2 -0
  132. smftools/preprocessing/recipes.py +2 -0
  133. smftools/tools/__init__.py +26 -18
  134. smftools/tools/archived/apply_hmm.py +2 -0
  135. smftools/tools/archived/classifiers.py +2 -0
  136. smftools/tools/archived/classify_methylated_features.py +2 -0
  137. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  138. smftools/tools/archived/subset_adata_v1.py +2 -0
  139. smftools/tools/archived/subset_adata_v2.py +2 -0
  140. smftools/tools/calculate_umap.py +3 -1
  141. smftools/tools/cluster_adata_on_methylation.py +7 -1
  142. smftools/tools/position_stats.py +17 -27
  143. {smftools-0.2.5.dist-info → smftools-0.3.0.dist-info}/METADATA +67 -33
  144. smftools-0.3.0.dist-info/RECORD +182 -0
  145. smftools-0.2.5.dist-info/RECORD +0 -181
  146. {smftools-0.2.5.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
  147. {smftools-0.2.5.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
  148. {smftools-0.2.5.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,23 +1,93 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import gzip
4
+ import shutil
5
+ import subprocess
4
6
  from concurrent.futures import ProcessPoolExecutor
7
+ from importlib.util import find_spec
5
8
  from pathlib import Path
6
- from typing import Dict, Iterable, Tuple
9
+ from typing import TYPE_CHECKING, Dict, Iterable, Tuple
7
10
 
8
11
  import numpy as np
9
- import pysam
10
12
  from Bio import SeqIO
11
13
  from Bio.Seq import Seq
12
14
  from Bio.SeqRecord import SeqRecord
13
- from pyfaidx import Fasta
14
15
 
15
16
  from smftools.logging_utils import get_logger
17
+ from smftools.optional_imports import require
16
18
 
17
19
  from ..readwrite import time_string
18
20
 
19
21
  logger = get_logger(__name__)
20
22
 
23
+ if TYPE_CHECKING:
24
+ import pysam as pysam_module
25
+
26
+
27
+ def _require_pysam() -> "pysam_module":
28
+ if pysam_types is not None:
29
+ return pysam_types
30
+ return require("pysam", extra="pysam", purpose="FASTA access")
31
+
32
+
33
+ pysam_types = None
34
+ if find_spec("pysam") is not None:
35
+ pysam_types = require("pysam", extra="pysam", purpose="FASTA access")
36
+
37
+
38
+ def _resolve_fasta_backend() -> str:
39
+ """Resolve the backend to use for FASTA access."""
40
+ if pysam_types is not None:
41
+ return "python"
42
+ if shutil is not None and shutil.which("samtools"):
43
+ return "cli"
44
+ raise RuntimeError("FASTA access requires pysam or samtools in PATH.")
45
+
46
+
47
+ def _ensure_fasta_index(fasta: Path) -> None:
48
+ fai = fasta.with_suffix(fasta.suffix + ".fai")
49
+ if fai.exists():
50
+ return
51
+ if subprocess is None or shutil is None or not shutil.which("samtools"):
52
+ pysam_mod = _require_pysam()
53
+ pysam_mod.faidx(str(fasta))
54
+ return
55
+ cp = subprocess.run(
56
+ ["samtools", "faidx", str(fasta)],
57
+ stdout=subprocess.DEVNULL,
58
+ stderr=subprocess.PIPE,
59
+ text=True,
60
+ )
61
+ if cp.returncode != 0:
62
+ raise RuntimeError(f"samtools faidx failed (exit {cp.returncode}):\n{cp.stderr}")
63
+
64
+
65
+ def _bed_to_faidx_region(chrom: str, start: int, end: int) -> str:
66
+ """Convert 0-based half-open BED coords to samtools faidx region."""
67
+ start1 = start + 1
68
+ end1 = end
69
+ if start1 > end1:
70
+ start1, end1 = end1, start1
71
+ return f"{chrom}:{start1}-{end1}"
72
+
73
+
74
+ def _fetch_sequence_with_samtools(fasta: Path, chrom: str, start: int, end: int) -> str:
75
+ if subprocess is None or shutil is None:
76
+ raise RuntimeError("samtools backend is unavailable.")
77
+ if not shutil.which("samtools"):
78
+ raise RuntimeError("samtools is required but not available in PATH.")
79
+ region = _bed_to_faidx_region(chrom, start, end)
80
+ cp = subprocess.run(
81
+ ["samtools", "faidx", str(fasta), region],
82
+ stdout=subprocess.PIPE,
83
+ stderr=subprocess.PIPE,
84
+ text=True,
85
+ )
86
+ if cp.returncode != 0:
87
+ raise RuntimeError(f"samtools faidx failed (exit {cp.returncode}):\n{cp.stderr}")
88
+ lines = [line.strip() for line in cp.stdout.splitlines() if line and not line.startswith(">")]
89
+ return "".join(lines)
90
+
21
91
 
22
92
  def _convert_FASTA_record(
23
93
  record: SeqRecord,
@@ -160,7 +230,7 @@ def index_fasta(fasta: str | Path, write_chrom_sizes: bool = True) -> Path:
160
230
  Path: Path to the index file or chromosome sizes file.
161
231
  """
162
232
  fasta = Path(fasta)
163
- pysam.faidx(str(fasta)) # creates <fasta>.fai
233
+ _require_pysam().faidx(str(fasta)) # creates <fasta>.fai
164
234
 
165
235
  fai = fasta.with_suffix(fasta.suffix + ".fai")
166
236
  if write_chrom_sizes:
@@ -307,8 +377,13 @@ def subsample_fasta_from_bed(
307
377
  # Ensure output directory exists
308
378
  output_directory.mkdir(parents=True, exist_ok=True)
309
379
 
310
- # Load the FASTA file using pyfaidx
311
- fasta = Fasta(str(input_FASTA)) # pyfaidx requires string paths
380
+ backend = _resolve_fasta_backend()
381
+ _ensure_fasta_index(input_FASTA)
382
+
383
+ fasta_handle = None
384
+ if backend == "python":
385
+ pysam_mod = _require_pysam()
386
+ fasta_handle = pysam_mod.FastaFile(str(input_FASTA))
312
387
 
313
388
  # Open BED + output FASTA
314
389
  with input_bed.open("r") as bed, output_FASTA.open("w") as out_fasta:
@@ -319,15 +394,24 @@ def subsample_fasta_from_bed(
319
394
  end = int(fields[2]) # BED is 0-based and end is exclusive
320
395
  desc = " ".join(fields[3:]) if len(fields) > 3 else ""
321
396
 
322
- if chrom not in fasta:
397
+ if backend == "python":
398
+ assert fasta_handle is not None
399
+ if chrom not in fasta_handle.references:
400
+ logger.warning(f"{chrom} not found in FASTA")
401
+ continue
402
+ sequence = fasta_handle.fetch(chrom, start, end)
403
+ else:
404
+ sequence = _fetch_sequence_with_samtools(input_FASTA, chrom, start, end)
405
+
406
+ if not sequence:
323
407
  logger.warning(f"{chrom} not found in FASTA")
324
408
  continue
325
409
 
326
- # pyfaidx is 1-based indexing internally, but [start:end] works with BED coords
327
- sequence = fasta[chrom][start:end].seq
328
-
329
410
  header = f">{chrom}:{start}-{end}"
330
411
  if desc:
331
412
  header += f" {desc}"
332
413
 
333
414
  out_fasta.write(f"{header}\n{sequence}\n")
415
+
416
+ if fasta_handle is not None:
417
+ fasta_handle.close()
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import glob
2
4
  import os
3
5
  from concurrent.futures import ProcessPoolExecutor, as_completed
@@ -7,9 +9,9 @@ from typing import Dict, List, Optional, Union
7
9
  import numpy as np
8
10
  import pandas as pd
9
11
  import scipy.sparse as sp
10
- from pod5 import Reader
11
12
 
12
13
  from smftools.logging_utils import get_logger
14
+ from smftools.optional_imports import require
13
15
 
14
16
  logger = get_logger(__name__)
15
17
 
@@ -90,6 +92,7 @@ def add_read_length_and_mapping_qc(
90
92
  extract_read_features_from_bam_callable=None,
91
93
  bypass: bool = False,
92
94
  force_redo: bool = True,
95
+ samtools_backend: str | None = "auto",
93
96
  ):
94
97
  """
95
98
  Populate adata.obs with read/mapping QC columns.
@@ -133,7 +136,7 @@ def add_read_length_and_mapping_qc(
133
136
  "No `read_metrics` provided and `extract_read_features_from_bam` not found."
134
137
  )
135
138
  for bam in bam_files:
136
- bam_read_metrics = extractor(bam)
139
+ bam_read_metrics = extractor(bam, samtools_backend)
137
140
  if not isinstance(bam_read_metrics, dict):
138
141
  raise ValueError(f"extract_read_features_from_bam returned non-dict for {bam}")
139
142
  read_metrics.update(bam_read_metrics)
@@ -228,6 +231,9 @@ def _collect_read_origins_from_pod5(pod5_path: str, target_ids: set[str]) -> dic
228
231
  Worker function: scan one POD5 file and return a mapping
229
232
  {read_id: pod5_basename} only for read_ids in `target_ids`.
230
233
  """
234
+ p5 = require("pod5", extra="ont", purpose="POD5 metadata")
235
+ Reader = p5.Reader
236
+
231
237
  basename = os.path.basename(pod5_path)
232
238
  mapping: dict[str, str] = {}
233
239
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import concurrent.futures
2
4
  import gc
3
5
  import re
@@ -16,9 +18,11 @@ from .bam_functions import count_aligned_reads
16
18
  logger = get_logger(__name__)
17
19
 
18
20
 
19
- def filter_bam_records(bam, mapping_threshold):
21
+ def filter_bam_records(bam, mapping_threshold, samtools_backend: str | None = "auto"):
20
22
  """Processes a single BAM file, counts reads, and determines records to analyze."""
21
- aligned_reads_count, unaligned_reads_count, record_counts_dict = count_aligned_reads(bam)
23
+ aligned_reads_count, unaligned_reads_count, record_counts_dict = count_aligned_reads(
24
+ bam, samtools_backend
25
+ )
22
26
 
23
27
  total_reads = aligned_reads_count + unaligned_reads_count
24
28
  percent_aligned = (aligned_reads_count * 100 / total_reads) if total_reads > 0 else 0
@@ -35,13 +39,16 @@ def filter_bam_records(bam, mapping_threshold):
35
39
  return set(records)
36
40
 
37
41
 
38
- def parallel_filter_bams(bam_path_list, mapping_threshold):
42
+ def parallel_filter_bams(bam_path_list, mapping_threshold, samtools_backend: str | None = "auto"):
39
43
  """Parallel processing for multiple BAM files."""
40
44
  records_to_analyze = set()
41
45
 
42
46
  with concurrent.futures.ProcessPoolExecutor() as executor:
43
47
  results = executor.map(
44
- filter_bam_records, bam_path_list, [mapping_threshold] * len(bam_path_list)
48
+ filter_bam_records,
49
+ bam_path_list,
50
+ [mapping_threshold] * len(bam_path_list),
51
+ [samtools_backend] * len(bam_path_list),
45
52
  )
46
53
 
47
54
  # Aggregate results
@@ -484,6 +491,7 @@ def modkit_extract_to_adata(
484
491
  delete_batch_hdfs=False,
485
492
  threads=None,
486
493
  double_barcoded_path=None,
494
+ samtools_backend: str | None = "auto",
487
495
  ):
488
496
  """
489
497
  Takes modkit extract outputs and organizes it into an adata object
@@ -591,7 +599,7 @@ def modkit_extract_to_adata(
591
599
 
592
600
  ######### Get Record names that have over a passed threshold of mapped reads #############
593
601
  # get all records that are above a certain mapping threshold in at least one sample bam
594
- records_to_analyze = parallel_filter_bams(bam_path_list, mapping_threshold)
602
+ records_to_analyze = parallel_filter_bams(bam_path_list, mapping_threshold, samtools_backend)
595
603
 
596
604
  ##########################################################################################
597
605
 
@@ -635,7 +643,9 @@ def modkit_extract_to_adata(
635
643
  rev_base_identities,
636
644
  mismatch_counts_per_read,
637
645
  mismatch_trend_per_read,
638
- ) = extract_base_identities(bam, record, positions, max_reference_length, ref_seq)
646
+ ) = extract_base_identities(
647
+ bam, record, positions, max_reference_length, ref_seq, samtools_backend
648
+ )
639
649
  # Store read names of fwd and rev mapped reads
640
650
  fwd_mapped_reads.update(fwd_base_identities.keys())
641
651
  rev_mapped_reads.update(rev_base_identities.keys())
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import subprocess
2
4
 
3
5
  from smftools.logging_utils import get_logger
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import concurrent.futures
2
4
  import os
3
5
 
@@ -5,9 +5,8 @@ import subprocess
5
5
  from pathlib import Path
6
6
  from typing import Iterable
7
7
 
8
- import pod5 as p5
9
-
10
8
  from smftools.logging_utils import get_logger
9
+ from smftools.optional_imports import require
11
10
 
12
11
  from ..config import LoadExperimentConfig
13
12
  from ..informatics.basecalling import canoncall, modcall
@@ -15,6 +14,8 @@ from ..readwrite import make_dirs
15
14
 
16
15
  logger = get_logger(__name__)
17
16
 
17
+ p5 = require("pod5", extra="ont", purpose="POD5 IO")
18
+
18
19
 
19
20
  def basecall_pod5s(config_path: str | Path) -> None:
20
21
  """Basecall POD5 inputs using a configuration file.
@@ -1,7 +1,23 @@
1
- from . import data, evaluation, inference, models, training, utils
1
+ from __future__ import annotations
2
2
 
3
- __all__ = [
4
- "calculate_relative_risk_on_activity",
5
- "evaluate_models_by_subgroup",
6
- "prepare_melted_model_data",
7
- ]
3
+ from importlib import import_module
4
+
5
+ _LAZY_MODULES = {
6
+ "data": "smftools.machine_learning.data",
7
+ "evaluation": "smftools.machine_learning.evaluation",
8
+ "inference": "smftools.machine_learning.inference",
9
+ "models": "smftools.machine_learning.models",
10
+ "training": "smftools.machine_learning.training",
11
+ "utils": "smftools.machine_learning.utils",
12
+ }
13
+
14
+
15
+ def __getattr__(name: str):
16
+ if name in _LAZY_MODULES:
17
+ module = import_module(_LAZY_MODULES[name])
18
+ globals()[name] = module
19
+ return module
20
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
21
+
22
+
23
+ __all__ = list(_LAZY_MODULES.keys())
@@ -1,2 +1,4 @@
1
+ from __future__ import annotations
2
+
1
3
  from .anndata_data_module import AnnDataModule, build_anndata_loader
2
4
  from .preprocessing import random_fill_nans
@@ -1,12 +1,26 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
4
  import pandas as pd
3
- import pytorch_lightning as pl
4
- import torch
5
- from sklearn.utils.class_weight import compute_class_weight
6
- from torch.utils.data import DataLoader, Dataset, Subset
5
+
6
+ from smftools.optional_imports import require
7
7
 
8
8
  from .preprocessing import random_fill_nans
9
9
 
10
+ pl = require("pytorch_lightning", extra="ml-extended", purpose="Lightning data modules")
11
+ torch = require("torch", extra="ml-base", purpose="ML data loading")
12
+ sklearn_class_weight = require(
13
+ "sklearn.utils.class_weight",
14
+ extra="ml-base",
15
+ purpose="class weighting",
16
+ )
17
+ torch_utils_data = require("torch.utils.data", extra="ml-base", purpose="ML data loading")
18
+
19
+ compute_class_weight = sklearn_class_weight.compute_class_weight
20
+ DataLoader = torch_utils_data.DataLoader
21
+ Dataset = torch_utils_data.Dataset
22
+ Subset = torch_utils_data.Subset
23
+
10
24
 
11
25
  class AnnDataDataset(Dataset):
12
26
  """
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
4
 
3
5
 
@@ -1,2 +1,4 @@
1
+ from __future__ import annotations
2
+
1
3
  from .eval_utils import flatten_sliding_window_results
2
4
  from .evaluators import ModelEvaluator, PostInferenceModelEvaluator
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import pandas as pd
2
4
 
3
5
 
@@ -1,14 +1,19 @@
1
- import matplotlib.pyplot as plt
1
+ from __future__ import annotations
2
+
2
3
  import numpy as np
3
4
  import pandas as pd
4
- from sklearn.metrics import (
5
- auc,
6
- confusion_matrix,
7
- f1_score,
8
- precision_recall_curve,
9
- roc_auc_score,
10
- roc_curve,
11
- )
5
+
6
+ from smftools.optional_imports import require
7
+
8
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="evaluation plots")
9
+ sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model evaluation")
10
+
11
+ auc = sklearn_metrics.auc
12
+ confusion_matrix = sklearn_metrics.confusion_matrix
13
+ f1_score = sklearn_metrics.f1_score
14
+ precision_recall_curve = sklearn_metrics.precision_recall_curve
15
+ roc_auc_score = sklearn_metrics.roc_auc_score
16
+ roc_curve = sklearn_metrics.roc_curve
12
17
 
13
18
 
14
19
  class ModelEvaluator:
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from .lightning_inference import run_lightning_inference
2
4
  from .sklearn_inference import run_sklearn_inference
3
5
  from .sliding_window_inference import sliding_window_inference
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import pandas as pd
2
4
 
3
5
 
@@ -1,9 +1,14 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
4
  import pandas as pd
3
- import torch
5
+
6
+ from smftools.optional_imports import require
4
7
 
5
8
  from .inference_utils import annotate_split_column
6
9
 
10
+ torch = require("torch", extra="ml-base", purpose="Lightning inference")
11
+
7
12
 
8
13
  def run_lightning_inference(adata, model, datamodule, trainer, prefix="model", devices=1):
9
14
  """
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
4
  import pandas as pd
3
5
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from ..data import AnnDataModule
2
4
  from ..evaluation import PostInferenceModelEvaluator
3
5
  from .lightning_inference import run_lightning_inference
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from .base import BaseTorchModel
2
4
  from .cnn import CNNClassifier
3
5
  from .lightning_base import TorchClassifierWrapper
@@ -1,9 +1,14 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
- import torch
3
- import torch.nn as nn
4
+
5
+ from smftools.optional_imports import require
4
6
 
5
7
  from ..utils.device import detect_device
6
8
 
9
+ torch = require("torch", extra="ml-base", purpose="ML base models")
10
+ nn = torch.nn
11
+
7
12
 
8
13
  class BaseTorchModel(nn.Module):
9
14
  """
@@ -1,9 +1,14 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
- import torch
3
- import torch.nn as nn
4
+
5
+ from smftools.optional_imports import require
4
6
 
5
7
  from .base import BaseTorchModel
6
8
 
9
+ torch = require("torch", extra="ml-base", purpose="CNN models")
10
+ nn = torch.nn
11
+
7
12
 
8
13
  class CNNClassifier(BaseTorchModel):
9
14
  def __init__(
@@ -1,15 +1,20 @@
1
- import matplotlib.pyplot as plt
1
+ from __future__ import annotations
2
+
2
3
  import numpy as np
3
- import pytorch_lightning as pl
4
- import torch
5
- from sklearn.metrics import (
6
- auc,
7
- confusion_matrix,
8
- f1_score,
9
- precision_recall_curve,
10
- roc_auc_score,
11
- roc_curve,
12
- )
4
+
5
+ from smftools.optional_imports import require
6
+
7
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="model evaluation plots")
8
+ pl = require("pytorch_lightning", extra="ml-extended", purpose="Lightning models")
9
+ torch = require("torch", extra="ml-base", purpose="Lightning models")
10
+ sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model evaluation")
11
+
12
+ auc = sklearn_metrics.auc
13
+ confusion_matrix = sklearn_metrics.confusion_matrix
14
+ f1_score = sklearn_metrics.f1_score
15
+ precision_recall_curve = sklearn_metrics.precision_recall_curve
16
+ roc_auc_score = sklearn_metrics.roc_auc_score
17
+ roc_curve = sklearn_metrics.roc_curve
13
18
 
14
19
 
15
20
  class TorchClassifierWrapper(pl.LightningModule):
@@ -1,7 +1,11 @@
1
- import torch.nn as nn
1
+ from __future__ import annotations
2
+
3
+ from smftools.optional_imports import require
2
4
 
3
5
  from .base import BaseTorchModel
4
6
 
7
+ nn = require("torch.nn", extra="ml-base", purpose="MLP models")
8
+
5
9
 
6
10
  class MLPClassifier(BaseTorchModel):
7
11
  def __init__(
@@ -1,6 +1,11 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
- import torch
3
- import torch.nn as nn
4
+
5
+ from smftools.optional_imports import require
6
+
7
+ torch = require("torch", extra="ml-base", purpose="positional encoding")
8
+ nn = torch.nn
4
9
 
5
10
 
6
11
  class PositionalEncoding(nn.Module):
@@ -1,7 +1,11 @@
1
- import torch.nn as nn
1
+ from __future__ import annotations
2
+
3
+ from smftools.optional_imports import require
2
4
 
3
5
  from .base import BaseTorchModel
4
6
 
7
+ nn = require("torch.nn", extra="ml-base", purpose="RNN models")
8
+
5
9
 
6
10
  class RNNClassifier(BaseTorchModel):
7
11
  def __init__(self, input_size, hidden_dim, num_classes, **kwargs):
@@ -1,13 +1,18 @@
1
- import matplotlib.pyplot as plt
1
+ from __future__ import annotations
2
+
2
3
  import numpy as np
3
- from sklearn.metrics import (
4
- auc,
5
- confusion_matrix,
6
- f1_score,
7
- precision_recall_curve,
8
- roc_auc_score,
9
- roc_curve,
10
- )
4
+
5
+ from smftools.optional_imports import require
6
+
7
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="model evaluation plots")
8
+ sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model evaluation")
9
+
10
+ auc = sklearn_metrics.auc
11
+ confusion_matrix = sklearn_metrics.confusion_matrix
12
+ f1_score = sklearn_metrics.f1_score
13
+ precision_recall_curve = sklearn_metrics.precision_recall_curve
14
+ roc_auc_score = sklearn_metrics.roc_auc_score
15
+ roc_curve = sklearn_metrics.roc_curve
11
16
 
12
17
 
13
18
  class SklearnModelWrapper:
@@ -1,11 +1,16 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
- import torch
3
- import torch.nn as nn
4
+
5
+ from smftools.optional_imports import require
4
6
 
5
7
  from ..utils.grl import grad_reverse
6
8
  from .base import BaseTorchModel
7
9
  from .positional import PositionalEncoding
8
10
 
11
+ torch = require("torch", extra="ml-base", purpose="Transformer models")
12
+ nn = torch.nn
13
+
9
14
 
10
15
  class TransformerEncoderLayerWithAttn(nn.TransformerEncoderLayer):
11
16
  def __init__(self, *args, **kwargs):
@@ -1,5 +1,9 @@
1
- import torch
2
- import torch.nn as nn
1
+ from __future__ import annotations
2
+
3
+ from smftools.optional_imports import require
4
+
5
+ torch = require("torch", extra="ml-base", purpose="model wrappers")
6
+ nn = torch.nn
3
7
 
4
8
 
5
9
  class ScaledModel(nn.Module):
@@ -1,2 +1,4 @@
1
+ from __future__ import annotations
2
+
1
3
  from .train_lightning_model import run_sliding_window_lightning_training, train_lightning_model
2
4
  from .train_sklearn_model import run_sliding_window_sklearn_training, train_sklearn_model
@@ -1,10 +1,20 @@
1
- import torch
2
- from pytorch_lightning import Trainer
3
- from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
1
+ from __future__ import annotations
2
+
3
+ from smftools.optional_imports import require
4
4
 
5
5
  from ..data import AnnDataModule
6
6
  from ..models import TorchClassifierWrapper
7
7
 
8
+ torch = require("torch", extra="ml-base", purpose="Lightning training")
9
+ pytorch_lightning = require("pytorch_lightning", extra="ml-extended", purpose="Lightning training")
10
+ pl_callbacks = require(
11
+ "pytorch_lightning.callbacks", extra="ml-extended", purpose="Lightning training"
12
+ )
13
+
14
+ Trainer = pytorch_lightning.Trainer
15
+ EarlyStopping = pl_callbacks.EarlyStopping
16
+ ModelCheckpoint = pl_callbacks.ModelCheckpoint
17
+
8
18
 
9
19
  def train_lightning_model(
10
20
  model,