tmnt 0.7.58__py3-none-any.whl → 0.7.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tmnt/__init__.py CHANGED
@@ -1,11 +1,8 @@
1
1
  # coding: utf-8
2
2
 
3
- import os
4
3
  from .distribution import *
5
4
  from .preprocess import *
6
- #from .models import *
5
+ from .sparse import *
7
6
  from .utils import *
8
7
 
9
- os.environ["MXNET_STORAGE_FALLBACK_LOG_VERBOSE"] = "0"
10
-
11
- __all__ = distribution.__all__ + preprocess.__all__ + utils.__all__
8
+ __all__ = distribution.__all__ + preprocess.__all__ + utils.__all__ + sparse.__all__
@@ -0,0 +1,12 @@
1
+ # coding: utf-8
2
+ """
3
+ Copyright (c) 2019 The MITRE Corporation.
4
+ """
5
+
6
+
7
+ from .config import *
8
+ from .estimator import *
9
+ from .inference import *
10
+ from .modeling import *
11
+
12
+ __all__ = config.__all__ + estimator.__all__ + inference.__all__ + modeling.__all__
tmnt/sparse/config.py ADDED
@@ -0,0 +1,33 @@
1
+ import torch
2
+
3
+ def get_default_cfg():
4
+ default_cfg = {
5
+ "seed": 49,
6
+ "batch_size": 4096,
7
+ "lr": 3e-4,
8
+ "num_samples": int(1e9),
9
+ "l1_coeff": 0,
10
+ "beta1": 0.9,
11
+ "beta2": 0.99,
12
+ "max_grad_norm": 100000,
13
+ "seq_len": 128,
14
+ "dtype": torch.float32,
15
+ "site": "resid_pre",
16
+ "layer": 8,
17
+ "act_size": 768,
18
+ "dict_size": 12288,
19
+ "device": "cuda:0",
20
+ "input_unit_norm": True,
21
+ "perf_log_freq": 1000,
22
+ "sae_type": "topk",
23
+ "checkpoint_freq": 10000,
24
+ "n_batches_to_dead": 5,
25
+
26
+ # (Batch)TopKSAE specific
27
+ "top_k": 32,
28
+ "top_k_aux": 512,
29
+ "aux_penalty": (1/32),
30
+ # for jumprelu
31
+ "bandwidth": 0.001,
32
+ }
33
+ return default_cfg
@@ -0,0 +1,96 @@
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from datasets import Dataset, IterableDataset
4
+ import tqdm
5
+ from datasets.arrow_writer import ArrowWriter
6
+ from tmnt.inference import SeqVEDInferencer
7
+ import io, json
8
+ from tmnt.sparse.modeling import BaseAutoencoder
9
+ from typing import List
10
+
11
+ class ActivationsStore:
12
+ def __init__(
13
+ self,
14
+ cfg: dict,
15
+ ):
16
+ self.device = cfg["device"]
17
+ self.activation_path = cfg["activation_path"]
18
+ shuffle = cfg.get("shuffle_data", False)
19
+ #self.dataset = Dataset.from_file(self.activation_path).with_format('torch', device=self.device)
20
+ self.dataset = Dataset.from_file(self.activation_path).select_columns(['data']).shuffle(seed=42).with_format('torch', device=self.device)
21
+ self.dataloader = DataLoader(self.dataset,
22
+ batch_size=cfg["batch_size"], shuffle=shuffle)
23
+ self.dataloader_iter = iter(self.dataloader)
24
+ self.cfg = cfg
25
+
26
+ def next_batch(self):
27
+ try:
28
+ return next(self.dataloader_iter)['data']
29
+ except (StopIteration, AttributeError):
30
+ self.dataloader_iter = iter(self.dataloader)
31
+ return next(self.dataloader_iter)['data']
32
+
33
+ def build_activation_store(json_input_texts, emb_model_path, arrow_output, max_seq_len=512, json_txt_key='text', device='cpu'):
34
+
35
+ inferencer = SeqVEDInferencer.from_saved(emb_model_path, max_length=max_seq_len, device=device)
36
+ with io.open(json_input_texts) as fp:
37
+ with ArrowWriter(path=arrow_output) as writer:
38
+ for l in fp:
39
+ js = json.loads(l)
40
+ tokenization_result = inferencer.prep_text(js[json_txt_key])
41
+ llm_out = inferencer.model.llm(tokenization_result['input_ids'].to(inferencer.device),
42
+ tokenization_result['attention_mask'].to(inferencer.device))
43
+ cls_vec = inferencer.model._get_embedding(llm_out, tokenization_result['attention_mask'].to(inferencer.device))
44
+ enc : List[float] = cls_vec.cpu().detach()[0].tolist()
45
+ writer.write({'data': enc})
46
+ writer.finalize()
47
+
48
+ def build_activation_store_batching(json_input_texts, emb_model_path, arrow_output, max_seq_len=512, batch_size=42, json_txt_key='text', device='cpu'):
49
+ inferencer = SeqVEDInferencer.from_saved(emb_model_path, max_length=max_seq_len, device=device)
50
+ def encode_batch(txt_batch):
51
+ tokenization_result = inferencer.prep_text(txt_batch)
52
+ llm_out = inferencer.model.llm(tokenization_result['input_ids'].to(inferencer.device),
53
+ tokenization_result['attention_mask'].to(inferencer.device))
54
+ cls_vec = inferencer.model._get_embedding(llm_out, tokenization_result['attention_mask'].to(inferencer.device))
55
+ encs : List[List[float]] = cls_vec.cpu().detach().tolist()
56
+ return zip(txt_batch, encs)
57
+
58
+ def write_encodings(writer: ArrowWriter, txt_enc_pairs):
59
+ for (t, e) in txt_enc_pairs:
60
+ writer.write({'text': t, 'data': e})
61
+
62
+ with io.open(json_input_texts) as fp:
63
+ with ArrowWriter(path=arrow_output) as writer:
64
+ txt_batch = []
65
+ for l in fp:
66
+ js = json.loads(l)
67
+ txt_batch.append(js[json_txt_key])
68
+ if len(txt_batch) >= batch_size:
69
+ encodings = encode_batch(txt_batch)
70
+ write_encodings(writer, encodings)
71
+ txt_batch = []
72
+ if len(txt_batch) > 0:
73
+ encodings = encode_batch(txt_batch)
74
+ write_encodings(writer, encodings)
75
+ writer.finalize()
76
+
77
+
78
+ def train_sparse_encoder_decoder(sed: BaseAutoencoder, activation_store: ActivationsStore, cfg: dict):
79
+ num_batches = cfg["num_samples"] // cfg["batch_size"]
80
+ optimizer = torch.optim.Adam(sed.parameters(), lr=cfg["lr"], betas=(cfg["beta1"], cfg["beta2"]))
81
+ pbar = tqdm.trange(num_batches)
82
+
83
+ for i in pbar:
84
+ batch = activation_store.next_batch()
85
+ sed_output = sed(batch)
86
+
87
+ loss = sed_output["loss"]
88
+ pbar.set_postfix({"Loss": f"{loss.item():.4f}", "Dead": f"{sed_output['num_dead_features']:.4f}", "L0": f"{sed_output['l0_norm']:.4f}", "L2": f"{sed_output['l2_loss']:.4f}", "L1": f"{sed_output['l1_loss']:.4f}", "L1_norm": f"{sed_output['l1_norm']:.4f}"})
89
+ loss.backward()
90
+ torch.nn.utils.clip_grad_norm_(sed.parameters(), cfg["max_grad_norm"])
91
+ sed.make_decoder_weights_and_grad_unit_norm()
92
+ optimizer.step()
93
+ optimizer.zero_grad()
94
+
95
+
96
+
@@ -0,0 +1,54 @@
1
+ from tmnt.inference import SeqVEDInferencer
2
+ from scipy.sparse import csr_matrix
3
+ import numpy as np
4
+ from typing import List, Tuple
5
+ from tmnt.distribution import ConceptLogisticGaussianDistribution
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+ from datasets import Dataset, IterableDataset
9
+ import tqdm
10
+ from datasets.arrow_writer import ArrowWriter
11
+ from tmnt.inference import SeqVEDInferencer
12
+ import io, json
13
+
14
+
15
+ def csr_to_indices_data(csr_mat):
16
+ return [ (csr_mat.getrow(ri).indices, csr_mat.getrow(ri).data) for ri in range(csr_mat.shape[0]) ]
17
+
18
+ def batch_process_to_arrow(model_path, json_input_texts, output_db_path, max_seq_len=512, device='cuda', batch_size=200, json_txt_key='text'):
19
+
20
+ inferencer = SeqVEDInferencer.from_saved(model_path, max_length=max_seq_len, device=device)
21
+ def encode_batch(txt_batch):
22
+ tokenization_result = inferencer.prep_text(txt_batch)
23
+ llm_out = inferencer.model.llm(tokenization_result['input_ids'].to(inferencer.device),
24
+ tokenization_result['attention_mask'].to(inferencer.device))
25
+ cls_vecs = inferencer.model._get_embedding(llm_out, tokenization_result['attention_mask'].to(inferencer.device))
26
+ raw_concepts = inferencer.model.latent_distribution.get_sparse_encoding(cls_vecs).cpu().detach()
27
+ mu_emb = inferencer.model.latent_distribution.get_mu_encoding(cls_vecs)
28
+ encs : List[List[float]] = cls_vecs.cpu().detach().tolist()
29
+ sparse_concepts : List[Tuple[List[int], List[float]]] = csr_to_indices_data(csr_matrix(raw_concepts))
30
+ topic_embeddings : List[List[float]] = mu_emb.cpu().detach().tolist()
31
+ print("Lengths: {}, {}, {}, {}".format(len(txt_batch), len(encs), len(sparse_concepts), len(topic_embeddings)))
32
+ return zip(txt_batch, encs, sparse_concepts, topic_embeddings)
33
+
34
+ def write_encodings(writer: ArrowWriter, txt_enc_pairs):
35
+ for (text, embedding, sparse_indices_and_data, topic_embedding) in txt_enc_pairs:
36
+ writer.write({'text': text, 'embedding': embedding, 'indices': sparse_indices_and_data[0],
37
+ 'values': sparse_indices_and_data[1], 'topic_embedding': topic_embedding})
38
+
39
+ with io.open(json_input_texts) as fp:
40
+ with ArrowWriter(path=output_db_path) as writer:
41
+ txt_batch = []
42
+ for l in fp:
43
+ js = json.loads(l)
44
+ txt_batch.append(js[json_txt_key])
45
+ if len(txt_batch) >= batch_size:
46
+ encodings = encode_batch(txt_batch)
47
+ write_encodings(writer, encodings)
48
+ txt_batch = []
49
+ if len(txt_batch) > 0:
50
+ encodings = encode_batch(txt_batch)
51
+ write_encodings(writer, encodings)
52
+ writer.finalize()
53
+
54
+
@@ -0,0 +1,382 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.autograd as autograd
5
+
6
+ class BaseEncoder(nn.Module):
7
+
8
+ def __init__(self, cfg):
9
+ super().__init__()
10
+ self.cfg = cfg
11
+ torch.manual_seed(self.cfg['seed'])
12
+ self.b_enc = nn.Parameter(torch.zeros(self.cfg['dict_size']))
13
+ self.W_enc = nn.Parameter(
14
+ torch.nn.init.kaiming_uniform_(
15
+ torch.empty(self.cfg['act_size'], self.cfg['dict_size'])
16
+ )
17
+ )
18
+ self.to(cfg['dtype']).to(cfg['device'])
19
+
20
+ def get_dict_size(self):
21
+ return int(self.cfg['dict_size'])
22
+
23
+ def preprocess_input(self, x):
24
+ if self.cfg.get("input_unit_norm", False):
25
+ x_mean = x.mean(dim=-1, keepdim=True)
26
+ x = x - x_mean
27
+ x_std = x.std(dim=-1, keepdim=True)
28
+ x = x / (x_std + 1e-5)
29
+ return x, x_mean, x_std
30
+ else:
31
+ return x, None, None
32
+
33
+
34
+ class BaseAutoencoder(nn.Module):
35
+ """Base class for autoencoder models."""
36
+
37
+ def __init__(self, cfg, encoder: BaseEncoder):
38
+ super().__init__()
39
+
40
+ self.cfg = cfg
41
+ torch.manual_seed(self.cfg["seed"])
42
+
43
+ self.encoder = encoder
44
+
45
+ self.b_dec = nn.Parameter(torch.zeros(self.cfg["act_size"]))
46
+ self.W_dec = nn.Parameter(
47
+ torch.nn.init.kaiming_uniform_(
48
+ torch.empty(self.cfg["dict_size"], self.cfg["act_size"])
49
+ )
50
+ )
51
+ self.W_dec.data[:] = self.encoder.W_enc.t().data
52
+ self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
53
+ self.num_batches_not_active = torch.zeros((self.cfg["dict_size"],)).to(
54
+ cfg["device"]
55
+ )
56
+
57
+ self.to(cfg["dtype"]).to(cfg["device"])
58
+
59
+ def postprocess_output(self, x_reconstruct, x_mean, x_std):
60
+ if self.cfg.get("input_unit_norm", False):
61
+ x_reconstruct = x_reconstruct * x_std + x_mean
62
+ return x_reconstruct
63
+
64
+ @torch.no_grad()
65
+ def make_decoder_weights_and_grad_unit_norm(self):
66
+ W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
67
+ W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(
68
+ -1, keepdim=True
69
+ ) * W_dec_normed
70
+ self.W_dec.grad -= W_dec_grad_proj
71
+ self.W_dec.data = W_dec_normed
72
+
73
+ def update_inactive_features(self, acts):
74
+ self.num_batches_not_active += (acts.sum(0) == 0).float()
75
+ self.num_batches_not_active[acts.sum(0) > 0] = 0
76
+
77
+
78
+ class BatchTopKEncoder(BaseEncoder):
79
+ def __init__(self, cfg):
80
+ super().__init__(cfg)
81
+
82
+ def forward(self, x):
83
+ x, x_mean, x_std = self.preprocess_input(x)
84
+
85
+ acts = F.relu(x @ self.W_enc)
86
+ acts_topk = torch.topk(acts.flatten(), self.cfg["top_k"] * x.shape[0], dim=-1)
87
+ acts_topk = (
88
+ torch.zeros_like(acts.flatten())
89
+ .scatter(-1, acts_topk.indices, acts_topk.values)
90
+ .reshape(acts.shape)
91
+ )
92
+ return acts, acts_topk, x, x_mean, x_std
93
+
94
+
95
+
96
+ class BatchTopKSAE(BaseAutoencoder):
97
+ def __init__(self, cfg:dict , encoder: BatchTopKEncoder):
98
+ super().__init__(cfg, encoder)
99
+
100
+ def forward(self, x):
101
+ acts, acts_topk, x, x_mean, x_std = self.encoder(x)
102
+ x_reconstruct = acts_topk @ self.W_dec + self.b_dec
103
+
104
+ self.update_inactive_features(acts_topk)
105
+ output = self.get_loss_dict(x, x_reconstruct, acts, acts_topk, x_mean, x_std)
106
+ return output
107
+
108
+
109
+ def get_loss_dict(self, x, x_reconstruct, acts, acts_topk, x_mean, x_std):
110
+ l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
111
+ l1_norm = acts_topk.float().abs().sum(-1).mean()
112
+ l1_loss = self.cfg["l1_coeff"] * l1_norm
113
+ l0_norm = (acts_topk > 0).float().sum(-1).mean()
114
+ aux_loss = self.get_auxiliary_loss(x, x_reconstruct, acts)
115
+ loss = l2_loss + l1_loss + aux_loss
116
+ num_dead_features = (
117
+ self.num_batches_not_active > self.cfg["n_batches_to_dead"]
118
+ ).sum()
119
+ sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
120
+ output = {
121
+ "sae_out": sae_out,
122
+ "feature_acts": acts_topk,
123
+ "num_dead_features": num_dead_features,
124
+ "loss": loss,
125
+ "l1_loss": l1_loss,
126
+ "l2_loss": l2_loss,
127
+ "l0_norm": l0_norm,
128
+ "l1_norm": l1_norm,
129
+ "aux_loss": aux_loss,
130
+ }
131
+ return output
132
+
133
+ def get_auxiliary_loss(self, x, x_reconstruct, acts):
134
+ dead_features = self.num_batches_not_active >= self.cfg["n_batches_to_dead"]
135
+ if dead_features.sum() > 0:
136
+ residual = x.float() - x_reconstruct.float()
137
+ acts_topk_aux = torch.topk(
138
+ acts[:, dead_features],
139
+ min(self.cfg["top_k_aux"], dead_features.sum()),
140
+ dim=-1,
141
+ )
142
+ acts_aux = torch.zeros_like(acts[:, dead_features]).scatter(
143
+ -1, acts_topk_aux.indices, acts_topk_aux.values
144
+ )
145
+ x_reconstruct_aux = acts_aux @ self.W_dec[dead_features]
146
+ l2_loss_aux = (
147
+ self.cfg["aux_penalty"]
148
+ * (x_reconstruct_aux.float() - residual.float()).pow(2).mean()
149
+ )
150
+ return l2_loss_aux
151
+ else:
152
+ return torch.tensor(0, dtype=x.dtype, device=x.device)
153
+
154
+
155
+ class TopKEncoder(BaseEncoder):
156
+ def __init__(self, cfg):
157
+ super().__init__(cfg)
158
+
159
+ def forward(self, x):
160
+ x, x_mean, x_std = self.preprocess_input(x)
161
+ acts = F.relu(x @ self.W_enc)
162
+ acts_topk = torch.topk(acts, self.cfg["top_k"], dim=-1)
163
+ acts_topk = torch.zeros_like(acts).scatter(
164
+ -1, acts_topk.indices, acts_topk.values
165
+ )
166
+ return acts, acts_topk, x, x_mean, x_std
167
+
168
+ class TopKSAE(BaseAutoencoder):
169
+ def __init__(self, cfg: dict, encoder: TopKEncoder):
170
+ super().__init__(cfg, encoder)
171
+
172
+ def forward(self, x):
173
+ acts, acts_topk, x, x_mean, x_std = self.encoder(x)
174
+ x_reconstruct = acts_topk @ self.W_dec + self.b_dec
175
+ self.update_inactive_features(acts_topk)
176
+ output = self.get_loss_dict(x, x_reconstruct, acts, acts_topk, x_mean, x_std)
177
+ return output
178
+
179
+ def get_loss_dict(self, x, x_reconstruct, acts, acts_topk, x_mean, x_std):
180
+ l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
181
+ l1_norm = acts_topk.float().abs().sum(-1).mean()
182
+ l1_loss = self.cfg["l1_coeff"] * l1_norm
183
+ l0_norm = (acts_topk > 0).float().sum(-1).mean()
184
+ aux_loss = self.get_auxiliary_loss(x, x_reconstruct, acts)
185
+ loss = l2_loss + l1_loss + aux_loss
186
+ num_dead_features = (
187
+ self.num_batches_not_active > self.cfg["n_batches_to_dead"]
188
+ ).sum()
189
+ sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
190
+ output = {
191
+ "sae_out": sae_out,
192
+ "feature_acts": acts_topk,
193
+ "num_dead_features": num_dead_features,
194
+ "loss": loss,
195
+ "l1_loss": l1_loss,
196
+ "l2_loss": l2_loss,
197
+ "l0_norm": l0_norm,
198
+ "l1_norm": l1_norm,
199
+ "aux_loss": aux_loss,
200
+ }
201
+ return output
202
+
203
+ def get_auxiliary_loss(self, x, x_reconstruct, acts):
204
+ dead_features = self.num_batches_not_active >= self.cfg["n_batches_to_dead"]
205
+ if dead_features.sum() > 0:
206
+ residual = x.float() - x_reconstruct.float()
207
+ acts_topk_aux = torch.topk(
208
+ acts[:, dead_features],
209
+ min(self.cfg["top_k_aux"], dead_features.sum()),
210
+ dim=-1,
211
+ )
212
+ acts_aux = torch.zeros_like(acts[:, dead_features]).scatter(
213
+ -1, acts_topk_aux.indices, acts_topk_aux.values
214
+ )
215
+ x_reconstruct_aux = acts_aux @ self.W_dec[dead_features]
216
+ l2_loss_aux = (
217
+ self.cfg["aux_penalty"]
218
+ * (x_reconstruct_aux.float() - residual.float()).pow(2).mean()
219
+ )
220
+ return l2_loss_aux
221
+ else:
222
+ return torch.tensor(0, dtype=x.dtype, device=x.device)
223
+
224
+
225
+ class VanillaEncoder(BaseEncoder):
226
+ def __init__(self, cfg):
227
+ super().__init__(cfg)
228
+
229
+ def forward(self, x):
230
+ x, x_mean, x_std = self.preprocess_input(x)
231
+ acts = F.relu(x @ self.W_enc + self.b_enc)
232
+ return acts, x, x_mean, x_std
233
+
234
+
235
+ class VanillaSAE(BaseAutoencoder):
236
+ def __init__(self, cfg, encoder: VanillaEncoder):
237
+ super().__init__(cfg, encoder)
238
+
239
+ def forward(self, x):
240
+ acts, x, x_mean, x_std = self.encoder(x)
241
+ x_reconstruct = acts @ self.W_dec + self.b_dec
242
+ self.update_inactive_features(acts)
243
+ output = self.get_loss_dict(x, x_reconstruct, acts, x_mean, x_std)
244
+ return output
245
+
246
+ def get_loss_dict(self, x, x_reconstruct, acts, x_mean, x_std):
247
+ l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
248
+ l1_norm = acts.float().abs().sum(-1).mean()
249
+ l1_loss = self.cfg["l1_coeff"] * l1_norm
250
+ l0_norm = (acts > 0).float().sum(-1).mean()
251
+ loss = l2_loss + l1_loss
252
+ num_dead_features = (
253
+ self.num_batches_not_active > self.cfg["n_batches_to_dead"]
254
+ ).sum()
255
+
256
+ sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
257
+ output = {
258
+ "sae_out": sae_out,
259
+ "feature_acts": acts,
260
+ "num_dead_features": num_dead_features,
261
+ "loss": loss,
262
+ "l1_loss": l1_loss,
263
+ "l2_loss": l2_loss,
264
+ "l0_norm": l0_norm,
265
+ "l1_norm": l1_norm,
266
+ }
267
+ return output
268
+
269
+ import torch
270
+ import torch.nn as nn
271
+ import torch.autograd as autograd
272
+
273
+ class RectangleFunction(autograd.Function):
274
+ @staticmethod
275
+ def forward(ctx, x):
276
+ ctx.save_for_backward(x)
277
+ return ((x > -0.5) & (x < 0.5)).float()
278
+
279
+ @staticmethod
280
+ def backward(ctx, grad_output):
281
+ (x,) = ctx.saved_tensors
282
+ grad_input = grad_output.clone()
283
+ grad_input[(x <= -0.5) | (x >= 0.5)] = 0
284
+ return grad_input
285
+
286
+ class JumpReLUFunction(autograd.Function):
287
+ @staticmethod
288
+ def forward(ctx, x, log_threshold, bandwidth):
289
+ ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth))
290
+ threshold = torch.exp(log_threshold)
291
+ return x * (x > threshold).float()
292
+
293
+ @staticmethod
294
+ def backward(ctx, grad_output):
295
+ x, log_threshold, bandwidth_tensor = ctx.saved_tensors
296
+ bandwidth = bandwidth_tensor.item()
297
+ threshold = torch.exp(log_threshold)
298
+ x_grad = (x > threshold).float() * grad_output
299
+ threshold_grad = (
300
+ -(threshold / bandwidth)
301
+ * RectangleFunction.apply((x - threshold) / bandwidth)
302
+ * grad_output
303
+ )
304
+ return x_grad, threshold_grad, None # None for bandwidth
305
+
306
+ class JumpReLU(nn.Module):
307
+ def __init__(self, feature_size, bandwidth, device='cpu'):
308
+ super(JumpReLU, self).__init__()
309
+ self.log_threshold = nn.Parameter(torch.zeros(feature_size, device=device))
310
+ self.bandwidth = bandwidth
311
+
312
+ def forward(self, x):
313
+ return JumpReLUFunction.apply(x, self.log_threshold, self.bandwidth)
314
+
315
+ class StepFunction(autograd.Function):
316
+ @staticmethod
317
+ def forward(ctx, x, log_threshold, bandwidth):
318
+ ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth))
319
+ threshold = torch.exp(log_threshold)
320
+ return (x > threshold).float()
321
+
322
+ @staticmethod
323
+ def backward(ctx, grad_output):
324
+ x, log_threshold, bandwidth_tensor = ctx.saved_tensors
325
+ bandwidth = bandwidth_tensor.item()
326
+ threshold = torch.exp(log_threshold)
327
+ x_grad = torch.zeros_like(x)
328
+ threshold_grad = (
329
+ -(1.0 / bandwidth)
330
+ * RectangleFunction.apply((x - threshold) / bandwidth)
331
+ * grad_output
332
+ )
333
+ return x_grad, threshold_grad, None # None for bandwidth
334
+
335
+
336
+ class JumpReLUEncoder(BaseEncoder):
337
+ def __init__(self, cfg):
338
+ super().__init__(cfg)
339
+
340
+ def forward(self, x):
341
+ x, x_mean, x_std = self.preprocess_input(x)
342
+
343
+ pre_activations = torch.relu(x @ self.W_enc + self.b_enc)
344
+ feature_magnitudes = self.jumprelu(pre_activations)
345
+ return feature_magnitudes, x, x_mean, x_std
346
+
347
+
348
+ class JumpReLUSAE(BaseAutoencoder):
349
+ def __init__(self, cfg):
350
+ super().__init__(cfg)
351
+ self.jumprelu = JumpReLU(feature_size=cfg["dict_size"], bandwidth=cfg["bandwidth"], device=cfg["device"])
352
+
353
+ def forward(self, x):
354
+ feature_magnitudes, x, x_mean, x_std = self.encoder(x)
355
+ x_reconstructed = feature_magnitudes @ self.W_dec + self.b_dec
356
+
357
+ return self.get_loss_dict(x, x_reconstructed, feature_magnitudes, x_mean, x_std)
358
+
359
+ def get_loss_dict(self, x, x_reconstruct, acts, x_mean, x_std):
360
+ l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
361
+
362
+ l0 = StepFunction.apply(acts, self.jumprelu.log_threshold, self.cfg["bandwidth"]).sum(dim=-1).mean()
363
+ l0_loss = self.cfg["l1_coeff"] * l0
364
+ l1_loss = l0_loss
365
+
366
+ loss = l2_loss + l1_loss
367
+ num_dead_features = (
368
+ self.num_batches_not_active > self.cfg["n_batches_to_dead"]
369
+ ).sum()
370
+
371
+ sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
372
+ output = {
373
+ "sae_out": sae_out,
374
+ "feature_acts": acts,
375
+ "num_dead_features": num_dead_features,
376
+ "loss": loss,
377
+ "l1_loss": l1_loss,
378
+ "l2_loss": l2_loss,
379
+ "l0_norm": l0,
380
+ "l1_norm": l0,
381
+ }
382
+ return output
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tmnt
3
- Version: 0.7.58
3
+ Version: 0.7.60
4
4
  Summary: Topic modeling neural toolkit
