cosmoglint 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,368 @@
1
+
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from .transformer import Transformer1, Transformer2, TransformerWithGlobalCond, Transformer1WithAttn, Transformer2WithAttn
8
+
9
+ from nflows.flows import Flow
10
+ from nflows.distributions import StandardNormal, ConditionalDiagonalNormal
11
+ from nflows.transforms import CompositeTransform, RandomPermutation, AffineCouplingTransform, ActNorm, PiecewiseRationalQuadraticCDF, AffineTransform
12
+ from nflows.nn.nets import ResidualNet
13
+ from nflows.utils import torchutils
14
+
15
+ from torch.distributions import Normal, Categorical, MixtureSameFamily, Independent
16
+
17
+
18
+ def transformer_nf_model(args, **kwargs):
19
+
20
+ ### Transformer model ###
21
+ if args.model_name == "transformer1":
22
+ model_class = Transformer1
23
+ elif args.model_name == "transformer2":
24
+ model_class = Transformer2
25
+ elif args.model_name == "transformer1_with_global_cond":
26
+ model_class = TransformerWithGlobalCond
27
+ transformer_cls = Transformer1
28
+ elif args.model_name == "transformer2_with_global_cond":
29
+ model_class = TransformerWithGlobalCond
30
+ transformer_cls = Transformer2
31
+ elif args.model_name == "transformer1_with_attn":
32
+ model_class = Transformer1WithAttn
33
+ elif args.model_name == "transformer2_with_attn":
34
+ model_class = Transformer2WithAttn
35
+ else:
36
+ raise ValueError(f"Invalid model: {args.model_name}")
37
+
38
+ common_args = dict(
39
+ num_condition = args.num_features_cond,
40
+ d_model = args.d_model,
41
+ num_layers = args.num_layers,
42
+ num_heads = args.num_heads,
43
+ max_length = args.max_length,
44
+ num_features_in = args.num_features_in,
45
+ num_features_out = args.num_context,
46
+ num_token_types = 1,
47
+ last_activation = nn.Tanh(),
48
+ pred_prob = False,
49
+ **kwargs
50
+ )
51
+
52
+ if "with_global_cond" in args.model_name:
53
+ common_args["num_features_global"] = args.num_features_global
54
+ common_args["transformer_cls"] = transformer_cls
55
+
56
+ model = model_class(**common_args)
57
+
58
+ ### Flow model ###
59
+ transforms = []
60
+
61
+ alpha = 0.05 # value in RealNVP
62
+ transforms.append(
63
+ AffineTransform(
64
+ shift=torch.full([args.num_features_in], fill_value=alpha),
65
+ scale=torch.full([args.num_features_in], fill_value=1.0 - 2.0*alpha),
66
+ )
67
+ )
68
+ transforms.append(Logit())
69
+
70
+ for ilayer in range(args.num_flows):
71
+ #transforms.append(ActNorm(args.num_features_out)) # for stabilyzing training and quick convergence
72
+ transforms.append(
73
+ AffineCouplingTransform(
74
+ #mask=torch.arange(args.num_features_out) % 2,
75
+ mask = torch.arange(args.num_features_in) % 2 if ilayer % 2 == 0 else (torch.arange(args.num_features_in) + 1) % 2,
76
+ transform_net_create_fn=lambda in_features, out_features: ResidualNet(
77
+ in_features=in_features,
78
+ out_features=out_features,
79
+ hidden_features=args.hidden_dim,
80
+ context_features=args.num_context,
81
+ num_blocks=2,
82
+ activation=nn.ReLU()
83
+ )
84
+ )
85
+ )
86
+ transforms.append(PiecewiseRationalQuadraticCDF(shape=[args.num_features_in], num_bins=8, tails='linear', tail_bound=3.0))
87
+ transforms.append(RandomPermutation(args.num_features_in))
88
+
89
+ transform = CompositeTransform(transforms)
90
+
91
+ if args.base_dist == "normal":
92
+ base_dist = StandardNormal(shape=[args.num_features_in])
93
+ elif args.base_dist == "conditional_normal":
94
+ base_dist = ConditionalDiagonalNormal(shape=[args.num_features_in],
95
+ context_encoder=nn.Sequential(nn.Linear(args.num_context, args.hidden_dim),
96
+ nn.LeakyReLU(),
97
+ nn.Linear(args.hidden_dim, 2 * args.num_features_in)
98
+ )
99
+ )
100
+
101
+ elif args.base_dist == "bimodal":
102
+ base_dist = BimodalNormal(shape=[args.num_features_in], offset=2.0)
103
+ elif args.base_dist == "conditional_bimodal":
104
+ base_dist = ConditionalBimodal(context_dim=args.num_context, latent_dim=args.num_features_in, hidden_dim=args.hidden_dim)
105
+ else:
106
+ raise ValueError(f"Invalid base distribution: {args.base_dist}")
107
+ flow = Flow(transform, base_dist)
108
+
109
+ return model, flow
110
+
111
+
112
+
113
+ def calculate_transformer_nf_loss(transformer, flow, batch, stop=None, stop_predictor=None):
114
+ device = next(transformer.parameters()).device
115
+
116
+ condition = batch["context"].to(device)
117
+ seq = batch["target"].to(device)
118
+ mask = batch["mask"].to(device) # (batch, max_length, num_context)
119
+ global_cond = batch["global_context"].to(device)
120
+
121
+ input_seq = seq[:, :-1] # (batch, max_length-1, num_features)
122
+ output = transformer(condition, input_seq, global_cond=global_cond) # (batch, max_length, num_context)
123
+
124
+ mask = mask[:,:,0].reshape(-1) # (batch * max_length)
125
+
126
+ _, _, num_features = seq.shape
127
+ _, _, num_context = output.shape
128
+
129
+ context_for_flow = output.reshape(-1, num_context)[mask] # flow context (num_galaxies, num_context)
130
+ target_for_flow = seq.reshape(-1, num_features)[mask] # flow target (num_galaxies, num_features)
131
+ log_prob = flow.log_prob(target_for_flow, context=context_for_flow)
132
+ loss = - log_prob.mean() / num_features
133
+
134
+ if stop is not None and stop_predictor is not None:
135
+ stop = stop.to(device) # (batch, max_length)
136
+ stop = stop.reshape(-1)[mask] # (num_galaxies,)
137
+ stop_pred = stop_predictor(context_for_flow)
138
+ loss_stop = torch.nn.BCELoss()(stop_pred, stop) # (num_galaxies,)
139
+ return loss, loss_stop
140
+ else:
141
+ return loss
142
+
143
+ def generate_with_transformer_nf(transformer, flow, x_cond, stop_predictor=None, prob_threshold=0, stop_threshold=None):
144
+ transformer.eval()
145
+ flow.eval()
146
+ if stop_predictor is not None:
147
+ stop_predictor.eval()
148
+
149
+ batch_size = len(x_cond)
150
+ max_length = transformer.max_length
151
+ num_features = transformer.num_features_in
152
+
153
+ x_seq = torch.zeros(batch_size, max_length, num_features).to(x_cond.device)
154
+ stop_flags = torch.zeros(batch_size, dtype=torch.bool).to(x_cond.device)
155
+
156
+ for t in range(max_length):
157
+ with torch.no_grad():
158
+ context = transformer(x_cond, x_seq[:,:t,:]) # (batch, t+1, num_context)
159
+ context = context[:, -1, :] # (batch, num_context)
160
+ if prob_threshold > 0:
161
+ raise ValueError("# prob_threshold > 0 is not yet implemented")
162
+ num_samples = 1
163
+ embedded_context = flow._embedding_net(context)
164
+
165
+ noise, log_prob = flow._distribution.sample_and_log_prob(num_samples, context)
166
+
167
+ if embedded_context is not None:
168
+ # Merge the context dimension with sample dimension in order to apply the transform.
169
+ noise = torchutils.merge_leading_dims(noise, num_dims=2)
170
+ embedded_context = torchutils.repeat_rows(embedded_context, num_reps=1)
171
+
172
+ samples, _ = flow._transform.inverse(noise, context=embedded_context)
173
+
174
+ if embedded_context is not None:
175
+ # Split the context dimension from sample dimension.
176
+ samples = torchutils.split_leading_dim(samples, shape=[-1, num_samples])
177
+
178
+ else:
179
+ samples = flow.sample(1, context=context) # (batch, 1, num_features)
180
+
181
+ x_seq[:, t, :] = samples.squeeze(1)
182
+
183
+ if stop_predictor is not None:
184
+ stop_prob = stop_predictor(context) # (batch, )
185
+ stop_now = stop_prob > stop_threshold
186
+ stop_flags = stop_now
187
+
188
+ elif stop_threshold is not None:
189
+ if t > 0:
190
+ stop_now = x_seq[:,t,0] < stop_threshold # (batch, )
191
+ stop_flags |= stop_now
192
+
193
+ if t == 0:
194
+ if num_features > 1:
195
+ x_seq[:,0,1] = torch.randn(batch_size).to(x_cond.device) * 1e-3 # Random distance for central
196
+ if num_features > 2:
197
+ x_seq[:,0,2] = torch.randn(batch_size).to(x_cond.device) * 1e-3
198
+
199
+ if stop_flags.all():
200
+ break
201
+
202
+ return x_seq
203
+
204
+ def my_stop_predictor(args):
205
+ return StopPredictor(context_dim=args.num_context, hidden_dim=args.hidden_dim_stop)
206
+
207
+
208
+ class StopPredictor(nn.Module):
209
+ def __init__(self, context_dim, hidden_dim=64):
210
+ super().__init__()
211
+ self.layers = nn.Sequential(
212
+ nn.Linear(context_dim, hidden_dim),
213
+ nn.LeakyReLU(),
214
+ nn.Linear(hidden_dim, 1),
215
+ nn.Sigmoid()
216
+ )
217
+
218
+ def forward(self, context):
219
+ x = self.layers(context)
220
+ return x.squeeze(-1) # (batch_size,) in [0, 1]
221
+
222
+ class BimodalNormal(StandardNormal):
223
+ """A simple mixture of two Gaussians (bimodal) with fixed means."""
224
+
225
+ def __init__(self, shape, offset=2.0):
226
+ super().__init__(shape)
227
+ self.offset = offset
228
+
229
+ def _sample(self, num_samples, context=None):
230
+ if context is None:
231
+ device = self._log_z.device
232
+ half = num_samples // 2
233
+
234
+ samples1 = torch.randn(half, *self._shape, device=device) - self.offset
235
+ samples2 = torch.randn(num_samples - half, *self._shape, device=device) + self.offset
236
+ samples = torch.cat([samples1, samples2], dim=0)
237
+ return samples[torch.randperm(num_samples)]
238
+
239
+ else:
240
+ context_size = context.shape[0]
241
+ total = context_size * num_samples
242
+ device = context.device
243
+ half = total // 2
244
+
245
+ samples1 = torch.randn(half, *self._shape, device=device) - self.offset
246
+ samples2 = torch.randn(total - half, *self._shape, device=device) + self.offset
247
+ samples = torch.cat([samples1, samples2], dim=0)
248
+ samples = samples[torch.randperm(total)]
249
+
250
+ return torchutils.split_leading_dim(samples, [context_size, num_samples])
251
+
252
+ def _log_prob(self, inputs, context=None):
253
+ if inputs.shape[1:] != self._shape:
254
+ raise ValueError(
255
+ "Expected input of shape {}, got {}".format(
256
+ self._shape, inputs.shape[1:]
257
+ )
258
+ )
259
+
260
+ x = inputs.view(inputs.size(0), -1)
261
+
262
+ logp1 = -0.5 * torch.sum((x + self.offset)**2, dim=1) - self._log_z
263
+ logp2 = -0.5 * torch.sum((x - self.offset)**2, dim=1) - self._log_z
264
+
265
+ log_prob = torch.logaddexp(logp1, logp2) - np.log(2.0)
266
+ return log_prob
267
+
268
+ def _mean(self, context):
269
+ if context is None:
270
+ return self._log_z.new_zeros(self._shape)
271
+ else:
272
+ return context.new_zeros(context.shape[0], *self._shape)
273
+
274
+ class ConditionalBimodal(nn.Module):
275
+ def __init__(self, context_dim, latent_dim, hidden_dim=64):
276
+ super().__init__()
277
+ self.latent_dim = latent_dim
278
+
279
+ # context → offset (1,), logit(weight) (1,), scale (1,)
280
+ self.net = nn.Sequential(
281
+ nn.Linear(context_dim, hidden_dim),
282
+ nn.ReLU(),
283
+ nn.Linear(hidden_dim, 2 * latent_dim + 2) # mean1, mean2, scale, weight_logits
284
+ )
285
+
286
+ def get_distribution(self, context):
287
+ batch_size = context.shape[0]
288
+ params = self.net(context) # (batch, 2*latent_dim + 2)
289
+
290
+ # 分解
291
+ mean1 = params[:, :self.latent_dim]
292
+ mean2 = params[:, self.latent_dim:2*self.latent_dim]
293
+ scale_raw = params[:, 2*self.latent_dim:2*self.latent_dim+1] # shape (batch, 1)
294
+ weight_logits = params[:, 2*self.latent_dim+1:] # shape (batch, 1)
295
+
296
+ scale = F.softplus(scale_raw) + 1e-3
297
+
298
+ mix_logits = torch.cat([weight_logits, -weight_logits], dim=1) # (batch, 2)
299
+ mix_dist = Categorical(logits=mix_logits)
300
+
301
+ means = torch.stack([mean1, mean2], dim=1) # (batch, 2, latent_dim)
302
+ scales = scale.unsqueeze(1).expand_as(means) # (batch, 2, latent_dim)
303
+
304
+ comp_dist = Independent(Normal(means, scales), 1) # treat latent_dim as event dim
305
+ mixture = MixtureSameFamily(mix_dist, comp_dist)
306
+ return mixture
307
+
308
+ def sample(self, num_samples, context):
309
+ """
310
+ context: (batch, context_dim)
311
+ return: (batch, num_samples, latent_dim)
312
+ """
313
+ dist = self.get_distribution(context)
314
+ samples = dist.sample((num_samples,)) # (num_samples, batch, latent_dim)
315
+ return samples.permute(1, 0, 2) # (batch, num_samples, latent_dim)
316
+
317
+ def log_prob(self, x, context):
318
+ """
319
+ x: (batch, latent_dim)
320
+ context: (batch, context_dim)
321
+ return: (batch,)
322
+ """
323
+ dist = self.get_distribution(context)
324
+ return dist.log_prob(x)
325
+
326
+
327
+ from nflows.transforms.base import InputOutsideDomain, Transform, InverseTransform
328
+
329
+ class Sigmoid(Transform):
330
+ def __init__(self, temperature=1, eps=1e-6, learn_temperature=False):
331
+ super().__init__()
332
+ self.eps = eps
333
+ if learn_temperature:
334
+ self.temperature = nn.Parameter(torch.Tensor([temperature]))
335
+ else:
336
+ temperature = torch.Tensor([temperature])
337
+ self.register_buffer('temperature', temperature)
338
+
339
+ def forward(self, inputs, context=None):
340
+ temperature = self.temperature.to(inputs.device)
341
+
342
+ inputs = temperature * inputs
343
+ outputs = torch.sigmoid(inputs)
344
+
345
+ logabsdet = torchutils.sum_except_batch(
346
+ torch.log(temperature) - F.softplus(-inputs) - F.softplus(inputs)
347
+ )
348
+ return outputs, logabsdet
349
+
350
+ def inverse(self, inputs, context=None):
351
+ if torch.min(inputs) < 0 or torch.max(inputs) > 1:
352
+ raise InputOutsideDomain()
353
+
354
+ inputs = torch.clamp(inputs, self.eps, 1 - self.eps)
355
+
356
+ temperature = self.temperature.to(inputs.device)
357
+ outputs = (1 / temperature) * (torch.log(inputs) - torch.log1p(-inputs))
358
+ logabsdet = -torchutils.sum_except_batch(
359
+ torch.log(temperature)
360
+ - F.softplus(-temperature * outputs)
361
+ - F.softplus(temperature * outputs)
362
+ )
363
+ return outputs, logabsdet
364
+
365
+
366
+ class Logit(InverseTransform):
367
+ def __init__(self, temperature=1, eps=1e-6):
368
+ super().__init__(Sigmoid(temperature=temperature, eps=eps))