smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. smftools/__init__.py +6 -8
  2. smftools/_settings.py +4 -6
  3. smftools/_version.py +1 -1
  4. smftools/cli/helpers.py +54 -0
  5. smftools/cli/hmm_adata.py +937 -256
  6. smftools/cli/load_adata.py +448 -268
  7. smftools/cli/preprocess_adata.py +469 -263
  8. smftools/cli/spatial_adata.py +536 -319
  9. smftools/cli_entry.py +97 -182
  10. smftools/config/__init__.py +1 -1
  11. smftools/config/conversion.yaml +17 -6
  12. smftools/config/deaminase.yaml +12 -10
  13. smftools/config/default.yaml +142 -33
  14. smftools/config/direct.yaml +11 -3
  15. smftools/config/discover_input_files.py +19 -5
  16. smftools/config/experiment_config.py +594 -264
  17. smftools/constants.py +37 -0
  18. smftools/datasets/__init__.py +2 -8
  19. smftools/datasets/datasets.py +32 -18
  20. smftools/hmm/HMM.py +2128 -1418
  21. smftools/hmm/__init__.py +2 -9
  22. smftools/hmm/archived/call_hmm_peaks.py +121 -0
  23. smftools/hmm/call_hmm_peaks.py +299 -91
  24. smftools/hmm/display_hmm.py +19 -6
  25. smftools/hmm/hmm_readwrite.py +13 -4
  26. smftools/hmm/nucleosome_hmm_refinement.py +102 -14
  27. smftools/informatics/__init__.py +30 -7
  28. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  30. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  31. smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
  32. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
  33. smftools/informatics/archived/print_bam_query_seq.py +7 -1
  34. smftools/informatics/bam_functions.py +397 -175
  35. smftools/informatics/basecalling.py +51 -9
  36. smftools/informatics/bed_functions.py +90 -57
  37. smftools/informatics/binarize_converted_base_identities.py +18 -7
  38. smftools/informatics/complement_base_list.py +7 -6
  39. smftools/informatics/converted_BAM_to_adata.py +265 -122
  40. smftools/informatics/fasta_functions.py +161 -83
  41. smftools/informatics/h5ad_functions.py +196 -30
  42. smftools/informatics/modkit_extract_to_adata.py +609 -270
  43. smftools/informatics/modkit_functions.py +85 -44
  44. smftools/informatics/ohe.py +44 -21
  45. smftools/informatics/pod5_functions.py +112 -73
  46. smftools/informatics/run_multiqc.py +20 -14
  47. smftools/logging_utils.py +51 -0
  48. smftools/machine_learning/__init__.py +2 -7
  49. smftools/machine_learning/data/anndata_data_module.py +143 -50
  50. smftools/machine_learning/data/preprocessing.py +2 -1
  51. smftools/machine_learning/evaluation/__init__.py +1 -1
  52. smftools/machine_learning/evaluation/eval_utils.py +11 -14
  53. smftools/machine_learning/evaluation/evaluators.py +46 -33
  54. smftools/machine_learning/inference/__init__.py +1 -1
  55. smftools/machine_learning/inference/inference_utils.py +7 -4
  56. smftools/machine_learning/inference/lightning_inference.py +9 -13
  57. smftools/machine_learning/inference/sklearn_inference.py +6 -8
  58. smftools/machine_learning/inference/sliding_window_inference.py +35 -25
  59. smftools/machine_learning/models/__init__.py +10 -5
  60. smftools/machine_learning/models/base.py +28 -42
  61. smftools/machine_learning/models/cnn.py +15 -11
  62. smftools/machine_learning/models/lightning_base.py +71 -40
  63. smftools/machine_learning/models/mlp.py +13 -4
  64. smftools/machine_learning/models/positional.py +3 -2
  65. smftools/machine_learning/models/rnn.py +3 -2
  66. smftools/machine_learning/models/sklearn_models.py +39 -22
  67. smftools/machine_learning/models/transformer.py +68 -53
  68. smftools/machine_learning/models/wrappers.py +2 -1
  69. smftools/machine_learning/training/__init__.py +2 -2
  70. smftools/machine_learning/training/train_lightning_model.py +29 -20
  71. smftools/machine_learning/training/train_sklearn_model.py +9 -15
  72. smftools/machine_learning/utils/__init__.py +1 -1
  73. smftools/machine_learning/utils/device.py +7 -4
  74. smftools/machine_learning/utils/grl.py +3 -1
  75. smftools/metadata.py +443 -0
  76. smftools/plotting/__init__.py +19 -5
  77. smftools/plotting/autocorrelation_plotting.py +145 -44
  78. smftools/plotting/classifiers.py +162 -72
  79. smftools/plotting/general_plotting.py +422 -197
  80. smftools/plotting/hmm_plotting.py +42 -13
  81. smftools/plotting/position_stats.py +147 -87
  82. smftools/plotting/qc_plotting.py +20 -12
  83. smftools/preprocessing/__init__.py +10 -12
  84. smftools/preprocessing/append_base_context.py +115 -80
  85. smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
  86. smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
  87. smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
  88. smftools/preprocessing/binarize.py +21 -4
  89. smftools/preprocessing/binarize_on_Youden.py +129 -31
  90. smftools/preprocessing/binary_layers_to_ohe.py +17 -11
  91. smftools/preprocessing/calculate_complexity_II.py +86 -59
  92. smftools/preprocessing/calculate_consensus.py +28 -19
  93. smftools/preprocessing/calculate_coverage.py +50 -25
  94. smftools/preprocessing/calculate_pairwise_differences.py +2 -1
  95. smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
  96. smftools/preprocessing/calculate_position_Youden.py +118 -54
  97. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  98. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  99. smftools/preprocessing/clean_NaN.py +38 -28
  100. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  101. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
  102. smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
  103. smftools/preprocessing/flag_duplicate_reads.py +689 -272
  104. smftools/preprocessing/invert_adata.py +26 -11
  105. smftools/preprocessing/load_sample_sheet.py +40 -22
  106. smftools/preprocessing/make_dirs.py +8 -3
  107. smftools/preprocessing/min_non_diagonal.py +2 -1
  108. smftools/preprocessing/recipes.py +56 -23
  109. smftools/preprocessing/reindex_references_adata.py +103 -0
  110. smftools/preprocessing/subsample_adata.py +33 -16
  111. smftools/readwrite.py +331 -82
  112. smftools/schema/__init__.py +11 -0
  113. smftools/schema/anndata_schema_v1.yaml +227 -0
  114. smftools/tools/__init__.py +3 -4
  115. smftools/tools/archived/classifiers.py +163 -0
  116. smftools/tools/archived/subset_adata_v1.py +10 -1
  117. smftools/tools/archived/subset_adata_v2.py +12 -1
  118. smftools/tools/calculate_umap.py +54 -15
  119. smftools/tools/cluster_adata_on_methylation.py +115 -46
  120. smftools/tools/general_tools.py +70 -25
  121. smftools/tools/position_stats.py +229 -98
  122. smftools/tools/read_stats.py +50 -29
  123. smftools/tools/spatial_autocorrelation.py +365 -192
  124. smftools/tools/subset_adata.py +23 -21
  125. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
  126. smftools-0.2.5.dist-info/RECORD +181 -0
  127. smftools-0.2.3.dist-info/RECORD +0 -173
  128. /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
  129. /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
  130. /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
  131. /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
  132. /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
  133. /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
  134. /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
  135. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
  136. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
  137. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,25 +1,40 @@
