smftools 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. smftools/_version.py +1 -1
  2. smftools/cli/chimeric_adata.py +1563 -0
  3. smftools/cli/helpers.py +18 -2
  4. smftools/cli/hmm_adata.py +18 -1
  5. smftools/cli/latent_adata.py +522 -67
  6. smftools/cli/load_adata.py +2 -2
  7. smftools/cli/preprocess_adata.py +32 -93
  8. smftools/cli/recipes.py +26 -0
  9. smftools/cli/spatial_adata.py +23 -109
  10. smftools/cli/variant_adata.py +423 -0
  11. smftools/cli_entry.py +41 -5
  12. smftools/config/conversion.yaml +0 -10
  13. smftools/config/deaminase.yaml +3 -0
  14. smftools/config/default.yaml +49 -13
  15. smftools/config/experiment_config.py +96 -3
  16. smftools/constants.py +4 -0
  17. smftools/hmm/call_hmm_peaks.py +1 -1
  18. smftools/informatics/binarize_converted_base_identities.py +2 -89
  19. smftools/informatics/converted_BAM_to_adata.py +53 -13
  20. smftools/informatics/h5ad_functions.py +83 -0
  21. smftools/informatics/modkit_extract_to_adata.py +4 -0
  22. smftools/plotting/__init__.py +26 -12
  23. smftools/plotting/autocorrelation_plotting.py +22 -4
  24. smftools/plotting/chimeric_plotting.py +1893 -0
  25. smftools/plotting/classifiers.py +28 -14
  26. smftools/plotting/general_plotting.py +58 -3362
  27. smftools/plotting/hmm_plotting.py +1586 -2
  28. smftools/plotting/latent_plotting.py +804 -0
  29. smftools/plotting/plotting_utils.py +243 -0
  30. smftools/plotting/position_stats.py +16 -8
  31. smftools/plotting/preprocess_plotting.py +281 -0
  32. smftools/plotting/qc_plotting.py +8 -3
  33. smftools/plotting/spatial_plotting.py +1134 -0
  34. smftools/plotting/variant_plotting.py +1231 -0
  35. smftools/preprocessing/__init__.py +3 -0
  36. smftools/preprocessing/append_base_context.py +1 -1
  37. smftools/preprocessing/append_mismatch_frequency_sites.py +35 -6
  38. smftools/preprocessing/append_sequence_mismatch_annotations.py +171 -0
  39. smftools/preprocessing/append_variant_call_layer.py +480 -0
  40. smftools/preprocessing/flag_duplicate_reads.py +4 -4
  41. smftools/preprocessing/invert_adata.py +1 -0
  42. smftools/readwrite.py +109 -85
  43. smftools/tools/__init__.py +6 -0
  44. smftools/tools/calculate_knn.py +121 -0
  45. smftools/tools/calculate_nmf.py +18 -7
  46. smftools/tools/calculate_pca.py +180 -0
  47. smftools/tools/calculate_umap.py +70 -154
  48. smftools/tools/position_stats.py +4 -4
  49. smftools/tools/rolling_nn_distance.py +640 -3
  50. smftools/tools/sequence_alignment.py +140 -0
  51. smftools/tools/tensor_factorization.py +52 -4
  52. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/METADATA +3 -1
  53. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/RECORD +56 -42
  54. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/WHEEL +0 -0
  55. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/entry_points.txt +0 -0
  56. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/licenses/LICENSE +0 -0
@@ -840,6 +840,10 @@ class ExperimentConfig:
840
840
  mismatch_frequency_range: Sequence[float] = field(default_factory=lambda: [0.05, 0.95])
841
841
  mismatch_frequency_layer: str = "mismatch_integer_encoding"
842
842
  mismatch_frequency_read_span_layer: str = "read_span_mask"
843
+ mismatch_base_frequency_exclude_mod_sites: bool = False
844
+ references_to_align_for_variant_annotation: List[Optional[str]] = field(
845
+ default_factory=lambda: [None, None]
846
+ )
843
847
 
844
848
  # Spatial Analysis - Clustermap params
845
849
  layer_for_clustermap_plotting: Optional[str] = "nan0_0minus1"
@@ -848,14 +852,45 @@ class ExperimentConfig:
848
852
  clustermap_cmap_cpg: Optional[str] = "coolwarm"
849
853
  clustermap_cmap_a: Optional[str] = "coolwarm"
850
854
  spatial_clustermap_sortby: Optional[str] = "gpc"
855
+ overlay_variant_calls: bool = False
856
+ variant_overlay_seq1_color: str = "white"
857
+ variant_overlay_seq2_color: str = "black"
858
+ variant_overlay_marker_size: float = 4.0
851
859
  rolling_nn_layer: Optional[str] = "nan0_0minus1"
852
860
  rolling_nn_plot_layer: Optional[str] = "nan0_0minus1"
