smftools 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. smftools/__init__.py +29 -0
  2. smftools/_settings.py +20 -0
  3. smftools/_version.py +1 -0
  4. smftools/datasets/F1_hybrid_NKG2A_enhander_promoter_GpC_conversion_SMF.h5ad.gz +0 -0
  5. smftools/datasets/F1_sample_sheet.csv +5 -0
  6. smftools/datasets/__init__.py +9 -0
  7. smftools/datasets/dCas9_m6A_invitro_kinetics.h5ad.gz +0 -0
  8. smftools/datasets/datasets.py +28 -0
  9. smftools/informatics/__init__.py +16 -0
  10. smftools/informatics/archived/bam_conversion.py +59 -0
  11. smftools/informatics/archived/bam_direct.py +63 -0
  12. smftools/informatics/archived/basecalls_to_adata.py +71 -0
  13. smftools/informatics/archived/print_bam_query_seq.py +29 -0
  14. smftools/informatics/basecall_pod5s.py +80 -0
  15. smftools/informatics/conversion_smf.py +132 -0
  16. smftools/informatics/direct_smf.py +137 -0
  17. smftools/informatics/fast5_to_pod5.py +21 -0
  18. smftools/informatics/helpers/LoadExperimentConfig.py +75 -0
  19. smftools/informatics/helpers/__init__.py +74 -0
  20. smftools/informatics/helpers/align_and_sort_BAM.py +59 -0
  21. smftools/informatics/helpers/aligned_BAM_to_bed.py +74 -0
  22. smftools/informatics/helpers/archived/informatics.py +260 -0
  23. smftools/informatics/helpers/archived/load_adata.py +516 -0
  24. smftools/informatics/helpers/bam_qc.py +66 -0
  25. smftools/informatics/helpers/bed_to_bigwig.py +39 -0
  26. smftools/informatics/helpers/binarize_converted_base_identities.py +79 -0
  27. smftools/informatics/helpers/canoncall.py +34 -0
  28. smftools/informatics/helpers/complement_base_list.py +21 -0
  29. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +55 -0
  30. smftools/informatics/helpers/converted_BAM_to_adata.py +245 -0
  31. smftools/informatics/helpers/converted_BAM_to_adata_II.py +369 -0
  32. smftools/informatics/helpers/count_aligned_reads.py +43 -0
  33. smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
  34. smftools/informatics/helpers/extract_base_identities.py +44 -0
  35. smftools/informatics/helpers/extract_mods.py +83 -0
  36. smftools/informatics/helpers/extract_read_features_from_bam.py +31 -0
  37. smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
  38. smftools/informatics/helpers/extract_readnames_from_BAM.py +22 -0
  39. smftools/informatics/helpers/find_conversion_sites.py +50 -0
  40. smftools/informatics/helpers/generate_converted_FASTA.py +99 -0
  41. smftools/informatics/helpers/get_chromosome_lengths.py +32 -0
  42. smftools/informatics/helpers/get_native_references.py +28 -0
  43. smftools/informatics/helpers/index_fasta.py +12 -0
  44. smftools/informatics/helpers/make_dirs.py +21 -0
  45. smftools/informatics/helpers/make_modbed.py +27 -0
  46. smftools/informatics/helpers/modQC.py +27 -0
  47. smftools/informatics/helpers/modcall.py +36 -0
  48. smftools/informatics/helpers/modkit_extract_to_adata.py +884 -0
  49. smftools/informatics/helpers/ohe_batching.py +76 -0
  50. smftools/informatics/helpers/ohe_layers_decode.py +32 -0
  51. smftools/informatics/helpers/one_hot_decode.py +27 -0
  52. smftools/informatics/helpers/one_hot_encode.py +57 -0
  53. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +53 -0
  54. smftools/informatics/helpers/run_multiqc.py +28 -0
  55. smftools/informatics/helpers/separate_bam_by_bc.py +43 -0
  56. smftools/informatics/helpers/split_and_index_BAM.py +36 -0
  57. smftools/informatics/load_adata.py +182 -0
  58. smftools/informatics/readwrite.py +106 -0
  59. smftools/informatics/subsample_fasta_from_bed.py +47 -0
  60. smftools/informatics/subsample_pod5.py +104 -0
  61. smftools/plotting/__init__.py +15 -0
  62. smftools/plotting/classifiers.py +355 -0
  63. smftools/plotting/general_plotting.py +205 -0
  64. smftools/plotting/position_stats.py +462 -0
  65. smftools/preprocessing/__init__.py +33 -0
  66. smftools/preprocessing/append_C_context.py +82 -0
  67. smftools/preprocessing/archives/mark_duplicates.py +146 -0
  68. smftools/preprocessing/archives/preprocessing.py +614 -0
  69. smftools/preprocessing/archives/remove_duplicates.py +21 -0
  70. smftools/preprocessing/binarize_on_Youden.py +45 -0
  71. smftools/preprocessing/binary_layers_to_ohe.py +40 -0
  72. smftools/preprocessing/calculate_complexity.py +72 -0
  73. smftools/preprocessing/calculate_consensus.py +47 -0
  74. smftools/preprocessing/calculate_converted_read_methylation_stats.py +94 -0
  75. smftools/preprocessing/calculate_coverage.py +42 -0
  76. smftools/preprocessing/calculate_pairwise_differences.py +49 -0
  77. smftools/preprocessing/calculate_pairwise_hamming_distances.py +27 -0
  78. smftools/preprocessing/calculate_position_Youden.py +115 -0
  79. smftools/preprocessing/calculate_read_length_stats.py +79 -0
  80. smftools/preprocessing/clean_NaN.py +46 -0
  81. smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
  82. smftools/preprocessing/filter_converted_reads_on_methylation.py +44 -0
  83. smftools/preprocessing/filter_reads_on_length.py +51 -0
  84. smftools/preprocessing/flag_duplicate_reads.py +149 -0
  85. smftools/preprocessing/invert_adata.py +30 -0
  86. smftools/preprocessing/load_sample_sheet.py +38 -0
  87. smftools/preprocessing/make_dirs.py +21 -0
  88. smftools/preprocessing/min_non_diagonal.py +25 -0
  89. smftools/preprocessing/recipes.py +127 -0
  90. smftools/preprocessing/subsample_adata.py +58 -0
  91. smftools/readwrite.py +198 -0
  92. smftools/tools/__init__.py +49 -0
  93. smftools/tools/apply_hmm.py +202 -0
  94. smftools/tools/apply_hmm_batched.py +241 -0
  95. smftools/tools/archived/classify_methylated_features.py +66 -0
  96. smftools/tools/archived/classify_non_methylated_features.py +75 -0
  97. smftools/tools/archived/subset_adata_v1.py +32 -0
  98. smftools/tools/archived/subset_adata_v2.py +46 -0
  99. smftools/tools/calculate_distances.py +18 -0
  100. smftools/tools/calculate_umap.py +62 -0
  101. smftools/tools/call_hmm_peaks.py +105 -0
  102. smftools/tools/classifiers.py +787 -0
  103. smftools/tools/cluster_adata_on_methylation.py +105 -0
  104. smftools/tools/data/__init__.py +2 -0
  105. smftools/tools/data/anndata_data_module.py +90 -0
  106. smftools/tools/data/preprocessing.py +6 -0
  107. smftools/tools/display_hmm.py +18 -0
  108. smftools/tools/evaluation/__init__.py +0 -0
  109. smftools/tools/general_tools.py +69 -0
  110. smftools/tools/hmm_readwrite.py +16 -0
  111. smftools/tools/inference/__init__.py +1 -0
  112. smftools/tools/inference/lightning_inference.py +41 -0
  113. smftools/tools/models/__init__.py +9 -0
  114. smftools/tools/models/base.py +14 -0
  115. smftools/tools/models/cnn.py +34 -0
  116. smftools/tools/models/lightning_base.py +41 -0
  117. smftools/tools/models/mlp.py +17 -0
  118. smftools/tools/models/positional.py +17 -0
  119. smftools/tools/models/rnn.py +16 -0
  120. smftools/tools/models/sklearn_models.py +40 -0
  121. smftools/tools/models/transformer.py +133 -0
  122. smftools/tools/models/wrappers.py +20 -0
  123. smftools/tools/nucleosome_hmm_refinement.py +104 -0
  124. smftools/tools/position_stats.py +239 -0
  125. smftools/tools/read_stats.py +70 -0
  126. smftools/tools/subset_adata.py +28 -0
  127. smftools/tools/train_hmm.py +78 -0
  128. smftools/tools/training/__init__.py +1 -0
  129. smftools/tools/training/train_lightning_model.py +47 -0
  130. smftools/tools/utils/__init__.py +2 -0
  131. smftools/tools/utils/device.py +10 -0
  132. smftools/tools/utils/grl.py +14 -0
  133. {smftools-0.1.6.dist-info → smftools-0.1.7.dist-info}/METADATA +5 -2
  134. smftools-0.1.7.dist-info/RECORD +136 -0
  135. smftools-0.1.6.dist-info/RECORD +0 -4
  136. {smftools-0.1.6.dist-info → smftools-0.1.7.dist-info}/WHEEL +0 -0
  137. {smftools-0.1.6.dist-info → smftools-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,66 @@
1
+ # classify_methylated_features
2
+
3
+ def classify_methylated_features(read, model, coordinates, classification_mapping={}):
4
+ """
5
+ Classifies methylated features (accessible features or CpG methylated features)
6
+
7
+ Parameters:
8
+ read (np.ndarray) : An array of binarized SMF data representing a read
9
+ model (): a trained pomegranate HMM
10
+ coordinates (list): a list of postional coordinates corresponding to the positions in the read
11
+ classification_mapping (dict): A dictionary keyed by classification name that points to a 2-element list containing size boundary constraints for that feature.
12
+ Returns:
13
+ final_classifications (list): A list of tuples, where each tuple is an instance of a non-methylated feature in the read. The tuple contains: feature start, feature length, feature classification, and HMM probability.
14
+ """
15
+ import numpy as np
16
+
17
+ sequence = list(read)
18
+ # Get the predicted states using the MAP algorithm
19
+ predicted_states = model.predict(sequence, algorithm='map')
20
+
21
+ # Get the probabilities for each state using the forward-backward algorithm
22
+ probabilities = model.predict_proba(sequence)
23
+
24
+ # Initialize lists to store the classifications and their probabilities
25
+ classifications = []
26
+ current_start = None
27
+ current_length = 0
28
+ current_probs = []
29
+
30
+ for i, state_index in enumerate(predicted_states):
31
+ state_name = model.states[state_index].name
32
+ state_prob = probabilities[i][state_index]
33
+
34
+ if state_name == "Methylated":
35
+ if current_start is None:
36
+ current_start = i
37
+ current_length += 1
38
+ current_probs.append(state_prob)
39
+ else:
40
+ if current_start is not None:
41
+ avg_prob = np.mean(current_probs)
42
+ classifications.append((current_start, current_length, "Methylated", avg_prob))
43
+ current_start = None
44
+ current_length = 0
45
+ current_probs = []
46
+
47
+ if current_start is not None:
48
+ avg_prob = np.mean(current_probs)
49
+ classifications.append((current_start, current_length, "Methylated", avg_prob))
50
+
51
+ final_classifications = []
52
+ for start, length, classification, prob in classifications:
53
+ final_classification = None
54
+ feature_length = int(coordinates[start + length - 1]) - int(coordinates[start]) + 1
55
+ for feature_type, size_range in classification_mapping.items():
56
+ if size_range[0] <= feature_length < size_range[1]:
57
+ final_classification = feature_type
58
+ break
59
+ else:
60
+ pass
61
+ if not final_classification:
62
+ final_classification = classification
63
+
64
+ final_classifications.append((int(coordinates[start]) + 1, feature_length, final_classification, prob))
65
+
66
+ return final_classifications
@@ -0,0 +1,75 @@
1
+ # classify_non_methylated_features
2
+
3
+ def classify_non_methylated_features(read, model, coordinates, classification_mapping={}):
4
+ """
5
+ Classifies non-methylated features (inaccessible features)
6
+
7
+ Parameters:
8
+ read (np.ndarray) : An array of binarized SMF data representing a read
9
+ model (): a trained pomegranate HMM
10
+ coordinates (list): a list of postional coordinates corresponding to the positions in the read
11
+ classification_mapping (dict): A dictionary keyed by classification name that points to a 2-element list containing size boundary constraints for that feature.
12
+ Returns:
13
+ final_classifications (list): A list of tuples, where each tuple is an instance of a non-methylated feature in the read. The tuple contains: feature start, feature length, feature classification, and HMM probability.
14
+ """
15
+ import numpy as np
16
+
17
+ sequence = list(read)
18
+ # Get the predicted states using the MAP algorithm
19
+ predicted_states = model.predict(sequence, algorithm='map')
20
+
21
+ # Get the probabilities for each state using the forward-backward algorithm
22
+ probabilities = model.predict_proba(sequence)
23
+
24
+ # Initialize lists to store the classifications and their probabilities
25
+ classifications = []
26
+ current_start = None
27
+ current_length = 0
28
+ current_probs = []
29
+
30
+ for i, state_index in enumerate(predicted_states):
31
+ state_name = model.states[state_index].name
32
+ state_prob = probabilities[i][state_index]
33
+
34
+ if state_name == "Non-Methylated":
35
+ if current_start is None:
36
+ current_start = i
37
+ current_length += 1
38
+ current_probs.append(state_prob)
39
+ else:
40
+ if current_start is not None:
41
+ avg_prob = np.mean(current_probs)
42
+ classifications.append((current_start, current_length, "Non-Methylated", avg_prob))
43
+ current_start = None
44
+ current_length = 0
45
+ current_probs = []
46
+
47
+ if current_start is not None:
48
+ avg_prob = np.mean(current_probs)
49
+ classifications.append((current_start, current_length, "Non-Methylated", avg_prob))
50
+
51
+ final_classifications = []
52
+ for start, length, classification, prob in classifications:
53
+ final_classification = None
54
+ feature_length = int(coordinates[start + length - 1]) - int(coordinates[start]) + 1
55
+ for feature_type, size_range in classification_mapping.items():
56
+ if size_range[0] <= feature_length < size_range[1]:
57
+ final_classification = feature_type
58
+ break
59
+ else:
60
+ pass
61
+ if not final_classification:
62
+ final_classification = classification
63
+
64
+ final_classifications.append((int(coordinates[start]) + 1, feature_length, final_classification, prob))
65
+
66
+ return final_classifications
67
+
68
+ # if feature_length < 80:
69
+ # final_classification = 'small_bound_stretch'
70
+ # elif 80 <= feature_length <= 200:
71
+ # final_classification = 'Putative_Nucleosome'
72
+ # elif 200 < feature_length:
73
+ # final_classification = 'large_bound_stretch'
74
+ # else:
75
+ # pass
@@ -0,0 +1,32 @@
1
+ # subset_adata
2
+
3
+ def subset_adata(adata, obs_columns):
4
+ """
5
+ Subsets an AnnData object based on categorical values in specified `.obs` columns.
6
+
7
+ Parameters:
8
+ adata (AnnData): The AnnData object to subset.
9
+ obs_columns (list of str): List of `.obs` column names to subset by. The order matters.
10
+
11
+ Returns:
12
+ dict: A dictionary where keys are tuples of category values and values are corresponding AnnData subsets.
13
+ """
14
+
15
+ def subset_recursive(adata_subset, columns):
16
+ if not columns:
17
+ return {(): adata_subset}
18
+
19
+ current_column = columns[0]
20
+ categories = adata_subset.obs[current_column].cat.categories
21
+
22
+ subsets = {}
23
+ for cat in categories:
24
+ subset = adata_subset[adata_subset.obs[current_column] == cat]
25
+ subsets.update(subset_recursive(subset, columns[1:]))
26
+
27
+ return subsets
28
+
29
+ # Start the recursive subset process
30
+ subsets_dict = subset_recursive(adata, obs_columns)
31
+
32
+ return subsets_dict
@@ -0,0 +1,46 @@
1
+ # subset_adata
2
+
3
+ def subset_adata(adata, columns, cat_type='obs'):
4
+ """
5
+ Subsets an AnnData object based on categorical values in specified .obs or .var columns.
6
+
7
+ Parameters:
8
+ adata (AnnData): The AnnData object to subset.
9
+ columns (list of str): List of .obs or .var column names to subset by. The order matters.
10
+ cat_type (str): obs or var. Default is obs
11
+
12
+ Returns:
13
+ dict: A dictionary where keys are tuples of category values and values are corresponding AnnData subsets.
14
+ """
15
+
16
+ def subset_recursive(adata_subset, columns, cat_type, key_prefix=()):
17
+ # Returns when the bottom of the stack is reached
18
+ if not columns:
19
+ # If there's only one column, return the key as a single value, not a tuple
20
+ if len(key_prefix) == 1:
21
+ return {key_prefix[0]: adata_subset}
22
+ return {key_prefix: adata_subset}
23
+
24
+ current_column = columns[0]
25
+ subsets = {}
26
+
27
+ if 'obs' in cat_type:
28
+ categories = adata_subset.obs[current_column].cat.categories
29
+ for cat in categories:
30
+ subset = adata_subset[adata_subset.obs[current_column] == cat].copy()
31
+ new_key = key_prefix + (cat,)
32
+ subsets.update(subset_recursive(subset, columns[1:], cat_type, new_key))
33
+
34
+ elif 'var' in cat_type:
35
+ categories = adata_subset.var[current_column].cat.categories
36
+ for cat in categories:
37
+ subset = adata_subset[:, adata_subset.var[current_column] == cat].copy()
38
+ new_key = key_prefix + (cat,)
39
+ subsets.update(subset_recursive(subset, columns[1:], cat_type, new_key))
40
+
41
+ return subsets
42
+
43
+ # Start the recursive subset process
44
+ subsets_dict = subset_recursive(adata, columns, cat_type)
45
+
46
+ return subsets_dict
@@ -0,0 +1,18 @@
1
+ # calculate_distances
2
+
3
+ def calculate_distances(intervals, threshold=0.9):
4
+ """
5
+ Calculates distance between features in a read.
6
+ Takes in a list of intervals (start of feature, length of feature)
7
+ """
8
+ # Sort intervals by start position
9
+ intervals = sorted(intervals, key=lambda x: x[0])
10
+ intervals = [interval for interval in intervals if interval[2] > threshold]
11
+
12
+ # Calculate distances
13
+ distances = []
14
+ for i in range(len(intervals) - 1):
15
+ end_current = intervals[i][0] + intervals[i][1]
16
+ start_next = intervals[i + 1][0]
17
+ distances.append(start_next - end_current)
18
+ return distances
@@ -0,0 +1,62 @@
1
+ def calculate_umap(adata, layer='nan_half', var_filters=None, n_pcs=15, knn_neighbors=100, overwrite=True, threads=8):
2
+ import scanpy as sc
3
+ import numpy as np
4
+ import os
5
+ from scipy.sparse import issparse
6
+
7
+ os.environ["OMP_NUM_THREADS"] = str(threads)
8
+
9
+ # Step 1: Apply var filter
10
+ if var_filters:
11
+ subset_mask = np.logical_or.reduce([adata.var[f].values for f in var_filters])
12
+ adata_subset = adata[:, subset_mask].copy()
13
+ print(f"🔹 Subsetting adata: Retained {adata_subset.shape[1]} features based on filters {var_filters}")
14
+ else:
15
+ adata_subset = adata.copy()
16
+ print("🔹 No var filters provided. Using all features.")
17
+
18
+ # Step 2: NaN handling inside layer
19
+ if layer:
20
+ data = adata_subset.layers[layer]
21
+ if not issparse(data):
22
+ if np.isnan(data).any():
23
+ print("⚠ NaNs detected, filling with 0.5 before PCA + neighbors.")
24
+ data = np.nan_to_num(data, nan=0.5)
25
+ adata_subset.layers[layer] = data
26
+ else:
27
+ print("✅ No NaNs detected.")
28
+ else:
29
+ print("✅ Sparse matrix detected; skipping NaN check (sparse formats typically do not store NaNs).")
30
+
31
+ # Step 3: PCA + neighbors + UMAP on subset
32
+ if "X_umap" not in adata_subset.obsm or overwrite:
33
+ n_pcs = min(adata_subset.shape[1], n_pcs)
34
+ print(f"Running PCA with n_pcs={n_pcs}")
35
+ sc.pp.pca(adata_subset, layer=layer)
36
+ print('Running neighborhood graph')
37
+ sc.pp.neighbors(adata_subset, use_rep="X_pca", n_pcs=n_pcs, n_neighbors=knn_neighbors)
38
+ print('Running UMAP')
39
+ sc.tl.umap(adata_subset)
40
+
41
+ # Step 4: Store results in original adata
42
+ adata.obsm["X_pca"] = adata_subset.obsm["X_pca"]
43
+ adata.obsm["X_umap"] = adata_subset.obsm["X_umap"]
44
+ adata.obsp["distances"] = adata_subset.obsp["distances"]
45
+ adata.obsp["connectivities"] = adata_subset.obsp["connectivities"]
46
+ adata.uns["neighbors"] = adata_subset.uns["neighbors"]
47
+
48
+
49
+ # Fix varm["PCs"] shape mismatch
50
+ pc_matrix = np.zeros((adata.shape[1], adata_subset.varm["PCs"].shape[1]))
51
+ if var_filters:
52
+ subset_mask = np.logical_or.reduce([adata.var[f].values for f in var_filters])
53
+ pc_matrix[subset_mask, :] = adata_subset.varm["PCs"]
54
+ else:
55
+ pc_matrix = adata_subset.varm["PCs"] # No subsetting case
56
+
57
+ adata.varm["PCs"] = pc_matrix
58
+
59
+
60
+ print(f"✅ Stored: adata.obsm['X_pca'] and adata.obsm['X_umap']")
61
+
62
+ return adata
@@ -0,0 +1,105 @@
1
+ def call_hmm_peaks(adata, feature_configs, obs_column='Reference_strand', site_types=['GpC_site', 'CpG_site'], save_plot=False, output_dir=None, date_tag=None):
2
+ """
3
+ Calls peaks from HMM feature layers and annotates them into the AnnData object.
4
+
5
+ Parameters:
6
+ adata : AnnData object with HMM layers (from apply_hmm)
7
+ feature_configs : dict
8
+ min_distance : minimum distance between peaks
9
+ peak_width : window size around peak centers
10
+ peak_prominence : required peak prominence
11
+ peak_threshold : threshold for labeling a read as "present" at a peak
12
+ site_types : list of var site types to aggregate
13
+ save_plot : whether to save the plot
14
+ output_dir : path to save the figure if save_plot=True
15
+ date_tag : optional tag for filename
16
+ """
17
+ import matplotlib.pyplot as plt
18
+ from scipy.signal import find_peaks
19
+ import os
20
+ import numpy as np
21
+
22
+ peak_columns = []
23
+
24
+ for feature_layer, config in feature_configs.items():
25
+ min_distance = config.get('min_distance', 200)
26
+ peak_width = config.get('peak_width', 200)
27
+ peak_prominence = config.get('peak_prominence', 0.2)
28
+ peak_threshold = config.get('peak_threshold', 0.8)
29
+
30
+ # 1️⃣ Calculate mean intensity profile
31
+ matrix = adata.layers[feature_layer]
32
+ means = np.mean(matrix, axis=0)
33
+ feature_peak_columns = []
34
+
35
+ # 2️⃣ Peak calling
36
+ peak_centers, _ = find_peaks(means, prominence=peak_prominence, distance=min_distance)
37
+ adata.uns[f'{feature_layer} peak_centers'] = peak_centers
38
+
39
+ # 3️⃣ Plot
40
+ plt.figure(figsize=(6, 3))
41
+ plt.plot(range(len(means)), means)
42
+ plt.title(f"{feature_layer} density with peak calls")
43
+ plt.xlabel("Genomic position")
44
+ plt.ylabel("Mean feature density")
45
+ y = max(means) / 2
46
+ for i, center in enumerate(peak_centers):
47
+ plus_minus_width = peak_width // 2
48
+ start = center - plus_minus_width
49
+ end = center + plus_minus_width
50
+ plt.axvspan(start, end, color='purple', alpha=0.2)
51
+ plt.axvline(center, color='red', linestyle='--')
52
+ if i%2:
53
+ aligned = [end, 'left']
54
+ else:
55
+ aligned = [start, 'right']
56
+ plt.text(aligned[0], 0, f"Peak {i}\n{center}", color='red', ha=aligned[1])
57
+
58
+ if save_plot and output_dir:
59
+ filename = f"{output_dir}/{date_tag or 'output'}_{feature_layer}_peaks.png"
60
+ plt.savefig(filename, bbox_inches='tight')
61
+ print(f"Saved plot to {filename}")
62
+ else:
63
+ plt.show()
64
+
65
+ # 4️⃣ Annotate peaks back into adata.obs
66
+ for center in peak_centers:
67
+ half_width = peak_width // 2
68
+ start, end = center - half_width, center + half_width
69
+ colname = f'{feature_layer}_peak_{center}'
70
+ peak_columns.append(colname)
71
+ feature_peak_columns.append(colname)
72
+
73
+ adata.var[colname] = (
74
+ (adata.var_names.astype(int) >= start) &
75
+ (adata.var_names.astype(int) <= end)
76
+ )
77
+
78
+ # Feature layer intensity around peak
79
+ mean_values = np.mean(matrix[:, start:end+1], axis=1)
80
+ sum_values = np.sum(matrix[:, start:end+1], axis=1)
81
+ adata.obs[f'mean_{feature_layer}_around_{center}'] = mean_values
82
+ adata.obs[f'sum_{feature_layer}_around_{center}'] = sum_values
83
+ adata.obs[f'{feature_layer}_present_at_{center}'] = mean_values > peak_threshold
84
+
85
+ # Site-type based aggregation
86
+ for site_type in site_types:
87
+ adata.obs[f'{site_type}_sum_around_{center}'] = 0
88
+ adata.obs[f'{site_type}_mean_around_{center}'] = np.nan
89
+
90
+ references = adata.obs[obs_column].cat.categories
91
+ for ref in adata.obs[obs_column].cat.categories:
92
+ subset = adata[adata.obs[obs_column] == ref]
93
+ for site_type in site_types:
94
+ mask = subset.var.get(f'{ref}_{site_type}', None)
95
+ if mask is not None:
96
+ region_mask = (subset.var_names[mask].astype(int) >= start) & (subset.var_names[mask].astype(int) <= end)
97
+ region = subset[:, mask].X[:, region_mask]
98
+ adata.obs.loc[subset.obs.index, f'{site_type}_sum_around_{center}'] = np.nansum(region, axis=1)
99
+ adata.obs.loc[subset.obs.index, f'{site_type}_mean_around_{center}'] = np.nanmean(region, axis=1)
100
+
101
+ adata.var[f'is_in_any_{feature_layer}_peak'] = adata.var[feature_peak_columns].any(axis=1)
102
+ print(f"✅ Peak annotation completed for {feature_layer} with {len(peak_centers)} peaks.")
103
+
104
+ # Combine all peaks into a single "is_in_any_peak" column
105
+ adata.var['is_in_any_peak'] = adata.var[peak_columns].any(axis=1)