smftools 0.1.3__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. smftools/__init__.py +5 -1
  2. smftools/_version.py +1 -1
  3. smftools/informatics/__init__.py +2 -0
  4. smftools/informatics/archived/print_bam_query_seq.py +29 -0
  5. smftools/informatics/basecall_pod5s.py +80 -0
  6. smftools/informatics/conversion_smf.py +63 -10
  7. smftools/informatics/direct_smf.py +66 -18
  8. smftools/informatics/helpers/LoadExperimentConfig.py +1 -0
  9. smftools/informatics/helpers/__init__.py +16 -2
  10. smftools/informatics/helpers/align_and_sort_BAM.py +27 -16
  11. smftools/informatics/helpers/aligned_BAM_to_bed.py +49 -48
  12. smftools/informatics/helpers/bam_qc.py +66 -0
  13. smftools/informatics/helpers/binarize_converted_base_identities.py +69 -21
  14. smftools/informatics/helpers/canoncall.py +12 -3
  15. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +5 -4
  16. smftools/informatics/helpers/converted_BAM_to_adata.py +34 -22
  17. smftools/informatics/helpers/converted_BAM_to_adata_II.py +369 -0
  18. smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
  19. smftools/informatics/helpers/extract_base_identities.py +33 -46
  20. smftools/informatics/helpers/extract_mods.py +55 -23
  21. smftools/informatics/helpers/extract_read_features_from_bam.py +31 -0
  22. smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
  23. smftools/informatics/helpers/find_conversion_sites.py +33 -44
  24. smftools/informatics/helpers/generate_converted_FASTA.py +87 -86
  25. smftools/informatics/helpers/modcall.py +13 -5
  26. smftools/informatics/helpers/modkit_extract_to_adata.py +762 -396
  27. smftools/informatics/helpers/ohe_batching.py +65 -41
  28. smftools/informatics/helpers/ohe_layers_decode.py +32 -0
  29. smftools/informatics/helpers/one_hot_decode.py +27 -0
  30. smftools/informatics/helpers/one_hot_encode.py +45 -9
  31. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +1 -0
  32. smftools/informatics/helpers/run_multiqc.py +28 -0
  33. smftools/informatics/helpers/split_and_index_BAM.py +3 -8
  34. smftools/informatics/load_adata.py +58 -3
  35. smftools/plotting/__init__.py +15 -0
  36. smftools/plotting/classifiers.py +355 -0
  37. smftools/plotting/general_plotting.py +205 -0
  38. smftools/plotting/position_stats.py +462 -0
  39. smftools/preprocessing/__init__.py +6 -7
  40. smftools/preprocessing/append_C_context.py +22 -9
  41. smftools/preprocessing/{mark_duplicates.py → archives/mark_duplicates.py} +38 -26
  42. smftools/preprocessing/binarize_on_Youden.py +35 -32
  43. smftools/preprocessing/binary_layers_to_ohe.py +13 -3
  44. smftools/preprocessing/calculate_complexity.py +3 -2
  45. smftools/preprocessing/calculate_converted_read_methylation_stats.py +44 -46
  46. smftools/preprocessing/calculate_coverage.py +26 -25
  47. smftools/preprocessing/calculate_pairwise_differences.py +49 -0
  48. smftools/preprocessing/calculate_position_Youden.py +18 -7
  49. smftools/preprocessing/calculate_read_length_stats.py +39 -46
  50. smftools/preprocessing/clean_NaN.py +33 -25
  51. smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
  52. smftools/preprocessing/filter_converted_reads_on_methylation.py +20 -5
  53. smftools/preprocessing/filter_reads_on_length.py +14 -4
  54. smftools/preprocessing/flag_duplicate_reads.py +149 -0
  55. smftools/preprocessing/invert_adata.py +18 -11
  56. smftools/preprocessing/load_sample_sheet.py +30 -16
  57. smftools/preprocessing/recipes.py +22 -20
  58. smftools/preprocessing/subsample_adata.py +58 -0
  59. smftools/readwrite.py +105 -13
  60. smftools/tools/__init__.py +49 -0
  61. smftools/tools/apply_hmm.py +202 -0
  62. smftools/tools/apply_hmm_batched.py +241 -0
  63. smftools/tools/archived/classify_methylated_features.py +66 -0
  64. smftools/tools/archived/classify_non_methylated_features.py +75 -0
  65. smftools/tools/archived/subset_adata_v1.py +32 -0
  66. smftools/tools/archived/subset_adata_v2.py +46 -0
  67. smftools/tools/calculate_distances.py +18 -0
  68. smftools/tools/calculate_umap.py +62 -0
  69. smftools/tools/call_hmm_peaks.py +105 -0
  70. smftools/tools/classifiers.py +787 -0
  71. smftools/tools/cluster_adata_on_methylation.py +105 -0
  72. smftools/tools/data/__init__.py +2 -0
  73. smftools/tools/data/anndata_data_module.py +90 -0
  74. smftools/tools/data/preprocessing.py +6 -0
  75. smftools/tools/display_hmm.py +18 -0
  76. smftools/tools/general_tools.py +69 -0
  77. smftools/tools/hmm_readwrite.py +16 -0
  78. smftools/tools/inference/__init__.py +1 -0
  79. smftools/tools/inference/lightning_inference.py +41 -0
  80. smftools/tools/models/__init__.py +9 -0
  81. smftools/tools/models/base.py +14 -0
  82. smftools/tools/models/cnn.py +34 -0
  83. smftools/tools/models/lightning_base.py +41 -0
  84. smftools/tools/models/mlp.py +17 -0
  85. smftools/tools/models/positional.py +17 -0
  86. smftools/tools/models/rnn.py +16 -0
  87. smftools/tools/models/sklearn_models.py +40 -0
  88. smftools/tools/models/transformer.py +133 -0
  89. smftools/tools/models/wrappers.py +20 -0
  90. smftools/tools/nucleosome_hmm_refinement.py +104 -0
  91. smftools/tools/position_stats.py +239 -0
  92. smftools/tools/read_stats.py +70 -0
  93. smftools/tools/subset_adata.py +19 -23
  94. smftools/tools/train_hmm.py +78 -0
  95. smftools/tools/training/__init__.py +1 -0
  96. smftools/tools/training/train_lightning_model.py +47 -0
  97. smftools/tools/utils/__init__.py +2 -0
  98. smftools/tools/utils/device.py +10 -0
  99. smftools/tools/utils/grl.py +14 -0
  100. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/METADATA +47 -11
  101. smftools-0.1.7.dist-info/RECORD +136 -0
  102. smftools/tools/apply_HMM.py +0 -1
  103. smftools/tools/read_HMM.py +0 -1
  104. smftools/tools/train_HMM.py +0 -43
  105. smftools-0.1.3.dist-info/RECORD +0 -84
  106. /smftools/preprocessing/{remove_duplicates.py → archives/remove_duplicates.py} +0 -0
  107. /smftools/tools/{cluster.py → evaluation/__init__.py} +0 -0
  108. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/WHEEL +0 -0
  109. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,6 @@
