tmnt 0.7.59__py3-none-any.whl → 0.7.61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tmnt/__init__.py CHANGED
@@ -1,11 +1,9 @@
1
1
  # coding: utf-8
2
2
 
3
- import os
4
- from .distribution import *
3
+
5
4
  from .preprocess import *
6
- #from .models import *
5
+ from .sparse import *
7
6
  from .utils import *
7
+ from .distribution import *
8
8
 
9
- os.environ["MXNET_STORAGE_FALLBACK_LOG_VERBOSE"] = "0"
10
-
11
- __all__ = distribution.__all__ + preprocess.__all__ + utils.__all__
9
+ __all__ = distribution.__all__ + preprocess.__all__ + utils.__all__ # + sparse.__all__
tmnt/estimator.py CHANGED
@@ -15,7 +15,7 @@ import numpy as np
15
15
  import scipy.sparse as sp
16
16
  import json
17
17
 
18
- from sklearn.metrics import average_precision_score, top_k_accuracy_score, roc_auc_score, ndcg_score, precision_recall_fscore_support
18
+ from sklearn.metrics import average_precision_score, top_k_accuracy_score, roc_auc_score, ndcg_score
19
19
  from tmnt.data_loading import PairedDataLoader, SingletonWrapperLoader, SparseDataLoader, get_llm_model
20
20
  from tmnt.modeling import BowVAEModel, SeqBowVED, BaseVAE
21
21
  from tmnt.modeling import CrossBatchCosineSimilarityLoss, GeneralizedSDMLLoss, MultiNegativeCrossEntropyLoss, MetricSeqBowVED, MetricBowVAEModel
@@ -29,18 +29,15 @@ from torcheval.metrics import MultilabelAUPRC, MulticlassAUPRC
29
29
  ## huggingface specifics
30
30
  from transformers.trainer_pt_utils import get_parameter_names
31
31
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
32
- from transformers.optimization import AdamW, get_scheduler
32
+ from transformers.optimization import get_scheduler
33
33
 
34
34
  ## model selection
35
35
  import optuna
36
36
 
37
- from itertools import cycle
38
37
  import pickle
39
38
  from typing import List, Tuple, Dict, Optional, Union, NoReturn
40
39
 
41
40
  import torch
42
- from torch.utils.data import Dataset, DataLoader
43
- from tqdm import tqdm
44
41
 
45
42
  MAX_DESIGN_MATRIX = 250000000
46
43
 
