smftools 0.1.3__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. smftools/__init__.py +5 -1
  2. smftools/_version.py +1 -1
  3. smftools/informatics/__init__.py +2 -0
  4. smftools/informatics/archived/print_bam_query_seq.py +29 -0
  5. smftools/informatics/basecall_pod5s.py +80 -0
  6. smftools/informatics/conversion_smf.py +63 -10
  7. smftools/informatics/direct_smf.py +66 -18
  8. smftools/informatics/helpers/LoadExperimentConfig.py +1 -0
  9. smftools/informatics/helpers/__init__.py +16 -2
  10. smftools/informatics/helpers/align_and_sort_BAM.py +27 -16
  11. smftools/informatics/helpers/aligned_BAM_to_bed.py +49 -48
  12. smftools/informatics/helpers/bam_qc.py +66 -0
  13. smftools/informatics/helpers/binarize_converted_base_identities.py +69 -21
  14. smftools/informatics/helpers/canoncall.py +12 -3
  15. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +5 -4
  16. smftools/informatics/helpers/converted_BAM_to_adata.py +34 -22
  17. smftools/informatics/helpers/converted_BAM_to_adata_II.py +369 -0
  18. smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
  19. smftools/informatics/helpers/extract_base_identities.py +33 -46
  20. smftools/informatics/helpers/extract_mods.py +55 -23
  21. smftools/informatics/helpers/extract_read_features_from_bam.py +31 -0
  22. smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
  23. smftools/informatics/helpers/find_conversion_sites.py +33 -44
  24. smftools/informatics/helpers/generate_converted_FASTA.py +87 -86
  25. smftools/informatics/helpers/modcall.py +13 -5
  26. smftools/informatics/helpers/modkit_extract_to_adata.py +762 -396
  27. smftools/informatics/helpers/ohe_batching.py +65 -41
  28. smftools/informatics/helpers/ohe_layers_decode.py +32 -0
  29. smftools/informatics/helpers/one_hot_decode.py +27 -0
  30. smftools/informatics/helpers/one_hot_encode.py +45 -9
  31. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +1 -0
  32. smftools/informatics/helpers/run_multiqc.py +28 -0
  33. smftools/informatics/helpers/split_and_index_BAM.py +3 -8
  34. smftools/informatics/load_adata.py +58 -3
  35. smftools/plotting/__init__.py +15 -0
  36. smftools/plotting/classifiers.py +355 -0
  37. smftools/plotting/general_plotting.py +205 -0
  38. smftools/plotting/position_stats.py +462 -0
  39. smftools/preprocessing/__init__.py +6 -7
  40. smftools/preprocessing/append_C_context.py +22 -9
  41. smftools/preprocessing/{mark_duplicates.py → archives/mark_duplicates.py} +38 -26
  42. smftools/preprocessing/binarize_on_Youden.py +35 -32
  43. smftools/preprocessing/binary_layers_to_ohe.py +13 -3
  44. smftools/preprocessing/calculate_complexity.py +3 -2
  45. smftools/preprocessing/calculate_converted_read_methylation_stats.py +44 -46
  46. smftools/preprocessing/calculate_coverage.py +26 -25
  47. smftools/preprocessing/calculate_pairwise_differences.py +49 -0
  48. smftools/preprocessing/calculate_position_Youden.py +18 -7
  49. smftools/preprocessing/calculate_read_length_stats.py +39 -46
  50. smftools/preprocessing/clean_NaN.py +33 -25
  51. smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
  52. smftools/preprocessing/filter_converted_reads_on_methylation.py +20 -5
  53. smftools/preprocessing/filter_reads_on_length.py +14 -4
  54. smftools/preprocessing/flag_duplicate_reads.py +149 -0
  55. smftools/preprocessing/invert_adata.py +18 -11
  56. smftools/preprocessing/load_sample_sheet.py +30 -16
  57. smftools/preprocessing/recipes.py +22 -20
  58. smftools/preprocessing/subsample_adata.py +58 -0
  59. smftools/readwrite.py +105 -13
  60. smftools/tools/__init__.py +49 -0
  61. smftools/tools/apply_hmm.py +202 -0
  62. smftools/tools/apply_hmm_batched.py +241 -0
  63. smftools/tools/archived/classify_methylated_features.py +66 -0
  64. smftools/tools/archived/classify_non_methylated_features.py +75 -0
  65. smftools/tools/archived/subset_adata_v1.py +32 -0
  66. smftools/tools/archived/subset_adata_v2.py +46 -0
  67. smftools/tools/calculate_distances.py +18 -0
  68. smftools/tools/calculate_umap.py +62 -0
  69. smftools/tools/call_hmm_peaks.py +105 -0
  70. smftools/tools/classifiers.py +787 -0
  71. smftools/tools/cluster_adata_on_methylation.py +105 -0
  72. smftools/tools/data/__init__.py +2 -0
  73. smftools/tools/data/anndata_data_module.py +90 -0
  74. smftools/tools/data/preprocessing.py +6 -0
  75. smftools/tools/display_hmm.py +18 -0
  76. smftools/tools/general_tools.py +69 -0
  77. smftools/tools/hmm_readwrite.py +16 -0
  78. smftools/tools/inference/__init__.py +1 -0
  79. smftools/tools/inference/lightning_inference.py +41 -0
  80. smftools/tools/models/__init__.py +9 -0
  81. smftools/tools/models/base.py +14 -0
  82. smftools/tools/models/cnn.py +34 -0
  83. smftools/tools/models/lightning_base.py +41 -0
  84. smftools/tools/models/mlp.py +17 -0
  85. smftools/tools/models/positional.py +17 -0
  86. smftools/tools/models/rnn.py +16 -0
  87. smftools/tools/models/sklearn_models.py +40 -0
  88. smftools/tools/models/transformer.py +133 -0
  89. smftools/tools/models/wrappers.py +20 -0
  90. smftools/tools/nucleosome_hmm_refinement.py +104 -0
  91. smftools/tools/position_stats.py +239 -0
  92. smftools/tools/read_stats.py +70 -0
  93. smftools/tools/subset_adata.py +19 -23
  94. smftools/tools/train_hmm.py +78 -0
  95. smftools/tools/training/__init__.py +1 -0
  96. smftools/tools/training/train_lightning_model.py +47 -0
  97. smftools/tools/utils/__init__.py +2 -0
  98. smftools/tools/utils/device.py +10 -0
  99. smftools/tools/utils/grl.py +14 -0
  100. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/METADATA +47 -11
  101. smftools-0.1.7.dist-info/RECORD +136 -0
  102. smftools/tools/apply_HMM.py +0 -1
  103. smftools/tools/read_HMM.py +0 -1
  104. smftools/tools/train_HMM.py +0 -43
  105. smftools-0.1.3.dist-info/RECORD +0 -84
  106. /smftools/preprocessing/{remove_duplicates.py → archives/remove_duplicates.py} +0 -0
  107. /smftools/tools/{cluster.py → evaluation/__init__.py} +0 -0
  108. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/WHEEL +0 -0
  109. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,105 @@