1
1
  # recipes
2
2
 
3
- def recipe_1_Kissiov_and_McKenna_2025(adata, sample_sheet_path, output_directory, mapping_key_column='Sample', reference_column = 'Reference', sample_names_col='Sample_names', invert=False):
3
+ def recipe_1_Kissiov_and_McKenna_2025(adata, sample_sheet_path, output_directory, mapping_key_column='Sample', reference_column = 'Reference', sample_names_col='Sample_names', invert=True):
4
4
  """
5
5
  The first part of the preprocessing workflow applied to the smf.inform.pod_to_adata() output derived from Kissiov_and_McKenna_2025.
6
6
 
@@ -9,9 +9,10 @@ def recipe_1_Kissiov_and_McKenna_2025(adata, sample_sheet_path, output_directory
9
9
  2) Appends a boolean indicating whether each position in var_names is within a given reference.
10
10
  3) Appends the cytosine context to each position from each reference.
11
11
  4) Calculate read level methylation statistics.
12
- 5) Optionally inverts the adata to flip the position coordinate orientation.
13
- 6) Calculates read length statistics (start position, end position, read length)
14
- 7) Returns a dictionary to pass the variable namespace to the parent scope.
12
+ 5) Calculates read length statistics (start position, end position, read length).
13
+ 6) Optionally inverts the adata to flip the position coordinate orientation.
14
+ 7) Adds new layers containing NaN replaced variants of adata.X (fill_closest, nan0_0minus1, nan1_12).
15
+ 8) Returns a dictionary to pass the variable namespace to the parent scope.
15
16
 
16
17
  Parameters:
17
18
  adata (AnnData): The AnnData object to use as input.
@@ -34,6 +35,7 @@ def recipe_1_Kissiov_and_McKenna_2025(adata, sample_sheet_path, output_directory
34
35
  from .calculate_converted_read_methylation_stats import calculate_converted_read_methylation_stats
35
36
  from .invert_adata import invert_adata
36
37
  from .calculate_read_length_stats import calculate_read_length_stats
38
+ from .clean_NaN import clean_NaN
37
39
 
38
40
  # Clean up some of the Reference metadata and save variable names that point to sets of values in the column.
39
41
  adata.obs[reference_column] = adata.obs[reference_column].astype('category')
@@ -42,7 +44,7 @@ def recipe_1_Kissiov_and_McKenna_2025(adata, sample_sheet_path, output_directory
42
44
  reference_mapping = {k: v for k, v in split_references}
43
45
  adata.obs[f'{reference_column}_short'] = adata.obs[reference_column].map(reference_mapping)
44
46
  short_references = set(adata.obs[f'{reference_column}_short'])
45
- binary_layers = adata.layers.keys()
47
+ binary_layers = list(adata.layers.keys())
46
48
 
47
49
  # load sample sheet metadata
48
50
  load_sample_sheet(adata, sample_sheet_path, mapping_key_column)
@@ -59,16 +61,19 @@ def recipe_1_Kissiov_and_McKenna_2025(adata, sample_sheet_path, output_directory
59
61
  append_C_context(adata, obs_column=reference_column, use_consensus=False)
60
62
 
61
63
  # Calculate read level methylation statistics. Assess if GpC methylation level is above other_C methylation level as a QC.
62
- calculate_converted_read_methylation_stats(adata, reference_column, sample_names_col, output_directory, show_methylation_histogram=False, save_methylation_histogram=False)
64
+ calculate_converted_read_methylation_stats(adata, reference_column, sample_names_col)
65
+
66
+ # Calculate read length statistics
67
+ upper_bound, lower_bound = calculate_read_length_stats(adata, reference_column, sample_names_col)
63
68
 
64
69
  # Invert the adata object (ie flip the strand orientation for visualization)
65
70
  if invert:
66
- invert_adata(adata)
71
+ adata = invert_adata(adata)
67
72
  else:
68
73
  pass
69
74
 
70
- # Calculate read length statistics, with options to display or save the read length histograms
71
- upper_bound, lower_bound = calculate_read_length_stats(adata, reference_column, sample_names_col, output_directory, show_read_length_histogram=False, save_read_length_histogram=False)
75
+ # NaN replacement strategies stored in additional layers. Having layer=None uses adata.X
76
+ clean_NaN(adata, layer=None)
72
77
 
73
78
  variables = {
74
79
  "short_references": short_references,
@@ -80,22 +85,21 @@ def recipe_1_Kissiov_and_McKenna_2025(adata, sample_sheet_path, output_directory
80
85
  }
81
86
  return variables
82
87
 
83
- def recipe_2_Kissiov_and_McKenna_2025(adata, output_directory, binary_layers, hamming_distance_thresholds={}, reference_column = 'Reference', sample_names_col='Sample_names'):
88
+ def recipe_2_Kissiov_and_McKenna_2025(adata, output_directory, binary_layers, distance_thresholds={}, reference_column = 'Reference', sample_names_col='Sample_names'):
84
89
  """
