cosmoglint 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cosmoglint/__init__.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "1.0.0"
@@ -0,0 +1,2 @@
1
+ from .transformer import transformer_model, Transformer1, Transformer2, Transformer1WithAttn, Transformer2WithAttn
2
+ from .transformer_nf import transformer_nf_model, my_stop_predictor, calculate_transformer_nf_loss, generate_with_transformer_nf
@@ -0,0 +1,500 @@
1
+ import random
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from torch.distributions import Categorical
7
+
8
+ def transformer_model(args, **kwargs):
9
+
10
+ if "transformer" in args.model_name:
11
+ if args.model_name == "transformer1":
12
+ model_class = Transformer1
13
+ elif args.model_name == "transformer2":
14
+ model_class = Transformer2
15
+ elif args.model_name == "transformer1_with_global_cond":
16
+ model_class = TransformerWithGlobalCond
17
+ transformer_cls = Transformer1
18
+ elif args.model_name == "transformer2_with_global_cond":
19
+ model_class = TransformerWithGlobalCond
20
+ transformer_cls = Transformer2
21
+ elif args.model_name == "transformer1_with_attn":
22
+ model_class = Transformer1WithAttn
23
+ elif args.model_name == "transformer2_with_attn":
24
+ model_class = Transformer2WithAttn
25
+ else:
26
+ raise ValueError(f"Invalid model: {args.model_name}")
27
+
28
+ if len(args.output_features) != args.num_features_in:
29
+ raise ValueError(f"num_features ({args.num_features_in}) is not consistent with the list of output features ({args.output_features})")
30
+
31
+ common_args = dict(
32
+ d_model=args.d_model,
33
+ num_layers=args.num_layers,
34
+ num_heads=args.num_heads,
35
+ num_condition=args.num_features_cond,
36
+ num_features_out=args.num_features_out,
37
+ output_features=args.output_features,
38
+ **kwargs,
39
+ )
40
+
41
+ if args.use_flat_representation:
42
+ common_args["max_length"] = args.max_length * args.num_features_in
43
+ common_args["num_features_in"] = 1
44
+ common_args["num_token_types"] = args.num_features_in
45
+
46
+ else:
47
+ common_args["max_length"] = args.max_length
48
+ common_args["num_features_in"] = args.num_features_in
49
+ common_args["num_token_types"] = 1
50
+
51
+
52
+ if "with_global_cond" in args.model_name:
53
+ common_args["num_features_global"] = args.num_features_global
54
+ common_args["transformer_cls"] = transformer_cls
55
+
56
+ model = model_class(**common_args)
57
+
58
+ else:
59
+ raise ValueError(f"Invalid model: {args.model}")
60
+
61
+ return model
62
+
63
+ class TransformerBase(nn.Module):
64
+ def __init__(
65
+ self,
66
+ num_condition = 1,
67
+ d_model = 128,
68
+ num_layers = 4,
69
+ num_heads = 8,
70
+ max_length = 10,
71
+ num_features_in = 1,
72
+ num_features_out = 1,
73
+ num_token_types = 1,
74
+ output_features = ["SubhaloSFR"],
75
+ central_values = {"SubhaloDist": 0.0, "SubhaloVrad": 0.5},
76
+ dropout = 0
77
+ ):
78
+ super().__init__()
79
+
80
+ self.d_model = d_model
81
+ self.max_length = max_length
82
+ self.num_features_in = num_features_in
83
+ self.num_features_out = num_features_out
84
+ self.num_token_types = num_token_types
85
+
86
+ self.output_idx_map = {name: i for i, name in enumerate(output_features)}
87
+ self.output_features = output_features
88
+ self.central_values = central_values
89
+
90
+ # Pisition and feature type embedding
91
+ actual_max_length = max_length // num_token_types
92
+ self.pos_embedding = nn.Embedding(actual_max_length, d_model)
93
+ self.token_type_embedding = nn.Embedding(num_token_types, d_model)
94
+
95
+ token_pos_id = torch.arange(actual_max_length).repeat_interleave(num_token_types)
96
+ token_type_id = torch.arange(num_token_types).repeat(actual_max_length)
97
+ self.register_buffer("token_pos_id", token_pos_id.long())
98
+ self.register_buffer("token_type_id", token_type_id.long())
99
+
100
+ def forward(self, context, x, global_cond=None):
101
+ raise NotImplementedError("forward method not implemented")
102
+
103
+ def generate_square_subsequent_mask(self, sz):
104
+ mask = torch.triu(torch.ones(sz, sz), diagonal=1)
105
+ mask = mask.masked_fill(mask==1, float('-inf'))
106
+ return mask
107
+
108
+ def _set_to_zero(self, x, mask):
109
+ zero_tensor = torch.tensor(0.0).to(x.device)
110
+ return torch.where(mask, zero_tensor, x)
111
+
112
+ def generate(
113
+ self,
114
+ context,
115
+ global_cond = None,
116
+ seq = None,
117
+ teacher_forcing_ratio = 0.0,
118
+ temperature = 1.0,
119
+ stop_criterion = None,
120
+ prob_threshold = 1e-5,
121
+ monotonicity_start_index = 1,
122
+ max_ids = None,
123
+ buffer_percent = 0.05
124
+ ):
125
+ """
126
+ context: (B, num_features_in)
127
+ global_cond: (B, num_features_global) -- simply ignored
128
+ seq: (B, L, num_features_out)
129
+ teacher_forcing_ratio:
130
+ monotonicity_start_index: from which galaxy to enforce monotonicity. No enforcement if < 0.
131
+ max_ids: torch tensor listing the indices of maximum value for the primary parameter for differen
132
+ """
133
+ # context: (batch, num_condition)
134
+ batch_size = len(context)
135
+
136
+ if len(context.shape) == 1:
137
+ context = context.unsqueeze(-1)
138
+
139
+ # used for enforce_monotonicity and max_ids
140
+ # buffer = 1 indicates the bins just above the max_ids are avoided.
141
+ buffer = max(int(buffer_percent * self.num_features_out), 1)
142
+ if max_ids is not None:
143
+ max_ids = max_ids.to(context.device) + buffer # (nbins, )
144
+ nbins = len(max_ids)
145
+ context_bins = torch.linspace(0, 1, nbins, device=context.device) # (nbins, )
146
+ context_bin_indices = torch.bucketize(context[:, 0], context_bins) - 1 # (batch, )
147
+ bin_indices = torch.arange(self.num_features_out, device=context.device) # (num_features_out, )
148
+ mask_max_ids = (bin_indices.unsqueeze(0) > max_ids[context_bin_indices].unsqueeze(1)) # (batch, num_features_out)
149
+
150
+ generated = torch.zeros(batch_size, self.max_length, self.num_features_in).to(context.device) # (batch, max_length, num_features_in)
151
+ mask_all_batch = torch.ones(batch_size, dtype=torch.bool).to(context.device)
152
+
153
+ for t in range(self.max_length):
154
+
155
+ if seq is not None and t < seq.size(1) and random.random() < teacher_forcing_ratio:
156
+ next_token = seq[:, t]
157
+ else:
158
+ x = self(context, generated[:, :t], global_cond=global_cond)
159
+ # generated[:, :t]: (batch, t, num_features_in)
160
+ # x: (batch, t+1, num_faetures_in, num_features_out)
161
+
162
+ x_last = x[:, -1, :, :]
163
+ # last taken (batch, num_features_in, num_features_out)
164
+
165
+ x_last = x_last / temperature
166
+
167
+ token_type = t % self.num_token_types
168
+
169
+ if token_type == 0:
170
+ if monotonicity_start_index is not None:
171
+ # Set the probability at x(t) >= x(t-1) to zero for the primary parameter
172
+ if t > monotonicity_start_index:
173
+ previous_token_bin = (generated[:, t - self.num_token_types, 0] * self.num_features_out).long() + buffer
174
+ previous_token_bin = previous_token_bin.contiguous().view(-1, 1) # (batch, 1)
175
+ bin_indices = torch.arange(self.num_features_out, device=context.device).view(1, -1) # (1, num_features_out)
176
+ mask = (bin_indices > previous_token_bin) # (batch, num_features_out)
177
+ mask = mask & (x_last[:, 0, :] >= prob_threshold) # (batch, num_features_out)
178
+ x_last[:, 0, :] = self._set_to_zero(x_last[:, 0, :], mask) # set the probability to zero for bins above the previous bin
179
+
180
+ if max_ids is not None:
181
+ # Even when monotonicity is not enforced, set the probability at x > x_max to zero for the primary parameter if max_ids is defined.
182
+ x_last[:, 0, :] = self._set_to_zero(x_last[:, 0, :], mask_max_ids)
183
+
184
+ x_last = self._set_to_zero(x_last, x_last < prob_threshold) # set the probability to zero if less than prob_threshold
185
+
186
+ x_last = x_last.reshape(-1, self.num_features_out) # (batch * num_features_in, num_features_out)
187
+ bin_indices = Categorical(probs=x_last).sample().float().view(-1, self.num_features_in) # (batch, num_features_in)
188
+ uniform_noise = torch.rand_like(bin_indices, device=context.device) # (batch, num_features_in)
189
+ next_token = (bin_indices + uniform_noise) / self.num_features_out # (batch, num_features_in)
190
+
191
+ if token_type == 0:
192
+ next_token[:, 0] = self._set_to_zero(next_token[:, 0], next_token[:,0] < 1./ self.num_features_out) # strictly set the sampled primary parameter to zero if it is less than 1/num_features_out
193
+
194
+ mask_all_batch = torch.ones(batch_size, dtype=torch.bool).to(context.device)
195
+
196
+ # Set the central galaxy's parameters to fixed values
197
+ is_first_gal = ( t // self.num_token_types == 0 )
198
+ if is_first_gal:
199
+ if self.num_token_types == 1:
200
+ for feat, cval in self.central_values.items():
201
+ idx = self.output_idx_map.get(feat)
202
+ if idx is not None:
203
+ next_token[:, idx] = cval + self._set_to_zero(next_token[:, idx], mask_all_batch)
204
+ else:
205
+ feat = self.output_features[token_type]
206
+ cval = self.central_values.get(feat)
207
+ if cval is not None:
208
+ next_token[:, 0] = cval + self._set_to_zero(next_token[:, 0], mask_all_batch)
209
+
210
+ # Stop generation if the primary parameter is below criterion
211
+ if token_type == 0:
212
+ if stop_criterion is not None:
213
+ if torch.all(next_token[:,0] < stop_criterion):
214
+ return generated, x
215
+
216
+ generated[:, t, :] = next_token # (batch, num_features_in)
217
+
218
+ if seq is not None and teacher_forcing_ratio > 0:
219
+ x = self(context, generated[:,:-1])
220
+
221
+ return generated, x
222
+
223
+
224
+ class Transformer1(TransformerBase): # add logM at first in the sequence
225
+ def __init__(
226
+ self,
227
+ num_condition = 1,
228
+ d_model = 128,
229
+ num_layers = 4,
230
+ num_heads = 8,
231
+ max_length = 10,
232
+ num_features_in = 1,
233
+ num_features_out = 1,
234
+ num_token_types = 1,
235
+ output_features = ["SubhaloSFR"],
236
+ dropout = 0,
237
+ last_activation = nn.Softmax(dim=-1),
238
+ pred_prob = True
239
+ ):
240
+ super().__init__(num_condition=num_condition, d_model=d_model, num_layers=num_layers, num_heads=num_heads, max_length=max_length, num_features_in=num_features_in, num_features_out=num_features_out, num_token_types=num_token_types, output_features=output_features, dropout=dropout)
241
+
242
+ self.embedding_layers = nn.Sequential(
243
+ nn.Linear(num_features_in, d_model),
244
+ nn.LeakyReLU(),
245
+ nn.Dropout(dropout),
246
+ nn.Linear(d_model, d_model),
247
+ )
248
+ self.context_embedding_layers = nn.Sequential(
249
+ nn.Linear(num_condition, d_model),
250
+ nn.LeakyReLU(),
251
+ nn.Dropout(dropout),
252
+ nn.Linear(d_model, d_model),
253
+ )
254
+
255
+ decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=num_heads, batch_first=True, dropout=dropout)
256
+ self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
257
+
258
+ self.pred_prob = pred_prob
259
+ if pred_prob:
260
+ self.output_layer = nn.Linear(d_model, num_features_in*num_features_out)
261
+ else:
262
+ self.output_layer = nn.Linear(d_model, num_features_out)
263
+
264
+ self.out_activation = last_activation
265
+
266
+
267
+ def forward(self, context, x, global_cond=None):
268
+ # context: (batch, num_condition)
269
+ # x: (batch, seq_length, num_features_in)
270
+
271
+ batch_size, seq_length, num_features_in = x.shape
272
+ total_seq_length = 1 + seq_length # total length of (context and x)
273
+
274
+ # concatenate embeddings of context and x
275
+ context = context.view(batch_size, 1, -1) # (batch, 1, num_condition)
276
+ for layer in self.context_embedding_layers:
277
+ context = layer(context)
278
+ # context: (batch, 1, d_model)
279
+
280
+ for layer in self.embedding_layers:
281
+ x = layer(x)
282
+ x = torch.cat([context, x], dim=1) # (batch, seq_length + 1, d_model)
283
+
284
+ # add position and type embedding
285
+ pos_emb = self.pos_embedding(self.token_pos_id[:total_seq_length]).unsqueeze(0) # (1, seq_length + 1, d_model)
286
+ type_emb = self.token_type_embedding(self.token_type_id[:total_seq_length]).unsqueeze(0) # (1, seq_length + 1, d_model)
287
+ x = x + pos_emb + type_emb # (batch, seq_length + 1, d_model)
288
+
289
+ # decode
290
+ causal_mask = self.generate_square_subsequent_mask(total_seq_length).to(x.device)
291
+ dummy_memory = torch.zeros(batch_size, 1, self.d_model, device=x.device)
292
+ x = self.decoder(x, memory=dummy_memory, tgt_mask=causal_mask) # (batch, seq_length + 1, d_model)
293
+
294
+ # output layer
295
+ x = self.output_layer(x) # (batch, seq_length + 1, num_features_in * num_features_out) or (batch, seq_length + 1, num_features_out)
296
+
297
+ if self.pred_prob:
298
+ x = x.view(batch_size, total_seq_length, self.num_features_in, -1) # (batch, seq_length + 1, num_features_in, num_features_out)
299
+
300
+ x = self.out_activation(x)
301
+ # x = entmax(x, dim=-1)
302
+
303
+ return x
304
+
305
+ class Transformer2(TransformerBase): # embed context and x together, and then add positional embedding
306
+ def __init__(
307
+ self,
308
+ num_condition = 1,
309
+ d_model = 128,
310
+ num_layers = 4,
311
+ num_heads = 8,
312
+ max_length = 10,
313
+ num_features_in = 1,
314
+ num_features_out = 1,
315
+ num_token_types = 1,
316
+ output_features = ["SubhaloSFR"],
317
+ dropout = 0,
318
+ last_activation = nn.Softmax(dim=-1),
319
+ pred_prob = True
320
+ ):
321
+
322
+ super().__init__(num_condition=num_condition, d_model=d_model, num_layers=num_layers, num_heads=num_heads, max_length=max_length, num_features_in=num_features_in, num_features_out=num_features_out, num_token_types=num_token_types, output_features=output_features, dropout=dropout)
323
+
324
+ self.start_token = torch.ones(1)
325
+
326
+ self.embedding_layers = nn.Sequential(
327
+ nn.Linear(num_condition+num_features_in, d_model),
328
+ nn.LeakyReLU(),
329
+ nn.Dropout(dropout),
330
+ nn.Linear(d_model, d_model),
331
+ )
332
+
333
+ decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=num_heads, batch_first=True, dropout=dropout)
334
+ self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
335
+
336
+ self.pred_prob = pred_prob
337
+ if pred_prob:
338
+ self.output_layer = nn.Linear(d_model, num_features_in*num_features_out)
339
+ else:
340
+ self.output_layer = nn.Linear(d_model, num_features_out)
341
+
342
+ self.out_activation = last_activation
343
+
344
+ def forward(self, context, x, global_cond=None):
345
+ # context: (batch, num_condition)
346
+ # x: (batch, seq_length, num_features_in)
347
+
348
+ batch_size, seq_length, num_features_in = x.shape
349
+ total_seq_length = 1 + seq_length # total length of (start token and x)
350
+
351
+ context = context.view(batch_size, 1, -1) # (batch, 1, num_condition)
352
+ context = context.expand(batch_size, seq_length+1, -1) # (batch, seq_length+1, num_condition)
353
+
354
+ ## add start token
355
+ start_token = self.start_token.expand(batch_size, 1, self.num_features_in).to(x.device) # (batch, 1, num_features_in)
356
+ x = torch.cat([start_token, x], dim=1) # (batch, seq_length+1, num_features_in)
357
+
358
+ ## concatenate context and x
359
+ x = torch.cat([context, x], dim=2) # (batch, seq_length+1, num_condition + num_features_in)
360
+
361
+ ## embedding
362
+ for layer in self.embedding_layers:
363
+ x = layer(x)
364
+ # x: (batch, seq_length+1, d_model)
365
+
366
+ # add position and type embedding
367
+ pos_emb = self.pos_embedding(self.token_pos_id[:total_seq_length]).unsqueeze(0) # (1, seq_length + 1, d_model)
368
+ type_emb = self.token_type_embedding(self.token_type_id[:total_seq_length]).unsqueeze(0) # (1, seq_length + 1, d_model)
369
+ x = x + pos_emb + type_emb # (batch, seq_length + 1, d_model)
370
+
371
+ # decode
372
+ causal_mask = self.generate_square_subsequent_mask(seq_length+1).to(x.device)
373
+ dummy_memory = torch.zeros(batch_size, 1, self.d_model, device=x.device)
374
+ x = self.decoder(x, memory=dummy_memory, tgt_mask=causal_mask) # (batch, seq_length+1, d_model)
375
+
376
+ # output layer
377
+ x = self.output_layer(x) # (batch, seq_length+1, num_features_in * num_features_out) or (batch, seq_length+1, num_features_out)
378
+
379
+ if self.pred_prob:
380
+ x = x.view(batch_size, seq_length+1, self.num_features_in, -1) # (batch, seq_length+1, num_features_in, num_features_out)
381
+
382
+ x = self.out_activation(x)
383
+
384
+ return x
385
+
386
+
387
+ class TransformerWithGlobalCond(nn.Module):
388
+ def __init__(
389
+ self,
390
+ num_features_global,
391
+ transformer_cls = Transformer1,
392
+ num_condition = 1,
393
+ d_model = 128,
394
+ num_layers = 4,
395
+ num_heads = 8,
396
+ max_length = 10,
397
+ num_features_in = 1,
398
+ num_features_out = 1,
399
+ num_token_types = 1,
400
+ output_features = ["SubhaloSFR"],
401
+ dropout = 0,
402
+ last_activation = nn.Softmax(dim=-1),
403
+ pred_prob = True
404
+ ):
405
+ super().__init__()
406
+
407
+ self.transformer = transformer_cls(
408
+ num_condition = num_condition + num_features_global,
409
+ d_model = d_model,
410
+ num_layers = num_layers,
411
+ num_heads = num_heads,
412
+ max_length = max_length,
413
+ num_features_in = num_features_in,
414
+ num_features_out = num_features_out,
415
+ num_token_types = num_token_types,
416
+ output_features = output_features,
417
+ dropout = dropout,
418
+ last_activation = last_activation,
419
+ pred_prob = pred_prob
420
+ )
421
+
422
+ def forward(self, context, x, global_cond):
423
+ if len(context.shape) == 1:
424
+ context = context.unsqueeze(-1)
425
+ ctx = torch.cat([context, global_cond], dim=1)
426
+ return self.transformer(ctx, x)
427
+
428
+ def generate(self, context, global_cond, **kwargs):
429
+ if len(context.shape) == 1:
430
+ context = context.unsqueeze(-1)
431
+ ctx = torch.cat([context, global_cond], dim=1)
432
+ return self.transformer.generate(ctx, **kwargs)
433
+
434
+
435
+ from typing import Optional
436
+
437
+ class TransformerDecoderLayerWithAttn(nn.TransformerDecoderLayer):
438
+ def _sa_block(
439
+ self,
440
+ x: torch.Tensor,
441
+ attn_mask: Optional[torch.Tensor],
442
+ key_padding_mask: Optional[torch.Tensor],
443
+ is_causal: bool = False
444
+ ) -> torch.Tensor:
445
+
446
+ attn_output, attn_weights = self.self_attn(
447
+ x, x, x,
448
+ attn_mask=attn_mask,
449
+ key_padding_mask=key_padding_mask,
450
+ need_weights=True,
451
+ is_causal=is_causal,
452
+ )
453
+
454
+ self.attn_weights = attn_weights.detach().cpu()
455
+ return self.dropout1(attn_output)
456
+
457
+
458
+ class Transformer1WithAttn(Transformer1):
459
+ def __init__(
460
+ self,
461
+ num_condition = 1,
462
+ d_model = 128,
463
+ num_layers = 4,
464
+ num_heads = 8,
465
+ max_length = 10,
466
+ num_features_in = 1,
467
+ num_features_out = 1,
468
+ num_token_types = 1,
469
+ output_features = ["SubhaloSFR"],
470
+ dropout = 0,
471
+ last_activation = nn.Softmax(dim=-1),
472
+ pred_prob = True
473
+ ):
474
+ super().__init__(num_condition=num_condition, d_model=d_model, num_layers=num_layers, num_heads=num_heads, max_length=max_length, num_features_in=num_features_in, num_features_out=num_features_out, num_token_types=num_token_types, output_features=output_features, dropout=dropout, last_activation=last_activation, pred_prob=pred_prob)
475
+
476
+ decoder_layer = TransformerDecoderLayerWithAttn(d_model=d_model, nhead=num_heads, batch_first=True, dropout=dropout)
477
+ self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
478
+
479
+ class Transformer2WithAttn(Transformer2):
480
+ def __init__(
481
+ self,
482
+ num_condition = 1,
483
+ d_model = 128,
484
+ num_layers = 4,
485
+ num_heads = 8,
486
+ max_length = 10,
487
+ num_features_in = 1,
488
+ num_features_out = 1,
489
+ num_token_types = 1,
490
+ output_features = ["SubhaloSFR"],
491
+ dropout = 0,
492
+ last_activation = nn.Softmax(dim=-1),
493
+ pred_prob = True
494
+ ):
495
+ super().__init__(num_condition=num_condition, d_model=d_model, num_layers=num_layers, num_heads=num_heads, max_length=max_length, num_features_in=num_features_in, num_features_out=num_features_out, num_token_types=num_token_types, output_features=output_features, dropout=dropout, last_activation=last_activation, pred_prob=pred_prob)
496
+
497
+ decoder_layer = TransformerDecoderLayerWithAttn(d_model=d_model, nhead=num_heads, batch_first=True, dropout=dropout)
498
+ self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
499
+
500
+