1
+ # cluster_adata_on_methylation
2
+
3
+ def cluster_adata_on_methylation(adata, obs_columns, method='hierarchical', n_clusters=3, layer=None, site_types = ['GpC_site', 'CpG_site']):
4
+ """
5
+ Adds cluster groups to the adata object as an observation column
6
+
7
+ Parameters:
8
+ adata
9
+ obs_columns
10
+ method
11
+ n_clusters
12
+ layer
13
+ site_types
14
+
15
+ Returns:
16
+ None
17
+ """
18
+ import pandas as pd
19
+ import numpy as np
20
+ from . import subset_adata
21
+ from ..readwrite import adata_to_df
22
+
23
+ # Ensure obs_columns are categorical
24
+ for col in obs_columns:
25
+ adata.obs[col] = adata.obs[col].astype('category')
26
+
27
+ references = adata.obs['Reference'].cat.categories
28
+
29
+ # Add subset metadata to the adata
30
+ subset_adata(adata, obs_columns)
31
+
32
+ subgroup_name = '_'.join(obs_columns)
33
+ subgroups = adata.obs[subgroup_name].cat.categories
34
+
35
+ subgroup_to_reference_map = {}
36
+ for subgroup in subgroups:
37
+ for reference in references:
38
+ if reference in subgroup:
39
+ subgroup_to_reference_map[subgroup] = reference
40
+ else:
41
+ pass
42
+
43
+ if method == 'hierarchical':
44
+ for site_type in site_types:
45
+ adata.obs[f'{site_type}_{layer}_hierarchical_clustering_index_within_{subgroup_name}'] = pd.Series(-1, index=adata.obs_names, dtype=int)
46
+ elif method == 'kmeans':
47
+ for site_type in site_types:
48
+ adata.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}'] = pd.Series(-1, index=adata.obs_names, dtype=int)
49
+
50
+ for subgroup in subgroups:
51
+ subgroup_subset = adata[adata.obs[subgroup_name] == subgroup].copy()
52
+ reference = subgroup_to_reference_map[subgroup]
53
+ for site_type in site_types:
54
+ site_subset = subgroup_subset[:, np.array(subgroup_subset.var[f'{reference}_{site_type}'])].copy()
55
+ df = adata_to_df(site_subset, layer=layer)
56
+ df2 = df.reset_index(drop=True)
57
+ if method == 'hierarchical':
58
+ try:
59
+ from scipy.cluster.hierarchy import linkage, dendrogram
60
+ # Perform hierarchical clustering on rows using the average linkage method and Euclidean metric
61
+ row_linkage = linkage(df2.values, method='average', metric='euclidean')
62
+
63
+ # Generate the dendrogram to get the ordered indices
64
+ dendro = dendrogram(row_linkage, no_plot=True)
65
+ reordered_row_indices = np.array(dendro['leaves']).astype(int)
66
+
67
+ # Get the reordered observation names
68
+ reordered_obs_names = [df.index[i] for i in reordered_row_indices]
69
+
70
+ temp_obs_data = pd.DataFrame({f'{site_type}_{layer}_hierarchical_clustering_index_within_{subgroup_name}': np.arange(0, len(reordered_obs_names), 1)}, index=reordered_obs_names, dtype=int)
71
+ adata.obs.update(temp_obs_data)
72
+ except:
73
+ print(f'Error found in {subgroup} of {site_type}_{layer}_hierarchical_clustering_index_within_{subgroup_name}')
74
+ elif method == 'kmeans':
75
+ try:
76
+ from sklearn.cluster import KMeans
77
+ kmeans = KMeans(n_clusters=n_clusters)
78
+ kmeans.fit(site_subset.layers[layer])
79
+ # Get the cluster labels for each data point
80
+ cluster_labels = kmeans.labels_
81
+ # Add the kmeans cluster data as an observation to the anndata object
82
+ site_subset.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}'] = cluster_labels.astype(str)
83
+ # Calculate the mean of each observation categoty of each cluster
84
+ cluster_means = site_subset.obs.groupby(f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}').mean()
85
+ # Sort the cluster indices by mean methylation value
86
+ sorted_clusters = cluster_means.sort_values(by=f'{site_type}_row_methylation_means', ascending=False).index
87
+ # Create a mapping of the old cluster values to the new cluster values
88
+ sorted_cluster_mapping = {old: new for new, old in enumerate(sorted_clusters)}
89
+ # Apply the mapping to create a new observation value: kmeans_labels_reordered
90
+ site_subset.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}'] = site_subset.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}'].map(sorted_cluster_mapping)
91
+ temp_obs_data = pd.DataFrame({f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}': site_subset.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}']}, index=site_subset.obs_names, dtype=int)
92
+ adata.obs.update(temp_obs_data)
93
+ except:
94
+ print(f'Error found in {subgroup} of {site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}')
95
+
96
+ if method == 'hierarchical':
97
+ # Ensure that the observation values are type int
98
+ for site_type in site_types:
99
+ adata.obs[f'{site_type}_{layer}_hierarchical_clustering_index_within_{subgroup_name}'] = adata.obs[f'{site_type}_{layer}_hierarchical_clustering_index_within_{subgroup_name}'].astype(int)
100
+ elif method == 'kmeans':
101
+ # Ensure that the observation values are type int
102
+ for site_type in site_types:
103
+ adata.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}'] = adata.obs[f'{site_type}_{layer}_kmeans_clustering_index_within_{subgroup_name}'].astype(int)
104
+
105
+ return None
@@ -0,0 +1,2 @@
1
+ from .anndata_data_module import AnnDataModule
2
+ from .preprocessing import random_fill_nans
@@ -0,0 +1,90 @@
1
+ import torch
2
+ from torch.utils.data import DataLoader, TensorDataset, random_split
3
+ import pytorch_lightning as pl
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ class AnnDataModule(pl.LightningDataModule):
8
+ def __init__(self, adata, tensor_source="X", tensor_key=None, label_col="labels",
9
+ batch_size=64, train_frac=0.7, random_seed=42, split_col='train_val_split', split_save_path=None, load_existing_split=False,
10
+ inference_mode=False):
11
+ super().__init__()
12
+ self.adata = adata # The adata object
13
+ self.tensor_source = tensor_source # X, layers, obsm
14
+ self.tensor_key = tensor_key # name of the layer or obsm key
15
+ self.label_col = label_col # name of the label column in obs
16
+ self.batch_size = batch_size
17
+ self.train_frac = train_frac
18
+ self.random_seed = random_seed
19
+ self.split_col = split_col # Name of obs column to store "train"/"val"
20
+ self.split_save_path = split_save_path # Where to save the obs_names and train/test split logging
21
+ self.load_existing_split = load_existing_split # Whether to load from an existing split
22
+ self.inference_mode = inference_mode # Whether to load the AnnDataModule in inference mode.
23
+
24
+ def setup(self, stage=None):
25
+ # Load feature matrix
26
+ if self.tensor_source == "X":
27
+ X = self.adata.X
28
+ elif self.tensor_source == "layers":
29
+ assert self.tensor_key in self.adata.layers, f"Layer '{self.tensor_key}' not found."
30
+ X = self.adata.layers[self.tensor_key]
31
+ elif self.tensor_source == "obsm":
32
+ assert self.tensor_key in self.adata.obsm, f"obsm key '{self.tensor_key}' not found."
33
+ X = self.adata.obsm[self.tensor_key]
34
+ else:
35
+ raise ValueError(f"Invalid tensor_source: {self.tensor_source}")
36
+
37
+ # Convert to tensor
38
+ X_tensor = torch.tensor(X, dtype=torch.float32)
39
+
40
+ if self.inference_mode:
41
+ self.infer_dataset = TensorDataset(X_tensor)
42
+
43
+ else:
44
+ # Load and encode labels
45
+ y = self.adata.obs[self.label_col]
46
+ if y.dtype.name == 'category':
47
+ y = y.cat.codes
48
+ y_tensor = torch.tensor(y.values, dtype=torch.long)
49
+
50
+ # Use existing split
51
+ if self.load_existing_split:
52
+ split_df = pd.read_csv(self.split_save_path, index_col=0)
53
+ assert self.split_col in split_df.columns, f"'{self.split_col}' column missing in split file."
54
+ self.adata.obs[self.split_col] = split_df.loc[self.adata.obs_names][self.split_col].values
55
+
56
+ # If no split exists, create one
57
+ if self.split_col not in self.adata.obs:
58
+ full_dataset = TensorDataset(X_tensor, y_tensor)
59
+ n_train = int(self.train_frac * len(full_dataset))
60
+ n_val = len(full_dataset) - n_train
61
+ self.train_set, self.val_set = random_split(
62
+ full_dataset, [n_train, n_val],
63
+ generator=torch.Generator().manual_seed(self.random_seed)
64
+ )
65
+ # Assign split labels
66
+ split_array = np.full(len(self.adata), "val", dtype=object)
67
+ train_idx = self.train_set.indices if hasattr(self.train_set, "indices") else self.train_set._indices
68
+ split_array[train_idx] = "train"
69
+ self.adata.obs[self.split_col] = split_array
70
+
71
+ # Save to disk
72
+ if self.split_save_path:
73
+ self.adata.obs[[self.split_col]].to_csv(self.split_save_path)
74
+ else:
75
+ split_labels = self.adata.obs[self.split_col].values
76
+ train_mask = split_labels == "train"
77
+ val_mask = split_labels == "val"
78
+ self.train_set = TensorDataset(X_tensor[train_mask], y_tensor[train_mask])
79
+ self.val_set = TensorDataset(X_tensor[val_mask], y_tensor[val_mask])
80
+
81
+ def train_dataloader(self):
82
+ return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True)
83
+
84
+ def val_dataloader(self):
85
+ return DataLoader(self.val_set, batch_size=self.batch_size)
86
+
87
+ def predict_dataloader(self):
88
+ if not self.inference_mode:
89
+ raise RuntimeError("predict_dataloader only available in inference mode.")
90
+ return DataLoader(self.infer_dataset, batch_size=self.batch_size)
@@ -0,0 +1,6 @@
1
+ import numpy as np
2
+
3
+ def random_fill_nans(X):
4
+ nan_mask = np.isnan(X)
5
+ X[nan_mask] = np.random.rand(*X[nan_mask].shape)
6
+ return X
@@ -0,0 +1,18 @@
1
+ def display_hmm(hmm, state_labels=["Non-Methylated", "Methylated"], obs_labels=["0", "1"]):
2
+ import torch
3
+ print("\n🔹 **HMM Model Overview**")
4
+ print(hmm)
5
+
6
+ print("\n🔹 **Transition Matrix**")
7
+ transition_matrix = torch.exp(hmm.edges).detach().cpu().numpy()
8
+ for i, row in enumerate(transition_matrix):
9
+ label = state_labels[i] if state_labels else f"State {i}"
10
+ formatted_row = ", ".join(f"{p:.6f}" for p in row)
11
+ print(f"{label}: [{formatted_row}]")
12
+
13
+ print("\n🔹 **Emission Probabilities**")
14
+ for i, dist in enumerate(hmm.distributions):
15
+ label = state_labels[i] if state_labels else f"State {i}"
16
+ probs = dist.probs.detach().cpu().numpy()
17
+ formatted_emissions = {obs_labels[j]: probs[j] for j in range(len(probs))}
18
+ print(f"{label}: {formatted_emissions}")
@@ -0,0 +1,69 @@
1
+ def create_nan_mask_from_X(adata, new_layer_name="nan_mask"):
2
+ """
3
+ Generates a nan mask where 1 = NaN in adata.X and 0 = valid value.
4
+ """
5
+ import numpy as np
6
+ nan_mask = np.isnan(adata.X).astype(int)
7
+ adata.layers[new_layer_name] = nan_mask
8
+ print(f"✅ Created '{new_layer_name}' layer based on NaNs in adata.X")
9
+ return adata
10
+
11
+ def create_nan_or_non_gpc_mask(adata, obs_column, new_layer_name="nan_or_non_gpc_mask"):
12
+ import numpy as np
13
+
14
+ nan_mask = np.isnan(adata.X).astype(int)
15
+ combined_mask = np.zeros_like(nan_mask)
16
+
17
+ for idx, row in enumerate(adata.obs.itertuples()):
18
+ ref = getattr(row, obs_column)
19
+ gpc_mask = adata.var[f"{ref}_GpC_site"].astype(int).values
20
+ combined_mask[idx, :] = 1 - gpc_mask # non-GpC is 1
21
+
22
+ mask = np.maximum(nan_mask, combined_mask)
23
+ adata.layers[new_layer_name] = mask
24
+
25
+ print(f"✅ Created '{new_layer_name}' layer based on NaNs in adata.X and non-GpC regions using {obs_column}")
26
+ return adata
27
+
28
+ def combine_layers(adata, input_layers, output_layer, negative_mask=None, values=None, binary_mode=False):
29
+ """
30
+ Combines layers into a single layer with specific coding:
31
+ - Background stays 0
32
+ - If binary_mode=True: any overlap = 1
33
+ - If binary_mode=False:
34
+ - Defaults to [1, 2, 3, ...] if values=None
35
+ - Later layers take precedence in overlaps
36
+
37
+ Parameters:
38
+ adata: AnnData object
39
+ input_layers: list of str
40
+ output_layer: str, name of the output layer
41
+ negative_mask: str (optional), binary mask to enforce 0s
42
+ values: list of ints (optional), values to assign to each input layer
43
+ binary_mode: bool, if True, creates a simple 0/1 mask regardless of values
44
+
45
+ Returns:
46
+ Updated AnnData with new layer.
47
+ """
48
+ import numpy as np
49
+ combined = np.zeros_like(adata.layers[input_layers[0]])
50
+
51
+ if binary_mode:
52
+ for layer in input_layers:
53
+ combined = np.logical_or(combined, adata.layers[layer] > 0)
54
+ combined = combined.astype(int)
55
+ else:
56
+ if values is None:
57
+ values = list(range(1, len(input_layers) + 1))
58
+ for i, layer in enumerate(input_layers):
59
+ arr = adata.layers[layer]
60
+ combined[arr > 0] = values[i]
61
+
62
+ if negative_mask:
63
+ mask = adata.layers[negative_mask]
64
+ combined[mask == 0] = 0
65
+
66
+ adata.layers[output_layer] = combined
67
+ print(f"✅ Combined layers into {output_layer} {'(binary)' if binary_mode else f'with values {values}'}")
68
+
69
+ return adata
@@ -0,0 +1,16 @@
1
+ def load_hmm(model_path, device='cpu'):
2
+ """
3
+ Reads in a pretrained HMM.
4
+
5
+ Parameters:
6
+ model_path (str): Path to a pretrained HMM
7
+ """
8
+ import torch
9
+ # Load model using PyTorch
10
+ hmm = torch.load(model_path)
11
+ hmm.to(device)
12
+ return hmm
13
+
14
+ def save_hmm(model, model_path):
15
+ import torch
16
+ torch.save(model, model_path)
@@ -0,0 +1 @@
1
+ from .lightning_inference import run_lightning_inference
@@ -0,0 +1,41 @@
1
+ import torch
2
+ import pandas as pd
3
+ import numpy as np
4
+ from pytorch_lightning import Trainer
5
+
6
+ def run_lightning_inference(
7
+ adata,
8
+ model,
9
+ datamodule,
10
+ label_col="labels",
11
+ prefix="model"
12
+ ):
13
+
14
+ # Get class labels
15
+ if label_col in adata.obs and pd.api.types.is_categorical_dtype(adata.obs[label_col]):
16
+ class_labels = adata.obs[label_col].cat.categories.tolist()
17
+ else:
18
+ raise ValueError("label_col must be a categorical column in adata.obs")
19
+
20
+ # Run predictions
21
+ trainer = Trainer(accelerator="auto", devices=1, logger=False, enable_checkpointing=False)
22
+ preds = trainer.predict(model, datamodule=datamodule)
23
+ probs = torch.cat(preds, dim=0).cpu().numpy() # (N, C)
24
+ pred_class_idx = probs.argmax(axis=1)
25
+ pred_class_labels = [class_labels[i] for i in pred_class_idx]
26
+ pred_class_probs = probs[np.arange(len(probs)), pred_class_idx]
27
+
28
+ # Construct full prefix with label_col
29
+ full_prefix = f"{prefix}_{label_col}"
30
+
31
+ # Store predictions in obs
32
+ adata.obs[f"{full_prefix}_pred"] = pred_class_idx
33
+ adata.obs[f"{full_prefix}_pred_label"] = pd.Categorical(pred_class_labels, categories=class_labels)
34
+ adata.obs[f"{full_prefix}_pred_prob"] = pred_class_probs
35
+
36
+ # Per-class probabilities
37
+ for i, class_name in enumerate(class_labels):
38
+ adata.obs[f"{full_prefix}_prob_{class_name}"] = probs[:, i]
39
+
40
+ # Full probability matrix in obsm
41
+ adata.obsm[f"{full_prefix}_pred_prob_all"] = probs
@@ -0,0 +1,9 @@
1
+ from .base import BaseTorchModel
2
+ from .mlp import MLPClassifier
3
+ from .cnn import CNNClassifier
4
+ from .rnn import RNNClassifier
5
+ from .transformer import BaseTransformer, TransformerClassifier, DANNTransformerClassifier, MaskedTransformerPretrainer
6
+ from .positional import PositionalEncoding
7
+ from .wrappers import ScaledModel
8
+ from .lightning_base import TorchClassifierWrapper
9
+ from .sklearn_models import SklearnModelWrapper
@@ -0,0 +1,14 @@
1
+ import torch.nn as nn
2
+ from ..utils.device import detect_device
3
+
4
+ class BaseTorchModel(nn.Module):
5
+ """
6
+ Minimal base class for torch models that:
7
+ - Stores device
8
+ - Moves model to detected device on init
9
+ """
10
+ def __init__(self, dropout_rate=0.2):
11
+ super().__init__()
12
+ self.device = detect_device() # detects available devices
13
+ self.dropout_rate = dropout_rate # default dropout rate to be used in regularization.
14
+ self.to(self.device) # move model to device
@@ -0,0 +1,34 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from .base import BaseTorchModel
4
+
5
+ class CNNClassifier(BaseTorchModel):
6
+ def __init__(self, input_size, num_classes, **kwargs):
7
+ super().__init__(**kwargs)
8
+ # Define convolutional layers
9
+ self.conv1 = nn.Conv1d(1, 16, kernel_size=3, padding=1)
10
+ self.conv2 = nn.Conv1d(16, 32, kernel_size=3, padding=1)
11
+ # Define activation function
12
+ self.relu = nn.ReLU()
13
+
14
+ # Determine the flattened size dynamically
15
+ dummy_input = torch.zeros(1, 1, input_size).to(self.device)
16
+ with torch.no_grad():
17
+ dummy_output = self._forward_conv(dummy_input)
18
+ flattened_size = dummy_output.view(1, -1).shape[1]
19
+
20
+ # Define fully connected layers
21
+ self.fc1 = nn.Linear(flattened_size, 64)
22
+ self.fc2 = nn.Linear(64, num_classes)
23
+
24
+ def _forward_conv(self, x):
25
+ x = self.relu(self.conv1(x))
26
+ x = self.relu(self.conv2(x))
27
+ return x
28
+
29
+ def forward(self, x):
30
+ x = x.unsqueeze(1) # [B, 1, L]
31
+ x = self._forward_conv(x)
32
+ x = x.view(x.size(0), -1) # flatten
33
+ x = self.relu(self.fc1(x))
34
+ return self.fc2(x)
@@ -0,0 +1,41 @@
1
+ import torch
2
+ import pytorch_lightning as pl
3
+
4
+ class TorchClassifierWrapper(pl.LightningModule):
5
+ def __init__(
6
+ self,
7
+ model: torch.nn.Module,
8
+ optimizer_cls=torch.optim.AdamW,
9
+ optimizer_kwargs=None,
10
+ criterion_cls=torch.nn.CrossEntropyLoss,
11
+ criterion_kwargs=None,
12
+ lr: float = 1e-3,
13
+ ):
14
+ super().__init__()
15
+ self.model = model
16
+ self.save_hyperparameters(ignore=['model']) # logs all except actual model instance
17
+ self.optimizer_cls = optimizer_cls
18
+ self.optimizer_kwargs = optimizer_kwargs or {}
19
+ self.criterion = criterion_cls(**(criterion_kwargs or {}))
20
+ self.lr = lr
21
+
22
+ def forward(self, x):
23
+ return self.model(x)
24
+
25
+ def training_step(self, batch, batch_idx):
26
+ x, y = batch
27
+ logits = self(x)
28
+ loss = self.criterion(logits, y)
29
+ self.log("train_loss", loss, prog_bar=True)
30
+ return loss
31
+
32
+ def validation_step(self, batch, batch_idx):
33
+ x, y = batch
34
+ logits = self(x)
35
+ loss = self.criterion(logits, y)
36
+ acc = (logits.argmax(dim=1) == y).float().mean()
37
+ self.log_dict({"val_loss": loss, "val_acc": acc}, prog_bar=True)
38
+ return loss
39
+
40
+ def configure_optimizers(self):
41
+ return self.optimizer_cls(self.parameters(), lr=self.lr, **self.optimizer_kwargs)
@@ -0,0 +1,17 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from .base import BaseTorchModel
4
+
5
+ class MLPClassifier(BaseTorchModel):
6
+ def __init__(self, input_dim, num_classes, hidden_sizes=(128, 64), **kwargs):
7
+ super().__init__(**kwargs)
8
+ layers = []
9
+ prev = input_dim
10
+ for h in hidden_sizes:
11
+ layers.extend([nn.Linear(prev, h), nn.ReLU(), nn.Dropout(self.dropout_rate)])
12
+ prev = h
13
+ layers.append(nn.Linear(prev, num_classes))
14
+ self.model = nn.Sequential(*layers)
15
+
16
+ def forward(self, x):
17
+ return self.model(x)
@@ -0,0 +1,17 @@
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class PositionalEncoding(nn.Module):
6
+ def __init__(self, d_model, max_len=5000):
7
+ super().__init__()
8
+ pe = torch.zeros(max_len, d_model)
9
+ position = torch.arange(0, max_len).unsqueeze(1).float()
10
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
11
+ pe[:, 0::2] = torch.sin(position * div_term)
12
+ pe[:, 1::2] = torch.cos(position * div_term)
13
+ self.pe = pe.unsqueeze(0) # (1, max_len, d_model)
14
+
15
+ def forward(self, x):
16
+ x = x + self.pe[:, :x.size(1)].to(x.device)
17
+ return x
@@ -0,0 +1,16 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from .base import BaseTorchModel
4
+
5
+ class RNNClassifier(BaseTorchModel):
6
+ def __init__(self, input_size, hidden_dim, num_classes, **kwargs):
7
+ super().__init__(**kwargs)
8
+ # Define LSTM layer
9
+ self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_dim, batch_first=True)
10
+ # Define fully connected output layer
11
+ self.fc = nn.Linear(hidden_dim, num_classes)
12
+
13
+ def forward(self, x):
14
+ x = x.unsqueeze(1) # [B, 1, L] → for LSTM expecting batch_first
15
+ _, (h_n, _) = self.lstm(x) # h_n: [1, B, H]
16
+ return self.fc(h_n.squeeze(0)) # [B, H] → [B, num_classes]
@@ -0,0 +1,40 @@
1
+ from sklearn.ensemble import RandomForestClassifier
2
+ from sklearn.naive_bayes import GaussianNB
3
+ from sklearn.metrics import (
4
+ roc_curve, precision_recall_curve, auc, f1_score, confusion_matrix
5
+ )
6
+ from sklearn.utils.class_weight import compute_class_weight
7
+
8
+ import numpy as np
9
+
10
+ class SklearnModelWrapper:
11
+ def __init__(self, model):
12
+ self.model = model
13
+
14
+ def fit(self, X_train, y_train):
15
+ self.model.fit(X_train, y_train)
16
+
17
+ def predict(self, X):
18
+ return self.model.predict(X)
19
+
20
+ def predict_proba(self, X):
21
+ return self.model.predict_proba(X)
22
+
23
+ def evaluate(self, X_test, y_test):
24
+ probs = self.predict_proba(X_test)[:, 1]
25
+ preds = self.predict(X_test)
26
+
27
+ fpr, tpr, _ = roc_curve(y_test, probs)
28
+ precision, recall, _ = precision_recall_curve(y_test, probs)
29
+ f1 = f1_score(y_test, preds)
30
+ auc_score = auc(fpr, tpr)
31
+ pr_auc = auc(recall, precision)
32
+ cm = confusion_matrix(y_test, preds)
33
+ pos_freq = np.mean(y_test == 1)
34
+ pr_auc_norm = pr_auc / pos_freq
35
+
36
+ return {
37
+ "fpr": fpr, "tpr": tpr, "precision": precision, "recall": recall,
38
+ "f1": f1, "auc": auc_score, "pr_auc": pr_auc,
39
+ "pr_auc_norm": pr_auc_norm, "confusion_matrix": cm
40
+ }