smftools 0.1.3__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. smftools/__init__.py +5 -1
  2. smftools/_version.py +1 -1
  3. smftools/informatics/__init__.py +2 -0
  4. smftools/informatics/archived/print_bam_query_seq.py +29 -0
  5. smftools/informatics/basecall_pod5s.py +80 -0
  6. smftools/informatics/conversion_smf.py +63 -10
  7. smftools/informatics/direct_smf.py +66 -18
  8. smftools/informatics/helpers/LoadExperimentConfig.py +1 -0
  9. smftools/informatics/helpers/__init__.py +16 -2
  10. smftools/informatics/helpers/align_and_sort_BAM.py +27 -16
  11. smftools/informatics/helpers/aligned_BAM_to_bed.py +49 -48
  12. smftools/informatics/helpers/bam_qc.py +66 -0
  13. smftools/informatics/helpers/binarize_converted_base_identities.py +69 -21
  14. smftools/informatics/helpers/canoncall.py +12 -3
  15. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +5 -4
  16. smftools/informatics/helpers/converted_BAM_to_adata.py +34 -22
  17. smftools/informatics/helpers/converted_BAM_to_adata_II.py +369 -0
  18. smftools/informatics/helpers/demux_and_index_BAM.py +52 -0
  19. smftools/informatics/helpers/extract_base_identities.py +33 -46
  20. smftools/informatics/helpers/extract_mods.py +55 -23
  21. smftools/informatics/helpers/extract_read_features_from_bam.py +31 -0
  22. smftools/informatics/helpers/extract_read_lengths_from_bed.py +25 -0
  23. smftools/informatics/helpers/find_conversion_sites.py +33 -44
  24. smftools/informatics/helpers/generate_converted_FASTA.py +87 -86
  25. smftools/informatics/helpers/modcall.py +13 -5
  26. smftools/informatics/helpers/modkit_extract_to_adata.py +762 -396
  27. smftools/informatics/helpers/ohe_batching.py +65 -41
  28. smftools/informatics/helpers/ohe_layers_decode.py +32 -0
  29. smftools/informatics/helpers/one_hot_decode.py +27 -0
  30. smftools/informatics/helpers/one_hot_encode.py +45 -9
  31. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +1 -0
  32. smftools/informatics/helpers/run_multiqc.py +28 -0
  33. smftools/informatics/helpers/split_and_index_BAM.py +3 -8
  34. smftools/informatics/load_adata.py +58 -3
  35. smftools/plotting/__init__.py +15 -0
  36. smftools/plotting/classifiers.py +355 -0
  37. smftools/plotting/general_plotting.py +205 -0
  38. smftools/plotting/position_stats.py +462 -0
  39. smftools/preprocessing/__init__.py +6 -7
  40. smftools/preprocessing/append_C_context.py +22 -9
  41. smftools/preprocessing/{mark_duplicates.py → archives/mark_duplicates.py} +38 -26
  42. smftools/preprocessing/binarize_on_Youden.py +35 -32
  43. smftools/preprocessing/binary_layers_to_ohe.py +13 -3
  44. smftools/preprocessing/calculate_complexity.py +3 -2
  45. smftools/preprocessing/calculate_converted_read_methylation_stats.py +44 -46
  46. smftools/preprocessing/calculate_coverage.py +26 -25
  47. smftools/preprocessing/calculate_pairwise_differences.py +49 -0
  48. smftools/preprocessing/calculate_position_Youden.py +18 -7
  49. smftools/preprocessing/calculate_read_length_stats.py +39 -46
  50. smftools/preprocessing/clean_NaN.py +33 -25
  51. smftools/preprocessing/filter_adata_by_nan_proportion.py +31 -0
  52. smftools/preprocessing/filter_converted_reads_on_methylation.py +20 -5
  53. smftools/preprocessing/filter_reads_on_length.py +14 -4
  54. smftools/preprocessing/flag_duplicate_reads.py +149 -0
  55. smftools/preprocessing/invert_adata.py +18 -11
  56. smftools/preprocessing/load_sample_sheet.py +30 -16
  57. smftools/preprocessing/recipes.py +22 -20
  58. smftools/preprocessing/subsample_adata.py +58 -0
  59. smftools/readwrite.py +105 -13
  60. smftools/tools/__init__.py +49 -0
  61. smftools/tools/apply_hmm.py +202 -0
  62. smftools/tools/apply_hmm_batched.py +241 -0
  63. smftools/tools/archived/classify_methylated_features.py +66 -0
  64. smftools/tools/archived/classify_non_methylated_features.py +75 -0
  65. smftools/tools/archived/subset_adata_v1.py +32 -0
  66. smftools/tools/archived/subset_adata_v2.py +46 -0
  67. smftools/tools/calculate_distances.py +18 -0
  68. smftools/tools/calculate_umap.py +62 -0
  69. smftools/tools/call_hmm_peaks.py +105 -0
  70. smftools/tools/classifiers.py +787 -0
  71. smftools/tools/cluster_adata_on_methylation.py +105 -0
  72. smftools/tools/data/__init__.py +2 -0
  73. smftools/tools/data/anndata_data_module.py +90 -0
  74. smftools/tools/data/preprocessing.py +6 -0
  75. smftools/tools/display_hmm.py +18 -0
  76. smftools/tools/general_tools.py +69 -0
  77. smftools/tools/hmm_readwrite.py +16 -0
  78. smftools/tools/inference/__init__.py +1 -0
  79. smftools/tools/inference/lightning_inference.py +41 -0
  80. smftools/tools/models/__init__.py +9 -0
  81. smftools/tools/models/base.py +14 -0
  82. smftools/tools/models/cnn.py +34 -0
  83. smftools/tools/models/lightning_base.py +41 -0
  84. smftools/tools/models/mlp.py +17 -0
  85. smftools/tools/models/positional.py +17 -0
  86. smftools/tools/models/rnn.py +16 -0
  87. smftools/tools/models/sklearn_models.py +40 -0
  88. smftools/tools/models/transformer.py +133 -0
  89. smftools/tools/models/wrappers.py +20 -0
  90. smftools/tools/nucleosome_hmm_refinement.py +104 -0
  91. smftools/tools/position_stats.py +239 -0
  92. smftools/tools/read_stats.py +70 -0
  93. smftools/tools/subset_adata.py +19 -23
  94. smftools/tools/train_hmm.py +78 -0
  95. smftools/tools/training/__init__.py +1 -0
  96. smftools/tools/training/train_lightning_model.py +47 -0
  97. smftools/tools/utils/__init__.py +2 -0
  98. smftools/tools/utils/device.py +10 -0
  99. smftools/tools/utils/grl.py +14 -0
  100. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/METADATA +47 -11
  101. smftools-0.1.7.dist-info/RECORD +136 -0
  102. smftools/tools/apply_HMM.py +0 -1
  103. smftools/tools/read_HMM.py +0 -1
  104. smftools/tools/train_HMM.py +0 -43
  105. smftools-0.1.3.dist-info/RECORD +0 -84
  106. /smftools/preprocessing/{remove_duplicates.py → archives/remove_duplicates.py} +0 -0
  107. /smftools/tools/{cluster.py → evaluation/__init__.py} +0 -0
  108. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/WHEEL +0 -0
  109. {smftools-0.1.3.dist-info → smftools-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,133 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from .base import BaseTorchModel
4
+ from .positional import PositionalEncoding
5
+ from ..utils.grl import grad_reverse
6
+
7
+
8
+ class BaseTransformer(BaseTorchModel):
9
+ def __init__(self, input_dim, model_dim, num_heads=4, num_layers=2, seq_len=None, use_learnable_pos=False, **kwargs):
10
+ super().__init__(**kwargs)
11
+ # Input FC layer to map D_input to D_model
12
+ self.input_fc = nn.Linear(input_dim, model_dim)
13
+
14
+ if use_learnable_pos:
15
+ assert seq_len is not None, "Must provide seq_len if use_learnable_pos=True"
16
+ self.pos_embed = nn.Parameter(torch.randn(seq_len, model_dim)) # (S, D)
17
+ self.pos_encoder = None
18
+ else:
19
+ self.pos_encoder = PositionalEncoding(model_dim)
20
+ self.pos_embed = None
21
+
22
+ # Specify the transformer encoder structure
23
+ encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=num_heads, batch_first=False)
24
+ # Stack the transformer encoder layers
25
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
26
+
27
+ def encode(self, x, mask=None):
28
+ """
29
+ x: (B, S, D_input)
30
+ mask: (B, S) optional
31
+ """
32
+ x = self.input_fc(x) # (B, S, D_model)
33
+ if self.pos_embed is not None:
34
+ x = x + self.pos_embed.unsqueeze(0).to(x.device) # (B, S, D_model)
35
+ elif self.pos_encoder is not None:
36
+ x = self.pos_encoder(x) # (B, S, D_model)
37
+ if mask is not None:
38
+ x = x * mask.unsqueeze(-1) # (B, S, D_model)
39
+ x = x.permute(1, 0, 2) # (S, B, D_model)
40
+ encoded = self.transformer(x) # (S, B, D_model)
41
+ return encoded.permute(1, 0, 2) # (B, S, D_model)
42
+
43
+ class TransformerClassifier(BaseTransformer):
44
+ def __init__(self, input_dim, model_dim, num_classes, num_heads=4, num_layers=2, **kwargs):
45
+ super().__init__(input_dim, model_dim, num_heads, num_layers, **kwargs)
46
+ # Classification head
47
+ self.cls_head = nn.Linear(model_dim, num_classes)
48
+
49
+ def forward(self, x):
50
+ """
51
+ x: (batch, seq_len, input_dim)
52
+ """
53
+ encoded = self.encode(x) # -> (B, S, D_model)
54
+ pooled = encoded.mean(dim=1) # -> (B, D_model)
55
+ return self.cls_head(pooled) # -> (B, C)
56
+
57
+ class DANNTransformerClassifier(TransformerClassifier):
58
+ def __init__(self, input_dim, model_dim, num_classes, n_domains, **kwargs):
59
+ super().__init__(input_dim, model_dim, num_classes, **kwargs)
60
+ self.domain_classifier = nn.Sequential(
61
+ nn.Linear(model_dim, 128),
62
+ nn.ReLU(),
63
+ nn.Linear(128, n_domains)
64
+ )
65
+
66
+ def forward(self, x, alpha=1.0):
67
+ encoded = self.encode(x) # (B, S, D_model)
68
+ pooled = encoded.mean(dim=1) # (B, D_model)
69
+
70
+ class_logits = self.cls_head(pooled)
71
+ domain_logits = self.domain_classifier(grad_reverse(pooled, alpha))
72
+
73
+ return class_logits, domain_logits
74
+
75
+ class MaskedTransformerPretrainer(BaseTransformer):
76
+ def __init__(self, input_dim, model_dim, num_heads=4, num_layers=2, **kwargs):
77
+ super().__init__(input_dim, model_dim, num_heads, num_layers, **kwargs)
78
+ self.decoder = nn.Linear(model_dim, input_dim)
79
+
80
+ def forward(self, x, mask):
81
+ """
82
+ x: (batch, seq_len, input_dim)
83
+ mask: (batch, seq_len) optional
84
+ """
85
+ if x.dim() == 2:
86
+ x = x.unsqueeze(-1)
87
+ encoded = self.encode(x, mask=mask) # -> (B, S, D_model)
88
+ return self.decoder(encoded) # -> (B, D_input)
89
+
90
+ class DANNTransformer(BaseTransformer):
91
+ """
92
+ """
93
+ def __init__(self, seq_len, model_dim, n_heads, n_layers, n_domains):
94
+ super().__init__(
95
+ input_dim=1, # 1D scalar input per token
96
+ model_dim=model_dim,
97
+ num_heads=n_heads,
98
+ num_layers=n_layers,
99
+ seq_len=seq_len,
100
+ use_learnable_pos=True # enables learnable pos_embed in base
101
+ )
102
+
103
+ # Reconstruction head
104
+ self.recon_head = nn.Linear(model_dim, 1)
105
+
106
+ # Domain classification head
107
+ self.domain_classifier = nn.Sequential(
108
+ nn.Linear(model_dim, 128),
109
+ nn.ReLU(),
110
+ nn.Linear(128, n_domains)
111
+ )
112
+
113
+ def forward(self, x, alpha=1.0):
114
+ """
115
+ x: Tensor of shape (B, S) or (B, S, 1)
116
+ alpha: GRL coefficient (float)
117
+ """
118
+ if x.dim() == 2:
119
+ x = x.unsqueeze(-1) # (B, S, 1)
120
+
121
+ # Encode sequence
122
+ h = self.encode(x) # (B, S, D_model)
123
+
124
+ # Head 1: Reconstruction
125
+ recon = self.recon_head(h).squeeze(-1) # (B, S)
126
+
127
+ # Head 2: Domain classification via GRL
128
+ pooled = h.mean(dim=1) # (B, D_model)
129
+ rev = grad_reverse(pooled, alpha)
130
+ domain_logits = self.domain_classifier(rev) # (B, n_batches)
131
+
132
+ return recon, domain_logits
133
+
@@ -0,0 +1,20 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class ScaledModel(nn.Module):
5
+ def __init__(self, model, mean, std):
6
+ super().__init__()
7
+ self.model = model
8
+ self.register_buffer("mean", torch.tensor(mean, dtype=torch.float32))
9
+ self.register_buffer("std", torch.tensor(std, dtype=torch.float32))
10
+
11
+ def forward(self, x):
12
+ mean = self.mean.to(x.device)
13
+ std = torch.clamp(self.std.to(x.device), min=1e-8)
14
+ if x.dim() == 2:
15
+ x = (x - mean) / std
16
+ elif x.dim() == 3:
17
+ x = (x - mean[None, None, :]) / std[None, None, :]
18
+ else:
19
+ raise ValueError(f"Unsupported input shape {x.shape}")
20
+ return self.model(x)
@@ -0,0 +1,104 @@
1
+ def refine_nucleosome_calls(adata, layer_name, nan_mask_layer, hexamer_size=120, octamer_size=147, max_wiggle=40, device="cpu"):
2
+ import numpy as np
3
+
4
+ nucleosome_layer = adata.layers[layer_name]
5
+ nan_mask = adata.layers[nan_mask_layer] # should be binary mask: 1 = nan region, 0 = valid data
6
+
7
+ hexamer_layer = np.zeros_like(nucleosome_layer)
8
+ octamer_layer = np.zeros_like(nucleosome_layer)
9
+
10
+ for read_idx, row in enumerate(nucleosome_layer):
11
+ in_patch = False
12
+ start_idx = None
13
+
14
+ for pos in range(len(row)):
15
+ if row[pos] == 1 and not in_patch:
16
+ in_patch = True
17
+ start_idx = pos
18
+ if (row[pos] == 0 or pos == len(row) - 1) and in_patch:
19
+ in_patch = False
20
+ end_idx = pos if row[pos] == 0 else pos + 1
21
+
22
+ # Expand boundaries into NaNs
23
+ left_expand = 0
24
+ right_expand = 0
25
+
26
+ # Left
27
+ for i in range(1, max_wiggle + 1):
28
+ if start_idx - i >= 0 and nan_mask[read_idx, start_idx - i] == 1:
29
+ left_expand += 1
30
+ else:
31
+ break
32
+ # Right
33
+ for i in range(1, max_wiggle + 1):
34
+ if end_idx + i < nucleosome_layer.shape[1] and nan_mask[read_idx, end_idx + i] == 1:
35
+ right_expand += 1
36
+ else:
37
+ break
38
+
39
+ expanded_start = start_idx - left_expand
40
+ expanded_end = end_idx + right_expand
41
+
42
+ available_size = expanded_end - expanded_start
43
+
44
+ # Octamer placement
45
+ if available_size >= octamer_size:
46
+ center = (expanded_start + expanded_end) // 2
47
+ half_oct = octamer_size // 2
48
+ octamer_layer[read_idx, center - half_oct: center - half_oct + octamer_size] = 1
49
+
50
+ # Hexamer placement
51
+ elif available_size >= hexamer_size:
52
+ center = (expanded_start + expanded_end) // 2
53
+ half_hex = hexamer_size // 2
54
+ hexamer_layer[read_idx, center - half_hex: center - half_hex + hexamer_size] = 1
55
+
56
+ adata.layers[f"{layer_name}_hexamers"] = hexamer_layer
57
+ adata.layers[f"{layer_name}_octamers"] = octamer_layer
58
+
59
+ print(f"✅ Added layers: {layer_name}_hexamers and {layer_name}_octamers")
60
+ return adata
61
+
62
+ def infer_nucleosomes_in_large_bound(adata, large_bound_layer, combined_nuc_layer, nan_mask_layer, nuc_size=147, linker_size=50, exclusion_buffer=30, device="cpu"):
63
+ import numpy as np
64
+
65
+ large_bound = adata.layers[large_bound_layer]
66
+ existing_nucs = adata.layers[combined_nuc_layer]
67
+ nan_mask = adata.layers[nan_mask_layer]
68
+
69
+ inferred_layer = np.zeros_like(large_bound)
70
+
71
+ for read_idx, row in enumerate(large_bound):
72
+ in_patch = False
73
+ start_idx = None
74
+
75
+ for pos in range(len(row)):
76
+ if row[pos] == 1 and not in_patch:
77
+ in_patch = True
78
+ start_idx = pos
79
+ if (row[pos] == 0 or pos == len(row) - 1) and in_patch:
80
+ in_patch = False
81
+ end_idx = pos if row[pos] == 0 else pos + 1
82
+
83
+ # Adjust boundaries into flanking NaN regions without getting too close to existing nucleosomes
84
+ left_expand = start_idx
85
+ while left_expand > 0 and nan_mask[read_idx, left_expand - 1] == 1 and np.sum(existing_nucs[read_idx, max(0, left_expand - exclusion_buffer):left_expand]) == 0:
86
+ left_expand -= 1
87
+
88
+ right_expand = end_idx
89
+ while right_expand < row.shape[0] and nan_mask[read_idx, right_expand] == 1 and np.sum(existing_nucs[read_idx, right_expand:min(row.shape[0], right_expand + exclusion_buffer)]) == 0:
90
+ right_expand += 1
91
+
92
+ # Phase nucleosomes with linker spacing only
93
+ region = (left_expand, right_expand)
94
+ pos_cursor = region[0]
95
+ while pos_cursor + nuc_size <= region[1]:
96
+ if np.all((existing_nucs[read_idx, pos_cursor - exclusion_buffer:pos_cursor + nuc_size + exclusion_buffer] == 0)):
97
+ inferred_layer[read_idx, pos_cursor:pos_cursor + nuc_size] = 1
98
+ pos_cursor += nuc_size + linker_size
99
+ else:
100
+ pos_cursor += 1
101
+
102
+ adata.layers[f"{large_bound_layer}_phased_nucleosomes"] = inferred_layer
103
+ print(f"✅ Added layer: {large_bound_layer}_phased_nucleosomes")
104
+ return adata
@@ -0,0 +1,239 @@
1
+ # ------------------------- Utilities -------------------------
2
+ def random_fill_nans(X):
3
+ import numpy as np
4
+ nan_mask = np.isnan(X)
5
+ X[nan_mask] = np.random.rand(*X[nan_mask].shape)
6
+ return X
7
+
8
+ def calculate_relative_risk_on_activity(adata, sites, alpha=0.05, groupby=None):
9
+ """
10
+ Perform Bayesian-style methylation vs activity analysis independently within each group.
11
+
12
+ Parameters:
13
+ adata (AnnData): Annotated data matrix.
14
+ sites (list of str): List of site keys (e.g., ['GpC_site', 'CpG_site']).
15
+ alpha (float): FDR threshold for significance.
16
+ groupby (str or list of str): Column(s) in adata.obs to group by.
17
+
18
+ Returns:
19
+ results_dict (dict): Dictionary with structure:
20
+ results_dict[ref][group_label] = (results_df, sig_df)
21
+ """
22
+ import numpy as np
23
+ import pandas as pd
24
+ from scipy.stats import fisher_exact
25
+ from statsmodels.stats.multitest import multipletests
26
+
27
+ def compute_risk_df(ref, site_subset, positions_list, relative_risks, p_values):
28
+ p_adj = multipletests(p_values, method='fdr_bh')[1] if p_values else []
29
+
30
+ genomic_positions = np.array(site_subset.var_names)[positions_list]
31
+ is_gpc_site = site_subset.var[f"{ref}_GpC_site"].values[positions_list]
32
+ is_cpg_site = site_subset.var[f"{ref}_CpG_site"].values[positions_list]
33
+
34
+ results_df = pd.DataFrame({
35
+ 'Feature_Index': positions_list,
36
+ 'Genomic_Position': genomic_positions.astype(int),
37
+ 'Relative_Risk': relative_risks,
38
+ 'Adjusted_P_Value': p_adj,
39
+ 'GpC_Site': is_gpc_site,
40
+ 'CpG_Site': is_cpg_site
41
+ })
42
+
43
+ results_df['log2_Relative_Risk'] = np.log2(results_df['Relative_Risk'].replace(0, 1e-300))
44
+ results_df['-log10_Adj_P'] = -np.log10(results_df['Adjusted_P_Value'].replace(0, 1e-300))
45
+ sig_df = results_df[results_df['Adjusted_P_Value'] < alpha]
46
+ return results_df, sig_df
47
+
48
+ results_dict = {}
49
+
50
+ for ref in adata.obs['Reference_strand'].unique():
51
+ ref_subset = adata[adata.obs['Reference_strand'] == ref].copy()
52
+ if ref_subset.shape[0] == 0:
53
+ continue
54
+
55
+ # Normalize groupby to list
56
+ if groupby is not None:
57
+ if isinstance(groupby, str):
58
+ groupby = [groupby]
59
+ def format_group_label(row):
60
+ return ",".join([f"{col}={row[col]}" for col in groupby])
61
+
62
+ combined_label = '__'.join(groupby)
63
+ ref_subset.obs[combined_label] = ref_subset.obs.apply(format_group_label, axis=1)
64
+ groups = ref_subset.obs[combined_label].unique()
65
+ else:
66
+ combined_label = None
67
+ groups = ['all']
68
+
69
+ results_dict[ref] = {}
70
+
71
+ for group in groups:
72
+ if group == 'all':
73
+ group_subset = ref_subset
74
+ else:
75
+ group_subset = ref_subset[ref_subset.obs[combined_label] == group]
76
+
77
+ if group_subset.shape[0] == 0:
78
+ continue
79
+
80
+ # Build site mask
81
+ site_mask = np.zeros(group_subset.shape[1], dtype=bool)
82
+ for site in sites:
83
+ site_mask |= group_subset.var[f"{ref}_{site}"]
84
+ site_subset = group_subset[:, site_mask].copy()
85
+
86
+ # Matrix and labels
87
+ X = random_fill_nans(site_subset.X.copy())
88
+ y = site_subset.obs['activity_status'].map({'Active': 1, 'Silent': 0}).values
89
+ P_active = np.mean(y)
90
+
91
+ # Analysis
92
+ positions_list, relative_risks, p_values = [], [], []
93
+ for pos in range(X.shape[1]):
94
+ methylation_state = (X[:, pos] > 0).astype(int)
95
+ table = pd.crosstab(methylation_state, y)
96
+ if table.shape != (2, 2):
97
+ continue
98
+
99
+ P_methylated = np.mean(methylation_state)
100
+ P_methylated_given_active = np.mean(methylation_state[y == 1])
101
+ P_methylated_given_inactive = np.mean(methylation_state[y == 0])
102
+
103
+ if P_methylated_given_inactive == 0 or P_methylated in [0, 1]:
104
+ continue
105
+
106
+ P_active_given_methylated = (P_methylated_given_active * P_active) / P_methylated
107
+ P_active_given_unmethylated = ((1 - P_methylated_given_active) * P_active) / (1 - P_methylated)
108
+ RR = P_active_given_methylated / P_active_given_unmethylated
109
+
110
+ _, p_value = fisher_exact(table)
111
+ positions_list.append(pos)
112
+ relative_risks.append(RR)
113
+ p_values.append(p_value)
114
+
115
+ results_df, sig_df = compute_risk_df(ref, site_subset, positions_list, relative_risks, p_values)
116
+ results_dict[ref][group] = (results_df, sig_df)
117
+
118
+ return results_dict
119
+
120
+ def compute_positionwise_statistic(
121
+ adata,
122
+ layer,
123
+ method="pearson",
124
+ groupby=["Reference_strand"],
125
+ output_key="positionwise_result",
126
+ site_config=None,
127
+ encoding="signed",
128
+ max_threads=None
129
+ ):
130
+ """
131
+ Computes a position-by-position matrix (correlation, RR, or Chi-squared) from an adata layer.
132
+
133
+ Parameters:
134
+ adata (AnnData): Annotated data matrix.
135
+ layer (str): Name of the adata layer to use.
136
+ method (str): 'pearson', 'binary_covariance', 'relative_risk', or 'chi_squared'.
137
+ groupby (str or list): Column(s) in adata.obs to group by.
138
+ output_key (str): Key in adata.uns to store results.
139
+ site_config (dict): Optional {ref: [site_types]} to restrict sites per reference.
140
+ encoding (str): 'signed' (1/-1/0) or 'binary' (1/0/NaN).
141
+ max_threads (int): Number of parallel threads to use (joblib).
142
+ """
143
+ import numpy as np
144
+ import pandas as pd
145
+ from scipy.stats import chi2_contingency
146
+ from joblib import Parallel, delayed
147
+ from tqdm import tqdm
148
+
149
+ if isinstance(groupby, str):
150
+ groupby = [groupby]
151
+
152
+ adata.uns[output_key] = {}
153
+ adata.uns[output_key + "_n"] = {}
154
+
155
+ label_col = "__".join(groupby)
156
+ adata.obs[label_col] = adata.obs[groupby].astype(str).agg("_".join, axis=1)
157
+
158
+ for group in adata.obs[label_col].unique():
159
+ subset = adata[adata.obs[label_col] == group].copy()
160
+ if subset.shape[0] == 0:
161
+ continue
162
+
163
+ ref = subset.obs["Reference_strand"].unique()[0] if "Reference_strand" in groupby else None
164
+
165
+ if site_config and ref in site_config:
166
+ site_mask = np.zeros(subset.shape[1], dtype=bool)
167
+ for site in site_config[ref]:
168
+ site_mask |= subset.var[f"{ref}_{site}"]
169
+ subset = subset[:, site_mask].copy()
170
+
171
+ X = subset.layers[layer].copy()
172
+
173
+ if encoding == "signed":
174
+ X_bin = np.where(X == 1, 1, np.where(X == -1, 0, np.nan))
175
+ else:
176
+ X_bin = np.where(X == 1, 1, np.where(X == 0, 0, np.nan))
177
+
178
+ n_pos = subset.shape[1]
179
+ mat = np.zeros((n_pos, n_pos))
180
+
181
+ if method == "pearson":
182
+ with np.errstate(invalid='ignore'):
183
+ mat = np.corrcoef(np.nan_to_num(X_bin).T)
184
+
185
+ elif method == "binary_covariance":
186
+ binary = (X_bin == 1).astype(float)
187
+ valid = (X_bin == 1) | (X_bin == 0) # Only consider true binary (ignore NaN)
188
+ valid = valid.astype(float)
189
+
190
+ numerator = np.dot(binary.T, binary)
191
+ denominator = np.dot(valid.T, valid)
192
+
193
+ with np.errstate(divide='ignore', invalid='ignore'):
194
+ mat = np.true_divide(numerator, denominator)
195
+ mat[~np.isfinite(mat)] = 0
196
+
197
+ elif method in ["relative_risk", "chi_squared"]:
198
+ def compute_row(i):
199
+ row = np.zeros(n_pos)
200
+ xi = X_bin[:, i]
201
+ for j in range(n_pos):
202
+ xj = X_bin[:, j]
203
+ mask = ~np.isnan(xi) & ~np.isnan(xj)
204
+ if np.sum(mask) < 10:
205
+ row[j] = np.nan
206
+ continue
207
+ if method == "relative_risk":
208
+ a = np.sum((xi[mask] == 1) & (xj[mask] == 1))
209
+ b = np.sum((xi[mask] == 1) & (xj[mask] == 0))
210
+ c = np.sum((xi[mask] == 0) & (xj[mask] == 1))
211
+ d = np.sum((xi[mask] == 0) & (xj[mask] == 0))
212
+ if (a + b) > 0 and (c + d) > 0 and c > 0:
213
+ p1 = a / (a + b)
214
+ p2 = c / (c + d)
215
+ row[j] = p1 / p2 if p2 > 0 else np.nan
216
+ else:
217
+ row[j] = np.nan
218
+ elif method == "chi_squared":
219
+ table = pd.crosstab(xi[mask], xj[mask])
220
+ if table.shape != (2, 2):
221
+ row[j] = np.nan
222
+ else:
223
+ chi2, _, _, _ = chi2_contingency(table, correction=False)
224
+ row[j] = chi2
225
+ return row
226
+
227
+ mat = np.array(
228
+ Parallel(n_jobs=max_threads)(
229
+ delayed(compute_row)(i) for i in tqdm(range(n_pos), desc=f"{method}: {group}")
230
+ )
231
+ )
232
+
233
+ else:
234
+ raise ValueError(f"Unsupported method: {method}")
235
+
236
+ var_names = subset.var_names.astype(int)
237
+ mat_df = pd.DataFrame(mat, index=var_names, columns=var_names)
238
+ adata.uns[output_key][group] = mat_df
239
+ adata.uns[output_key + "_n"][group] = subset.shape[0]
@@ -0,0 +1,70 @@
1
+ # ------------------------- Utilities -------------------------
2
+ def random_fill_nans(X):
3
+ import numpy as np
4
+ nan_mask = np.isnan(X)
5
+ X[nan_mask] = np.random.rand(*X[nan_mask].shape)
6
+ return X
7
+
8
+ def calculate_row_entropy(
9
+ adata,
10
+ layer,
11
+ output_key="entropy",
12
+ site_config=None,
13
+ ref_col="Reference_strand",
14
+ encoding="signed",
15
+ max_threads=None):
16
+ """
17
+ Adds an obs column to the adata that calculates entropy within each read from a given layer
18
+ when looking at each site type passed in the site_config list.
19
+
20
+ Parameters:
21
+ adata (AnnData): The annotated data matrix.
22
+ layer (str): Name of the layer to use for entropy calculation.
23
+ method (str): Unused currently. Placeholder for potential future methods.
24
+ output_key (str): Base name for the entropy column in adata.obs.
25
+ site_config (dict): {ref: [site_types]} for masking relevant sites.
26
+ ref_col (str): Column in adata.obs denoting reference strands.
27
+ encoding (str): 'signed' (1/-1/0) or 'binary' (1/0/NaN).
28
+ max_threads (int): Number of threads for parallel processing.
29
+ """
30
+ import numpy as np
31
+ import pandas as pd
32
+ from scipy.stats import entropy
33
+ from joblib import Parallel, delayed
34
+ from tqdm import tqdm
35
+
36
+ entropy_values = []
37
+ row_indices = []
38
+
39
+ for ref in adata.obs[ref_col].cat.categories:
40
+ subset = adata[adata.obs[ref_col] == ref].copy()
41
+ if subset.shape[0] == 0:
42
+ continue
43
+
44
+ if site_config and ref in site_config:
45
+ site_mask = np.zeros(subset.shape[1], dtype=bool)
46
+ for site in site_config[ref]:
47
+ site_mask |= subset.var[f"{ref}_{site}"]
48
+ subset = subset[:, site_mask].copy()
49
+
50
+ X = subset.layers[layer].copy()
51
+
52
+ if encoding == "signed":
53
+ X_bin = np.where(X == 1, 1, np.where(X == -1, 0, np.nan))
54
+ else:
55
+ X_bin = np.where(X == 1, 1, np.where(X == 0, 0, np.nan))
56
+
57
+ def compute_entropy(row):
58
+ counts = pd.Series(row).value_counts(dropna=True).sort_index()
59
+ probs = counts / counts.sum()
60
+ return entropy(probs, base=2)
61
+
62
+ entropies = Parallel(n_jobs=max_threads)(
63
+ delayed(compute_entropy)(X_bin[i, :]) for i in tqdm(range(X_bin.shape[0]), desc=f"Entropy: {ref}")
64
+ )
65
+
66
+ entropy_values.extend(entropies)
67
+ row_indices.extend(subset.obs_names.tolist())
68
+
69
+ entropy_key = f"{output_key}_entropy"
70
+ adata.obs.loc[row_indices, entropy_key] = entropy_values
@@ -1,32 +1,28 @@
1
1
  # subset_adata