5
5
  Home-page: https://github.com/mitre/tmnt.git
6
6
  Author: The MITRE Corporation
@@ -48,7 +48,7 @@ Dynamic: summary
48
48
  The Topic Modeling Neural Toolkit (TMNT) is a software library that enables training
49
49
  topic models as neural network-based variational auto-encoders.
50
50
 
51
- Current stable version is: 0.7.54
51
+ Current stable version is: 0.7.60
52
52
 
53
53
  Documentation can be found here: https://tmnt.readthedocs.io/en/stable/
54
54
 
@@ -1,4 +1,4 @@
1
- tmnt/__init__.py,sha256=EPNq1H7UMyMewWT_zTGBaC7ZouvCywX_gMX4G1dtmvw,250
1
+ tmnt/__init__.py,sha256=s7YqLj32HKhIYO1QbD0zms8rDlTrleJM8LRjKt8bDPk,200
2
2
  tmnt/configuration.py,sha256=P8PEhzVPKO5xG0FrdTLRQ60OYWigbzPY-OSx_hzQlrY,10054
3
3
  tmnt/data_loading.py,sha256=LcVcXX00UsuAillRPILcvmqj3AsCIgzB6V_S6lfsbIY,19335
4
4
  tmnt/distribution.py,sha256=4gn1wnszVAErzICCvZXSYki0G78WC3_jyBr27N-Aj3E,15108