853
- rolling_nn_window: int = 15
854
- rolling_nn_step: int = 2
855
- rolling_nn_min_overlap: int = 10
861
+ rolling_nn_plot_layers: List[str] = field(
862
+ default_factory=lambda: ["nan0_0minus1", "nan0_0minus1"]
863
+ )
864
+ rolling_nn_window: int = 10
865
+ rolling_nn_step: int = 1
866
+ rolling_nn_min_overlap: int = 8
856
867
  rolling_nn_return_fraction: bool = True
857
868
  rolling_nn_obsm_key: str = "rolling_nn_dist"
858
869
  rolling_nn_site_types: Optional[List[str]] = None
870
+ rolling_nn_write_zero_pairs_csvs: bool = True
871
+ rolling_nn_zero_pairs_uns_key: Optional[str] = None
872
+ rolling_nn_zero_pairs_segments_key: Optional[str] = None
873
+ rolling_nn_zero_pairs_layer_key: Optional[str] = None
874
+ rolling_nn_zero_pairs_refine: bool = True
875
+ rolling_nn_zero_pairs_max_nan_run: Optional[int] = None
876
+ rolling_nn_zero_pairs_merge_gap: int = 0
877
+ rolling_nn_zero_pairs_max_segments_per_read: Optional[int] = None
878
+ rolling_nn_zero_pairs_max_overlap: Optional[int] = None
879
+ rolling_nn_zero_pairs_layer_overlap_mode: str = "binary"
880
+ rolling_nn_zero_pairs_layer_overlap_value: Optional[int] = None
881
+ rolling_nn_zero_pairs_keep_uns: bool = True
882
+ rolling_nn_zero_pairs_segments_keep_uns: bool = True
883
+ rolling_nn_zero_pairs_top_segments_per_read: Optional[int] = None
884
+ rolling_nn_zero_pairs_top_segments_max_overlap: Optional[int] = None
885
+ rolling_nn_zero_pairs_top_segments_min_span: Optional[float] = None
886
+ rolling_nn_zero_pairs_top_segments_write_csvs: bool = True
887
+ rolling_nn_zero_pairs_segment_histogram_bins: int = 30
888
+
889
+ # Cross-sample rolling NN analysis
890
+ cross_sample_analysis: bool = False
891
+ cross_sample_grouping_col: Optional[str] = None
892
+ cross_sample_random_seed: int = 42
893
+ delta_hamming_chimeric_span_threshold: int = 200
859
894
 
860
895
  # Spatial Analysis - UMAP/Leiden params
861
896
  layer_for_umap_plotting: Optional[str] = "nan_half"
@@ -1148,6 +1183,10 @@ class ExperimentConfig:
1148
1183
  merged["mod_target_bases"] = _parse_list(merged["mod_target_bases"])
1149
1184
  if "conversion_types" in merged:
1150
1185
  merged["conversion_types"] = _parse_list(merged["conversion_types"])
1186
+ if "references_to_align_for_variant_annotation" in merged:
1187
+ merged["references_to_align_for_variant_annotation"] = _parse_list(
1188
+ merged["references_to_align_for_variant_annotation"]
1189
+ )
1151
1190
 
1152
1191
  merged["filter_threshold"] = float(_parse_numeric(merged.get("filter_threshold", 0.8), 0.8))
1153
1192
  merged["m6A_threshold"] = float(_parse_numeric(merged.get("m6A_threshold", 0.7), 0.7))
@@ -1360,14 +1399,65 @@ class ExperimentConfig:
1360
1399
  clustermap_cmap_cpg=merged.get("clustermap_cmap_cpg", "coolwarm"),
1361
1400
  clustermap_cmap_a=merged.get("clustermap_cmap_a", "coolwarm"),
1362
1401
  spatial_clustermap_sortby=merged.get("spatial_clustermap_sortby", "gpc"),
1402
+ overlay_variant_calls=_parse_bool(merged.get("overlay_variant_calls", False)),
1403
+ variant_overlay_seq1_color=merged.get("variant_overlay_seq1_color", "white"),
1404
+ variant_overlay_seq2_color=merged.get("variant_overlay_seq2_color", "black"),
1405
+ variant_overlay_marker_size=float(merged.get("variant_overlay_marker_size", 4.0)),
1363
1406
  rolling_nn_layer=merged.get("rolling_nn_layer", "nan0_0minus1"),
1364
1407
  rolling_nn_plot_layer=merged.get("rolling_nn_plot_layer", "nan0_0minus1"),
1408
+ rolling_nn_plot_layers=merged.get(
1409
+ "rolling_nn_plot_layers", ["nan0_0minus1", "nan0_0minus1"]
1410
+ ),
1365
1411
  rolling_nn_window=merged.get("rolling_nn_window", 15),
1366
1412
  rolling_nn_step=merged.get("rolling_nn_step", 2),
1367
1413
  rolling_nn_min_overlap=merged.get("rolling_nn_min_overlap", 10),
