smftools 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. smftools/__init__.py +29 -0
  2. smftools/_settings.py +20 -0
  3. smftools/_version.py +1 -0
  4. smftools/datasets/F1_hybrid_NKG2A_enhander_promoter_GpC_conversion_SMF.h5ad.gz +0 -0
  5. smftools/datasets/F1_sample_sheet.csv +5 -0
  6. smftools/datasets/__init__.py +9 -0
  7. smftools/datasets/dCas9_m6A_invitro_kinetics.h5ad.gz +0 -0
  8. smftools/datasets/datasets.py +28 -0
  9. smftools/informatics/__init__.py +16 -0
  10. smftools/informatics/archived/bam_conversion.py +59 -0
  11. smftools/informatics/archived/bam_direct.py +63 -0
  12. smftools/informatics/archived/basecalls_to_adata.py +71 -0
  13. smftools/informatics/archived/print_bam_query_seq.py +29 -0
  14. smftools/informatics/basecall_pod5s.py +80 -0
  15. smftools/informatics/conversion_smf.py +132 -0
  16. smftools/informatics/direct_smf.py +137 -0
  17. smftools/informatics/fast5_to_pod5.py +21 -0
  18. smftools/informatics/helpers/LoadExperimentConfig.py +75 -0
  19. smftools/informatics/helpers/__init__.py +74 -0
  20. smftools/informatics/helpers/align_and_sort_BAM.py +59 -0
  21. smftools/informatics/helpers/aligned_BAM_to_bed.py +74 -0
  22. smftools/informatics/helpers/archived/informatics.py +260 -0
  23. smftools/informatics/helpers/archived/load_adata.py +516 -0
  24. smftools/informatics/helpers/bam_qc.py +66 -0
  25. smftools/informatics/helpers/bed_to_bigwig.py +39 -0
  26. smftools/informatics/helpers/binarize_converted_base_identities.py +79 -0
  27. smftools/informatics/helpers/canoncall.py +34 -0
  28. smftools/informatics/helpers/complement_base_list.py +21 -0
  29. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +55 -0
  30. smftools/informatics/helpers/converted_BAM_to_adata.py +245 -0
  31. smftools/informatics/helpers/converted_BAM_to_adata_II.py +369 -0
  32. smftools/informatics/helpers/count_aligned_reads.py +43 -0
  33. smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
  34. smftools/informatics/helpers/extract_base_identities.py +44 -0
  35. smftools/informatics/helpers/extract_mods.py +83 -0
  36. smftools/informatics/helpers/extract_read_features_from_bam.py +31 -0
  37. smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
  38. smftools/informatics/helpers/extract_readnames_from_BAM.py +22 -0
  39. smftools/informatics/helpers/find_conversion_sites.py +50 -0
  40. smftools/informatics/helpers/generate_converted_FASTA.py +99 -0
  41. smftools/informatics/helpers/get_chromosome_lengths.py +32 -0
  42. smftools/informatics/helpers/get_native_references.py +28 -0
  43. smftools/informatics/helpers/index_fasta.py +12 -0
  44. smftools/informatics/helpers/make_dirs.py +21 -0
  45. smftools/informatics/helpers/make_modbed.py +27 -0
  46. smftools/informatics/helpers/modQC.py +27 -0
  47. smftools/informatics/helpers/modcall.py +36 -0
  48. smftools/informatics/helpers/modkit_extract_to_adata.py +884 -0
  49. smftools/informatics/helpers/ohe_batching.py +76 -0
  50. smftools/informatics/helpers/ohe_layers_decode.py +32 -0
  51. smftools/informatics/helpers/one_hot_decode.py +27 -0
  52. smftools/informatics/helpers/one_hot_encode.py +57 -0
  53. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +53 -0
  54. smftools/informatics/helpers/run_multiqc.py +28 -0
  55. smftools/informatics/helpers/separate_bam_by_bc.py +43 -0
  56. smftools/informatics/helpers/split_and_index_BAM.py +36 -0
  57. smftools/informatics/load_adata.py +182 -0
  58. smftools/informatics/readwrite.py +106 -0
  59. smftools/informatics/subsample_fasta_from_bed.py +47 -0
  60. smftools/informatics/subsample_pod5.py +104 -0
  61. smftools/plotting/__init__.py +15 -0
  62. smftools/plotting/classifiers.py +355 -0
  63. smftools/plotting/general_plotting.py +205 -0
  64. smftools/plotting/position_stats.py +462 -0
  65. smftools/preprocessing/__init__.py +33 -0
  66. smftools/preprocessing/append_C_context.py +82 -0
  67. smftools/preprocessing/archives/mark_duplicates.py +146 -0
  68. smftools/preprocessing/archives/preprocessing.py +614 -0
  69. smftools/preprocessing/archives/remove_duplicates.py +21 -0
  70. smftools/preprocessing/binarize_on_Youden.py +45 -0
  71. smftools/preprocessing/binary_layers_to_ohe.py +40 -0
  72. smftools/preprocessing/calculate_complexity.py +72 -0
  73. smftools/preprocessing/calculate_consensus.py +47 -0
  74. smftools/preprocessing/calculate_converted_read_methylation_stats.py +94 -0
  75. smftools/preprocessing/calculate_coverage.py +42 -0
  76. smftools/preprocessing/calculate_pairwise_differences.py +49 -0
  77. smftools/preprocessing/calculate_pairwise_hamming_distances.py +27 -0
  78. smftools/preprocessing/calculate_position_Youden.py +115 -0
  79. smftools/preprocessing/calculate_read_length_stats.py +79 -0
  80. smftools/preprocessing/clean_NaN.py +46 -0
  81. smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
  82. smftools/preprocessing/filter_converted_reads_on_methylation.py +44 -0
  83. smftools/preprocessing/filter_reads_on_length.py +51 -0
  84. smftools/preprocessing/flag_duplicate_reads.py +149 -0
  85. smftools/preprocessing/invert_adata.py +30 -0
  86. smftools/preprocessing/load_sample_sheet.py +38 -0
  87. smftools/preprocessing/make_dirs.py +21 -0
  88. smftools/preprocessing/min_non_diagonal.py +25 -0
  89. smftools/preprocessing/recipes.py +127 -0
  90. smftools/preprocessing/subsample_adata.py +58 -0
  91. smftools/readwrite.py +198 -0
  92. smftools/tools/__init__.py +49 -0
  93. smftools/tools/apply_hmm.py +202 -0
  94. smftools/tools/apply_hmm_batched.py +241 -0
  95. smftools/tools/archived/classify_methylated_features.py +66 -0
  96. smftools/tools/archived/classify_non_methylated_features.py +75 -0
  97. smftools/tools/archived/subset_adata_v1.py +32 -0
  98. smftools/tools/archived/subset_adata_v2.py +46 -0
  99. smftools/tools/calculate_distances.py +18 -0
  100. smftools/tools/calculate_umap.py +62 -0
  101. smftools/tools/call_hmm_peaks.py +105 -0
  102. smftools/tools/classifiers.py +787 -0
  103. smftools/tools/cluster_adata_on_methylation.py +105 -0
  104. smftools/tools/data/__init__.py +2 -0
  105. smftools/tools/data/anndata_data_module.py +90 -0
  106. smftools/tools/data/preprocessing.py +6 -0
  107. smftools/tools/display_hmm.py +18 -0
  108. smftools/tools/evaluation/__init__.py +0 -0
  109. smftools/tools/general_tools.py +69 -0
  110. smftools/tools/hmm_readwrite.py +16 -0
  111. smftools/tools/inference/__init__.py +1 -0
  112. smftools/tools/inference/lightning_inference.py +41 -0
  113. smftools/tools/models/__init__.py +9 -0
  114. smftools/tools/models/base.py +14 -0
  115. smftools/tools/models/cnn.py +34 -0
  116. smftools/tools/models/lightning_base.py +41 -0
  117. smftools/tools/models/mlp.py +17 -0
  118. smftools/tools/models/positional.py +17 -0
  119. smftools/tools/models/rnn.py +16 -0
  120. smftools/tools/models/sklearn_models.py +40 -0
  121. smftools/tools/models/transformer.py +133 -0
  122. smftools/tools/models/wrappers.py +20 -0
  123. smftools/tools/nucleosome_hmm_refinement.py +104 -0
  124. smftools/tools/position_stats.py +239 -0
  125. smftools/tools/read_stats.py +70 -0
  126. smftools/tools/subset_adata.py +28 -0
  127. smftools/tools/train_hmm.py +78 -0
  128. smftools/tools/training/__init__.py +1 -0
  129. smftools/tools/training/train_lightning_model.py +47 -0
  130. smftools/tools/utils/__init__.py +2 -0
  131. smftools/tools/utils/device.py +10 -0
  132. smftools/tools/utils/grl.py +14 -0
  133. {smftools-0.1.6.dist-info → smftools-0.1.7.dist-info}/METADATA +5 -2
  134. smftools-0.1.7.dist-info/RECORD +136 -0
  135. smftools-0.1.6.dist-info/RECORD +0 -4
  136. {smftools-0.1.6.dist-info → smftools-0.1.7.dist-info}/WHEEL +0 -0
  137. {smftools-0.1.6.dist-info → smftools-0.1.7.dist-info}/licenses/LICENSE +0 -0
