tmnt 0.7.58__py3-none-any.whl → 0.7.60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tmnt/__init__.py +2 -5
- tmnt/sparse/__init__.py +12 -0
- tmnt/sparse/config.py +33 -0
- tmnt/sparse/estimator.py +96 -0
- tmnt/sparse/inference.py +54 -0
- tmnt/sparse/modeling.py +382 -0
- {tmnt-0.7.58.dist-info → tmnt-0.7.60.dist-info}/METADATA +2 -2
- {tmnt-0.7.58.dist-info → tmnt-0.7.60.dist-info}/RECORD +12 -7
- {tmnt-0.7.58.dist-info → tmnt-0.7.60.dist-info}/WHEEL +1 -1
- {tmnt-0.7.58.dist-info → tmnt-0.7.60.dist-info}/licenses/LICENSE +0 -0
- {tmnt-0.7.58.dist-info → tmnt-0.7.60.dist-info}/licenses/NOTICE +0 -0
- {tmnt-0.7.58.dist-info → tmnt-0.7.60.dist-info}/top_level.txt +0 -0
tmnt/__init__.py
CHANGED
@@ -1,11 +1,8 @@
|
|
1
1
|
# coding: utf-8
|
2
2
|
|
3
|
-
import os
|
4
3
|
from .distribution import *
|
5
4
|
from .preprocess import *
|
6
|
-
|
5
|
+
from .sparse import *
|
7
6
|
from .utils import *
|
8
7
|
|
9
|
-
|
10
|
-
|
11
|
-
__all__ = distribution.__all__ + preprocess.__all__ + utils.__all__
|
8
|
+
__all__ = distribution.__all__ + preprocess.__all__ + utils.__all__ + sparse.__all__
|
tmnt/sparse/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
1
|
+
# coding: utf-8
|
2
|
+
"""
|
3
|
+
Copyright (c) 2019 The MITRE Corporation.
|
4
|
+
"""
|
5
|
+
|
6
|
+
|
7
|
+
from .config import *
|
8
|
+
from .estimator import *
|
9
|
+
from .inference import *
|
10
|
+
from .modeling import *
|
11
|
+
|
12
|
+
__all__ = config.__all__ + estimator.__all__ + inference.__all__ + modeling.__all__
|
tmnt/sparse/config.py
ADDED
@@ -0,0 +1,33 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
def get_default_cfg():
|
4
|
+
default_cfg = {
|
5
|
+
"seed": 49,
|
6
|
+
"batch_size": 4096,
|
7
|
+
"lr": 3e-4,
|
8
|
+
"num_samples": int(1e9),
|
9
|
+
"l1_coeff": 0,
|
10
|
+
"beta1": 0.9,
|
11
|
+
"beta2": 0.99,
|
12
|
+
"max_grad_norm": 100000,
|
13
|
+
"seq_len": 128,
|
14
|
+
"dtype": torch.float32,
|
15
|
+
"site": "resid_pre",
|
16
|
+
"layer": 8,
|
17
|
+
"act_size": 768,
|
18
|
+
"dict_size": 12288,
|
19
|
+
"device": "cuda:0",
|
20
|
+
"input_unit_norm": True,
|
21
|
+
"perf_log_freq": 1000,
|
22
|
+
"sae_type": "topk",
|
23
|
+
"checkpoint_freq": 10000,
|
24
|
+
"n_batches_to_dead": 5,
|
25
|
+
|
26
|
+
# (Batch)TopKSAE specific
|
27
|
+
"top_k": 32,
|
28
|
+
"top_k_aux": 512,
|
29
|
+
"aux_penalty": (1/32),
|
30
|
+
# for jumprelu
|
31
|
+
"bandwidth": 0.001,
|
32
|
+
}
|
33
|
+
return default_cfg
|
tmnt/sparse/estimator.py
ADDED
@@ -0,0 +1,96 @@
|
|
1
|
+
import torch
|
2
|
+
from torch.utils.data import DataLoader
|
3
|
+
from datasets import Dataset, IterableDataset
|
4
|
+
import tqdm
|
5
|
+
from datasets.arrow_writer import ArrowWriter
|
6
|
+
from tmnt.inference import SeqVEDInferencer
|
7
|
+
import io, json
|
8
|
+
from tmnt.sparse.modeling import BaseAutoencoder
|
9
|
+
from typing import List
|
10
|
+
|
11
|
+
class ActivationsStore:
|
12
|
+
def __init__(
|
13
|
+
self,
|
14
|
+
cfg: dict,
|
15
|
+
):
|
16
|
+
self.device = cfg["device"]
|
17
|
+
self.activation_path = cfg["activation_path"]
|
18
|
+
shuffle = cfg.get("shuffle_data", False)
|
19
|
+
#self.dataset = Dataset.from_file(self.activation_path).with_format('torch', device=self.device)
|
20
|
+
self.dataset = Dataset.from_file(self.activation_path).select_columns(['data']).shuffle(seed=42).with_format('torch', device=self.device)
|
21
|
+
self.dataloader = DataLoader(self.dataset,
|
22
|
+
batch_size=cfg["batch_size"], shuffle=shuffle)
|
23
|
+
self.dataloader_iter = iter(self.dataloader)
|
24
|
+
self.cfg = cfg
|
25
|
+
|
26
|
+
def next_batch(self):
|
27
|
+
try:
|
28
|
+
return next(self.dataloader_iter)['data']
|
29
|
+
except (StopIteration, AttributeError):
|
30
|
+
self.dataloader_iter = iter(self.dataloader)
|
31
|
+
return next(self.dataloader_iter)['data']
|
32
|
+
|
33
|
+
def build_activation_store(json_input_texts, emb_model_path, arrow_output, max_seq_len=512, json_txt_key='text', device='cpu'):
|
34
|
+
|
35
|
+
inferencer = SeqVEDInferencer.from_saved(emb_model_path, max_length=max_seq_len, device=device)
|
36
|
+
with io.open(json_input_texts) as fp:
|
37
|
+
with ArrowWriter(path=arrow_output) as writer:
|
38
|
+
for l in fp:
|
39
|
+
js = json.loads(l)
|
40
|
+
tokenization_result = inferencer.prep_text(js[json_txt_key])
|
41
|
+
llm_out = inferencer.model.llm(tokenization_result['input_ids'].to(inferencer.device),
|
42
|
+
tokenization_result['attention_mask'].to(inferencer.device))
|
43
|
+
cls_vec = inferencer.model._get_embedding(llm_out, tokenization_result['attention_mask'].to(inferencer.device))
|
44
|
+
enc : List[float] = cls_vec.cpu().detach()[0].tolist()
|
45
|
+
writer.write({'data': enc})
|
46
|
+
writer.finalize()
|
47
|
+
|
48
|
+
def build_activation_store_batching(json_input_texts, emb_model_path, arrow_output, max_seq_len=512, batch_size=42, json_txt_key='text', device='cpu'):
|
49
|
+
inferencer = SeqVEDInferencer.from_saved(emb_model_path, max_length=max_seq_len, device=device)
|
50
|
+
def encode_batch(txt_batch):
|
51
|
+
tokenization_result = inferencer.prep_text(txt_batch)
|
52
|
+
llm_out = inferencer.model.llm(tokenization_result['input_ids'].to(inferencer.device),
|
53
|
+
tokenization_result['attention_mask'].to(inferencer.device))
|
54
|
+
cls_vec = inferencer.model._get_embedding(llm_out, tokenization_result['attention_mask'].to(inferencer.device))
|
55
|
+
encs : List[List[float]] = cls_vec.cpu().detach().tolist()
|
56
|
+
return zip(txt_batch, encs)
|
57
|
+
|
58
|
+
def write_encodings(writer: ArrowWriter, txt_enc_pairs):
|
59
|
+
for (t, e) in txt_enc_pairs:
|
60
|
+
writer.write({'text': t, 'data': e})
|
61
|
+
|
62
|
+
with io.open(json_input_texts) as fp:
|
63
|
+
with ArrowWriter(path=arrow_output) as writer:
|
64
|
+
txt_batch = []
|
65
|
+
for l in fp:
|
66
|
+
js = json.loads(l)
|
67
|
+
txt_batch.append(js[json_txt_key])
|
68
|
+
if len(txt_batch) >= batch_size:
|
69
|
+
encodings = encode_batch(txt_batch)
|
70
|
+
write_encodings(writer, encodings)
|
71
|
+
txt_batch = []
|
72
|
+
if len(txt_batch) > 0:
|
73
|
+
encodings = encode_batch(txt_batch)
|
74
|
+
write_encodings(writer, encodings)
|
75
|
+
writer.finalize()
|
76
|
+
|
77
|
+
|
78
|
+
def train_sparse_encoder_decoder(sed: BaseAutoencoder, activation_store: ActivationsStore, cfg: dict):
|
79
|
+
num_batches = cfg["num_samples"] // cfg["batch_size"]
|
80
|
+
optimizer = torch.optim.Adam(sed.parameters(), lr=cfg["lr"], betas=(cfg["beta1"], cfg["beta2"]))
|
81
|
+
pbar = tqdm.trange(num_batches)
|
82
|
+
|
83
|
+
for i in pbar:
|
84
|
+
batch = activation_store.next_batch()
|
85
|
+
sed_output = sed(batch)
|
86
|
+
|
87
|
+
loss = sed_output["loss"]
|
88
|
+
pbar.set_postfix({"Loss": f"{loss.item():.4f}", "Dead": f"{sed_output['num_dead_features']:.4f}", "L0": f"{sed_output['l0_norm']:.4f}", "L2": f"{sed_output['l2_loss']:.4f}", "L1": f"{sed_output['l1_loss']:.4f}", "L1_norm": f"{sed_output['l1_norm']:.4f}"})
|
89
|
+
loss.backward()
|
90
|
+
torch.nn.utils.clip_grad_norm_(sed.parameters(), cfg["max_grad_norm"])
|
91
|
+
sed.make_decoder_weights_and_grad_unit_norm()
|
92
|
+
optimizer.step()
|
93
|
+
optimizer.zero_grad()
|
94
|
+
|
95
|
+
|
96
|
+
|
tmnt/sparse/inference.py
ADDED
@@ -0,0 +1,54 @@
|
|
1
|
+
from tmnt.inference import SeqVEDInferencer
|
2
|
+
from scipy.sparse import csr_matrix
|
3
|
+
import numpy as np
|
4
|
+
from typing import List, Tuple
|
5
|
+
from tmnt.distribution import ConceptLogisticGaussianDistribution
|
6
|
+
import torch
|
7
|
+
from torch.utils.data import DataLoader
|
8
|
+
from datasets import Dataset, IterableDataset
|
9
|
+
import tqdm
|
10
|
+
from datasets.arrow_writer import ArrowWriter
|
11
|
+
from tmnt.inference import SeqVEDInferencer
|
12
|
+
import io, json
|
13
|
+
|
14
|
+
|
15
|
+
def csr_to_indices_data(csr_mat):
|
16
|
+
return [ (csr_mat.getrow(ri).indices, csr_mat.getrow(ri).data) for ri in range(csr_mat.shape[0]) ]
|
17
|
+
|
18
|
+
def batch_process_to_arrow(model_path, json_input_texts, output_db_path, max_seq_len=512, device='cuda', batch_size=200, json_txt_key='text'):
|
19
|
+
|
20
|
+
inferencer = SeqVEDInferencer.from_saved(model_path, max_length=max_seq_len, device=device)
|
21
|
+
def encode_batch(txt_batch):
|
22
|
+
tokenization_result = inferencer.prep_text(txt_batch)
|
23
|
+
llm_out = inferencer.model.llm(tokenization_result['input_ids'].to(inferencer.device),
|
24
|
+
tokenization_result['attention_mask'].to(inferencer.device))
|
25
|
+
cls_vecs = inferencer.model._get_embedding(llm_out, tokenization_result['attention_mask'].to(inferencer.device))
|
26
|
+
raw_concepts = inferencer.model.latent_distribution.get_sparse_encoding(cls_vecs).cpu().detach()
|
27
|
+
mu_emb = inferencer.model.latent_distribution.get_mu_encoding(cls_vecs)
|
28
|
+
encs : List[List[float]] = cls_vecs.cpu().detach().tolist()
|
29
|
+
sparse_concepts : List[Tuple[List[int], List[float]]] = csr_to_indices_data(csr_matrix(raw_concepts))
|
30
|
+
topic_embeddings : List[List[float]] = mu_emb.cpu().detach().tolist()
|
31
|
+
print("Lengths: {}, {}, {}, {}".format(len(txt_batch), len(encs), len(sparse_concepts), len(topic_embeddings)))
|
32
|
+
return zip(txt_batch, encs, sparse_concepts, topic_embeddings)
|
33
|
+
|
34
|
+
def write_encodings(writer: ArrowWriter, txt_enc_pairs):
|
35
|
+
for (text, embedding, sparse_indices_and_data, topic_embedding) in txt_enc_pairs:
|
36
|
+
writer.write({'text': text, 'embedding': embedding, 'indices': sparse_indices_and_data[0],
|
37
|
+
'values': sparse_indices_and_data[1], 'topic_embedding': topic_embedding})
|
38
|
+
|
39
|
+
with io.open(json_input_texts) as fp:
|
40
|
+
with ArrowWriter(path=output_db_path) as writer:
|
41
|
+
txt_batch = []
|
42
|
+
for l in fp:
|
43
|
+
js = json.loads(l)
|
44
|
+
txt_batch.append(js[json_txt_key])
|
45
|
+
if len(txt_batch) >= batch_size:
|
46
|
+
encodings = encode_batch(txt_batch)
|
47
|
+
write_encodings(writer, encodings)
|
48
|
+
txt_batch = []
|
49
|
+
if len(txt_batch) > 0:
|
50
|
+
encodings = encode_batch(txt_batch)
|
51
|
+
write_encodings(writer, encodings)
|
52
|
+
writer.finalize()
|
53
|
+
|
54
|
+
|
tmnt/sparse/modeling.py
ADDED
@@ -0,0 +1,382 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
import torch.autograd as autograd
|
5
|
+
|
6
|
+
class BaseEncoder(nn.Module):
|
7
|
+
|
8
|
+
def __init__(self, cfg):
|
9
|
+
super().__init__()
|
10
|
+
self.cfg = cfg
|
11
|
+
torch.manual_seed(self.cfg['seed'])
|
12
|
+
self.b_enc = nn.Parameter(torch.zeros(self.cfg['dict_size']))
|
13
|
+
self.W_enc = nn.Parameter(
|
14
|
+
torch.nn.init.kaiming_uniform_(
|
15
|
+
torch.empty(self.cfg['act_size'], self.cfg['dict_size'])
|
16
|
+
)
|
17
|
+
)
|
18
|
+
self.to(cfg['dtype']).to(cfg['device'])
|
19
|
+
|
20
|
+
def get_dict_size(self):
|
21
|
+
return int(self.cfg['dict_size'])
|
22
|
+
|
23
|
+
def preprocess_input(self, x):
|
24
|
+
if self.cfg.get("input_unit_norm", False):
|
25
|
+
x_mean = x.mean(dim=-1, keepdim=True)
|
26
|
+
x = x - x_mean
|
27
|
+
x_std = x.std(dim=-1, keepdim=True)
|
28
|
+
x = x / (x_std + 1e-5)
|
29
|
+
return x, x_mean, x_std
|
30
|
+
else:
|
31
|
+
return x, None, None
|
32
|
+
|
33
|
+
|
34
|
+
class BaseAutoencoder(nn.Module):
|
35
|
+
"""Base class for autoencoder models."""
|
36
|
+
|
37
|
+
def __init__(self, cfg, encoder: BaseEncoder):
|
38
|
+
super().__init__()
|
39
|
+
|
40
|
+
self.cfg = cfg
|
41
|
+
torch.manual_seed(self.cfg["seed"])
|
42
|
+
|
43
|
+
self.encoder = encoder
|
44
|
+
|
45
|
+
self.b_dec = nn.Parameter(torch.zeros(self.cfg["act_size"]))
|
46
|
+
self.W_dec = nn.Parameter(
|
47
|
+
torch.nn.init.kaiming_uniform_(
|
48
|
+
torch.empty(self.cfg["dict_size"], self.cfg["act_size"])
|
49
|
+
)
|
50
|
+
)
|
51
|
+
self.W_dec.data[:] = self.encoder.W_enc.t().data
|
52
|
+
self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
|
53
|
+
self.num_batches_not_active = torch.zeros((self.cfg["dict_size"],)).to(
|
54
|
+
cfg["device"]
|
55
|
+
)
|
56
|
+
|
57
|
+
self.to(cfg["dtype"]).to(cfg["device"])
|
58
|
+
|
59
|
+
def postprocess_output(self, x_reconstruct, x_mean, x_std):
|
60
|
+
if self.cfg.get("input_unit_norm", False):
|
61
|
+
x_reconstruct = x_reconstruct * x_std + x_mean
|
62
|
+
return x_reconstruct
|
63
|
+
|
64
|
+
@torch.no_grad()
|
65
|
+
def make_decoder_weights_and_grad_unit_norm(self):
|
66
|
+
W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
|
67
|
+
W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(
|
68
|
+
-1, keepdim=True
|
69
|
+
) * W_dec_normed
|
70
|
+
self.W_dec.grad -= W_dec_grad_proj
|
71
|
+
self.W_dec.data = W_dec_normed
|
72
|
+
|
73
|
+
def update_inactive_features(self, acts):
|
74
|
+
self.num_batches_not_active += (acts.sum(0) == 0).float()
|
75
|
+
self.num_batches_not_active[acts.sum(0) > 0] = 0
|
76
|
+
|
77
|
+
|
78
|
+
class BatchTopKEncoder(BaseEncoder):
|
79
|
+
def __init__(self, cfg):
|
80
|
+
super().__init__(cfg)
|
81
|
+
|
82
|
+
def forward(self, x):
|
83
|
+
x, x_mean, x_std = self.preprocess_input(x)
|
84
|
+
|
85
|
+
acts = F.relu(x @ self.W_enc)
|
86
|
+
acts_topk = torch.topk(acts.flatten(), self.cfg["top_k"] * x.shape[0], dim=-1)
|
87
|
+
acts_topk = (
|
88
|
+
torch.zeros_like(acts.flatten())
|
89
|
+
.scatter(-1, acts_topk.indices, acts_topk.values)
|
90
|
+
.reshape(acts.shape)
|
91
|
+
)
|
92
|
+
return acts, acts_topk, x, x_mean, x_std
|
93
|
+
|
94
|
+
|
95
|
+
|
96
|
+
class BatchTopKSAE(BaseAutoencoder):
|
97
|
+
def __init__(self, cfg:dict , encoder: BatchTopKEncoder):
|
98
|
+
super().__init__(cfg, encoder)
|
99
|
+
|
100
|
+
def forward(self, x):
|
101
|
+
acts, acts_topk, x, x_mean, x_std = self.encoder(x)
|
102
|
+
x_reconstruct = acts_topk @ self.W_dec + self.b_dec
|
103
|
+
|
104
|
+
self.update_inactive_features(acts_topk)
|
105
|
+
output = self.get_loss_dict(x, x_reconstruct, acts, acts_topk, x_mean, x_std)
|
106
|
+
return output
|
107
|
+
|
108
|
+
|
109
|
+
def get_loss_dict(self, x, x_reconstruct, acts, acts_topk, x_mean, x_std):
|
110
|
+
l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
|
111
|
+
l1_norm = acts_topk.float().abs().sum(-1).mean()
|
112
|
+
l1_loss = self.cfg["l1_coeff"] * l1_norm
|
113
|
+
l0_norm = (acts_topk > 0).float().sum(-1).mean()
|
114
|
+
aux_loss = self.get_auxiliary_loss(x, x_reconstruct, acts)
|
115
|
+
loss = l2_loss + l1_loss + aux_loss
|
116
|
+
num_dead_features = (
|
117
|
+
self.num_batches_not_active > self.cfg["n_batches_to_dead"]
|
118
|
+
).sum()
|
119
|
+
sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
|
120
|
+
output = {
|
121
|
+
"sae_out": sae_out,
|
122
|
+
"feature_acts": acts_topk,
|
123
|
+
"num_dead_features": num_dead_features,
|
124
|
+
"loss": loss,
|
125
|
+
"l1_loss": l1_loss,
|
126
|
+
"l2_loss": l2_loss,
|
127
|
+
"l0_norm": l0_norm,
|
128
|
+
"l1_norm": l1_norm,
|
129
|
+
"aux_loss": aux_loss,
|
130
|
+
}
|
131
|
+
return output
|
132
|
+
|
133
|
+
def get_auxiliary_loss(self, x, x_reconstruct, acts):
|
134
|
+
dead_features = self.num_batches_not_active >= self.cfg["n_batches_to_dead"]
|
135
|
+
if dead_features.sum() > 0:
|
136
|
+
residual = x.float() - x_reconstruct.float()
|
137
|
+
acts_topk_aux = torch.topk(
|
138
|
+
acts[:, dead_features],
|
139
|
+
min(self.cfg["top_k_aux"], dead_features.sum()),
|
140
|
+
dim=-1,
|
141
|
+
)
|
142
|
+
acts_aux = torch.zeros_like(acts[:, dead_features]).scatter(
|
143
|
+
-1, acts_topk_aux.indices, acts_topk_aux.values
|
144
|
+
)
|
145
|
+
x_reconstruct_aux = acts_aux @ self.W_dec[dead_features]
|
146
|
+
l2_loss_aux = (
|
147
|
+
self.cfg["aux_penalty"]
|
148
|
+
* (x_reconstruct_aux.float() - residual.float()).pow(2).mean()
|
149
|
+
)
|
150
|
+
return l2_loss_aux
|
151
|
+
else:
|
152
|
+
return torch.tensor(0, dtype=x.dtype, device=x.device)
|
153
|
+
|
154
|
+
|
155
|
+
class TopKEncoder(BaseEncoder):
|
156
|
+
def __init__(self, cfg):
|
157
|
+
super().__init__(cfg)
|
158
|
+
|
159
|
+
def forward(self, x):
|
160
|
+
x, x_mean, x_std = self.preprocess_input(x)
|
161
|
+
acts = F.relu(x @ self.W_enc)
|
162
|
+
acts_topk = torch.topk(acts, self.cfg["top_k"], dim=-1)
|
163
|
+
acts_topk = torch.zeros_like(acts).scatter(
|
164
|
+
-1, acts_topk.indices, acts_topk.values
|
165
|
+
)
|
166
|
+
return acts, acts_topk, x, x_mean, x_std
|
167
|
+
|
168
|
+
class TopKSAE(BaseAutoencoder):
|
169
|
+
def __init__(self, cfg: dict, encoder: TopKEncoder):
|
170
|
+
super().__init__(cfg, encoder)
|
171
|
+
|
172
|
+
def forward(self, x):
|
173
|
+
acts, acts_topk, x, x_mean, x_std = self.encoder(x)
|
174
|
+
x_reconstruct = acts_topk @ self.W_dec + self.b_dec
|
175
|
+
self.update_inactive_features(acts_topk)
|
176
|
+
output = self.get_loss_dict(x, x_reconstruct, acts, acts_topk, x_mean, x_std)
|
177
|
+
return output
|
178
|
+
|
179
|
+
def get_loss_dict(self, x, x_reconstruct, acts, acts_topk, x_mean, x_std):
|
180
|
+
l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
|
181
|
+
l1_norm = acts_topk.float().abs().sum(-1).mean()
|
182
|
+
l1_loss = self.cfg["l1_coeff"] * l1_norm
|
183
|
+
l0_norm = (acts_topk > 0).float().sum(-1).mean()
|
184
|
+
aux_loss = self.get_auxiliary_loss(x, x_reconstruct, acts)
|
185
|
+
loss = l2_loss + l1_loss + aux_loss
|
186
|
+
num_dead_features = (
|
187
|
+
self.num_batches_not_active > self.cfg["n_batches_to_dead"]
|
188
|
+
).sum()
|
189
|
+
sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
|
190
|
+
output = {
|
191
|
+
"sae_out": sae_out,
|
192
|
+
"feature_acts": acts_topk,
|
193
|
+
"num_dead_features": num_dead_features,
|
194
|
+
"loss": loss,
|
195
|
+
"l1_loss": l1_loss,
|
196
|
+
"l2_loss": l2_loss,
|
197
|
+
"l0_norm": l0_norm,
|
198
|
+
"l1_norm": l1_norm,
|
199
|
+
"aux_loss": aux_loss,
|
200
|
+
}
|
201
|
+
return output
|
202
|
+
|
203
|
+
def get_auxiliary_loss(self, x, x_reconstruct, acts):
|
204
|
+
dead_features = self.num_batches_not_active >= self.cfg["n_batches_to_dead"]
|
205
|
+
if dead_features.sum() > 0:
|
206
|
+
residual = x.float() - x_reconstruct.float()
|
207
|
+
acts_topk_aux = torch.topk(
|
208
|
+
acts[:, dead_features],
|
209
|
+
min(self.cfg["top_k_aux"], dead_features.sum()),
|
210
|
+
dim=-1,
|
211
|
+
)
|
212
|
+
acts_aux = torch.zeros_like(acts[:, dead_features]).scatter(
|
213
|
+
-1, acts_topk_aux.indices, acts_topk_aux.values
|
214
|
+
)
|
215
|
+
x_reconstruct_aux = acts_aux @ self.W_dec[dead_features]
|
216
|
+
l2_loss_aux = (
|
217
|
+
self.cfg["aux_penalty"]
|
218
|
+
* (x_reconstruct_aux.float() - residual.float()).pow(2).mean()
|
219
|
+
)
|
220
|
+
return l2_loss_aux
|
221
|
+
else:
|
222
|
+
return torch.tensor(0, dtype=x.dtype, device=x.device)
|
223
|
+
|
224
|
+
|
225
|
+
class VanillaEncoder(BaseEncoder):
|
226
|
+
def __init__(self, cfg):
|
227
|
+
super().__init__(cfg)
|
228
|
+
|
229
|
+
def forward(self, x):
|
230
|
+
x, x_mean, x_std = self.preprocess_input(x)
|
231
|
+
acts = F.relu(x @ self.W_enc + self.b_enc)
|
232
|
+
return acts, x, x_mean, x_std
|
233
|
+
|
234
|
+
|
235
|
+
class VanillaSAE(BaseAutoencoder):
|
236
|
+
def __init__(self, cfg, encoder: VanillaEncoder):
|
237
|
+
super().__init__(cfg, encoder)
|
238
|
+
|
239
|
+
def forward(self, x):
|
240
|
+
acts, x, x_mean, x_std = self.encoder(x)
|
241
|
+
x_reconstruct = acts @ self.W_dec + self.b_dec
|
242
|
+
self.update_inactive_features(acts)
|
243
|
+
output = self.get_loss_dict(x, x_reconstruct, acts, x_mean, x_std)
|
244
|
+
return output
|
245
|
+
|
246
|
+
def get_loss_dict(self, x, x_reconstruct, acts, x_mean, x_std):
|
247
|
+
l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
|
248
|
+
l1_norm = acts.float().abs().sum(-1).mean()
|
249
|
+
l1_loss = self.cfg["l1_coeff"] * l1_norm
|
250
|
+
l0_norm = (acts > 0).float().sum(-1).mean()
|
251
|
+
loss = l2_loss + l1_loss
|
252
|
+
num_dead_features = (
|
253
|
+
self.num_batches_not_active > self.cfg["n_batches_to_dead"]
|
254
|
+
).sum()
|
255
|
+
|
256
|
+
sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
|
257
|
+
output = {
|
258
|
+
"sae_out": sae_out,
|
259
|
+
"feature_acts": acts,
|
260
|
+
"num_dead_features": num_dead_features,
|
261
|
+
"loss": loss,
|
262
|
+
"l1_loss": l1_loss,
|
263
|
+
"l2_loss": l2_loss,
|
264
|
+
"l0_norm": l0_norm,
|
265
|
+
"l1_norm": l1_norm,
|
266
|
+
}
|
267
|
+
return output
|
268
|
+
|
269
|
+
import torch
|
270
|
+
import torch.nn as nn
|
271
|
+
import torch.autograd as autograd
|
272
|
+
|
273
|
+
class RectangleFunction(autograd.Function):
|
274
|
+
@staticmethod
|
275
|
+
def forward(ctx, x):
|
276
|
+
ctx.save_for_backward(x)
|
277
|
+
return ((x > -0.5) & (x < 0.5)).float()
|
278
|
+
|
279
|
+
@staticmethod
|
280
|
+
def backward(ctx, grad_output):
|
281
|
+
(x,) = ctx.saved_tensors
|
282
|
+
grad_input = grad_output.clone()
|
283
|
+
grad_input[(x <= -0.5) | (x >= 0.5)] = 0
|
284
|
+
return grad_input
|
285
|
+
|
286
|
+
class JumpReLUFunction(autograd.Function):
|
287
|
+
@staticmethod
|
288
|
+
def forward(ctx, x, log_threshold, bandwidth):
|
289
|
+
ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth))
|
290
|
+
threshold = torch.exp(log_threshold)
|
291
|
+
return x * (x > threshold).float()
|
292
|
+
|
293
|
+
@staticmethod
|
294
|
+
def backward(ctx, grad_output):
|
295
|
+
x, log_threshold, bandwidth_tensor = ctx.saved_tensors
|
296
|
+
bandwidth = bandwidth_tensor.item()
|
297
|
+
threshold = torch.exp(log_threshold)
|
298
|
+
x_grad = (x > threshold).float() * grad_output
|
299
|
+
threshold_grad = (
|
300
|
+
-(threshold / bandwidth)
|
301
|
+
* RectangleFunction.apply((x - threshold) / bandwidth)
|
302
|
+
* grad_output
|
303
|
+
)
|
304
|
+
return x_grad, threshold_grad, None # None for bandwidth
|
305
|
+
|
306
|
+
class JumpReLU(nn.Module):
|
307
|
+
def __init__(self, feature_size, bandwidth, device='cpu'):
|
308
|
+
super(JumpReLU, self).__init__()
|
309
|
+
self.log_threshold = nn.Parameter(torch.zeros(feature_size, device=device))
|
310
|
+
self.bandwidth = bandwidth
|
311
|
+
|
312
|
+
def forward(self, x):
|
313
|
+
return JumpReLUFunction.apply(x, self.log_threshold, self.bandwidth)
|
314
|
+
|
315
|
+
class StepFunction(autograd.Function):
|
316
|
+
@staticmethod
|
317
|
+
def forward(ctx, x, log_threshold, bandwidth):
|
318
|
+
ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth))
|
319
|
+
threshold = torch.exp(log_threshold)
|
320
|
+
return (x > threshold).float()
|
321
|
+
|
322
|
+
@staticmethod
|
323
|
+
def backward(ctx, grad_output):
|
324
|
+
x, log_threshold, bandwidth_tensor = ctx.saved_tensors
|
325
|
+
bandwidth = bandwidth_tensor.item()
|
326
|
+
threshold = torch.exp(log_threshold)
|
327
|
+
x_grad = torch.zeros_like(x)
|
328
|
+
threshold_grad = (
|
329
|
+
-(1.0 / bandwidth)
|
330
|
+
* RectangleFunction.apply((x - threshold) / bandwidth)
|
331
|
+
* grad_output
|
332
|
+
)
|
333
|
+
return x_grad, threshold_grad, None # None for bandwidth
|
334
|
+
|
335
|
+
|
336
|
+
class JumpReLUEncoder(BaseEncoder):
|
337
|
+
def __init__(self, cfg):
|
338
|
+
super().__init__(cfg)
|
339
|
+
|
340
|
+
def forward(self, x):
|
341
|
+
x, x_mean, x_std = self.preprocess_input(x)
|
342
|
+
|
343
|
+
pre_activations = torch.relu(x @ self.W_enc + self.b_enc)
|
344
|
+
feature_magnitudes = self.jumprelu(pre_activations)
|
345
|
+
return feature_magnitudes, x, x_mean, x_std
|
346
|
+
|
347
|
+
|
348
|
+
class JumpReLUSAE(BaseAutoencoder):
|
349
|
+
def __init__(self, cfg):
|
350
|
+
super().__init__(cfg)
|
351
|
+
self.jumprelu = JumpReLU(feature_size=cfg["dict_size"], bandwidth=cfg["bandwidth"], device=cfg["device"])
|
352
|
+
|
353
|
+
def forward(self, x):
|
354
|
+
feature_magnitudes, x, x_mean, x_std = self.encoder(x)
|
355
|
+
x_reconstructed = feature_magnitudes @ self.W_dec + self.b_dec
|
356
|
+
|
357
|
+
return self.get_loss_dict(x, x_reconstructed, feature_magnitudes, x_mean, x_std)
|
358
|
+
|
359
|
+
def get_loss_dict(self, x, x_reconstruct, acts, x_mean, x_std):
|
360
|
+
l2_loss = (x_reconstruct.float() - x.float()).pow(2).mean()
|
361
|
+
|
362
|
+
l0 = StepFunction.apply(acts, self.jumprelu.log_threshold, self.cfg["bandwidth"]).sum(dim=-1).mean()
|
363
|
+
l0_loss = self.cfg["l1_coeff"] * l0
|
364
|
+
l1_loss = l0_loss
|
365
|
+
|
366
|
+
loss = l2_loss + l1_loss
|
367
|
+
num_dead_features = (
|
368
|
+
self.num_batches_not_active > self.cfg["n_batches_to_dead"]
|
369
|
+
).sum()
|
370
|
+
|
371
|
+
sae_out = self.postprocess_output(x_reconstruct, x_mean, x_std)
|
372
|
+
output = {
|
373
|
+
"sae_out": sae_out,
|
374
|
+
"feature_acts": acts,
|
375
|
+
"num_dead_features": num_dead_features,
|
376
|
+
"loss": loss,
|
377
|
+
"l1_loss": l1_loss,
|
378
|
+
"l2_loss": l2_loss,
|
379
|
+
"l0_norm": l0,
|
380
|
+
"l1_norm": l0,
|
381
|
+
}
|
382
|
+
return output
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: tmnt
|
3
|
-
Version: 0.7.
|
3
|
+
Version: 0.7.60
|
4
4
|
Summary: Topic modeling neural toolkit
|
5
5
|
Home-page: https://github.com/mitre/tmnt.git
|
6
6
|
Author: The MITRE Corporation
|
@@ -48,7 +48,7 @@ Dynamic: summary
|
|
48
48
|
The Topic Modeling Neural Toolkit (TMNT) is a software library that enables training
|
49
49
|
topic models as neural network-based variational auto-encoders.
|
50
50
|
|
51
|
-
Current stable version is: 0.7.
|
51
|
+
Current stable version is: 0.7.60
|
52
52
|
|
53
53
|
Documentation can be found here: https://tmnt.readthedocs.io/en/stable/
|
54
54
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
tmnt/__init__.py,sha256=
|
1
|
+
tmnt/__init__.py,sha256=s7YqLj32HKhIYO1QbD0zms8rDlTrleJM8LRjKt8bDPk,200
|
2
2
|
tmnt/configuration.py,sha256=P8PEhzVPKO5xG0FrdTLRQ60OYWigbzPY-OSx_hzQlrY,10054
|
3
3
|
tmnt/data_loading.py,sha256=LcVcXX00UsuAillRPILcvmqj3AsCIgzB6V_S6lfsbIY,19335
|
4
4
|
tmnt/distribution.py,sha256=4gn1wnszVAErzICCvZXSYki0G78WC3_jyBr27N-Aj3E,15108
|
@@ -9,6 +9,11 @@ tmnt/modeling.py,sha256=rGHQsW7ldycFUd1f9NzcnNuSRElr600vLwmYPl6YY0M,30215
|
|
9
9
|
tmnt/preprocess/__init__.py,sha256=gwMejkQrnqKS05i0JVsUru2hDUR5jE1hKC10dL934GU,170
|
10
10
|
tmnt/preprocess/tokenizer.py,sha256=-ZgowfbHrM040vbNTktZM_hdl6HDTqxSJ4mDAxq3dUs,14050
|
11
11
|
tmnt/preprocess/vectorizer.py,sha256=RaianZ_DG3Nc-RI96FtmI4PCZPi5Nipx9a5xndLZ52M,20689
|
12
|
+
tmnt/sparse/__init__.py,sha256=BEhOm_o0UrVUKTG3rSiBJzE7qQQL9HRSZ1MHCA2GJu8,249
|
13
|
+
tmnt/sparse/config.py,sha256=gfJ1BAP3zzMKKUJExrs8D0hHB7XOVvqXPpmx0ECAPIE,796
|
14
|
+
tmnt/sparse/estimator.py,sha256=SOEPkQo7T2RuBJLYRAk65MhoYjgYtNCAzlvhqq10c5E,4596
|
15
|
+
tmnt/sparse/inference.py,sha256=etOuXTxh8bKc7EoohZLYYufFAD51aEpq8mGj1vjULbg,2813
|
16
|
+
tmnt/sparse/modeling.py,sha256=IejHbRXyj5WsEthkVbp3vYF1wTAAsihbneaPuCgAfGA,13612
|
12
17
|
tmnt/utils/__init__.py,sha256=1PZsxRPsHI_DnOpxD0iAhLxhxHnx6Svzg3W-79YfWWs,237
|
13
18
|
tmnt/utils/csv2json.py,sha256=A1TXy-uxA4dc9tw0tjiHzL7fv4C6b0Uc_bwI1keTmKU,795
|
14
19
|
tmnt/utils/log_utils.py,sha256=ZtR4nF_Iee23ev935YQcTtXv-cCC7lgXkXLl_yokfS4,2075
|
@@ -18,9 +23,9 @@ tmnt/utils/pubmed_utils.py,sha256=3sHwoun7vxb0GV-arhpXLMUbAZne0huAh9xQNy6H40E,12
|
|
18
23
|
tmnt/utils/random.py,sha256=qY75WG3peWoMh9pUyCPBEo6q8IvkF6VRjeb5CqJOBF8,327
|
19
24
|
tmnt/utils/recalibrate.py,sha256=TmpB8An8bslICZ13UTJfIvr8VoqiSedtpHxec4n8CHk,1439
|
20
25
|
tmnt/utils/vocab.py,sha256=J6GFGLyvDgdmtVQjYlyzWjuykRD3kllCKPG1z0lI0P8,3504
|
21
|
-
tmnt-0.7.
|
22
|
-
tmnt-0.7.
|
23
|
-
tmnt-0.7.
|
24
|
-
tmnt-0.7.
|
25
|
-
tmnt-0.7.
|
26
|
-
tmnt-0.7.
|
26
|
+
tmnt-0.7.60.dist-info/licenses/LICENSE,sha256=qFZJrfJ7Zi4IXDiyiGVrHWic_l1h2tc36tI8Z7rK9bs,11356
|
27
|
+
tmnt-0.7.60.dist-info/licenses/NOTICE,sha256=p0kYIVAkReTFaGb4C-qPa7h5ztze6hGzOpjCMMbOipU,425
|
28
|
+
tmnt-0.7.60.dist-info/METADATA,sha256=DbmpuasEyoW6FNHmu-YfNsgq5J1lV0MjWPe04CCq8Fk,1663
|
29
|
+
tmnt-0.7.60.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
|
30
|
+
tmnt-0.7.60.dist-info/top_level.txt,sha256=RpYgUl187sXnqmiwKjZZdcDlHz2AALs6bGdUcukyd_E,5
|
31
|
+
tmnt-0.7.60.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|