85
90
  The second part of the preprocessing workflow applied to the adata that has already been preprocessed by recipe_1_Kissiov_and_McKenna_2025.
86
91
 
87
92
  Performs the following tasks:
88
- 1) Adds new layers containing NaN replaced variants of adata.X (fill_closest, nan0_0minus1, nan1_12).
89
- 2) Marks putative PCR duplicates using pairwise hamming distance metrics.
90
- 3) Performs a complexity analysis of the library based on the PCR duplicate detection rate.
91
- 4) Removes PCR duplicates from the adata.
92
- 5) Returns two adata object: one for the filtered adata and one for the duplicate adata.
93
+ 1) Marks putative PCR duplicates using pairwise hamming distance metrics.
94
+ 2) Performs a complexity analysis of the library based on the PCR duplicate detection rate.
95
+ 3) Removes PCR duplicates from the adata.
96
+ 4) Returns two adata object: one for the filtered adata and one for the duplicate adata.
93
97
 
94
98
  Parameters:
95
99
  adata (AnnData): The AnnData object to use as input.
96
100
  output_directory (str): String representing the path to the output directory for plots.
97
101
  binary_layers (list): A list of layers to used for the binary encoding of read sequences. Used for duplicate detection.
98
- hamming_distance_thresholds (dict): A dictionary keyed by obs_column categories that points to a float corresponding to the distance threshold to apply. Default is an empty dict.
102
+ distance_thresholds (dict): A dictionary keyed by obs_column categories that points to a float corresponding to the distance threshold to apply. Default is an empty dict.
99
103
  reference_column (str): The name of the reference column to use.