@@ -0,0 +1,12 @@
1
+ # coding: utf-8
2
+ """
3
+ Copyright (c) 2019 The MITRE Corporation.
4
+ """
5
+
6
+
7
+ from .config import *
8
+ from .estimator import *
9
+ from .inference import *
10
+ from .modeling import *
11
+
12
+ __all__ = config.__all__ + estimator.__all__ + inference.__all__ + modeling.__all__
tmnt/sparse/config.py ADDED
@@ -0,0 +1,35 @@
1
+ import torch
2
+
3
+ __all__ = ['get_default_cfg']
4
+
5
+ def get_default_cfg():
6
+ default_cfg = {
7
+ "seed": 49,
8
+ "batch_size": 4096,
9
+ "lr": 3e-4,
10
+ "num_samples": int(1e9),
11
+ "l1_coeff": 0,
12
+ "beta1": 0.9,
13
+ "beta2": 0.99,
14
+ "max_grad_norm": 100000,
15
+ "seq_len": 128,
16
+ "dtype": torch.float32,
17
+ "site": "resid_pre",
18
+ "layer": 8,
19
+ "act_size": 768,
20
+ "dict_size": 12288,
21
+ "device": "cuda:0",
22
+ "input_unit_norm": True,
23
+ "perf_log_freq": 1000,
24
+ "sae_type": "topk",
25
+ "checkpoint_freq": 10000,
26
+ "n_batches_to_dead": 5,
27
+
28
+ # (Batch)TopKSAE specific
29
+ "top_k": 32,
30
+ "top_k_aux": 512,
31
+ "aux_penalty": (1/32),
32
+ # for jumprelu
33
+ "bandwidth": 0.001,
34
+ }
35
+ return default_cfg
@@ -0,0 +1,98 @@
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from datasets import Dataset, IterableDataset
4
+ import tqdm
5
+ from datasets.arrow_writer import ArrowWriter
6
+ from tmnt.inference import SeqVEDInferencer
7
+ import io, json
8
+ from tmnt.sparse.modeling import BaseAutoencoder
9
+ from typing import List
10
+
11
+ __all__ = ['ActivationsStore', 'build_activation_store', 'build_activation_store_batching', 'train_sparse_encoder_decoder']
12
+
13
+ class ActivationsStore:
14
+ def __init__(
15
+ self,
16
+ cfg: dict,
17
+ ):
18
+ self.device = cfg["device"]
19
+ self.activation_path = cfg["activation_path"]
20
+ shuffle = cfg.get("shuffle_data", False)
21
+ #self.dataset = Dataset.from_file(self.activation_path).with_format('torch', device=self.device)
22
+ self.dataset = Dataset.from_file(self.activation_path).select_columns(['data']).shuffle(seed=42).with_format('torch', device=self.device)
23
+ self.dataloader = DataLoader(self.dataset,
24
+ batch_size=cfg["batch_size"], shuffle=shuffle)
25
+ self.dataloader_iter = iter(self.dataloader)
26
+ self.cfg = cfg
27
+
28
+ def next_batch(self):
29
+ try:
30
+ return next(self.dataloader_iter)['data']
31
+ except (StopIteration, AttributeError):
32
+ self.dataloader_iter = iter(self.dataloader)
33
+ return next(self.dataloader_iter)['data']
34
+
35
+ def build_activation_store(json_input_texts, emb_model_path, arrow_output, max_seq_len=512, json_txt_key='text', device='cpu'):
36
+
37
+ inferencer = SeqVEDInferencer.from_saved(emb_model_path, max_length=max_seq_len, device=device)
38
+ with io.open(json_input_texts) as fp:
39
+ with ArrowWriter(path=arrow_output) as writer:
40
+ for l in fp:
41
+ js = json.loads(l)
42
+ tokenization_result = inferencer.prep_text(js[json_txt_key])
43
+ llm_out = inferencer.model.llm(tokenization_result['input_ids'].to(inferencer.device),
44
+ tokenization_result['attention_mask'].to(inferencer.device))
45
+ cls_vec = inferencer.model._get_embedding(llm_out, tokenization_result['attention_mask'].to(inferencer.device))
46
+ enc : List[float] = cls_vec.cpu().detach()[0].tolist()
47
+ writer.write({'data': enc})
48
+ writer.finalize()
49
+
50
+ def build_activation_store_batching(json_input_texts, emb_model_path, arrow_output, max_seq_len=512, batch_size=42, json_txt_key='text', device='cpu'):
51
+ inferencer = SeqVEDInferencer.from_saved(emb_model_path, max_length=max_seq_len, device=device)
52
+ def encode_batch(txt_batch):
53
+ tokenization_result = inferencer.prep_text(txt_batch)
54
+ llm_out = inferencer.model.llm(tokenization_result['input_ids'].to(inferencer.device),
55
+ tokenization_result['attention_mask'].to(inferencer.device))
56
+ cls_vec = inferencer.model._get_embedding(llm_out, tokenization_result['attention_mask'].to(inferencer.device))
57
+ encs : List[List[float]] = cls_vec.cpu().detach().tolist()
58
+ return zip(txt_batch, encs)
59
+
60
+ def write_encodings(writer: ArrowWriter, txt_enc_pairs):
61
+ for (t, e) in txt_enc_pairs:
62
+ writer.write({'text': t, 'data': e})
63
+
64
+ with io.open(json_input_texts) as fp:
65
+ with ArrowWriter(path=arrow_output) as writer:
66
+ txt_batch = []
67
+ for l in fp:
68
+ js = json.loads(l)
69
+ txt_batch.append(js[json_txt_key])
70
+ if len(txt_batch) >= batch_size:
71
+ encodings = encode_batch(txt_batch)
72
+ write_encodings(writer, encodings)
73
+ txt_batch = []
74
+ if len(txt_batch) > 0:
75
+ encodings = encode_batch(txt_batch)
76
+ write_encodings(writer, encodings)
77
+ writer.finalize()
78
+
79
+
80
+ def train_sparse_encoder_decoder(sed: BaseAutoencoder, activation_store: ActivationsStore, cfg: dict):
81
+ num_batches = cfg["num_samples"] // cfg["batch_size"]
82
+ optimizer = torch.optim.Adam(sed.parameters(), lr=cfg["lr"], betas=(cfg["beta1"], cfg["beta2"]))
83
+ pbar = tqdm.trange(num_batches)
84
+
85
+ for i in pbar:
86
+ batch = activation_store.next_batch()
87
+ sed_output = sed(batch)
88
+
89
+ loss = sed_output["loss"]
90
+ pbar.set_postfix({"Loss": f"{loss.item():.4f}", "Dead": f"{sed_output['num_dead_features']:.4f}", "L0": f"{sed_output['l0_norm']:.4f}", "L2": f"{sed_output['l2_loss']:.4f}", "L1": f"{sed_output['l1_loss']:.4f}", "L1_norm": f"{sed_output['l1_norm']:.4f}"})
91
+ loss.backward()
92
+ torch.nn.utils.clip_grad_norm_(sed.parameters(), cfg["max_grad_norm"])
93
+ sed.make_decoder_weights_and_grad_unit_norm()
94
+ optimizer.step()
95
+ optimizer.zero_grad()
96
+
97
+
98
+
@@ -0,0 +1,55 @@
1
+ from tmnt.inference import SeqVEDInferencer
2
+ from scipy.sparse import csr_matrix
3
+ import numpy as np
4
+ from typing import List, Tuple
5
+ from tmnt.distribution import ConceptLogisticGaussianDistribution
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+ from datasets import Dataset, IterableDataset
9
+ import tqdm
10
+ from datasets.arrow_writer import ArrowWriter
11
+ from tmnt.inference import SeqVEDInferencer
12
+ import io, json
13
+
14
+ __all__ = ['batch_process_to_arrow']
15
+
16
+ def csr_to_indices_data(csr_mat):
17
+ return [ (csr_mat.getrow(ri).indices, csr_mat.getrow(ri).data) for ri in range(csr_mat.shape[0]) ]
18
+
19
+ def batch_process_to_arrow(model_path, json_input_texts, output_db_path, max_seq_len=512, device='cuda', batch_size=200, json_txt_key='text'):
20
+
21
+ inferencer = SeqVEDInferencer.from_saved(model_path, max_length=max_seq_len, device=device)
22
+ def encode_batch(txt_batch):
23
+ tokenization_result = inferencer.prep_text(txt_batch)
24
+ llm_out = inferencer.model.llm(tokenization_result['input_ids'].to(inferencer.device),
25
+ tokenization_result['attention_mask'].to(inferencer.device))
26
+ cls_vecs = inferencer.model._get_embedding(llm_out, tokenization_result['attention_mask'].to(inferencer.device))
27
+ raw_concepts = inferencer.model.latent_distribution.get_sparse_encoding(cls_vecs).cpu().detach()
28
+ mu_emb = inferencer.model.latent_distribution.get_mu_encoding(cls_vecs)
29
+ encs : List[List[float]] = cls_vecs.cpu().detach().tolist()
30
+ sparse_concepts : List[Tuple[List[int], List[float]]] = csr_to_indices_data(csr_matrix(raw_concepts))
31
+ topic_embeddings : List[List[float]] = mu_emb.cpu().detach().tolist()
32
+ print("Lengths: {}, {}, {}, {}".format(len(txt_batch), len(encs), len(sparse_concepts), len(topic_embeddings)))
33
+ return zip(txt_batch, encs, sparse_concepts, topic_embeddings)
34
+
35
+ def write_encodings(writer: ArrowWriter, txt_enc_pairs):
36
+ for (text, embedding, sparse_indices_and_data, topic_embedding) in txt_enc_pairs:
37
+ writer.write({'text': text, 'embedding': embedding, 'indices': sparse_indices_and_data[0],
38
+ 'values': sparse_indices_and_data[1], 'topic_embedding': topic_embedding})
39
+
40
+ with io.open(json_input_texts) as fp:
41
+ with ArrowWriter(path=output_db_path) as writer:
42
+ txt_batch = []
43
+ for l in fp:
44
+ js = json.loads(l)
45
+ txt_batch.append(js[json_txt_key])
46
+ if len(txt_batch) >= batch_size:
47
+ encodings = encode_batch(txt_batch)
48
+ write_encodings(writer, encodings)
49
+ txt_batch = []
50
+ if len(txt_batch) > 0:
51
+ encodings = encode_batch(txt_batch)
52
+ write_encodings(writer, encodings)
53
+ writer.finalize()
54
+
55
+
@@ -0,0 +1,385 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.autograd as autograd
5
+
6
+
7
+ __all__ = ['BatchTopKEncoder', 'BatchTopKSAE', 'TopKEncoder', 'TopKSAE', 'VanillaEncoder', 'VanillaSAE', 'JumpReLUEncoder', 'JumpReLUSAE']
8
+
9
+ class BaseEncoder(nn.Module):
10
+
11
+ def __init__(self, cfg):
12
+ super().__init__()
13
+ self.cfg = cfg
14
+ torch.manual_seed(self.cfg['seed'])
15
+ self.b_enc = nn.Parameter(torch.zeros(self.cfg['dict_size']))
16
+ self.W_enc = nn.Parameter(
17
+ torch.nn.init.kaiming_uniform_(
18
+ torch.empty(self.cfg['act_size'], self.cfg['dict_size'])
19
+ )
20
+ )
21
+ self.to(cfg['dtype']).to(cfg['device'])
22
+
23
+ def get_dict_size(self):
24
+ return int(self.cfg['dict_size'])
25
+
26
+ def preprocess_input(self, x):
27
+ if self.cfg.get("input_unit_norm", False):
28
+ x_mean = x.mean(dim=-1, keepdim=True)
29
+ x = x - x_mean
30
+ x_std = x.std(dim=-1, keepdim=True)
31
+ x = x / (x_std + 1e-5)
32
+ return x, x_mean, x_std
33
+ else:
34
+ return x, None, None
35
+
36
+
37
+ class BaseAutoencoder(nn.Module):
38
+ """Base class for autoencoder models."""
39
+
40
+ def __init__(self, cfg, encoder: BaseEncoder):
41
+ super().__init__()
42
+
43
+ self.cfg = cfg
44
+ torch.manual_seed(self.cfg["seed"])
45
+
46
+ self.encoder = encoder
47
+
48
+ self.b_dec = nn.Parameter(torch.zeros(self.cfg["act_size"]))
49
+ self.W_dec = nn.Parameter(
50
+ torch.nn.init.kaiming_uniform_(
51
+ torch.empty(self.cfg["dict_size"], self.cfg["act_size"])
52
+ )
53
+ )
54
+ self.W_dec.data[:] = self.encoder.W_enc.t().data
55
+ self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
56
+ self.num_batches_not_active = torch.zeros((self.cfg["dict_size"],)).to(
57
+ cfg["device"]
58
+ )
59
+
60
+ self.to(cfg["dtype"]).to(cfg["device"])
61
+
62
+ def postprocess_output(self, x_reconstruct, x_mean, x_std):
63
+ if self.cfg.get("input_unit_norm", False):
64
+ x_reconstruct = x_reconstruct * x_std + x_mean
65
+ return x_reconstruct
66
+
67
+ @torch.no_grad()
68
+ def make_decoder_weights_and_grad_unit_norm(self):
69
+ W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
70
+ W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(
71
+ -1, keepdim=True
72
+ ) * W_dec_normed
73
+ self.W_dec.grad -= W_dec_grad_proj
74
+ self.W_dec.data = W_dec_normed
75
+
76
+ def update_inactive_features(self, acts):
77
+ self.num_batches_not_active += (acts.sum(0) == 0).float()
78
+ self.num_batches_not_active[acts.sum(0) > 0] = 0
79
+
80
+
81
+ class BatchTopKEncoder(BaseEncoder):
82
+ def __init__(self, cfg):
83
+ super().__init__(cfg)
84
+
85
+ def forward(self, x):
86
+ x, x_mean, x_std = self.preprocess_input(x)
87
+
88
+ acts = F.relu(x @ self.W_enc)
89
+ acts_topk = torch.topk(acts.flatten(), self.cfg["top_k"] * x.shape[0], dim=-1)
90
+ acts_topk = (
91
+ torch.zeros_like(acts.flatten())
92
+ .scatter(-1, acts_topk.indices, acts_topk.values)
93
+ .reshape(acts.shape)
94
+ )
95
+ return acts, acts_topk, x, x_mean, x_std
96
+
97
+
98
+
99
+ class BatchTopKSAE(BaseAutoencoder):
100
+ def __init__(self, cfg:dict , encoder: BatchTopKEncoder):
101
+ super().__init__(cfg, encoder)
102
+
103
+ def forward(self, x):
104
+ acts, acts_topk, x, x_mean, x_std = self.encoder(x)
105
+ x_reconstruct = acts_topk @ self.W_dec + self.b_dec
106
+
107
+ self.update_inactive_features(acts_topk)
108
+ output = self.get_loss_dict(x, x_reconstruct, acts, acts_topk, x_mean, x_std)
109
+ return output
110
+
111
+
112
+ def get_loss_dict(self, x, x_reconstruct, acts, acts_topk, x_mean, x_std):
113
+ l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
114
+ l1_norm = acts_topk.float().abs().sum(-1).mean()
115
+ l1_loss = self.cfg["l1_coeff"] * l1_norm
116
+ l0_norm = (acts_topk > 0).float().sum(-1).mean()
117
+ aux_loss = self.get_auxiliary_loss(x, x_reconstruct, acts)
118
+ loss = l2_loss + l1_loss + aux_loss
119
+ num_dead_features = (
120
+ self.num_batches_not_active > self.cfg["n_batches_to_dead"]
121
+ ).sum()
122
+ sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
123
+ output = {
124
+ "sae_out": sae_out,
125
+ "feature_acts": acts_topk,
126
+ "num_dead_features": num_dead_features,
127
+ "loss": loss,
128
+ "l1_loss": l1_loss,
129
+ "l2_loss": l2_loss,
130
+ "l0_norm": l0_norm,
131
+ "l1_norm": l1_norm,
132
+ "aux_loss": aux_loss,
133
+ }
134
+ return output
135
+
136
+ def get_auxiliary_loss(self, x, x_reconstruct, acts):
137
+ dead_features = self.num_batches_not_active >= self.cfg["n_batches_to_dead"]
138
+ if dead_features.sum() > 0:
139
+ residual = x.float() - x_reconstruct.float()
140
+ acts_topk_aux = torch.topk(
141
+ acts[:, dead_features],
142
+ min(self.cfg["top_k_aux"], dead_features.sum()),
143
+ dim=-1,
144
+ )
145
+ acts_aux = torch.zeros_like(acts[:, dead_features]).scatter(
146
+ -1, acts_topk_aux.indices, acts_topk_aux.values
147
+ )
148
+ x_reconstruct_aux = acts_aux @ self.W_dec[dead_features]
149
+ l2_loss_aux = (
150
+ self.cfg["aux_penalty"]
151
+ * (x_reconstruct_aux.float() - residual.float()).pow(2).mean()
152
+ )
153
+ return l2_loss_aux
154
+ else:
155
+ return torch.tensor(0, dtype=x.dtype, device=x.device)
156
+
157
+
158
+ class TopKEncoder(BaseEncoder):
159
+ def __init__(self, cfg):
160
+ super().__init__(cfg)
161
+
162
+ def forward(self, x):
163
+ x, x_mean, x_std = self.preprocess_input(x)
164
+ acts = F.relu(x @ self.W_enc)
165
+ acts_topk = torch.topk(acts, self.cfg["top_k"], dim=-1)
166
+ acts_topk = torch.zeros_like(acts).scatter(
167
+ -1, acts_topk.indices, acts_topk.values
168
+ )
169
+ return acts, acts_topk, x, x_mean, x_std
170
+
171
+ class TopKSAE(BaseAutoencoder):
172
+ def __init__(self, cfg: dict, encoder: TopKEncoder):
173
+ super().__init__(cfg, encoder)
174
+
175
+ def forward(self, x):
176
+ acts, acts_topk, x, x_mean, x_std = self.encoder(x)
177
+ x_reconstruct = acts_topk @ self.W_dec + self.b_dec
178
+ self.update_inactive_features(acts_topk)
179
+ output = self.get_loss_dict(x, x_reconstruct, acts, acts_topk, x_mean, x_std)
180
+ return output
181
+
182
+ def get_loss_dict(self, x, x_reconstruct, acts, acts_topk, x_mean, x_std):
183
+ l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
184
+ l1_norm = acts_topk.float().abs().sum(-1).mean()
185
+ l1_loss = self.cfg["l1_coeff"] * l1_norm
186
+ l0_norm = (acts_topk > 0).float().sum(-1).mean()
187
+ aux_loss = self.get_auxiliary_loss(x, x_reconstruct, acts)
188
+ loss = l2_loss + l1_loss + aux_loss
189
+ num_dead_features = (
190
+ self.num_batches_not_active > self.cfg["n_batches_to_dead"]
191
+ ).sum()
192
+ sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
193
+ output = {
194
+ "sae_out": sae_out,
195
+ "feature_acts": acts_topk,
196
+ "num_dead_features": num_dead_features,
197
+ "loss": loss,
198
+ "l1_loss": l1_loss,
199
+ "l2_loss": l2_loss,
200
+ "l0_norm": l0_norm,
201
+ "l1_norm": l1_norm,
202
+ "aux_loss": aux_loss,
203
+ }
204
+ return output
205
+
206
+ def get_auxiliary_loss(self, x, x_reconstruct, acts):
207
+ dead_features = self.num_batches_not_active >= self.cfg["n_batches_to_dead"]
208
+ if dead_features.sum() > 0:
209
+ residual = x.float() - x_reconstruct.float()
210
+ acts_topk_aux = torch.topk(
211
+ acts[:, dead_features],
212
+ min(self.cfg["top_k_aux"], dead_features.sum()),
213
+ dim=-1,
214
+ )
215
+ acts_aux = torch.zeros_like(acts[:, dead_features]).scatter(
216
+ -1, acts_topk_aux.indices, acts_topk_aux.values
217
+ )
218
+ x_reconstruct_aux = acts_aux @ self.W_dec[dead_features]
219
+ l2_loss_aux = (
220
+ self.cfg["aux_penalty"]
221
+ * (x_reconstruct_aux.float() - residual.float()).pow(2).mean()
222
+ )
223
+ return l2_loss_aux
224
+ else:
225
+ return torch.tensor(0, dtype=x.dtype, device=x.device)
226
+
227
+
228
+ class VanillaEncoder(BaseEncoder):
229
+ def __init__(self, cfg):
230
+ super().__init__(cfg)
231
+
232
+ def forward(self, x):
233
+ x, x_mean, x_std = self.preprocess_input(x)
234
+ acts = F.relu(x @ self.W_enc + self.b_enc)
235
+ return acts, x, x_mean, x_std
236
+
237
+
238
+ class VanillaSAE(BaseAutoencoder):
239
+ def __init__(self, cfg, encoder: VanillaEncoder):
240
+ super().__init__(cfg, encoder)
241
+
242
+ def forward(self, x):
243
+ acts, x, x_mean, x_std = self.encoder(x)
244
+ x_reconstruct = acts @ self.W_dec + self.b_dec
245
+ self.update_inactive_features(acts)
246
+ output = self.get_loss_dict(x, x_reconstruct, acts, x_mean, x_std)
247
+ return output
248
+
249
+ def get_loss_dict(self, x, x_reconstruct, acts, x_mean, x_std):
250
+ l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
251
+ l1_norm = acts.float().abs().sum(-1).mean()
252
+ l1_loss = self.cfg["l1_coeff"] * l1_norm
253
+ l0_norm = (acts > 0).float().sum(-1).mean()
254
+ loss = l2_loss + l1_loss
255
+ num_dead_features = (
256
+ self.num_batches_not_active > self.cfg["n_batches_to_dead"]
257
+ ).sum()
258
+
259
+ sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
260
+ output = {
261
+ "sae_out": sae_out,
262
+ "feature_acts": acts,
263
+ "num_dead_features": num_dead_features,
264
+ "loss": loss,
265
+ "l1_loss": l1_loss,
266
+ "l2_loss": l2_loss,
267
+ "l0_norm": l0_norm,
268
+ "l1_norm": l1_norm,
269
+ }
270
+ return output
271
+
272
+ import torch
273
+ import torch.nn as nn
274
+ import torch.autograd as autograd
275
+
276
+ class RectangleFunction(autograd.Function):
277
+ @staticmethod
278
+ def forward(ctx, x):
279
+ ctx.save_for_backward(x)
280
+ return ((x > -0.5) & (x < 0.5)).float()
281
+
282
+ @staticmethod
283
+ def backward(ctx, grad_output):
284
+ (x,) = ctx.saved_tensors
285
+ grad_input = grad_output.clone()
286
+ grad_input[(x <= -0.5) | (x >= 0.5)] = 0
287
+ return grad_input
288
+
289
+ class JumpReLUFunction(autograd.Function):
290
+ @staticmethod
291
+ def forward(ctx, x, log_threshold, bandwidth):
292
+ ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth))
293
+ threshold = torch.exp(log_threshold)
294
+ return x * (x > threshold).float()
295
+
296
+ @staticmethod
297
+ def backward(ctx, grad_output):
298
+ x, log_threshold, bandwidth_tensor = ctx.saved_tensors
299
+ bandwidth = bandwidth_tensor.item()
300
+ threshold = torch.exp(log_threshold)
301
+ x_grad = (x > threshold).float() * grad_output
302
+ threshold_grad = (
303
+ -(threshold / bandwidth)
304
+ * RectangleFunction.apply((x - threshold) / bandwidth)
305
+ * grad_output
306
+ )
307
+ return x_grad, threshold_grad, None # None for bandwidth
308
+
309
+ class JumpReLU(nn.Module):
310
+ def __init__(self, feature_size, bandwidth, device='cpu'):
311
+ super(JumpReLU, self).__init__()
312
+ self.log_threshold = nn.Parameter(torch.zeros(feature_size, device=device))
313
+ self.bandwidth = bandwidth
314
+
315
+ def forward(self, x):
316
+ return JumpReLUFunction.apply(x, self.log_threshold, self.bandwidth)
317
+
318
+ class StepFunction(autograd.Function):
319
+ @staticmethod
320
+ def forward(ctx, x, log_threshold, bandwidth):
321
+ ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth))
322
+ threshold = torch.exp(log_threshold)
323
+ return (x > threshold).float()
324
+
325
+ @staticmethod
326
+ def backward(ctx, grad_output):
327
+ x, log_threshold, bandwidth_tensor = ctx.saved_tensors
328
+ bandwidth = bandwidth_tensor.item()
329
+ threshold = torch.exp(log_threshold)
330
+ x_grad = torch.zeros_like(x)
331
+ threshold_grad = (
332
+ -(1.0 / bandwidth)
333
+ * RectangleFunction.apply((x - threshold) / bandwidth)
334
+ * grad_output
335
+ )
336
+ return x_grad, threshold_grad, None # None for bandwidth
337
+
338
+
339
+ class JumpReLUEncoder(BaseEncoder):
340
+ def __init__(self, cfg):
341
+ super().__init__(cfg)
342
+
343
+ def forward(self, x):
344
+ x, x_mean, x_std = self.preprocess_input(x)
345
+
346
+ pre_activations = torch.relu(x @ self.W_enc + self.b_enc)
347
+ feature_magnitudes = self.jumprelu(pre_activations)
348
+ return feature_magnitudes, x, x_mean, x_std
349
+
350
+
351
+ class JumpReLUSAE(BaseAutoencoder):
352
+ def __init__(self, cfg):
353
+ super().__init__(cfg)
354
+ self.jumprelu = JumpReLU(feature_size=cfg["dict_size"], bandwidth=cfg["bandwidth"], device=cfg["device"])
355
+
356
+ def forward(self, x):
357
+ feature_magnitudes, x, x_mean, x_std = self.encoder(x)
358
+ x_reconstructed = feature_magnitudes @ self.W_dec + self.b_dec
359
+
360
+ return self.get_loss_dict(x, x_reconstructed, feature_magnitudes, x_mean, x_std)
361
+
362
+ def get_loss_dict(self, x, x_reconstruct, acts, x_mean, x_std):
363
+ l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
364
+
365
+ l0 = StepFunction.apply(acts, self.jumprelu.log_threshold, self.cfg["bandwidth"]).sum(dim=-1).mean()
366
+ l0_loss = self.cfg["l1_coeff"] * l0
367
+ l1_loss = l0_loss
368
+
369
+ loss = l2_loss + l1_loss
370
+ num_dead_features = (
371
+ self.num_batches_not_active > self.cfg["n_batches_to_dead"]
372
+ ).sum()
373
+
374
+ sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
375
+ output = {
376
+ "sae_out": sae_out,
377
+ "feature_acts": acts,
378
+ "num_dead_features": num_dead_features,
379
+ "loss": loss,
380
+ "l1_loss": l1_loss,
381
+ "l2_loss": l2_loss,
382
+ "l0_norm": l0,
383
+ "l1_norm": l0,
384
+ }
385
+ return output
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tmnt
3
- Version: 0.7.59
3
+ Version: 0.7.61
4
4
  Summary: Topic modeling neural toolkit