1368
1414
  rolling_nn_return_fraction=merged.get("rolling_nn_return_fraction", True),
1369
1415
  rolling_nn_obsm_key=merged.get("rolling_nn_obsm_key", "rolling_nn_dist"),
1370
1416
  rolling_nn_site_types=merged.get("rolling_nn_site_types", None),
1417
+ rolling_nn_write_zero_pairs_csvs=merged.get("rolling_nn_write_zero_pairs_csvs", True),
1418
+ rolling_nn_zero_pairs_uns_key=merged.get("rolling_nn_zero_pairs_uns_key", None),
1419
+ rolling_nn_zero_pairs_segments_key=merged.get(
1420
+ "rolling_nn_zero_pairs_segments_key", None
1421
+ ),
1422
+ rolling_nn_zero_pairs_layer_key=merged.get("rolling_nn_zero_pairs_layer_key", None),
1423
+ rolling_nn_zero_pairs_refine=merged.get("rolling_nn_zero_pairs_refine", True),
1424
+ rolling_nn_zero_pairs_max_nan_run=merged.get("rolling_nn_zero_pairs_max_nan_run", None),
1425
+ rolling_nn_zero_pairs_merge_gap=merged.get("rolling_nn_zero_pairs_merge_gap", 0),
1426
+ rolling_nn_zero_pairs_max_segments_per_read=merged.get(
1427
+ "rolling_nn_zero_pairs_max_segments_per_read", None
1428
+ ),
1429
+ rolling_nn_zero_pairs_max_overlap=merged.get("rolling_nn_zero_pairs_max_overlap", None),
1430
+ rolling_nn_zero_pairs_layer_overlap_mode=merged.get(
1431
+ "rolling_nn_zero_pairs_layer_overlap_mode", "binary"
1432
+ ),
1433
+ rolling_nn_zero_pairs_layer_overlap_value=merged.get(
1434
+ "rolling_nn_zero_pairs_layer_overlap_value", None
1435
+ ),
1436
+ rolling_nn_zero_pairs_keep_uns=merged.get("rolling_nn_zero_pairs_keep_uns", True),
1437
+ rolling_nn_zero_pairs_segments_keep_uns=merged.get(
1438
+ "rolling_nn_zero_pairs_segments_keep_uns", True
1439
+ ),
1440
+ rolling_nn_zero_pairs_top_segments_per_read=merged.get(
1441
+ "rolling_nn_zero_pairs_top_segments_per_read", None
1442
+ ),
1443
+ rolling_nn_zero_pairs_top_segments_max_overlap=merged.get(
1444
+ "rolling_nn_zero_pairs_top_segments_max_overlap", None
1445
+ ),
1446
+ rolling_nn_zero_pairs_top_segments_min_span=merged.get(
1447
+ "rolling_nn_zero_pairs_top_segments_min_span", None
1448
+ ),
1449
+ rolling_nn_zero_pairs_top_segments_write_csvs=merged.get(
1450
+ "rolling_nn_zero_pairs_top_segments_write_csvs", True
1451
+ ),
1452
+ rolling_nn_zero_pairs_segment_histogram_bins=merged.get(
1453
+ "rolling_nn_zero_pairs_segment_histogram_bins", 30
1454
+ ),
1455
+ cross_sample_analysis=merged.get("cross_sample_analysis", False),
1456
+ cross_sample_grouping_col=merged.get("cross_sample_grouping_col", None),
1457
+ cross_sample_random_seed=merged.get("cross_sample_random_seed", 42),
1458
+ delta_hamming_chimeric_span_threshold=merged.get(
1459
+ "delta_hamming_chimeric_span_threshold", 200
1460
+ ),
1371
1461
  layer_for_umap_plotting=merged.get("layer_for_umap_plotting", "nan_half"),
1372
1462
  umap_layers_to_plot=merged.get(
1373
1463
  "umap_layers_to_plot", ["mapped_length", "Raw_modification_signal"]
@@ -1531,6 +1621,9 @@ class ExperimentConfig:
1531
1621
  force_redo_hmm_fit=merged.get("force_redo_hmm_fit", False),
1532
1622
  bypass_hmm_apply=merged.get("bypass_hmm_apply", False),
1533
1623
  force_redo_hmm_apply=merged.get("force_redo_hmm_apply", False),
1624
+ references_to_align_for_variant_annotation=merged.get(
1625
+ "references_to_align_for_variant_annotation", [None, None]
1626
+ ),
1534
1627
  config_source=config_source or "<var_dict>",
1535
1628
  )
1536
1629
 
smftools/constants.py CHANGED
@@ -44,7 +44,11 @@ PREPROCESS_DIR: Final[str] = "preprocess_adata_outputs"
44
44
  SPATIAL_DIR: Final[str] = "spatial_adata_outputs"
45
45
  HMM_DIR: Final[str] = "hmm_adata_outputs"
46
46
  LATENT_DIR: Final[str] = "latent_adata_outputs"
