tpcav 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tpcav/__init__.py ADDED
@@ -0,0 +1,39 @@
1
+ """
2
+ Lightweight, reusable TCAV utilities built from the repository scripts.
3
+
4
+ The package keeps existing scripts untouched while offering programmatic
5
+ access to concept construction and PCA/attribution workflows.
6
+ """
7
+
8
+ import logging
9
+
10
+ # Set the logging level to INFO
11
+ logging.basicConfig(level=logging.INFO)
12
+
13
+ from .cavs import CavTrainer
14
+ from .concepts import ConceptBuilder
15
+ from .helper import (
16
+ bed_to_chrom_tracks_iter,
17
+ bed_to_fasta_iter,
18
+ dataframe_to_chrom_tracks_iter,
19
+ dataframe_to_fasta_iter,
20
+ dinuc_shuffle_sequences,
21
+ fasta_to_one_hot_sequences,
22
+ random_regions_dataframe,
23
+ )
24
+ from .logging_utils import set_verbose
25
+ from .tpcav_model import TPCAV
26
+
27
+ __all__ = [
28
+ "ConceptBuilder",
29
+ "CavTrainer",
30
+ "TPCAV",
31
+ "bed_to_fasta_iter",
32
+ "dataframe_to_fasta_iter",
33
+ "bed_to_chrom_tracks_iter",
34
+ "dataframe_to_chrom_tracks_iter",
35
+ "fasta_to_one_hot_sequences",
36
+ "random_regions_dataframe",
37
+ "dinuc_shuffle_sequences",
38
+ "set_verbose",
39
+ ]
tpcav/cavs.py ADDED
@@ -0,0 +1,334 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ CAV training and attribution utilities built on TPCAV.
4
+ """
5
+
6
+ import logging
7
+ import multiprocessing
8
+ from pathlib import Path
9
+ from typing import Iterable, List, Optional, Tuple
10
+
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+ import seaborn as sns
14
+ import torch
15
+ from sklearn.linear_model import SGDClassifier
16
+ from sklearn.metrics import precision_recall_fscore_support
17
+ from sklearn.metrics.pairwise import cosine_similarity
18
+ from sklearn.model_selection import GridSearchCV
19
+ from torch.utils.data import DataLoader, TensorDataset, random_split
20
+
21
+ from tpcav.tpcav_model import TPCAV
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ def _load_all_tensors_to_numpy(dataloaders: Iterable[DataLoader]):
27
+ if not isinstance(dataloaders, list):
28
+ dataloaders = [dataloaders]
29
+ avs, ls = [], []
30
+ for dataloader in dataloaders:
31
+ for av, l in dataloader:
32
+ avs.append(av.cpu().numpy())
33
+ ls.append(l.cpu().numpy())
34
+ return np.concatenate(avs), np.concatenate(ls)
35
+
36
+
37
+ class _SGDWrapper:
38
+ """Lightweight SGD concept classifier."""
39
+
40
+ def __init__(self, penalty: str = "l2", n_jobs: int = -1):
41
+ self.lm = SGDClassifier(
42
+ max_iter=1000,
43
+ early_stopping=True,
44
+ validation_fraction=0.1,
45
+ learning_rate="optimal",
46
+ n_iter_no_change=10,
47
+ n_jobs=n_jobs,
48
+ penalty=penalty,
49
+ )
50
+ if penalty == "l2":
51
+ params = {"alpha": [1e-2, 1e-4, 1e-6]}
52
+ elif penalty == "l1":
53
+ params = {"alpha": [1e-1, 1]}
54
+ else:
55
+ raise ValueError(f"Unexpected penalty type {penalty}")
56
+ self.search = GridSearchCV(self.lm, params)
57
+
58
+ def fit(self, train_dl: DataLoader, val_dl: DataLoader):
59
+ train_avs, train_ls = _load_all_tensors_to_numpy([train_dl, val_dl])
60
+ self.search.fit(train_avs, train_ls)
61
+ self.lm = self.search.best_estimator_
62
+ logger.info(
63
+ "Best Params: %s | Iterations: %s",
64
+ self.search.best_params_,
65
+ self.lm.n_iter_,
66
+ )
67
+
68
+ @property
69
+ def weights(self) -> torch.Tensor:
70
+ if len(self.lm.coef_) == 1:
71
+ return torch.tensor(np.array([-1 * self.lm.coef_[0], self.lm.coef_[0]]))
72
+ return torch.tensor(self.lm.coef_)
73
+
74
+ @property
75
+ def classes_(self):
76
+ return self.lm.classes_
77
+
78
+ def predict(self, x: np.ndarray) -> np.ndarray:
79
+ return self.lm.predict(x)
80
+
81
+
82
+ def _train(
83
+ concept_embeddings: torch.Tensor,
84
+ control_embeddings: torch.Tensor,
85
+ output_dir: str,
86
+ penalty: str = "l2",
87
+ ) -> Tuple[float, torch.Tensor]:
88
+ """
89
+ Train a binary CAV classifier for a concept vs cached control embeddings.
90
+
91
+ Requires set_control to have been called beforehand.
92
+ """
93
+ output_dir = Path(output_dir)
94
+
95
+ avd = TensorDataset(
96
+ concept_embeddings, torch.full((concept_embeddings.shape[0],), 0)
97
+ )
98
+ cvd = TensorDataset(
99
+ control_embeddings, torch.full((control_embeddings.shape[0],), 1)
100
+ )
101
+ train_ds, val_ds, test_ds = random_split(avd, [0.8, 0.1, 0.1])
102
+ c_train, c_val, c_test = random_split(cvd, [0.8, 0.1, 0.1])
103
+
104
+ train_dl = DataLoader(train_ds + c_train, batch_size=32, shuffle=True)
105
+ val_dl = DataLoader(val_ds + c_val, batch_size=32)
106
+ test_dl = DataLoader(test_ds + c_test, batch_size=32)
107
+
108
+ clf = _SGDWrapper(penalty=penalty)
109
+ clf.fit(train_dl, val_dl)
110
+
111
+ def _eval(split_dl: DataLoader, name: str):
112
+ y_preds, y_trues = [], []
113
+ for x, y in split_dl:
114
+ y_pred = clf.predict(x.cpu().numpy())
115
+ y_preds.append(y_pred)
116
+ y_trues.append(y.cpu().numpy())
117
+ y_preds = np.concatenate(y_preds)
118
+ y_trues = np.concatenate(y_trues)
119
+ acc = (y_preds == y_trues).sum() / len(y_trues)
120
+ precision, recall, fscore, support = precision_recall_fscore_support(
121
+ y_trues, y_preds, average="binary", pos_label=1
122
+ )
123
+ logger.info("[%s] Accuracy: %.4f", name, acc)
124
+ (output_dir / f"classifier_perform_on_{name}.txt").write_text(
125
+ f"Accuracy: {acc}\n"
126
+ )
127
+ return fscore
128
+
129
+ output_dir.mkdir(parents=True, exist_ok=True)
130
+ _eval(train_dl, "train")
131
+ _eval(val_dl, "val")
132
+ test_fscore = _eval(test_dl, "test")
133
+
134
+ weights = clf.weights
135
+ assert len(weights.shape) == 2 and weights.shape[0] == 2
136
+ torch.save(weights, output_dir / "classifier_weights.pt")
137
+
138
+ return test_fscore, weights[0]
139
+
140
+
141
+ class CavTrainer:
142
+ """Train CAVs and compute attribution-driven TCAV scores."""
143
+
144
+ def __init__(self, tpcav: TPCAV, penalty: str = "l2") -> None:
145
+ self.tpcav = tpcav
146
+ self.penalty = penalty
147
+ self.cavs_fscores = {}
148
+ self.cav_weights = {}
149
+ self.control_embeddings: Optional[torch.Tensor] = None
150
+ self.cavs_list: List[torch.Tensor] = []
151
+
152
+ def save_state(self, output_path: str = "cav_trainer_state.pt"):
153
+ """
154
+ Save CavTrainer state to a file.
155
+ """
156
+ state = {
157
+ "cavs_fscores": self.cavs_fscores,
158
+ "cav_weights": self.cav_weights,
159
+ "control_embeddings": self.control_embeddings,
160
+ "cavs_list": self.cavs_list,
161
+ }
162
+ torch.save(state, output_path)
163
+
164
+ def restore_state(self, input_path: str = "cav_trainer_state.pt"):
165
+ """
166
+ Restore CavTrainer state from a file.
167
+ """
168
+ state = torch.load(input_path, map_location="cpu")
169
+ self.cavs_fscores = state["cavs_fscores"]
170
+ self.cav_weights = state["cav_weights"]
171
+ self.control_embeddings = state["control_embeddings"]
172
+ self.cavs_list = state["cavs_list"]
173
+
174
+ def set_control(self, control_concept, num_samples: int) -> torch.Tensor:
175
+ """
176
+ Set and cache control embeddings to avoid recomputation across CAV trainings.
177
+ """
178
+ self.control_embeddings = self.tpcav.concept_embeddings(
179
+ control_concept, num_samples=num_samples
180
+ )
181
+ return self.control_embeddings
182
+
183
+ def train_concepts(
184
+ self,
185
+ concept_list,
186
+ num_samples: int,
187
+ output_dir: str,
188
+ num_processes: int = 1,
189
+ ):
190
+ if self.control_embeddings is None:
191
+ raise ValueError(
192
+ "Call set_control(control_concept, num_samples=...) before training CAVs."
193
+ )
194
+
195
+ if num_processes == 1:
196
+ for c in concept_list:
197
+ concept_embeddings = self.tpcav.concept_embeddings(
198
+ c, num_samples=num_samples
199
+ )
200
+ fscore, weight = _train(
201
+ concept_embeddings,
202
+ self.control_embeddings,
203
+ Path(output_dir) / c.name,
204
+ self.penalty,
205
+ )
206
+ self.cavs_fscores[c.name] = fscore
207
+ self.cav_weights[c.name] = weight
208
+ self.cavs_list.append(weight)
209
+ else:
210
+ pool = multiprocessing.Pool(processes=num_processes)
211
+ results = []
212
+ for c in concept_list:
213
+ concept_embeddings = self.tpcav.concept_embeddings(
214
+ c, num_samples=num_samples
215
+ )
216
+ res = pool.apply_async(
217
+ _train,
218
+ args=(
219
+ concept_embeddings,
220
+ self.control_embeddings,
221
+ Path(output_dir) / c.name,
222
+ self.penalty,
223
+ ),
224
+ )
225
+ logger.info("Submitted CAV training for concept %s", c.name)
226
+ results.append((c.name, res))
227
+ pool.close()
228
+ pool.join()
229
+ results = [(name, res.get()) for name, res in results]
230
+ for name, (fscore, weight) in results:
231
+ self.cavs_fscores[name] = fscore
232
+ self.cav_weights[name] = weight
233
+ self.cavs_list.append(weight)
234
+
235
+ def tpcav_score(
236
+ self, concept_name: str, attributions: torch.Tensor
237
+ ) -> torch.Tensor:
238
+ """
239
+ Compute a simple TCAV score: mean directional attribution along the concept CAV.
240
+ """
241
+ if concept_name not in self.cav_weights:
242
+ raise ValueError(f"No CAV weights stored for concept {concept_name}")
243
+ weights = self.cav_weights[concept_name]
244
+ flat_attr = attributions.flatten(start_dim=1)
245
+ scores = torch.matmul(flat_attr, weights.to(flat_attr.device).unsqueeze(-1))
246
+
247
+ return scores
248
+
249
+ def tpcav_score_binary_log_ratio(
250
+ self, concept_name: str, attributions: torch.Tensor, pseudocount: float = 1.0
251
+ ) -> float:
252
+ """
253
+ Compute TCAV log ratio score: log2 of ratio of positive to negative directional attributions.
254
+ """
255
+ scores = self.tpcav_score(concept_name, attributions)
256
+
257
+ pos_count = (scores > 0).sum().item()
258
+ neg_count = (scores < 0).sum().item()
259
+
260
+ return np.log((pos_count + pseudocount) / (neg_count + pseudocount))
261
+
262
+ def plot_cavs_similaritiy_heatmap(
263
+ self,
264
+ attributions: torch.Tensor,
265
+ concept_list: List[str] | None = None,
266
+ fscore_thresh=0.8,
267
+ output_path: str = "cavs_similarity_heatmap.png",
268
+ ):
269
+ if concept_list is None:
270
+ cavs_names = list(self.cav_weights.keys())
271
+ else:
272
+ cavs_names = concept_list
273
+ cavs_pass = []
274
+ cavs_names_pass = []
275
+ for cname in cavs_names:
276
+ if self.cavs_fscores[cname] >= fscore_thresh:
277
+ cavs_pass.append(self.cav_weights[cname])
278
+ cavs_names_pass.append(cname)
279
+ else:
280
+ logger.info(
281
+ "Skipping CAV %s with F-score %.3f below threshold %.3f",
282
+ cname,
283
+ self.cavs_fscores[cname],
284
+ fscore_thresh,
285
+ )
286
+ if len(cavs_pass) == 0:
287
+ logger.warning(f"No CAVs passed the F-score threshold {fscore_thresh:.3f}.")
288
+ return
289
+
290
+ # compute similarity matrix
291
+ matrix_similarity = cosine_similarity(cavs_pass)
292
+
293
+ # plot
294
+ cm = sns.clustermap(
295
+ matrix_similarity,
296
+ xticklabels=False,
297
+ yticklabels=False,
298
+ cmap="bwr",
299
+ vmin=-1,
300
+ vmax=1,
301
+ )
302
+ cm.gs.update(left=0.05, right=0.5)
303
+ cm.ax_cbar.set_position([0.01, 0.9, 0.05, 0.05])
304
+
305
+ cavs_names_sorted = [
306
+ cavs_names_pass[i] for i in cm.dendrogram_col.reordered_ind
307
+ ]
308
+
309
+ ## plot log ratio plot
310
+ ax_log = cm.figure.add_subplot()
311
+ heatmap_bbox = cm.ax_heatmap.get_position()
312
+ ax_log.set_position([0.5, heatmap_bbox.y0, 0.2, heatmap_bbox.height])
313
+ # used to leave space for motif logos
314
+ # ax_log.tick_params(
315
+ # axis="y", which="major", pad=cm.figure.get_size_inches()[0] * 0.2 * 72
316
+ # )
317
+
318
+ log_ratios_reordered = [
319
+ self.tpcav_score_binary_log_ratio(cname, attributions)
320
+ for cname in cavs_names_sorted
321
+ ]
322
+ sns.barplot(y=cavs_names_sorted, x=log_ratios_reordered, orient="y", ax=ax_log)
323
+ # set color of bar by value
324
+ for idx in range(len((ax_log.containers[0]))):
325
+ if ax_log.containers[0].datavalues[idx] > 0:
326
+ ax_log.containers[0][idx].set_color("red")
327
+ else:
328
+ ax_log.containers[0][idx].set_color("blue")
329
+
330
+ ax_log.set_xlim(left=-5, right=5)
331
+ ax_log.yaxis.tick_right()
332
+ ax_log.set_title("TCAV log ratio")
333
+
334
+ plt.savefig(output_path, dpi=300, bbox_inches="tight")
tpcav/concepts.py ADDED
@@ -0,0 +1,309 @@
1
+ import logging
2
+ from copy import deepcopy
3
+ from typing import Dict, Iterable, List, Optional, Sequence, Tuple
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import seqchromloader as scl
8
+ import webdataset as wds
9
+ from Bio import motifs as Bio_motifs
10
+ from captum.concept import Concept
11
+ from torch.utils.data import DataLoader
12
+
13
+ from . import helper, utils
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class _PairedLoader:
19
+ """Allow repeated iteration over paired dataloaders."""
20
+
21
+ def __init__(self, seq_dl: Iterable, chrom_dl: Iterable) -> None:
22
+ self.seq_dl = seq_dl
23
+ self.chrom_dl = chrom_dl
24
+ self.apply_func = None
25
+
26
+ def apply(self, apply_func):
27
+ self.apply_func = apply_func
28
+
29
+ def __iter__(self):
30
+ for inputs in zip(self.seq_dl, self.chrom_dl):
31
+ if self.apply_func:
32
+ inputs = self.apply_func(*inputs)
33
+ yield inputs
34
+
35
+
36
+ def _construct_motif_concept_dataloader_from_control(
37
+ control_seq_df: pd.DataFrame,
38
+ genome_fasta: str,
39
+ motifs: Sequence,
40
+ num_motifs: int,
41
+ motif_mode: str,
42
+ batch_size: int,
43
+ num_workers: int,
44
+ ) -> DataLoader:
45
+ """Mirror the motif-based dataloader logic used in the TCAV script."""
46
+ datasets = []
47
+ for motif in motifs:
48
+ ds = utils.IterateSeqDataFrame(
49
+ control_seq_df,
50
+ genome_fasta,
51
+ motif=motif,
52
+ motif_mode=motif_mode,
53
+ num_motifs=num_motifs,
54
+ start_buffer=0,
55
+ end_buffer=0,
56
+ print_warning=False,
57
+ infinite=False,
58
+ )
59
+ datasets.append(ds)
60
+
61
+ mixed_dl = DataLoader(
62
+ wds.RandomMix(datasets),
63
+ batch_size=batch_size,
64
+ num_workers=num_workers,
65
+ pin_memory=True,
66
+ drop_last=False,
67
+ )
68
+ return mixed_dl
69
+
70
+
71
+ class ConceptBuilder:
72
+ """Build and store concepts/control concepts in a reusable, programmatic way."""
73
+
74
+ def __init__(
75
+ self,
76
+ genome_fasta: str,
77
+ genome_size_file: str,
78
+ input_window_length: int = 1024,
79
+ bws: Optional[List[str]] = None,
80
+ batch_size: int = 8,
81
+ num_workers: int = 0,
82
+ num_motifs: int = 12,
83
+ include_reverse_complement: bool = False,
84
+ min_samples: int = 5000,
85
+ rng_seed: int = 1001,
86
+ ) -> None:
87
+ self.genome_fasta = genome_fasta
88
+ self.genome_size_file = genome_size_file
89
+ self.input_window_length = input_window_length
90
+ self.bws = bws or []
91
+ self.batch_size = batch_size
92
+ self.num_workers = num_workers
93
+ self.num_motifs = num_motifs
94
+ self.include_reverse_complement = include_reverse_complement
95
+ self.min_samples = min_samples
96
+ self.rng_seed = rng_seed
97
+
98
+ self.control_regions: pd.DataFrame | None = None
99
+ self.control_concepts: List[Concept] = []
100
+ self.concepts: List[Concept] = []
101
+ self.metadata: Dict[str, object] = {}
102
+ self._next_concept_id = 0
103
+ self._control_seq_loader: Optional[Iterable] = None
104
+ self._control_chrom_loader: Optional[Iterable] = None
105
+
106
+ def build_control(self, name: str = "random_regions") -> Concept:
107
+ """Create the background/control concept."""
108
+ control_regions = scl.random_coords(
109
+ gs=self.genome_size_file,
110
+ l=self.input_window_length,
111
+ n=self.min_samples,
112
+ )
113
+ control_regions["label"] = -1
114
+ control_regions["strand"] = "+"
115
+ self.control_regions = control_regions
116
+
117
+ concept = Concept(
118
+ id=self._reserve_id(is_control=True),
119
+ name=name,
120
+ data_iter=_PairedLoader(self._control_seq_dl(), self._control_chrom_dl()),
121
+ )
122
+ self.control_concepts = [concept]
123
+ self.metadata["control_regions"] = control_regions
124
+ return concept
125
+
126
+ def _control_seq_dl(self):
127
+ if self.control_regions is None:
128
+ raise ValueError("Call build_control before creating control regions.")
129
+ seq_fasta_iter = helper.DataFrame2FastaIterator(
130
+ self.control_regions, self.genome_fasta, batch_size=self.batch_size
131
+ )
132
+ return seq_fasta_iter
133
+
134
+ def _control_chrom_dl(self):
135
+ if self.control_regions is None:
136
+ raise ValueError("Call build_control before creating control regions.")
137
+ chrom_iter = helper.DataFrame2ChromTracksIterator(
138
+ self.control_regions,
139
+ self.bws,
140
+ batch_size=self.batch_size,
141
+ )
142
+ return chrom_iter
143
+
144
+ def add_custom_motif_concepts(
145
+ self, motif_table: str, control_regions: Optional[pd.DataFrame] = None
146
+ ) -> List[Concept]:
147
+ """Add concepts from a tab-delimited motif table: motif_name<TAB>consensus."""
148
+ if control_regions is None:
149
+ if not self.control_concepts:
150
+ raise ValueError("Call build_control or pass control_regions first.")
151
+ control_regions = self.metadata.get("control_regions")
152
+ assert control_regions is not None
153
+ df = pd.read_table(motif_table, names=["motif_name", "consensus_seq"])
154
+ added: List[Concept] = []
155
+ for motif_name in np.unique(df.motif_name):
156
+ consensus = df.loc[df.motif_name == motif_name, "consensus_seq"].tolist()
157
+ motifs = []
158
+ for idx, cons in enumerate(consensus):
159
+ motif = utils.CustomMotif(f"{motif_name}_{idx}", cons)
160
+ motifs.append(motif)
161
+ if self.include_reverse_complement:
162
+ motifs.append(motif.reverse_complement())
163
+ seq_dl = _construct_motif_concept_dataloader_from_control(
164
+ control_regions,
165
+ self.genome_fasta,
166
+ motifs=motifs,
167
+ num_motifs=self.num_motifs,
168
+ motif_mode="consensus",
169
+ batch_size=self.batch_size,
170
+ num_workers=self.num_workers,
171
+ )
172
+ concept = Concept(
173
+ id=self._reserve_id(),
174
+ name=motif_name,
175
+ data_iter=_PairedLoader(seq_dl, self._control_chrom_dl()),
176
+ )
177
+ self.concepts.append(concept)
178
+ added.append(concept)
179
+ return added
180
+
181
+ def add_meme_motif_concepts(
182
+ self, meme_file: str, control_regions: Optional[pd.DataFrame] = None
183
+ ) -> List[Concept]:
184
+ """Add concepts from a MEME minimal-format motif file."""
185
+ if control_regions is None:
186
+ if not self.control_concepts:
187
+ raise ValueError("Call build_control or pass control_regions first.")
188
+ control_regions = self.metadata.get("control_regions")
189
+ assert control_regions is not None
190
+
191
+ added: List[Concept] = []
192
+ with open(meme_file) as handle:
193
+ for motif in Bio_motifs.parse(handle, fmt="MINIMAL"):
194
+ motifs = [motif]
195
+ if self.include_reverse_complement:
196
+ motifs.append(motif.reverse_complement())
197
+ motif_name = motif.name.replace("/", "-")
198
+ seq_dl = _construct_motif_concept_dataloader_from_control(
199
+ control_regions,
200
+ self.genome_fasta,
201
+ motifs=motifs,
202
+ num_motifs=self.num_motifs,
203
+ motif_mode="pwm",
204
+ batch_size=self.batch_size,
205
+ num_workers=self.num_workers,
206
+ )
207
+ concept = Concept(
208
+ id=self._reserve_id(),
209
+ name=motif_name,
210
+ data_iter=_PairedLoader(seq_dl, self._control_chrom_dl()),
211
+ )
212
+ self.concepts.append(concept)
213
+ added.append(concept)
214
+ return added
215
+
216
+ def add_bed_sequence_concepts(self, bed_paths: Iterable[str]) -> List[Concept]:
217
+ """Add concepts backed by BED sequences with concept_name in column 5."""
218
+ added: List[Concept] = []
219
+ for bed in bed_paths:
220
+ bed_df = pd.read_table(
221
+ bed,
222
+ header=None,
223
+ usecols=[0, 1, 2, 3, 4],
224
+ names=["chrom", "start", "end", "strand", "concept_name"],
225
+ )
226
+ for concept_name in bed_df.concept_name.unique():
227
+ concept_df = bed_df.loc[bed_df.concept_name == concept_name]
228
+ if len(concept_df) < self.min_samples:
229
+ logger.warning(
230
+ "Concept %s has %s samples, fewer than min_samples=%s; skipping",
231
+ concept_name,
232
+ len(concept_df),
233
+ self.min_samples,
234
+ )
235
+ continue
236
+ seq_fasta_iter = helper.dataframe_to_fasta_iter(
237
+ concept_df.sample(n=self.min_samples, random_state=self.rng_seed),
238
+ self.genome_fasta,
239
+ batch_size=self.batch_size,
240
+ )
241
+ concept = Concept(
242
+ id=self._reserve_id(),
243
+ name=concept_name,
244
+ data_iter=_PairedLoader(seq_fasta_iter, self._control_chrom_dl()),
245
+ )
246
+ self.concepts.append(concept)
247
+ added.append(concept)
248
+ return added
249
+
250
+ def add_bed_chrom_concepts(self, bed_paths: Iterable[str]) -> List[Concept]:
251
+ """Add concepts backed by chromatin signal bigwigs and BED coordinates."""
252
+ added: List[Concept] = []
253
+ for bed in bed_paths:
254
+ bed_df = pd.read_table(
255
+ bed,
256
+ header=None,
257
+ usecols=[0, 1, 2, 3, 4],
258
+ names=["chrom", "start", "end", "strand", "concept_name"],
259
+ )
260
+ for concept_name in bed_df.concept_name.unique():
261
+ concept_df = bed_df.loc[bed_df.concept_name == concept_name]
262
+ if len(concept_df) < self.min_samples:
263
+ logger.warning(
264
+ "Concept %s has %s samples, fewer than min_samples=%s; skipping",
265
+ concept_name,
266
+ len(concept_df),
267
+ self.min_samples,
268
+ )
269
+ continue
270
+ chrom_dl = helper.dataframe_to_chrom_tracks_iter(
271
+ concept_df.sample(n=self.min_samples, random_state=self.rng_seed),
272
+ self.genome_fasta,
273
+ self.bws,
274
+ batch_size=self.batch_size,
275
+ )
276
+ concept = Concept(
277
+ id=self._reserve_id(),
278
+ name=concept_name,
279
+ data_iter=_PairedLoader(self._control_seq_dl(), chrom_dl),
280
+ )
281
+ self.concepts.append(concept)
282
+ added.append(concept)
283
+ return added
284
+
285
+ def all_concepts(self) -> List[Concept]:
286
+ """Return test + control concepts."""
287
+ return [*self.concepts, *self.control_concepts]
288
+
289
+ def summary(self) -> Dict[str, object]:
290
+ """Lightweight run metadata."""
291
+ return {
292
+ "num_concepts": len(self.concepts),
293
+ "num_control": len(self.control_concepts),
294
+ "concept_names": [c.name for c in self.concepts],
295
+ "control_names": [c.name for c in self.control_concepts],
296
+ "input_window_length": self.input_window_length,
297
+ "num_motifs": self.num_motifs,
298
+ "bigwigs": self.bws,
299
+ }
300
+
301
+ def _reserve_id(self, is_control: bool = False) -> int:
302
+ cid = -self._next_concept_id if is_control else self._next_concept_id
303
+ self._next_concept_id += 1
304
+ return cid
305
+
306
+ def apply_transform(self, transform):
307
+ """Apply a transform function to all concepts"""
308
+ for c in self.all_concepts():
309
+ c.data_iter.apply(transform)