1
1
  ## invert_adata
2
2
 
3
- def invert_adata(adata, uns_flag='adata_positions_inverted', force_redo=False):
4
- """
5
- Inverts the AnnData object along the column (variable) axis.
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ from smftools.logging_utils import get_logger
8
+
9
+ if TYPE_CHECKING:
10
+ import anndata as ad
6
11
 
7
- Parameters:
8
- adata (AnnData): An AnnData object.
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ def invert_adata(
16
+ adata: "ad.AnnData",
17
+ uns_flag: str = "invert_adata_performed",
18
+ force_redo: bool = False,
19
+ ) -> "ad.AnnData":
20
+ """Invert the AnnData object along the column axis.
21
+
22
+ Args:
23
+ adata: AnnData object.
24
+ uns_flag: Flag in ``adata.uns`` indicating prior completion.
25
+ force_redo: Whether to rerun even if ``uns_flag`` is set.
9
26
 
10
27
  Returns:
11
- AnnData: A new AnnData object with inverted column ordering.
28
+ anndata.AnnData: New AnnData object with inverted column ordering.
12
29
  """
13
- import numpy as np
14
- import anndata as ad
15
30
 
16
31
  # Only run if not already performed
17
32
  already = bool(adata.uns.get(uns_flag, False))
18
- if (already and not force_redo):
33
+ if already and not force_redo:
19
34
  # QC already performed; nothing to do
20
35
  return adata
21
36
 
22
- print("Inverting AnnData along the column axis...")
37
+ logger.info("Inverting AnnData along the column axis...")
23
38
 
24
39
  # Reverse the order of columns (variables)
25
40
  inverted_adata = adata[:, ::-1].copy()
@@ -33,5 +48,5 @@ def invert_adata(adata, uns_flag='adata_positions_inverted', force_redo=False):
33
48
  # mark as done
34
49
  inverted_adata.uns[uns_flag] = True
35
50
 
36
- print("Inversion complete!")
51
+ logger.info("Inversion complete!")
37
52
  return inverted_adata
@@ -1,21 +1,36 @@
1
- def load_sample_sheet(adata,
2
- sample_sheet_path,
3
- mapping_key_column='obs_names',
4
- as_category=True,
5
- uns_flag='sample_sheet_loaded',
6
- force_reload=True
7
- ):
8
- """
9
- Loads a sample sheet CSV and maps metadata into the AnnData object as categorical columns.
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING
5
+
6
+ from smftools.logging_utils import get_logger
7
+
8
+ if TYPE_CHECKING:
9
+ import anndata as ad
10
+
11
+ logger = get_logger(__name__)
10
12
 
