cosmoglint 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cosmoglint/__init__.py +1 -0
- cosmoglint/model/__init__.py +2 -0
- cosmoglint/model/transformer.py +500 -0
- cosmoglint/model/transformer_nf.py +368 -0
- cosmoglint/utils/ReadPinocchio5.py +1022 -0
- cosmoglint/utils/__init__.py +2 -0
- cosmoglint/utils/cosmology_utils.py +194 -0
- cosmoglint/utils/generation_utils.py +366 -0
- cosmoglint/utils/io_utils.py +397 -0
- cosmoglint-1.0.0.dist-info/METADATA +164 -0
- cosmoglint-1.0.0.dist-info/RECORD +14 -0
- cosmoglint-1.0.0.dist-info/WHEEL +5 -0
- cosmoglint-1.0.0.dist-info/licenses/LICENSE +21 -0
- cosmoglint-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,397 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import os
|
|
3
|
+
from argparse import Namespace
|
|
4
|
+
import random
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
import h5py
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
|
|
13
|
+
from torch.utils.data import Dataset
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def my_save_model(model, fname):
|
|
17
|
+
torch.save(model.state_dict(), fname)
|
|
18
|
+
print(f"# Model saved to {fname}")
|
|
19
|
+
|
|
20
|
+
def save_catalog_data(pos_list, value, args, output_fname):
|
|
21
|
+
if not isinstance(pos_list, list):
|
|
22
|
+
pos_list = [pos_list]
|
|
23
|
+
|
|
24
|
+
with open(output_fname, 'w') as f:
|
|
25
|
+
for i, v in enumerate(value):
|
|
26
|
+
f.write(f"{pos_list[0][i, 0]} {pos_list[0][i, 1]} ")
|
|
27
|
+
for pos in pos_list:
|
|
28
|
+
f.write(f"{pos[i, 2]} ")
|
|
29
|
+
|
|
30
|
+
f.write(f"{v}\n")
|
|
31
|
+
|
|
32
|
+
print(f"# Catalog saved to {output_fname}")
|
|
33
|
+
|
|
34
|
+
def save_intensity_data(intensity, args, output_fname):
|
|
35
|
+
args_dict = vars(args)
|
|
36
|
+
args_dict = {k: (v if v is not None else "None") for k, v in args_dict.items()}
|
|
37
|
+
|
|
38
|
+
if not isinstance(intensity, list):
|
|
39
|
+
intensity = [intensity]
|
|
40
|
+
|
|
41
|
+
with h5py.File(output_fname, 'w') as f:
|
|
42
|
+
for i, d in enumerate(intensity):
|
|
43
|
+
f.create_dataset(f'intensity{i}', data=d)
|
|
44
|
+
for key, value in args_dict.items():
|
|
45
|
+
f.attrs[key] = value
|
|
46
|
+
print(f"# Data cube saved as {output_fname}")
|
|
47
|
+
|
|
48
|
+
def namespace_to_dict(ns):
|
|
49
|
+
if isinstance(ns, Namespace):
|
|
50
|
+
return {k: namespace_to_dict(v) for k, v in vars(ns).items()}
|
|
51
|
+
elif isinstance(ns, dict):
|
|
52
|
+
return {k: namespace_to_dict(v) for k, v in ns.items()}
|
|
53
|
+
else:
|
|
54
|
+
return ns
|
|
55
|
+
|
|
56
|
+
def convert_to_log(val, val_min):
|
|
57
|
+
log_val = np.full_like(val, val_min)
|
|
58
|
+
mask = val > 10**val_min
|
|
59
|
+
log_val[mask] = np.log10(val[mask])
|
|
60
|
+
return log_val
|
|
61
|
+
|
|
62
|
+
def convert_to_log_with_sign(val):
|
|
63
|
+
return np.sign(val) * np.log10(np.abs(val) + 1)
|
|
64
|
+
|
|
65
|
+
def inverse_convert_to_log_with_sign(val):
|
|
66
|
+
return np.sign(val) * ( 10 ** np.abs( val ) - 1 )
|
|
67
|
+
|
|
68
|
+
def normalize(x, key, norm_param_dict, inverse=False, convert=True):
|
|
69
|
+
"""
|
|
70
|
+
x: array-like
|
|
71
|
+
key: str
|
|
72
|
+
norm_param_dict: dict
|
|
73
|
+
e.g., {
|
|
74
|
+
"GroupMass": {"min": 10, "max": 15, "norm": "log"},
|
|
75
|
+
...
|
|
76
|
+
}
|
|
77
|
+
inverse: bool
|
|
78
|
+
If True, perform inverse normalization.
|
|
79
|
+
convert: bool
|
|
80
|
+
If True, convert to/from log scale based on norm_param_dict.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
x = np.array(x)
|
|
84
|
+
|
|
85
|
+
if norm_param_dict is not None:
|
|
86
|
+
xmin = norm_param_dict[key]["min"]
|
|
87
|
+
xmax = norm_param_dict[key]["max"]
|
|
88
|
+
norm = norm_param_dict[key]["norm"]
|
|
89
|
+
|
|
90
|
+
if inverse:
|
|
91
|
+
x = x * ( xmax - xmin ) + xmin
|
|
92
|
+
if convert:
|
|
93
|
+
if norm == "log":
|
|
94
|
+
x = 10 ** x
|
|
95
|
+
elif norm == "log_with_sign":
|
|
96
|
+
x = inverse_convert_to_log_with_sign(x)
|
|
97
|
+
else:
|
|
98
|
+
if convert:
|
|
99
|
+
if norm == "log":
|
|
100
|
+
x = convert_to_log(x, xmin)
|
|
101
|
+
elif norm == "log_with_sign":
|
|
102
|
+
x = convert_to_log_with_sign(x)
|
|
103
|
+
x = ( x - xmin ) / ( xmax - xmin )
|
|
104
|
+
|
|
105
|
+
return x
|
|
106
|
+
|
|
107
|
+
def load_global_params(global_param_file, global_features, norm_param_dict=None):
|
|
108
|
+
|
|
109
|
+
if global_features is None:
|
|
110
|
+
global_params = None
|
|
111
|
+
|
|
112
|
+
else:
|
|
113
|
+
if global_param_file is None:
|
|
114
|
+
raise ValueError("global_param_file must be specified when global_features is provided.")
|
|
115
|
+
|
|
116
|
+
if not isinstance(global_param_file, list):
|
|
117
|
+
global_param_file = [global_param_file]
|
|
118
|
+
|
|
119
|
+
global_params = []
|
|
120
|
+
for f in global_param_file:
|
|
121
|
+
data = np.genfromtxt(f, names=True, dtype=None, encoding="utf-8")
|
|
122
|
+
global_params_now = np.vstack([data[name] for name in global_features]).T.astype(np.float32)
|
|
123
|
+
global_params.append(global_params_now)
|
|
124
|
+
|
|
125
|
+
global_params = np.vstack(global_params)
|
|
126
|
+
|
|
127
|
+
for i, key in enumerate(global_features):
|
|
128
|
+
global_params[...,i] = normalize(global_params[...,i], key, norm_param_dict)
|
|
129
|
+
|
|
130
|
+
return global_params # (ndata, num_features_global)
|
|
131
|
+
|
|
132
|
+
def load_halo_data(
|
|
133
|
+
file_path,
|
|
134
|
+
input_features,
|
|
135
|
+
output_features,
|
|
136
|
+
norm_param_dict=None,
|
|
137
|
+
max_length=10,
|
|
138
|
+
sort=True,
|
|
139
|
+
ndata=None,
|
|
140
|
+
exclude_ratio=0.0,
|
|
141
|
+
use_excluded_region=False,
|
|
142
|
+
):
|
|
143
|
+
|
|
144
|
+
def load_values(f, key):
|
|
145
|
+
if key not in f:
|
|
146
|
+
raise ValueError(f"Key '{key}' not found in the file.")
|
|
147
|
+
|
|
148
|
+
data = f[key][:]
|
|
149
|
+
return data
|
|
150
|
+
|
|
151
|
+
num_features_in = len(input_features)
|
|
152
|
+
num_features_out = len(output_features)
|
|
153
|
+
|
|
154
|
+
with h5py.File(file_path, "r") as f:
|
|
155
|
+
|
|
156
|
+
# Load input features
|
|
157
|
+
source_list = []
|
|
158
|
+
for feature in input_features:
|
|
159
|
+
x = load_values(f, feature)
|
|
160
|
+
source_list.append(x)
|
|
161
|
+
|
|
162
|
+
source = np.stack(source_list, axis=1) # (N, num_features_in)
|
|
163
|
+
for i, key in enumerate(input_features):
|
|
164
|
+
source[:,i] = normalize(source[:,i], key, norm_param_dict)
|
|
165
|
+
|
|
166
|
+
mask = np.ones(len(source), dtype=bool)
|
|
167
|
+
for i in range(num_features_in):
|
|
168
|
+
mask = mask & ( source[:,i] > 0 )
|
|
169
|
+
|
|
170
|
+
if exclude_ratio > 0:
|
|
171
|
+
boxsize = f.attrs["BoxSize"] # [kpc/h]
|
|
172
|
+
halo_pos = f["GroupPos"][:] # [kpc/h]
|
|
173
|
+
mask_exclude = (halo_pos[:,0] > boxsize * (1.-exclude_ratio)) \
|
|
174
|
+
& (halo_pos[:,1] > boxsize * (1.-exclude_ratio)) \
|
|
175
|
+
& (halo_pos[:,2] > boxsize * (1.-exclude_ratio))
|
|
176
|
+
|
|
177
|
+
if use_excluded_region:
|
|
178
|
+
print("# Using excluded region of size ({} * BoxSize)^3".format(exclude_ratio))
|
|
179
|
+
mask = mask & mask_exclude
|
|
180
|
+
else:
|
|
181
|
+
print("# Exclude halos in the corner of size ({} * BoxSize)^3".format(exclude_ratio))
|
|
182
|
+
print("# The excluded region is {:.2f} % of the entire volume".format(100.0 * (exclude_ratio**3)))
|
|
183
|
+
mask = mask & (~mask_exclude)
|
|
184
|
+
|
|
185
|
+
if mask.sum() == 0:
|
|
186
|
+
print("# No halo is found in {}".format(file_path))
|
|
187
|
+
return torch.empty((0, num_features_in), dtype=torch.float32), []
|
|
188
|
+
|
|
189
|
+
# Load output features
|
|
190
|
+
target_list = []
|
|
191
|
+
for feature in output_features:
|
|
192
|
+
y = load_values(f, feature)
|
|
193
|
+
target_list.append(y)
|
|
194
|
+
|
|
195
|
+
target = np.stack(target_list, axis=1) # (N, num_features_out)
|
|
196
|
+
for i, key in enumerate(output_features):
|
|
197
|
+
target[:,i] = normalize(target[:,i], key, norm_param_dict)
|
|
198
|
+
|
|
199
|
+
num_subgroups = f["GroupNsubs"][:]
|
|
200
|
+
|
|
201
|
+
offset = 0
|
|
202
|
+
y_list = []
|
|
203
|
+
for j in range(len(source)):
|
|
204
|
+
start = offset
|
|
205
|
+
end = start + num_subgroups[j]
|
|
206
|
+
offset = end
|
|
207
|
+
|
|
208
|
+
if not mask[j]:
|
|
209
|
+
continue
|
|
210
|
+
|
|
211
|
+
if num_subgroups[j] == 0:
|
|
212
|
+
y_j = np.zeros((1, num_features_out)) # handle empty subgroups
|
|
213
|
+
else:
|
|
214
|
+
y_j = target[start:end, :]
|
|
215
|
+
|
|
216
|
+
if sort:
|
|
217
|
+
sorted_indices = [0] + sorted(range(1, len(y_j)), key=lambda k: y_j[k,0], reverse=True)
|
|
218
|
+
y_j = y_j[sorted_indices]
|
|
219
|
+
|
|
220
|
+
y_j = y_j[:max_length] # truncate
|
|
221
|
+
y_j = torch.tensor(y_j, dtype=torch.float32)
|
|
222
|
+
y_list.append(y_j)
|
|
223
|
+
|
|
224
|
+
x = source[mask]
|
|
225
|
+
x = torch.tensor(x, dtype=torch.float32)
|
|
226
|
+
|
|
227
|
+
if ndata is not None:
|
|
228
|
+
x = x[:ndata]
|
|
229
|
+
y_list = y_list[:ndata]
|
|
230
|
+
|
|
231
|
+
return x, y_list
|
|
232
|
+
|
|
233
|
+
class MyDataset(Dataset):
|
|
234
|
+
def __init__(
|
|
235
|
+
self,
|
|
236
|
+
path,
|
|
237
|
+
input_features,
|
|
238
|
+
output_features,
|
|
239
|
+
global_params=None,
|
|
240
|
+
norm_param_dict=None,
|
|
241
|
+
max_length=10,
|
|
242
|
+
sort=True,
|
|
243
|
+
ndata=None,
|
|
244
|
+
exclude_ratio=0.0,
|
|
245
|
+
use_excluded_region=False,
|
|
246
|
+
use_flat_representation=False,
|
|
247
|
+
show_pbar=True,
|
|
248
|
+
):
|
|
249
|
+
|
|
250
|
+
if not isinstance(path, list):
|
|
251
|
+
path = [path]
|
|
252
|
+
|
|
253
|
+
if global_params is not None:
|
|
254
|
+
if len(global_params) != len(path):
|
|
255
|
+
raise ValueError("The number of global parameter sets must match the number of data files")
|
|
256
|
+
|
|
257
|
+
x = []
|
|
258
|
+
self.y = []
|
|
259
|
+
self.g = []
|
|
260
|
+
|
|
261
|
+
plist = path
|
|
262
|
+
if len(path) < 20:
|
|
263
|
+
verbose = True
|
|
264
|
+
else:
|
|
265
|
+
verbose = False
|
|
266
|
+
if show_pbar:
|
|
267
|
+
plist = tqdm(plist, file=sys.stderr)
|
|
268
|
+
print("# Loading halo data from {} to {} ({} files)".format(path[0], path[-1], len(path)))
|
|
269
|
+
|
|
270
|
+
for i, p in enumerate(plist):
|
|
271
|
+
if verbose:
|
|
272
|
+
print(f"# Loading halo data from {p}")
|
|
273
|
+
|
|
274
|
+
x_tmp, y_tmp = load_halo_data(p, input_features, output_features, norm_param_dict=norm_param_dict, max_length=max_length, sort=sort, ndata=ndata, exclude_ratio=exclude_ratio, use_excluded_region=use_excluded_region)
|
|
275
|
+
x.append(x_tmp)
|
|
276
|
+
self.y = self.y + y_tmp
|
|
277
|
+
|
|
278
|
+
if global_params is not None:
|
|
279
|
+
global_param = global_params[i]
|
|
280
|
+
g_tmp = np.repeat(global_param[None, :], len(x_tmp), axis=0) # (Nhalo, num_features_global)
|
|
281
|
+
else:
|
|
282
|
+
g_tmp = np.zeros((len(x_tmp), 1)) # dummy (Nhalo, 1)
|
|
283
|
+
|
|
284
|
+
self.g.append( g_tmp )
|
|
285
|
+
|
|
286
|
+
self.x = torch.cat(x, dim=0)
|
|
287
|
+
|
|
288
|
+
if len(self.x) == 0:
|
|
289
|
+
raise ValueError("No halo is found.")
|
|
290
|
+
|
|
291
|
+
self.g = np.vstack(self.g)
|
|
292
|
+
self.g = torch.tensor(self.g, dtype=torch.float32)
|
|
293
|
+
|
|
294
|
+
_, num_params = (self.y[0]).shape
|
|
295
|
+
|
|
296
|
+
self.y_padded = torch.zeros(len(self.x), max_length, num_params)
|
|
297
|
+
self.mask = torch.zeros(len(self.x), max_length, num_params, dtype=torch.bool)
|
|
298
|
+
|
|
299
|
+
for i, y_i in enumerate(self.y):
|
|
300
|
+
length = len(y_i)
|
|
301
|
+
self.y_padded[i, :length, :] = y_i[:max_length]
|
|
302
|
+
self.mask[i, :length+1, :] = True # use the last + 1 value to learn when to stop
|
|
303
|
+
|
|
304
|
+
if use_flat_representation:
|
|
305
|
+
self.y_padded = self.y_padded.reshape(len(self.y_padded), -1, 1) # (Nhalo, max_length * output_features, 1)
|
|
306
|
+
self.mask = self.mask.reshape(len(self.mask), -1, 1) # (Nhalo, max_length * output_features, 1)
|
|
307
|
+
|
|
308
|
+
def __len__(self):
|
|
309
|
+
return len(self.x)
|
|
310
|
+
|
|
311
|
+
def __getitem__(self, idx):
|
|
312
|
+
batch = {
|
|
313
|
+
"context": self.x[idx],
|
|
314
|
+
"global_context": self.g[idx],
|
|
315
|
+
"target": self.y_padded[idx],
|
|
316
|
+
"mask": self.mask[idx]
|
|
317
|
+
}
|
|
318
|
+
return batch
|
|
319
|
+
|
|
320
|
+
def load_lightcone_data(input_fname, cosmo):
|
|
321
|
+
print(f"# Load {input_fname}")
|
|
322
|
+
|
|
323
|
+
if "pinocchio" in input_fname:
|
|
324
|
+
if "old_version" in input_fname:
|
|
325
|
+
M, theta, phi, _, redshift_obs, redshift_real = load_old_plc(input_fname)
|
|
326
|
+
mass = M
|
|
327
|
+
else:
|
|
328
|
+
import cosmoglint.utils.ReadPinocchio5 as rp
|
|
329
|
+
myplc = rp.plc(input_fname)
|
|
330
|
+
|
|
331
|
+
mass = myplc.data["Mass"]
|
|
332
|
+
theta = myplc.data["theta"] # [arcsec]
|
|
333
|
+
phi = myplc.data["phi"]
|
|
334
|
+
|
|
335
|
+
redshift_obs = myplc.data["obsz"]
|
|
336
|
+
redshift_real = myplc.data["truez"]
|
|
337
|
+
|
|
338
|
+
import astropy.units as u
|
|
339
|
+
hlittle = cosmo.H(0).to(u.km/u.s/u.Mpc).value / 100.0
|
|
340
|
+
mass /= hlittle # [Msun]
|
|
341
|
+
|
|
342
|
+
theta = ( 90. - theta ) * 3600 # [arcsec]
|
|
343
|
+
pos_x = theta * np.cos( phi * np.pi / 180. ) # [arcsec]
|
|
344
|
+
pos_y = theta * np.sin( phi * np.pi / 180. ) # [arcsec]
|
|
345
|
+
|
|
346
|
+
print("# Minimum log mass in catalog: {:.5f}".format(np.min(np.log10(mass))))
|
|
347
|
+
print("# Maximum pos: ({:.3f}, {:.3f}) arcsec".format(np.max(pos_x), np.max(pos_y)))
|
|
348
|
+
print("# Minimum pos: ({:.3f}, {:.3f}) arcsec".format(np.min(pos_x), np.min(pos_y)))
|
|
349
|
+
print("# Redshift: {:.3f} - {:.3f}".format(np.min(redshift_real), np.max(redshift_real)))
|
|
350
|
+
print("# Number of halos: {}".format(len(mass)))
|
|
351
|
+
|
|
352
|
+
else:
|
|
353
|
+
raise ValueError("Unknown input file format")
|
|
354
|
+
|
|
355
|
+
return mass, pos_x, pos_y, redshift_obs, redshift_real
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def load_old_plc(filename):
|
|
359
|
+
import struct
|
|
360
|
+
|
|
361
|
+
plc_struct_format = "<Q d ddd ddd ddddd" # Q=uint64, d=double, little-endian
|
|
362
|
+
plc_size = struct.calcsize(plc_struct_format)
|
|
363
|
+
|
|
364
|
+
M_list = []
|
|
365
|
+
th_list = []
|
|
366
|
+
ph_list = []
|
|
367
|
+
vl_list = []
|
|
368
|
+
zo_list = []
|
|
369
|
+
z_list = []
|
|
370
|
+
with open(filename, "rb") as f:
|
|
371
|
+
while True:
|
|
372
|
+
dummy_bytes = f.read(4)
|
|
373
|
+
if not dummy_bytes:
|
|
374
|
+
break # EOF
|
|
375
|
+
dummy = struct.unpack("<i", dummy_bytes)[0]
|
|
376
|
+
|
|
377
|
+
plc_bytes = f.read(dummy)
|
|
378
|
+
if len(plc_bytes) != dummy:
|
|
379
|
+
break # 不完全な読み込み
|
|
380
|
+
|
|
381
|
+
data = struct.unpack(plc_struct_format, plc_bytes)
|
|
382
|
+
(
|
|
383
|
+
id, z, x1, x2, x3, v1, v2, v3,
|
|
384
|
+
M, th, ph, vl, zo
|
|
385
|
+
) = data
|
|
386
|
+
|
|
387
|
+
dummy2_bytes = f.read(4)
|
|
388
|
+
dummy2 = struct.unpack("<i", dummy2_bytes)[0]
|
|
389
|
+
|
|
390
|
+
M_list.append(M)
|
|
391
|
+
th_list.append(th)
|
|
392
|
+
ph_list.append(ph)
|
|
393
|
+
vl_list.append(vl)
|
|
394
|
+
zo_list.append(zo)
|
|
395
|
+
z_list.append(z)
|
|
396
|
+
|
|
397
|
+
return np.array(M_list), np.array(th_list), np.array(ph_list), np.array(vl_list), np.array(zo_list), np.array(z_list)
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: cosmoglint
|
|
3
|
+
Version: 1.0.0
|
|
4
|
+
Summary: Transformer-based generative model for galaxies
|
|
5
|
+
Author: Kana Moriwaki
|
|
6
|
+
License: MIT License
|
|
7
|
+
|
|
8
|
+
Copyright (c) 2026 Kana Moriwaki
|
|
9
|
+
|
|
10
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
11
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
12
|
+
in the Software without restriction, including without limitation the rights
|
|
13
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
14
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
15
|
+
furnished to do so, subject to the following conditions:
|
|
16
|
+
|
|
17
|
+
The above copyright notice and this permission notice shall be included in all
|
|
18
|
+
copies or substantial portions of the Software.
|
|
19
|
+
|
|
20
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
21
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
22
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
23
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
24
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
25
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
26
|
+
SOFTWARE.
|
|
27
|
+
|
|
28
|
+
Project-URL: Homepage, https://github.com/knmoriwaki/cosmoglint
|
|
29
|
+
Requires-Python: >=3.9
|
|
30
|
+
Description-Content-Type: text/markdown
|
|
31
|
+
License-File: LICENSE
|
|
32
|
+
Requires-Dist: numpy
|
|
33
|
+
Requires-Dist: tqdm
|
|
34
|
+
Requires-Dist: h5py
|
|
35
|
+
Requires-Dist: nflows
|
|
36
|
+
Requires-Dist: astropy
|
|
37
|
+
Dynamic: license-file
|
|
38
|
+
|
|
39
|
+
# CosmoGLINT: Cosmological Generative model for Line INtensity mapping with Transformer
|
|
40
|
+
|
|
41
|
+
This repository includes:
|
|
42
|
+
|
|
43
|
+
- cosmoglint, a package of Transformer-based models that generate galaxy properties from halo mass.
|
|
44
|
+
- Scripts for training and mock catalog generation.
|
|
45
|
+
- Example notebooks for result visualization.
|
|
46
|
+
|
|
47
|
+
Models trained with TNG300-1 at z = 0.5 - 6 and generated data are available at [Google Drive](https://drive.google.com/drive/folders/1IFje9tNRf4Dr3NufqzlDdGMFTEDpsm35?usp=share_link).
|
|
48
|
+
|
|
49
|
+
For detailed usage and options, see [DOCUMENTATION](./DOCUMENTATION.md).
|
|
50
|
+
|
|
51
|
+
---
|
|
52
|
+
|
|
53
|
+
## Installation
|
|
54
|
+
|
|
55
|
+
Python>=3.9 is required.
|
|
56
|
+
|
|
57
|
+
This package requires PyTorch.
|
|
58
|
+
Please install PyTorch first following https://pytorch.org
|
|
59
|
+
|
|
60
|
+
Install package and from local clone:
|
|
61
|
+
|
|
62
|
+
```bash
|
|
63
|
+
git clone https://github.com/knmoriwaki/cosmoglint.git
|
|
64
|
+
cd cosmoglint
|
|
65
|
+
pip install .
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
If you only need the `cosmoglint` package (e.g., to import it in your own code), you can install it directly:
|
|
69
|
+
|
|
70
|
+
```bash
|
|
71
|
+
pip install git+https://github.com/knmoriwaki/cosmoglint.git
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
Several libraries needs to be additionally installed to run the scripts and notebooks:
|
|
76
|
+
```bash
|
|
77
|
+
pip install -r requirements.txt
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
## Training
|
|
81
|
+
|
|
82
|
+
Example:
|
|
83
|
+
```bash
|
|
84
|
+
python train_transformer.py --data_path [data_path] --norm_param_file [norm_param_file]
|
|
85
|
+
```
|
|
86
|
+
|
|
87
|
+
Options:
|
|
88
|
+
- `--data_path`: Path(s) to the training data. Data is an hdf5 file that contains properties of halos and galaxies. In addition to those for input and output features, the number of galaxies in each halo (`GroupNsubs`) should be provided. Multiple files can be passed.
|
|
89
|
+
- `--norm_param_file`: Path to the json file that specifies the normalization settings. Each key (e.g., `HaloMass`) maps to a dictionary with `min` / `max` and `norm`. If `norm` is `"log"` or `"log_with_sign"`, the `min` / `max` normalization is applied after the log conversion.
|
|
90
|
+
- `--input_features`: List of the input properties (default: `["GroupMass"]`)
|
|
91
|
+
- `--output_features`: List of the output properties (default: `["SubhaloSFR", "SubhaloDist", "SubhaloVrad", "SubhaloVtan"]`)
|
|
92
|
+
- `--max_length`: Maximum number of galaxies (sequence length) per halo (default: 30).
|
|
93
|
+
- `--use_flat_representation`: If true, use flattened point features (B, N * M). If false, keep (B, N, M). Use `--no-use_flat_representation` to set it to false (default: true).
|
|
94
|
+
|
|
95
|
+
## Create mock data cube
|
|
96
|
+
|
|
97
|
+
Example:
|
|
98
|
+
```bash
|
|
99
|
+
python create_data_cube.py --input_fname [input_fname] --model_dir [model_dir]
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
Options:
|
|
103
|
+
- `--input_fname`: Path to the halo catalog. Text file that contains halo mass [Msun] in log scale (1st column), comving positions [Mpc/h] (2nd to 4th columns), and velocities [km/s] (5th to 8th columns) and catalog in [Pinocchio](https://github.com/pigimonaco/Pinocchio) format are supported.
|
|
104
|
+
- `--model_dir`: Path to a directory containing the trained model (`model.pth` and `args.json`). If not set, column 7 of the input file is used as intensity.
|
|
105
|
+
- `--boxsize`: Size of the simulation box in comoving units [Mpc/h] (default: 100.0).
|
|
106
|
+
- `--redshift_space`: If set, positions are converted to redshift space using halo velocities.
|
|
107
|
+
- `--gen_both`: If set, generates both real-space and redshift-space data cubes.
|
|
108
|
+
- `--npix`: Number of pixels in the x and y directions for the data cube (default: 100).
|
|
109
|
+
- `--npix_z`: Number of pixels in the z direction (default: 90).
|
|
110
|
+
|
|
111
|
+
## Create lightcone
|
|
112
|
+
|
|
113
|
+
Example:
|
|
114
|
+
```bash
|
|
115
|
+
python create_lightcone.py --input_fname [input_fname] --model_dir [model_dir] --model_config_file [model_config_file]
|
|
116
|
+
```
|
|
117
|
+
|
|
118
|
+
Example of `model_config_file`:
|
|
119
|
+
```json
|
|
120
|
+
{
|
|
121
|
+
"33": ["transformer1_33_ep40_bs512_w0.02", 2.002],
|
|
122
|
+
"21": ["transformer1_21_ep60_bs512_w0.02", 4.008]
|
|
123
|
+
}
|
|
124
|
+
```
|
|
125
|
+
|
|
126
|
+
- `--input_fname`: Path to the lightcone halo catalog. Pinocchio format is supported.
|
|
127
|
+
- `--output_fname`: Output filename (HDF5 format).
|
|
128
|
+
- `--model_dir`: Path to a directory containing the trained models.
|
|
129
|
+
- `--model_config_file`: Path to a JSON file that contains the names of the trained models to be used for each redshift bin. The JSON file is a dictionary where each key is a stringified snapshot ID, and the value is a list containing the model directory relative to `model_dir` and the redshift.
|
|
130
|
+
- `--redshift_space`: If set, generate output in redshift space.
|
|
131
|
+
- `redshift_min`, `--redshift_max`: Redshift range for the lightcone.
|
|
132
|
+
- `dz`: Redshift bin width. Indicates dlogz if `--use_logz` is given.
|
|
133
|
+
- `use_logz`: Use dlogz instead of dz for redshift binning.
|
|
134
|
+
- `--side_length`, `--angular_resolution`: Angular size and resolution (arcsec) of the simulated map.
|
|
135
|
+
- `--gen_catalog`: If set, generate a galaxy catalog with SFR greater than --catalog_threshold.
|
|
136
|
+
- `--catalog_threshold`: SFR threshold for inclusion in the catalog.
|
|
137
|
+
|
|
138
|
+
## Visualization
|
|
139
|
+
|
|
140
|
+
Example Jupyter notebooks are available in the `notebooks/` directory:
|
|
141
|
+
|
|
142
|
+
- `plot_transformer.ipynb`: visualize training results
|
|
143
|
+
- `plot_data_cube.ipynb`: visualize created data cube
|
|
144
|
+
- `plot_lightcone.ipynb`: visualize lightcone data
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
## Citation
|
|
148
|
+
|
|
149
|
+
If you use CosmoGLINT in your research, please cite [Moriwaki et al. 2025](https://arxiv.org/abs/2506.16843)
|
|
150
|
+
|
|
151
|
+
```
|
|
152
|
+
@ARTICLE{CosmoGLINT,
|
|
153
|
+
title = {CosmoGLINT: Cosmological Generative Model for Line Intensity Mapping with Transformer},
|
|
154
|
+
author = {{Moriwaki}, Kana and {Jun}, Rui Lan and {Osato}, Ken and {Yoshida}, Naoki},
|
|
155
|
+
journal = {arXiv preprints},
|
|
156
|
+
year = 2025,
|
|
157
|
+
month = jun,
|
|
158
|
+
eid = {arXiv:2506.16843},
|
|
159
|
+
doi = {10.48550/arXiv.2506.16843},
|
|
160
|
+
archivePrefix = {arXiv},
|
|
161
|
+
eprint = {2506.16843},
|
|
162
|
+
primaryClass = {astro-ph.CO}
|
|
163
|
+
}
|
|
164
|
+
```
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
cosmoglint/__init__.py,sha256=Aj77VL1d5Mdku7sgCgKQmPuYavPpAHuZuJcy6bygQZE,21
|
|
2
|
+
cosmoglint/model/__init__.py,sha256=gmiP1R51lrK8dSjjAnLXHshGllnXa1SV-dhsSqqiqis,244
|
|
3
|
+
cosmoglint/model/transformer.py,sha256=woFgnk_uWKZEl6EIQX6PBU0ERtUKSk1W_2_sPjYujnM,21949
|
|
4
|
+
cosmoglint/model/transformer_nf.py,sha256=MXMYdj1KY1dZgmEzCeFxWp_QdqgHO7Sc3tDguoNG7P0,14459
|
|
5
|
+
cosmoglint/utils/ReadPinocchio5.py,sha256=17v3GFydE-Ha45TohwtUsU60xmGzfuZ1sNjPJEzUguI,39714
|
|
6
|
+
cosmoglint/utils/__init__.py,sha256=r19103JGuFfnKqPa0Oi0jmeAeKM5darGcTEY-_wLzpg,206
|
|
7
|
+
cosmoglint/utils/cosmology_utils.py,sha256=Ergv-REB-0LERfbXc0AnB4KJTLdUOCeVmze2u3H4KUE,7606
|
|
8
|
+
cosmoglint/utils/generation_utils.py,sha256=cnMVTjZdHaefKKBUpEFL7BSN2pSEyiQgyR48MdMqy9c,15301
|
|
9
|
+
cosmoglint/utils/io_utils.py,sha256=XUF7znHmGSZ8wQpOPk-KqNJvYN6-0rU8NE8mEtO1q2I,13068
|
|
10
|
+
cosmoglint-1.0.0.dist-info/licenses/LICENSE,sha256=xMgUlRtQRou9wDsNl4o8Op0w_VHlQmhY_uPwFCl7SWo,1070
|
|
11
|
+
cosmoglint-1.0.0.dist-info/METADATA,sha256=LJu3TSkCsnAqenZZWKKNYk0NNUVZGDIQWnJHBGO_kR0,7140
|
|
12
|
+
cosmoglint-1.0.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
13
|
+
cosmoglint-1.0.0.dist-info/top_level.txt,sha256=gKGGdFf41h3PsCSMtVYESsfUqfxNLsYXSCe3tQKiPxw,11
|
|
14
|
+
cosmoglint-1.0.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Kana Moriwaki
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
cosmoglint
|