@@ -9,6 +9,11 @@ tmnt/modeling.py,sha256=rGHQsW7ldycFUd1f9NzcnNuSRElr600vLwmYPl6YY0M,30215
9
9
  tmnt/preprocess/__init__.py,sha256=gwMejkQrnqKS05i0JVsUru2hDUR5jE1hKC10dL934GU,170
10
10
  tmnt/preprocess/tokenizer.py,sha256=-ZgowfbHrM040vbNTktZM_hdl6HDTqxSJ4mDAxq3dUs,14050
11
11
  tmnt/preprocess/vectorizer.py,sha256=RaianZ_DG3Nc-RI96FtmI4PCZPi5Nipx9a5xndLZ52M,20689
12
+ tmnt/sparse/__init__.py,sha256=BEhOm_o0UrVUKTG3rSiBJzE7qQQL9HRSZ1MHCA2GJu8,249
13
+ tmnt/sparse/config.py,sha256=gfJ1BAP3zzMKKUJExrs8D0hHB7XOVvqXPpmx0ECAPIE,796
14
+ tmnt/sparse/estimator.py,sha256=SOEPkQo7T2RuBJLYRAk65MhoYjgYtNCAzlvhqq10c5E,4596
15
+ tmnt/sparse/inference.py,sha256=etOuXTxh8bKc7EoohZLYYufFAD51aEpq8mGj1vjULbg,2813
16
+ tmnt/sparse/modeling.py,sha256=IejHbRXyj5WsEthkVbp3vYF1wTAAsihbneaPuCgAfGA,13612
12
17
  tmnt/utils/__init__.py,sha256=1PZsxRPsHI_DnOpxD0iAhLxhxHnx6Svzg3W-79YfWWs,237