100
104
  sample_names_col (str): The name of the sample name column to use.
101
105
 
@@ -106,16 +110,14 @@ def recipe_2_Kissiov_and_McKenna_2025(adata, output_directory, binary_layers, ha
106
110
  import anndata as ad
107
111
  import pandas as pd
108
112
  import numpy as np
109
- from .clean_NaN import clean_NaN
110
113
  from .mark_duplicates import mark_duplicates
111
114
  from .calculate_complexity import calculate_complexity
112
115
  from .remove_duplicates import remove_duplicates
113
116
 
114
- # NaN replacement strategies stored in additional layers. Having layer=None uses adata.X
115
- clean_NaN(adata, layer=None)
117
+ # Add here a way to remove reads below a given read quality (based on nan content). Need to also add a way to pull from BAM files the read quality from each read
116
118
 
117
119
  # Duplicate detection using pairwise hamming distance across reads
118
- mark_duplicates(adata, binary_layers, obs_column=reference_column, sample_col=sample_names_col, hamming_distance_thresholds=hamming_distance_thresholds)
120
+ mark_duplicates(adata, binary_layers, obs_column=reference_column, sample_col=sample_names_col, distance_thresholds=distance_thresholds, method='N_masked_distances')
119
121
 
120
122
  # Complexity analysis using the marked duplicates and the lander-watermann algorithm
121
123
  calculate_complexity(adata, output_directory, obs_column=reference_column, sample_col=sample_names_col, plot=True, save_plot=False)
@@ -0,0 +1,58 @@
1
+ def subsample_adata(adata, obs_columns=None, max_samples=2000, random_seed=42):
2
+ """
3
+ Subsamples an AnnData object so that each unique combination of categories
4
+ in the given `obs_columns` has at most `max_samples` observations.
5
+ If `obs_columns` is None or empty, the function randomly subsamples the entire dataset.
6
+
7
+ Parameters:
8
+ adata (AnnData): The AnnData object to subsample.
9
+ obs_columns (list of str, optional): List of observation column names to group by.
10
+ max_samples (int): The maximum number of observations per category combination.
11
+ random_seed (int): Random seed for reproducibility.
12
+
13
+ Returns:
14
+ AnnData: A new AnnData object with subsampled observations.
15
+ """
16
+ import anndata as ad
17
+ import numpy as np
18
+
19
+ np.random.seed(random_seed) # Ensure reproducibility
20
+
21
+ if not obs_columns: # If no obs columns are given, sample globally
22
+ if adata.shape[0] > max_samples:
23
+ sampled_indices = np.random.choice(adata.obs.index, max_samples, replace=False)
24
+ else:
25
+ sampled_indices = adata.obs.index # Keep all if fewer than max_samples
26
+
27
+ return adata[sampled_indices].copy()
28
+
29
+ sampled_indices = []
30
+
31
+ # Get unique category combinations from all specified obs columns
32
+ unique_combinations = adata.obs[obs_columns].drop_duplicates()
33
+
34
+ for _, row in unique_combinations.iterrows():
35
+ # Build filter condition dynamically for multiple columns
36
+ condition = (adata.obs[obs_columns] == row.values).all(axis=1)
37
+
38
+ # Get indices for the current category combination
39
+ subset_indices = adata.obs[condition].index.to_numpy()
40
+
41
+ # Subsample or take all
42
+ if len(subset_indices) > max_samples:
43
+ sampled = np.random.choice(subset_indices, max_samples, replace=False)
44
+ else:
45
+ sampled = subset_indices # Keep all if fewer than max_samples
46
+
47
+ sampled_indices.extend(sampled)
48
+
49
+ # ⚠ Handle backed mode detection
50
+ if adata.isbacked:
51
+ print("⚠ Detected backed mode. Subset will be loaded fully into memory.")
52
+ subset = adata[sampled_indices]
53
+ subset = subset.to_memory()
54
+ else:
55
+ subset = adata[sampled_indices]
56
+
57
+ # Create a new AnnData object with only the selected indices
58
+ return subset[sampled_indices].copy()
smftools/readwrite.py CHANGED
@@ -23,27 +23,46 @@ def time_string():
23
23
 
24
24
  ######################################################################################################
25
25
  ## Numpy, Pandas, Anndata functionality
26
+
26
27
  def adata_to_df(adata, layer=None):
27
28
  """
28
- Input: An adata object with a specified layer.
29
- Output: A dataframe for the specific layer.
29
+ Convert an AnnData object into a Pandas DataFrame.
30
+
31
+ Parameters:
32
+ adata (AnnData): The input AnnData object.
33
+ layer (str, optional): The layer to extract. If None, uses adata.X.
34
+
35
+ Returns:
36
+ pd.DataFrame: A DataFrame where rows are observations and columns are positions.
30
37
  """
31
38
  import pandas as pd
32
39
  import anndata as ad
40
+ import numpy as np
41
+
42
+ # Validate that the requested layer exists
43
+ if layer and layer not in adata.layers:
44
+ raise ValueError(f"Layer '{layer}' not found in adata.layers.")
45
+
46
+ # Extract the data matrix
47
+ data_matrix = adata.layers.get(layer, adata.X)
48
+
49
+ # Ensure matrix is dense (handle sparse formats)
50
+ if hasattr(data_matrix, "toarray"):
51
+ data_matrix = data_matrix.toarray()
52
+
53
+ # Ensure obs and var have unique indices
54
+ if adata.obs.index.duplicated().any():
55
+ raise ValueError("Duplicate values found in `adata.obs.index`. Ensure unique observation indices.")
56
+
57
+ if adata.var.index.duplicated().any():
58
+ raise ValueError("Duplicate values found in `adata.var.index`. Ensure unique variable indices.")
59
+
60
+ # Convert to DataFrame
61
+ df = pd.DataFrame(data_matrix, index=adata.obs.index, columns=adata.var.index)
33
62
 
34
- # Extract the data matrix from the given layer
35
- if layer:
36
- data_matrix = adata.layers[layer]
37
- else:
38
- data_matrix = adata.X
39
- # Extract observation (read) annotations
40
- obs_df = adata.obs
41
- # Extract variable (position) annotations
42
- var_df = adata.var
43
- # Convert data matrix and annotations to pandas DataFrames
44
- df = pd.DataFrame(data_matrix, index=obs_df.index, columns=var_df.index)
45
63
  return df
46
64
 
65
+
47
66
  def save_matrix(matrix, save_name):
48
67
  """
49
68
  Input: A numpy matrix and a save_name
@@ -103,4 +122,77 @@ def concatenate_h5ads(output_file, file_suffix='h5ad.gz', delete_inputs=True):
103
122
  print(f"Error deleting file {hdf}: {e}")
104
123
  else:
105
124
  print('Keeping input files')
125
+
126
+ def safe_write_h5ad(adata, path, compression="gzip", backup=False, backup_dir="./"):
127
+ """
128
+ Saves an AnnData object safely by omitting problematic columns from .obs and .var.
129
+
130
+ Parameters:
131
+ adata (AnnData): The AnnData object to save.
132
+ path (str): Output .h5ad file path.
133
+ compression (str): Compression method for h5ad file.
134
+ backup (bool): If True, saves problematic columns to CSV files.
135
+ backup_dir (str): Directory to store backups if backup=True.
136
+ """
137
+ import anndata as ad
138
+ import pandas as pd
139
+ import os
140
+
141
+ os.makedirs(backup_dir, exist_ok=True)
142
+
143
+ def filter_df(df, df_name):
144
+ bad_cols = []
145
+ for col in df.columns:
146
+ if df[col].dtype == 'object':
147
+ if not df[col].apply(lambda x: isinstance(x, (str, type(None)))).all():
148
+ bad_cols.append(col)
149
+ if bad_cols:
150
+ print(f"⚠️ Skipping columns from {df_name}: {bad_cols}")
151
+ if backup:
152
+ df[bad_cols].to_csv(os.path.join(backup_dir, f"{df_name}_skipped_columns.csv"))
153
+ print(f"📝 Backed up skipped columns to {backup_dir}/{df_name}_skipped_columns.csv")
154
+ return df.drop(columns=bad_cols)
155
+
156
+ # Clean obs and var
157
+ obs_clean = filter_df(adata.obs, "obs")
158
+ var_clean = filter_df(adata.var, "var")
159
+
160
+ # Save clean version
161
+ adata_copy = ad.AnnData(
162
+ X=adata.X,
163
+ obs=obs_clean,
164
+ var=var_clean,
165
+ layers=adata.layers,
166
+ uns=adata.uns,
167
+ obsm=adata.obsm,
168
+ varm=adata.varm
169
+ )
170
+ adata_copy.write_h5ad(path, compression=compression)
171
+ print(f"✅ Saved safely to {path}")
172
+
173
+ def merge_barcoded_anndatas(adata_single, adata_double):
174
+ import numpy as np
175
+ import anndata as ad
176
+
177
+ # Step 1: Identify overlap
178
+ overlap = np.intersect1d(adata_single.obs_names, adata_double.obs_names)
179
+
180
+ # Step 2: Filter out overlaps from adata_single
181
+ adata_single_filtered = adata_single[~adata_single.obs_names.isin(overlap)].copy()
182
+
183
+ # Step 3: Add source tag
184
+ adata_single_filtered.obs['source'] = 'single_barcode'
185
+ adata_double.obs['source'] = 'double_barcode'
186
+
187
+ # Step 4: Concatenate all components
188
+ adata_merged = ad.concat([
189
+ adata_single_filtered,
190
+ adata_double
191
+ ], join='outer', merge='same') # merge='same' preserves matching layers, obsm, etc.
192
+
193
+ # Step 5: Merge `.uns`
194
+ adata_merged.uns = {**adata_single.uns, **adata_double.uns}
195
+
196
+ return adata_merged
197
+
106
198
  ######################################################################################################
@@ -0,0 +1,49 @@
1
+ from .apply_hmm import apply_hmm
2
+ from .apply_hmm_batched import apply_hmm_batched
3
+ from .position_stats import calculate_relative_risk_on_activity, compute_positionwise_statistic
4
+ from .calculate_distances import calculate_distances
5
+ from .calculate_umap import calculate_umap
6
+ from .call_hmm_peaks import call_hmm_peaks
7
+ from .classifiers import run_training_loop, run_inference, evaluate_models_by_subgroup, prepare_melted_model_data, sliding_window_train_test
8
+ from .cluster_adata_on_methylation import cluster_adata_on_methylation
9
+ from .display_hmm import display_hmm
10
+ from .general_tools import create_nan_mask_from_X, combine_layers, create_nan_or_non_gpc_mask
11
+ from .hmm_readwrite import load_hmm, save_hmm
12
+ from .nucleosome_hmm_refinement import refine_nucleosome_calls, infer_nucleosomes_in_large_bound
13
+ from .read_stats import calculate_row_entropy
14
+ from .subset_adata import subset_adata
15
+ from .train_hmm import train_hmm
16
+
17
+ from . import models
18
+ from . import data
19
+ from . import utils
20
+ from . import evaluation
21
+ from . import inference
22
+ from . import training
23
+
24
+ __all__ = [
25
+ "apply_hmm",
26
+ "apply_hmm_batched",
27
+ "calculate_distances",
28
+ "compute_positionwise_statistic",
29
+ "calculate_row_entropy",
30
+ "calculate_umap",
31
+ "calculate_relative_risk_on_activity",
32
+ "call_hmm_peaks",
33
+ "cluster_adata_on_methylation",
34
+ "create_nan_mask_from_X",
35
+ "create_nan_or_non_gpc_mask",
36
+ "combine_layers",
37
+ "display_hmm",
38
+ "evaluate_models_by_subgroup",
39
+ "load_hmm",
40
+ "prepare_melted_model_data",
41
+ "refine_nucleosome_calls",
42
+ "infer_nucleosomes_in_large_bound",
43
+ "run_training_loop",
44
+ "run_inference",
45
+ "save_hmm",
46
+ "sliding_window_train_test"
47
+ "subset_adata",
48
+ "train_hmm"
49
+ ]
@@ -0,0 +1,202 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import torch
4
+ from tqdm import tqdm
5
+
6
+ def apply_hmm(adata, model, obs_column, layer=None, footprints=True, accessible_patches=False, cpg=False, methbases=["GpC", "CpG", "A"], device="cpu", threshold=0.7):
7
+ """
8
+ Applies an HMM model to an AnnData object using tensor-based sequence inputs.
9
+ If multiple methbases are passed, generates a combined feature set.
10
+ """
11
+ model.to(device)
12
+
13
+ # --- Feature Definitions ---
14
+ feature_sets = {}
15
+ if footprints:
16
+ feature_sets["footprint"] = {
17
+ "features": {
18
+ "small_bound_stretch": [0, 30],
19
+ "medium_bound_stretch": [30, 80],
20
+ "putative_nucleosome": [80, 200],
21
+ "large_bound_stretch": [200, np.inf]
22
+ },
23
+ "state": "Non-Methylated"
24
+ }
25
+ if accessible_patches:
26
+ feature_sets["accessible"] = {
27
+ "features": {
28
+ "small_accessible_patch": [0, 30],
29
+ "mid_accessible_patch": [30, 80],
30
+ "large_accessible_patch": [80, np.inf]
31
+ },
32
+ "state": "Methylated"
33
+ }
34
+ if cpg:
35
+ feature_sets["cpg"] = {
36
+ "features": {
37
+ "cpg_patch": [0, np.inf]
38
+ },
39
+ "state": "Methylated"
40
+ }
41
+
42
+ # --- Init columns ---
43
+ all_features = []
44
+ combined_prefix = "Combined"
45
+ for key, fs in feature_sets.items():
46
+ if key == 'cpg':
47
+ all_features += [f"CpG_{f}" for f in fs["features"]]
48
+ all_features.append(f"CpG_all_{key}_features")
49
+ else:
50
+ for methbase in methbases:
51
+ all_features += [f"{methbase}_{f}" for f in fs["features"]]
52
+ all_features.append(f"{methbase}_all_{key}_features")
53
+ all_features += [f"{combined_prefix}_{f}" for f in fs["features"]]
54
+ all_features.append(f"{combined_prefix}_all_{key}_features")
55
+
56
+ for feature in all_features:
57
+ adata.obs[feature] = pd.Series([[] for _ in range(adata.shape[0])], dtype=object, index=adata.obs.index)
58
+ adata.obs[f"{feature}_distances"] = pd.Series([None] * adata.shape[0])
59
+ adata.obs[f"n_{feature}"] = -1
60
+
61
+ # --- Main loop ---
62
+ references = adata.obs[obs_column].cat.categories
63
+
64
+ for ref in tqdm(references, desc="Processing References"):
65
+ ref_subset = adata[adata.obs[obs_column] == ref]
66
+
67
+ # Create combined mask for methbases
68
+ combined_mask = None
69
+ for methbase in methbases:
70
+ mask = {
71
+ "a": ref_subset.var[f"{ref}_strand_FASTA_base"] == "A",
72
+ "gpc": ref_subset.var[f"{ref}_GpC_site"] == True,
73
+ "cpg": ref_subset.var[f"{ref}_CpG_site"] == True
74
+ }[methbase.lower()]
75
+ combined_mask = mask if combined_mask is None else combined_mask | mask
76
+
77
+ methbase_subset = ref_subset[:, mask]
78
+ matrix = methbase_subset.layers[layer] if layer else methbase_subset.X
79
+
80
+ for i, raw_read in enumerate(matrix):
81
+ read = [int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in raw_read]
82
+ tensor_read = torch.tensor(read, dtype=torch.long, device=device).unsqueeze(0).unsqueeze(-1)
83
+ coords = methbase_subset.var_names
84
+
85
+ for key, fs in feature_sets.items():
86
+ if key == 'cpg':
87
+ continue
88
+ state_target = fs["state"]
89
+ feature_map = fs["features"]
90
+
91
+ classifications = classify_features(tensor_read, model, coords, feature_map, target_state=state_target)
92
+ idx = methbase_subset.obs.index[i]
93
+
94
+ for start, length, label, prob in classifications:
95
+ adata.obs.at[idx, f"{methbase}_{label}"].append([start, length, prob])
96
+ adata.obs.at[idx, f"{methbase}_all_{key}_features"].append([start, length, prob])
97
+
98
+ # Combined methbase subset
99
+ combined_subset = ref_subset[:, combined_mask]
100
+ combined_matrix = combined_subset.layers[layer] if layer else combined_subset.X
101
+
102
+ for i, raw_read in enumerate(combined_matrix):
103
+ read = [int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in raw_read]
104
+ tensor_read = torch.tensor(read, dtype=torch.long, device=device).unsqueeze(0).unsqueeze(-1)
105
+ coords = combined_subset.var_names
106
+
107
+ for key, fs in feature_sets.items():
108
+ if key == 'cpg':
109
+ continue
110
+ state_target = fs["state"]
111
+ feature_map = fs["features"]
112
+
113
+ classifications = classify_features(tensor_read, model, coords, feature_map, target_state=state_target)
114
+ idx = combined_subset.obs.index[i]
115
+
116
+ for start, length, label, prob in classifications:
117
+ adata.obs.at[idx, f"{combined_prefix}_{label}"].append([start, length, prob])
118
+ adata.obs.at[idx, f"{combined_prefix}_all_{key}_features"].append([start, length, prob])
119
+
120
+ # --- Special handling for CpG ---
121
+ if cpg:
122
+ for ref in tqdm(references, desc="Processing CpG"):
123
+ ref_subset = adata[adata.obs[obs_column] == ref]
124
+ mask = (ref_subset.var[f"{ref}_CpG_site"] == True)
125
+ cpg_subset = ref_subset[:, mask]
126
+ matrix = cpg_subset.layers[layer] if layer else cpg_subset.X
127
+
128
+ for i, raw_read in enumerate(matrix):
129
+ read = [int(x) if not np.isnan(x) else np.random.choice([0, 1]) for x in raw_read]
130
+ tensor_read = torch.tensor(read, dtype=torch.long, device=device).unsqueeze(0).unsqueeze(-1)
131
+ coords = cpg_subset.var_names
132
+ fs = feature_sets['cpg']
133
+ state_target = fs["state"]
134
+ feature_map = fs["features"]
135
+ classifications = classify_features(tensor_read, model, coords, feature_map, target_state=state_target)
136
+ idx = cpg_subset.obs.index[i]
137
+ for start, length, label, prob in classifications:
138
+ adata.obs.at[idx, f"CpG_{label}"].append([start, length, prob])
139
+ adata.obs.at[idx, f"CpG_all_cpg_features"].append([start, length, prob])
140
+
141
+ # --- Binarization + Distance ---
142
+ for feature in tqdm(all_features, desc="Finalizing Layers"):
143
+ bin_matrix = np.zeros((adata.shape[0], adata.shape[1]), dtype=int)
144
+ counts = np.zeros(adata.shape[0], dtype=int)
145
+ for row_idx, intervals in enumerate(adata.obs[feature]):
146
+ if not isinstance(intervals, list):
147
+ intervals = []
148
+ for start, length, prob in intervals:
149
+ if prob > threshold:
150
+ bin_matrix[row_idx, start:start+length] = 1
151
+ counts[row_idx] += 1
152
+ adata.layers[f"{feature}"] = bin_matrix
153
+ adata.obs[f"n_{feature}"] = counts
154
+ adata.obs[f"{feature}_distances"] = adata.obs[feature].apply(lambda x: calculate_distances(x, threshold))
155
+
156
+ def calculate_distances(intervals, threshold=0.9):
157
+ """Calculates distances between consecutive features in a read."""
158
+ intervals = sorted([iv for iv in intervals if iv[2] > threshold], key=lambda x: x[0])
159
+ distances = [(intervals[i + 1][0] - (intervals[i][0] + intervals[i][1]))
160
+ for i in range(len(intervals) - 1)]
161
+ return distances
162
+
163
+
164
+ def classify_features(sequence, model, coordinates, classification_mapping={}, target_state="Methylated"):
165
+ """
166
+ Classifies regions based on HMM state.
167
+
168
+ Parameters:
169
+ sequence (torch.Tensor): Tensor of binarized data [batch_size, seq_len, 1]
170
+ model: Trained pomegranate HMM
171
+ coordinates (list): Genomic coordinates for sequence
172
+ classification_mapping (dict): Mapping for feature labeling
173
+ target_state (str): The state to classify ("Methylated" or "Non-Methylated")
174
+ """
175
+ predicted_states = model.predict(sequence).squeeze(-1).squeeze(0).cpu().numpy()
176
+ probabilities = model.predict_proba(sequence).squeeze(0).cpu().numpy()
177
+ state_labels = ["Non-Methylated", "Methylated"]
178
+
179
+ classifications, current_start, current_length, current_probs = [], None, 0, []
180
+
181
+ for i, state_index in enumerate(predicted_states):
182
+ state_name = state_labels[state_index]
183
+ state_prob = probabilities[i][state_index]
184
+
185
+ if state_name == target_state:
186
+ if current_start is None:
187
+ current_start = i
188
+ current_length += 1
189
+ current_probs.append(state_prob)
190
+ elif current_start is not None:
191
+ classifications.append((current_start, current_length, avg := np.mean(current_probs)))
192
+ current_start, current_length, current_probs = None, 0, []
193
+
194
+ if current_start is not None:
195
+ classifications.append((current_start, current_length, avg := np.mean(current_probs)))
196
+
197
+ final = []
198
+ for start, length, prob in classifications:
199
+ feature_length = int(coordinates[start + length - 1]) - int(coordinates[start]) + 1
200
+ label = next((ftype for ftype, rng in classification_mapping.items() if rng[0] <= feature_length < rng[1]), target_state)
201
+ final.append((int(coordinates[start]) + 1, feature_length, label, prob))
202
+ return final