cosmoglint 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cosmoglint/__init__.py +1 -0
- cosmoglint/model/__init__.py +2 -0
- cosmoglint/model/transformer.py +500 -0
- cosmoglint/model/transformer_nf.py +368 -0
- cosmoglint/utils/ReadPinocchio5.py +1022 -0
- cosmoglint/utils/__init__.py +2 -0
- cosmoglint/utils/cosmology_utils.py +194 -0
- cosmoglint/utils/generation_utils.py +366 -0
- cosmoglint/utils/io_utils.py +397 -0
- cosmoglint-1.0.0.dist-info/METADATA +164 -0
- cosmoglint-1.0.0.dist-info/RECORD +14 -0
- cosmoglint-1.0.0.dist-info/WHEEL +5 -0
- cosmoglint-1.0.0.dist-info/licenses/LICENSE +21 -0
- cosmoglint-1.0.0.dist-info/top_level.txt +1 -0
cosmoglint/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "1.0.0"
|
|
@@ -0,0 +1,500 @@
|
|
|
1
|
+
import random
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
|
|
6
|
+
from torch.distributions import Categorical
|
|
7
|
+
|
|
8
|
+
def transformer_model(args, **kwargs):
|
|
9
|
+
|
|
10
|
+
if "transformer" in args.model_name:
|
|
11
|
+
if args.model_name == "transformer1":
|
|
12
|
+
model_class = Transformer1
|
|
13
|
+
elif args.model_name == "transformer2":
|
|
14
|
+
model_class = Transformer2
|
|
15
|
+
elif args.model_name == "transformer1_with_global_cond":
|
|
16
|
+
model_class = TransformerWithGlobalCond
|
|
17
|
+
transformer_cls = Transformer1
|
|
18
|
+
elif args.model_name == "transformer2_with_global_cond":
|
|
19
|
+
model_class = TransformerWithGlobalCond
|
|
20
|
+
transformer_cls = Transformer2
|
|
21
|
+
elif args.model_name == "transformer1_with_attn":
|
|
22
|
+
model_class = Transformer1WithAttn
|
|
23
|
+
elif args.model_name == "transformer2_with_attn":
|
|
24
|
+
model_class = Transformer2WithAttn
|
|
25
|
+
else:
|
|
26
|
+
raise ValueError(f"Invalid model: {args.model_name}")
|
|
27
|
+
|
|
28
|
+
if len(args.output_features) != args.num_features_in:
|
|
29
|
+
raise ValueError(f"num_features ({args.num_features_in}) is not consistent with the list of output features ({args.output_features})")
|
|
30
|
+
|
|
31
|
+
common_args = dict(
|
|
32
|
+
d_model=args.d_model,
|
|
33
|
+
num_layers=args.num_layers,
|
|
34
|
+
num_heads=args.num_heads,
|
|
35
|
+
num_condition=args.num_features_cond,
|
|
36
|
+
num_features_out=args.num_features_out,
|
|
37
|
+
output_features=args.output_features,
|
|
38
|
+
**kwargs,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
if args.use_flat_representation:
|
|
42
|
+
common_args["max_length"] = args.max_length * args.num_features_in
|
|
43
|
+
common_args["num_features_in"] = 1
|
|
44
|
+
common_args["num_token_types"] = args.num_features_in
|
|
45
|
+
|
|
46
|
+
else:
|
|
47
|
+
common_args["max_length"] = args.max_length
|
|
48
|
+
common_args["num_features_in"] = args.num_features_in
|
|
49
|
+
common_args["num_token_types"] = 1
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
if "with_global_cond" in args.model_name:
|
|
53
|
+
common_args["num_features_global"] = args.num_features_global
|
|
54
|
+
common_args["transformer_cls"] = transformer_cls
|
|
55
|
+
|
|
56
|
+
model = model_class(**common_args)
|
|
57
|
+
|
|
58
|
+
else:
|
|
59
|
+
raise ValueError(f"Invalid model: {args.model}")
|
|
60
|
+
|
|
61
|
+
return model
|
|
62
|
+
|
|
63
|
+
class TransformerBase(nn.Module):
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
num_condition = 1,
|
|
67
|
+
d_model = 128,
|
|
68
|
+
num_layers = 4,
|
|
69
|
+
num_heads = 8,
|
|
70
|
+
max_length = 10,
|
|
71
|
+
num_features_in = 1,
|
|
72
|
+
num_features_out = 1,
|
|
73
|
+
num_token_types = 1,
|
|
74
|
+
output_features = ["SubhaloSFR"],
|
|
75
|
+
central_values = {"SubhaloDist": 0.0, "SubhaloVrad": 0.5},
|
|
76
|
+
dropout = 0
|
|
77
|
+
):
|
|
78
|
+
super().__init__()
|
|
79
|
+
|
|
80
|
+
self.d_model = d_model
|
|
81
|
+
self.max_length = max_length
|
|
82
|
+
self.num_features_in = num_features_in
|
|
83
|
+
self.num_features_out = num_features_out
|
|
84
|
+
self.num_token_types = num_token_types
|
|
85
|
+
|
|
86
|
+
self.output_idx_map = {name: i for i, name in enumerate(output_features)}
|
|
87
|
+
self.output_features = output_features
|
|
88
|
+
self.central_values = central_values
|
|
89
|
+
|
|
90
|
+
# Pisition and feature type embedding
|
|
91
|
+
actual_max_length = max_length // num_token_types
|
|
92
|
+
self.pos_embedding = nn.Embedding(actual_max_length, d_model)
|
|
93
|
+
self.token_type_embedding = nn.Embedding(num_token_types, d_model)
|
|
94
|
+
|
|
95
|
+
token_pos_id = torch.arange(actual_max_length).repeat_interleave(num_token_types)
|
|
96
|
+
token_type_id = torch.arange(num_token_types).repeat(actual_max_length)
|
|
97
|
+
self.register_buffer("token_pos_id", token_pos_id.long())
|
|
98
|
+
self.register_buffer("token_type_id", token_type_id.long())
|
|
99
|
+
|
|
100
|
+
def forward(self, context, x, global_cond=None):
|
|
101
|
+
raise NotImplementedError("forward method not implemented")
|
|
102
|
+
|
|
103
|
+
def generate_square_subsequent_mask(self, sz):
|
|
104
|
+
mask = torch.triu(torch.ones(sz, sz), diagonal=1)
|
|
105
|
+
mask = mask.masked_fill(mask==1, float('-inf'))
|
|
106
|
+
return mask
|
|
107
|
+
|
|
108
|
+
def _set_to_zero(self, x, mask):
|
|
109
|
+
zero_tensor = torch.tensor(0.0).to(x.device)
|
|
110
|
+
return torch.where(mask, zero_tensor, x)
|
|
111
|
+
|
|
112
|
+
def generate(
|
|
113
|
+
self,
|
|
114
|
+
context,
|
|
115
|
+
global_cond = None,
|
|
116
|
+
seq = None,
|
|
117
|
+
teacher_forcing_ratio = 0.0,
|
|
118
|
+
temperature = 1.0,
|
|
119
|
+
stop_criterion = None,
|
|
120
|
+
prob_threshold = 1e-5,
|
|
121
|
+
monotonicity_start_index = 1,
|
|
122
|
+
max_ids = None,
|
|
123
|
+
buffer_percent = 0.05
|
|
124
|
+
):
|
|
125
|
+
"""
|
|
126
|
+
context: (B, num_features_in)
|
|
127
|
+
global_cond: (B, num_features_global) -- simply ignored
|
|
128
|
+
seq: (B, L, num_features_out)
|
|
129
|
+
teacher_forcing_ratio:
|
|
130
|
+
monotonicity_start_index: from which galaxy to enforce monotonicity. No enforcement if < 0.
|
|
131
|
+
max_ids: torch tensor listing the indices of maximum value for the primary parameter for differen
|
|
132
|
+
"""
|
|
133
|
+
# context: (batch, num_condition)
|
|
134
|
+
batch_size = len(context)
|
|
135
|
+
|
|
136
|
+
if len(context.shape) == 1:
|
|
137
|
+
context = context.unsqueeze(-1)
|
|
138
|
+
|
|
139
|
+
# used for enforce_monotonicity and max_ids
|
|
140
|
+
# buffer = 1 indicates the bins just above the max_ids are avoided.
|
|
141
|
+
buffer = max(int(buffer_percent * self.num_features_out), 1)
|
|
142
|
+
if max_ids is not None:
|
|
143
|
+
max_ids = max_ids.to(context.device) + buffer # (nbins, )
|
|
144
|
+
nbins = len(max_ids)
|
|
145
|
+
context_bins = torch.linspace(0, 1, nbins, device=context.device) # (nbins, )
|
|
146
|
+
context_bin_indices = torch.bucketize(context[:, 0], context_bins) - 1 # (batch, )
|
|
147
|
+
bin_indices = torch.arange(self.num_features_out, device=context.device) # (num_features_out, )
|
|
148
|
+
mask_max_ids = (bin_indices.unsqueeze(0) > max_ids[context_bin_indices].unsqueeze(1)) # (batch, num_features_out)
|
|
149
|
+
|
|
150
|
+
generated = torch.zeros(batch_size, self.max_length, self.num_features_in).to(context.device) # (batch, max_length, num_features_in)
|
|
151
|
+
mask_all_batch = torch.ones(batch_size, dtype=torch.bool).to(context.device)
|
|
152
|
+
|
|
153
|
+
for t in range(self.max_length):
|
|
154
|
+
|
|
155
|
+
if seq is not None and t < seq.size(1) and random.random() < teacher_forcing_ratio:
|
|
156
|
+
next_token = seq[:, t]
|
|
157
|
+
else:
|
|
158
|
+
x = self(context, generated[:, :t], global_cond=global_cond)
|
|
159
|
+
# generated[:, :t]: (batch, t, num_features_in)
|
|
160
|
+
# x: (batch, t+1, num_faetures_in, num_features_out)
|
|
161
|
+
|
|
162
|
+
x_last = x[:, -1, :, :]
|
|
163
|
+
# last taken (batch, num_features_in, num_features_out)
|
|
164
|
+
|
|
165
|
+
x_last = x_last / temperature
|
|
166
|
+
|
|
167
|
+
token_type = t % self.num_token_types
|
|
168
|
+
|
|
169
|
+
if token_type == 0:
|
|
170
|
+
if monotonicity_start_index is not None:
|
|
171
|
+
# Set the probability at x(t) >= x(t-1) to zero for the primary parameter
|
|
172
|
+
if t > monotonicity_start_index:
|
|
173
|
+
previous_token_bin = (generated[:, t - self.num_token_types, 0] * self.num_features_out).long() + buffer
|
|
174
|
+
previous_token_bin = previous_token_bin.contiguous().view(-1, 1) # (batch, 1)
|
|
175
|
+
bin_indices = torch.arange(self.num_features_out, device=context.device).view(1, -1) # (1, num_features_out)
|
|
176
|
+
mask = (bin_indices > previous_token_bin) # (batch, num_features_out)
|
|
177
|
+
mask = mask & (x_last[:, 0, :] >= prob_threshold) # (batch, num_features_out)
|
|
178
|
+
x_last[:, 0, :] = self._set_to_zero(x_last[:, 0, :], mask) # set the probability to zero for bins above the previous bin
|
|
179
|
+
|
|
180
|
+
if max_ids is not None:
|
|
181
|
+
# Even when monotonicity is not enforced, set the probability at x > x_max to zero for the primary parameter if max_ids is defined.
|
|
182
|
+
x_last[:, 0, :] = self._set_to_zero(x_last[:, 0, :], mask_max_ids)
|
|
183
|
+
|
|
184
|
+
x_last = self._set_to_zero(x_last, x_last < prob_threshold) # set the probability to zero if less than prob_threshold
|
|
185
|
+
|
|
186
|
+
x_last = x_last.reshape(-1, self.num_features_out) # (batch * num_features_in, num_features_out)
|
|
187
|
+
bin_indices = Categorical(probs=x_last).sample().float().view(-1, self.num_features_in) # (batch, num_features_in)
|
|
188
|
+
uniform_noise = torch.rand_like(bin_indices, device=context.device) # (batch, num_features_in)
|
|
189
|
+
next_token = (bin_indices + uniform_noise) / self.num_features_out # (batch, num_features_in)
|
|
190
|
+
|
|
191
|
+
if token_type == 0:
|
|
192
|
+
next_token[:, 0] = self._set_to_zero(next_token[:, 0], next_token[:,0] < 1./ self.num_features_out) # strictly set the sampled primary parameter to zero if it is less than 1/num_features_out
|
|
193
|
+
|
|
194
|
+
mask_all_batch = torch.ones(batch_size, dtype=torch.bool).to(context.device)
|
|
195
|
+
|
|
196
|
+
# Set the central galaxy's parameters to fixed values
|
|
197
|
+
is_first_gal = ( t // self.num_token_types == 0 )
|
|
198
|
+
if is_first_gal:
|
|
199
|
+
if self.num_token_types == 1:
|
|
200
|
+
for feat, cval in self.central_values.items():
|
|
201
|
+
idx = self.output_idx_map.get(feat)
|
|
202
|
+
if idx is not None:
|
|
203
|
+
next_token[:, idx] = cval + self._set_to_zero(next_token[:, idx], mask_all_batch)
|
|
204
|
+
else:
|
|
205
|
+
feat = self.output_features[token_type]
|
|
206
|
+
cval = self.central_values.get(feat)
|
|
207
|
+
if cval is not None:
|
|
208
|
+
next_token[:, 0] = cval + self._set_to_zero(next_token[:, 0], mask_all_batch)
|
|
209
|
+
|
|
210
|
+
# Stop generation if the primary parameter is below criterion
|
|
211
|
+
if token_type == 0:
|
|
212
|
+
if stop_criterion is not None:
|
|
213
|
+
if torch.all(next_token[:,0] < stop_criterion):
|
|
214
|
+
return generated, x
|
|
215
|
+
|
|
216
|
+
generated[:, t, :] = next_token # (batch, num_features_in)
|
|
217
|
+
|
|
218
|
+
if seq is not None and teacher_forcing_ratio > 0:
|
|
219
|
+
x = self(context, generated[:,:-1])
|
|
220
|
+
|
|
221
|
+
return generated, x
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class Transformer1(TransformerBase): # add logM at first in the sequence
|
|
225
|
+
def __init__(
|
|
226
|
+
self,
|
|
227
|
+
num_condition = 1,
|
|
228
|
+
d_model = 128,
|
|
229
|
+
num_layers = 4,
|
|
230
|
+
num_heads = 8,
|
|
231
|
+
max_length = 10,
|
|
232
|
+
num_features_in = 1,
|
|
233
|
+
num_features_out = 1,
|
|
234
|
+
num_token_types = 1,
|
|
235
|
+
output_features = ["SubhaloSFR"],
|
|
236
|
+
dropout = 0,
|
|
237
|
+
last_activation = nn.Softmax(dim=-1),
|
|
238
|
+
pred_prob = True
|
|
239
|
+
):
|
|
240
|
+
super().__init__(num_condition=num_condition, d_model=d_model, num_layers=num_layers, num_heads=num_heads, max_length=max_length, num_features_in=num_features_in, num_features_out=num_features_out, num_token_types=num_token_types, output_features=output_features, dropout=dropout)
|
|
241
|
+
|
|
242
|
+
self.embedding_layers = nn.Sequential(
|
|
243
|
+
nn.Linear(num_features_in, d_model),
|
|
244
|
+
nn.LeakyReLU(),
|
|
245
|
+
nn.Dropout(dropout),
|
|
246
|
+
nn.Linear(d_model, d_model),
|
|
247
|
+
)
|
|
248
|
+
self.context_embedding_layers = nn.Sequential(
|
|
249
|
+
nn.Linear(num_condition, d_model),
|
|
250
|
+
nn.LeakyReLU(),
|
|
251
|
+
nn.Dropout(dropout),
|
|
252
|
+
nn.Linear(d_model, d_model),
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=num_heads, batch_first=True, dropout=dropout)
|
|
256
|
+
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
|
|
257
|
+
|
|
258
|
+
self.pred_prob = pred_prob
|
|
259
|
+
if pred_prob:
|
|
260
|
+
self.output_layer = nn.Linear(d_model, num_features_in*num_features_out)
|
|
261
|
+
else:
|
|
262
|
+
self.output_layer = nn.Linear(d_model, num_features_out)
|
|
263
|
+
|
|
264
|
+
self.out_activation = last_activation
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def forward(self, context, x, global_cond=None):
|
|
268
|
+
# context: (batch, num_condition)
|
|
269
|
+
# x: (batch, seq_length, num_features_in)
|
|
270
|
+
|
|
271
|
+
batch_size, seq_length, num_features_in = x.shape
|
|
272
|
+
total_seq_length = 1 + seq_length # total length of (context and x)
|
|
273
|
+
|
|
274
|
+
# concatenate embeddings of context and x
|
|
275
|
+
context = context.view(batch_size, 1, -1) # (batch, 1, num_condition)
|
|
276
|
+
for layer in self.context_embedding_layers:
|
|
277
|
+
context = layer(context)
|
|
278
|
+
# context: (batch, 1, d_model)
|
|
279
|
+
|
|
280
|
+
for layer in self.embedding_layers:
|
|
281
|
+
x = layer(x)
|
|
282
|
+
x = torch.cat([context, x], dim=1) # (batch, seq_length + 1, d_model)
|
|
283
|
+
|
|
284
|
+
# add position and type embedding
|
|
285
|
+
pos_emb = self.pos_embedding(self.token_pos_id[:total_seq_length]).unsqueeze(0) # (1, seq_length + 1, d_model)
|
|
286
|
+
type_emb = self.token_type_embedding(self.token_type_id[:total_seq_length]).unsqueeze(0) # (1, seq_length + 1, d_model)
|
|
287
|
+
x = x + pos_emb + type_emb # (batch, seq_length + 1, d_model)
|
|
288
|
+
|
|
289
|
+
# decode
|
|
290
|
+
causal_mask = self.generate_square_subsequent_mask(total_seq_length).to(x.device)
|
|
291
|
+
dummy_memory = torch.zeros(batch_size, 1, self.d_model, device=x.device)
|
|
292
|
+
x = self.decoder(x, memory=dummy_memory, tgt_mask=causal_mask) # (batch, seq_length + 1, d_model)
|
|
293
|
+
|
|
294
|
+
# output layer
|
|
295
|
+
x = self.output_layer(x) # (batch, seq_length + 1, num_features_in * num_features_out) or (batch, seq_length + 1, num_features_out)
|
|
296
|
+
|
|
297
|
+
if self.pred_prob:
|
|
298
|
+
x = x.view(batch_size, total_seq_length, self.num_features_in, -1) # (batch, seq_length + 1, num_features_in, num_features_out)
|
|
299
|
+
|
|
300
|
+
x = self.out_activation(x)
|
|
301
|
+
# x = entmax(x, dim=-1)
|
|
302
|
+
|
|
303
|
+
return x
|
|
304
|
+
|
|
305
|
+
class Transformer2(TransformerBase): # embed context and x together, and then add positional embedding
|
|
306
|
+
def __init__(
|
|
307
|
+
self,
|
|
308
|
+
num_condition = 1,
|
|
309
|
+
d_model = 128,
|
|
310
|
+
num_layers = 4,
|
|
311
|
+
num_heads = 8,
|
|
312
|
+
max_length = 10,
|
|
313
|
+
num_features_in = 1,
|
|
314
|
+
num_features_out = 1,
|
|
315
|
+
num_token_types = 1,
|
|
316
|
+
output_features = ["SubhaloSFR"],
|
|
317
|
+
dropout = 0,
|
|
318
|
+
last_activation = nn.Softmax(dim=-1),
|
|
319
|
+
pred_prob = True
|
|
320
|
+
):
|
|
321
|
+
|
|
322
|
+
super().__init__(num_condition=num_condition, d_model=d_model, num_layers=num_layers, num_heads=num_heads, max_length=max_length, num_features_in=num_features_in, num_features_out=num_features_out, num_token_types=num_token_types, output_features=output_features, dropout=dropout)
|
|
323
|
+
|
|
324
|
+
self.start_token = torch.ones(1)
|
|
325
|
+
|
|
326
|
+
self.embedding_layers = nn.Sequential(
|
|
327
|
+
nn.Linear(num_condition+num_features_in, d_model),
|
|
328
|
+
nn.LeakyReLU(),
|
|
329
|
+
nn.Dropout(dropout),
|
|
330
|
+
nn.Linear(d_model, d_model),
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=num_heads, batch_first=True, dropout=dropout)
|
|
334
|
+
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
|
|
335
|
+
|
|
336
|
+
self.pred_prob = pred_prob
|
|
337
|
+
if pred_prob:
|
|
338
|
+
self.output_layer = nn.Linear(d_model, num_features_in*num_features_out)
|
|
339
|
+
else:
|
|
340
|
+
self.output_layer = nn.Linear(d_model, num_features_out)
|
|
341
|
+
|
|
342
|
+
self.out_activation = last_activation
|
|
343
|
+
|
|
344
|
+
def forward(self, context, x, global_cond=None):
|
|
345
|
+
# context: (batch, num_condition)
|
|
346
|
+
# x: (batch, seq_length, num_features_in)
|
|
347
|
+
|
|
348
|
+
batch_size, seq_length, num_features_in = x.shape
|
|
349
|
+
total_seq_length = 1 + seq_length # total length of (start token and x)
|
|
350
|
+
|
|
351
|
+
context = context.view(batch_size, 1, -1) # (batch, 1, num_condition)
|
|
352
|
+
context = context.expand(batch_size, seq_length+1, -1) # (batch, seq_length+1, num_condition)
|
|
353
|
+
|
|
354
|
+
## add start token
|
|
355
|
+
start_token = self.start_token.expand(batch_size, 1, self.num_features_in).to(x.device) # (batch, 1, num_features_in)
|
|
356
|
+
x = torch.cat([start_token, x], dim=1) # (batch, seq_length+1, num_features_in)
|
|
357
|
+
|
|
358
|
+
## concatenate context and x
|
|
359
|
+
x = torch.cat([context, x], dim=2) # (batch, seq_length+1, num_condition + num_features_in)
|
|
360
|
+
|
|
361
|
+
## embedding
|
|
362
|
+
for layer in self.embedding_layers:
|
|
363
|
+
x = layer(x)
|
|
364
|
+
# x: (batch, seq_length+1, d_model)
|
|
365
|
+
|
|
366
|
+
# add position and type embedding
|
|
367
|
+
pos_emb = self.pos_embedding(self.token_pos_id[:total_seq_length]).unsqueeze(0) # (1, seq_length + 1, d_model)
|
|
368
|
+
type_emb = self.token_type_embedding(self.token_type_id[:total_seq_length]).unsqueeze(0) # (1, seq_length + 1, d_model)
|
|
369
|
+
x = x + pos_emb + type_emb # (batch, seq_length + 1, d_model)
|
|
370
|
+
|
|
371
|
+
# decode
|
|
372
|
+
causal_mask = self.generate_square_subsequent_mask(seq_length+1).to(x.device)
|
|
373
|
+
dummy_memory = torch.zeros(batch_size, 1, self.d_model, device=x.device)
|
|
374
|
+
x = self.decoder(x, memory=dummy_memory, tgt_mask=causal_mask) # (batch, seq_length+1, d_model)
|
|
375
|
+
|
|
376
|
+
# output layer
|
|
377
|
+
x = self.output_layer(x) # (batch, seq_length+1, num_features_in * num_features_out) or (batch, seq_length+1, num_features_out)
|
|
378
|
+
|
|
379
|
+
if self.pred_prob:
|
|
380
|
+
x = x.view(batch_size, seq_length+1, self.num_features_in, -1) # (batch, seq_length+1, num_features_in, num_features_out)
|
|
381
|
+
|
|
382
|
+
x = self.out_activation(x)
|
|
383
|
+
|
|
384
|
+
return x
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
class TransformerWithGlobalCond(nn.Module):
|
|
388
|
+
def __init__(
|
|
389
|
+
self,
|
|
390
|
+
num_features_global,
|
|
391
|
+
transformer_cls = Transformer1,
|
|
392
|
+
num_condition = 1,
|
|
393
|
+
d_model = 128,
|
|
394
|
+
num_layers = 4,
|
|
395
|
+
num_heads = 8,
|
|
396
|
+
max_length = 10,
|
|
397
|
+
num_features_in = 1,
|
|
398
|
+
num_features_out = 1,
|
|
399
|
+
num_token_types = 1,
|
|
400
|
+
output_features = ["SubhaloSFR"],
|
|
401
|
+
dropout = 0,
|
|
402
|
+
last_activation = nn.Softmax(dim=-1),
|
|
403
|
+
pred_prob = True
|
|
404
|
+
):
|
|
405
|
+
super().__init__()
|
|
406
|
+
|
|
407
|
+
self.transformer = transformer_cls(
|
|
408
|
+
num_condition = num_condition + num_features_global,
|
|
409
|
+
d_model = d_model,
|
|
410
|
+
num_layers = num_layers,
|
|
411
|
+
num_heads = num_heads,
|
|
412
|
+
max_length = max_length,
|
|
413
|
+
num_features_in = num_features_in,
|
|
414
|
+
num_features_out = num_features_out,
|
|
415
|
+
num_token_types = num_token_types,
|
|
416
|
+
output_features = output_features,
|
|
417
|
+
dropout = dropout,
|
|
418
|
+
last_activation = last_activation,
|
|
419
|
+
pred_prob = pred_prob
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
def forward(self, context, x, global_cond):
|
|
423
|
+
if len(context.shape) == 1:
|
|
424
|
+
context = context.unsqueeze(-1)
|
|
425
|
+
ctx = torch.cat([context, global_cond], dim=1)
|
|
426
|
+
return self.transformer(ctx, x)
|
|
427
|
+
|
|
428
|
+
def generate(self, context, global_cond, **kwargs):
|
|
429
|
+
if len(context.shape) == 1:
|
|
430
|
+
context = context.unsqueeze(-1)
|
|
431
|
+
ctx = torch.cat([context, global_cond], dim=1)
|
|
432
|
+
return self.transformer.generate(ctx, **kwargs)
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
from typing import Optional
|
|
436
|
+
|
|
437
|
+
class TransformerDecoderLayerWithAttn(nn.TransformerDecoderLayer):
|
|
438
|
+
def _sa_block(
|
|
439
|
+
self,
|
|
440
|
+
x: torch.Tensor,
|
|
441
|
+
attn_mask: Optional[torch.Tensor],
|
|
442
|
+
key_padding_mask: Optional[torch.Tensor],
|
|
443
|
+
is_causal: bool = False
|
|
444
|
+
) -> torch.Tensor:
|
|
445
|
+
|
|
446
|
+
attn_output, attn_weights = self.self_attn(
|
|
447
|
+
x, x, x,
|
|
448
|
+
attn_mask=attn_mask,
|
|
449
|
+
key_padding_mask=key_padding_mask,
|
|
450
|
+
need_weights=True,
|
|
451
|
+
is_causal=is_causal,
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
self.attn_weights = attn_weights.detach().cpu()
|
|
455
|
+
return self.dropout1(attn_output)
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
class Transformer1WithAttn(Transformer1):
|
|
459
|
+
def __init__(
|
|
460
|
+
self,
|
|
461
|
+
num_condition = 1,
|
|
462
|
+
d_model = 128,
|
|
463
|
+
num_layers = 4,
|
|
464
|
+
num_heads = 8,
|
|
465
|
+
max_length = 10,
|
|
466
|
+
num_features_in = 1,
|
|
467
|
+
num_features_out = 1,
|
|
468
|
+
num_token_types = 1,
|
|
469
|
+
output_features = ["SubhaloSFR"],
|
|
470
|
+
dropout = 0,
|
|
471
|
+
last_activation = nn.Softmax(dim=-1),
|
|
472
|
+
pred_prob = True
|
|
473
|
+
):
|
|
474
|
+
super().__init__(num_condition=num_condition, d_model=d_model, num_layers=num_layers, num_heads=num_heads, max_length=max_length, num_features_in=num_features_in, num_features_out=num_features_out, num_token_types=num_token_types, output_features=output_features, dropout=dropout, last_activation=last_activation, pred_prob=pred_prob)
|
|
475
|
+
|
|
476
|
+
decoder_layer = TransformerDecoderLayerWithAttn(d_model=d_model, nhead=num_heads, batch_first=True, dropout=dropout)
|
|
477
|
+
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
|
|
478
|
+
|
|
479
|
+
class Transformer2WithAttn(Transformer2):
|
|
480
|
+
def __init__(
|
|
481
|
+
self,
|
|
482
|
+
num_condition = 1,
|
|
483
|
+
d_model = 128,
|
|
484
|
+
num_layers = 4,
|
|
485
|
+
num_heads = 8,
|
|
486
|
+
max_length = 10,
|
|
487
|
+
num_features_in = 1,
|
|
488
|
+
num_features_out = 1,
|
|
489
|
+
num_token_types = 1,
|
|
490
|
+
output_features = ["SubhaloSFR"],
|
|
491
|
+
dropout = 0,
|
|
492
|
+
last_activation = nn.Softmax(dim=-1),
|
|
493
|
+
pred_prob = True
|
|
494
|
+
):
|
|
495
|
+
super().__init__(num_condition=num_condition, d_model=d_model, num_layers=num_layers, num_heads=num_heads, max_length=max_length, num_features_in=num_features_in, num_features_out=num_features_out, num_token_types=num_token_types, output_features=output_features, dropout=dropout, last_activation=last_activation, pred_prob=pred_prob)
|
|
496
|
+
|
|
497
|
+
decoder_layer = TransformerDecoderLayerWithAttn(d_model=d_model, nhead=num_heads, batch_first=True, dropout=dropout)
|
|
498
|
+
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
|
|
499
|
+
|
|
500
|
+
|