13
18
  tmnt/utils/csv2json.py,sha256=A1TXy-uxA4dc9tw0tjiHzL7fv4C6b0Uc_bwI1keTmKU,795
14
19
  tmnt/utils/log_utils.py,sha256=ZtR4nF_Iee23ev935YQcTtXv-cCC7lgXkXLl_yokfS4,2075
@@ -18,9 +23,9 @@ tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,12
18
23
  tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
19
24
  tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
20
25
  tmnt/utils/vocab.py,sha256=J6GFGLyvDgdmtVQjYlyzWjuykRD3kllCKPG1z0lI0P8,3504
21
- tmnt-0.7.58.dist-info/licenses/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
22
- tmnt-0.7.58.dist-info/licenses/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
23
- tmnt-0.7.58.dist-info/METADATA,sha256=drdqhfVdpDs5LD_FMAMZjPRWw_TnNqFlGsh0QGtm8QE,1663
24
- tmnt-0.7.58.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
25
- tmnt-0.7.58.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
26
- tmnt-0.7.58.dist-info/RECORD,,
26
+ tmnt-0.7.60.dist-info/licenses/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
27
+ tmnt-0.7.60.dist-info/licenses/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
28
+ tmnt-0.7.60.dist-info/METADATA,sha256=DbmpuasEyoW6FNHmu-YfNsgq5J1lV0MjWPe04CCq8Fk,1663
29
+ tmnt-0.7.60.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
30
+ tmnt-0.7.60.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
31
+ tmnt-0.7.60.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.0)
2
+ Generator: setuptools (80.0.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5