smftools 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/_version.py +1 -1
- smftools/cli/helpers.py +32 -6
- smftools/cli/hmm_adata.py +232 -31
- smftools/cli/latent_adata.py +318 -0
- smftools/cli/load_adata.py +77 -73
- smftools/cli/preprocess_adata.py +178 -53
- smftools/cli/spatial_adata.py +149 -101
- smftools/cli_entry.py +12 -0
- smftools/config/conversion.yaml +11 -1
- smftools/config/default.yaml +38 -1
- smftools/config/experiment_config.py +53 -1
- smftools/constants.py +65 -0
- smftools/hmm/HMM.py +88 -0
- smftools/informatics/__init__.py +6 -0
- smftools/informatics/bam_functions.py +358 -8
- smftools/informatics/converted_BAM_to_adata.py +584 -163
- smftools/informatics/h5ad_functions.py +115 -2
- smftools/informatics/modkit_extract_to_adata.py +1003 -425
- smftools/informatics/sequence_encoding.py +72 -0
- smftools/logging_utils.py +21 -2
- smftools/metadata.py +1 -1
- smftools/plotting/__init__.py +9 -0
- smftools/plotting/general_plotting.py +2411 -628
- smftools/plotting/hmm_plotting.py +85 -7
- smftools/preprocessing/__init__.py +1 -0
- smftools/preprocessing/append_base_context.py +17 -17
- smftools/preprocessing/append_mismatch_frequency_sites.py +158 -0
- smftools/preprocessing/calculate_consensus.py +1 -1
- smftools/preprocessing/calculate_read_modification_stats.py +6 -1
- smftools/readwrite.py +53 -17
- smftools/schema/anndata_schema_v1.yaml +15 -1
- smftools/tools/__init__.py +4 -0
- smftools/tools/calculate_leiden.py +57 -0
- smftools/tools/calculate_nmf.py +119 -0
- smftools/tools/calculate_umap.py +91 -8
- smftools/tools/rolling_nn_distance.py +235 -0
- smftools/tools/tensor_factorization.py +169 -0
- {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/METADATA +8 -6
- {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/RECORD +42 -35
- {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/WHEEL +0 -0
- {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/entry_points.txt +0 -0
- {smftools-0.3.0.dist-info → smftools-0.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -12,6 +12,7 @@ from smftools.constants import (
|
|
|
12
12
|
BAM_SUFFIX,
|
|
13
13
|
BARCODE_BOTH_ENDS,
|
|
14
14
|
CONVERSIONS,
|
|
15
|
+
LOAD_DIR,
|
|
15
16
|
MOD_LIST,
|
|
16
17
|
MOD_MAP,
|
|
17
18
|
REF_COL,
|
|
@@ -664,6 +665,8 @@ class ExperimentConfig:
|
|
|
664
665
|
# General I/O
|
|
665
666
|
input_data_path: Optional[str] = None
|
|
666
667
|
output_directory: Optional[str] = None
|
|
668
|
+
emit_log_file: Optional[bool] = True
|
|
669
|
+
log_level: Optional[str] = "INFO"
|
|
667
670
|
fasta: Optional[str] = None
|
|
668
671
|
bam_suffix: str = BAM_SUFFIX
|
|
669
672
|
recursive_input_search: bool = True
|
|
@@ -736,6 +739,7 @@ class ExperimentConfig:
|
|
|
736
739
|
aligner_args: Optional[List[str]] = None
|
|
737
740
|
make_bigwigs: bool = False
|
|
738
741
|
make_beds: bool = False
|
|
742
|
+
annotate_secondary_supplementary: bool = True
|
|
739
743
|
samtools_backend: str = "auto"
|
|
740
744
|
bedtools_backend: str = "auto"
|
|
741
745
|
bigwig_backend: str = "auto"
|
|
@@ -747,6 +751,9 @@ class ExperimentConfig:
|
|
|
747
751
|
# General Plotting
|
|
748
752
|
sample_name_col_for_plotting: Optional[str] = "Barcode"
|
|
749
753
|
rows_per_qc_histogram_grid: int = 12
|
|
754
|
+
clustermap_demux_types_to_plot: List[str] = field(
|
|
755
|
+
default_factory=lambda: ["single", "double", "already"]
|
|
756
|
+
)
|
|
750
757
|
|
|
751
758
|
# Preprocessing - Read length and quality filter params
|
|
752
759
|
read_coord_filter: Optional[Sequence[float]] = field(default_factory=lambda: [None, None])
|
|
@@ -816,6 +823,9 @@ class ExperimentConfig:
|
|
|
816
823
|
duplicate_detection_site_types: List[str] = field(
|
|
817
824
|
default_factory=lambda: ["GpC", "CpG", "ambiguous_GpC_CpG"]
|
|
818
825
|
)
|
|
826
|
+
duplicate_detection_demux_types_to_use: List[str] = field(
|
|
827
|
+
default_factory=lambda: ["single", "double", "already"]
|
|
828
|
+
)
|
|
819
829
|
duplicate_detection_distance_threshold: float = 0.07
|
|
820
830
|
hamming_vs_metric_keys: List[str] = field(default_factory=lambda: ["Fraction_C_site_modified"])
|
|
821
831
|
duplicate_detection_keep_best_metric: str = "read_quality"
|
|
@@ -827,6 +837,9 @@ class ExperimentConfig:
|
|
|
827
837
|
|
|
828
838
|
# Preprocessing - Position QC
|
|
829
839
|
position_max_nan_threshold: float = 0.1
|
|
840
|
+
mismatch_frequency_range: Sequence[float] = field(default_factory=lambda: [0.05, 0.95])
|
|
841
|
+
mismatch_frequency_layer: str = "mismatch_integer_encoding"
|
|
842
|
+
mismatch_frequency_read_span_layer: str = "read_span_mask"
|
|
830
843
|
|
|
831
844
|
# Spatial Analysis - Clustermap params
|
|
832
845
|
layer_for_clustermap_plotting: Optional[str] = "nan0_0minus1"
|
|
@@ -835,6 +848,14 @@ class ExperimentConfig:
|
|
|
835
848
|
clustermap_cmap_cpg: Optional[str] = "coolwarm"
|
|
836
849
|
clustermap_cmap_a: Optional[str] = "coolwarm"
|
|
837
850
|
spatial_clustermap_sortby: Optional[str] = "gpc"
|
|
851
|
+
rolling_nn_layer: Optional[str] = "nan0_0minus1"
|
|
852
|
+
rolling_nn_plot_layer: Optional[str] = "nan0_0minus1"
|
|
853
|
+
rolling_nn_window: int = 15
|
|
854
|
+
rolling_nn_step: int = 2
|
|
855
|
+
rolling_nn_min_overlap: int = 10
|
|
856
|
+
rolling_nn_return_fraction: bool = True
|
|
857
|
+
rolling_nn_obsm_key: str = "rolling_nn_dist"
|
|
858
|
+
rolling_nn_site_types: Optional[List[str]] = None
|
|
838
859
|
|
|
839
860
|
# Spatial Analysis - UMAP/Leiden params
|
|
840
861
|
layer_for_umap_plotting: Optional[str] = "nan_half"
|
|
@@ -883,11 +904,15 @@ class ExperimentConfig:
|
|
|
883
904
|
accessible_patches: Optional[bool] = True
|
|
884
905
|
cpg: Optional[bool] = False
|
|
885
906
|
hmm_feature_sets: Dict[str, Any] = field(default_factory=dict)
|
|
907
|
+
hmm_feature_colormaps: Dict[str, Any] = field(default_factory=dict)
|
|
886
908
|
hmm_merge_layer_features: Optional[List[Tuple]] = field(default_factory=lambda: [(None, 60)])
|
|
887
909
|
clustermap_cmap_hmm: Optional[str] = "coolwarm"
|
|
888
910
|
hmm_clustermap_feature_layers: List[str] = field(
|
|
889
911
|
default_factory=lambda: ["all_accessible_features"]
|
|
890
912
|
)
|
|
913
|
+
hmm_clustermap_length_layers: List[str] = field(
|
|
914
|
+
default_factory=lambda: ["all_accessible_features"]
|
|
915
|
+
)
|
|
891
916
|
hmm_clustermap_sortby: Optional[str] = "hmm"
|
|
892
917
|
hmm_peak_feature_configs: Dict[str, Any] = field(default_factory=dict)
|
|
893
918
|
|
|
@@ -906,6 +931,8 @@ class ExperimentConfig:
|
|
|
906
931
|
invert_adata: bool = False
|
|
907
932
|
bypass_append_binary_layer_by_base_context: bool = False
|
|
908
933
|
force_redo_append_binary_layer_by_base_context: bool = False
|
|
934
|
+
bypass_append_mismatch_frequency_sites: bool = False
|
|
935
|
+
force_redo_append_mismatch_frequency_sites: bool = False
|
|
909
936
|
bypass_calculate_read_modification_stats: bool = False
|
|
910
937
|
force_redo_calculate_read_modification_stats: bool = False
|
|
911
938
|
bypass_filter_reads_on_modification_thresholds: bool = False
|
|
@@ -1110,7 +1137,7 @@ class ExperimentConfig:
|
|
|
1110
1137
|
|
|
1111
1138
|
# Demultiplexing output path
|
|
1112
1139
|
split_dir = merged.get("split_dir", SPLIT_DIR)
|
|
1113
|
-
split_path = output_dir / split_dir
|
|
1140
|
+
split_path = output_dir / LOAD_DIR / split_dir
|
|
1114
1141
|
|
|
1115
1142
|
# final normalization
|
|
1116
1143
|
if "strands" in merged:
|
|
@@ -1197,6 +1224,9 @@ class ExperimentConfig:
|
|
|
1197
1224
|
# Final normalization of hmm_feature_sets and canonical local variables
|
|
1198
1225
|
merged["hmm_feature_sets"] = normalize_hmm_feature_sets(merged.get("hmm_feature_sets", {}))
|
|
1199
1226
|
hmm_feature_sets = merged.get("hmm_feature_sets", {})
|
|
1227
|
+
hmm_feature_colormaps = merged.get("hmm_feature_colormaps", {})
|
|
1228
|
+
if not isinstance(hmm_feature_colormaps, dict):
|
|
1229
|
+
hmm_feature_colormaps = {}
|
|
1200
1230
|
hmm_annotation_threshold = merged.get("hmm_annotation_threshold", 0.5)
|
|
1201
1231
|
hmm_batch_size = int(merged.get("hmm_batch_size", 1024))
|
|
1202
1232
|
hmm_use_viterbi = bool(merged.get("hmm_use_viterbi", False))
|
|
@@ -1211,6 +1241,9 @@ class ExperimentConfig:
|
|
|
1211
1241
|
hmm_clustermap_feature_layers = _parse_list(
|
|
1212
1242
|
merged.get("hmm_clustermap_feature_layers", "all_accessible_features")
|
|
1213
1243
|
)
|
|
1244
|
+
hmm_clustermap_length_layers = _parse_list(
|
|
1245
|
+
merged.get("hmm_clustermap_length_layers", hmm_clustermap_feature_layers)
|
|
1246
|
+
)
|
|
1214
1247
|
|
|
1215
1248
|
hmm_fit_strategy = str(merged.get("hmm_fit_strategy", "per_group")).strip()
|
|
1216
1249
|
hmm_shared_scope = _parse_list(merged.get("hmm_shared_scope", ["reference", "methbase"]))
|
|
@@ -1231,6 +1264,7 @@ class ExperimentConfig:
|
|
|
1231
1264
|
|
|
1232
1265
|
# instantiate dataclass
|
|
1233
1266
|
instance = cls(
|
|
1267
|
+
annotate_secondary_supplementary=merged.get("annotate_secondary_supplementary", True),
|
|
1234
1268
|
smf_modality=merged.get("smf_modality"),
|
|
1235
1269
|
input_data_path=input_data_path,
|
|
1236
1270
|
recursive_input_search=merged.get("recursive_input_search"),
|
|
@@ -1257,6 +1291,8 @@ class ExperimentConfig:
|
|
|
1257
1291
|
trim=merged.get("trim", TRIM),
|
|
1258
1292
|
input_already_demuxed=merged.get("input_already_demuxed", False),
|
|
1259
1293
|
threads=merged.get("threads"),
|
|
1294
|
+
emit_log_file=merged.get("emit_log_file", True),
|
|
1295
|
+
log_level=merged.get("log_level", "INFO"),
|
|
1260
1296
|
sample_sheet_path=merged.get("sample_sheet_path"),
|
|
1261
1297
|
sample_sheet_mapping_column=merged.get("sample_sheet_mapping_column"),
|
|
1262
1298
|
delete_intermediate_bams=merged.get("delete_intermediate_bams", False),
|
|
@@ -1313,6 +1349,9 @@ class ExperimentConfig:
|
|
|
1313
1349
|
),
|
|
1314
1350
|
reindexing_offsets=merged.get("reindexing_offsets", {None: None}),
|
|
1315
1351
|
reindexed_var_suffix=merged.get("reindexed_var_suffix", "reindexed"),
|
|
1352
|
+
clustermap_demux_types_to_plot=merged.get(
|
|
1353
|
+
"clustermap_demux_types_to_plot", ["single", "double", "already"]
|
|
1354
|
+
),
|
|
1316
1355
|
layer_for_clustermap_plotting=merged.get(
|
|
1317
1356
|
"layer_for_clustermap_plotting", "nan0_0minus1"
|
|
1318
1357
|
),
|
|
@@ -1321,6 +1360,14 @@ class ExperimentConfig:
|
|
|
1321
1360
|
clustermap_cmap_cpg=merged.get("clustermap_cmap_cpg", "coolwarm"),
|
|
1322
1361
|
clustermap_cmap_a=merged.get("clustermap_cmap_a", "coolwarm"),
|
|
1323
1362
|
spatial_clustermap_sortby=merged.get("spatial_clustermap_sortby", "gpc"),
|
|
1363
|
+
rolling_nn_layer=merged.get("rolling_nn_layer", "nan0_0minus1"),
|
|
1364
|
+
rolling_nn_plot_layer=merged.get("rolling_nn_plot_layer", "nan0_0minus1"),
|
|
1365
|
+
rolling_nn_window=merged.get("rolling_nn_window", 15),
|
|
1366
|
+
rolling_nn_step=merged.get("rolling_nn_step", 2),
|
|
1367
|
+
rolling_nn_min_overlap=merged.get("rolling_nn_min_overlap", 10),
|
|
1368
|
+
rolling_nn_return_fraction=merged.get("rolling_nn_return_fraction", True),
|
|
1369
|
+
rolling_nn_obsm_key=merged.get("rolling_nn_obsm_key", "rolling_nn_dist"),
|
|
1370
|
+
rolling_nn_site_types=merged.get("rolling_nn_site_types", None),
|
|
1324
1371
|
layer_for_umap_plotting=merged.get("layer_for_umap_plotting", "nan_half"),
|
|
1325
1372
|
umap_layers_to_plot=merged.get(
|
|
1326
1373
|
"umap_layers_to_plot", ["mapped_length", "Raw_modification_signal"]
|
|
@@ -1347,6 +1394,7 @@ class ExperimentConfig:
|
|
|
1347
1394
|
hmm_emission_adapt_tol=hmm_emission_adapt_tol,
|
|
1348
1395
|
hmm_dtype=merged.get("hmm_dtype", "float64"),
|
|
1349
1396
|
hmm_feature_sets=hmm_feature_sets,
|
|
1397
|
+
hmm_feature_colormaps=hmm_feature_colormaps,
|
|
1350
1398
|
hmm_annotation_threshold=hmm_annotation_threshold,
|
|
1351
1399
|
hmm_batch_size=hmm_batch_size,
|
|
1352
1400
|
hmm_use_viterbi=hmm_use_viterbi,
|
|
@@ -1355,6 +1403,7 @@ class ExperimentConfig:
|
|
|
1355
1403
|
hmm_merge_layer_features=hmm_merge_layer_features,
|
|
1356
1404
|
clustermap_cmap_hmm=merged.get("clustermap_cmap_hmm", "coolwarm"),
|
|
1357
1405
|
hmm_clustermap_feature_layers=hmm_clustermap_feature_layers,
|
|
1406
|
+
hmm_clustermap_length_layers=hmm_clustermap_length_layers,
|
|
1358
1407
|
hmm_clustermap_sortby=merged.get("hmm_clustermap_sortby", "hmm"),
|
|
1359
1408
|
hmm_peak_feature_configs=hmm_peak_feature_configs,
|
|
1360
1409
|
footprints=merged.get("footprints", None),
|
|
@@ -1390,6 +1439,9 @@ class ExperimentConfig:
|
|
|
1390
1439
|
duplicate_detection_site_types=merged.get(
|
|
1391
1440
|
"duplicate_detection_site_types", ["GpC", "CpG", "ambiguous_GpC_CpG"]
|
|
1392
1441
|
),
|
|
1442
|
+
duplicate_detection_demux_types_to_use=merged.get(
|
|
1443
|
+
"duplicate_detection_demux_types_to_use", ["single", "double", "already"]
|
|
1444
|
+
),
|
|
1393
1445
|
duplicate_detection_distance_threshold=merged.get(
|
|
1394
1446
|
"duplicate_detection_distance_threshold", 0.07
|
|
1395
1447
|
),
|
smftools/constants.py
CHANGED
|
@@ -21,7 +21,30 @@ BAM_SUFFIX: Final[str] = ".bam"
|
|
|
21
21
|
BARCODE_BOTH_ENDS: Final[bool] = False
|
|
22
22
|
REF_COL: Final[str] = "Reference_strand"
|
|
23
23
|
SAMPLE_COL: Final[str] = "Experiment_name_and_barcode"
|
|
24
|
+
SAMPLE: Final[str] = "Sample"
|
|
24
25
|
SPLIT_DIR: Final[str] = "demultiplexed_BAMs"
|
|
26
|
+
H5_DIR: Final[str] = "h5ads"
|
|
27
|
+
DEMUX_TYPE: Final[str] = "demux_type"
|
|
28
|
+
BARCODE: Final[str] = "Barcode"
|
|
29
|
+
REFERENCE: Final[str] = "Reference"
|
|
30
|
+
REFERENCE_STRAND: Final[str] = "Reference_strand"
|
|
31
|
+
REFERENCE_DATASET_STRAND: Final[str] = "Reference_dataset_strand"
|
|
32
|
+
STRAND: Final[str] = "Strand"
|
|
33
|
+
DATASET: Final[str] = "Dataset"
|
|
34
|
+
READ_MISMATCH_TREND: Final[str] = "Read_mismatch_trend"
|
|
35
|
+
READ_MAPPING_DIRECTION: Final[str] = "Read_mapping_direction"
|
|
36
|
+
SEQUENCE_INTEGER_ENCODING: Final[str] = "sequence_integer_encoding"
|
|
37
|
+
SEQUENCE_INTEGER_DECODING: Final[str] = "sequence_integer_decoding"
|
|
38
|
+
MISMATCH_INTEGER_ENCODING: Final[str] = "mismatch_integer_encoding"
|
|
39
|
+
BASE_QUALITY_SCORES: Final[str] = "base_quality_scores"
|
|
40
|
+
READ_SPAN_MASK: Final[str] = "read_span_mask"
|
|
41
|
+
|
|
42
|
+
LOAD_DIR: Final[str] = "load_adata_outputs"
|
|
43
|
+
PREPROCESS_DIR: Final[str] = "preprocess_adata_outputs"
|
|
44
|
+
SPATIAL_DIR: Final[str] = "spatial_adata_outputs"
|
|
45
|
+
HMM_DIR: Final[str] = "hmm_adata_outputs"
|
|
46
|
+
LATENT_DIR: Final[str] = "latent_adata_outputs"
|
|
47
|
+
LOGGING_DIR: Final[str] = "logs"
|
|
25
48
|
TRIM: Final[bool] = False
|
|
26
49
|
|
|
27
50
|
_private_conversions = ["unconverted"]
|
|
@@ -35,3 +58,45 @@ MOD_MAP: Final[Mapping[str, str]] = _deep_freeze(_private_mod_map)
|
|
|
35
58
|
|
|
36
59
|
_private_strands = ("bottom", "top")
|
|
37
60
|
STRANDS: Final[tuple[str, ...]] = _deep_freeze(_private_strands)
|
|
61
|
+
|
|
62
|
+
MODKIT_EXTRACT_TSV_COLUMN_CHROM: Final[str] = "chrom"
|
|
63
|
+
MODKIT_EXTRACT_TSV_COLUMN_REF_POSITION: Final[str] = "ref_position"
|
|
64
|
+
MODKIT_EXTRACT_TSV_COLUMN_MODIFIED_PRIMARY_BASE: Final[str] = "modified_primary_base"
|
|
65
|
+
MODKIT_EXTRACT_TSV_COLUMN_REF_STRAND: Final[str] = "ref_strand"
|
|
66
|
+
MODKIT_EXTRACT_TSV_COLUMN_READ_ID: Final[str] = "read_id"
|
|
67
|
+
MODKIT_EXTRACT_TSV_COLUMN_CALL_CODE: Final[str] = "call_code"
|
|
68
|
+
MODKIT_EXTRACT_TSV_COLUMN_CALL_PROB: Final[str] = "call_prob"
|
|
69
|
+
|
|
70
|
+
MODKIT_EXTRACT_MODIFIED_BASE_A: Final[str] = "A"
|
|
71
|
+
MODKIT_EXTRACT_MODIFIED_BASE_C: Final[str] = "C"
|
|
72
|
+
MODKIT_EXTRACT_REF_STRAND_PLUS: Final[str] = "+"
|
|
73
|
+
MODKIT_EXTRACT_REF_STRAND_MINUS: Final[str] = "-"
|
|
74
|
+
|
|
75
|
+
_private_modkit_extract_call_code_modified = ("a", "h", "m")
|
|
76
|
+
MODKIT_EXTRACT_CALL_CODE_MODIFIED: Final[tuple[str, ...]] = _deep_freeze(
|
|
77
|
+
_private_modkit_extract_call_code_modified
|
|
78
|
+
)
|
|
79
|
+
_private_modkit_extract_call_code_canonical = ("-",)
|
|
80
|
+
MODKIT_EXTRACT_CALL_CODE_CANONICAL: Final[tuple[str, ...]] = _deep_freeze(
|
|
81
|
+
_private_modkit_extract_call_code_canonical
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
MODKIT_EXTRACT_SEQUENCE_BASES: Final[tuple[str, ...]] = _deep_freeze(("A", "C", "G", "T", "N"))
|
|
85
|
+
MODKIT_EXTRACT_SEQUENCE_PADDING_BASE: Final[str] = "PAD"
|
|
86
|
+
_private_modkit_extract_base_to_int: Dict[str, int] = {
|
|
87
|
+
"A": 0,
|
|
88
|
+
"C": 1,
|
|
89
|
+
"G": 2,
|
|
90
|
+
"T": 3,
|
|
91
|
+
"N": 4,
|
|
92
|
+
"PAD": 5,
|
|
93
|
+
}
|
|
94
|
+
MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT: Final[Mapping[str, int]] = _deep_freeze(
|
|
95
|
+
_private_modkit_extract_base_to_int
|
|
96
|
+
)
|
|
97
|
+
_private_modkit_extract_int_to_base: Dict[int, str] = {
|
|
98
|
+
value: key for key, value in _private_modkit_extract_base_to_int.items()
|
|
99
|
+
}
|
|
100
|
+
MODKIT_EXTRACT_SEQUENCE_INT_TO_BASE: Final[Mapping[int, str]] = _deep_freeze(
|
|
101
|
+
_private_modkit_extract_int_to_base
|
|
102
|
+
)
|
smftools/hmm/HMM.py
CHANGED
|
@@ -144,6 +144,83 @@ def _safe_int_coords(var_names) -> Tuple[np.ndarray, bool]:
|
|
|
144
144
|
return np.arange(len(var_names), dtype=int), False
|
|
145
145
|
|
|
146
146
|
|
|
147
|
+
def mask_layers_outside_read_span(
|
|
148
|
+
adata,
|
|
149
|
+
layers: Sequence[str],
|
|
150
|
+
*,
|
|
151
|
+
start_key: str = "reference_start",
|
|
152
|
+
end_key: str = "reference_end",
|
|
153
|
+
use_original_var_names: bool = True,
|
|
154
|
+
) -> List[str]:
|
|
155
|
+
"""Mask layer values outside read reference spans with NaN.
|
|
156
|
+
|
|
157
|
+
This uses integer coordinate comparisons against either ``adata.var["Original_var_names"]``
|
|
158
|
+
(when present) or ``adata.var_names``. Values strictly less than ``start_key`` or greater
|
|
159
|
+
than ``end_key`` are set to NaN for each read.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
adata: AnnData object to modify in-place.
|
|
163
|
+
layers: Layer names to mask.
|
|
164
|
+
start_key: obs column holding reference start positions.
|
|
165
|
+
end_key: obs column holding reference end positions.
|
|
166
|
+
use_original_var_names: Use ``adata.var["Original_var_names"]`` when available.
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
List of layer names that were masked.
|
|
170
|
+
"""
|
|
171
|
+
if not layers:
|
|
172
|
+
return []
|
|
173
|
+
|
|
174
|
+
if start_key not in adata.obs or end_key not in adata.obs:
|
|
175
|
+
raise KeyError(f"Missing {start_key!r} or {end_key!r} in adata.obs.")
|
|
176
|
+
|
|
177
|
+
coord_source = adata.var_names
|
|
178
|
+
if use_original_var_names and "Original_var_names" in adata.var:
|
|
179
|
+
orig = np.asarray(adata.var["Original_var_names"])
|
|
180
|
+
if orig.size == adata.n_vars:
|
|
181
|
+
try:
|
|
182
|
+
orig_numeric = np.asarray(orig, dtype=float)
|
|
183
|
+
except (TypeError, ValueError):
|
|
184
|
+
orig_numeric = None
|
|
185
|
+
if orig_numeric is not None and np.isfinite(orig_numeric).any():
|
|
186
|
+
coord_source = orig
|
|
187
|
+
|
|
188
|
+
coords, _ = _safe_int_coords(coord_source)
|
|
189
|
+
if coords.shape[0] != adata.n_vars:
|
|
190
|
+
raise ValueError("Coordinate source length does not match adata.n_vars.")
|
|
191
|
+
|
|
192
|
+
try:
|
|
193
|
+
starts = np.asarray(adata.obs[start_key], dtype=float)
|
|
194
|
+
ends = np.asarray(adata.obs[end_key], dtype=float)
|
|
195
|
+
except (TypeError, ValueError) as exc:
|
|
196
|
+
raise ValueError("Start/end positions must be numeric.") from exc
|
|
197
|
+
|
|
198
|
+
masked = []
|
|
199
|
+
for layer in layers:
|
|
200
|
+
if layer not in adata.layers:
|
|
201
|
+
raise KeyError(f"Layer {layer!r} not found in adata.layers.")
|
|
202
|
+
|
|
203
|
+
arr = np.asarray(adata.layers[layer])
|
|
204
|
+
if not np.issubdtype(arr.dtype, np.floating):
|
|
205
|
+
arr = arr.astype(float, copy=True)
|
|
206
|
+
|
|
207
|
+
for i in range(adata.n_obs):
|
|
208
|
+
start = starts[i]
|
|
209
|
+
end = ends[i]
|
|
210
|
+
if not np.isfinite(start) or not np.isfinite(end):
|
|
211
|
+
continue
|
|
212
|
+
start_i = int(start)
|
|
213
|
+
end_i = int(end)
|
|
214
|
+
row_mask = (coords < start_i) | (coords > end_i)
|
|
215
|
+
if row_mask.any():
|
|
216
|
+
arr[i, row_mask] = np.nan
|
|
217
|
+
|
|
218
|
+
adata.layers[layer] = arr
|
|
219
|
+
masked.append(layer)
|
|
220
|
+
|
|
221
|
+
return masked
|
|
222
|
+
|
|
223
|
+
|
|
147
224
|
def _logsumexp(x: torch.Tensor, dim: int) -> torch.Tensor:
|
|
148
225
|
"""Compute log-sum-exp in a numerically stable way.
|
|
149
226
|
|
|
@@ -1064,6 +1141,8 @@ class BaseHMM(nn.Module):
|
|
|
1064
1141
|
uns_key: str = "hmm_appended_layers",
|
|
1065
1142
|
uns_flag: str = "hmm_annotated",
|
|
1066
1143
|
force_redo: bool = False,
|
|
1144
|
+
mask_to_read_span: bool = True,
|
|
1145
|
+
mask_use_original_var_names: bool = True,
|
|
1067
1146
|
device: Optional[Union[str, torch.device]] = None,
|
|
1068
1147
|
**kwargs,
|
|
1069
1148
|
):
|
|
@@ -1085,6 +1164,8 @@ class BaseHMM(nn.Module):
|
|
|
1085
1164
|
uns_key: .uns key to track appended layers.
|
|
1086
1165
|
uns_flag: .uns flag to mark annotations.
|
|
1087
1166
|
force_redo: Whether to overwrite existing layers.
|
|
1167
|
+
mask_to_read_span: Whether to mask appended layers outside read spans.
|
|
1168
|
+
mask_use_original_var_names: Use ``adata.var["Original_var_names"]`` when available.
|
|
1088
1169
|
device: Device specifier.
|
|
1089
1170
|
**kwargs: Additional parameters for specialized workflows.
|
|
1090
1171
|
|
|
@@ -1245,6 +1326,13 @@ class BaseHMM(nn.Module):
|
|
|
1245
1326
|
np.asarray(adata.layers[nm])
|
|
1246
1327
|
)
|
|
1247
1328
|
|
|
1329
|
+
if mask_to_read_span and appended:
|
|
1330
|
+
mask_layers_outside_read_span(
|
|
1331
|
+
adata,
|
|
1332
|
+
appended,
|
|
1333
|
+
use_original_var_names=mask_use_original_var_names,
|
|
1334
|
+
)
|
|
1335
|
+
|
|
1248
1336
|
adata.uns[uns_key] = appended
|
|
1249
1337
|
adata.uns[uns_flag] = True
|
|
1250
1338
|
return None
|
smftools/informatics/__init__.py
CHANGED
|
@@ -6,6 +6,7 @@ _LAZY_ATTRS = {
|
|
|
6
6
|
"_bed_to_bigwig": "smftools.informatics.bed_functions",
|
|
7
7
|
"_plot_bed_histograms": "smftools.informatics.bed_functions",
|
|
8
8
|
"add_demux_type_annotation": "smftools.informatics.h5ad_functions",
|
|
9
|
+
"add_read_tag_annotations": "smftools.informatics.h5ad_functions",
|
|
9
10
|
"add_read_length_and_mapping_qc": "smftools.informatics.h5ad_functions",
|
|
10
11
|
"align_and_sort_BAM": "smftools.informatics.bam_functions",
|
|
11
12
|
"bam_qc": "smftools.informatics.bam_functions",
|
|
@@ -18,6 +19,7 @@ _LAZY_ATTRS = {
|
|
|
18
19
|
"extract_base_identities": "smftools.informatics.bam_functions",
|
|
19
20
|
"extract_mods": "smftools.informatics.modkit_functions",
|
|
20
21
|
"extract_read_features_from_bam": "smftools.informatics.bam_functions",
|
|
22
|
+
"extract_read_tags_from_bam": "smftools.informatics.bam_functions",
|
|
21
23
|
"extract_read_lengths_from_bed": "smftools.informatics.bed_functions",
|
|
22
24
|
"extract_readnames_from_bam": "smftools.informatics.bam_functions",
|
|
23
25
|
"fast5_to_pod5": "smftools.informatics.pod5_functions",
|
|
@@ -30,6 +32,8 @@ _LAZY_ATTRS = {
|
|
|
30
32
|
"modQC": "smftools.informatics.modkit_functions",
|
|
31
33
|
"modcall": "smftools.informatics.basecalling",
|
|
32
34
|
"modkit_extract_to_adata": "smftools.informatics.modkit_extract_to_adata",
|
|
35
|
+
"decode_int_sequence": "smftools.informatics.sequence_encoding",
|
|
36
|
+
"encode_sequence_to_int": "smftools.informatics.sequence_encoding",
|
|
33
37
|
"ohe_batching": "smftools.informatics.ohe",
|
|
34
38
|
"ohe_layers_decode": "smftools.informatics.ohe",
|
|
35
39
|
"one_hot_decode": "smftools.informatics.ohe",
|
|
@@ -55,6 +59,8 @@ def __getattr__(name: str):
|
|
|
55
59
|
__all__ = [
|
|
56
60
|
"basecall_pod5s",
|
|
57
61
|
"converted_BAM_to_adata",
|
|
62
|
+
"decode_int_sequence",
|
|
63
|
+
"encode_sequence_to_int",
|
|
58
64
|
"subsample_fasta_from_bed",
|
|
59
65
|
"subsample_pod5",
|
|
60
66
|
"fast5_to_pod5",
|