5
5
  Home-page: https://github.com/mitre/tmnt.git
6
6
  Author: The MITRE Corporation
@@ -48,7 +48,7 @@ Dynamic: summary
48
48
  The Topic Modeling Neural Toolkit (TMNT) is a software library that enables training
49
49
  topic models as neural network-based variational auto-encoders.
50
50
 
51
- Current stable version is: 0.7.54
51
+ Current stable version is: 0.7.61
52
52
 
53
53
  Documentation can be found here: https://tmnt.readthedocs.io/en/stable/
54
54
 
@@ -1,14 +1,19 @@
1
- tmnt/__init__.py,sha256=EPNq1H7UMyMewWT_zTGBaC7ZouvCywX_gMX4G1dtmvw,250
1
+ tmnt/__init__.py,sha256=3W6sRKacvIMQuxGZpscjU52B4sDcasEA1RCXJVHaK-Q,203
2
2
  tmnt/configuration.py,sha256=P8PEhzVPKO5xG0FrdTLRQ60OYWigbzPY-OSx_hzQlrY,10054
3
3
  tmnt/data_loading.py,sha256=LcVcXX00UsuAillRPILcvmqj3AsCIgzB6V_S6lfsbIY,19335
4
4
  tmnt/distribution.py,sha256=4gn1wnszVAErzICCvZXSYki0G78WC3_jyBr27N-Aj3E,15108
