smftools 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. smftools/__init__.py +6 -8
  2. smftools/_settings.py +4 -6
  3. smftools/_version.py +1 -1
  4. smftools/cli/helpers.py +7 -1
  5. smftools/cli/hmm_adata.py +902 -244
  6. smftools/cli/load_adata.py +318 -198
  7. smftools/cli/preprocess_adata.py +285 -171
  8. smftools/cli/spatial_adata.py +137 -53
  9. smftools/cli_entry.py +94 -178
  10. smftools/config/__init__.py +1 -1
  11. smftools/config/conversion.yaml +5 -1
  12. smftools/config/deaminase.yaml +1 -1
  13. smftools/config/default.yaml +22 -17
  14. smftools/config/direct.yaml +8 -3
  15. smftools/config/discover_input_files.py +19 -5
  16. smftools/config/experiment_config.py +505 -276
  17. smftools/constants.py +37 -0
  18. smftools/datasets/__init__.py +2 -8
  19. smftools/datasets/datasets.py +32 -18
  20. smftools/hmm/HMM.py +2125 -1426
  21. smftools/hmm/__init__.py +2 -3
  22. smftools/hmm/archived/call_hmm_peaks.py +16 -1
  23. smftools/hmm/call_hmm_peaks.py +173 -193
  24. smftools/hmm/display_hmm.py +19 -6
  25. smftools/hmm/hmm_readwrite.py +13 -4
  26. smftools/hmm/nucleosome_hmm_refinement.py +102 -14
  27. smftools/informatics/__init__.py +30 -7
  28. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  30. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  31. smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
  32. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
  33. smftools/informatics/archived/print_bam_query_seq.py +7 -1
  34. smftools/informatics/bam_functions.py +379 -156
  35. smftools/informatics/basecalling.py +51 -9
  36. smftools/informatics/bed_functions.py +90 -57
  37. smftools/informatics/binarize_converted_base_identities.py +18 -7
  38. smftools/informatics/complement_base_list.py +7 -6
  39. smftools/informatics/converted_BAM_to_adata.py +265 -122
  40. smftools/informatics/fasta_functions.py +161 -83
  41. smftools/informatics/h5ad_functions.py +195 -29
  42. smftools/informatics/modkit_extract_to_adata.py +609 -270
  43. smftools/informatics/modkit_functions.py +85 -44
  44. smftools/informatics/ohe.py +44 -21
  45. smftools/informatics/pod5_functions.py +112 -73
  46. smftools/informatics/run_multiqc.py +20 -14
  47. smftools/logging_utils.py +51 -0
  48. smftools/machine_learning/__init__.py +2 -7
  49. smftools/machine_learning/data/anndata_data_module.py +143 -50
  50. smftools/machine_learning/data/preprocessing.py +2 -1
  51. smftools/machine_learning/evaluation/__init__.py +1 -1
  52. smftools/machine_learning/evaluation/eval_utils.py +11 -14
  53. smftools/machine_learning/evaluation/evaluators.py +46 -33
  54. smftools/machine_learning/inference/__init__.py +1 -1
  55. smftools/machine_learning/inference/inference_utils.py +7 -4
  56. smftools/machine_learning/inference/lightning_inference.py +9 -13
  57. smftools/machine_learning/inference/sklearn_inference.py +6 -8
  58. smftools/machine_learning/inference/sliding_window_inference.py +35 -25
  59. smftools/machine_learning/models/__init__.py +10 -5
  60. smftools/machine_learning/models/base.py +28 -42
  61. smftools/machine_learning/models/cnn.py +15 -11
  62. smftools/machine_learning/models/lightning_base.py +71 -40
  63. smftools/machine_learning/models/mlp.py +13 -4
  64. smftools/machine_learning/models/positional.py +3 -2
  65. smftools/machine_learning/models/rnn.py +3 -2
  66. smftools/machine_learning/models/sklearn_models.py +39 -22
  67. smftools/machine_learning/models/transformer.py +68 -53
  68. smftools/machine_learning/models/wrappers.py +2 -1
  69. smftools/machine_learning/training/__init__.py +2 -2
  70. smftools/machine_learning/training/train_lightning_model.py +29 -20
  71. smftools/machine_learning/training/train_sklearn_model.py +9 -15
  72. smftools/machine_learning/utils/__init__.py +1 -1
  73. smftools/machine_learning/utils/device.py +7 -4
  74. smftools/machine_learning/utils/grl.py +3 -1
  75. smftools/metadata.py +443 -0
  76. smftools/plotting/__init__.py +19 -5
  77. smftools/plotting/autocorrelation_plotting.py +145 -44
  78. smftools/plotting/classifiers.py +162 -72
  79. smftools/plotting/general_plotting.py +347 -168
  80. smftools/plotting/hmm_plotting.py +42 -13
  81. smftools/plotting/position_stats.py +145 -85
  82. smftools/plotting/qc_plotting.py +20 -12
  83. smftools/preprocessing/__init__.py +8 -8
  84. smftools/preprocessing/append_base_context.py +105 -79
  85. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  86. smftools/preprocessing/{archives → archived}/calculate_complexity.py +3 -1
  87. smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
  88. smftools/preprocessing/binarize.py +21 -4
  89. smftools/preprocessing/binarize_on_Youden.py +127 -31
  90. smftools/preprocessing/binary_layers_to_ohe.py +17 -11
  91. smftools/preprocessing/calculate_complexity_II.py +86 -59
  92. smftools/preprocessing/calculate_consensus.py +28 -19
  93. smftools/preprocessing/calculate_coverage.py +44 -22
  94. smftools/preprocessing/calculate_pairwise_differences.py +2 -1
  95. smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
  96. smftools/preprocessing/calculate_position_Youden.py +103 -55
  97. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  98. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  99. smftools/preprocessing/clean_NaN.py +38 -28
  100. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  101. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +70 -37
  102. smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
  103. smftools/preprocessing/flag_duplicate_reads.py +688 -271
  104. smftools/preprocessing/invert_adata.py +26 -11
  105. smftools/preprocessing/load_sample_sheet.py +40 -22
  106. smftools/preprocessing/make_dirs.py +8 -3
  107. smftools/preprocessing/min_non_diagonal.py +2 -1
  108. smftools/preprocessing/recipes.py +56 -23
  109. smftools/preprocessing/reindex_references_adata.py +93 -27
  110. smftools/preprocessing/subsample_adata.py +33 -16
  111. smftools/readwrite.py +264 -109
  112. smftools/schema/__init__.py +11 -0
  113. smftools/schema/anndata_schema_v1.yaml +227 -0
  114. smftools/tools/__init__.py +3 -4
  115. smftools/tools/archived/classifiers.py +163 -0
  116. smftools/tools/archived/subset_adata_v1.py +10 -1
  117. smftools/tools/archived/subset_adata_v2.py +12 -1
  118. smftools/tools/calculate_umap.py +54 -15
  119. smftools/tools/cluster_adata_on_methylation.py +115 -46
  120. smftools/tools/general_tools.py +70 -25
  121. smftools/tools/position_stats.py +229 -98
  122. smftools/tools/read_stats.py +50 -29
  123. smftools/tools/spatial_autocorrelation.py +365 -192
  124. smftools/tools/subset_adata.py +23 -21
  125. {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/METADATA +15 -43
  126. smftools-0.2.5.dist-info/RECORD +181 -0
  127. smftools-0.2.4.dist-info/RECORD +0 -176
  128. /smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +0 -0
  129. /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
  130. /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
  131. {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
  132. {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
  133. {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,7 +1,42 @@
1
- def calculate_umap(adata, layer='nan_half', var_filters=None, n_pcs=15, knn_neighbors=100, overwrite=True, threads=8):
2
- import scanpy as sc
3
- import numpy as np
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Sequence
4
+
5
+ from smftools.logging_utils import get_logger
6
+
7
+ if TYPE_CHECKING:
8
+ import anndata as ad
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ def calculate_umap(
14
+ adata: "ad.AnnData",
15
+ layer: str | None = "nan_half",
16
+ var_filters: Sequence[str] | None = None,
17
+ n_pcs: int = 15,
18
+ knn_neighbors: int = 100,
19
+ overwrite: bool = True,
20
+ threads: int = 8,
21
+ ) -> "ad.AnnData":
22
+ """Compute PCA, neighbors, and UMAP embeddings.
23
+
24
+ Args:
25
+ adata: AnnData object to update.
26
+ layer: Layer name to use for PCA/UMAP (``None`` uses ``adata.X``).
27
+ var_filters: Optional list of var masks to subset features.
28
+ n_pcs: Number of principal components.
29
+ knn_neighbors: Number of neighbors for the graph.
30
+ overwrite: Whether to recompute embeddings if they exist.
31
+ threads: Number of OMP threads for computation.
32
+
33
+ Returns:
34
+ anndata.AnnData: Updated AnnData object.
35
+ """
4
36
  import os
37
+
38
+ import numpy as np
39
+ import scanpy as sc
5
40
  from scipy.sparse import issparse
6
41
 
7
42
  os.environ["OMP_NUM_THREADS"] = str(threads)
@@ -10,32 +45,38 @@ def calculate_umap(adata, layer='nan_half', var_filters=None, n_pcs=15, knn_neig
10
45
  if var_filters:
11
46
  subset_mask = np.logical_or.reduce([adata.var[f].values for f in var_filters])
12
47
  adata_subset = adata[:, subset_mask].copy()
13
- print(f"Subsetting adata: Retained {adata_subset.shape[1]} features based on filters {var_filters}")
48
+ logger.info(
49
+ "Subsetting adata: retained %s features based on filters %s",
50
+ adata_subset.shape[1],
51
+ var_filters,
52
+ )
14
53
  else:
15
54
  adata_subset = adata.copy()
16
- print("No var filters provided. Using all features.")
55
+ logger.info("No var filters provided. Using all features.")
17
56
 
18
57
  # Step 2: NaN handling inside layer
19
58
  if layer:
20
59
  data = adata_subset.layers[layer]
21
60
  if not issparse(data):
22
61
  if np.isnan(data).any():
23
- print("NaNs detected, filling with 0.5 before PCA + neighbors.")
62
+ logger.warning("NaNs detected, filling with 0.5 before PCA + neighbors.")
24
63
  data = np.nan_to_num(data, nan=0.5)
25
64
  adata_subset.layers[layer] = data
26
65
  else:
27
- print("No NaNs detected.")
66
+ logger.info("No NaNs detected.")
28
67
  else:
29
- print("Sparse matrix detected; skipping NaN check (sparse formats typically do not store NaNs).")
68
+ logger.info(
69
+ "Sparse matrix detected; skipping NaN check (sparse formats typically do not store NaNs)."
70
+ )
30
71
 
31
72
  # Step 3: PCA + neighbors + UMAP on subset
32
73
  if "X_umap" not in adata_subset.obsm or overwrite:
33
74
  n_pcs = min(adata_subset.shape[1], n_pcs)
34
- print(f"Running PCA with n_pcs={n_pcs}")
75
+ logger.info("Running PCA with n_pcs=%s", n_pcs)
35
76
  sc.pp.pca(adata_subset, layer=layer)
36
- print('Running neighborhood graph')
77
+ logger.info("Running neighborhood graph")
37
78
  sc.pp.neighbors(adata_subset, use_rep="X_pca", n_pcs=n_pcs, n_neighbors=knn_neighbors)
38
- print('Running UMAP')
79
+ logger.info("Running UMAP")
39
80
  sc.tl.umap(adata_subset)
40
81
 
41
82
  # Step 4: Store results in original adata
@@ -45,7 +86,6 @@ def calculate_umap(adata, layer='nan_half', var_filters=None, n_pcs=15, knn_neig
45
86
  adata.obsp["connectivities"] = adata_subset.obsp["connectivities"]
46
87
  adata.uns["neighbors"] = adata_subset.uns["neighbors"]
47
88
 
48
-
49
89
  # Fix varm["PCs"] shape mismatch
50
90
  pc_matrix = np.zeros((adata.shape[1], adata_subset.varm["PCs"].shape[1]))
51
91
  if var_filters:
@@ -56,7 +96,6 @@ def calculate_umap(adata, layer='nan_half', var_filters=None, n_pcs=15, knn_neig
56
96
 
57
97
  adata.varm["PCs"] = pc_matrix
58
98
 
99
+ logger.info("Stored: adata.obsm['X_pca'] and adata.obsm['X_umap']")
59
100
 
60
- print(f"Stored: adata.obsm['X_pca'] and adata.obsm['X_umap']")
61
-
62
- return adata
101
+ return adata
@@ -1,35 +1,50 @@
1
+ from __future__ import annotations
2
+
1
3
  # cluster_adata_on_methylation
4
+ from typing import TYPE_CHECKING, Sequence
2
5
 
3
- def cluster_adata_on_methylation(adata, obs_columns, method='hierarchical', n_clusters=3, layer=None, site_types = ['GpC_site', 'CpG_site']):
4
- """
5
- Adds cluster groups to the adata object as an observation column
6
-
7
- Parameters:
8
- adata
9
- obs_columns
10
- method
11
- n_clusters
12
- layer
13
- site_types
14
-
15
- Returns:
16
- None
6
+ from smftools.logging_utils import get_logger
7
+
8
+ if TYPE_CHECKING:
9
+ import anndata as ad
10
+
11
+ logger = get_logger(__name__)
12
+
13
+
14
+ def cluster_adata_on_methylation(
15
+ adata: "ad.AnnData",
16
+ obs_columns: Sequence[str],
17
+ method: str = "hierarchical",
18
+ n_clusters: int = 3,
19
+ layer: str | None = None,
20
+ site_types: Sequence[str] = ("GpC_site", "CpG_site"),
21
+ ) -> None:
22
+ """Add clustering groups to ``adata.obs`` based on methylation patterns.
23
+
24
+ Args:
25
+ adata: AnnData object to annotate.
26
+ obs_columns: Observation columns to define subgroups.
27
+ method: Clustering method (``"hierarchical"`` or ``"kmeans"``).
28
+ n_clusters: Number of clusters for k-means.
29
+ layer: Layer to use for clustering.
30
+ site_types: Site types to analyze.
17
31
  """
18
- import pandas as pd
19
32
  import numpy as np
20
- from . import subset_adata
33
+ import pandas as pd
34
+
21
35
  from ..readwrite import adata_to_df
36
+ from . import subset_adata
22
37
 
23
38
  # Ensure obs_columns are categorical
24
39
  for col in obs_columns:
25
- adata.obs[col] = adata.obs[col].astype('category')
40
+ adata.obs[col] = adata.obs[col].astype("category")
26
41
 
27
- references = adata.obs['Reference'].cat.categories
42
+ references = adata.obs["Reference"].cat.categories
28
43
 
29
44
  # Add subset metadata to the adata
30
45
  subset_adata(adata, obs_columns)
31
46
 
32
- subgroup_name = '_'.join(obs_columns)
47
+ subgroup_name = "_".join(obs_columns)
33
48
  subgroups = adata.obs[subgroup_name].cat.categories
34
49
 
35
50
  subgroup_to_reference_map = {}
@@ -40,66 +55,120 @@ def cluster_adata_on_methylation(adata, obs_columns, method='hierarchical', n_cl
40
55
  else:
41
56
  pass
42
57
 
43
- if method == 'hierarchical':
58
+ if method == "hierarchical":
44
59
  for site_type in site_types:
45
- adata.obs[f'{site_type}_{layer}_hierarchical_clustering_index_within_{subgroup_name}'] = pd.Series(-1, index=adata.obs_names, dtype=int)
46
- elif method == 'kmeans':
60
+ adata.obs[
61
+ f"{site_type}_{layer}_hierarchical_clustering_index_within_{subgroup_name}"
62
+ ] = pd.Series(-1, index=adata.obs_names, dtype=int)
63
+ elif method == "kmeans":
47
64
  for site_type in site_types:
48
- adata.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}'] = pd.Series(-1, index=adata.obs_names, dtype=int)
49
-
65
+ adata.obs[f"{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}"] = (
66
+ pd.Series(-1, index=adata.obs_names, dtype=int)
67
+ )
68
+
50
69
  for subgroup in subgroups:
51
70
  subgroup_subset = adata[adata.obs[subgroup_name] == subgroup].copy()
52
71
  reference = subgroup_to_reference_map[subgroup]
53
72
  for site_type in site_types:
54
- site_subset = subgroup_subset[:, np.array(subgroup_subset.var[f'{reference}_{site_type}'])].copy()
73
+ site_subset = subgroup_subset[
74
+ :, np.array(subgroup_subset.var[f"{reference}_{site_type}"])
75
+ ].copy()
55
76
  df = adata_to_df(site_subset, layer=layer)
56
77
  df2 = df.reset_index(drop=True)
57
- if method == 'hierarchical':
78
+ if method == "hierarchical":
58
79
  try:
59
- from scipy.cluster.hierarchy import linkage, dendrogram
80
+ from scipy.cluster.hierarchy import dendrogram, linkage
81
+
60
82
  # Perform hierarchical clustering on rows using the average linkage method and Euclidean metric
61
- row_linkage = linkage(df2.values, method='average', metric='euclidean')
83
+ row_linkage = linkage(df2.values, method="average", metric="euclidean")
62
84
 
63
85
  # Generate the dendrogram to get the ordered indices
64
86
  dendro = dendrogram(row_linkage, no_plot=True)
65
- reordered_row_indices = np.array(dendro['leaves']).astype(int)
87
+ reordered_row_indices = np.array(dendro["leaves"]).astype(int)
66
88
 
67
89
  # Get the reordered observation names
68
90
  reordered_obs_names = [df.index[i] for i in reordered_row_indices]
69
91
 
70
- temp_obs_data = pd.DataFrame({f'{site_type}_{layer}_hierarchical_clustering_index_within_{subgroup_name}': np.arange(0, len(reordered_obs_names), 1)}, index=reordered_obs_names, dtype=int)
92
+ temp_obs_data = pd.DataFrame(
93
+ {
94
+ f"{site_type}_{layer}_hierarchical_clustering_index_within_{subgroup_name}": np.arange(
95
+ 0, len(reordered_obs_names), 1
96
+ )
97
+ },
98
+ index=reordered_obs_names,
99
+ dtype=int,
100
+ )
71
101
  adata.obs.update(temp_obs_data)
72
- except:
73
- print(f'Error found in {subgroup} of {site_type}_{layer}_hierarchical_clustering_index_within_{subgroup_name}')
74
- elif method == 'kmeans':
102
+ except Exception:
103
+ logger.exception(
104
+ "Error found in %s of %s_%s_hierarchical_clustering_index_within_%s",
105
+ subgroup,
106
+ site_type,
107
+ layer,
108
+ subgroup_name,
109
+ )
110
+ elif method == "kmeans":
75
111
  try:
76
112
  from sklearn.cluster import KMeans
113
+
77
114
  kmeans = KMeans(n_clusters=n_clusters)
78
115
  kmeans.fit(site_subset.layers[layer])
79
116
  # Get the cluster labels for each data point
80
117
  cluster_labels = kmeans.labels_
81
118
  # Add the kmeans cluster data as an observation to the anndata object
82
- site_subset.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}'] = cluster_labels.astype(str)
119
+ site_subset.obs[
120
+ f"{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}"
121
+ ] = cluster_labels.astype(str)
83
122
  # Calculate the mean of each observation categoty of each cluster
84
- cluster_means = site_subset.obs.groupby(f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}').mean()
123
+ cluster_means = site_subset.obs.groupby(
124
+ f"{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}"
125
+ ).mean()
85
126
  # Sort the cluster indices by mean methylation value
86
- sorted_clusters = cluster_means.sort_values(by=f'{site_type}_row_methylation_means', ascending=False).index
127
+ sorted_clusters = cluster_means.sort_values(
128
+ by=f"{site_type}_row_methylation_means", ascending=False
129
+ ).index
87
130
  # Create a mapping of the old cluster values to the new cluster values
88
131
  sorted_cluster_mapping = {old: new for new, old in enumerate(sorted_clusters)}
89
132
  # Apply the mapping to create a new observation value: kmeans_labels_reordered
90
- site_subset.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}'] = site_subset.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}'].map(sorted_cluster_mapping)
91
- temp_obs_data = pd.DataFrame({f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}': site_subset.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}']}, index=site_subset.obs_names, dtype=int)
92
- adata.obs.update(temp_obs_data)
93
- except:
94
- print(f'Error found in {subgroup} of {site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}')
133
+ site_subset.obs[
134
+ f"{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}"
135
+ ] = site_subset.obs[
136
+ f"{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}"
137
+ ].map(sorted_cluster_mapping)
138
+ temp_obs_data = pd.DataFrame(
139
+ {
140
+ f"{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}": site_subset.obs[
141
+ f"{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}"
142
+ ]
143
+ },
144
+ index=site_subset.obs_names,
145
+ dtype=int,
146
+ )
147
+ adata.obs.update(temp_obs_data)
148
+ except Exception:
149
+ logger.exception(
150
+ "Error found in %s of %s_%s_kmeans_clustering_index_within_%s",
151
+ subgroup,
152
+ site_type,
153
+ layer,
154
+ subgroup_name,
155
+ )
95
156
 
96
- if method == 'hierarchical':
157
+ if method == "hierarchical":
97
158
  # Ensure that the observation values are type int
98
159
  for site_type in site_types:
99
- adata.obs[f'{site_type}_{layer}_hierarchical_clustering_index_within_{subgroup_name}'] = adata.obs[f'{site_type}_{layer}_hierarchical_clustering_index_within_{subgroup_name}'].astype(int)
100
- elif method == 'kmeans':
160
+ adata.obs[
161
+ f"{site_type}_{layer}_hierarchical_clustering_index_within_{subgroup_name}"
162
+ ] = adata.obs[
163
+ f"{site_type}_{layer}_hierarchical_clustering_index_within_{subgroup_name}"
164
+ ].astype(int)
165
+ elif method == "kmeans":
101
166
  # Ensure that the observation values are type int
102
167
  for site_type in site_types:
103
- adata.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}'] = adata.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}'].astype(int)
168
+ adata.obs[f"{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}"] = (
169
+ adata.obs[
170
+ f"{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}"
171
+ ].astype(int)
172
+ )
104
173
 
105
- return None
174
+ return None
@@ -1,14 +1,48 @@
1
- def create_nan_mask_from_X(adata, new_layer_name="nan_mask"):
2
- """
3
- Generates a nan mask where 1 = NaN in adata.X and 0 = valid value.
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Sequence
4
+
5
+ from smftools.logging_utils import get_logger
6
+
7
+ if TYPE_CHECKING:
8
+ import anndata as ad
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ def create_nan_mask_from_X(adata: "ad.AnnData", new_layer_name: str = "nan_mask") -> "ad.AnnData":
14
+ """Generate a NaN mask layer from ``adata.X``.
15
+
16
+ Args:
17
+ adata: AnnData object.
18
+ new_layer_name: Name of the output mask layer.
19
+
20
+ Returns:
21
+ anndata.AnnData: Updated AnnData object.
4
22
  """
5
23
  import numpy as np
24
+
6
25
  nan_mask = np.isnan(adata.X).astype(int)
7
26
  adata.layers[new_layer_name] = nan_mask
8
- print(f"Created '{new_layer_name}' layer based on NaNs in adata.X")
27
+ logger.info("Created '%s' layer based on NaNs in adata.X", new_layer_name)
9
28
  return adata
10
29
 
11
- def create_nan_or_non_gpc_mask(adata, obs_column, new_layer_name="nan_or_non_gpc_mask"):
30
+
31
+ def create_nan_or_non_gpc_mask(
32
+ adata: "ad.AnnData",
33
+ obs_column: str,
34
+ new_layer_name: str = "nan_or_non_gpc_mask",
35
+ ) -> "ad.AnnData":
36
+ """Generate a mask layer combining NaNs and non-GpC positions.
37
+
38
+ Args:
39
+ adata: AnnData object.
40
+ obs_column: Obs column used to derive reference-specific GpC masks.
41
+ new_layer_name: Name of the output mask layer.
42
+
43
+ Returns:
44
+ anndata.AnnData: Updated AnnData object.
45
+ """
12
46
  import numpy as np
13
47
 
14
48
  nan_mask = np.isnan(adata.X).astype(int)
@@ -22,30 +56,37 @@ def create_nan_or_non_gpc_mask(adata, obs_column, new_layer_name="nan_or_non_gpc
22
56
  mask = np.maximum(nan_mask, combined_mask)
23
57
  adata.layers[new_layer_name] = mask
24
58
 
25
- print(f"Created '{new_layer_name}' layer based on NaNs in adata.X and non-GpC regions using {obs_column}")
59
+ logger.info(
60
+ "Created '%s' layer based on NaNs in adata.X and non-GpC regions using %s",
61
+ new_layer_name,
62
+ obs_column,
63
+ )
26
64
  return adata
27
65
 
28
- def combine_layers(adata, input_layers, output_layer, negative_mask=None, values=None, binary_mode=False):
29
- """
30
- Combines layers into a single layer with specific coding:
31
- - Background stays 0
32
- - If binary_mode=True: any overlap = 1
33
- - If binary_mode=False:
34
- - Defaults to [1, 2, 3, ...] if values=None
35
- - Later layers take precedence in overlaps
36
-
37
- Parameters:
38
- adata: AnnData object
39
- input_layers: list of str
40
- output_layer: str, name of the output layer
41
- negative_mask: str (optional), binary mask to enforce 0s
42
- values: list of ints (optional), values to assign to each input layer
43
- binary_mode: bool, if True, creates a simple 0/1 mask regardless of values
44
-
66
+
67
+ def combine_layers(
68
+ adata: "ad.AnnData",
69
+ input_layers: Sequence[str],
70
+ output_layer: str,
71
+ negative_mask: str | None = None,
72
+ values: Sequence[int] | None = None,
73
+ binary_mode: bool = False,
74
+ ) -> "ad.AnnData":
75
+ """Combine layers into a single coded layer.
76
+
77
+ Args:
78
+ adata: AnnData object.
79
+ input_layers: Input layer names.
80
+ output_layer: Name of the output layer.
81
+ negative_mask: Optional binary mask layer to enforce zeros.
82
+ values: Values assigned to each input layer when ``binary_mode`` is ``False``.
83
+ binary_mode: Whether to build a simple 0/1 mask.
84
+
45
85
  Returns:
46
- Updated AnnData with new layer.
86
+ anndata.AnnData: Updated AnnData object.
47
87
  """
48
88
  import numpy as np
89
+
49
90
  combined = np.zeros_like(adata.layers[input_layers[0]])
50
91
 
51
92
  if binary_mode:
@@ -64,6 +105,10 @@ def combine_layers(adata, input_layers, output_layer, negative_mask=None, values
64
105
  combined[mask == 0] = 0
65
106
 
66
107
  adata.layers[output_layer] = combined
67
- print(f"Combined layers into {output_layer} {'(binary)' if binary_mode else f'with values {values}'}")
108
+ logger.info(
109
+ "Combined layers into %s %s",
110
+ output_layer,
111
+ "(binary)" if binary_mode else f"with values {values}",
112
+ )
68
113
 
69
114
  return adata