smftools/readwrite.py ADDED
@@ -0,0 +1,198 @@
1
+ ## readwrite ##
2
+
3
+ ######################################################################################################
4
+ ## Datetime functionality
5
+ def date_string():
6
+ """
7
+ Each time this is called, it returns the current date string
8
+ """
9
+ from datetime import datetime
10
+ current_date = datetime.now()
11
+ date_string = current_date.strftime("%Y%m%d")
12
+ date_string = date_string[2:]
13
+ return date_string
14
+
15
+ def time_string():
16
+ """
17
+ Each time this is called, it returns the current time string
18
+ """
19
+ from datetime import datetime
20
+ current_time = datetime.now()
21
+ return current_time.strftime("%H:%M:%S")
22
+ ######################################################################################################
23
+
24
+ ######################################################################################################
25
+ ## Numpy, Pandas, Anndata functionality
26
+
27
+ def adata_to_df(adata, layer=None):
28
+ """
29
+ Convert an AnnData object into a Pandas DataFrame.
30
+
31
+ Parameters:
32
+ adata (AnnData): The input AnnData object.
33
+ layer (str, optional): The layer to extract. If None, uses adata.X.
34
+
35
+ Returns:
36
+ pd.DataFrame: A DataFrame where rows are observations and columns are positions.
37
+ """
38
+ import pandas as pd
39
+ import anndata as ad
40
+ import numpy as np
41
+
42
+ # Validate that the requested layer exists
43
+ if layer and layer not in adata.layers:
44
+ raise ValueError(f"Layer '{layer}' not found in adata.layers.")
45
+
46
+ # Extract the data matrix
47
+ data_matrix = adata.layers.get(layer, adata.X)
48
+
49
+ # Ensure matrix is dense (handle sparse formats)
50
+ if hasattr(data_matrix, "toarray"):
51
+ data_matrix = data_matrix.toarray()
52
+
53
+ # Ensure obs and var have unique indices
54
+ if adata.obs.index.duplicated().any():
55
+ raise ValueError("Duplicate values found in `adata.obs.index`. Ensure unique observation indices.")
56
+
57
+ if adata.var.index.duplicated().any():
58
+ raise ValueError("Duplicate values found in `adata.var.index`. Ensure unique variable indices.")
59
+
60
+ # Convert to DataFrame
61
+ df = pd.DataFrame(data_matrix, index=adata.obs.index, columns=adata.var.index)
62
+
63
+ return df
64
+
65
+
66
+ def save_matrix(matrix, save_name):
67
+ """
68
+ Input: A numpy matrix and a save_name
69
+ Output: A txt file representation of the data matrix
70
+ """
71
+ import numpy as np
72
+ np.savetxt(f'{save_name}.txt', matrix)
73
+
74
+ def concatenate_h5ads(output_file, file_suffix='h5ad.gz', delete_inputs=True):
75
+ """
76
+ Concatenate all h5ad files in a directory and delete them after the final adata is written out.
77
+ Input: an output file path relative to the directory in which the function is called
78
+ """
79
+ import os
80
+ import anndata as ad
81
+ # Runtime warnings
82
+ import warnings
83
+ warnings.filterwarnings('ignore', category=UserWarning, module='anndata')
84
+ warnings.filterwarnings('ignore', category=FutureWarning, module='anndata')
85
+
86
+ # List all files in the directory
87
+ files = os.listdir(os.getcwd())
88
+ # get current working directory
89
+ cwd = os.getcwd()
90
+ suffix = file_suffix
91
+ # Filter file names that contain the search string in their filename and keep them in a list
92
+ hdfs = [hdf for hdf in files if suffix in hdf]
93
+ # Sort file list by names and print the list of file names
94
+ hdfs.sort()
95
+ print('{0} sample files found: {1}'.format(len(hdfs), hdfs))
96
+ # Iterate over all of the hdf5 files and concatenate them.
97
+ final_adata = None
98
+ for hdf in hdfs:
99
+ print('{0}: Reading in {1} hdf5 file'.format(time_string(), hdf))
100
+ temp_adata = ad.read_h5ad(hdf)
101
+ if final_adata:
102
+ print('{0}: Concatenating final adata object with {1} hdf5 file'.format(time_string(), hdf))
103
+ final_adata = ad.concat([final_adata, temp_adata], join='outer', index_unique=None)
104
+ else:
105
+ print('{0}: Initializing final adata object with {1} hdf5 file'.format(time_string(), hdf))
106
+ final_adata = temp_adata
107
+ print('{0}: Writing final concatenated hdf5 file'.format(time_string()))
108
+ final_adata.write_h5ad(output_file, compression='gzip')
109
+
110
+ # Delete the individual h5ad files and only keep the final concatenated file
111
+ if delete_inputs:
112
+ files = os.listdir(os.getcwd())
113
+ hdfs = [hdf for hdf in files if suffix in hdf]
114
+ if output_file in hdfs:
115
+ hdfs.remove(output_file)
116
+ # Iterate over the files and delete them
117
+ for hdf in hdfs:
118
+ try:
119
+ os.remove(hdf)
120
+ print(f"Deleted file: {hdf}")
121
+ except OSError as e:
122
+ print(f"Error deleting file {hdf}: {e}")
123
+ else:
124
+ print('Keeping input files')
125
+
126
+ def safe_write_h5ad(adata, path, compression="gzip", backup=False, backup_dir="./"):
127
+ """
128
+ Saves an AnnData object safely by omitting problematic columns from .obs and .var.
129
+
130
+ Parameters:
131
+ adata (AnnData): The AnnData object to save.
132
+ path (str): Output .h5ad file path.
133
+ compression (str): Compression method for h5ad file.
134
+ backup (bool): If True, saves problematic columns to CSV files.
135
+ backup_dir (str): Directory to store backups if backup=True.
136
+ """
137
+ import anndata as ad
138
+ import pandas as pd
139
+ import os
140
+
141
+ os.makedirs(backup_dir, exist_ok=True)
142
+
143
+ def filter_df(df, df_name):
144
+ bad_cols = []
145
+ for col in df.columns:
146
+ if df[col].dtype == 'object':
147
+ if not df[col].apply(lambda x: isinstance(x, (str, type(None)))).all():
148
+ bad_cols.append(col)
149
+ if bad_cols:
150
+ print(f"⚠️ Skipping columns from {df_name}: {bad_cols}")
151
+ if backup:
152
+ df[bad_cols].to_csv(os.path.join(backup_dir, f"{df_name}_skipped_columns.csv"))
153
+ print(f"📝 Backed up skipped columns to {backup_dir}/{df_name}_skipped_columns.csv")
154
+ return df.drop(columns=bad_cols)
155
+
156
+ # Clean obs and var
157
+ obs_clean = filter_df(adata.obs, "obs")
158
+ var_clean = filter_df(adata.var, "var")
159
+
160
+ # Save clean version
161
+ adata_copy = ad.AnnData(
162
+ X=adata.X,
163
+ obs=obs_clean,
164
+ var=var_clean,
165
+ layers=adata.layers,
166
+ uns=adata.uns,
167
+ obsm=adata.obsm,
168
+ varm=adata.varm
169
+ )
170
+ adata_copy.write_h5ad(path, compression=compression)
171
+ print(f"✅ Saved safely to {path}")
172
+
173
+ def merge_barcoded_anndatas(adata_single, adata_double):
174
+ import numpy as np
175
+ import anndata as ad
176
+
177
+ # Step 1: Identify overlap
178
+ overlap = np.intersect1d(adata_single.obs_names, adata_double.obs_names)
179
+
180
+ # Step 2: Filter out overlaps from adata_single
181
+ adata_single_filtered = adata_single[~adata_single.obs_names.isin(overlap)].copy()
182
+
183
+ # Step 3: Add source tag
184
+ adata_single_filtered.obs['source'] = 'single_barcode'
185
+ adata_double.obs['source'] = 'double_barcode'
186
+
187
+ # Step 4: Concatenate all components
188
+ adata_merged = ad.concat([
189
+ adata_single_filtered,
190
+ adata_double
191
+ ], join='outer', merge='same') # merge='same' preserves matching layers, obsm, etc.
192
+
193
+ # Step 5: Merge `.uns`
194
+ adata_merged.uns = {**adata_single.uns, **adata_double.uns}
195
+
196
+ return adata_merged
197
+
198
+ ######################################################################################################
@@ -0,0 +1,49 @@
1
+ from .apply_hmm import apply_hmm
2
+ from .apply_hmm_batched import apply_hmm_batched
3
+ from .position_stats import calculate_relative_risk_on_activity, compute_positionwise_statistic
4
+ from .calculate_distances import calculate_distances
5
+ from .calculate_umap import calculate_umap
6
+ from .call_hmm_peaks import call_hmm_peaks
7
+ from .classifiers import run_training_loop, run_inference, evaluate_models_by_subgroup, prepare_melted_model_data, sliding_window_train_test
8
+ from .cluster_adata_on_methylation import cluster_adata_on_methylation
9
+ from .display_hmm import display_hmm
10
+ from .general_tools import create_nan_mask_from_X, combine_layers, create_nan_or_non_gpc_mask
11
+ from .hmm_readwrite import load_hmm, save_hmm
12
+ from .nucleosome_hmm_refinement import refine_nucleosome_calls, infer_nucleosomes_in_large_bound
13
+ from .read_stats import calculate_row_entropy
14
+ from .subset_adata import subset_adata
15
+ from .train_hmm import train_hmm
16
+
17
+ from . import models
18
+ from . import data
19
+ from . import utils
20
+ from . import evaluation
21
+ from . import inference
22
+ from . import training
23
+
24
+ __all__ = [
25
+ "apply_hmm",
26
+ "apply_hmm_batched",
27
+ "calculate_distances",
28
+ "compute_positionwise_statistic",
29
+ "calculate_row_entropy",
30
+ "calculate_umap",
31
+ "calculate_relative_risk_on_activity",
32
+ "call_hmm_peaks",
33
+ "cluster_adata_on_methylation",
34
+ "create_nan_mask_from_X",
35
+ "create_nan_or_non_gpc_mask",
36
+ "combine_layers",
37
+ "display_hmm",
38
+ "evaluate_models_by_subgroup",
39
+ "load_hmm",
40
+ "prepare_melted_model_data",
41
+ "refine_nucleosome_calls",
42
+ "infer_nucleosomes_in_large_bound",
43
+ "run_training_loop",
44
+ "run_inference",
45
+ "save_hmm",
46
+ "sliding_window_train_test"
47
+ "subset_adata",
48
+ "train_hmm"
49
+ ]
@@ -0,0 +1,202 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import torch
4
+ from tqdm import tqdm
5
+
6
+ def apply_hmm(adata, model, obs_column, layer=None, footprints=True, accessible_patches=False, cpg=False, methbases=["GpC", "CpG", "A"], device="cpu", threshold=0.7):
7
+ """
8
+ Applies an HMM model to an AnnData object using tensor-based sequence inputs.
9
+ If multiple methbases are passed, generates a combined feature set.
10
+ """
11
+ model.to(device)
12
+
13
+ # --- Feature Definitions ---
14
+ feature_sets = {}
15
+ if footprints:
16
+ feature_sets["footprint"] = {
17
+ "features": {
18
+ "small_bound_stretch": [0, 30],
19
+ "medium_bound_stretch": [30, 80],
20
+ "putative_nucleosome": [80, 200],
21
+ "large_bound_stretch": [200, np.inf]
22
+ },
23
+ "state": "Non-Methylated"
24
+ }
25
+ if accessible_patches:
26
+ feature_sets["accessible"] = {
27
+ "features": {
28
+ "small_accessible_patch": [0, 30],
29
+ "mid_accessible_patch": [30, 80],
30
+ "large_accessible_patch": [80, np.inf]
31
+ },
32
+ "state": "Methylated"
33
+ }
34
+ if cpg:
35
+ feature_sets["cpg"] = {
36
+ "features": {
37
+ "cpg_patch": [0, np.inf]
38
+ },
39
+ "state": "Methylated"
40
+ }
41
+
42
+ # --- Init columns ---
43
+ all_features = []
44
+ combined_prefix = "Combined"
45
+ for key, fs in feature_sets.items():
46
+ if key == 'cpg':
47
+ all_features += [f"CpG_{f}" for f in fs["features"]]
48
+ all_features.append(f"CpG_all_{key}_features")
49
+ else:
50
+ for methbase in methbases:
51
+ all_features += [f"{methbase}_{f}" for f in fs["features"]]
52
+ all_features.append(f"{methbase}_all_{key}_features")
53
+ all_features += [f"{combined_prefix}_{f}" for f in fs["features"]]
54
+ all_features.append(f"{combined_prefix}_all_{key}_features")
55
+
56
+ for feature in all_features:
57
+ adata.obs[feature] = pd.Series([[] for _ in range(adata.shape[0])], dtype=object, index=adata.obs.index)
58
+ adata.obs[f"{feature}_distances"] = pd.Series([None] * adata.shape[0])
59
+ adata.obs[f"n_{feature}"] = -1
60
+
61
+ # --- Main loop ---
62
+ references = adata.obs[obs_column].cat.categories
63
+
64
+ for ref in tqdm(references, desc="Processing References"):
65
+ ref_subset = adata[adata.obs[obs_column] == ref]
66
+
67
+ # Create combined mask for methbases
68
+ combined_mask = None
69
+ for methbase in methbases:
70
+ mask = {
71
+ "a": ref_subset.var[f"{ref}_strand_FASTA_base"] == "A",
72
+ "gpc": ref_subset.var[f"{ref}_GpC_site"] == True,
73
+ "cpg": ref_subset.var[f"{ref}_CpG_site"] == True
74
+ }[methbase.lower()]
75
+ combined_mask = mask if combined_mask is None else combined_mask | mask
76
+
77
+ methbase_subset = ref_subset[:, mask]
78
+ matrix = methbase_subset.layers[layer] if layer else methbase_subset.X
79
+
80
+ for i, raw_read in enumerate(matrix):
81
+ read = [int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in raw_read]
82
+ tensor_read = torch.tensor(read, dtype=torch.long, device=device).unsqueeze(0).unsqueeze(-1)
83
+ coords = methbase_subset.var_names
84
+
85
+ for key, fs in feature_sets.items():
86
+ if key == 'cpg':
87
+ continue
88
+ state_target = fs["state"]
89
+ feature_map = fs["features"]
90
+
91
+ classifications = classify_features(tensor_read, model, coords, feature_map, target_state=state_target)
92
+ idx = methbase_subset.obs.index[i]
93
+
94
+ for start, length, label, prob in classifications:
95
+ adata.obs.at[idx, f"{methbase}_{label}"].append([start, length, prob])
96
+ adata.obs.at[idx, f"{methbase}_all_{key}_features"].append([start, length, prob])
97
+
98
+ # Combined methbase subset
99
+ combined_subset = ref_subset[:, combined_mask]
100
+ combined_matrix = combined_subset.layers[layer] if layer else combined_subset.X
101
+
102
+ for i, raw_read in enumerate(combined_matrix):
103
+ read = [int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in raw_read]
104
+ tensor_read = torch.tensor(read, dtype=torch.long, device=device).unsqueeze(0).unsqueeze(-1)
105
+ coords = combined_subset.var_names
106
+
107
+ for key, fs in feature_sets.items():
108
+ if key == 'cpg':
109
+ continue
110
+ state_target = fs["state"]
111
+ feature_map = fs["features"]
112
+
113
+ classifications = classify_features(tensor_read, model, coords, feature_map, target_state=state_target)
114
+ idx = combined_subset.obs.index[i]
115
+
116
+ for start, length, label, prob in classifications:
117
+ adata.obs.at[idx, f"{combined_prefix}_{label}"].append([start, length, prob])
118
+ adata.obs.at[idx, f"{combined_prefix}_all_{key}_features"].append([start, length, prob])
119
+
120
+ # --- Special handling for CpG ---
121
+ if cpg:
122
+ for ref in tqdm(references, desc="Processing CpG"):
123
+ ref_subset = adata[adata.obs[obs_column] == ref]
124
+ mask = (ref_subset.var[f"{ref}_CpG_site"] == True)
125
+ cpg_subset = ref_subset[:, mask]
126
+ matrix = cpg_subset.layers[layer] if layer else cpg_subset.X
127
+
128
+ for i, raw_read in enumerate(matrix):
129
+ read = [int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in raw_read]
130
+ tensor_read = torch.tensor(read, dtype=torch.long, device=device).unsqueeze(0).unsqueeze(-1)
131
+ coords = cpg_subset.var_names
132
+ fs = feature_sets['cpg']
133
+ state_target = fs["state"]
134
+ feature_map = fs["features"]
135
+ classifications = classify_features(tensor_read, model, coords, feature_map, target_state=state_target)
136
+ idx = cpg_subset.obs.index[i]
137
+ for start, length, label, prob in classifications:
138
+ adata.obs.at[idx, f"CpG_{label}"].append([start, length, prob])
139
+ adata.obs.at[idx, f"CpG_all_cpg_features"].append([start, length, prob])
140
+
141
+ # --- Binarization + Distance ---
142
+ for feature in tqdm(all_features, desc="Finalizing Layers"):
143
+ bin_matrix = np.zeros((adata.shape[0], adata.shape[1]), dtype=int)
144
+ counts = np.zeros(adata.shape[0], dtype=int)
145
+ for row_idx, intervals in enumerate(adata.obs[feature]):
146
+ if not isinstance(intervals, list):
147
+ intervals = []
148
+ for start, length, prob in intervals:
149
+ if prob > threshold:
150
+ bin_matrix[row_idx, start:start+length] = 1
151
+ counts[row_idx] += 1
152
+ adata.layers[f"{feature}"] = bin_matrix
153
+ adata.obs[f"n_{feature}"] = counts
154
+ adata.obs[f"{feature}_distances"] = adata.obs[feature].apply(lambda x: calculate_distances(x, threshold))
155
+
156
+ def calculate_distances(intervals, threshold=0.9):
157
+ """Calculates distances between consecutive features in a read."""
158
+ intervals = sorted([iv for iv in intervals if iv[2] > threshold], key=lambda x: x[0])
159
+ distances = [(intervals[i + 1][0] - (intervals[i][0] + intervals[i][1]))
160
+ for i in range(len(intervals) - 1)]
161
+ return distances
162
+
163
+
164
+ def classify_features(sequence, model, coordinates, classification_mapping={}, target_state="Methylated"):
165
+ """
166
+ Classifies regions based on HMM state.
167
+
168
+ Parameters:
169
+ sequence (torch.Tensor): Tensor of binarized data [batch_size, seq_len, 1]
170
+ model: Trained pomegranate HMM
171
+ coordinates (list): Genomic coordinates for sequence
172
+ classification_mapping (dict): Mapping for feature labeling
173
+ target_state (str): The state to classify ("Methylated" or "Non-Methylated")
174
+ """
175
+ predicted_states = model.predict(sequence).squeeze(-1).squeeze(0).cpu().numpy()
176
+ probabilities = model.predict_proba(sequence).squeeze(0).cpu().numpy()
177
+ state_labels = ["Non-Methylated", "Methylated"]
178
+
179
+ classifications, current_start, current_length, current_probs = [], None, 0, []
180
+
181
+ for i, state_index in enumerate(predicted_states):
182
+ state_name = state_labels[state_index]
183
+ state_prob = probabilities[i][state_index]
184
+
185
+ if state_name == target_state:
186
+ if current_start is None:
187
+ current_start = i
188
+ current_length += 1
189
+ current_probs.append(state_prob)
190
+ elif current_start is not None:
191
+ classifications.append((current_start, current_length, avg := np.mean(current_probs)))
192
+ current_start, current_length, current_probs = None, 0, []
193
+
194
+ if current_start is not None:
195
+ classifications.append((current_start, current_length, avg := np.mean(current_probs)))
196
+
197
+ final = []
198
+ for start, length, prob in classifications:
199
+ feature_length = int(coordinates[start + length - 1]) - int(coordinates[start]) + 1
200
+ label = next((ftype for ftype, rng in classification_mapping.items() if rng[0] <= feature_length < rng[1]), target_state)
201
+ final.append((int(coordinates[start]) + 1, feature_length, label, prob))
202
+ return final
@@ -0,0 +1,241 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import torch
4
+ from tqdm import tqdm
5
+
6
+ def apply_hmm_batched(adata, model, obs_column, layer=None, footprints=True, accessible_patches=False, cpg=False, methbases=["GpC", "CpG", "A"], device="cpu", threshold=0.7):
7
+ """
8
+ Applies an HMM model to an AnnData object using tensor-based sequence inputs.
9
+ If multiple methbases are passed, generates a combined feature set.
10
+ """
11
+ import numpy as np
12
+ import torch
13
+ from tqdm import tqdm
14
+
15
+ model.to(device)
16
+
17
+ # --- Feature Definitions ---
18
+ feature_sets = {}
19
+ if footprints:
20
+ feature_sets["footprint"] = {
21
+ "features": {
22
+ "small_bound_stretch": [0, 20],
23
+ "medium_bound_stretch": [20, 50],
24
+ "putative_nucleosome": [50, 200],
25
+ "large_bound_stretch": [200, np.inf]
26
+ },
27
+ "state": "Non-Methylated"
28
+ }
29
+ if accessible_patches:
30
+ feature_sets["accessible"] = {
31
+ "features": {
32
+ "small_accessible_patch": [0, 20],
33
+ "mid_accessible_patch": [20, 80],
34
+ "large_accessible_patch": [80, np.inf]
35
+ },
36
+ "state": "Methylated"
37
+ }
38
+ if cpg:
39
+ feature_sets["cpg"] = {
40
+ "features": {
41
+ "cpg_patch": [0, np.inf]
42
+ },
43
+ "state": "Methylated"
44
+ }
45
+
46
+ # --- Init columns ---
47
+ all_features = []
48
+ combined_prefix = "Combined"
49
+ for key, fs in feature_sets.items():
50
+ if key == 'cpg':
51
+ all_features += [f"CpG_{f}" for f in fs["features"]]
52
+ all_features.append(f"CpG_all_{key}_features")
53
+ else:
54
+ for methbase in methbases:
55
+ all_features += [f"{methbase}_{f}" for f in fs["features"]]
56
+ all_features.append(f"{methbase}_all_{key}_features")
57
+ if len(methbases) > 1:
58
+ all_features += [f"{combined_prefix}_{f}" for f in fs["features"]]
59
+ all_features.append(f"{combined_prefix}_all_{key}_features")
60
+
61
+ for feature in all_features:
62
+ adata.obs[feature] = [[] for _ in range(adata.shape[0])]
63
+ adata.obs[f"{feature}_distances"] = [None] * adata.shape[0]
64
+ adata.obs[f"n_{feature}"] = -1
65
+
66
+ # --- Main loop ---
67
+ references = adata.obs[obs_column].cat.categories
68
+
69
+ for ref in tqdm(references, desc="Processing References"):
70
+ ref_subset = adata[adata.obs[obs_column] == ref]
71
+
72
+ # Combined methbase mask
73
+ combined_mask = None
74
+ for methbase in methbases:
75
+ mask = {
76
+ "a": ref_subset.var[f"{ref}_strand_FASTA_base"] == "A",
77
+ "gpc": ref_subset.var[f"{ref}_GpC_site"] == True,
78
+ "cpg": ref_subset.var[f"{ref}_CpG_site"] == True
79
+ }[methbase.lower()]
80
+ combined_mask = mask if combined_mask is None else combined_mask | mask
81
+
82
+ methbase_subset = ref_subset[:, mask]
83
+ matrix = methbase_subset.layers[layer] if layer else methbase_subset.X
84
+
85
+ processed_reads = [[int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in read] for read in matrix]
86
+ tensor_batch = torch.tensor(processed_reads, dtype=torch.long, device=device).unsqueeze(-1)
87
+
88
+ coords = methbase_subset.var_names
89
+ for key, fs in feature_sets.items():
90
+ if key == 'cpg':
91
+ continue
92
+ state_target = fs["state"]
93
+ feature_map = fs["features"]
94
+
95
+ pred_states = model.predict(tensor_batch)
96
+ probs = model.predict_proba(tensor_batch)
97
+ classifications = classify_batch(pred_states, probs, coords, feature_map, target_state=state_target)
98
+
99
+ for i, idx in enumerate(methbase_subset.obs.index):
100
+ for start, length, label, prob in classifications[i]:
101
+ adata.obs.at[idx, f"{methbase}_{label}"].append([start, length, prob])
102
+ adata.obs.at[idx, f"{methbase}_all_{key}_features"].append([start, length, prob])
103
+
104
+ # Combined subset
105
+ if len(methbases) > 1:
106
+ combined_subset = ref_subset[:, combined_mask]
107
+ combined_matrix = combined_subset.layers[layer] if layer else combined_subset.X
108
+ processed_combined_reads = [[int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in read] for read in combined_matrix]
109
+ tensor_combined_batch = torch.tensor(processed_combined_reads, dtype=torch.long, device=device).unsqueeze(-1)
110
+
111
+ coords = combined_subset.var_names
112
+ for key, fs in feature_sets.items():
113
+ if key == 'cpg':
114
+ continue
115
+ state_target = fs["state"]
116
+ feature_map = fs["features"]
117
+
118
+ pred_states = model.predict(tensor_combined_batch)
119
+ probs = model.predict_proba(tensor_combined_batch)
120
+ classifications = classify_batch(pred_states, probs, coords, feature_map, target_state=state_target)
121
+
122
+ for i, idx in enumerate(combined_subset.obs.index):
123
+ for start, length, label, prob in classifications[i]:
124
+ adata.obs.at[idx, f"{combined_prefix}_{label}"].append([start, length, prob])
125
+ adata.obs.at[idx, f"{combined_prefix}_all_{key}_features"].append([start, length, prob])
126
+
127
+ # --- Special handling for CpG ---
128
+ if cpg:
129
+ for ref in tqdm(references, desc="Processing CpG"):
130
+ ref_subset = adata[adata.obs[obs_column] == ref]
131
+ mask = (ref_subset.var[f"{ref}_CpG_site"] == True)
132
+ cpg_subset = ref_subset[:, mask]
133
+ matrix = cpg_subset.layers[layer] if layer else cpg_subset.X
134
+
135
+ processed_reads = [[int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in read] for read in matrix]
136
+ tensor_batch = torch.tensor(processed_reads, dtype=torch.long, device=device).unsqueeze(-1)
137
+
138
+ coords = cpg_subset.var_names
139
+ fs = feature_sets['cpg']
140
+ state_target = fs["state"]
141
+ feature_map = fs["features"]
142
+
143
+ pred_states = model.predict(tensor_batch)
144
+ probs = model.predict_proba(tensor_batch)
145
+ classifications = classify_batch(pred_states, probs, coords, feature_map, target_state=state_target)
146
+
147
+ for i, idx in enumerate(cpg_subset.obs.index):
148
+ for start, length, label, prob in classifications[i]:
149
+ adata.obs.at[idx, f"CpG_{label}"].append([start, length, prob])
150
+ adata.obs.at[idx, f"CpG_all_cpg_features"].append([start, length, prob])
151
+
152
+ # --- Binarization + Distance ---
153
+ for feature in tqdm(all_features, desc="Finalizing Layers"):
154
+ bin_matrix = np.zeros((adata.shape[0], adata.shape[1]), dtype=int)
155
+ counts = np.zeros(adata.shape[0], dtype=int)
156
+ for row_idx, intervals in enumerate(adata.obs[feature]):
157
+ if not isinstance(intervals, list):
158
+ intervals = []
159
+ for start, length, prob in intervals:
160
+ if prob > threshold:
161
+ bin_matrix[row_idx, start:start+length] = 1
162
+ counts[row_idx] += 1
163
+ adata.layers[f"{feature}"] = bin_matrix
164
+ adata.obs[f"n_{feature}"] = counts
165
+ adata.obs[f"{feature}_distances"] = calculate_batch_distances(adata.obs[feature].tolist(), threshold)
166
+
167
+ def calculate_batch_distances(intervals_list, threshold=0.9):
168
+ """
169
+ Vectorized calculation of distances across multiple reads.
170
+
171
+ Parameters:
172
+ intervals_list (list of list): Outer list = reads, inner list = intervals [start, length, prob]
173
+ threshold (float): Minimum probability threshold for filtering
174
+
175
+ Returns:
176
+ List of distance lists per read.
177
+ """
178
+ results = []
179
+ for intervals in intervals_list:
180
+ if not isinstance(intervals, list) or len(intervals) == 0:
181
+ results.append([])
182
+ continue
183
+ valid = [iv for iv in intervals if iv[2] > threshold]
184
+ valid = sorted(valid, key=lambda x: x[0])
185
+ dists = [(valid[i + 1][0] - (valid[i][0] + valid[i][1])) for i in range(len(valid) - 1)]
186
+ results.append(dists)
187
+ return results
188
+
189
+
190
+
191
+ def classify_batch(predicted_states_batch, probabilities_batch, coordinates, classification_mapping, target_state="Methylated"):
192
+ """
193
+ Classify batch sequences efficiently.
194
+
195
+ Parameters:
196
+ predicted_states_batch: Tensor [batch_size, seq_len]
197
+ probabilities_batch: Tensor [batch_size, seq_len, n_states]
198
+ coordinates: list of genomic coordinates
199
+ classification_mapping: dict of feature bins
200
+ target_state: state name ("Methylated" or "Non-Methylated")
201
+
202
+ Returns:
203
+ List of classifications for each sequence.
204
+ """
205
+ import numpy as np
206
+
207
+ state_labels = ["Non-Methylated", "Methylated"]
208
+ target_idx = state_labels.index(target_state)
209
+ batch_size = predicted_states_batch.shape[0]
210
+
211
+ all_classifications = []
212
+
213
+ for b in range(batch_size):
214
+ predicted_states = predicted_states_batch[b].cpu().numpy()
215
+ probabilities = probabilities_batch[b].cpu().numpy()
216
+
217
+ regions = []
218
+ current_start, current_length, current_probs = None, 0, []
219
+
220
+ for i, state_index in enumerate(predicted_states):
221
+ state_prob = probabilities[i][state_index]
222
+ if state_index == target_idx:
223
+ if current_start is None:
224
+ current_start = i
225
+ current_length += 1
226
+ current_probs.append(state_prob)
227
+ elif current_start is not None:
228
+ regions.append((current_start, current_length, np.mean(current_probs)))
229
+ current_start, current_length, current_probs = None, 0, []
230
+
231
+ if current_start is not None:
232
+ regions.append((current_start, current_length, np.mean(current_probs)))
233
+
234
+ final = []
235
+ for start, length, prob in regions:
236
+ feature_length = int(coordinates[start + length - 1]) - int(coordinates[start]) + 1
237
+ label = next((ftype for ftype, rng in classification_mapping.items() if rng[0] <= feature_length < rng[1]), target_state)
238
+ final.append((int(coordinates[start]) + 1, feature_length, label, prob))
239
+ all_classifications.append(final)
240
+
241
+ return all_classifications