5
- tmnt/estimator.py,sha256=KnnvSNXm6cRL0GwDrGdgqqPX5ZubpCQ0WqcSXJDkUU4,68072
5
+ tmnt/estimator.py,sha256=SJW9koDHb_lTwmqXm4Ilgga7MMMmRGLd3ylEFYu3bZg,67933
6
6
  tmnt/eval_npmi.py,sha256=8S-IE-bEhtQofF6oKeXs7oaUeu-7yDlaEqjMj52gmNQ,6549
7
7
  tmnt/inference.py,sha256=Iwc2_w7QrS1epiVEm_Ewx5sYFNNMDfvhMJETOgJqm0E,15783
8
8
  tmnt/modeling.py,sha256=rGHQsW7ldycFUd1f9NzcnNuSRElr600vLwmYPl6YY0M,30215
9
9
  tmnt/preprocess/__init__.py,sha256=gwMejkQrnqKS05i0JVsUru2hDUR5jE1hKC10dL934GU,170
10
10
  tmnt/preprocess/tokenizer.py,sha256=-ZgowfbHrM040vbNTktZM_hdl6HDTqxSJ4mDAxq3dUs,14050
11
11
  tmnt/preprocess/vectorizer.py,sha256=RaianZ_DG3Nc-RI96FtmI4PCZPi5Nipx9a5xndLZ52M,20689
