smftools 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/_version.py +1 -1
- smftools/cli/chimeric_adata.py +1563 -0
- smftools/cli/helpers.py +18 -2
- smftools/cli/hmm_adata.py +18 -1
- smftools/cli/latent_adata.py +522 -67
- smftools/cli/load_adata.py +2 -2
- smftools/cli/preprocess_adata.py +32 -93
- smftools/cli/recipes.py +26 -0
- smftools/cli/spatial_adata.py +23 -109
- smftools/cli/variant_adata.py +423 -0
- smftools/cli_entry.py +41 -5
- smftools/config/conversion.yaml +0 -10
- smftools/config/deaminase.yaml +3 -0
- smftools/config/default.yaml +49 -13
- smftools/config/experiment_config.py +96 -3
- smftools/constants.py +4 -0
- smftools/hmm/call_hmm_peaks.py +1 -1
- smftools/informatics/binarize_converted_base_identities.py +2 -89
- smftools/informatics/converted_BAM_to_adata.py +53 -13
- smftools/informatics/h5ad_functions.py +83 -0
- smftools/informatics/modkit_extract_to_adata.py +4 -0
- smftools/plotting/__init__.py +26 -12
- smftools/plotting/autocorrelation_plotting.py +22 -4
- smftools/plotting/chimeric_plotting.py +1893 -0
- smftools/plotting/classifiers.py +28 -14
- smftools/plotting/general_plotting.py +58 -3362
- smftools/plotting/hmm_plotting.py +1586 -2
- smftools/plotting/latent_plotting.py +804 -0
- smftools/plotting/plotting_utils.py +243 -0
- smftools/plotting/position_stats.py +16 -8
- smftools/plotting/preprocess_plotting.py +281 -0
- smftools/plotting/qc_plotting.py +8 -3
- smftools/plotting/spatial_plotting.py +1134 -0
- smftools/plotting/variant_plotting.py +1231 -0
- smftools/preprocessing/__init__.py +3 -0
- smftools/preprocessing/append_base_context.py +1 -1
- smftools/preprocessing/append_mismatch_frequency_sites.py +35 -6
- smftools/preprocessing/append_sequence_mismatch_annotations.py +171 -0
- smftools/preprocessing/append_variant_call_layer.py +480 -0
- smftools/preprocessing/flag_duplicate_reads.py +4 -4
- smftools/preprocessing/invert_adata.py +1 -0
- smftools/readwrite.py +109 -85
- smftools/tools/__init__.py +6 -0
- smftools/tools/calculate_knn.py +121 -0
- smftools/tools/calculate_nmf.py +18 -7
- smftools/tools/calculate_pca.py +180 -0
- smftools/tools/calculate_umap.py +70 -154
- smftools/tools/position_stats.py +4 -4
- smftools/tools/rolling_nn_distance.py +640 -3
- smftools/tools/sequence_alignment.py +140 -0
- smftools/tools/tensor_factorization.py +52 -4
- {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/METADATA +3 -1
- {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/RECORD +56 -42
- {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/WHEEL +0 -0
- {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/entry_points.txt +0 -0
- {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -840,6 +840,10 @@ class ExperimentConfig:
|
|
|
840
840
|
mismatch_frequency_range: Sequence[float] = field(default_factory=lambda: [0.05, 0.95])
|
|
841
841
|
mismatch_frequency_layer: str = "mismatch_integer_encoding"
|
|
842
842
|
mismatch_frequency_read_span_layer: str = "read_span_mask"
|
|
843
|
+
mismatch_base_frequency_exclude_mod_sites: bool = False
|
|
844
|
+
references_to_align_for_variant_annotation: List[Optional[str]] = field(
|
|
845
|
+
default_factory=lambda: [None, None]
|
|
846
|
+
)
|
|
843
847
|
|
|
844
848
|
# Spatial Analysis - Clustermap params
|
|
845
849
|
layer_for_clustermap_plotting: Optional[str] = "nan0_0minus1"
|
|
@@ -848,14 +852,45 @@ class ExperimentConfig:
|
|
|
848
852
|
clustermap_cmap_cpg: Optional[str] = "coolwarm"
|
|
849
853
|
clustermap_cmap_a: Optional[str] = "coolwarm"
|
|
850
854
|
spatial_clustermap_sortby: Optional[str] = "gpc"
|
|
855
|
+
overlay_variant_calls: bool = False
|
|
856
|
+
variant_overlay_seq1_color: str = "white"
|
|
857
|
+
variant_overlay_seq2_color: str = "black"
|
|
858
|
+
variant_overlay_marker_size: float = 4.0
|
|
851
859
|
rolling_nn_layer: Optional[str] = "nan0_0minus1"
|
|
852
860
|
rolling_nn_plot_layer: Optional[str] = "nan0_0minus1"
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
861
|
+
rolling_nn_plot_layers: List[str] = field(
|
|
862
|
+
default_factory=lambda: ["nan0_0minus1", "nan0_0minus1"]
|
|
863
|
+
)
|
|
864
|
+
rolling_nn_window: int = 10
|
|
865
|
+
rolling_nn_step: int = 1
|
|
866
|
+
rolling_nn_min_overlap: int = 8
|
|
856
867
|
rolling_nn_return_fraction: bool = True
|
|
857
868
|
rolling_nn_obsm_key: str = "rolling_nn_dist"
|
|
858
869
|
rolling_nn_site_types: Optional[List[str]] = None
|
|
870
|
+
rolling_nn_write_zero_pairs_csvs: bool = True
|
|
871
|
+
rolling_nn_zero_pairs_uns_key: Optional[str] = None
|
|
872
|
+
rolling_nn_zero_pairs_segments_key: Optional[str] = None
|
|
873
|
+
rolling_nn_zero_pairs_layer_key: Optional[str] = None
|
|
874
|
+
rolling_nn_zero_pairs_refine: bool = True
|
|
875
|
+
rolling_nn_zero_pairs_max_nan_run: Optional[int] = None
|
|
876
|
+
rolling_nn_zero_pairs_merge_gap: int = 0
|
|
877
|
+
rolling_nn_zero_pairs_max_segments_per_read: Optional[int] = None
|
|
878
|
+
rolling_nn_zero_pairs_max_overlap: Optional[int] = None
|
|
879
|
+
rolling_nn_zero_pairs_layer_overlap_mode: str = "binary"
|
|
880
|
+
rolling_nn_zero_pairs_layer_overlap_value: Optional[int] = None
|
|
881
|
+
rolling_nn_zero_pairs_keep_uns: bool = True
|
|
882
|
+
rolling_nn_zero_pairs_segments_keep_uns: bool = True
|
|
883
|
+
rolling_nn_zero_pairs_top_segments_per_read: Optional[int] = None
|
|
884
|
+
rolling_nn_zero_pairs_top_segments_max_overlap: Optional[int] = None
|
|
885
|
+
rolling_nn_zero_pairs_top_segments_min_span: Optional[float] = None
|
|
886
|
+
rolling_nn_zero_pairs_top_segments_write_csvs: bool = True
|
|
887
|
+
rolling_nn_zero_pairs_segment_histogram_bins: int = 30
|
|
888
|
+
|
|
889
|
+
# Cross-sample rolling NN analysis
|
|
890
|
+
cross_sample_analysis: bool = False
|
|
891
|
+
cross_sample_grouping_col: Optional[str] = None
|
|
892
|
+
cross_sample_random_seed: int = 42
|
|
893
|
+
delta_hamming_chimeric_span_threshold: int = 200
|
|
859
894
|
|
|
860
895
|
# Spatial Analysis - UMAP/Leiden params
|
|
861
896
|
layer_for_umap_plotting: Optional[str] = "nan_half"
|
|
@@ -1148,6 +1183,10 @@ class ExperimentConfig:
|
|
|
1148
1183
|
merged["mod_target_bases"] = _parse_list(merged["mod_target_bases"])
|
|
1149
1184
|
if "conversion_types" in merged:
|
|
1150
1185
|
merged["conversion_types"] = _parse_list(merged["conversion_types"])
|
|
1186
|
+
if "references_to_align_for_variant_annotation" in merged:
|
|
1187
|
+
merged["references_to_align_for_variant_annotation"] = _parse_list(
|
|
1188
|
+
merged["references_to_align_for_variant_annotation"]
|
|
1189
|
+
)
|
|
1151
1190
|
|
|
1152
1191
|
merged["filter_threshold"] = float(_parse_numeric(merged.get("filter_threshold", 0.8), 0.8))
|
|
1153
1192
|
merged["m6A_threshold"] = float(_parse_numeric(merged.get("m6A_threshold", 0.7), 0.7))
|
|
@@ -1360,14 +1399,65 @@ class ExperimentConfig:
|
|
|
1360
1399
|
clustermap_cmap_cpg=merged.get("clustermap_cmap_cpg", "coolwarm"),
|
|
1361
1400
|
clustermap_cmap_a=merged.get("clustermap_cmap_a", "coolwarm"),
|
|
1362
1401
|
spatial_clustermap_sortby=merged.get("spatial_clustermap_sortby", "gpc"),
|
|
1402
|
+
overlay_variant_calls=_parse_bool(merged.get("overlay_variant_calls", False)),
|
|
1403
|
+
variant_overlay_seq1_color=merged.get("variant_overlay_seq1_color", "white"),
|
|
1404
|
+
variant_overlay_seq2_color=merged.get("variant_overlay_seq2_color", "black"),
|
|
1405
|
+
variant_overlay_marker_size=float(merged.get("variant_overlay_marker_size", 4.0)),
|
|
1363
1406
|
rolling_nn_layer=merged.get("rolling_nn_layer", "nan0_0minus1"),
|
|
1364
1407
|
rolling_nn_plot_layer=merged.get("rolling_nn_plot_layer", "nan0_0minus1"),
|
|
1408
|
+
rolling_nn_plot_layers=merged.get(
|
|
1409
|
+
"rolling_nn_plot_layers", ["nan0_0minus1", "nan0_0minus1"]
|
|
1410
|
+
),
|
|
1365
1411
|
rolling_nn_window=merged.get("rolling_nn_window", 15),
|
|
1366
1412
|
rolling_nn_step=merged.get("rolling_nn_step", 2),
|
|
1367
1413
|
rolling_nn_min_overlap=merged.get("rolling_nn_min_overlap", 10),
|
|
1368
1414
|
rolling_nn_return_fraction=merged.get("rolling_nn_return_fraction", True),
|
|
1369
1415
|
rolling_nn_obsm_key=merged.get("rolling_nn_obsm_key", "rolling_nn_dist"),
|
|
1370
1416
|
rolling_nn_site_types=merged.get("rolling_nn_site_types", None),
|
|
1417
|
+
rolling_nn_write_zero_pairs_csvs=merged.get("rolling_nn_write_zero_pairs_csvs", True),
|
|
1418
|
+
rolling_nn_zero_pairs_uns_key=merged.get("rolling_nn_zero_pairs_uns_key", None),
|
|
1419
|
+
rolling_nn_zero_pairs_segments_key=merged.get(
|
|
1420
|
+
"rolling_nn_zero_pairs_segments_key", None
|
|
1421
|
+
),
|
|
1422
|
+
rolling_nn_zero_pairs_layer_key=merged.get("rolling_nn_zero_pairs_layer_key", None),
|
|
1423
|
+
rolling_nn_zero_pairs_refine=merged.get("rolling_nn_zero_pairs_refine", True),
|
|
1424
|
+
rolling_nn_zero_pairs_max_nan_run=merged.get("rolling_nn_zero_pairs_max_nan_run", None),
|
|
1425
|
+
rolling_nn_zero_pairs_merge_gap=merged.get("rolling_nn_zero_pairs_merge_gap", 0),
|
|
1426
|
+
rolling_nn_zero_pairs_max_segments_per_read=merged.get(
|
|
1427
|
+
"rolling_nn_zero_pairs_max_segments_per_read", None
|
|
1428
|
+
),
|
|
1429
|
+
rolling_nn_zero_pairs_max_overlap=merged.get("rolling_nn_zero_pairs_max_overlap", None),
|
|
1430
|
+
rolling_nn_zero_pairs_layer_overlap_mode=merged.get(
|
|
1431
|
+
"rolling_nn_zero_pairs_layer_overlap_mode", "binary"
|
|
1432
|
+
),
|
|
1433
|
+
rolling_nn_zero_pairs_layer_overlap_value=merged.get(
|
|
1434
|
+
"rolling_nn_zero_pairs_layer_overlap_value", None
|
|
1435
|
+
),
|
|
1436
|
+
rolling_nn_zero_pairs_keep_uns=merged.get("rolling_nn_zero_pairs_keep_uns", True),
|
|
1437
|
+
rolling_nn_zero_pairs_segments_keep_uns=merged.get(
|
|
1438
|
+
"rolling_nn_zero_pairs_segments_keep_uns", True
|
|
1439
|
+
),
|
|
1440
|
+
rolling_nn_zero_pairs_top_segments_per_read=merged.get(
|
|
1441
|
+
"rolling_nn_zero_pairs_top_segments_per_read", None
|
|
1442
|
+
),
|
|
1443
|
+
rolling_nn_zero_pairs_top_segments_max_overlap=merged.get(
|
|
1444
|
+
"rolling_nn_zero_pairs_top_segments_max_overlap", None
|
|
1445
|
+
),
|
|
1446
|
+
rolling_nn_zero_pairs_top_segments_min_span=merged.get(
|
|
1447
|
+
"rolling_nn_zero_pairs_top_segments_min_span", None
|
|
1448
|
+
),
|
|
1449
|
+
rolling_nn_zero_pairs_top_segments_write_csvs=merged.get(
|
|
1450
|
+
"rolling_nn_zero_pairs_top_segments_write_csvs", True
|
|
1451
|
+
),
|
|
1452
|
+
rolling_nn_zero_pairs_segment_histogram_bins=merged.get(
|
|
1453
|
+
"rolling_nn_zero_pairs_segment_histogram_bins", 30
|
|
1454
|
+
),
|
|
1455
|
+
cross_sample_analysis=merged.get("cross_sample_analysis", False),
|
|
1456
|
+
cross_sample_grouping_col=merged.get("cross_sample_grouping_col", None),
|
|
1457
|
+
cross_sample_random_seed=merged.get("cross_sample_random_seed", 42),
|
|
1458
|
+
delta_hamming_chimeric_span_threshold=merged.get(
|
|
1459
|
+
"delta_hamming_chimeric_span_threshold", 200
|
|
1460
|
+
),
|
|
1371
1461
|
layer_for_umap_plotting=merged.get("layer_for_umap_plotting", "nan_half"),
|
|
1372
1462
|
umap_layers_to_plot=merged.get(
|
|
1373
1463
|
"umap_layers_to_plot", ["mapped_length", "Raw_modification_signal"]
|
|
@@ -1531,6 +1621,9 @@ class ExperimentConfig:
|
|
|
1531
1621
|
force_redo_hmm_fit=merged.get("force_redo_hmm_fit", False),
|
|
1532
1622
|
bypass_hmm_apply=merged.get("bypass_hmm_apply", False),
|
|
1533
1623
|
force_redo_hmm_apply=merged.get("force_redo_hmm_apply", False),
|
|
1624
|
+
references_to_align_for_variant_annotation=merged.get(
|
|
1625
|
+
"references_to_align_for_variant_annotation", [None, None]
|
|
1626
|
+
),
|
|
1534
1627
|
config_source=config_source or "<var_dict>",
|
|
1535
1628
|
)
|
|
1536
1629
|
|
smftools/constants.py
CHANGED
|
@@ -44,7 +44,11 @@ PREPROCESS_DIR: Final[str] = "preprocess_adata_outputs"
|
|
|
44
44
|
SPATIAL_DIR: Final[str] = "spatial_adata_outputs"
|
|
45
45
|
HMM_DIR: Final[str] = "hmm_adata_outputs"
|
|
46
46
|
LATENT_DIR: Final[str] = "latent_adata_outputs"
|
|
47
|
+
VARIANT_DIR: Final[str] = "variant_adata_outputs"
|
|
48
|
+
CHIMERIC_DIR: Final[str] = "chimeric_adata_outputs"
|
|
49
|
+
|
|
47
50
|
LOGGING_DIR: Final[str] = "logs"
|
|
51
|
+
|
|
48
52
|
TRIM: Final[bool] = False
|
|
49
53
|
|
|
50
54
|
_private_conversions = ["unconverted"]
|
smftools/hmm/call_hmm_peaks.py
CHANGED
|
@@ -51,7 +51,7 @@ def call_hmm_peaks(
|
|
|
51
51
|
raise KeyError(f"obs column '{ref_column}' not found")
|
|
52
52
|
|
|
53
53
|
# Ensure categorical for predictable ref iteration
|
|
54
|
-
if not
|
|
54
|
+
if not isinstance(adata.obs[ref_column].dtype, pd.CategoricalDtype):
|
|
55
55
|
adata.obs[ref_column] = adata.obs[ref_column].astype("category")
|
|
56
56
|
|
|
57
57
|
# Optional: drop duplicate obs columns once to avoid Pandas/AnnData view quirks
|
|
@@ -5,20 +5,19 @@ def binarize_converted_base_identities(
|
|
|
5
5
|
base_identities,
|
|
6
6
|
strand,
|
|
7
7
|
modification_type,
|
|
8
|
-
bam,
|
|
9
|
-
device="cpu",
|
|
10
8
|
deaminase_footprinting=False,
|
|
11
9
|
mismatch_trend_per_read={},
|
|
12
10
|
on_missing="nan",
|
|
13
11
|
):
|
|
14
12
|
"""
|
|
15
13
|
Efficiently binarizes conversion SMF data within a sequence string using NumPy arrays.
|
|
14
|
+
For conversion modality, the strand parameter is used for mapping.
|
|
15
|
+
For deaminase modality, the mismatch_trend_per_read is used for mapping.
|
|
16
16
|
|
|
17
17
|
Parameters:
|
|
18
18
|
base_identities (dict): A dictionary returned by extract_base_identities. Keyed by read name. Points to a list of base identities.
|
|
19
19
|
strand (str): A string indicating which strand was converted in the experiment (options are 'top' and 'bottom').
|
|
20
20
|
modification_type (str): A string indicating the modification type of interest (options are '5mC' and '6mA').
|
|
21
|
-
bam (str): The bam file path
|
|
22
21
|
deaminase_footprinting (bool): Whether direct deaminase footprinting chemistry was used.
|
|
23
22
|
mismatch_trend_per_read (dict): For deaminase footprinting, indicates the type of conversion relative to the top strand reference for each read. (C->T or G->A if bottom strand was converted)
|
|
24
23
|
on_missing (str): Error handling if a read is missing
|
|
@@ -98,89 +97,3 @@ def binarize_converted_base_identities(
|
|
|
98
97
|
out[read_id] = res
|
|
99
98
|
|
|
100
99
|
return out
|
|
101
|
-
|
|
102
|
-
# if mismatch_trend_per_read is None:
|
|
103
|
-
# mismatch_trend_per_read = {}
|
|
104
|
-
|
|
105
|
-
# # If the modification type is 'unconverted', return NaN for all positions if the deaminase_footprinting strategy is not being used.
|
|
106
|
-
# if modification_type == "unconverted" and not deaminase_footprinting:
|
|
107
|
-
# #print(f"Skipping binarization for unconverted {strand} reads on bam: {bam}.")
|
|
108
|
-
# return {key: np.full(len(bases), np.nan) for key, bases in base_identities.items()}
|
|
109
|
-
|
|
110
|
-
# # Define mappings for binarization based on strand and modification type
|
|
111
|
-
# if deaminase_footprinting:
|
|
112
|
-
# binarization_maps = {
|
|
113
|
-
# ('C->T'): {'C': 0, 'T': 1},
|
|
114
|
-
# ('G->A'): {'G': 0, 'A': 1},
|
|
115
|
-
# }
|
|
116
|
-
|
|
117
|
-
# binarized_base_identities = {}
|
|
118
|
-
# for key, bases in base_identities.items():
|
|
119
|
-
# arr = np.array(bases, dtype='<U1')
|
|
120
|
-
# # Fetch the appropriate mapping
|
|
121
|
-
# conversion_type = mismatch_trend_per_read[key]
|
|
122
|
-
# base_map = binarization_maps.get(conversion_type, None)
|
|
123
|
-
# binarized = np.vectorize(lambda x: base_map.get(x, np.nan))(arr) # Apply mapping with fallback to NaN
|
|
124
|
-
# binarized_base_identities[key] = binarized
|
|
125
|
-
|
|
126
|
-
# return binarized_base_identities
|
|
127
|
-
|
|
128
|
-
# else:
|
|
129
|
-
# binarization_maps = {
|
|
130
|
-
# ('top', '5mC'): {'C': 1, 'T': 0},
|
|
131
|
-
# ('top', '6mA'): {'A': 1, 'G': 0},
|
|
132
|
-
# ('bottom', '5mC'): {'G': 1, 'A': 0},
|
|
133
|
-
# ('bottom', '6mA'): {'T': 1, 'C': 0}
|
|
134
|
-
# }
|
|
135
|
-
|
|
136
|
-
# if (strand, modification_type) not in binarization_maps:
|
|
137
|
-
# raise ValueError(f"Invalid combination of strand='{strand}' and modification_type='{modification_type}'")
|
|
138
|
-
|
|
139
|
-
# # Fetch the appropriate mapping
|
|
140
|
-
# base_map = binarization_maps[(strand, modification_type)]
|
|
141
|
-
|
|
142
|
-
# binarized_base_identities = {}
|
|
143
|
-
# for key, bases in base_identities.items():
|
|
144
|
-
# arr = np.array(bases, dtype='<U1')
|
|
145
|
-
# binarized = np.vectorize(lambda x: base_map.get(x, np.nan))(arr) # Apply mapping with fallback to NaN
|
|
146
|
-
# binarized_base_identities[key] = binarized
|
|
147
|
-
|
|
148
|
-
# return binarized_base_identities
|
|
149
|
-
# import torch
|
|
150
|
-
|
|
151
|
-
# # If the modification type is 'unconverted', return NaN for all positions
|
|
152
|
-
# if modification_type == "unconverted":
|
|
153
|
-
# print(f"Skipping binarization for unconverted {strand} reads on bam: {bam}.")
|
|
154
|
-
# return {key: torch.full((len(bases),), float('nan'), device=device) for key, bases in base_identities.items()}
|
|
155
|
-
|
|
156
|
-
# # Define mappings for binarization based on strand and modification type
|
|
157
|
-
# binarization_maps = {
|
|
158
|
-
# ('top', '5mC'): {'C': 1, 'T': 0},
|
|
159
|
-
# ('top', '6mA'): {'A': 1, 'G': 0},
|
|
160
|
-
# ('bottom', '5mC'): {'G': 1, 'A': 0},
|
|
161
|
-
# ('bottom', '6mA'): {'T': 1, 'C': 0}
|
|
162
|
-
# }
|
|
163
|
-
|
|
164
|
-
# if (strand, modification_type) not in binarization_maps:
|
|
165
|
-
# raise ValueError(f"Invalid combination of strand='{strand}' and modification_type='{modification_type}'")
|
|
166
|
-
|
|
167
|
-
# # Fetch the appropriate mapping
|
|
168
|
-
# base_map = binarization_maps[(strand, modification_type)]
|
|
169
|
-
|
|
170
|
-
# # Convert mapping to tensor
|
|
171
|
-
# base_keys = list(base_map.keys())
|
|
172
|
-
# base_values = torch.tensor(list(base_map.values()), dtype=torch.float32, device=device)
|
|
173
|
-
|
|
174
|
-
# # Create a lookup dictionary (ASCII-based for fast mapping)
|
|
175
|
-
# lookup_table = torch.full((256,), float('nan'), dtype=torch.float32, device=device)
|
|
176
|
-
# for k, v in zip(base_keys, base_values):
|
|
177
|
-
# lookup_table[ord(k)] = v
|
|
178
|
-
|
|
179
|
-
# # Process reads
|
|
180
|
-
# binarized_base_identities = {}
|
|
181
|
-
# for key, bases in base_identities.items():
|
|
182
|
-
# bases_tensor = torch.tensor([ord(c) for c in bases], dtype=torch.uint8, device=device) # Convert chars to ASCII
|
|
183
|
-
# binarized = lookup_table[bases_tensor] # Efficient lookup
|
|
184
|
-
# binarized_base_identities[key] = binarized
|
|
185
|
-
|
|
186
|
-
# return binarized_base_identities
|
|
@@ -272,6 +272,10 @@ def converted_BAM_to_adata(
|
|
|
272
272
|
consensus_sequence_list
|
|
273
273
|
)
|
|
274
274
|
|
|
275
|
+
from .h5ad_functions import append_reference_strand_quality_stats
|
|
276
|
+
|
|
277
|
+
append_reference_strand_quality_stats(final_adata)
|
|
278
|
+
|
|
275
279
|
if input_already_demuxed:
|
|
276
280
|
final_adata.obs[DEMUX_TYPE] = ["already"] * final_adata.shape[0]
|
|
277
281
|
final_adata.obs[DEMUX_TYPE] = final_adata.obs[DEMUX_TYPE].astype("category")
|
|
@@ -321,15 +325,17 @@ def process_conversion_sites(
|
|
|
321
325
|
conversion_types = conversions[1:]
|
|
322
326
|
|
|
323
327
|
# Process the unconverted sequence once
|
|
328
|
+
# modification dict is keyed by mod type (ie unconverted, 5mC, 6mA)
|
|
329
|
+
# modification_dict[unconverted] points to a dictionary keyed by unconverted record.id keys.
|
|
330
|
+
# This then maps to [sequence_length, [], [], unconverted sequence, unconverted complement]
|
|
324
331
|
modification_dict[unconverted] = find_conversion_sites(
|
|
325
332
|
converted_FASTA, unconverted, conversions, deaminase_footprinting
|
|
326
333
|
)
|
|
327
|
-
# Above points to record_dict[record.id] = [sequence_length, [], [], sequence, complement] with only unconverted record.id keys
|
|
328
334
|
|
|
329
|
-
# Get
|
|
335
|
+
# Get max sequence length from unconverted records
|
|
330
336
|
max_reference_length = max(values[0] for values in modification_dict[unconverted].values())
|
|
331
337
|
|
|
332
|
-
# Add
|
|
338
|
+
# Add unconverted records to `record_FASTA_dict`
|
|
333
339
|
for record, values in modification_dict[unconverted].items():
|
|
334
340
|
sequence_length, top_coords, bottom_coords, sequence, complement = values
|
|
335
341
|
|
|
@@ -358,25 +364,34 @@ def process_conversion_sites(
|
|
|
358
364
|
)
|
|
359
365
|
|
|
360
366
|
# Process converted records
|
|
367
|
+
# For each conversion type (ie 5mC, 6mA), add the conversion type as a key to modification_dict.
|
|
368
|
+
# This points to a dictionary keyed by the unconverted record id key.
|
|
369
|
+
# This points to [sequence_length, top_strand_coordinates, bottom_strand_coordinates, unconverted sequence, unconverted complement]
|
|
361
370
|
for conversion in conversion_types:
|
|
362
371
|
modification_dict[conversion] = find_conversion_sites(
|
|
363
372
|
converted_FASTA, conversion, conversions, deaminase_footprinting
|
|
364
373
|
)
|
|
365
|
-
# Above points to record_dict[record.id] = [sequence_length, top_strand_coordinates, bottom_strand_coordinates, sequence, complement] with only unconverted record.id keys
|
|
366
374
|
|
|
375
|
+
# Iterate over the unconverted record ids in mod_dict, as well as the
|
|
376
|
+
# [sequence_length, top_strand_coordinates, bottom_strand_coordinates, unconverted sequence, unconverted complement] for the conversion type
|
|
367
377
|
for record, values in modification_dict[conversion].items():
|
|
368
378
|
sequence_length, top_coords, bottom_coords, sequence, complement = values
|
|
369
379
|
|
|
370
380
|
if not deaminase_footprinting:
|
|
371
|
-
|
|
381
|
+
# For conversion smf, make the chromosome name the base record name
|
|
382
|
+
chromosome = record.split(f"_{unconverted}_")[0]
|
|
372
383
|
else:
|
|
384
|
+
# For deaminase smf, make the chromosome and record name the same
|
|
373
385
|
chromosome = record
|
|
374
386
|
|
|
375
|
-
# Add
|
|
387
|
+
# Add both strands for converted records
|
|
376
388
|
for strand in ["top", "bottom"]:
|
|
389
|
+
# Generate converted/unconverted record names that are found in the converted FASTA
|
|
377
390
|
converted_name = f"{chromosome}_{conversion}_{strand}"
|
|
378
391
|
unconverted_name = f"{chromosome}_{unconverted}_top"
|
|
379
392
|
|
|
393
|
+
# Use the converted FASTA record names as keys to a dict that points to RecordFastaInfo objects.
|
|
394
|
+
# These objects will contain the unconverted sequence/complement.
|
|
380
395
|
record_FASTA_dict[converted_name] = RecordFastaInfo(
|
|
381
396
|
sequence=sequence + "N" * (max_reference_length - sequence_length),
|
|
382
397
|
complement=complement + "N" * (max_reference_length - sequence_length),
|
|
@@ -577,16 +592,19 @@ def process_single_bam(
|
|
|
577
592
|
"""
|
|
578
593
|
adata_list: list[ad.AnnData] = []
|
|
579
594
|
|
|
595
|
+
# Iterate over BAM records that passed filtering.
|
|
580
596
|
for record in records_to_analyze:
|
|
581
597
|
sample = bam.stem
|
|
582
598
|
record_info = record_FASTA_dict[record]
|
|
583
599
|
chromosome = record_info.chromosome
|
|
584
600
|
current_length = record_info.sequence_length
|
|
601
|
+
# Note, mod_type and strand are only correctly load for conversion smf and not deaminase
|
|
602
|
+
# However, these variables are only used for conversion smf and not deaminase, so works.
|
|
585
603
|
mod_type, strand = record_info.conversion, record_info.strand
|
|
586
604
|
non_converted_sequence = chromosome_FASTA_dict[chromosome][0]
|
|
587
605
|
record_sequence = converted_FASTA_record_seq_map[record][1]
|
|
588
606
|
|
|
589
|
-
# Extract Base Identities
|
|
607
|
+
# Extract Base Identities for forward and reverse mapped reads.
|
|
590
608
|
(
|
|
591
609
|
fwd_bases,
|
|
592
610
|
rev_bases,
|
|
@@ -615,13 +633,12 @@ def process_single_bam(
|
|
|
615
633
|
merged_bin = {}
|
|
616
634
|
|
|
617
635
|
# Binarize the Base Identities if they exist
|
|
636
|
+
# Note, mod_type is always unconverted and strand is always top currently for deaminase smf. this works for now.
|
|
618
637
|
if fwd_bases:
|
|
619
638
|
fwd_bin = binarize_converted_base_identities(
|
|
620
639
|
fwd_bases,
|
|
621
640
|
strand,
|
|
622
641
|
mod_type,
|
|
623
|
-
bam,
|
|
624
|
-
device,
|
|
625
642
|
deaminase_footprinting,
|
|
626
643
|
mismatch_trend_per_read,
|
|
627
644
|
)
|
|
@@ -632,8 +649,6 @@ def process_single_bam(
|
|
|
632
649
|
rev_bases,
|
|
633
650
|
strand,
|
|
634
651
|
mod_type,
|
|
635
|
-
bam,
|
|
636
|
-
device,
|
|
637
652
|
deaminase_footprinting,
|
|
638
653
|
mismatch_trend_per_read,
|
|
639
654
|
)
|
|
@@ -742,10 +757,35 @@ def process_single_bam(
|
|
|
742
757
|
adata.obs[REFERENCE] = [chromosome] * len(adata)
|
|
743
758
|
adata.obs[STRAND] = [strand] * len(adata)
|
|
744
759
|
adata.obs[DATASET] = [mod_type] * len(adata)
|
|
745
|
-
adata.obs[REFERENCE_DATASET_STRAND] = [f"{chromosome}_{mod_type}_{strand}"] * len(adata)
|
|
746
|
-
adata.obs[REFERENCE_STRAND] = [f"{chromosome}_{strand}"] * len(adata)
|
|
747
760
|
adata.obs[READ_MISMATCH_TREND] = adata.obs_names.map(mismatch_trend_series)
|
|
748
761
|
|
|
762
|
+
# Currently, deaminase footprinting uses mismatch trend to define the strand.
|
|
763
|
+
if deaminase_footprinting:
|
|
764
|
+
is_ct = adata.obs[READ_MISMATCH_TREND] == "C->T"
|
|
765
|
+
is_ga = adata.obs[READ_MISMATCH_TREND] == "G->A"
|
|
766
|
+
|
|
767
|
+
adata.obs.loc[is_ct, STRAND] = "top"
|
|
768
|
+
adata.obs.loc[is_ga, STRAND] = "bottom"
|
|
769
|
+
# Currently, conversion footprinting uses strand to define the mismatch trend.
|
|
770
|
+
else:
|
|
771
|
+
is_top = adata.obs[STRAND] == "top"
|
|
772
|
+
is_bottom = adata.obs[STRAND] == "bottom"
|
|
773
|
+
|
|
774
|
+
adata.obs.loc[is_top, READ_MISMATCH_TREND] = "C->T"
|
|
775
|
+
adata.obs.loc[is_bottom, READ_MISMATCH_TREND] = "G->A"
|
|
776
|
+
|
|
777
|
+
adata.obs[REFERENCE_DATASET_STRAND] = (
|
|
778
|
+
adata.obs[REFERENCE].astype(str)
|
|
779
|
+
+ "_"
|
|
780
|
+
+ adata.obs[DATASET].astype(str)
|
|
781
|
+
+ "_"
|
|
782
|
+
+ adata.obs[STRAND].astype(str)
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
adata.obs[REFERENCE_STRAND] = (
|
|
786
|
+
adata.obs[REFERENCE].astype(str) + "_" + adata.obs[STRAND].astype(str)
|
|
787
|
+
)
|
|
788
|
+
|
|
749
789
|
read_mapping_direction = []
|
|
750
790
|
for read_id in adata.obs_names:
|
|
751
791
|
if read_id in fwd_reads:
|
|
@@ -10,6 +10,7 @@ import numpy as np
|
|
|
10
10
|
import pandas as pd
|
|
11
11
|
import scipy.sparse as sp
|
|
12
12
|
|
|
13
|
+
from smftools.constants import BASE_QUALITY_SCORES, READ_SPAN_MASK, REFERENCE_STRAND
|
|
13
14
|
from smftools.logging_utils import get_logger
|
|
14
15
|
from smftools.optional_imports import require
|
|
15
16
|
|
|
@@ -84,6 +85,88 @@ def add_demux_type_annotation(
|
|
|
84
85
|
return adata
|
|
85
86
|
|
|
86
87
|
|
|
88
|
+
def append_reference_strand_quality_stats(
|
|
89
|
+
adata,
|
|
90
|
+
ref_column: str = REFERENCE_STRAND,
|
|
91
|
+
quality_layer: str = BASE_QUALITY_SCORES,
|
|
92
|
+
read_span_layer: str = READ_SPAN_MASK,
|
|
93
|
+
uns_flag: str = "append_reference_strand_quality_stats_performed",
|
|
94
|
+
force_redo: bool = False,
|
|
95
|
+
bypass: bool = False,
|
|
96
|
+
) -> None:
|
|
97
|
+
"""Append per-position quality and error rate stats for each reference strand.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
adata: AnnData object to annotate in-place.
|
|
101
|
+
ref_column: Obs column defining reference strand groups.
|
|
102
|
+
quality_layer: Layer containing base quality scores.
|
|
103
|
+
read_span_layer: Optional layer marking covered positions (1=covered, 0=not covered).
|
|
104
|
+
uns_flag: Flag in ``adata.uns`` indicating prior completion.
|
|
105
|
+
force_redo: Whether to rerun even if ``uns_flag`` is set.
|
|
106
|
+
bypass: Whether to skip this step.
|
|
107
|
+
"""
|
|
108
|
+
if bypass:
|
|
109
|
+
return
|
|
110
|
+
|
|
111
|
+
already = bool(adata.uns.get(uns_flag, False))
|
|
112
|
+
if already and not force_redo:
|
|
113
|
+
return
|
|
114
|
+
|
|
115
|
+
if ref_column not in adata.obs:
|
|
116
|
+
logger.debug("Reference column '%s' not found; skipping quality stats.", ref_column)
|
|
117
|
+
return
|
|
118
|
+
|
|
119
|
+
if quality_layer not in adata.layers:
|
|
120
|
+
logger.debug("Quality layer '%s' not found; skipping quality stats.", quality_layer)
|
|
121
|
+
return
|
|
122
|
+
|
|
123
|
+
ref_values = adata.obs[ref_column]
|
|
124
|
+
references = (
|
|
125
|
+
ref_values.cat.categories if hasattr(ref_values, "cat") else pd.Index(pd.unique(ref_values))
|
|
126
|
+
)
|
|
127
|
+
n_vars = adata.shape[1]
|
|
128
|
+
has_span_mask = read_span_layer in adata.layers
|
|
129
|
+
|
|
130
|
+
for ref in references:
|
|
131
|
+
ref_mask = ref_values == ref
|
|
132
|
+
ref_position_mask = adata.var.get(f"position_in_{ref}")
|
|
133
|
+
if ref_position_mask is None:
|
|
134
|
+
ref_position_mask = pd.Series(np.ones(n_vars, dtype=bool), index=adata.var.index)
|
|
135
|
+
else:
|
|
136
|
+
ref_position_mask = ref_position_mask.astype(bool)
|
|
137
|
+
|
|
138
|
+
mean_quality = np.full(n_vars, np.nan, dtype=float)
|
|
139
|
+
std_quality = np.full(n_vars, np.nan, dtype=float)
|
|
140
|
+
mean_error = np.full(n_vars, np.nan, dtype=float)
|
|
141
|
+
std_error = np.full(n_vars, np.nan, dtype=float)
|
|
142
|
+
|
|
143
|
+
if ref_mask.sum() > 0:
|
|
144
|
+
quality_matrix = np.asarray(adata.layers[quality_layer][ref_mask]).astype(float)
|
|
145
|
+
quality_matrix[quality_matrix < 0] = np.nan
|
|
146
|
+
if has_span_mask:
|
|
147
|
+
coverage_mask = np.asarray(adata.layers[read_span_layer][ref_mask]) > 0
|
|
148
|
+
quality_matrix = np.where(coverage_mask, quality_matrix, np.nan)
|
|
149
|
+
|
|
150
|
+
mean_quality = np.nanmean(quality_matrix, axis=0)
|
|
151
|
+
std_quality = np.nanstd(quality_matrix, axis=0)
|
|
152
|
+
|
|
153
|
+
error_matrix = np.power(10.0, -quality_matrix / 10.0)
|
|
154
|
+
mean_error = np.nanmean(error_matrix, axis=0)
|
|
155
|
+
std_error = np.nanstd(error_matrix, axis=0)
|
|
156
|
+
|
|
157
|
+
mean_quality = np.where(ref_position_mask.values, mean_quality, np.nan)
|
|
158
|
+
std_quality = np.where(ref_position_mask.values, std_quality, np.nan)
|
|
159
|
+
mean_error = np.where(ref_position_mask.values, mean_error, np.nan)
|
|
160
|
+
std_error = np.where(ref_position_mask.values, std_error, np.nan)
|
|
161
|
+
|
|
162
|
+
adata.var[f"{ref}_mean_base_quality"] = pd.Series(mean_quality, index=adata.var.index)
|
|
163
|
+
adata.var[f"{ref}_std_base_quality"] = pd.Series(std_quality, index=adata.var.index)
|
|
164
|
+
adata.var[f"{ref}_mean_error_rate"] = pd.Series(mean_error, index=adata.var.index)
|
|
165
|
+
adata.var[f"{ref}_std_error_rate"] = pd.Series(std_error, index=adata.var.index)
|
|
166
|
+
|
|
167
|
+
adata.uns[uns_flag] = True
|
|
168
|
+
|
|
169
|
+
|
|
87
170
|
def add_read_tag_annotations(
|
|
88
171
|
adata,
|
|
89
172
|
bam_files: Optional[List[str]] = None,
|
|
@@ -1881,6 +1881,10 @@ def modkit_extract_to_adata(
|
|
|
1881
1881
|
f"{record}_{strand}_{mapping_dir}_consensus_sequence_from_all_samples"
|
|
1882
1882
|
] = consensus_sequence_list
|
|
1883
1883
|
|
|
1884
|
+
from .h5ad_functions import append_reference_strand_quality_stats
|
|
1885
|
+
|
|
1886
|
+
append_reference_strand_quality_stats(final_adata)
|
|
1887
|
+
|
|
1884
1888
|
if input_already_demuxed:
|
|
1885
1889
|
final_adata.obs[DEMUX_TYPE] = ["already"] * final_adata.shape[0]
|
|
1886
1890
|
final_adata.obs[DEMUX_TYPE] = final_adata.obs[DEMUX_TYPE].astype("category")
|
smftools/plotting/__init__.py
CHANGED
|
@@ -3,18 +3,32 @@ from __future__ import annotations
|
|
|
3
3
|
from importlib import import_module
|
|
4
4
|
|
|
5
5
|
_LAZY_ATTRS = {
|
|
6
|
-
"combined_hmm_length_clustermap": "smftools.plotting.
|
|
7
|
-
"combined_hmm_raw_clustermap": "smftools.plotting.
|
|
8
|
-
"combined_raw_clustermap": "smftools.plotting.
|
|
9
|
-
"
|
|
10
|
-
"
|
|
11
|
-
"
|
|
12
|
-
"
|
|
13
|
-
"
|
|
14
|
-
"
|
|
15
|
-
"
|
|
16
|
-
"
|
|
17
|
-
"
|
|
6
|
+
"combined_hmm_length_clustermap": "smftools.plotting.hmm_plotting",
|
|
7
|
+
"combined_hmm_raw_clustermap": "smftools.plotting.hmm_plotting",
|
|
8
|
+
"combined_raw_clustermap": "smftools.plotting.spatial_plotting",
|
|
9
|
+
"plot_delta_hamming_summary": "smftools.plotting.chimeric_plotting",
|
|
10
|
+
"plot_hamming_span_trio": "smftools.plotting.chimeric_plotting",
|
|
11
|
+
"plot_rolling_nn_and_layer": "smftools.plotting.chimeric_plotting",
|
|
12
|
+
"plot_rolling_nn_and_two_layers": "smftools.plotting.chimeric_plotting",
|
|
13
|
+
"plot_segment_length_histogram": "smftools.plotting.chimeric_plotting",
|
|
14
|
+
"plot_span_length_distributions": "smftools.plotting.chimeric_plotting",
|
|
15
|
+
"plot_zero_hamming_pair_counts": "smftools.plotting.chimeric_plotting",
|
|
16
|
+
"plot_zero_hamming_span_and_layer": "smftools.plotting.chimeric_plotting",
|
|
17
|
+
"plot_hmm_layers_rolling_by_sample_ref": "smftools.plotting.hmm_plotting",
|
|
18
|
+
"plot_nmf_components": "smftools.plotting.latent_plotting",
|
|
19
|
+
"plot_pca_components": "smftools.plotting.latent_plotting",
|
|
20
|
+
"plot_cp_sequence_components": "smftools.plotting.latent_plotting",
|
|
21
|
+
"plot_embedding": "smftools.plotting.latent_plotting",
|
|
22
|
+
"plot_embedding_grid": "smftools.plotting.latent_plotting",
|
|
23
|
+
"plot_read_span_quality_clustermaps": "smftools.plotting.preprocess_plotting",
|
|
24
|
+
"plot_mismatch_base_frequency_by_position": "smftools.plotting.variant_plotting",
|
|
25
|
+
"plot_pca": "smftools.plotting.latent_plotting",
|
|
26
|
+
"plot_pca_grid": "smftools.plotting.latent_plotting",
|
|
27
|
+
"plot_pca_explained_variance": "smftools.plotting.latent_plotting",
|
|
28
|
+
"plot_sequence_integer_encoding_clustermaps": "smftools.plotting.variant_plotting",
|
|
29
|
+
"plot_variant_segment_clustermaps": "smftools.plotting.variant_plotting",
|
|
30
|
+
"plot_umap": "smftools.plotting.latent_plotting",
|
|
31
|
+
"plot_umap_grid": "smftools.plotting.latent_plotting",
|
|
18
32
|
"plot_bar_relative_risk": "smftools.plotting.position_stats",
|
|
19
33
|
"plot_positionwise_matrix": "smftools.plotting.position_stats",
|
|
20
34
|
"plot_positionwise_matrix_grid": "smftools.plotting.position_stats",
|