11
- Parameters:
12
- adata (AnnData): The AnnData object to append sample information to.
13
- sample_sheet_path (str): Path to the CSV file.
14
- mapping_key_column (str): Column name in the CSV to map against adata.obs_names or an existing obs column.
15
- as_category (bool): If True, added columns will be cast as pandas Categorical.
13
+
14
+ def load_sample_sheet(
15
+ adata: "ad.AnnData",
16
+ sample_sheet_path: str | Path,
17
+ mapping_key_column: str = "obs_names",
18
+ as_category: bool = True,
19
+ uns_flag: str = "load_sample_sheet_performed",
20
+ force_reload: bool = True,
21
+ ) -> "ad.AnnData":
22
+ """Load a sample sheet CSV and map metadata into ``adata.obs``.
23
+
24
+ Args:
25
+ adata: AnnData object to append sample information to.
26
+ sample_sheet_path: Path to the CSV file.
27
+ mapping_key_column: Column name to map against ``adata.obs_names`` or an obs column.
28
+ as_category: Whether to cast added columns as pandas Categoricals.
29
+ uns_flag: Flag in ``adata.uns`` indicating prior completion.
30
+ force_reload: Whether to reload even if ``uns_flag`` is set.
16
31
 
17
32
  Returns:
18
- AnnData: Updated AnnData object.
33
+ anndata.AnnData: Updated AnnData object.
19
34
  """
20
35
  import pandas as pd
21
36
 
