tpcav 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tpcav/helper.py ADDED
@@ -0,0 +1,165 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Lightweight data loading helpers for sequences and chromatin tracks.
4
+ """
5
+
6
+ from typing import Iterable, List, Optional
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import seqchromloader as scl
11
+ import torch
12
+ from deeplift.dinuc_shuffle import dinuc_shuffle
13
+ from pyfaidx import Fasta
14
+ from seqchromloader.utils import dna2OneHot, extract_bw
15
+
16
+
17
+ def load_bed_and_center(bed_file: str, window: int) -> pd.DataFrame:
18
+ """
19
+ Load a BED file and center the regions to a fixed window size.
20
+ """
21
+ bed_df = pd.read_table(bed_file, usecols=[0, 1, 2], names=["chrom", "start", "end"])
22
+ bed_df["center"] = ((bed_df["start"] + bed_df["end"]) // 2).astype(int)
23
+ bed_df["start"] = bed_df["center"] - (window // 2)
24
+ bed_df["end"] = bed_df["start"] + window
25
+ bed_df = bed_df[["chrom", "start", "end"]]
26
+ return bed_df
27
+
28
+
29
+ def bed_to_fasta_iter(
30
+ bed_file: str, genome_fasta: str, batch_size: int
31
+ ) -> Iterable[List[str]]:
32
+ """
33
+ Yield sequences from a BED file as fasta strings.
34
+ """
35
+ bed_df = pd.read_table(bed_file, usecols=[0, 1, 2], names=["chrom", "start", "end"])
36
+ yield from dataframe_to_fasta_iter(bed_df, genome_fasta, batch_size)
37
+
38
+
39
+ def dataframe_to_fasta_iter(
40
+ df: pd.DataFrame, genome_fasta: str, batch_size: int
41
+ ) -> Iterable[List[str]]:
42
+ """
43
+ Yield sequences from a DataFrame with columns [chrom, start, end].
44
+ """
45
+ genome = Fasta(genome_fasta)
46
+ fasta_seqs = []
47
+ for row in df.itertuples(index=False):
48
+ seq = str(genome[row.chrom][row.start : row.end]).upper()
49
+ fasta_seqs.append(seq)
50
+ if len(fasta_seqs) == batch_size:
51
+ yield fasta_seqs
52
+ fasta_seqs = []
53
+ if fasta_seqs:
54
+ yield fasta_seqs
55
+
56
+
57
+ class DataFrame2FastaIterator:
58
+ """
59
+ Iterator class to yield sequences from a DataFrame with columns [chrom, start, end].
60
+ """
61
+
62
+ def __init__(self, df: pd.DataFrame, genome_fasta: str, batch_size: int):
63
+ self.genome_fasta = genome_fasta
64
+ self.df = df
65
+ self.batch_size = batch_size
66
+
67
+ def __iter__(self):
68
+ return dataframe_to_fasta_iter(
69
+ self.df, genome_fasta=self.genome_fasta, batch_size=self.batch_size
70
+ )
71
+
72
+
73
+ def bed_to_chrom_tracks_iter(
74
+ bed_file: str, genome_fasta: str, bigwigs: List[str]
75
+ ) -> Iterable[torch.Tensor]:
76
+ """
77
+ Yield chromatin tracks for BED regions using bigwig files.
78
+ Each batch is shaped [batch, num_tracks, window].
79
+ """
80
+ bed_df = pd.read_table(bed_file, usecols=[0, 1, 2], names=["chrom", "start", "end"])
81
+ yield from dataframe_to_chrom_tracks_iter(bed_df, genome_fasta, bigwigs)
82
+
83
+
84
+ def dataframe_to_chrom_tracks_iter(
85
+ df: pd.DataFrame,
86
+ bigwigs: List[str] | None,
87
+ batch_size: int = 1,
88
+ ) -> Iterable[torch.Tensor | None]:
89
+ """
90
+ Yield chromatin tracks for regions from a DataFrame using bigwig files.
91
+ """
92
+ if bigwigs is not None and len(bigwigs) > 0:
93
+ chrom_arrs = []
94
+ for row in df.itertuples(index=False):
95
+ chrom = extract_bw(
96
+ row.chrom, row.start, row.end, getattr(row, "strand", "+"), bigwigs
97
+ )
98
+ chrom_arrs.append(chrom)
99
+ if batch_size is not None and len(chrom_arrs) == batch_size:
100
+ yield torch.tensor(np.stack(chrom_arrs))
101
+ chrom_arrs = []
102
+ if chrom_arrs:
103
+ yield torch.tensor(np.stack(chrom_arrs))
104
+ else:
105
+ while True:
106
+ yield torch.full((batch_size, 1), torch.nan)
107
+
108
+
109
+ class DataFrame2ChromTracksIterator:
110
+ """
111
+ Iterator class to yield chromatin tracks from a DataFrame with columns [chrom, start, end].
112
+ """
113
+
114
+ def __init__(
115
+ self,
116
+ df: pd.DataFrame,
117
+ bigwigs: List[str] | None,
118
+ batch_size: int = 1,
119
+ ):
120
+ self.bigwigs = bigwigs
121
+ self.df = df
122
+ self.batch_size = batch_size
123
+
124
+ def __iter__(self):
125
+ return dataframe_to_chrom_tracks_iter(
126
+ self.df,
127
+ bigwigs=self.bigwigs,
128
+ batch_size=self.batch_size,
129
+ )
130
+
131
+
132
+ def fasta_to_one_hot_sequences(seqs: List[str]) -> torch.Tensor:
133
+ """
134
+ Return one-hot encoded numpy arrays [4, L] for list of fasta sequences.
135
+ """
136
+ return torch.tensor(np.stack([dna2OneHot(seq) for seq in seqs]))
137
+
138
+
139
+ def random_regions_dataframe(
140
+ genome_size_file: str, window: int, n: int, seed: int = 1
141
+ ) -> pd.DataFrame:
142
+ """
143
+ Generate random regions as a DataFrame with columns [chrom, start, end].
144
+ """
145
+ return scl.random_coords(gs=genome_size_file, l=window, n=n, seed=seed)[
146
+ ["chrom", "start", "end"]
147
+ ]
148
+
149
+
150
+ def dinuc_shuffle_sequences(
151
+ seqs: Iterable[str], num_shuffles: int = 10, seed: int = 1
152
+ ) -> List[List[str]]:
153
+ """
154
+ For each fasta sequence, yield a list of dinucleotide-shuffled sequences.
155
+ """
156
+ rng = np.random.RandomState(seed)
157
+ results = []
158
+ for seq in seqs:
159
+ shuffles = dinuc_shuffle(
160
+ seq,
161
+ num_shufs=num_shuffles,
162
+ rng=rng,
163
+ )
164
+ results.append(shuffles)
165
+ return results
tpcav/logging_utils.py ADDED
@@ -0,0 +1,10 @@
1
+ import logging
2
+
3
+
4
+ def set_verbose(level: str = "INFO") -> None:
5
+ """
6
+ Set logging level for the tpcav_package (and root logger).
7
+ """
8
+ lvl = getattr(logging, level.upper(), logging.INFO)
9
+ logging.getLogger().setLevel(lvl)
10
+ logging.getLogger("tpcav").setLevel(lvl)
tpcav/tpcav_model.py ADDED
@@ -0,0 +1,427 @@
1
+ import logging
2
+ from functools import partial
3
+ from typing import Dict, Iterable, List, Optional, Tuple
4
+
5
+ import torch
6
+ from captum.attr import DeepLift
7
+ from scipy.linalg import svd
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def _abs_attribution_func(multipliers, inputs, baselines):
13
+ "Multiplier x abs(inputs - baselines) to avoid double-sign effects."
14
+ # print(f"inputs: {inputs[1][:5]}")
15
+ # print(f"baselines: {baselines[1][:5]}")
16
+ # print(f"multipliers: {multipliers[0][:5]}")
17
+ # print(f"multipliers: {multipliers[1][:5]}")
18
+ return tuple(
19
+ (input_ - baseline).abs() * multiplier
20
+ for input_, baseline, multiplier in zip(inputs, baselines, multipliers)
21
+ )
22
+
23
+
24
+ class TPCAV(torch.nn.Module):
25
+ """End-to-end PCA fitting, projection, and attribution utilities."""
26
+
27
+ def __init__(
28
+ self,
29
+ model,
30
+ device: Optional[str] = None,
31
+ layer_name: Optional[str] = None,
32
+ ) -> None:
33
+ """
34
+ layer_name: optional module name to intercept activations via forward hook
35
+ (useful if forward_until_select_layer is not implemented).
36
+ """
37
+ super().__init__()
38
+ self.model = model
39
+ self.device = device or ("cuda:0" if torch.cuda.is_available() else "cpu")
40
+ self.model.to(self.device)
41
+ self.model.eval()
42
+ self.fitted = False
43
+ self.layer_name = layer_name
44
+
45
+ def list_module_names(self) -> List[str]:
46
+ """List all module names in the model for layer selection."""
47
+ return [name for name, _ in self.model.named_modules()]
48
+
49
+ def tpcav_state_dict(self) -> Dict:
50
+ """Export TPCAV buffers."""
51
+ return {
52
+ "layer_name": self.layer_name,
53
+ "zscore_mean": getattr(self, "zscore_mean", None),
54
+ "zscore_std": getattr(self, "zscore_std", None),
55
+ "pca_inv": getattr(self, "pca_inv", None),
56
+ "orig_shape": getattr(self, "orig_shape", None),
57
+ }
58
+
59
+ def restore_tpcav_state(self, tpcav_state_dict: Dict) -> None:
60
+ """Load PCA buffers from disk."""
61
+ self.layer_name = tpcav_state_dict["layer_name"]
62
+ self._set_buffer("zscore_mean", tpcav_state_dict["zscore_mean"])
63
+ self._set_buffer("zscore_std", tpcav_state_dict["zscore_std"])
64
+ self._set_buffer("pca_inv", tpcav_state_dict["pca_inv"])
65
+ self._set_buffer("orig_shape", tpcav_state_dict["orig_shape"])
66
+ self.fitted = True
67
+ logger.warning(
68
+ "Restored TPCAV state, please set model attribute!\n\n Example: self.model = Model_class()",
69
+ )
70
+
71
+ def fit_pca(
72
+ self,
73
+ concepts: Iterable,
74
+ num_samples_per_concept: int = 10,
75
+ num_pc: Optional[int] | str = None,
76
+ ) -> Dict[str, torch.Tensor]:
77
+ """Sample activations, compute PCA, and attach buffers to the model."""
78
+ sampled_avs = []
79
+ for concept in concepts:
80
+ avs = self._sample_concept(concept, num_samples=num_samples_per_concept)
81
+ logger.info(
82
+ "Sampled %s activations from concept %s", avs.shape[0], concept.name
83
+ )
84
+ sampled_avs.append(avs)
85
+ all_avs = torch.cat(sampled_avs)
86
+ orig_shape = all_avs.shape
87
+ flat = all_avs.flatten(start_dim=1)
88
+
89
+ mean = flat.mean(dim=0)
90
+ std = flat.std(dim=0)
91
+ std[std == 0] = -1
92
+ standardized = (flat - mean) / std
93
+
94
+ v_inverse = None
95
+ if num_pc is None or num_pc == "full":
96
+ _, _, v = svd(standardized, lapack_driver="gesvd", full_matrices=False)
97
+ v_inverse = torch.tensor(v)
98
+ elif int(num_pc) == 0:
99
+ v_inverse = None
100
+ else:
101
+ _, _, v = svd(standardized, lapack_driver="gesvd", full_matrices=False)
102
+ v_inverse = torch.tensor(v[: int(num_pc)])
103
+
104
+ self._set_buffer("zscore_mean", mean.to(self.device))
105
+ self._set_buffer("zscore_std", std.to(self.device))
106
+ self._set_buffer(
107
+ "pca_inv", v_inverse.to(self.device) if v_inverse is not None else None
108
+ )
109
+ self._set_buffer("orig_shape", torch.tensor(orig_shape).to(self.device))
110
+ self.fitted = True
111
+
112
+ return {
113
+ "zscore_mean": mean,
114
+ "zscore_std": std,
115
+ "pca_inv": v_inverse,
116
+ "orig_shape": torch.tensor(orig_shape),
117
+ }
118
+
119
+ def project_activations(
120
+ self, activations: torch.Tensor
121
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
122
+ """Project flattened activations into PCA space and residual."""
123
+ if not self.fitted:
124
+ raise RuntimeError("Call fit_pca before projecting activations.")
125
+
126
+ y = activations.flatten(start_dim=1).to(self.device)
127
+ if self.pca_inv is not None:
128
+ V = self.pca_inv.T
129
+ zscore_mean = getattr(self, "zscore_mean", 0.0)
130
+ zscore_std = getattr(self, "zscore_std", 1.0)
131
+ y_standardized = (y - zscore_mean) / zscore_std
132
+ y_projected = torch.matmul(y_standardized, V)
133
+ y_residual = y_standardized - torch.matmul(y_projected, self.pca_inv)
134
+ return y_residual, y_projected
135
+ else:
136
+ return y, None
137
+
138
+ def concept_embeddings(self, concept, num_samples: int) -> torch.Tensor:
139
+ """Return concatenated projected + residual activations for a concept."""
140
+ avs = self._sample_concept(concept, num_samples=num_samples)
141
+ residual, projected = self.project_activations(avs)
142
+ if projected is not None:
143
+ return torch.cat((projected, residual), dim=1)
144
+ return residual
145
+
146
+ def forward_from_embeddings_at_layer(
147
+ self,
148
+ avs_residual: torch.Tensor,
149
+ avs_projected: Optional[torch.Tensor] = None,
150
+ model_inputs: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
151
+ ) -> torch.Tensor:
152
+ """
153
+ Resume model forward by injecting activations at a named layer.
154
+
155
+ If layer_name is not set, falls back to forward_with_embeddings which
156
+ expects model.forward_from_projected_and_residual to exist.
157
+ """
158
+ name = self.layer_name
159
+ if name is None:
160
+ raise ValueError("layer name must be defined")
161
+ if model_inputs is None:
162
+ raise ValueError(
163
+ "model_inputs (seq, chrom) must be provided to run forward."
164
+ )
165
+
166
+ y_hat = self.embedding_to_layer_activation(avs_residual, avs_projected)
167
+ return self.forward_patched(
168
+ model_inputs,
169
+ layer_activation=y_hat,
170
+ )
171
+
172
+ def layer_attributions(
173
+ self,
174
+ target_batches: Iterable,
175
+ baseline_batches: Iterable,
176
+ multiply_by_inputs: bool = True,
177
+ ) -> Dict[str, torch.Tensor]:
178
+ """Compute DeepLift attributions on PCA embedding space.
179
+
180
+ target_batches and baseline_batches should yield (seq, chrom) pairs of matching length.
181
+ """
182
+ if not self.fitted:
183
+ raise RuntimeError("Call fit_pca before attributing.")
184
+ self.forward = self.forward_from_embeddings_at_layer
185
+ deeplift = DeepLift(self, multiply_by_inputs=multiply_by_inputs)
186
+
187
+ attributions = []
188
+ for inputs, binputs in zip(target_batches, baseline_batches):
189
+ avs = self._layer_output(*[i.to(self.device) for i in inputs])
190
+ avs_residual, avs_projected = self.project_activations(avs)
191
+
192
+ bavs = self._layer_output(*[bi.to(self.device) for bi in binputs])
193
+ bavs_residual, bavs_projected = self.project_activations(bavs)
194
+
195
+ # detach the projected tensor as it's connnected to the original input graph,
196
+ # detaching it would keep the gradients on it
197
+ if avs_projected is not None:
198
+ avs_projected = avs_projected.detach()
199
+ bavs_projected = bavs_projected.detach()
200
+ attribution = deeplift.attribute(
201
+ (avs_residual.to(self.device), avs_projected.to(self.device)),
202
+ baselines=(
203
+ bavs_residual.to(self.device),
204
+ bavs_projected.to(self.device),
205
+ ),
206
+ additional_forward_args=(inputs,),
207
+ custom_attribution_func=(
208
+ None if not multiply_by_inputs else _abs_attribution_func
209
+ ),
210
+ )
211
+ attr_residual, attr_projected = attribution
212
+ attribution = torch.cat((attr_projected, attr_residual), dim=1)
213
+ else:
214
+ attribution = deeplift.attribute(
215
+ (avs_residual.to(self.device),),
216
+ baselines=(bavs_residual.to(self.device),),
217
+ additional_forward_args=(
218
+ None,
219
+ inputs,
220
+ ),
221
+ custom_attribution_func=(
222
+ None if not multiply_by_inputs else _abs_attribution_func
223
+ ),
224
+ )[0]
225
+
226
+ attributions.append(attribution.detach().cpu())
227
+
228
+ with torch.no_grad():
229
+ del (
230
+ avs,
231
+ avs_projected,
232
+ avs_residual,
233
+ bavs,
234
+ bavs_projected,
235
+ bavs_residual,
236
+ )
237
+ torch.cuda.empty_cache()
238
+
239
+ return {
240
+ "attributions": torch.cat(attributions),
241
+ }
242
+
243
+ def input_attributions(
244
+ self,
245
+ target_batches: Iterable,
246
+ baseline_batches: Iterable,
247
+ multiply_by_inputs: bool = True,
248
+ cavs_list: List[torch.Tensor] | None = None,
249
+ ) -> List[torch.Tensor]:
250
+ """Compute DeepLift attributions on PCA embedding space.
251
+
252
+ target_batches and baseline_batches should yield (seq, chrom) pairs of matching length.
253
+ """
254
+ if not self.fitted:
255
+ raise RuntimeError("Call fit_pca before attributing.")
256
+ self.forward = partial(
257
+ self.forward_patched_tensor,
258
+ layer_activation=None,
259
+ cavs_list=cavs_list,
260
+ mute_x_avs=False,
261
+ mute_remainder=True,
262
+ )
263
+ deeplift = DeepLift(self, multiply_by_inputs=multiply_by_inputs)
264
+
265
+ attributions = []
266
+ for inputs, binputs in zip(target_batches, baseline_batches):
267
+ attribution = deeplift.attribute(
268
+ tuple([i.to(self.device) for i in inputs]),
269
+ baselines=tuple([bi.to(self.device) for bi in binputs]),
270
+ )
271
+ attributions.append(
272
+ [a.detach().cpu() for a in attribution]
273
+ if isinstance(attribution, tuple)
274
+ else attribution.detach().cpu()
275
+ )
276
+
277
+ return [torch.cat(z) for z in zip(*attributions)]
278
+
279
+ def _sample_concept(self, concept, num_samples: int) -> torch.Tensor:
280
+ avs: List[torch.Tensor] = []
281
+ num = 0
282
+ for inputs in concept.data_iter:
283
+ av = self._layer_output(*[i.to(self.device) for i in inputs])
284
+ avs.append(av.detach().cpu())
285
+ num += av.shape[0]
286
+ if num >= num_samples:
287
+ break
288
+ if not avs:
289
+ raise ValueError(f"No activations gathered for concept {concept.name}")
290
+ return torch.cat(avs)[:num_samples]
291
+
292
+ def _layer_output(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor:
293
+ """Return activations from the configured layer or model hook."""
294
+ if self.layer_name is None:
295
+ # No layer configured; return model output directly.
296
+ self.model(*inputs)
297
+ layer = self._resolve_layer(self.layer_name)
298
+ cache: List[torch.Tensor] = []
299
+
300
+ def hook_fn(_module, _inputs, output):
301
+ cache.append(output)
302
+
303
+ handle = layer.register_forward_hook(hook_fn)
304
+ try:
305
+ inputs = [inp.to(self.device) if inp is not None else inp for inp in inputs]
306
+ _ = self.model(*inputs)
307
+ finally:
308
+ handle.remove()
309
+
310
+ if not cache:
311
+ raise RuntimeError(f"No activation captured for layer {self.layer_name}")
312
+ return cache[0]
313
+
314
+ def _resolve_layer(self, name: str):
315
+ for module_name, module in self.model.named_modules():
316
+ if module_name == name:
317
+ return module
318
+ raise ValueError(f"Layer {name} not found in model.")
319
+
320
+ def embedding_to_layer_activation(
321
+ self, avs_residual: torch.Tensor, avs_projected: Optional[torch.Tensor]
322
+ ) -> torch.Tensor:
323
+ """
324
+ Combine projected/residual embeddings into the layer activation space,
325
+ mirroring scripts/models.py merge logic.
326
+ """
327
+ y_hat = torch.matmul(avs_projected, self.pca_inv) + avs_residual
328
+ y_hat = y_hat * self.zscore_std + self.zscore_mean
329
+
330
+ return y_hat.reshape((-1, *self.orig_shape[1:]))
331
+
332
+ def forward_patched_tensor(
333
+ self,
334
+ *model_inputs: torch.Tensor,
335
+ layer_activation: Optional[torch.Tensor] = None,
336
+ cavs_list: Optional[List[torch.Tensor]] = None,
337
+ mute_x_avs: bool = False,
338
+ mute_remainder: bool = True,
339
+ ) -> torch.Tensor:
340
+ return self.forward_patched(
341
+ model_inputs, layer_activation, cavs_list, mute_x_avs, mute_remainder
342
+ )
343
+
344
+ def forward_patched(
345
+ self,
346
+ model_inputs: Tuple[torch.Tensor, Optional[torch.Tensor]],
347
+ layer_activation: Optional[torch.Tensor] = None,
348
+ cavs_list: Optional[List[torch.Tensor]] = None,
349
+ mute_x_avs: bool = False,
350
+ mute_remainder: bool = True,
351
+ ) -> torch.Tensor:
352
+ """
353
+ Full forward pass with optional activation replacement and/or CAV-based gradient muting.
354
+ """
355
+ name = self.layer_name
356
+ if name is None:
357
+ raise ValueError("layer_name must be set on TPCAV to use forward_patched.")
358
+ layer = self._resolve_layer(name)
359
+
360
+ def hook_fn(_module, _inputs, output):
361
+ y = layer_activation if layer_activation is not None else output
362
+ if cavs_list is None or len(cavs_list) == 0:
363
+ return y
364
+ return self._disentangle_with_cavs(
365
+ y, cavs_list, mute_x_avs=mute_x_avs, mute_remainder=mute_remainder
366
+ )
367
+
368
+ handle = layer.register_forward_hook(hook_fn)
369
+ try:
370
+ output = self.model(*[i.to(self.device) for i in model_inputs])
371
+ finally:
372
+ handle.remove()
373
+ return output
374
+
375
+ def _disentangle_with_cavs(
376
+ self,
377
+ layer_output: torch.Tensor,
378
+ cavs_list: List[torch.Tensor],
379
+ mute_x_avs: bool = False,
380
+ mute_remainder: bool = True,
381
+ ) -> torch.Tensor:
382
+ """
383
+ Project activations into CAV subspace and optionally zero gradients for
384
+ subspaces (default mutes orthogonal/remainder gradients).
385
+ """
386
+ y = layer_output.flatten(start_dim=1)
387
+ y_residual, y_projected = self.project_activations(y)
388
+ y_pca_all = torch.cat([y_projected, y_residual], dim=1)
389
+
390
+ cavs_matrix = torch.stack(cavs_list, dim=1).to(y.device) # [dims, #cavs]
391
+
392
+ if cavs_matrix.shape[1] > cavs_matrix.shape[0]:
393
+ logger.warning(
394
+ "CAV matrix has more columns than rows; remainder should be near zero."
395
+ )
396
+
397
+ mrank = torch.linalg.matrix_rank(cavs_matrix)
398
+ cavs_ortho = torch.linalg.qr(cavs_matrix, mode="reduced").Q[:, :mrank].detach()
399
+ if not torch.allclose(
400
+ cavs_ortho.T @ cavs_ortho,
401
+ torch.eye(mrank, device=cavs_ortho.device),
402
+ atol=1e-3,
403
+ rtol=1e-3,
404
+ ):
405
+ logger.warning("Q^TQ not identity; check CAV matrix conditioning.")
406
+
407
+ y_pca_x_avs = y_pca_all @ cavs_ortho @ cavs_ortho.T
408
+ y_pca_remainder = y_pca_all - y_pca_x_avs
409
+
410
+ if mute_x_avs:
411
+ y_pca_x_avs.register_hook(lambda grad: torch.zeros_like(grad))
412
+ if mute_remainder:
413
+ y_pca_remainder.register_hook(lambda grad: torch.zeros_like(grad))
414
+
415
+ y_pca_hat = y_pca_x_avs + y_pca_remainder
416
+
417
+ dim_projected = y_projected.shape[1] if y_projected is not None else 0
418
+
419
+ return self.embedding_to_layer_activation(
420
+ y_pca_hat[:, dim_projected:], y_pca_hat[:, :dim_projected]
421
+ )
422
+
423
+ def _set_buffer(self, name: str, value: Optional[torch.Tensor]) -> None:
424
+ if hasattr(self.model, "_buffers") and name in self.model._buffers:
425
+ self._buffers[name] = value # type: ignore[index]
426
+ else:
427
+ self.register_buffer(name, value)