2
2
 
3
- def subset_adata(adata, obs_columns):
3
+ def subset_adata(adata, columns, cat_type='obs'):
4
4
  """
5
- Subsets an AnnData object based on categorical values in specified `.obs` columns.
5
+ Adds subset metadata to an AnnData object based on categorical values in specified .obs or .var columns.
6
6
 
7
7
  Parameters:
8
- adata (AnnData): The AnnData object to subset.
9
- obs_columns (list of str): List of `.obs` column names to subset by. The order matters.
8
+ adata (AnnData): The AnnData object to add subset metadata to.
9
+ columns (list of str): List of .obs or .var column names to subset by. The order matters.
10
+ cat_type (str): obs or var. Default is obs
10
11
 
11
12
  Returns:
12
- dict: A dictionary where keys are tuples of category values and values are corresponding AnnData subsets.
13
+ None
13
14
  """
15
+ import pandas as pd
16
+ import anndata as ad
14
17
 
15
- def subset_recursive(adata_subset, columns):
16
- if not columns:
17
- return {(): adata_subset}
18
-
19
- current_column = columns[0]
20
- categories = adata_subset.obs[current_column].cat.categories
21
-
22
- subsets = {}
23
- for cat in categories:
24
- subset = adata_subset[adata_subset.obs[current_column] == cat]
25
- subsets.update(subset_recursive(subset, columns[1:]))
26
-
27
- return subsets
28
-
29
- # Start the recursive subset process
30
- subsets_dict = subset_recursive(adata, obs_columns)
31
-
32
- return subsets_dict
18
+ subgroup_name = '_'.join(columns)
19
+ if 'obs' in cat_type:
20
+ df = adata.obs[columns]
21
+ adata.obs[subgroup_name] = df.apply(lambda row: '_'.join(row.astype(str)), axis=1)
22
+ adata.obs[subgroup_name] = adata.obs[subgroup_name].astype('category')
23
+ elif 'var' in cat_type:
24
+ df = adata.var[columns]
25
+ adata.var[subgroup_name] = df.apply(lambda row: '_'.join(row.astype(str)), axis=1)
26
+ adata.var[subgroup_name] = adata.var[subgroup_name].astype('category')
27
+
28
+ return None