47
+ VARIANT_DIR: Final[str] = "variant_adata_outputs"
48
+ CHIMERIC_DIR: Final[str] = "chimeric_adata_outputs"
49
+
47
50
  LOGGING_DIR: Final[str] = "logs"
51
+
48
52
  TRIM: Final[bool] = False
49
53
 
50
54
  _private_conversions = ["unconverted"]
@@ -51,7 +51,7 @@ def call_hmm_peaks(
51
51
  raise KeyError(f"obs column '{ref_column}' not found")
52
52
 
53
53
  # Ensure categorical for predictable ref iteration
54
- if not pd.api.types.is_categorical_dtype(adata.obs[ref_column]):
54
+ if not isinstance(adata.obs[ref_column].dtype, pd.CategoricalDtype):
55
55
  adata.obs[ref_column] = adata.obs[ref_column].astype("category")
56
56
 
57
57
  # Optional: drop duplicate obs columns once to avoid Pandas/AnnData view quirks
@@ -5,20 +5,19 @@ def binarize_converted_base_identities(
5
5
  base_identities,
6
6
  strand,
7
7
  modification_type,
8
- bam,
9
- device="cpu",
10
8
  deaminase_footprinting=False,
11
9
  mismatch_trend_per_read={},
12
10
  on_missing="nan",
13
11
  ):
14
12
  """
15
13
  Efficiently binarizes conversion SMF data within a sequence string using NumPy arrays.
14
+ For conversion modality, the strand parameter is used for mapping.
15
+ For deaminase modality, the mismatch_trend_per_read is used for mapping.
16
16
 
17
17
  Parameters:
18
18
  base_identities (dict): A dictionary returned by extract_base_identities. Keyed by read name. Points to a list of base identities.
19
19
  strand (str): A string indicating which strand was converted in the experiment (options are 'top' and 'bottom').
20
20
  modification_type (str): A string indicating the modification type of interest (options are '5mC' and '6mA').
21
- bam (str): The bam file path
22
21
  deaminase_footprinting (bool): Whether direct deaminase footprinting chemistry was used.
23
22
  mismatch_trend_per_read (dict): For deaminase footprinting, indicates the type of conversion relative to the top strand reference for each read. (C->T or G->A if bottom strand was converted)
24
23
  on_missing (str): Error handling if a read is missing
@@ -98,89 +97,3 @@ def binarize_converted_base_identities(
98
97
  out[read_id] = res
99
98
 
100
99
  return out
101
-
102
- # if mismatch_trend_per_read is None:
103
- # mismatch_trend_per_read = {}
104
-
105
- # # If the modification type is 'unconverted', return NaN for all positions if the deaminase_footprinting strategy is not being used.
106
- # if modification_type == "unconverted" and not deaminase_footprinting:
107
- # #print(f"Skipping binarization for unconverted {strand} reads on bam: {bam}.")
108
- # return {key: np.full(len(bases), np.nan) for key, bases in base_identities.items()}
109
-
110
- # # Define mappings for binarization based on strand and modification type
111
- # if deaminase_footprinting:
112
- # binarization_maps = {
113
- # ('C->T'): {'C': 0, 'T': 1},
114
- # ('G->A'): {'G': 0, 'A': 1},
115
- # }
116
-
117
- # binarized_base_identities = {}
118
- # for key, bases in base_identities.items():
119
- # arr = np.array(bases, dtype='<U1')
120
- # # Fetch the appropriate mapping
121
- # conversion_type = mismatch_trend_per_read[key]
122
- # base_map = binarization_maps.get(conversion_type, None)
123
- # binarized = np.vectorize(lambda x: base_map.get(x, np.nan))(arr) # Apply mapping with fallback to NaN
124
- # binarized_base_identities[key] = binarized
125
-
126
- # return binarized_base_identities
127
-
128
- # else:
129
- # binarization_maps = {
130
- # ('top', '5mC'): {'C': 1, 'T': 0},
131
- # ('top', '6mA'): {'A': 1, 'G': 0},
132
- # ('bottom', '5mC'): {'G': 1, 'A': 0},
133
- # ('bottom', '6mA'): {'T': 1, 'C': 0}
134
- # }
135
-
136
- # if (strand, modification_type) not in binarization_maps:
137
- # raise ValueError(f"Invalid combination of strand='{strand}' and modification_type='{modification_type}'")
138
-
139
- # # Fetch the appropriate mapping
140
- # base_map = binarization_maps[(strand, modification_type)]
141
-
142
- # binarized_base_identities = {}
143
- # for key, bases in base_identities.items():
144
- # arr = np.array(bases, dtype='<U1')
145
- # binarized = np.vectorize(lambda x: base_map.get(x, np.nan))(arr) # Apply mapping with fallback to NaN
146
- # binarized_base_identities[key] = binarized
147
-
148
- # return binarized_base_identities
149
- # import torch
150
-
151
- # # If the modification type is 'unconverted', return NaN for all positions
152
- # if modification_type == "unconverted":
153
- # print(f"Skipping binarization for unconverted {strand} reads on bam: {bam}.")
154
- # return {key: torch.full((len(bases),), float('nan'), device=device) for key, bases in base_identities.items()}
155
-
156
- # # Define mappings for binarization based on strand and modification type
157
- # binarization_maps = {
158
- # ('top', '5mC'): {'C': 1, 'T': 0},
159
- # ('top', '6mA'): {'A': 1, 'G': 0},
160
- # ('bottom', '5mC'): {'G': 1, 'A': 0},
161
- # ('bottom', '6mA'): {'T': 1, 'C': 0}
162
- # }
163
-
164
- # if (strand, modification_type) not in binarization_maps:
165
- # raise ValueError(f"Invalid combination of strand='{strand}' and modification_type='{modification_type}'")
166
-
167
- # # Fetch the appropriate mapping
168
- # base_map = binarization_maps[(strand, modification_type)]
169
-
170
- # # Convert mapping to tensor
171
- # base_keys = list(base_map.keys())
172
- # base_values = torch.tensor(list(base_map.values()), dtype=torch.float32, device=device)
173
-
174
- # # Create a lookup dictionary (ASCII-based for fast mapping)
175
- # lookup_table = torch.full((256,), float('nan'), dtype=torch.float32, device=device)
176
- # for k, v in zip(base_keys, base_values):
177
- # lookup_table[ord(k)] = v
178
-
179
- # # Process reads
180
- # binarized_base_identities = {}
181
- # for key, bases in base_identities.items():
182
- # bases_tensor = torch.tensor([ord(c) for c in bases], dtype=torch.uint8, device=device) # Convert chars to ASCII
183
- # binarized = lookup_table[bases_tensor] # Efficient lookup
184
- # binarized_base_identities[key] = binarized
185
-
186
- # return binarized_base_identities
@@ -272,6 +272,10 @@ def converted_BAM_to_adata(
272
272
  consensus_sequence_list
273
273
  )
274
274
 
275
+ from .h5ad_functions import append_reference_strand_quality_stats
276
+
277
+ append_reference_strand_quality_stats(final_adata)
278
+
275
279
  if input_already_demuxed:
276
280
  final_adata.obs[DEMUX_TYPE] = ["already"] * final_adata.shape[0]
277
281
  final_adata.obs[DEMUX_TYPE] = final_adata.obs[DEMUX_TYPE].astype("category")
@@ -321,15 +325,17 @@ def process_conversion_sites(
321
325
  conversion_types = conversions[1:]
322
326
 
323
327
  # Process the unconverted sequence once
328
+ # modification dict is keyed by mod type (ie unconverted, 5mC, 6mA)
329
+ # modification_dict[unconverted] points to a dictionary keyed by unconverted record.id keys.
330
+ # This then maps to [sequence_length, [], [], unconverted sequence, unconverted complement]
324
331
  modification_dict[unconverted] = find_conversion_sites(
325
332
  converted_FASTA, unconverted, conversions, deaminase_footprinting
326
333
  )
327
- # Above points to record_dict[record.id] = [sequence_length, [], [], sequence, complement] with only unconverted record.id keys
328
334
 
329
- # Get **max sequence length** from unconverted records
335
+ # Get max sequence length from unconverted records
330
336
  max_reference_length = max(values[0] for values in modification_dict[unconverted].values())
331
337
 
332
- # Add **unconverted records** to `record_FASTA_dict`
338
+ # Add unconverted records to `record_FASTA_dict`
333
339
  for record, values in modification_dict[unconverted].items():
334
340
  sequence_length, top_coords, bottom_coords, sequence, complement = values
335
341
 
@@ -358,25 +364,34 @@ def process_conversion_sites(
358
364
  )
359
365
 
360
366
  # Process converted records
367
+ # For each conversion type (ie 5mC, 6mA), add the conversion type as a key to modification_dict.
368
+ # This points to a dictionary keyed by the unconverted record id key.
369
+ # This points to [sequence_length, top_strand_coordinates, bottom_strand_coordinates, unconverted sequence, unconverted complement]
361
370
  for conversion in conversion_types:
362
371
  modification_dict[conversion] = find_conversion_sites(
363
372
  converted_FASTA, conversion, conversions, deaminase_footprinting
364
373
  )
365
- # Above points to record_dict[record.id] = [sequence_length, top_strand_coordinates, bottom_strand_coordinates, sequence, complement] with only unconverted record.id keys
366
374
 
375
+ # Iterate over the unconverted record ids in mod_dict, as well as the
376
+ # [sequence_length, top_strand_coordinates, bottom_strand_coordinates, unconverted sequence, unconverted complement] for the conversion type
367
377
  for record, values in modification_dict[conversion].items():
368
378
  sequence_length, top_coords, bottom_coords, sequence, complement = values
369
379
 
370
380
  if not deaminase_footprinting:
371
- chromosome = record.split(f"_{unconverted}_")[0] # Extract chromosome name
381
+ # For conversion smf, make the chromosome name the base record name
382
+ chromosome = record.split(f"_{unconverted}_")[0]
372
383
  else:
384
+ # For deaminase smf, make the chromosome and record name the same
373
385
  chromosome = record
374
386
 
375
- # Add **both strands** for converted records
387
+ # Add both strands for converted records
376
388
  for strand in ["top", "bottom"]:
389
+ # Generate converted/unconverted record names that are found in the converted FASTA
377
390
  converted_name = f"{chromosome}_{conversion}_{strand}"
378
391
  unconverted_name = f"{chromosome}_{unconverted}_top"
379
392
 
393
+ # Use the converted FASTA record names as keys to a dict that points to RecordFastaInfo objects.
394
+ # These objects will contain the unconverted sequence/complement.
380
395
  record_FASTA_dict[converted_name] = RecordFastaInfo(
381
396
  sequence=sequence + "N" * (max_reference_length - sequence_length),
382
397
  complement=complement + "N" * (max_reference_length - sequence_length),
@@ -577,16 +592,19 @@ def process_single_bam(
577
592
  """
578
593
  adata_list: list[ad.AnnData] = []
579
594
 
595
+ # Iterate over BAM records that passed filtering.
580
596
  for record in records_to_analyze:
581
597
  sample = bam.stem
582
598
  record_info = record_FASTA_dict[record]
583
599
  chromosome = record_info.chromosome
584
600
  current_length = record_info.sequence_length
601
+ # Note, mod_type and strand are only correctly load for conversion smf and not deaminase
602
+ # However, these variables are only used for conversion smf and not deaminase, so works.
585
603
  mod_type, strand = record_info.conversion, record_info.strand
586
604
  non_converted_sequence = chromosome_FASTA_dict[chromosome][0]
587
605
  record_sequence = converted_FASTA_record_seq_map[record][1]
588
606
 
589
- # Extract Base Identities
607
+ # Extract Base Identities for forward and reverse mapped reads.
590
608
  (
591
609
  fwd_bases,
592
610
  rev_bases,
@@ -615,13 +633,12 @@ def process_single_bam(
615
633
  merged_bin = {}
616
634
 
617
635
  # Binarize the Base Identities if they exist
636
+ # Note, mod_type is always unconverted and strand is always top currently for deaminase smf. this works for now.
618
637
  if fwd_bases:
619
638
  fwd_bin = binarize_converted_base_identities(
620
639
  fwd_bases,
621
640
  strand,
622
641
  mod_type,
623
- bam,
624
- device,
625
642
  deaminase_footprinting,
626
643
  mismatch_trend_per_read,
627
644
  )
@@ -632,8 +649,6 @@ def process_single_bam(
632
649
  rev_bases,
633
650
  strand,
634
651
  mod_type,
635
- bam,
636
- device,
637
652
  deaminase_footprinting,
638
653
  mismatch_trend_per_read,
639
654
  )
@@ -742,10 +757,35 @@ def process_single_bam(
742
757
  adata.obs[REFERENCE] = [chromosome] * len(adata)
743
758
  adata.obs[STRAND] = [strand] * len(adata)
744
759
  adata.obs[DATASET] = [mod_type] * len(adata)
745
- adata.obs[REFERENCE_DATASET_STRAND] = [f"{chromosome}_{mod_type}_{strand}"] * len(adata)
746
- adata.obs[REFERENCE_STRAND] = [f"{chromosome}_{strand}"] * len(adata)
747
760
  adata.obs[READ_MISMATCH_TREND] = adata.obs_names.map(mismatch_trend_series)
748
761
 
762
+ # Currently, deaminase footprinting uses mismatch trend to define the strand.
763
+ if deaminase_footprinting:
764
+ is_ct = adata.obs[READ_MISMATCH_TREND] == "C->T"
765
+ is_ga = adata.obs[READ_MISMATCH_TREND] == "G->A"
766
+
767
+ adata.obs.loc[is_ct, STRAND] = "top"
768
+ adata.obs.loc[is_ga, STRAND] = "bottom"
769
+ # Currently, conversion footprinting uses strand to define the mismatch trend.
770
+ else:
771
+ is_top = adata.obs[STRAND] == "top"
772
+ is_bottom = adata.obs[STRAND] == "bottom"
773
+
774
+ adata.obs.loc[is_top, READ_MISMATCH_TREND] = "C->T"
775
+ adata.obs.loc[is_bottom, READ_MISMATCH_TREND] = "G->A"
776
+
777
+ adata.obs[REFERENCE_DATASET_STRAND] = (
778
+ adata.obs[REFERENCE].astype(str)
779
+ + "_"
780
+ + adata.obs[DATASET].astype(str)
781
+ + "_"
782
+ + adata.obs[STRAND].astype(str)
783
+ )
784
+
785
+ adata.obs[REFERENCE_STRAND] = (
786
+ adata.obs[REFERENCE].astype(str) + "_" + adata.obs[STRAND].astype(str)
787
+ )
788
+
749
789
  read_mapping_direction = []
750
790
  for read_id in adata.obs_names:
751
791
  if read_id in fwd_reads:
@@ -10,6 +10,7 @@ import numpy as np
10
10
  import pandas as pd
11
11
  import scipy.sparse as sp
12
12
 
13
+ from smftools.constants import BASE_QUALITY_SCORES, READ_SPAN_MASK, REFERENCE_STRAND
13
14
  from smftools.logging_utils import get_logger
14
15
  from smftools.optional_imports import require
15
16
 
@@ -84,6 +85,88 @@ def add_demux_type_annotation(
84
85
  return adata
85
86
 
86
87
 
88
+ def append_reference_strand_quality_stats(
89
+ adata,
90
+ ref_column: str = REFERENCE_STRAND,
91
+ quality_layer: str = BASE_QUALITY_SCORES,
92
+ read_span_layer: str = READ_SPAN_MASK,
93
+ uns_flag: str = "append_reference_strand_quality_stats_performed",
94
+ force_redo: bool = False,
95
+ bypass: bool = False,
96
+ ) -> None:
97
+ """Append per-position quality and error rate stats for each reference strand.
98
+
99
+ Args:
100
+ adata: AnnData object to annotate in-place.
101
+ ref_column: Obs column defining reference strand groups.
102
+ quality_layer: Layer containing base quality scores.
103
+ read_span_layer: Optional layer marking covered positions (1=covered, 0=not covered).
104
+ uns_flag: Flag in ``adata.uns`` indicating prior completion.
105
+ force_redo: Whether to rerun even if ``uns_flag`` is set.
106
+ bypass: Whether to skip this step.
107
+ """
108
+ if bypass:
109
+ return
110
+
111
+ already = bool(adata.uns.get(uns_flag, False))
112
+ if already and not force_redo:
113
+ return
114
+
115
+ if ref_column not in adata.obs:
116
+ logger.debug("Reference column '%s' not found; skipping quality stats.", ref_column)
117
+ return
118
+
119
+ if quality_layer not in adata.layers:
120
+ logger.debug("Quality layer '%s' not found; skipping quality stats.", quality_layer)
121
+ return
122
+
123
+ ref_values = adata.obs[ref_column]
124
+ references = (
125
+ ref_values.cat.categories if hasattr(ref_values, "cat") else pd.Index(pd.unique(ref_values))
126
+ )
127
+ n_vars = adata.shape[1]
128
+ has_span_mask = read_span_layer in adata.layers
129
+
130
+ for ref in references:
131
+ ref_mask = ref_values == ref
132
+ ref_position_mask = adata.var.get(f"position_in_{ref}")
133
+ if ref_position_mask is None:
134
+ ref_position_mask = pd.Series(np.ones(n_vars, dtype=bool), index=adata.var.index)
135
+ else:
136
+ ref_position_mask = ref_position_mask.astype(bool)
137
+
138
+ mean_quality = np.full(n_vars, np.nan, dtype=float)
139
+ std_quality = np.full(n_vars, np.nan, dtype=float)
140
+ mean_error = np.full(n_vars, np.nan, dtype=float)
141
+ std_error = np.full(n_vars, np.nan, dtype=float)
142
+
143
+ if ref_mask.sum() > 0:
144
+ quality_matrix = np.asarray(adata.layers[quality_layer][ref_mask]).astype(float)
145
+ quality_matrix[quality_matrix < 0] = np.nan
146
+ if has_span_mask:
147
+ coverage_mask = np.asarray(adata.layers[read_span_layer][ref_mask]) > 0
148
+ quality_matrix = np.where(coverage_mask, quality_matrix, np.nan)
149
+
150
+ mean_quality = np.nanmean(quality_matrix, axis=0)
151
+ std_quality = np.nanstd(quality_matrix, axis=0)
152
+
153
+ error_matrix = np.power(10.0, -quality_matrix / 10.0)
154
+ mean_error = np.nanmean(error_matrix, axis=0)
155
+ std_error = np.nanstd(error_matrix, axis=0)
156
+
157
+ mean_quality = np.where(ref_position_mask.values, mean_quality, np.nan)
158
+ std_quality = np.where(ref_position_mask.values, std_quality, np.nan)
159
+ mean_error = np.where(ref_position_mask.values, mean_error, np.nan)
160
+ std_error = np.where(ref_position_mask.values, std_error, np.nan)
161
+
162
+ adata.var[f"{ref}_mean_base_quality"] = pd.Series(mean_quality, index=adata.var.index)
163
+ adata.var[f"{ref}_std_base_quality"] = pd.Series(std_quality, index=adata.var.index)
164
+ adata.var[f"{ref}_mean_error_rate"] = pd.Series(mean_error, index=adata.var.index)
165
+ adata.var[f"{ref}_std_error_rate"] = pd.Series(std_error, index=adata.var.index)
166
+
167
+ adata.uns[uns_flag] = True
168
+
169
+
87
170
  def add_read_tag_annotations(
88
171
  adata,
89
172
  bam_files: Optional[List[str]] = None,
@@ -1881,6 +1881,10 @@ def modkit_extract_to_adata(
1881
1881
  f"{record}_{strand}_{mapping_dir}_consensus_sequence_from_all_samples"
1882
1882
  ] = consensus_sequence_list
1883
1883
 
1884
+ from .h5ad_functions import append_reference_strand_quality_stats
1885
+
1886
+ append_reference_strand_quality_stats(final_adata)
1887
+
1884
1888
  if input_already_demuxed:
1885
1889
  final_adata.obs[DEMUX_TYPE] = ["already"] * final_adata.shape[0]
1886
1890
  final_adata.obs[DEMUX_TYPE] = final_adata.obs[DEMUX_TYPE].astype("category")
@@ -3,18 +3,32 @@ from __future__ import annotations
3
3
  from importlib import import_module
4
4
 
5
5
  _LAZY_ATTRS = {
6
- "combined_hmm_length_clustermap": "smftools.plotting.general_plotting",
7
- "combined_hmm_raw_clustermap": "smftools.plotting.general_plotting",
8
- "combined_raw_clustermap": "smftools.plotting.general_plotting",
9
- "plot_rolling_nn_and_layer": "smftools.plotting.general_plotting",
10
- "plot_hmm_layers_rolling_by_sample_ref": "smftools.plotting.general_plotting",
11
- "plot_nmf_components": "smftools.plotting.general_plotting",
12
- "plot_cp_sequence_components": "smftools.plotting.general_plotting",
13
- "plot_embedding": "smftools.plotting.general_plotting",
14
- "plot_read_span_quality_clustermaps": "smftools.plotting.general_plotting",
15
- "plot_pca": "smftools.plotting.general_plotting",
16
- "plot_sequence_integer_encoding_clustermaps": "smftools.plotting.general_plotting",
17
- "plot_umap": "smftools.plotting.general_plotting",
6
+ "combined_hmm_length_clustermap": "smftools.plotting.hmm_plotting",
7
+ "combined_hmm_raw_clustermap": "smftools.plotting.hmm_plotting",
8
+ "combined_raw_clustermap": "smftools.plotting.spatial_plotting",
9
+ "plot_delta_hamming_summary": "smftools.plotting.chimeric_plotting",
10
+ "plot_hamming_span_trio": "smftools.plotting.chimeric_plotting",
11
+ "plot_rolling_nn_and_layer": "smftools.plotting.chimeric_plotting",
12
+ "plot_rolling_nn_and_two_layers": "smftools.plotting.chimeric_plotting",
13
+ "plot_segment_length_histogram": "smftools.plotting.chimeric_plotting",
14
+ "plot_span_length_distributions": "smftools.plotting.chimeric_plotting",
15
+ "plot_zero_hamming_pair_counts": "smftools.plotting.chimeric_plotting",
16
+ "plot_zero_hamming_span_and_layer": "smftools.plotting.chimeric_plotting",
17
+ "plot_hmm_layers_rolling_by_sample_ref": "smftools.plotting.hmm_plotting",
18
+ "plot_nmf_components": "smftools.plotting.latent_plotting",
19
+ "plot_pca_components": "smftools.plotting.latent_plotting",
20
+ "plot_cp_sequence_components": "smftools.plotting.latent_plotting",
21
+ "plot_embedding": "smftools.plotting.latent_plotting",
22
+ "plot_embedding_grid": "smftools.plotting.latent_plotting",
23
+ "plot_read_span_quality_clustermaps": "smftools.plotting.preprocess_plotting",
24
+ "plot_mismatch_base_frequency_by_position": "smftools.plotting.variant_plotting",
25
+ "plot_pca": "smftools.plotting.latent_plotting",
26
+ "plot_pca_grid": "smftools.plotting.latent_plotting",
27
+ "plot_pca_explained_variance": "smftools.plotting.latent_plotting",
28
+ "plot_sequence_integer_encoding_clustermaps": "smftools.plotting.variant_plotting",
29
+ "plot_variant_segment_clustermaps": "smftools.plotting.variant_plotting",
30
+ "plot_umap": "smftools.plotting.latent_plotting",
31
+ "plot_umap_grid": "smftools.plotting.latent_plotting",
18
32
  "plot_bar_relative_risk": "smftools.plotting.position_stats",
19
33
  "plot_positionwise_matrix": "smftools.plotting.position_stats",
20
34
  "plot_positionwise_matrix_grid": "smftools.plotting.position_stats",