cosmoglint 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cosmoglint/__init__.py +1 -0
- cosmoglint/model/__init__.py +2 -0
- cosmoglint/model/transformer.py +500 -0
- cosmoglint/model/transformer_nf.py +368 -0
- cosmoglint/utils/ReadPinocchio5.py +1022 -0
- cosmoglint/utils/__init__.py +2 -0
- cosmoglint/utils/cosmology_utils.py +194 -0
- cosmoglint/utils/generation_utils.py +366 -0
- cosmoglint/utils/io_utils.py +397 -0
- cosmoglint-1.0.0.dist-info/METADATA +164 -0
- cosmoglint-1.0.0.dist-info/RECORD +14 -0
- cosmoglint-1.0.0.dist-info/WHEEL +5 -0
- cosmoglint-1.0.0.dist-info/licenses/LICENSE +21 -0
- cosmoglint-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,368 @@
|
|
|
1
|
+
|
|
2
|
+
import numpy as np
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
|
|
7
|
+
from .transformer import Transformer1, Transformer2, TransformerWithGlobalCond, Transformer1WithAttn, Transformer2WithAttn
|
|
8
|
+
|
|
9
|
+
from nflows.flows import Flow
|
|
10
|
+
from nflows.distributions import StandardNormal, ConditionalDiagonalNormal
|
|
11
|
+
from nflows.transforms import CompositeTransform, RandomPermutation, AffineCouplingTransform, ActNorm, PiecewiseRationalQuadraticCDF, AffineTransform
|
|
12
|
+
from nflows.nn.nets import ResidualNet
|
|
13
|
+
from nflows.utils import torchutils
|
|
14
|
+
|
|
15
|
+
from torch.distributions import Normal, Categorical, MixtureSameFamily, Independent
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def transformer_nf_model(args, **kwargs):
|
|
19
|
+
|
|
20
|
+
### Transformer model ###
|
|
21
|
+
if args.model_name == "transformer1":
|
|
22
|
+
model_class = Transformer1
|
|
23
|
+
elif args.model_name == "transformer2":
|
|
24
|
+
model_class = Transformer2
|
|
25
|
+
elif args.model_name == "transformer1_with_global_cond":
|
|
26
|
+
model_class = TransformerWithGlobalCond
|
|
27
|
+
transformer_cls = Transformer1
|
|
28
|
+
elif args.model_name == "transformer2_with_global_cond":
|
|
29
|
+
model_class = TransformerWithGlobalCond
|
|
30
|
+
transformer_cls = Transformer2
|
|
31
|
+
elif args.model_name == "transformer1_with_attn":
|
|
32
|
+
model_class = Transformer1WithAttn
|
|
33
|
+
elif args.model_name == "transformer2_with_attn":
|
|
34
|
+
model_class = Transformer2WithAttn
|
|
35
|
+
else:
|
|
36
|
+
raise ValueError(f"Invalid model: {args.model_name}")
|
|
37
|
+
|
|
38
|
+
common_args = dict(
|
|
39
|
+
num_condition = args.num_features_cond,
|
|
40
|
+
d_model = args.d_model,
|
|
41
|
+
num_layers = args.num_layers,
|
|
42
|
+
num_heads = args.num_heads,
|
|
43
|
+
max_length = args.max_length,
|
|
44
|
+
num_features_in = args.num_features_in,
|
|
45
|
+
num_features_out = args.num_context,
|
|
46
|
+
num_token_types = 1,
|
|
47
|
+
last_activation = nn.Tanh(),
|
|
48
|
+
pred_prob = False,
|
|
49
|
+
**kwargs
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
if "with_global_cond" in args.model_name:
|
|
53
|
+
common_args["num_features_global"] = args.num_features_global
|
|
54
|
+
common_args["transformer_cls"] = transformer_cls
|
|
55
|
+
|
|
56
|
+
model = model_class(**common_args)
|
|
57
|
+
|
|
58
|
+
### Flow model ###
|
|
59
|
+
transforms = []
|
|
60
|
+
|
|
61
|
+
alpha = 0.05 # value in RealNVP
|
|
62
|
+
transforms.append(
|
|
63
|
+
AffineTransform(
|
|
64
|
+
shift=torch.full([args.num_features_in], fill_value=alpha),
|
|
65
|
+
scale=torch.full([args.num_features_in], fill_value=1.0 - 2.0*alpha),
|
|
66
|
+
)
|
|
67
|
+
)
|
|
68
|
+
transforms.append(Logit())
|
|
69
|
+
|
|
70
|
+
for ilayer in range(args.num_flows):
|
|
71
|
+
#transforms.append(ActNorm(args.num_features_out)) # for stabilyzing training and quick convergence
|
|
72
|
+
transforms.append(
|
|
73
|
+
AffineCouplingTransform(
|
|
74
|
+
#mask=torch.arange(args.num_features_out) % 2,
|
|
75
|
+
mask = torch.arange(args.num_features_in) % 2 if ilayer % 2 == 0 else (torch.arange(args.num_features_in) + 1) % 2,
|
|
76
|
+
transform_net_create_fn=lambda in_features, out_features: ResidualNet(
|
|
77
|
+
in_features=in_features,
|
|
78
|
+
out_features=out_features,
|
|
79
|
+
hidden_features=args.hidden_dim,
|
|
80
|
+
context_features=args.num_context,
|
|
81
|
+
num_blocks=2,
|
|
82
|
+
activation=nn.ReLU()
|
|
83
|
+
)
|
|
84
|
+
)
|
|
85
|
+
)
|
|
86
|
+
transforms.append(PiecewiseRationalQuadraticCDF(shape=[args.num_features_in], num_bins=8, tails='linear', tail_bound=3.0))
|
|
87
|
+
transforms.append(RandomPermutation(args.num_features_in))
|
|
88
|
+
|
|
89
|
+
transform = CompositeTransform(transforms)
|
|
90
|
+
|
|
91
|
+
if args.base_dist == "normal":
|
|
92
|
+
base_dist = StandardNormal(shape=[args.num_features_in])
|
|
93
|
+
elif args.base_dist == "conditional_normal":
|
|
94
|
+
base_dist = ConditionalDiagonalNormal(shape=[args.num_features_in],
|
|
95
|
+
context_encoder=nn.Sequential(nn.Linear(args.num_context, args.hidden_dim),
|
|
96
|
+
nn.LeakyReLU(),
|
|
97
|
+
nn.Linear(args.hidden_dim, 2 * args.num_features_in)
|
|
98
|
+
)
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
elif args.base_dist == "bimodal":
|
|
102
|
+
base_dist = BimodalNormal(shape=[args.num_features_in], offset=2.0)
|
|
103
|
+
elif args.base_dist == "conditional_bimodal":
|
|
104
|
+
base_dist = ConditionalBimodal(context_dim=args.num_context, latent_dim=args.num_features_in, hidden_dim=args.hidden_dim)
|
|
105
|
+
else:
|
|
106
|
+
raise ValueError(f"Invalid base distribution: {args.base_dist}")
|
|
107
|
+
flow = Flow(transform, base_dist)
|
|
108
|
+
|
|
109
|
+
return model, flow
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def calculate_transformer_nf_loss(transformer, flow, batch, stop=None, stop_predictor=None):
|
|
114
|
+
device = next(transformer.parameters()).device
|
|
115
|
+
|
|
116
|
+
condition = batch["context"].to(device)
|
|
117
|
+
seq = batch["target"].to(device)
|
|
118
|
+
mask = batch["mask"].to(device) # (batch, max_length, num_context)
|
|
119
|
+
global_cond = batch["global_context"].to(device)
|
|
120
|
+
|
|
121
|
+
input_seq = seq[:, :-1] # (batch, max_length-1, num_features)
|
|
122
|
+
output = transformer(condition, input_seq, global_cond=global_cond) # (batch, max_length, num_context)
|
|
123
|
+
|
|
124
|
+
mask = mask[:,:,0].reshape(-1) # (batch * max_length)
|
|
125
|
+
|
|
126
|
+
_, _, num_features = seq.shape
|
|
127
|
+
_, _, num_context = output.shape
|
|
128
|
+
|
|
129
|
+
context_for_flow = output.reshape(-1, num_context)[mask] # flow context (num_galaxies, num_context)
|
|
130
|
+
target_for_flow = seq.reshape(-1, num_features)[mask] # flow target (num_galaxies, num_features)
|
|
131
|
+
log_prob = flow.log_prob(target_for_flow, context=context_for_flow)
|
|
132
|
+
loss = - log_prob.mean() / num_features
|
|
133
|
+
|
|
134
|
+
if stop is not None and stop_predictor is not None:
|
|
135
|
+
stop = stop.to(device) # (batch, max_length)
|
|
136
|
+
stop = stop.reshape(-1)[mask] # (num_galaxies,)
|
|
137
|
+
stop_pred = stop_predictor(context_for_flow)
|
|
138
|
+
loss_stop = torch.nn.BCELoss()(stop_pred, stop) # (num_galaxies,)
|
|
139
|
+
return loss, loss_stop
|
|
140
|
+
else:
|
|
141
|
+
return loss
|
|
142
|
+
|
|
143
|
+
def generate_with_transformer_nf(transformer, flow, x_cond, stop_predictor=None, prob_threshold=0, stop_threshold=None):
|
|
144
|
+
transformer.eval()
|
|
145
|
+
flow.eval()
|
|
146
|
+
if stop_predictor is not None:
|
|
147
|
+
stop_predictor.eval()
|
|
148
|
+
|
|
149
|
+
batch_size = len(x_cond)
|
|
150
|
+
max_length = transformer.max_length
|
|
151
|
+
num_features = transformer.num_features_in
|
|
152
|
+
|
|
153
|
+
x_seq = torch.zeros(batch_size, max_length, num_features).to(x_cond.device)
|
|
154
|
+
stop_flags = torch.zeros(batch_size, dtype=torch.bool).to(x_cond.device)
|
|
155
|
+
|
|
156
|
+
for t in range(max_length):
|
|
157
|
+
with torch.no_grad():
|
|
158
|
+
context = transformer(x_cond, x_seq[:,:t,:]) # (batch, t+1, num_context)
|
|
159
|
+
context = context[:, -1, :] # (batch, num_context)
|
|
160
|
+
if prob_threshold > 0:
|
|
161
|
+
raise ValueError("# prob_threshold > 0 is not yet implemented")
|
|
162
|
+
num_samples = 1
|
|
163
|
+
embedded_context = flow._embedding_net(context)
|
|
164
|
+
|
|
165
|
+
noise, log_prob = flow._distribution.sample_and_log_prob(num_samples, context)
|
|
166
|
+
|
|
167
|
+
if embedded_context is not None:
|
|
168
|
+
# Merge the context dimension with sample dimension in order to apply the transform.
|
|
169
|
+
noise = torchutils.merge_leading_dims(noise, num_dims=2)
|
|
170
|
+
embedded_context = torchutils.repeat_rows(embedded_context, num_reps=1)
|
|
171
|
+
|
|
172
|
+
samples, _ = flow._transform.inverse(noise, context=embedded_context)
|
|
173
|
+
|
|
174
|
+
if embedded_context is not None:
|
|
175
|
+
# Split the context dimension from sample dimension.
|
|
176
|
+
samples = torchutils.split_leading_dim(samples, shape=[-1, num_samples])
|
|
177
|
+
|
|
178
|
+
else:
|
|
179
|
+
samples = flow.sample(1, context=context) # (batch, 1, num_features)
|
|
180
|
+
|
|
181
|
+
x_seq[:, t, :] = samples.squeeze(1)
|
|
182
|
+
|
|
183
|
+
if stop_predictor is not None:
|
|
184
|
+
stop_prob = stop_predictor(context) # (batch, )
|
|
185
|
+
stop_now = stop_prob > stop_threshold
|
|
186
|
+
stop_flags = stop_now
|
|
187
|
+
|
|
188
|
+
elif stop_threshold is not None:
|
|
189
|
+
if t > 0:
|
|
190
|
+
stop_now = x_seq[:,t,0] < stop_threshold # (batch, )
|
|
191
|
+
stop_flags |= stop_now
|
|
192
|
+
|
|
193
|
+
if t == 0:
|
|
194
|
+
if num_features > 1:
|
|
195
|
+
x_seq[:,0,1] = torch.randn(batch_size).to(x_cond.device) * 1e-3 # Random distance for central
|
|
196
|
+
if num_features > 2:
|
|
197
|
+
x_seq[:,0,2] = torch.randn(batch_size).to(x_cond.device) * 1e-3
|
|
198
|
+
|
|
199
|
+
if stop_flags.all():
|
|
200
|
+
break
|
|
201
|
+
|
|
202
|
+
return x_seq
|
|
203
|
+
|
|
204
|
+
def my_stop_predictor(args):
|
|
205
|
+
return StopPredictor(context_dim=args.num_context, hidden_dim=args.hidden_dim_stop)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class StopPredictor(nn.Module):
|
|
209
|
+
def __init__(self, context_dim, hidden_dim=64):
|
|
210
|
+
super().__init__()
|
|
211
|
+
self.layers = nn.Sequential(
|
|
212
|
+
nn.Linear(context_dim, hidden_dim),
|
|
213
|
+
nn.LeakyReLU(),
|
|
214
|
+
nn.Linear(hidden_dim, 1),
|
|
215
|
+
nn.Sigmoid()
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
def forward(self, context):
|
|
219
|
+
x = self.layers(context)
|
|
220
|
+
return x.squeeze(-1) # (batch_size,) in [0, 1]
|
|
221
|
+
|
|
222
|
+
class BimodalNormal(StandardNormal):
|
|
223
|
+
"""A simple mixture of two Gaussians (bimodal) with fixed means."""
|
|
224
|
+
|
|
225
|
+
def __init__(self, shape, offset=2.0):
|
|
226
|
+
super().__init__(shape)
|
|
227
|
+
self.offset = offset
|
|
228
|
+
|
|
229
|
+
def _sample(self, num_samples, context=None):
|
|
230
|
+
if context is None:
|
|
231
|
+
device = self._log_z.device
|
|
232
|
+
half = num_samples // 2
|
|
233
|
+
|
|
234
|
+
samples1 = torch.randn(half, *self._shape, device=device) - self.offset
|
|
235
|
+
samples2 = torch.randn(num_samples - half, *self._shape, device=device) + self.offset
|
|
236
|
+
samples = torch.cat([samples1, samples2], dim=0)
|
|
237
|
+
return samples[torch.randperm(num_samples)]
|
|
238
|
+
|
|
239
|
+
else:
|
|
240
|
+
context_size = context.shape[0]
|
|
241
|
+
total = context_size * num_samples
|
|
242
|
+
device = context.device
|
|
243
|
+
half = total // 2
|
|
244
|
+
|
|
245
|
+
samples1 = torch.randn(half, *self._shape, device=device) - self.offset
|
|
246
|
+
samples2 = torch.randn(total - half, *self._shape, device=device) + self.offset
|
|
247
|
+
samples = torch.cat([samples1, samples2], dim=0)
|
|
248
|
+
samples = samples[torch.randperm(total)]
|
|
249
|
+
|
|
250
|
+
return torchutils.split_leading_dim(samples, [context_size, num_samples])
|
|
251
|
+
|
|
252
|
+
def _log_prob(self, inputs, context=None):
|
|
253
|
+
if inputs.shape[1:] != self._shape:
|
|
254
|
+
raise ValueError(
|
|
255
|
+
"Expected input of shape {}, got {}".format(
|
|
256
|
+
self._shape, inputs.shape[1:]
|
|
257
|
+
)
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
x = inputs.view(inputs.size(0), -1)
|
|
261
|
+
|
|
262
|
+
logp1 = -0.5 * torch.sum((x + self.offset)**2, dim=1) - self._log_z
|
|
263
|
+
logp2 = -0.5 * torch.sum((x - self.offset)**2, dim=1) - self._log_z
|
|
264
|
+
|
|
265
|
+
log_prob = torch.logaddexp(logp1, logp2) - np.log(2.0)
|
|
266
|
+
return log_prob
|
|
267
|
+
|
|
268
|
+
def _mean(self, context):
|
|
269
|
+
if context is None:
|
|
270
|
+
return self._log_z.new_zeros(self._shape)
|
|
271
|
+
else:
|
|
272
|
+
return context.new_zeros(context.shape[0], *self._shape)
|
|
273
|
+
|
|
274
|
+
class ConditionalBimodal(nn.Module):
|
|
275
|
+
def __init__(self, context_dim, latent_dim, hidden_dim=64):
|
|
276
|
+
super().__init__()
|
|
277
|
+
self.latent_dim = latent_dim
|
|
278
|
+
|
|
279
|
+
# context → offset (1,), logit(weight) (1,), scale (1,)
|
|
280
|
+
self.net = nn.Sequential(
|
|
281
|
+
nn.Linear(context_dim, hidden_dim),
|
|
282
|
+
nn.ReLU(),
|
|
283
|
+
nn.Linear(hidden_dim, 2 * latent_dim + 2) # mean1, mean2, scale, weight_logits
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
def get_distribution(self, context):
|
|
287
|
+
batch_size = context.shape[0]
|
|
288
|
+
params = self.net(context) # (batch, 2*latent_dim + 2)
|
|
289
|
+
|
|
290
|
+
# 分解
|
|
291
|
+
mean1 = params[:, :self.latent_dim]
|
|
292
|
+
mean2 = params[:, self.latent_dim:2*self.latent_dim]
|
|
293
|
+
scale_raw = params[:, 2*self.latent_dim:2*self.latent_dim+1] # shape (batch, 1)
|
|
294
|
+
weight_logits = params[:, 2*self.latent_dim+1:] # shape (batch, 1)
|
|
295
|
+
|
|
296
|
+
scale = F.softplus(scale_raw) + 1e-3
|
|
297
|
+
|
|
298
|
+
mix_logits = torch.cat([weight_logits, -weight_logits], dim=1) # (batch, 2)
|
|
299
|
+
mix_dist = Categorical(logits=mix_logits)
|
|
300
|
+
|
|
301
|
+
means = torch.stack([mean1, mean2], dim=1) # (batch, 2, latent_dim)
|
|
302
|
+
scales = scale.unsqueeze(1).expand_as(means) # (batch, 2, latent_dim)
|
|
303
|
+
|
|
304
|
+
comp_dist = Independent(Normal(means, scales), 1) # treat latent_dim as event dim
|
|
305
|
+
mixture = MixtureSameFamily(mix_dist, comp_dist)
|
|
306
|
+
return mixture
|
|
307
|
+
|
|
308
|
+
def sample(self, num_samples, context):
|
|
309
|
+
"""
|
|
310
|
+
context: (batch, context_dim)
|
|
311
|
+
return: (batch, num_samples, latent_dim)
|
|
312
|
+
"""
|
|
313
|
+
dist = self.get_distribution(context)
|
|
314
|
+
samples = dist.sample((num_samples,)) # (num_samples, batch, latent_dim)
|
|
315
|
+
return samples.permute(1, 0, 2) # (batch, num_samples, latent_dim)
|
|
316
|
+
|
|
317
|
+
def log_prob(self, x, context):
|
|
318
|
+
"""
|
|
319
|
+
x: (batch, latent_dim)
|
|
320
|
+
context: (batch, context_dim)
|
|
321
|
+
return: (batch,)
|
|
322
|
+
"""
|
|
323
|
+
dist = self.get_distribution(context)
|
|
324
|
+
return dist.log_prob(x)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
from nflows.transforms.base import InputOutsideDomain, Transform, InverseTransform
|
|
328
|
+
|
|
329
|
+
class Sigmoid(Transform):
|
|
330
|
+
def __init__(self, temperature=1, eps=1e-6, learn_temperature=False):
|
|
331
|
+
super().__init__()
|
|
332
|
+
self.eps = eps
|
|
333
|
+
if learn_temperature:
|
|
334
|
+
self.temperature = nn.Parameter(torch.Tensor([temperature]))
|
|
335
|
+
else:
|
|
336
|
+
temperature = torch.Tensor([temperature])
|
|
337
|
+
self.register_buffer('temperature', temperature)
|
|
338
|
+
|
|
339
|
+
def forward(self, inputs, context=None):
|
|
340
|
+
temperature = self.temperature.to(inputs.device)
|
|
341
|
+
|
|
342
|
+
inputs = temperature * inputs
|
|
343
|
+
outputs = torch.sigmoid(inputs)
|
|
344
|
+
|
|
345
|
+
logabsdet = torchutils.sum_except_batch(
|
|
346
|
+
torch.log(temperature) - F.softplus(-inputs) - F.softplus(inputs)
|
|
347
|
+
)
|
|
348
|
+
return outputs, logabsdet
|
|
349
|
+
|
|
350
|
+
def inverse(self, inputs, context=None):
|
|
351
|
+
if torch.min(inputs) < 0 or torch.max(inputs) > 1:
|
|
352
|
+
raise InputOutsideDomain()
|
|
353
|
+
|
|
354
|
+
inputs = torch.clamp(inputs, self.eps, 1 - self.eps)
|
|
355
|
+
|
|
356
|
+
temperature = self.temperature.to(inputs.device)
|
|
357
|
+
outputs = (1 / temperature) * (torch.log(inputs) - torch.log1p(-inputs))
|
|
358
|
+
logabsdet = -torchutils.sum_except_batch(
|
|
359
|
+
torch.log(temperature)
|
|
360
|
+
- F.softplus(-temperature * outputs)
|
|
361
|
+
- F.softplus(temperature * outputs)
|
|
362
|
+
)
|
|
363
|
+
return outputs, logabsdet
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
class Logit(InverseTransform):
|
|
367
|
+
def __init__(self, temperature=1, eps=1e-6):
|
|
368
|
+
super().__init__(Sigmoid(temperature=temperature, eps=eps))
|