@@ -25,29 +40,32 @@ def load_sample_sheet(adata,
25
40
  # QC already performed; nothing to do
26
41
  return
27
42
 
28
- print('Loading sample sheet...')
43
+ logger.info("Loading sample sheet...")
29
44
  df = pd.read_csv(sample_sheet_path)
30
45
  df[mapping_key_column] = df[mapping_key_column].astype(str)
31
-
46
+
32
47
  # If matching against obs_names directly
33
- if mapping_key_column == 'obs_names':
48
+ if mapping_key_column == "obs_names":
34
49
  key_series = adata.obs_names.astype(str)
35
50
  else:
36
51
  key_series = adata.obs[mapping_key_column].astype(str)
37
52
 
38
53
  value_columns = [col for col in df.columns if col != mapping_key_column]
39
-
40
- print(f'Appending metadata columns: {value_columns}')
54
+
55
+ logger.info("Appending metadata columns: %s", value_columns)
41
56
  df = df.set_index(mapping_key_column)
42
57
 
43
58
  for col in value_columns:
44
59
  mapped = key_series.map(df[col])
45
60
  if as_category:
46
- mapped = mapped.astype('category')
61
+ mapped = mapped.astype("category")
47
62
  adata.obs[col] = mapped
48
63
 
49
64
  # mark as done
50
65
  adata.uns[uns_flag] = True
51
66
 
52
- print('Sample sheet metadata successfully added as categories.' if as_category else 'Metadata added.')
67
+ if as_category:
68
+ logger.info("Sample sheet metadata successfully added as categories.")
69
+ else:
70
+ logger.info("Metadata added.")
53
71
  return adata
@@ -1,5 +1,10 @@
1
1
  ## make_dirs
2
2
 
3
+ from smftools.logging_utils import get_logger
4
+
5
+ logger = get_logger(__name__)
6
+
7
+
3
8
  # General
4
9
  def make_dirs(directories):
5
10
  """
@@ -7,7 +12,7 @@ def make_dirs(directories):
7
12
 
8
13
  Parameters:
9
14
  directories (list): A list of directories to make
10
-
15
+
11
16
  Returns:
12
17
  None
13
18
  """
@@ -16,6 +21,6 @@ def make_dirs(directories):
16
21
  for directory in directories:
17
22
  if not os.path.isdir(directory):
18
23
  os.mkdir(directory)
19
- print(f"Directory '{directory}' created successfully.")
24
+ logger.info("Directory '%s' created successfully.", directory)
20
25
  else:
21
- print(f"Directory '{directory}' already exists.")
26
+ logger.info("Directory '%s' already exists.", directory)
@@ -1,5 +1,6 @@
1
1
  ## min_non_diagonal
2
2
 
3
+
3
4
  def min_non_diagonal(matrix):
4
5
  """
5
6
  Takes a matrix and returns the smallest value from each row with the diagonal masked.
@@ -22,4 +23,4 @@ def min_non_diagonal(matrix):
22
23
  row = matrix[i, row_mask]
23
24
  # Find the minimum value in the row
24
25
  min_values.append(np.min(row))
25
- return min_values
26
+ return min_values
@@ -1,6 +1,15 @@
1
1
  # recipes
2
2
 
3
- def recipe_1_Kissiov_and_McKenna_2025(adata, sample_sheet_path, output_directory, mapping_key_column='Sample', reference_column = 'Reference', sample_names_col='Sample_names', invert=True):
3
+
4
+ def recipe_1_Kissiov_and_McKenna_2025(
5
+ adata,
6
+ sample_sheet_path,
7
+ output_directory,
8
+ mapping_key_column="Sample",
9
+ reference_column="Reference",
10
+ sample_names_col="Sample_names",
11
+ invert=True,
12
+ ):
4
13
  """
5
14
  The first part of the preprocessing workflow applied to the smf.inform.pod_to_adata() output derived from Kissiov_and_McKenna_2025.
6
15
 
@@ -26,36 +35,38 @@ def recipe_1_Kissiov_and_McKenna_2025(adata, sample_sheet_path, output_directory
26
35
  Returns:
27
36
  variables (dict): A dictionary of variables to append to the parent scope.
28
37
  """
29
- import anndata as ad
30
- import pandas as pd
31
- import numpy as np
32
- from .load_sample_sheet import load_sample_sheet
33
- from .calculate_coverage import calculate_coverage
38
+
34
39
  from .append_C_context import append_C_context
35
- from .calculate_converted_read_methylation_stats import calculate_converted_read_methylation_stats
36
- from .invert_adata import invert_adata
40
+ from .calculate_converted_read_methylation_stats import (
41
+ calculate_converted_read_methylation_stats,
42
+ )
43
+ from .calculate_coverage import calculate_coverage
37
44
  from .calculate_read_length_stats import calculate_read_length_stats
38
45
  from .clean_NaN import clean_NaN
46
+ from .invert_adata import invert_adata
47
+ from .load_sample_sheet import load_sample_sheet
39
48
 
40
49
  # Clean up some of the Reference metadata and save variable names that point to sets of values in the column.
41
- adata.obs[reference_column] = adata.obs[reference_column].astype('category')
50
+ adata.obs[reference_column] = adata.obs[reference_column].astype("category")
42
51
  references = adata.obs[reference_column].cat.categories
43
- split_references = [(reference, reference.split('_')[0][1:]) for reference in references]
52
+ split_references = [(reference, reference.split("_")[0][1:]) for reference in references]
44
53
  reference_mapping = {k: v for k, v in split_references}
45
- adata.obs[f'{reference_column}_short'] = adata.obs[reference_column].map(reference_mapping)
46
- short_references = set(adata.obs[f'{reference_column}_short'])
54
+ adata.obs[f"{reference_column}_short"] = adata.obs[reference_column].map(reference_mapping)
55
+ short_references = set(adata.obs[f"{reference_column}_short"])
47
56
  binary_layers = list(adata.layers.keys())
48
57
 
49
58
  # load sample sheet metadata
50
59
  load_sample_sheet(adata, sample_sheet_path, mapping_key_column)
51
60
 
52
61
  # hold sample names set
53
- adata.obs[sample_names_col] = adata.obs[sample_names_col].astype('category')
62
+ adata.obs[sample_names_col] = adata.obs[sample_names_col].astype("category")
54
63
  sample_names = adata.obs[sample_names_col].cat.categories
55
64
 
56
65
  # Add position level metadata
57
66
  calculate_coverage(adata, obs_column=reference_column)
58
- adata.var['SNP_position'] = (adata.var[f'N_{reference_column}_with_position'] > 0) & (adata.var[f'N_{reference_column}_with_position'] < len(references)).astype(bool)
67
+ adata.var["SNP_position"] = (adata.var[f"N_{reference_column}_with_position"] > 0) & (
68
+ adata.var[f"N_{reference_column}_with_position"] < len(references)
69
+ ).astype(bool)
59
70
 
60
71
  # Append cytosine context to the reference positions based on the conversion strand.
61
72
  append_C_context(adata, obs_column=reference_column, use_consensus=False)
@@ -64,7 +75,9 @@ def recipe_1_Kissiov_and_McKenna_2025(adata, sample_sheet_path, output_directory
64
75
  calculate_converted_read_methylation_stats(adata, reference_column, sample_names_col)
65
76
 
66
77
  # Calculate read length statistics
67
- upper_bound, lower_bound = calculate_read_length_stats(adata, reference_column, sample_names_col)
78
+ upper_bound, lower_bound = calculate_read_length_stats(
79
+ adata, reference_column, sample_names_col
80
+ )
68
81
 
69
82
  # Invert the adata object (ie flip the strand orientation for visualization)
70
83
  if invert:
@@ -81,11 +94,19 @@ def recipe_1_Kissiov_and_McKenna_2025(adata, sample_sheet_path, output_directory
81
94
  "sample_names": sample_names,
82
95
  "upper_bound": upper_bound,
83
96
  "lower_bound": lower_bound,
84
- "references": references
97
+ "references": references,
85
98
  }
86
99
  return variables
87
100
 
88
- def recipe_2_Kissiov_and_McKenna_2025(adata, output_directory, binary_layers, distance_thresholds={}, reference_column = 'Reference', sample_names_col='Sample_names'):
101
+
102
+ def recipe_2_Kissiov_and_McKenna_2025(
103
+ adata,
104
+ output_directory,
105
+ binary_layers,
106
+ distance_thresholds={},
107
+ reference_column="Reference",
108
+ sample_names_col="Sample_names",
109
+ ):
89
110
  """
90
111
  The second part of the preprocessing workflow applied to the adata that has already been preprocessed by recipe_1_Kissiov_and_McKenna_2025.
91
112
 
@@ -107,20 +128,32 @@ def recipe_2_Kissiov_and_McKenna_2025(adata, output_directory, binary_layers, di
107
128
  filtered_adata (AnnData): An AnnData object containing the filtered reads
108
129
  duplicates (AnnData): An AnnData object containing the duplicate reads
109
130
  """
110
- import anndata as ad
111
- import pandas as pd
112
- import numpy as np
113
- from .mark_duplicates import mark_duplicates
131
+
114
132
  from .calculate_complexity import calculate_complexity
133
+ from .mark_duplicates import mark_duplicates
115
134
  from .remove_duplicates import remove_duplicates
116
135
 
117
136
  # Add here a way to remove reads below a given read quality (based on nan content). Need to also add a way to pull from BAM files the read quality from each read
118
137
 
119
138
  # Duplicate detection using pairwise hamming distance across reads
120
- mark_duplicates(adata, binary_layers, obs_column=reference_column, sample_col=sample_names_col, distance_thresholds=distance_thresholds, method='N_masked_distances')
139
+ mark_duplicates(
140
+ adata,
141
+ binary_layers,
142
+ obs_column=reference_column,
143
+ sample_col=sample_names_col,
144
+ distance_thresholds=distance_thresholds,
145
+ method="N_masked_distances",
146
+ )
121
147
 
122
148
  # Complexity analysis using the marked duplicates and the lander-watermann algorithm
123
- calculate_complexity(adata, output_directory, obs_column=reference_column, sample_col=sample_names_col, plot=True, save_plot=False)
149
+ calculate_complexity(
150
+ adata,
151
+ output_directory,
152
+ obs_column=reference_column,
153
+ sample_col=sample_names_col,
154
+ plot=True,
155
+ save_plot=False,
156
+ )
124
157
 
125
158
  # Remove duplicate reads and store the duplicate reads in a new AnnData object named duplicates.
126
159
  filtered_adata, duplicates = remove_duplicates(adata)
@@ -0,0 +1,103 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from smftools.logging_utils import get_logger
6
+
7
+ if TYPE_CHECKING:
8
+ import anndata as ad
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ def reindex_references_adata(
14
+ adata: "ad.AnnData",
15
+ reference_col: str = "Reference_strand",
16
+ offsets: dict | None = None,
17
+ new_col: str = "reindexed",
18
+ uns_flag: str = "reindex_references_adata_performed",
19
+ force_redo: bool = False,
20
+ ) -> None:
21
+ """Reindex genomic coordinates by adding per-reference offsets.
22
+
23
+ Args:
24
+ adata: AnnData object.
25
+ reference_col: Obs column containing reference identifiers.
26
+ offsets: Mapping of reference to integer offset.
27
+ new_col: Suffix for generated reindexed columns.
28
+ uns_flag: Flag in ``adata.uns`` indicating prior completion.
29
+ force_redo: Whether to rerun even if ``uns_flag`` is set.
30
+
31
+ Notes:
32
+ If ``offsets`` is ``None`` or missing a reference, the new column mirrors
33
+ the existing ``var_names`` values.
34
+ """
35
+
36
+ import numpy as np
37
+
38
+ # ============================================================
39
+ # 1. Skip if already done
40
+ # ============================================================
41
+ already = bool(adata.uns.get(uns_flag, False))
42
+ if already and not force_redo:
43
+ logger.info("%s already set; skipping. Use force_redo=True to recompute.", uns_flag)
44
+ return None
45
+
46
+ # Normalize offsets
47
+ if offsets is None:
48
+ offsets = {}
49
+ elif not isinstance(offsets, dict):
50
+ raise TypeError("offsets must be a dict {ref: int} or None.")
51
+
52
+ # ============================================================
53
+ # 2. Ensure var_names are numeric
54
+ # ============================================================
55
+ try:
56
+ var_coords = adata.var_names.astype(int)
57
+ except Exception as e:
58
+ raise ValueError(
59
+ "reindex_references_adata requires adata.var_names to be integer-like."
60
+ ) from e
61
+
62
+ # ============================================================
63
+ # 3. Gather all references
64
+ # ============================================================
65
+ ref_series = adata.obs[reference_col]
66
+ references = ref_series.cat.categories if hasattr(ref_series, "cat") else ref_series.unique()
67
+
68
+ # ============================================================
69
+ # 4. Create reindexed columns
70
+ # ============================================================
71
+ for ref in references:
72
+ colname = f"{ref}_{new_col}"
73
+
74
+ # Case 1: No offset provided → identity mapping
75
+ if ref not in offsets:
76
+ logger.info("No offset for ref=%r; using identity positions.", ref)
77
+ adata.var[colname] = var_coords
78
+ continue
79
+
80
+ offset_value = offsets[ref]
81
+
82
+ # Case 2: offset explicitly None → identity mapping
83
+ if offset_value is None:
84
+ logger.info("Offset for ref=%r is None; using identity positions.", ref)
85
+ adata.var[colname] = var_coords
86
+ continue
87
+
88
+ # Case 3: real shift
89
+ if not isinstance(offset_value, (int, np.integer)):
90
+ raise TypeError(
91
+ f"Offset for reference {ref!r} must be an integer or None. Got {offset_value!r}"
92
+ )
93
+
94
+ adata.var[colname] = var_coords + offset_value
95
+ logger.info("Added reindexed column '%s' (offset=%s).", colname, offset_value)
96
+
97
+ # ============================================================
98
+ # 5. Mark complete
99
+ # ============================================================
100
+ adata.uns[uns_flag] = True
101
+ logger.info("Reindexing complete!")
102
+
103
+ return None
@@ -1,19 +1,36 @@
1
- def subsample_adata(adata, obs_columns=None, max_samples=2000, random_seed=42):
2
- """
3
- Subsamples an AnnData object so that each unique combination of categories
4
- in the given `obs_columns` has at most `max_samples` observations.
5
- If `obs_columns` is None or empty, the function randomly subsamples the entire dataset.
6
-
7
- Parameters:
8
- adata (AnnData): The AnnData object to subsample.
9
- obs_columns (list of str, optional): List of observation column names to group by.
10
- max_samples (int): The maximum number of observations per category combination.
11
- random_seed (int): Random seed for reproducibility.
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Sequence
4
+
5
+ from smftools.logging_utils import get_logger
6
+
7
+ if TYPE_CHECKING:
8
+ import anndata as ad
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ def subsample_adata(
14
+ adata: "ad.AnnData",
15
+ obs_columns: Sequence[str] | None = None,
16
+ max_samples: int = 2000,
17
+ random_seed: int = 42,
18
+ ) -> "ad.AnnData":
19
+ """Subsample an AnnData object by observation categories.
20
+
21
+ Each unique combination of categories in ``obs_columns`` is capped at
22
+ ``max_samples`` observations. If ``obs_columns`` is ``None``, the function
23
+ randomly subsamples the entire dataset.
24
+
25
+ Args:
26
+ adata: AnnData object to subsample.
27
+ obs_columns: Observation column names to group by.
28
+ max_samples: Maximum observations per category combination.
29
+ random_seed: Random seed for reproducibility.
12
30
 
13
31
  Returns:
14
- AnnData: A new AnnData object with subsampled observations.
32
+ anndata.AnnData: Subsampled AnnData object.
15
33
  """
16
- import anndata as ad
17
34
  import numpy as np
18
35
 
19
36
  np.random.seed(random_seed) # Ensure reproducibility
@@ -23,7 +40,7 @@ def subsample_adata(adata, obs_columns=None, max_samples=2000, random_seed=42):
23
40
  sampled_indices = np.random.choice(adata.obs.index, max_samples, replace=False)
24
41
  else:
25
42
  sampled_indices = adata.obs.index # Keep all if fewer than max_samples
26
-
43
+
27
44
  return adata[sampled_indices].copy()
28
45
 
29
46
  sampled_indices = []
@@ -34,7 +51,7 @@ def subsample_adata(adata, obs_columns=None, max_samples=2000, random_seed=42):
34
51
  for _, row in unique_combinations.iterrows():
35
52
  # Build filter condition dynamically for multiple columns
36
53
  condition = (adata.obs[obs_columns] == row.values).all(axis=1)
37
-
54
+
38
55
  # Get indices for the current category combination
39
56
  subset_indices = adata.obs[condition].index.to_numpy()
40
57
 
@@ -48,7 +65,7 @@ def subsample_adata(adata, obs_columns=None, max_samples=2000, random_seed=42):
48
65
 
49
66
  # ⚠ Handle backed mode detection
50
67
  if adata.isbacked:
51
- print("Detected backed mode. Subset will be loaded fully into memory.")
68
+ logger.warning("Detected backed mode. Subset will be loaded fully into memory.")
52
69
  subset = adata[sampled_indices]
53
70
  subset = subset.to_memory()
54
71
  else: