smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. smftools/__init__.py +43 -13
  2. smftools/_settings.py +6 -6
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +9 -1
  7. smftools/cli/hmm_adata.py +905 -242
  8. smftools/cli/load_adata.py +432 -280
  9. smftools/cli/preprocess_adata.py +287 -171
  10. smftools/cli/spatial_adata.py +141 -53
  11. smftools/cli_entry.py +119 -178
  12. smftools/config/__init__.py +3 -1
  13. smftools/config/conversion.yaml +5 -1
  14. smftools/config/deaminase.yaml +1 -1
  15. smftools/config/default.yaml +26 -18
  16. smftools/config/direct.yaml +8 -3
  17. smftools/config/discover_input_files.py +19 -5
  18. smftools/config/experiment_config.py +511 -276
  19. smftools/constants.py +37 -0
  20. smftools/datasets/__init__.py +4 -8
  21. smftools/datasets/datasets.py +32 -18
  22. smftools/hmm/HMM.py +2133 -1428
  23. smftools/hmm/__init__.py +24 -14
  24. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  25. smftools/hmm/archived/calculate_distances.py +2 -0
  26. smftools/hmm/archived/call_hmm_peaks.py +18 -1
  27. smftools/hmm/archived/train_hmm.py +2 -0
  28. smftools/hmm/call_hmm_peaks.py +176 -193
  29. smftools/hmm/display_hmm.py +23 -7
  30. smftools/hmm/hmm_readwrite.py +20 -6
  31. smftools/hmm/nucleosome_hmm_refinement.py +104 -14
  32. smftools/informatics/__init__.py +55 -13
  33. smftools/informatics/archived/bam_conversion.py +2 -0
  34. smftools/informatics/archived/bam_direct.py +2 -0
  35. smftools/informatics/archived/basecall_pod5s.py +2 -0
  36. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  37. smftools/informatics/archived/conversion_smf.py +2 -0
  38. smftools/informatics/archived/deaminase_smf.py +1 -0
  39. smftools/informatics/archived/direct_smf.py +2 -0
  40. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  41. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  42. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
  43. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  44. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  45. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  46. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  47. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  48. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  49. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  50. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  52. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  53. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  54. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  55. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  56. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  57. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  58. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  59. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  60. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  61. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  62. smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
  63. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  64. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  65. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  66. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  67. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  68. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  69. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  70. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
  71. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  72. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  73. smftools/informatics/archived/print_bam_query_seq.py +9 -1
  74. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  75. smftools/informatics/archived/subsample_pod5.py +2 -0
  76. smftools/informatics/bam_functions.py +1059 -269
  77. smftools/informatics/basecalling.py +53 -9
  78. smftools/informatics/bed_functions.py +357 -114
  79. smftools/informatics/binarize_converted_base_identities.py +21 -7
  80. smftools/informatics/complement_base_list.py +9 -6
  81. smftools/informatics/converted_BAM_to_adata.py +324 -137
  82. smftools/informatics/fasta_functions.py +251 -89
  83. smftools/informatics/h5ad_functions.py +202 -30
  84. smftools/informatics/modkit_extract_to_adata.py +623 -274
  85. smftools/informatics/modkit_functions.py +87 -44
  86. smftools/informatics/ohe.py +46 -21
  87. smftools/informatics/pod5_functions.py +114 -74
  88. smftools/informatics/run_multiqc.py +20 -14
  89. smftools/logging_utils.py +51 -0
  90. smftools/machine_learning/__init__.py +23 -12
  91. smftools/machine_learning/data/__init__.py +2 -0
  92. smftools/machine_learning/data/anndata_data_module.py +157 -50
  93. smftools/machine_learning/data/preprocessing.py +4 -1
  94. smftools/machine_learning/evaluation/__init__.py +3 -1
  95. smftools/machine_learning/evaluation/eval_utils.py +13 -14
  96. smftools/machine_learning/evaluation/evaluators.py +52 -34
  97. smftools/machine_learning/inference/__init__.py +3 -1
  98. smftools/machine_learning/inference/inference_utils.py +9 -4
  99. smftools/machine_learning/inference/lightning_inference.py +14 -13
  100. smftools/machine_learning/inference/sklearn_inference.py +8 -8
  101. smftools/machine_learning/inference/sliding_window_inference.py +37 -25
  102. smftools/machine_learning/models/__init__.py +12 -5
  103. smftools/machine_learning/models/base.py +34 -43
  104. smftools/machine_learning/models/cnn.py +22 -13
  105. smftools/machine_learning/models/lightning_base.py +78 -42
  106. smftools/machine_learning/models/mlp.py +18 -5
  107. smftools/machine_learning/models/positional.py +10 -4
  108. smftools/machine_learning/models/rnn.py +8 -3
  109. smftools/machine_learning/models/sklearn_models.py +46 -24
  110. smftools/machine_learning/models/transformer.py +75 -55
  111. smftools/machine_learning/models/wrappers.py +8 -3
  112. smftools/machine_learning/training/__init__.py +4 -2
  113. smftools/machine_learning/training/train_lightning_model.py +42 -23
  114. smftools/machine_learning/training/train_sklearn_model.py +11 -15
  115. smftools/machine_learning/utils/__init__.py +3 -1
  116. smftools/machine_learning/utils/device.py +12 -5
  117. smftools/machine_learning/utils/grl.py +8 -2
  118. smftools/metadata.py +443 -0
  119. smftools/optional_imports.py +31 -0
  120. smftools/plotting/__init__.py +32 -17
  121. smftools/plotting/autocorrelation_plotting.py +153 -48
  122. smftools/plotting/classifiers.py +175 -73
  123. smftools/plotting/general_plotting.py +350 -168
  124. smftools/plotting/hmm_plotting.py +53 -14
  125. smftools/plotting/position_stats.py +155 -87
  126. smftools/plotting/qc_plotting.py +25 -12
  127. smftools/preprocessing/__init__.py +35 -37
  128. smftools/preprocessing/append_base_context.py +105 -79
  129. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  130. smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
  131. smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
  132. smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
  133. smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
  134. smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
  135. smftools/preprocessing/binarize.py +21 -4
  136. smftools/preprocessing/binarize_on_Youden.py +127 -31
  137. smftools/preprocessing/binary_layers_to_ohe.py +18 -11
  138. smftools/preprocessing/calculate_complexity_II.py +89 -59
  139. smftools/preprocessing/calculate_consensus.py +28 -19
  140. smftools/preprocessing/calculate_coverage.py +44 -22
  141. smftools/preprocessing/calculate_pairwise_differences.py +4 -1
  142. smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
  143. smftools/preprocessing/calculate_position_Youden.py +110 -55
  144. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  145. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  146. smftools/preprocessing/clean_NaN.py +38 -28
  147. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  148. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
  149. smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
  150. smftools/preprocessing/flag_duplicate_reads.py +708 -303
  151. smftools/preprocessing/invert_adata.py +26 -11
  152. smftools/preprocessing/load_sample_sheet.py +40 -22
  153. smftools/preprocessing/make_dirs.py +9 -3
  154. smftools/preprocessing/min_non_diagonal.py +4 -1
  155. smftools/preprocessing/recipes.py +58 -23
  156. smftools/preprocessing/reindex_references_adata.py +93 -27
  157. smftools/preprocessing/subsample_adata.py +33 -16
  158. smftools/readwrite.py +264 -109
  159. smftools/schema/__init__.py +11 -0
  160. smftools/schema/anndata_schema_v1.yaml +227 -0
  161. smftools/tools/__init__.py +25 -18
  162. smftools/tools/archived/apply_hmm.py +2 -0
  163. smftools/tools/archived/classifiers.py +165 -0
  164. smftools/tools/archived/classify_methylated_features.py +2 -0
  165. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  166. smftools/tools/archived/subset_adata_v1.py +12 -1
  167. smftools/tools/archived/subset_adata_v2.py +14 -1
  168. smftools/tools/calculate_umap.py +56 -15
  169. smftools/tools/cluster_adata_on_methylation.py +122 -47
  170. smftools/tools/general_tools.py +70 -25
  171. smftools/tools/position_stats.py +220 -99
  172. smftools/tools/read_stats.py +50 -29
  173. smftools/tools/spatial_autocorrelation.py +365 -192
  174. smftools/tools/subset_adata.py +23 -21
  175. smftools-0.3.0.dist-info/METADATA +147 -0
  176. smftools-0.3.0.dist-info/RECORD +182 -0
  177. smftools-0.2.4.dist-info/METADATA +0 -141
  178. smftools-0.2.4.dist-info/RECORD +0 -176
  179. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
  180. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
  181. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  ## preprocessing
2
4
  from .. import readwrite
3
5
 
@@ -322,12 +324,14 @@ def min_non_diagonal(matrix):
322
324
  min_values.append(np.min(row))
323
325
  return min_values
324
326
 
325
- def lander_waterman(x, C0):
326
- return C0 * (1 - np.exp(-x / C0))
327
+ def lander_waterman(x, C0):
328
+ """Lander-Waterman curve for complexity estimation."""
329
+ return C0 * (1 - np.exp(-x / C0))
327
330
 
328
- def count_unique_reads(reads, depth):
329
- subsample = np.random.choice(reads, depth, replace=False)
330
- return len(np.unique(subsample))
331
+ def count_unique_reads(reads, depth):
332
+ """Count unique reads in a subsample of the given depth."""
333
+ subsample = np.random.choice(reads, depth, replace=False)
334
+ return len(np.unique(subsample))
331
335
 
332
336
  def mark_duplicates(adata, layers, obs_column='Reference', sample_col='Sample_names'):
333
337
  """
@@ -611,4 +615,4 @@ def binarize_on_Youden(adata, obs_column='Reference'):
611
615
  # Pull back the new binarized layers into the original adata object
612
616
  adata.layers['binarized_methylation'] = temp_adata.layers['binarized_methylation']
613
617
 
614
- ######################################################################################################
618
+ ######################################################################################################
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  # remove_duplicates
2
4
 
3
5
  def remove_duplicates(adata):
@@ -1,9 +1,26 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
1
5
  import numpy as np
2
6
 
3
- def binarize_adata(adata, source="X", target_layer="binary", threshold=0.8):
4
- """
5
- Binarize a dense matrix and preserve NaN.
6
- source: "X" or layer name
7
+ if TYPE_CHECKING:
8
+ import anndata as ad
9
+
10
+
11
+ def binarize_adata(
12
+ adata: "ad.AnnData",
13
+ source: str = "X",
14
+ target_layer: str = "binary",
15
+ threshold: float = 0.8,
16
+ ) -> None:
17
+ """Binarize a dense matrix and preserve NaNs.
18
+
19
+ Args:
20
+ adata: AnnData object with input matrix or layer.
21
+ source: ``"X"`` to use the main matrix or a layer name.
22
+ target_layer: Layer name to store the binarized values.
23
+ threshold: Threshold above which values are set to 1.
7
24
  """
8
25
  X = adata.X if source == "X" else adata.layers[source]
9
26
 
@@ -1,47 +1,143 @@
1
- def binarize_on_Youden(adata,
2
- ref_column='Reference_strand',
3
- output_layer_name='binarized_methylation'):
4
- """
5
- Binarize SMF values based on position thresholds determined by calculate_position_Youden.
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from smftools.logging_utils import get_logger
6
+
7
+ if TYPE_CHECKING:
8
+ import anndata as ad
6
9
 
7
- Parameters:
8
- adata (AnnData): The anndata object to binarize. `calculate_position_Youden` must have been run first.
9
- obs_column (str): The obs column to stratify on. Needs to match what was passed in `calculate_position_Youden`.
10
+ logger = get_logger(__name__)
10
11
 
11
- Modifies:
12
- Adds a new layer to `adata.layers['binarized_methylation']` containing the binarized methylation matrix.
12
+
13
+ def binarize_on_Youden(
14
+ adata: "ad.AnnData",
15
+ ref_column: str = "Reference_strand",
16
+ output_layer_name: str = "binarized_methylation",
17
+ mask_failed_positions: bool = True,
18
+ ) -> None:
19
+ """Binarize SMF values using thresholds from ``calculate_position_Youden``.
20
+
21
+ Args:
22
+ adata: AnnData object to binarize.
23
+ ref_column: Obs column denoting reference/strand categories.
24
+ output_layer_name: Layer in which to store the binarized matrix.
25
+ mask_failed_positions: If ``True``, positions that failed Youden QC are set to NaN;
26
+ otherwise all positions are binarized.
13
27
  """
28
+
14
29
  import numpy as np
15
- import anndata as ad
16
30
 
17
- # Initialize an empty matrix to store the binarized methylation values
18
- binarized_methylation = np.full_like(adata.X, np.nan, dtype=float) # Keeps same shape as adata.X
31
+ # Extract dense X once
32
+ X = adata.X
33
+ if hasattr(X, "toarray"): # sparse → dense
34
+ X = X.toarray()
35
+
36
+ n_obs, n_var = X.shape
37
+ binarized = np.full((n_obs, n_var), np.nan, dtype=float)
19
38
 
20
- # Get unique categories
21
39
  references = adata.obs[ref_column].cat.categories
40
+ ref_labels = adata.obs[ref_column].to_numpy()
22
41
 
23
42
  for ref in references:
24
- # Select subset for this category
25
- ref_mask = adata.obs[ref_column] == ref
26
- ref_subset = adata[ref_mask]
43
+ logger.info("Binarizing on Youden statistics for %s", ref)
44
+
45
+ ref_mask = ref_labels == ref
46
+ if not np.any(ref_mask):
47
+ continue
48
+
49
+ X_block = X[ref_mask, :].astype(float, copy=True)
50
+
51
+ # thresholds: list of (threshold, J)
52
+ youden_stats = adata.var[f"{ref}_position_methylation_thresholding_Youden_stats"].to_numpy()
53
+
54
+ thresholds = np.array(
55
+ [t[0] if isinstance(t, (tuple, list)) else np.nan for t in youden_stats],
56
+ dtype=float,
57
+ )
58
+
59
+ # QC mask
60
+ qc_mask = adata.var[f"{ref}_position_passed_Youden_thresholding_QC"].to_numpy().astype(bool)
61
+
62
+ if mask_failed_positions:
63
+ # Only binarize positions passing QC
64
+ cols_to_binarize = np.where(qc_mask)[0]
65
+ else:
66
+ # Binarize all positions
67
+ cols_to_binarize = np.arange(n_var)
68
+
69
+ # Prepare result block
70
+ block_out = np.full_like(X_block, np.nan, dtype=float)
71
+
72
+ if len(cols_to_binarize) > 0:
73
+ sub_X = X_block[:, cols_to_binarize]
74
+ sub_thresh = thresholds[cols_to_binarize]
75
+
76
+ nan_mask = np.isnan(sub_X)
77
+
78
+ bin_sub = (sub_X > sub_thresh[None, :]).astype(float)
79
+ bin_sub[nan_mask] = np.nan
80
+
81
+ block_out[:, cols_to_binarize] = bin_sub
82
+
83
+ # Write into full output matrix
84
+ binarized[ref_mask, :] = block_out
85
+
86
+ adata.layers[output_layer_name] = binarized
87
+ logger.info(
88
+ "Finished binarization → stored in adata.layers['%s'] (mask_failed_positions=%s)",
89
+ output_layer_name,
90
+ mask_failed_positions,
91
+ )
92
+
93
+
94
+ # def binarize_on_Youden(adata,
95
+ # ref_column='Reference_strand',
96
+ # output_layer_name='binarized_methylation'):
97
+ # """
98
+ # Binarize SMF values based on position thresholds determined by calculate_position_Youden.
99
+
100
+ # Parameters:
101
+ # adata (AnnData): The anndata object to binarize. `calculate_position_Youden` must have been run first.
102
+ # obs_column (str): The obs column to stratify on. Needs to match what was passed in `calculate_position_Youden`.
103
+
104
+ # Modifies:
105
+ # Adds a new layer to `adata.layers['binarized_methylation']` containing the binarized methylation matrix.
106
+ # """
107
+ # import numpy as np
108
+ # import anndata as ad
109
+
110
+ # # Initialize an empty matrix to store the binarized methylation values
111
+ # binarized_methylation = np.full_like(adata.X, np.nan, dtype=float) # Keeps same shape as adata.X
112
+
113
+ # # Get unique categories
114
+ # references = adata.obs[ref_column].cat.categories
115
+
116
+ # for ref in references:
117
+ # print(f"Binarizing adata on Youden statistics for {ref}")
118
+ # # Select subset for this category
119
+ # ref_mask = adata.obs[ref_column] == ref
120
+ # ref_subset = adata[ref_mask]
121
+
122
+ # # Extract the probability matrix
123
+ # original_matrix = ref_subset.X.copy()
27
124
 
28
- # Extract the probability matrix
29
- original_matrix = ref_subset.X.copy()
125
+ # # Extract the thresholds for each position efficiently
126
+ # thresholds = np.array(ref_subset.var[f'{ref}_position_methylation_thresholding_Youden_stats'].apply(lambda x: x[0]))
30
127
 
31
- # Extract the thresholds for each position efficiently
32
- thresholds = np.array(ref_subset.var[f'{ref}_position_methylation_thresholding_Youden_stats'].apply(lambda x: x[0]))
128
+ # # Identify NaN values
129
+ # nan_mask = np.isnan(original_matrix)
33
130
 
34
- # Identify NaN values
35
- nan_mask = np.isnan(original_matrix)
131
+ # # Binarize based on threshold
132
+ # binarized_matrix = (original_matrix > thresholds).astype(float)
36
133
 
37
- # Binarize based on threshold
38
- binarized_matrix = (original_matrix > thresholds).astype(float)
134
+ # # Restore NaN values
135
+ # binarized_matrix[nan_mask] = np.nan
39
136
 
40
- # Restore NaN values
41
- binarized_matrix[nan_mask] = np.nan
137
+ # # Assign the binarized values back into the preallocated storage
138
+ # binarized_methylation[ref_subset, :] = binarized_matrix
42
139
 
43
- # Assign the binarized values back into the preallocated storage
44
- binarized_methylation[ref_subset, :] = binarized_matrix
140
+ # # Store the binarized matrix in a new layer
141
+ # adata.layers[output_layer_name] = binarized_methylation
45
142
 
46
- # Store the binarized matrix in a new layer
47
- adata.layers[output_layer_name] = binarized_methylation
143
+ # print(f"Finished binarizing adata on Youden statistics")
@@ -1,28 +1,35 @@
1
+ from __future__ import annotations
2
+
1
3
  ## binary_layers_to_ohe
4
+ from smftools.logging_utils import get_logger
5
+
6
+ logger = get_logger(__name__)
2
7
 
3
- ## Conversion SMF Specific
4
- def binary_layers_to_ohe(adata, binary_layers, stack='hstack'):
8
+
9
+ ## Conversion SMF Specific
10
+ def binary_layers_to_ohe(adata, binary_layers, stack="hstack"):
5
11
  """
6
12
  Parameters:
7
13
  adata (AnnData): Anndata object.
8
- binary_layers (list): a list of strings. Each string represents a layer in the adata object. The layer should encode a binary matrix.
14
+ binary_layers (list): a list of strings. Each string represents a layer in the adata object. The layer should encode a binary matrix.
9
15
  stack (str): Dimension to stack the one-hot-encoding. Options include 'hstack' and 'vstack'. Default is 'hstack', since this is more efficient.
10
-
16
+
11
17
  Returns:
12
18
  ohe_dict (dict): A dictionary keyed by obs_name that points to a stacked (hstack or vstack) one-hot encoding of the binary layers
13
19
  Input: An adata object and a list of layers containing a binary encoding.
14
20
  """
15
21
  import numpy as np
16
- import anndata as ad
17
22
 
18
23
  # Ensure that the N layer is last!
19
24
  # Grab all binary layers that are not encoding N
20
- ACGT_binary_layers = [layer for layer in binary_layers if 'binary' in layer and layer != 'N_binary_encoding']
25
+ ACGT_binary_layers = [
26
+ layer for layer in binary_layers if "binary" in layer and layer != "N_binary_encoding"
27
+ ]
21
28
  # If there is a binary layer encoding N, hold it in N_binary_layer
22
- N_binary_layer = [layer for layer in binary_layers if layer == 'N_binary_encoding']
29
+ N_binary_layer = [layer for layer in binary_layers if layer == "N_binary_encoding"]
23
30
  # Add the N_binary_encoding layer to the end of the list of binary layers
24
31
  all_binary_layers = ACGT_binary_layers + N_binary_layer
25
- print(f'Found {all_binary_layers} layers in adata')
32
+ logger.info("Found %s layers in adata", all_binary_layers)
26
33
 
27
34
  # Extract the layers
28
35
  layers = [adata.layers[layer_name] for layer_name in all_binary_layers]
@@ -33,8 +40,8 @@ def binary_layers_to_ohe(adata, binary_layers, stack='hstack'):
33
40
  for layer in layers:
34
41
  read_ohe.append(layer[i])
35
42
  read_name = adata.obs_names[i]
36
- if stack == 'hstack':
43
+ if stack == "hstack":
37
44
  ohe_dict[read_name] = np.hstack(read_ohe)
38
- elif stack == 'vstack':
45
+ elif stack == "vstack":
39
46
  ohe_dict[read_name] = np.vstack(read_ohe)
40
- return ohe_dict
47
+ return ohe_dict
@@ -1,42 +1,62 @@
1
- from typing import Optional
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING, Optional
5
+
6
+ from smftools.optional_imports import require
7
+
8
+ if TYPE_CHECKING:
9
+ import anndata as ad
10
+
11
+
2
12
  def calculate_complexity_II(
3
- adata,
4
- output_directory='',
5
- sample_col='Sample_names',
6
- ref_col: Optional[str] = 'Reference_strand',
7
- cluster_col='sequence__merged_cluster_id',
8
- plot=True,
9
- save_plot=False,
10
- n_boot=30,
11
- n_depths=12,
12
- random_state=0,
13
- csv_summary=True,
14
- uns_flag='calculate_complexity_II_performed',
15
- force_redo=False,
16
- bypass=False
17
- ):
18
- """
19
- Estimate and plot library complexity.
13
+ adata: "ad.AnnData",
14
+ output_directory: str | Path = "",
15
+ sample_col: str = "Sample_names",
16
+ ref_col: Optional[str] = "Reference_strand",
17
+ cluster_col: str = "sequence__merged_cluster_id",
18
+ plot: bool = True,
19
+ save_plot: bool = False,
20
+ n_boot: int = 30,
21
+ n_depths: int = 12,
22
+ random_state: int = 0,
23
+ csv_summary: bool = True,
24
+ uns_flag: str = "calculate_complexity_II_performed",
25
+ force_redo: bool = False,
26
+ bypass: bool = False,
27
+ ) -> None:
28
+ """Estimate and optionally plot library complexity.
20
29
 
21
- If ref_col is None (default), behaves as before: one calculation per sample.
22
- If ref_col is provided, computes complexity for each (sample, ref) pair.
30
+ If ``ref_col`` is ``None``, the calculation is performed per sample. If provided,
31
+ complexity is computed for each ``(sample, reference)`` pair.
23
32
 
24
- Results:
25
- - adata.uns['Library_complexity_results'] : dict keyed by (sample,) or (sample, ref) -> dict with fields
26
- C0, n_reads, n_unique, depths, mean_unique, ci_low, ci_high
27
- - Also stores per-entity record in adata.uns[f'Library_complexity_{sanitized_name}'] (backwards compatible)
28
- - Optionally saves PNGs and CSVs (curve points + fit summary)
33
+ Args:
34
+ adata: AnnData object containing read metadata.
35
+ output_directory: Directory for output plots/CSVs.
36
+ sample_col: Obs column containing sample names.
37
+ ref_col: Obs column with reference/strand categories, or ``None``.
38
+ cluster_col: Obs column with merged cluster IDs.
39
+ plot: Whether to generate plots.
40
+ save_plot: Whether to save plots to disk.
41
+ n_boot: Number of bootstrap iterations per depth.
42
+ n_depths: Number of subsampling depths to evaluate.
43
+ random_state: Random seed for bootstrapping.
44
+ csv_summary: Whether to write CSV summary files.
45
+ uns_flag: Flag in ``adata.uns`` indicating prior completion.
46
+ force_redo: Whether to rerun even if ``uns_flag`` is present.
47
+ bypass: Whether to skip processing.
29
48
  """
30
49
  import os
50
+
31
51
  import numpy as np
32
52
  import pandas as pd
33
- import matplotlib.pyplot as plt
34
53
  from scipy.optimize import curve_fit
35
- from datetime import datetime
54
+
55
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="complexity plots")
36
56
 
37
57
  # early exits
38
58
  already = bool(adata.uns.get(uns_flag, False))
39
- if (already and not force_redo):
59
+ if already and not force_redo:
40
60
  return None
41
61
  if bypass:
42
62
  return None
@@ -44,9 +64,11 @@ def calculate_complexity_II(
44
64
  rng = np.random.default_rng(random_state)
45
65
 
46
66
  def lw(x, C0):
67
+ """Lander-Waterman curve for complexity estimation."""
47
68
  return C0 * (1.0 - np.exp(-x / C0))
48
69
 
49
70
  def sanitize(name: str) -> str:
71
+ """Sanitize a string for safe filenames."""
50
72
  return "".join(c if c.isalnum() or c in "-._" else "_" for c in str(name))
51
73
 
52
74
  # checks
@@ -77,7 +99,7 @@ def calculate_complexity_II(
77
99
  group_keys = []
78
100
  # iterate only pairs that exist in data to avoid empty processing
79
101
  for s in samples:
80
- mask_s = (adata.obs[sample_col] == s)
102
+ mask_s = adata.obs[sample_col] == s
81
103
  # find references present for this sample
82
104
  ref_present = pd.Categorical(adata.obs.loc[mask_s, ref_col]).categories
83
105
  # Use intersection of known reference categories and those present for sample
@@ -109,7 +131,7 @@ def calculate_complexity_II(
109
131
  "ci_high": np.array([], dtype=float),
110
132
  }
111
133
  # also store back-compat key
112
- adata.uns[f'Library_complexity_{sanitize(group_label)}'] = results[g]
134
+ adata.uns[f"Library_complexity_{sanitize(group_label)}"] = results[g]
113
135
  continue
114
136
 
115
137
  # cluster ids array for this group
@@ -175,39 +197,45 @@ def calculate_complexity_II(
175
197
  }
176
198
 
177
199
  # save per-group in adata.uns for backward compatibility
178
- adata.uns[f'Library_complexity_{sanitize(group_label)}'] = results[g]
200
+ adata.uns[f"Library_complexity_{sanitize(group_label)}"] = results[g]
179
201
 
180
202
  # prepare curve and fit records for CSV
181
- fit_records.append({
182
- "sample": sample,
183
- "reference": ref if ref_col is not None else "",
184
- "C0": float(C0),
185
- "n_reads": int(n_reads),
186
- "n_unique_observed": int(observed_unique),
187
- })
203
+ fit_records.append(
204
+ {
205
+ "sample": sample,
206
+ "reference": ref if ref_col is not None else "",
207
+ "C0": float(C0),
208
+ "n_reads": int(n_reads),
209
+ "n_unique_observed": int(observed_unique),
210
+ }
211
+ )
188
212
 
189
213
  x_fit = np.linspace(0, max(n_reads, int(depths[-1]) if depths.size else n_reads), 200)
190
214
  y_fit = lw(x_fit, C0)
191
215
  for d, mu, lo, hi in zip(depths, mean_unique, lo_ci, hi_ci):
192
- curve_records.append({
193
- "sample": sample,
194
- "reference": ref if ref_col is not None else "",
195
- "type": "bootstrap",
196
- "depth": int(d),
197
- "mean_unique": float(mu),
198
- "ci_low": float(lo),
199
- "ci_high": float(hi),
200
- })
216
+ curve_records.append(
217
+ {
218
+ "sample": sample,
219
+ "reference": ref if ref_col is not None else "",
220
+ "type": "bootstrap",
221
+ "depth": int(d),
222
+ "mean_unique": float(mu),
223
+ "ci_low": float(lo),
224
+ "ci_high": float(hi),
225
+ }
226
+ )
201
227
  for xf, yf in zip(x_fit, y_fit):
202
- curve_records.append({
203
- "sample": sample,
204
- "reference": ref if ref_col is not None else "",
205
- "type": "fit",
206
- "depth": float(xf),
207
- "mean_unique": float(yf),
208
- "ci_low": np.nan,
209
- "ci_high": np.nan,
210
- })
228
+ curve_records.append(
229
+ {
230
+ "sample": sample,
231
+ "reference": ref if ref_col is not None else "",
232
+ "type": "fit",
233
+ "depth": float(xf),
234
+ "mean_unique": float(yf),
235
+ "ci_low": np.nan,
236
+ "ci_high": np.nan,
237
+ }
238
+ )
211
239
 
212
240
  # plotting for this group
213
241
  if plot:
@@ -226,7 +254,9 @@ def calculate_complexity_II(
226
254
 
227
255
  if save_plot:
228
256
  fname = f"complexity_{sanitize(group_label)}.png"
229
- plt.savefig(os.path.join(output_directory or ".", fname), dpi=160, bbox_inches="tight")
257
+ plt.savefig(
258
+ os.path.join(output_directory or ".", fname), dpi=160, bbox_inches="tight"
259
+ )
230
260
  plt.close()
231
261
  else:
232
262
  plt.show()
@@ -242,7 +272,7 @@ def calculate_complexity_II(
242
272
  fit_df = pd.DataFrame(fit_records)
243
273
  curve_df = pd.DataFrame(curve_records)
244
274
  base = output_directory or "."
245
- fit_df.to_csv(os.path.join(base, f"complexity_fit_summary.csv"), index=False)
246
- curve_df.to_csv(os.path.join(base, f"complexity_curves.csv"), index=False)
275
+ fit_df.to_csv(os.path.join(base, "complexity_fit_summary.csv"), index=False)
276
+ curve_df.to_csv(os.path.join(base, "complexity_curves.csv"), index=False)
247
277
 
248
278
  return results
@@ -1,19 +1,28 @@
1
1
  # calculate_consensus
2
2
 
3
- def calculate_consensus(adata, reference, sample=False, reference_column='Reference', sample_column='Sample'):
4
- """
5
- Takes an input AnnData object, the reference to subset on, and the sample name to subset on to calculate the consensus sequence of the read set.
6
-
7
- Parameters:
8
- adata (AnnData): The input adata to append consensus metadata to.
9
- reference (str): The name of the reference to subset the adata on.
10
- sample (bool | str): If False, uses all samples. If a string is passed, the adata is further subsetted to only analyze that sample.
11
- reference_column (str): The name of the reference column (Default is 'Reference')
12
- sample_column (str): The name of the sample column (Default is 'Sample)
13
-
14
- Returns:
15
- None
16
-
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ if TYPE_CHECKING:
8
+ import anndata as ad
9
+
10
+
11
+ def calculate_consensus(
12
+ adata: "ad.AnnData",
13
+ reference: str,
14
+ sample: str | bool = False,
15
+ reference_column: str = "Reference",
16
+ sample_column: str = "Sample",
17
+ ) -> None:
18
+ """Calculate a consensus sequence for a reference (and optional sample).
19
+
20
+ Args:
21
+ adata: AnnData object to append consensus metadata to.
22
+ reference: Reference name to subset on.
23
+ sample: If ``False``, uses all samples. If a string is passed, subsets to that sample.
24
+ reference_column: Obs column with reference names.
25
+ sample_column: Obs column with sample names.
17
26
  """
18
27
  import numpy as np
19
28
 
@@ -25,11 +34,11 @@ def calculate_consensus(adata, reference, sample=False, reference_column='Refere
25
34
  pass
26
35
 
27
36
  # Grab layer names from the adata object that correspond to the binary encodings of the read sequences.
28
- layers = [layer for layer in record_subset.layers if '_binary_' in layer]
37
+ layers = [layer for layer in record_subset.layers if "_binary_" in layer]
29
38
  layer_map, layer_counts = {}, []
30
39
  for i, layer in enumerate(layers):
31
40
  # Gives an integer mapping to access which sequence base the binary layer is encoding
32
- layer_map[i] = layer.split('_')[0]
41
+ layer_map[i] = layer.split("_")[0]
33
42
  # Get the positional counts from all reads for the given base identity.
34
43
  layer_counts.append(np.sum(record_subset.layers[layer], axis=0))
35
44
  # Combine the positional counts array derived from each binary base layer into an ndarray
@@ -40,8 +49,8 @@ def calculate_consensus(adata, reference, sample=False, reference_column='Refere
40
49
  consensus_sequence_list = [layer_map[i] for i in nucleotide_indexes]
41
50
 
42
51
  if sample:
43
- adata.var[f'{reference}_consensus_from_{sample}'] = consensus_sequence_list
52
+ adata.var[f"{reference}_consensus_from_{sample}"] = consensus_sequence_list
44
53
  else:
45
- adata.var[f'{reference}_consensus_across_samples'] = consensus_sequence_list
54
+ adata.var[f"{reference}_consensus_across_samples"] = consensus_sequence_list
46
55
 
47
- adata.uns[f'{reference}_consensus_sequence'] = consensus_sequence_list
56
+ adata.uns[f"{reference}_consensus_sequence"] = consensus_sequence_list