12
+ tmnt/sparse/__init__.py,sha256=BEhOm_o0UrVUKTG3rSiBJzE7qQQL9HRSZ1MHCA2GJu8,249
13
+ tmnt/sparse/config.py,sha256=DzqzmclxbUmmBZHS7poMicsvgOZCMXTXWEY5bTpPrRI,827
14
+ tmnt/sparse/estimator.py,sha256=PSqIopCZG7WO5MXYcN-CAs7R7M36jZXjQeUp0I5kG9Q,4721
15
+ tmnt/sparse/inference.py,sha256=KiSUjo_lO14qkUzzo9HuOWZ5S5oJRzoSsSa2dOg54UY,2850
16
+ tmnt/sparse/modeling.py,sha256=hy2nkI2tq3Bpmj7OfXr_zwHRajWBwsqz7T3wHTuEbyE,13753
12
17
  tmnt/utils/__init__.py,sha256=1PZsxRPsHI_DnOpxD0iAhLxhxHnx6Svzg3W-79YfWWs,237
13
18
  tmnt/utils/csv2json.py,sha256=A1TXy-uxA4dc9tw0tjiHzL7fv4C6b0Uc_bwI1keTmKU,795
14
19
  tmnt/utils/log_utils.py,sha256=ZtR4nF_Iee23ev935YQcTtXv-cCC7lgXkXLl_yokfS4,2075
@@ -18,9 +23,9 @@ tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,12
18
23
  tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
19
24
  tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
20
25
  tmnt/utils/vocab.py,sha256=J6GFGLyvDgdmtVQjYlyzWjuykRD3kllCKPG1z0lI0P8,3504
21
- tmnt-0.7.59.dist-info/licenses/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
22
- tmnt-0.7.59.dist-info/licenses/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
23
- tmnt-0.7.59.dist-info/METADATA,sha256=WnXkGjITudOxSCMTfqQVMdx2UvfxN4O-SLvdYVtTckY,1663
24
- tmnt-0.7.59.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
25
- tmnt-0.7.59.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
26
- tmnt-0.7.59.dist-info/RECORD,,
26
+ tmnt-0.7.61.dist-info/licenses/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
27
+ tmnt-0.7.61.dist-info/licenses/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
28
+ tmnt-0.7.61.dist-info/METADATA,sha256=hyptC3YXersP-7oR89NgKU4UksTrvsbcC5Ml38WSYG0,1663
29
+ tmnt-0.7.61.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
30
+ tmnt-0.7.61.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
31
+ tmnt-0.7.61.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.0)
2
+ Generator: setuptools (80.0.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5