smftools 0.1.6__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. smftools/__init__.py +34 -0
  2. smftools/_settings.py +20 -0
  3. smftools/_version.py +1 -0
  4. smftools/cli.py +184 -0
  5. smftools/config/__init__.py +1 -0
  6. smftools/config/conversion.yaml +33 -0
  7. smftools/config/deaminase.yaml +56 -0
  8. smftools/config/default.yaml +253 -0
  9. smftools/config/direct.yaml +17 -0
  10. smftools/config/experiment_config.py +1191 -0
  11. smftools/datasets/F1_hybrid_NKG2A_enhander_promoter_GpC_conversion_SMF.h5ad.gz +0 -0
  12. smftools/datasets/F1_sample_sheet.csv +5 -0
  13. smftools/datasets/__init__.py +9 -0
  14. smftools/datasets/dCas9_m6A_invitro_kinetics.h5ad.gz +0 -0
  15. smftools/datasets/datasets.py +28 -0
  16. smftools/hmm/HMM.py +1576 -0
  17. smftools/hmm/__init__.py +20 -0
  18. smftools/hmm/apply_hmm_batched.py +242 -0
  19. smftools/hmm/calculate_distances.py +18 -0
  20. smftools/hmm/call_hmm_peaks.py +106 -0
  21. smftools/hmm/display_hmm.py +18 -0
  22. smftools/hmm/hmm_readwrite.py +16 -0
  23. smftools/hmm/nucleosome_hmm_refinement.py +104 -0
  24. smftools/hmm/train_hmm.py +78 -0
  25. smftools/informatics/__init__.py +14 -0
  26. smftools/informatics/archived/bam_conversion.py +59 -0
  27. smftools/informatics/archived/bam_direct.py +63 -0
  28. smftools/informatics/archived/basecalls_to_adata.py +71 -0
  29. smftools/informatics/archived/conversion_smf.py +132 -0
  30. smftools/informatics/archived/deaminase_smf.py +132 -0
  31. smftools/informatics/archived/direct_smf.py +137 -0
  32. smftools/informatics/archived/print_bam_query_seq.py +29 -0
  33. smftools/informatics/basecall_pod5s.py +80 -0
  34. smftools/informatics/fast5_to_pod5.py +24 -0
  35. smftools/informatics/helpers/__init__.py +73 -0
  36. smftools/informatics/helpers/align_and_sort_BAM.py +86 -0
  37. smftools/informatics/helpers/aligned_BAM_to_bed.py +85 -0
  38. smftools/informatics/helpers/archived/informatics.py +260 -0
  39. smftools/informatics/helpers/archived/load_adata.py +516 -0
  40. smftools/informatics/helpers/bam_qc.py +66 -0
  41. smftools/informatics/helpers/bed_to_bigwig.py +39 -0
  42. smftools/informatics/helpers/binarize_converted_base_identities.py +172 -0
  43. smftools/informatics/helpers/canoncall.py +34 -0
  44. smftools/informatics/helpers/complement_base_list.py +21 -0
  45. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +378 -0
  46. smftools/informatics/helpers/converted_BAM_to_adata.py +245 -0
  47. smftools/informatics/helpers/converted_BAM_to_adata_II.py +505 -0
  48. smftools/informatics/helpers/count_aligned_reads.py +43 -0
  49. smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
  50. smftools/informatics/helpers/discover_input_files.py +100 -0
  51. smftools/informatics/helpers/extract_base_identities.py +70 -0
  52. smftools/informatics/helpers/extract_mods.py +83 -0
  53. smftools/informatics/helpers/extract_read_features_from_bam.py +33 -0
  54. smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
  55. smftools/informatics/helpers/extract_readnames_from_BAM.py +22 -0
  56. smftools/informatics/helpers/find_conversion_sites.py +51 -0
  57. smftools/informatics/helpers/generate_converted_FASTA.py +99 -0
  58. smftools/informatics/helpers/get_chromosome_lengths.py +32 -0
  59. smftools/informatics/helpers/get_native_references.py +28 -0
  60. smftools/informatics/helpers/index_fasta.py +12 -0
  61. smftools/informatics/helpers/make_dirs.py +21 -0
  62. smftools/informatics/helpers/make_modbed.py +27 -0
  63. smftools/informatics/helpers/modQC.py +27 -0
  64. smftools/informatics/helpers/modcall.py +36 -0
  65. smftools/informatics/helpers/modkit_extract_to_adata.py +887 -0
  66. smftools/informatics/helpers/ohe_batching.py +76 -0
  67. smftools/informatics/helpers/ohe_layers_decode.py +32 -0
  68. smftools/informatics/helpers/one_hot_decode.py +27 -0
  69. smftools/informatics/helpers/one_hot_encode.py +57 -0
  70. smftools/informatics/helpers/plot_bed_histograms.py +269 -0
  71. smftools/informatics/helpers/run_multiqc.py +28 -0
  72. smftools/informatics/helpers/separate_bam_by_bc.py +43 -0
  73. smftools/informatics/helpers/split_and_index_BAM.py +32 -0
  74. smftools/informatics/readwrite.py +106 -0
  75. smftools/informatics/subsample_fasta_from_bed.py +47 -0
  76. smftools/informatics/subsample_pod5.py +104 -0
  77. smftools/load_adata.py +1346 -0
  78. smftools/machine_learning/__init__.py +12 -0
  79. smftools/machine_learning/data/__init__.py +2 -0
  80. smftools/machine_learning/data/anndata_data_module.py +234 -0
  81. smftools/machine_learning/data/preprocessing.py +6 -0
  82. smftools/machine_learning/evaluation/__init__.py +2 -0
  83. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  84. smftools/machine_learning/evaluation/evaluators.py +223 -0
  85. smftools/machine_learning/inference/__init__.py +3 -0
  86. smftools/machine_learning/inference/inference_utils.py +27 -0
  87. smftools/machine_learning/inference/lightning_inference.py +68 -0
  88. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  89. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  90. smftools/machine_learning/models/__init__.py +9 -0
  91. smftools/machine_learning/models/base.py +295 -0
  92. smftools/machine_learning/models/cnn.py +138 -0
  93. smftools/machine_learning/models/lightning_base.py +345 -0
  94. smftools/machine_learning/models/mlp.py +26 -0
  95. smftools/machine_learning/models/positional.py +18 -0
  96. smftools/machine_learning/models/rnn.py +17 -0
  97. smftools/machine_learning/models/sklearn_models.py +273 -0
  98. smftools/machine_learning/models/transformer.py +303 -0
  99. smftools/machine_learning/models/wrappers.py +20 -0
  100. smftools/machine_learning/training/__init__.py +2 -0
  101. smftools/machine_learning/training/train_lightning_model.py +135 -0
  102. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  103. smftools/machine_learning/utils/__init__.py +2 -0
  104. smftools/machine_learning/utils/device.py +10 -0
  105. smftools/machine_learning/utils/grl.py +14 -0
  106. smftools/plotting/__init__.py +18 -0
  107. smftools/plotting/autocorrelation_plotting.py +611 -0
  108. smftools/plotting/classifiers.py +355 -0
  109. smftools/plotting/general_plotting.py +682 -0
  110. smftools/plotting/hmm_plotting.py +260 -0
  111. smftools/plotting/position_stats.py +462 -0
  112. smftools/plotting/qc_plotting.py +270 -0
  113. smftools/preprocessing/__init__.py +38 -0
  114. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  115. smftools/preprocessing/append_base_context.py +122 -0
  116. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  117. smftools/preprocessing/archives/mark_duplicates.py +146 -0
  118. smftools/preprocessing/archives/preprocessing.py +614 -0
  119. smftools/preprocessing/archives/remove_duplicates.py +21 -0
  120. smftools/preprocessing/binarize_on_Youden.py +45 -0
  121. smftools/preprocessing/binary_layers_to_ohe.py +40 -0
  122. smftools/preprocessing/calculate_complexity.py +72 -0
  123. smftools/preprocessing/calculate_complexity_II.py +248 -0
  124. smftools/preprocessing/calculate_consensus.py +47 -0
  125. smftools/preprocessing/calculate_coverage.py +51 -0
  126. smftools/preprocessing/calculate_pairwise_differences.py +49 -0
  127. smftools/preprocessing/calculate_pairwise_hamming_distances.py +27 -0
  128. smftools/preprocessing/calculate_position_Youden.py +115 -0
  129. smftools/preprocessing/calculate_read_length_stats.py +79 -0
  130. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  131. smftools/preprocessing/clean_NaN.py +62 -0
  132. smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
  133. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  134. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  135. smftools/preprocessing/flag_duplicate_reads.py +1351 -0
  136. smftools/preprocessing/invert_adata.py +37 -0
  137. smftools/preprocessing/load_sample_sheet.py +53 -0
  138. smftools/preprocessing/make_dirs.py +21 -0
  139. smftools/preprocessing/min_non_diagonal.py +25 -0
  140. smftools/preprocessing/recipes.py +127 -0
  141. smftools/preprocessing/subsample_adata.py +58 -0
  142. smftools/readwrite.py +1004 -0
  143. smftools/tools/__init__.py +20 -0
  144. smftools/tools/archived/apply_hmm.py +202 -0
  145. smftools/tools/archived/classifiers.py +787 -0
  146. smftools/tools/archived/classify_methylated_features.py +66 -0
  147. smftools/tools/archived/classify_non_methylated_features.py +75 -0
  148. smftools/tools/archived/subset_adata_v1.py +32 -0
  149. smftools/tools/archived/subset_adata_v2.py +46 -0
  150. smftools/tools/calculate_umap.py +62 -0
  151. smftools/tools/cluster_adata_on_methylation.py +105 -0
  152. smftools/tools/general_tools.py +69 -0
  153. smftools/tools/position_stats.py +601 -0
  154. smftools/tools/read_stats.py +184 -0
  155. smftools/tools/spatial_autocorrelation.py +562 -0
  156. smftools/tools/subset_adata.py +28 -0
  157. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/METADATA +9 -2
  158. smftools-0.2.1.dist-info/RECORD +161 -0
  159. smftools-0.2.1.dist-info/entry_points.txt +2 -0
  160. smftools-0.1.6.dist-info/RECORD +0 -4
  161. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/WHEEL +0 -0
  162. {smftools-0.1.6.dist-info → smftools-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,66 @@
1
+ # classify_methylated_features
2
+
3
+ def classify_methylated_features(read, model, coordinates, classification_mapping={}):
4
+ """
5
+ Classifies methylated features (accessible features or CpG methylated features)
6
+
7
+ Parameters:
8
+ read (np.ndarray) : An array of binarized SMF data representing a read
9
+ model (): a trained pomegranate HMM
10
+ coordinates (list): a list of postional coordinates corresponding to the positions in the read
11
+ classification_mapping (dict): A dictionary keyed by classification name that points to a 2-element list containing size boundary constraints for that feature.
12
+ Returns:
13
+ final_classifications (list): A list of tuples, where each tuple is an instance of a non-methylated feature in the read. The tuple contains: feature start, feature length, feature classification, and HMM probability.
14
+ """
15
+ import numpy as np
16
+
17
+ sequence = list(read)
18
+ # Get the predicted states using the MAP algorithm
19
+ predicted_states = model.predict(sequence, algorithm='map')
20
+
21
+ # Get the probabilities for each state using the forward-backward algorithm
22
+ probabilities = model.predict_proba(sequence)
23
+
24
+ # Initialize lists to store the classifications and their probabilities
25
+ classifications = []
26
+ current_start = None
27
+ current_length = 0
28
+ current_probs = []
29
+
30
+ for i, state_index in enumerate(predicted_states):
31
+ state_name = model.states[state_index].name
32
+ state_prob = probabilities[i][state_index]
33
+
34
+ if state_name == "Methylated":
35
+ if current_start is None:
36
+ current_start = i
37
+ current_length += 1
38
+ current_probs.append(state_prob)
39
+ else:
40
+ if current_start is not None:
41
+ avg_prob = np.mean(current_probs)
42
+ classifications.append((current_start, current_length, "Methylated", avg_prob))
43
+ current_start = None
44
+ current_length = 0
45
+ current_probs = []
46
+
47
+ if current_start is not None:
48
+ avg_prob = np.mean(current_probs)
49
+ classifications.append((current_start, current_length, "Methylated", avg_prob))
50
+
51
+ final_classifications = []
52
+ for start, length, classification, prob in classifications:
53
+ final_classification = None
54
+ feature_length = int(coordinates[start + length - 1]) - int(coordinates[start]) + 1
55
+ for feature_type, size_range in classification_mapping.items():
56
+ if size_range[0] <= feature_length < size_range[1]:
57
+ final_classification = feature_type
58
+ break
59
+ else:
60
+ pass
61
+ if not final_classification:
62
+ final_classification = classification
63
+
64
+ final_classifications.append((int(coordinates[start]) + 1, feature_length, final_classification, prob))
65
+
66
+ return final_classifications
@@ -0,0 +1,75 @@
1
+ # classify_non_methylated_features
2
+
3
+ def classify_non_methylated_features(read, model, coordinates, classification_mapping={}):
4
+ """
5
+ Classifies non-methylated features (inaccessible features)
6
+
7
+ Parameters:
8
+ read (np.ndarray) : An array of binarized SMF data representing a read
9
+ model (): a trained pomegranate HMM
10
+ coordinates (list): a list of postional coordinates corresponding to the positions in the read
11
+ classification_mapping (dict): A dictionary keyed by classification name that points to a 2-element list containing size boundary constraints for that feature.
12
+ Returns:
13
+ final_classifications (list): A list of tuples, where each tuple is an instance of a non-methylated feature in the read. The tuple contains: feature start, feature length, feature classification, and HMM probability.
14
+ """
15
+ import numpy as np
16
+
17
+ sequence = list(read)
18
+ # Get the predicted states using the MAP algorithm
19
+ predicted_states = model.predict(sequence, algorithm='map')
20
+
21
+ # Get the probabilities for each state using the forward-backward algorithm
22
+ probabilities = model.predict_proba(sequence)
23
+
24
+ # Initialize lists to store the classifications and their probabilities
25
+ classifications = []
26
+ current_start = None
27
+ current_length = 0
28
+ current_probs = []
29
+
30
+ for i, state_index in enumerate(predicted_states):
31
+ state_name = model.states[state_index].name
32
+ state_prob = probabilities[i][state_index]
33
+
34
+ if state_name == "Non-Methylated":
35
+ if current_start is None:
36
+ current_start = i
37
+ current_length += 1
38
+ current_probs.append(state_prob)
39
+ else:
40
+ if current_start is not None:
41
+ avg_prob = np.mean(current_probs)
42
+ classifications.append((current_start, current_length, "Non-Methylated", avg_prob))
43
+ current_start = None
44
+ current_length = 0
45
+ current_probs = []
46
+
47
+ if current_start is not None:
48
+ avg_prob = np.mean(current_probs)
49
+ classifications.append((current_start, current_length, "Non-Methylated", avg_prob))
50
+
51
+ final_classifications = []
52
+ for start, length, classification, prob in classifications:
53
+ final_classification = None
54
+ feature_length = int(coordinates[start + length - 1]) - int(coordinates[start]) + 1
55
+ for feature_type, size_range in classification_mapping.items():
56
+ if size_range[0] <= feature_length < size_range[1]:
57
+ final_classification = feature_type
58
+ break
59
+ else:
60
+ pass
61
+ if not final_classification:
62
+ final_classification = classification
63
+
64
+ final_classifications.append((int(coordinates[start]) + 1, feature_length, final_classification, prob))
65
+
66
+ return final_classifications
67
+
68
+ # if feature_length < 80:
69
+ # final_classification = 'small_bound_stretch'
70
+ # elif 80 <= feature_length <= 200:
71
+ # final_classification = 'Putative_Nucleosome'
72
+ # elif 200 < feature_length:
73
+ # final_classification = 'large_bound_stretch'
74
+ # else:
75
+ # pass
@@ -0,0 +1,32 @@
1
+ # subset_adata
2
+
3
+ def subset_adata(adata, obs_columns):
4
+ """
5
+ Subsets an AnnData object based on categorical values in specified `.obs` columns.
6
+
7
+ Parameters:
8
+ adata (AnnData): The AnnData object to subset.
9
+ obs_columns (list of str): List of `.obs` column names to subset by. The order matters.
10
+
11
+ Returns:
12
+ dict: A dictionary where keys are tuples of category values and values are corresponding AnnData subsets.
13
+ """
14
+
15
+ def subset_recursive(adata_subset, columns):
16
+ if not columns:
17
+ return {(): adata_subset}
18
+
19
+ current_column = columns[0]
20
+ categories = adata_subset.obs[current_column].cat.categories
21
+
22
+ subsets = {}
23
+ for cat in categories:
24
+ subset = adata_subset[adata_subset.obs[current_column] == cat]
25
+ subsets.update(subset_recursive(subset, columns[1:]))
26
+
27
+ return subsets
28
+
29
+ # Start the recursive subset process
30
+ subsets_dict = subset_recursive(adata, obs_columns)
31
+
32
+ return subsets_dict
@@ -0,0 +1,46 @@
1
+ # subset_adata
2
+
3
+ def subset_adata(adata, columns, cat_type='obs'):
4
+ """
5
+ Subsets an AnnData object based on categorical values in specified .obs or .var columns.
6
+
7
+ Parameters:
8
+ adata (AnnData): The AnnData object to subset.
9
+ columns (list of str): List of .obs or .var column names to subset by. The order matters.
10
+ cat_type (str): obs or var. Default is obs
11
+
12
+ Returns:
13
+ dict: A dictionary where keys are tuples of category values and values are corresponding AnnData subsets.
14
+ """
15
+
16
+ def subset_recursive(adata_subset, columns, cat_type, key_prefix=()):
17
+ # Returns when the bottom of the stack is reached
18
+ if not columns:
19
+ # If there's only one column, return the key as a single value, not a tuple
20
+ if len(key_prefix) == 1:
21
+ return {key_prefix[0]: adata_subset}
22
+ return {key_prefix: adata_subset}
23
+
24
+ current_column = columns[0]
25
+ subsets = {}
26
+
27
+ if 'obs' in cat_type:
28
+ categories = adata_subset.obs[current_column].cat.categories
29
+ for cat in categories:
30
+ subset = adata_subset[adata_subset.obs[current_column] == cat].copy()
31
+ new_key = key_prefix + (cat,)
32
+ subsets.update(subset_recursive(subset, columns[1:], cat_type, new_key))
33
+
34
+ elif 'var' in cat_type:
35
+ categories = adata_subset.var[current_column].cat.categories
36
+ for cat in categories:
37
+ subset = adata_subset[:, adata_subset.var[current_column] == cat].copy()
38
+ new_key = key_prefix + (cat,)
39
+ subsets.update(subset_recursive(subset, columns[1:], cat_type, new_key))
40
+
41
+ return subsets
42
+
43
+ # Start the recursive subset process
44
+ subsets_dict = subset_recursive(adata, columns, cat_type)
45
+
46
+ return subsets_dict
@@ -0,0 +1,62 @@
1
+ def calculate_umap(adata, layer='nan_half', var_filters=None, n_pcs=15, knn_neighbors=100, overwrite=True, threads=8):
2
+ import scanpy as sc
3
+ import numpy as np
4
+ import os
5
+ from scipy.sparse import issparse
6
+
7
+ os.environ["OMP_NUM_THREADS"] = str(threads)
8
+
9
+ # Step 1: Apply var filter
10
+ if var_filters:
11
+ subset_mask = np.logical_or.reduce([adata.var[f].values for f in var_filters])
12
+ adata_subset = adata[:, subset_mask].copy()
13
+ print(f"Subsetting adata: Retained {adata_subset.shape[1]} features based on filters {var_filters}")
14
+ else:
15
+ adata_subset = adata.copy()
16
+ print("No var filters provided. Using all features.")
17
+
18
+ # Step 2: NaN handling inside layer
19
+ if layer:
20
+ data = adata_subset.layers[layer]
21
+ if not issparse(data):
22
+ if np.isnan(data).any():
23
+ print("⚠ NaNs detected, filling with 0.5 before PCA + neighbors.")
24
+ data = np.nan_to_num(data, nan=0.5)
25
+ adata_subset.layers[layer] = data
26
+ else:
27
+ print("No NaNs detected.")
28
+ else:
29
+ print("Sparse matrix detected; skipping NaN check (sparse formats typically do not store NaNs).")
30
+
31
+ # Step 3: PCA + neighbors + UMAP on subset
32
+ if "X_umap" not in adata_subset.obsm or overwrite:
33
+ n_pcs = min(adata_subset.shape[1], n_pcs)
34
+ print(f"Running PCA with n_pcs={n_pcs}")
35
+ sc.pp.pca(adata_subset, layer=layer)
36
+ print('Running neighborhood graph')
37
+ sc.pp.neighbors(adata_subset, use_rep="X_pca", n_pcs=n_pcs, n_neighbors=knn_neighbors)
38
+ print('Running UMAP')
39
+ sc.tl.umap(adata_subset)
40
+
41
+ # Step 4: Store results in original adata
42
+ adata.obsm["X_pca"] = adata_subset.obsm["X_pca"]
43
+ adata.obsm["X_umap"] = adata_subset.obsm["X_umap"]
44
+ adata.obsp["distances"] = adata_subset.obsp["distances"]
45
+ adata.obsp["connectivities"] = adata_subset.obsp["connectivities"]
46
+ adata.uns["neighbors"] = adata_subset.uns["neighbors"]
47
+
48
+
49
+ # Fix varm["PCs"] shape mismatch
50
+ pc_matrix = np.zeros((adata.shape[1], adata_subset.varm["PCs"].shape[1]))
51
+ if var_filters:
52
+ subset_mask = np.logical_or.reduce([adata.var[f].values for f in var_filters])
53
+ pc_matrix[subset_mask, :] = adata_subset.varm["PCs"]
54
+ else:
55
+ pc_matrix = adata_subset.varm["PCs"] # No subsetting case
56
+
57
+ adata.varm["PCs"] = pc_matrix
58
+
59
+
60
+ print(f"Stored: adata.obsm['X_pca'] and adata.obsm['X_umap']")
61
+
62
+ return adata
@@ -0,0 +1,105 @@
1
+ # cluster_adata_on_methylation
2
+
3
+ def cluster_adata_on_methylation(adata, obs_columns, method='hierarchical', n_clusters=3, layer=None, site_types = ['GpC_site', 'CpG_site']):
4
+ """
5
+ Adds cluster groups to the adata object as an observation column
6
+
7
+ Parameters:
8
+ adata
9
+ obs_columns
10
+ method
11
+ n_clusters
12
+ layer
13
+ site_types
14
+
15
+ Returns:
16
+ None
17
+ """
18
+ import pandas as pd
19
+ import numpy as np
20
+ from . import subset_adata
21
+ from ..readwrite import adata_to_df
22
+
23
+ # Ensure obs_columns are categorical
24
+ for col in obs_columns:
25
+ adata.obs[col] = adata.obs[col].astype('category')
26
+
27
+ references = adata.obs['Reference'].cat.categories
28
+
29
+ # Add subset metadata to the adata
30
+ subset_adata(adata, obs_columns)
31
+
32
+ subgroup_name = '_'.join(obs_columns)
33
+ subgroups = adata.obs[subgroup_name].cat.categories
34
+
35
+ subgroup_to_reference_map = {}
36
+ for subgroup in subgroups:
37
+ for reference in references:
38
+ if reference in subgroup:
39
+ subgroup_to_reference_map[subgroup] = reference
40
+ else:
41
+ pass
42
+
43
+ if method == 'hierarchical':
44
+ for site_type in site_types:
45
+ adata.obs[f'{site_type}_{layer}_hierarchical_clustering_index_within_{subgroup_name}'] = pd.Series(-1, index=adata.obs_names, dtype=int)
46
+ elif method == 'kmeans':
47
+ for site_type in site_types:
48
+ adata.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}'] = pd.Series(-1, index=adata.obs_names, dtype=int)
49
+
50
+ for subgroup in subgroups:
51
+ subgroup_subset = adata[adata.obs[subgroup_name] == subgroup].copy()
52
+ reference = subgroup_to_reference_map[subgroup]
53
+ for site_type in site_types:
54
+ site_subset = subgroup_subset[:, np.array(subgroup_subset.var[f'{reference}_{site_type}'])].copy()
55
+ df = adata_to_df(site_subset, layer=layer)
56
+ df2 = df.reset_index(drop=True)
57
+ if method == 'hierarchical':
58
+ try:
59
+ from scipy.cluster.hierarchy import linkage, dendrogram
60
+ # Perform hierarchical clustering on rows using the average linkage method and Euclidean metric
61
+ row_linkage = linkage(df2.values, method='average', metric='euclidean')
62
+
63
+ # Generate the dendrogram to get the ordered indices
64
+ dendro = dendrogram(row_linkage, no_plot=True)
65
+ reordered_row_indices = np.array(dendro['leaves']).astype(int)
66
+
67
+ # Get the reordered observation names
68
+ reordered_obs_names = [df.index[i] for i in reordered_row_indices]
69
+
70
+ temp_obs_data = pd.DataFrame({f'{site_type}_{layer}_hierarchical_clustering_index_within_{subgroup_name}': np.arange(0, len(reordered_obs_names), 1)}, index=reordered_obs_names, dtype=int)
71
+ adata.obs.update(temp_obs_data)
72
+ except:
73
+ print(f'Error found in {subgroup} of {site_type}_{layer}_hierarchical_clustering_index_within_{subgroup_name}')
74
+ elif method == 'kmeans':
75
+ try:
76
+ from sklearn.cluster import KMeans
77
+ kmeans = KMeans(n_clusters=n_clusters)
78
+ kmeans.fit(site_subset.layers[layer])
79
+ # Get the cluster labels for each data point
80
+ cluster_labels = kmeans.labels_
81
+ # Add the kmeans cluster data as an observation to the anndata object
82
+ site_subset.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}'] = cluster_labels.astype(str)
83
+ # Calculate the mean of each observation categoty of each cluster
84
+ cluster_means = site_subset.obs.groupby(f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}').mean()
85
+ # Sort the cluster indices by mean methylation value
86
+ sorted_clusters = cluster_means.sort_values(by=f'{site_type}_row_methylation_means', ascending=False).index
87
+ # Create a mapping of the old cluster values to the new cluster values
88
+ sorted_cluster_mapping = {old: new for new, old in enumerate(sorted_clusters)}
89
+ # Apply the mapping to create a new observation value: kmeans_labels_reordered
90
+ site_subset.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}'] = site_subset.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}'].map(sorted_cluster_mapping)
91
+ temp_obs_data = pd.DataFrame({f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}': site_subset.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}']}, index=site_subset.obs_names, dtype=int)
92
+ adata.obs.update(temp_obs_data)
93
+ except:
94
+ print(f'Error found in {subgroup} of {site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}')
95
+
96
+ if method == 'hierarchical':
97
+ # Ensure that the observation values are type int
98
+ for site_type in site_types:
99
+ adata.obs[f'{site_type}_{layer}_hierarchical_clustering_index_within_{subgroup_name}'] = adata.obs[f'{site_type}_{layer}_hierarchical_clustering_index_within_{subgroup_name}'].astype(int)
100
+ elif method == 'kmeans':
101
+ # Ensure that the observation values are type int
102
+ for site_type in site_types:
103
+ adata.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}'] = adata.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}'].astype(int)
104
+
105
+ return None
@@ -0,0 +1,69 @@
1
+ def create_nan_mask_from_X(adata, new_layer_name="nan_mask"):
2
+ """
3
+ Generates a nan mask where 1 = NaN in adata.X and 0 = valid value.
4
+ """
5
+ import numpy as np
6
+ nan_mask = np.isnan(adata.X).astype(int)
7
+ adata.layers[new_layer_name] = nan_mask
8
+ print(f"Created '{new_layer_name}' layer based on NaNs in adata.X")
9
+ return adata
10
+
11
+ def create_nan_or_non_gpc_mask(adata, obs_column, new_layer_name="nan_or_non_gpc_mask"):
12
+ import numpy as np
13
+
14
+ nan_mask = np.isnan(adata.X).astype(int)
15
+ combined_mask = np.zeros_like(nan_mask)
16
+
17
+ for idx, row in enumerate(adata.obs.itertuples()):
18
+ ref = getattr(row, obs_column)
19
+ gpc_mask = adata.var[f"{ref}_GpC_site"].astype(int).values
20
+ combined_mask[idx, :] = 1 - gpc_mask # non-GpC is 1
21
+
22
+ mask = np.maximum(nan_mask, combined_mask)
23
+ adata.layers[new_layer_name] = mask
24
+
25
+ print(f"Created '{new_layer_name}' layer based on NaNs in adata.X and non-GpC regions using {obs_column}")
26
+ return adata
27
+
28
+ def combine_layers(adata, input_layers, output_layer, negative_mask=None, values=None, binary_mode=False):
29
+ """
30
+ Combines layers into a single layer with specific coding:
31
+ - Background stays 0
32
+ - If binary_mode=True: any overlap = 1
33
+ - If binary_mode=False:
34
+ - Defaults to [1, 2, 3, ...] if values=None
35
+ - Later layers take precedence in overlaps
36
+
37
+ Parameters:
38
+ adata: AnnData object
39
+ input_layers: list of str
40
+ output_layer: str, name of the output layer
41
+ negative_mask: str (optional), binary mask to enforce 0s
42
+ values: list of ints (optional), values to assign to each input layer
43
+ binary_mode: bool, if True, creates a simple 0/1 mask regardless of values
44
+
45
+ Returns:
46
+ Updated AnnData with new layer.
47
+ """
48
+ import numpy as np
49
+ combined = np.zeros_like(adata.layers[input_layers[0]])
50
+
51
+ if binary_mode:
52
+ for layer in input_layers:
53
+ combined = np.logical_or(combined, adata.layers[layer] > 0)
54
+ combined = combined.astype(int)
55
+ else:
56
+ if values is None:
57
+ values = list(range(1, len(input_layers) + 1))
58
+ for i, layer in enumerate(input_layers):
59
+ arr = adata.layers[layer]
60
+ combined[arr > 0] = values[i]
61
+
62
+ if negative_mask:
63
+ mask = adata.layers[negative_mask]
64
+ combined[mask == 0] = 0
65
+
66
+ adata.layers[output_layer] = combined
67
+ print(f"Combined layers into {output_layer} {'(binary)' if binary_mode else f'with values {values}'}")
68
+
69
+ return adata