cosmoglint 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cosmoglint/__init__.py +1 -0
- cosmoglint/model/__init__.py +2 -0
- cosmoglint/model/transformer.py +500 -0
- cosmoglint/model/transformer_nf.py +368 -0
- cosmoglint/utils/ReadPinocchio5.py +1022 -0
- cosmoglint/utils/__init__.py +2 -0
- cosmoglint/utils/cosmology_utils.py +194 -0
- cosmoglint/utils/generation_utils.py +366 -0
- cosmoglint/utils/io_utils.py +397 -0
- cosmoglint-1.0.0.dist-info/METADATA +164 -0
- cosmoglint-1.0.0.dist-info/RECORD +14 -0
- cosmoglint-1.0.0.dist-info/WHEEL +5 -0
- cosmoglint-1.0.0.dist-info/licenses/LICENSE +21 -0
- cosmoglint-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import argparse
|
|
4
|
+
import json
|
|
5
|
+
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
import warnings, os
|
|
11
|
+
|
|
12
|
+
import copy
|
|
13
|
+
|
|
14
|
+
def short_formatwarning(msg, category, filename, lineno, line=None):
|
|
15
|
+
return f"{os.path.basename(filename)}:{lineno}: {category.__name__}: {msg}\n"
|
|
16
|
+
|
|
17
|
+
warnings.formatwarning = short_formatwarning
|
|
18
|
+
warnings.filterwarnings("always", category=RuntimeWarning)
|
|
19
|
+
|
|
20
|
+
import astropy.units as u
|
|
21
|
+
from astropy.constants import c as cspeed # [m/s]
|
|
22
|
+
from astropy.cosmology import FlatLambdaCDM
|
|
23
|
+
cosmo_default = FlatLambdaCDM(H0=67.74, Om0=0.3089)
|
|
24
|
+
|
|
25
|
+
def cMpc_to_arcsec(l_cMpc, z, cosmo=cosmo_default, l_with_hlittle=False):
|
|
26
|
+
if l_with_hlittle:
|
|
27
|
+
hlittle = cosmo.H0.value / 100
|
|
28
|
+
l_cMpc = l_cMpc / hlittle # [Mpc/h] -> [Mpc]
|
|
29
|
+
l_rad = l_cMpc * u.Mpc / cosmo.comoving_transverse_distance(z)
|
|
30
|
+
l_arcsec = (l_rad * u.radian).to(u.arcsec)
|
|
31
|
+
return l_arcsec.value
|
|
32
|
+
|
|
33
|
+
def arcsec_to_cMpc(l_arcsec, z, cosmo=cosmo_default, l_with_hlittle=False):
|
|
34
|
+
l_rad = l_arcsec * u.arcsec / u.radian
|
|
35
|
+
l_cMpc = ( cosmo.comoving_transverse_distance(z) * l_rad ).to(u.Mpc)
|
|
36
|
+
if l_with_hlittle:
|
|
37
|
+
hlittle = cosmo.H0.value / 100
|
|
38
|
+
l_cMpc = l_cMpc * hlittle # [Mpc] -> [Mpc/h]
|
|
39
|
+
return l_cMpc.value
|
|
40
|
+
|
|
41
|
+
def dcMpc_to_dz(l_cMpc, z, cosmo=cosmo_default, l_with_hlittle=False):
|
|
42
|
+
if l_with_hlittle:
|
|
43
|
+
hlittle = cosmo.H0.value / 100
|
|
44
|
+
l_cMpc = l_cMpc / hlittle # [Mpc/h] -> [Mpc]
|
|
45
|
+
dx_dz = (cspeed / cosmo.H(z)).to(u.Mpc)
|
|
46
|
+
d_z = l_cMpc / dx_dz.value
|
|
47
|
+
return d_z
|
|
48
|
+
|
|
49
|
+
def dz_to_dcMpc(dz, z, cosmo=cosmo_default, l_with_hlittle=False):
|
|
50
|
+
dx_dz = (cspeed / cosmo.H(z)).to(u.Mpc)
|
|
51
|
+
l_cMpc = dz * dx_dz.value
|
|
52
|
+
if l_with_hlittle:
|
|
53
|
+
hlittle = cosmo.H0.value / 100
|
|
54
|
+
l_cMpc = l_cMpc * hlittle # [Mpc] -> [Mpc/h]
|
|
55
|
+
return l_cMpc # [Mpc/h] if l_with_hlittle else [Mpc]
|
|
56
|
+
|
|
57
|
+
def freq_to_comdis(nu_obs, nu_rest, cosmo=cosmo_default, l_with_hlittle=False):
|
|
58
|
+
z = nu_rest / nu_obs - 1
|
|
59
|
+
if z < 0:
|
|
60
|
+
print("Error: z < 0")
|
|
61
|
+
sys.exit(1)
|
|
62
|
+
l_cMpc = cosmo.comoving_distance(z).to(u.Mpc).value
|
|
63
|
+
if l_with_hlittle:
|
|
64
|
+
hlittle = cosmo.H0.value / 100
|
|
65
|
+
l_cMpc = l_cMpc * hlittle # [Mpc] -> [Mpc/h]
|
|
66
|
+
|
|
67
|
+
return l_cMpc # [Mpc/h] if l_with_hlittle else [Mpc]
|
|
68
|
+
|
|
69
|
+
def z_to_log_lumi_dis(z, cosmo=cosmo_default):
|
|
70
|
+
return np.log10( cosmo.luminosity_distance(z).to(u.cm).value )
|
|
71
|
+
|
|
72
|
+
def populate_galaxies_in_lightcone(args, logm, pos, redshift, cosmo=cosmo_default):
|
|
73
|
+
"""
|
|
74
|
+
args: args.gpu_id, args.model_dir, args.model_config_file, args.args.threshold, and args.param_dir are used
|
|
75
|
+
logm: (num_halos, )
|
|
76
|
+
pos: (num_halos, 3)
|
|
77
|
+
redshift: (num_halos, )
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
print("# Use Transformer to generate SFR")
|
|
81
|
+
|
|
82
|
+
opt = copy.deepcopy(args)
|
|
83
|
+
|
|
84
|
+
generated_all = []
|
|
85
|
+
pos_central_all = []
|
|
86
|
+
redshift_central_all = []
|
|
87
|
+
flag_central_all = []
|
|
88
|
+
|
|
89
|
+
with open(args.model_config_file, "r") as f:
|
|
90
|
+
snapshot_dict_str = json.load(f)
|
|
91
|
+
snapshot_dict = {int(k): v for k, v in snapshot_dict_str.items()}
|
|
92
|
+
|
|
93
|
+
print("# Model config:", snapshot_dict)
|
|
94
|
+
redshifts_of_snapshots = np.array([ v[1] for v in snapshot_dict.values() ])
|
|
95
|
+
bin_edges = (redshifts_of_snapshots[:-1] + redshifts_of_snapshots[1:]) / 2.0
|
|
96
|
+
bin_indices = np.digitize(redshift, bin_edges)
|
|
97
|
+
|
|
98
|
+
if args.param_dir is None:
|
|
99
|
+
max_sfr_file_list = [ None for snapshot_number in snapshot_dict ]
|
|
100
|
+
else:
|
|
101
|
+
max_sfr_file_list = ["{}/max_nbin20_{:d}.txt".format(args.param_dir, snapshot_number) for snapshot_number in snapshot_dict]
|
|
102
|
+
|
|
103
|
+
for i, snapshot_number in enumerate(snapshot_dict):
|
|
104
|
+
model_path, redshift_of_snapshot = snapshot_dict[snapshot_number]
|
|
105
|
+
print("# Snapshot number: {:d}, Redshift: {:.2f}".format(snapshot_number, redshift_of_snapshot))
|
|
106
|
+
|
|
107
|
+
### Skip if no haloes in this redshift bin
|
|
108
|
+
mask_z = (bin_indices == i)
|
|
109
|
+
if not np.any(mask_z):
|
|
110
|
+
print("# No haloes in redshift bin {:d} (snapshot number {:d}), skipping...".format(i, snapshot_number))
|
|
111
|
+
continue
|
|
112
|
+
|
|
113
|
+
logm_now = logm[mask_z] # (num_halos_in_bin, )
|
|
114
|
+
pos_now = pos[mask_z] # (num_halos_in_bin, 3)
|
|
115
|
+
redshift_now = redshift[mask_z] # (num_halos_in_bin, 1)
|
|
116
|
+
|
|
117
|
+
opt.model_dir = "{}/{}".format(args.model_dir, model_path)
|
|
118
|
+
opt.max_sfr_file = max_sfr_file_list[i]
|
|
119
|
+
opt.prob_threshold = 1e-5
|
|
120
|
+
|
|
121
|
+
if "Transformer_NF" in opt.model_dir:
|
|
122
|
+
raise ValueError("Transformer_NF model is not supported yet. Please use a different model.")
|
|
123
|
+
else:
|
|
124
|
+
from .generation_utils import generate_galaxy
|
|
125
|
+
generated, mask = generate_galaxy(opt, logm_now)
|
|
126
|
+
|
|
127
|
+
seq_length = mask.shape[1]
|
|
128
|
+
num_features = generated.shape[-1]
|
|
129
|
+
|
|
130
|
+
# Define flag_central
|
|
131
|
+
flag_central = np.zeros_like(mask, dtype=bool)
|
|
132
|
+
flag_central[:, 0] = True
|
|
133
|
+
|
|
134
|
+
# Flatten the arrays
|
|
135
|
+
mask = mask.reshape(-1)
|
|
136
|
+
generated = generated.reshape(-1, num_features) # (num_halos * seq_length, num_features)
|
|
137
|
+
pos_central = np.repeat(pos_now[:,None,:], seq_length, axis=1).reshape(-1, 3) # (num_halos * seq_length, 3)
|
|
138
|
+
redshift_central = np.repeat(redshift_now[:,None], seq_length, axis=1).reshape(-1) # (num_halos * seq_length, 3)
|
|
139
|
+
flag_central = flag_central.reshape(-1)
|
|
140
|
+
|
|
141
|
+
# Apply mask to arrays
|
|
142
|
+
generated = generated[mask] # (num_galaxies_valid, num_features)
|
|
143
|
+
pos_central = pos_central[mask] # (num_galaxies_valid, 3)
|
|
144
|
+
redshift_central = redshift_central[mask] # (num_galaxies_valid, 3)
|
|
145
|
+
flag_central = flag_central[mask] # (num_galaxies_valid, )
|
|
146
|
+
|
|
147
|
+
# Append
|
|
148
|
+
generated_all.append(generated)
|
|
149
|
+
pos_central_all.append(pos_central)
|
|
150
|
+
redshift_central_all.append(redshift_central)
|
|
151
|
+
flag_central_all.append(flag_central)
|
|
152
|
+
|
|
153
|
+
generated_all = np.concatenate(generated_all, axis=0) # (num_galaxies_valid, num_features)
|
|
154
|
+
pos_central_all = np.concatenate(pos_central_all, axis=0)
|
|
155
|
+
redshift_central_all = np.concatenate(redshift_central_all, axis=0) # (num_galaxies_valid,)
|
|
156
|
+
flag_central_all = np.concatenate(flag_central_all, axis=0) # (num_galaxies_valid,)
|
|
157
|
+
|
|
158
|
+
### Distributes galaxies in lightcone
|
|
159
|
+
sfr = generated_all[:,0]
|
|
160
|
+
distance = generated_all[:,1]
|
|
161
|
+
|
|
162
|
+
num_gal = len(sfr)
|
|
163
|
+
|
|
164
|
+
# Determine positions of galaxies
|
|
165
|
+
print("# Generate positions of galaxies")
|
|
166
|
+
_phi = np.random.uniform(0, 2 * np.pi, size=num_gal)
|
|
167
|
+
_cos_theta = np.random.uniform(-1, 1, size=num_gal)
|
|
168
|
+
_sin_theta = np.sqrt(1 - _cos_theta ** 2)
|
|
169
|
+
|
|
170
|
+
# Convert Mpc to deg
|
|
171
|
+
distance_arcsec = cMpc_to_arcsec(distance, redshift_central_all, cosmo=cosmo, l_with_hlittle=True)
|
|
172
|
+
distance_z = dcMpc_to_dz(distance, redshift_central_all, cosmo=cosmo, l_with_hlittle=True)
|
|
173
|
+
|
|
174
|
+
pos_galaxies = pos_central
|
|
175
|
+
pos_galaxies[:,0] += distance_arcsec * _sin_theta * np.cos(_phi)
|
|
176
|
+
pos_galaxies[:,1] += distance_arcsec * _sin_theta * np.sin(_phi)
|
|
177
|
+
pos_galaxies[:,2] += distance_z * _cos_theta
|
|
178
|
+
|
|
179
|
+
# Add redshift-space distortion
|
|
180
|
+
if args.redshift_space:
|
|
181
|
+
|
|
182
|
+
relative_vel_rad = generated[:,2]
|
|
183
|
+
relative_vel_tan = generated[:,3]
|
|
184
|
+
relative_vel_rad[flag_central] = 0 # Set vr to 0 for central galaxies
|
|
185
|
+
alpha = np.random.uniform(0, 2 * np.pi, size=num_gal)
|
|
186
|
+
vz_gal = - relative_vel_rad * _cos_theta + relative_vel_tan * _sin_theta * np.cos(alpha)
|
|
187
|
+
|
|
188
|
+
beta = vz_gal / (cspeed * 100) # [(km/s) / (km/s)]
|
|
189
|
+
|
|
190
|
+
redshift_rest = pos_galaxies[:,2]
|
|
191
|
+
pos_galaxies[:,2] = ( 1. + redshift_rest ) * np.sqrt( (1. + beta) / (1. - beta) ) - 1.0
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
return sfr, pos_galaxies, redshift_central_all
|
|
@@ -0,0 +1,366 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import argparse
|
|
4
|
+
import json
|
|
5
|
+
import copy
|
|
6
|
+
import h5py
|
|
7
|
+
import re
|
|
8
|
+
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
|
|
15
|
+
cspeed = 3e10 # [cm/s]
|
|
16
|
+
|
|
17
|
+
from .io_utils import normalize, namespace_to_dict
|
|
18
|
+
|
|
19
|
+
def create_mask(array, threshold):
|
|
20
|
+
"""
|
|
21
|
+
mask out galaxies satisfying either of the following:
|
|
22
|
+
- if it is central and the value is below threshold
|
|
23
|
+
- if it is satellite and any of the satellites before it is below threshold
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
_, seq_length = array.shape
|
|
27
|
+
|
|
28
|
+
mask_valid = array > threshold # (num_halos, max_length)
|
|
29
|
+
mask_below = array <= threshold # (num_halos, max_length)
|
|
30
|
+
mask_below[:, 0] = False
|
|
31
|
+
|
|
32
|
+
first_below = np.where(mask_below.any(axis=1), mask_below.argmax(axis=1), seq_length)
|
|
33
|
+
indices = np.arange(seq_length)[None, :] # (1, max_length)
|
|
34
|
+
mask = indices < first_below[:, None] # (num_halos, max_length)
|
|
35
|
+
|
|
36
|
+
mask = mask & mask_valid # (num_halos, max_length)
|
|
37
|
+
|
|
38
|
+
return mask
|
|
39
|
+
|
|
40
|
+
def generate_galaxy(args, x_in, global_params=None, verbose=True):
|
|
41
|
+
"""
|
|
42
|
+
args: args.gpu_id, args.model_dir, args.threshold, and args.max_sfr_file are used
|
|
43
|
+
x_in: (num_halos, num_features_in); halo properties
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
print("# Use Transformer to generate SFR")
|
|
47
|
+
|
|
48
|
+
from cosmoglint.model.transformer import transformer_model
|
|
49
|
+
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else "cpu")
|
|
50
|
+
|
|
51
|
+
### load Transformer
|
|
52
|
+
with open("{}/args.json".format(args.model_dir), "r") as f:
|
|
53
|
+
opt = json.load(f, object_hook=lambda d: argparse.Namespace(**d))
|
|
54
|
+
opt.norm_param_dict = namespace_to_dict(opt.norm_param_dict)
|
|
55
|
+
|
|
56
|
+
model = transformer_model(opt)
|
|
57
|
+
model.load_state_dict(torch.load("{}/model.pth".format(args.model_dir), map_location="cpu"))
|
|
58
|
+
model.to(device)
|
|
59
|
+
model.eval()
|
|
60
|
+
|
|
61
|
+
if verbose:
|
|
62
|
+
print("opt: ", opt)
|
|
63
|
+
print(model)
|
|
64
|
+
|
|
65
|
+
### generate galaxies
|
|
66
|
+
print("# Generate galaxies (batch size: {:d})".format(opt.batch_size))
|
|
67
|
+
for i, key in enumerate(opt.input_features):
|
|
68
|
+
x_in[...,i] = normalize(x_in[...,i], key, opt.norm_param_dict)
|
|
69
|
+
x_in = torch.from_numpy(x_in).float().to(device)
|
|
70
|
+
|
|
71
|
+
if global_params is not None:
|
|
72
|
+
global_params = global_params[opt.global_features].to_numpy(dtype=np.float32)
|
|
73
|
+
for i, key in enumerate(opt.global_features):
|
|
74
|
+
global_params[...,i] = normalize(global_params[...,i], key, opt.norm_param_dict)
|
|
75
|
+
global_params = torch.tensor(np.array(global_params), dtype=torch.float32).to(device)
|
|
76
|
+
|
|
77
|
+
if args.max_sfr_file is None:
|
|
78
|
+
print("# No max SFR file provided, using default max IDs")
|
|
79
|
+
max_ids = None
|
|
80
|
+
else:
|
|
81
|
+
max_ids = np.loadtxt(args.max_sfr_file)
|
|
82
|
+
max_ids = ( max_ids * opt.num_features_out ).astype(int)
|
|
83
|
+
max_ids = torch.tensor(max_ids).to(device) # (num_features, )
|
|
84
|
+
|
|
85
|
+
num_batch = (len(x_in) + opt.batch_size - 1) // opt.batch_size
|
|
86
|
+
stop_criterion = normalize(args.threshold, opt.output_features[0], opt.norm_param_dict) # stop criterion for SFR
|
|
87
|
+
generated = []
|
|
88
|
+
for batch_idx in tqdm(range(num_batch)):
|
|
89
|
+
start = batch_idx * opt.batch_size
|
|
90
|
+
x_batch = x_in[start: start + opt.batch_size] # (batch_size, num_features)
|
|
91
|
+
global_cond_batch = global_params.unsqueeze(0).repeat(len(x_batch), 1) if global_params is not None else None # (batch_size, num_global_features)
|
|
92
|
+
with torch.no_grad():
|
|
93
|
+
generated_batch, _ = model.generate(x_batch, global_cond=global_cond_batch, prob_threshold=1e-5, stop_criterion=stop_criterion, max_ids=max_ids) # (batch_size, seq_length, num_features)
|
|
94
|
+
|
|
95
|
+
generated.append(generated_batch.cpu().detach().numpy())
|
|
96
|
+
|
|
97
|
+
generated = np.concatenate(generated, axis=0) # (num_halos, seq_length, num_features) or (num_halos, seq_length * num_features, 1)
|
|
98
|
+
|
|
99
|
+
if opt.use_flat_representation:
|
|
100
|
+
generated = generated.squeeze(-1).reshape(len(generated), -1, opt.num_features_in) # (num_halos, max_length, num_features)
|
|
101
|
+
mask = mask.reshape(len(mask), -1, opt.num_features_in)
|
|
102
|
+
|
|
103
|
+
mask = create_mask(generated[:,:,0], stop_criterion) # (num_halos, seq_length)
|
|
104
|
+
|
|
105
|
+
# De-normalize
|
|
106
|
+
for i, key in enumerate(opt.output_features):
|
|
107
|
+
generated[...,i] = normalize(generated[...,i], key, opt.norm_param_dict, inverse=True)
|
|
108
|
+
|
|
109
|
+
print("# Number of valid galaxies: {:d}".format(len(generated)))
|
|
110
|
+
|
|
111
|
+
return generated, mask
|
|
112
|
+
|
|
113
|
+
def generate_galaxy_TransNF(args, x_in, global_params=None, verbose=True):
|
|
114
|
+
"""
|
|
115
|
+
args: args.gpu_id, args.model_dir, and args.threshold are used
|
|
116
|
+
x_in: (num_halos, num_features_in), halo properties
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
print("# Use Transformer-NF to generate galaxies")
|
|
120
|
+
|
|
121
|
+
from cosmoglint.model.transformer_nf import transformer_nf_model, generate_with_transformer_nf
|
|
122
|
+
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else "cpu")
|
|
123
|
+
print("Using device: ", device)
|
|
124
|
+
|
|
125
|
+
### load Transformer
|
|
126
|
+
with open("{}/args.json".format(args.model_dir), "r") as f:
|
|
127
|
+
opt = json.load(f, object_hook=lambda d: argparse.Namespace(**d))
|
|
128
|
+
opt.norm_param_dict = namespace_to_dict(opt.norm_param_dict)
|
|
129
|
+
|
|
130
|
+
model, flow = transformer_nf_model(opt)
|
|
131
|
+
|
|
132
|
+
model.load_state_dict(torch.load("{}/model.pth".format(args.model_dir), map_location="cpu"))
|
|
133
|
+
model.to(device)
|
|
134
|
+
model.eval()
|
|
135
|
+
|
|
136
|
+
flow.load_state_dict(torch.load("{}/flow.pth".format(args.model_dir), map_location="cpu"))
|
|
137
|
+
flow.to(device)
|
|
138
|
+
flow.eval()
|
|
139
|
+
|
|
140
|
+
if verbose:
|
|
141
|
+
print("opt: ", opt)
|
|
142
|
+
print(model)
|
|
143
|
+
print(flow)
|
|
144
|
+
|
|
145
|
+
### generate galaxies
|
|
146
|
+
print("# Generate galaxies (batch size: {:d})".format(opt.batch_size))
|
|
147
|
+
|
|
148
|
+
for i, key in enumerate(opt.input_features):
|
|
149
|
+
x_in[...,i] = normalize(x_in[...,i], key, opt.norm_param_dict)
|
|
150
|
+
x_in = torch.from_numpy(x_in).float().to(device)
|
|
151
|
+
|
|
152
|
+
if global_params is not None:
|
|
153
|
+
global_params = np.array([global_params[name] for name in opt.global_features], dtype=np.float32)
|
|
154
|
+
for i, key in enumerate(opt.global_features):
|
|
155
|
+
global_params[...,i] = normalize(global_params[...,i], key, opt.norm_param_dict)
|
|
156
|
+
global_params = torch.from_numpy(global_params).float().to(device)
|
|
157
|
+
|
|
158
|
+
num_batch = (len(x_in) + opt.batch_size - 1) // opt.batch_size
|
|
159
|
+
generated = []
|
|
160
|
+
def stop_criterion(sample):
|
|
161
|
+
# sample: (batch, num_features)
|
|
162
|
+
return (sample[:, 0] < 1).all()
|
|
163
|
+
|
|
164
|
+
for batch_idx in tqdm(range(num_batch)):
|
|
165
|
+
start = batch_idx * opt.batch_size
|
|
166
|
+
x_batch = x_in[start: start + opt.batch_size] # (batch_size, 1)
|
|
167
|
+
global_cond_batch = global_params.unsqueeze(0).repeat(len(x_batch), 1) if global_params is not None else None
|
|
168
|
+
generated_batch = generate_with_transformer_nf(model, flow, x_batch, global_cond=global_cond_batch, stop_criterion=stop_criterion) # (batch_size, max_length, num_features)
|
|
169
|
+
generated.append(generated_batch.cpu().detach().numpy())
|
|
170
|
+
generated = torch.cat(generated, dim=0) # (num_halos, max_length, num_features) or (num_halos, max_length * num_features, 1)
|
|
171
|
+
|
|
172
|
+
# De-normalize
|
|
173
|
+
for i, key in enumerate(opt.output_features):
|
|
174
|
+
generated[...,i] = normalize(generated[...,i], key, opt.norm_param_dict, inverse=True)
|
|
175
|
+
|
|
176
|
+
# Set mask for selection
|
|
177
|
+
sfr = generated[...,0]
|
|
178
|
+
mask = create_mask(sfr, args.threshold) # (num_halos, seq_length)
|
|
179
|
+
|
|
180
|
+
print("# Number of valid galaxies: {:d}".format(len(generated)))
|
|
181
|
+
|
|
182
|
+
return generated, mask
|
|
183
|
+
|
|
184
|
+
def populate_galaxies_in_cube(args, x_in, pos, vel, redshift, cosmo, global_params=None):
|
|
185
|
+
|
|
186
|
+
if "Transformer_NF" in args.model_dir:
|
|
187
|
+
generated, mask = generate_galaxy_TransNF(args, x_in, global_params=global_params)
|
|
188
|
+
else:
|
|
189
|
+
generated, mask = generate_galaxy(args, x_in, global_params=global_params)
|
|
190
|
+
|
|
191
|
+
seq_length = mask.shape[1]
|
|
192
|
+
num_features = generated.shape[-1]
|
|
193
|
+
num_gal = mask.sum()
|
|
194
|
+
|
|
195
|
+
# Define flag_central
|
|
196
|
+
flag_central = np.zeros_like(mask, dtype=bool) # (num_halos, seq_length)
|
|
197
|
+
flag_central[:, 0] = True
|
|
198
|
+
|
|
199
|
+
# Flatten the arrays
|
|
200
|
+
mask = mask.reshape(-1) # (num_halos * seq_length, )
|
|
201
|
+
generated = generated.reshape(-1, num_features) # (num_halos * seq_length, num_features)
|
|
202
|
+
pos_central = np.repeat(pos[:,None,:], seq_length, axis=1).reshape(-1, 3) # (num_halos * seq_length, 3)
|
|
203
|
+
vel_central = np.repeat(vel[:,None,:], seq_length, axis=1).reshape(-1, 3) # (num_halos * seq_length, 3)
|
|
204
|
+
flag_central = flag_central.reshape(-1) # (num_halos * seq_length, )
|
|
205
|
+
|
|
206
|
+
# Apply mask to arrays
|
|
207
|
+
generated = generated[mask] # (num_galaxies_valid, num_features)
|
|
208
|
+
pos_central = pos_central[mask] # (num_galaxies_valid, 3)
|
|
209
|
+
vel_central = vel_central[mask] # (num_galaxies_valid, 3)
|
|
210
|
+
flag_central = flag_central[mask] # (num_galaxies_valid, )
|
|
211
|
+
|
|
212
|
+
# Distribute galaxies in cube
|
|
213
|
+
print("# Generate positions of galaxies")
|
|
214
|
+
|
|
215
|
+
sfr = generated[:,0]
|
|
216
|
+
distance = generated[:,1]
|
|
217
|
+
|
|
218
|
+
phi = np.random.uniform(0, 2 * np.pi, size=num_gal)
|
|
219
|
+
cos_theta = np.random.uniform(-1, 1, size=num_gal)
|
|
220
|
+
sin_theta = np.sqrt(1 - cos_theta ** 2)
|
|
221
|
+
|
|
222
|
+
pos_galaxies = pos_central
|
|
223
|
+
pos_galaxies[:,0] += distance * sin_theta * np.cos(phi)
|
|
224
|
+
pos_galaxies[:,1] += distance * sin_theta * np.sin(phi)
|
|
225
|
+
pos_galaxies[:,2] += distance * cos_theta
|
|
226
|
+
|
|
227
|
+
pos_galaxies_real = copy.deepcopy(pos_galaxies)
|
|
228
|
+
|
|
229
|
+
# Add redshift-space distortion
|
|
230
|
+
if args.redshift_space:
|
|
231
|
+
import astropy.units as u
|
|
232
|
+
H = cosmo.H(redshift).to(u.km/u.s/u.Mpc).value #[km/s/Mpc]
|
|
233
|
+
hlittle = cosmo.H(0).to(u.km/u.s/u.Mpc).value / 100.0
|
|
234
|
+
scale_factor = 1 / (1 + redshift)
|
|
235
|
+
|
|
236
|
+
relative_vel_rad = generated[:,2]
|
|
237
|
+
relative_vel_tan = generated[:,3]
|
|
238
|
+
relative_vel_rad[flag_central] = 0 # Set vr to 0 for central galaxies
|
|
239
|
+
alpha = np.random.uniform(0, 2 * np.pi, size=num_gal)
|
|
240
|
+
vz_gal = - relative_vel_rad * cos_theta + relative_vel_tan * sin_theta * np.cos(alpha)
|
|
241
|
+
pos_galaxies[:,2] += ( vel_central[:,2] + vz_gal )/ scale_factor / H * hlittle
|
|
242
|
+
|
|
243
|
+
return sfr, pos_galaxies_real, pos_galaxies
|
|
244
|
+
|
|
245
|
+
def populate_galaxies_in_lightcone(args, x_in, pos, redshift, cosmo, global_params=None):
|
|
246
|
+
"""
|
|
247
|
+
args: args.gpu_id, args.model_dir, args.model_config_file, args.args.threshold, and args.param_dir are used
|
|
248
|
+
x_in: (num_halos, )
|
|
249
|
+
pos: (num_halos, 3)
|
|
250
|
+
redshift: (num_halos, )
|
|
251
|
+
"""
|
|
252
|
+
|
|
253
|
+
print("# Use Transformer to generate SFR")
|
|
254
|
+
|
|
255
|
+
opt = copy.deepcopy(args)
|
|
256
|
+
|
|
257
|
+
generated_all = []
|
|
258
|
+
pos_central_all = []
|
|
259
|
+
redshift_central_all = []
|
|
260
|
+
flag_central_all = []
|
|
261
|
+
|
|
262
|
+
with open(args.model_config_file, "r") as f:
|
|
263
|
+
snapshot_dict_str = json.load(f)
|
|
264
|
+
snapshot_dict = {int(k): v for k, v in snapshot_dict_str.items()}
|
|
265
|
+
|
|
266
|
+
print("# Model config:", snapshot_dict)
|
|
267
|
+
redshifts_of_snapshots = np.array([ v[1] for v in snapshot_dict.values() ])
|
|
268
|
+
bin_edges = (redshifts_of_snapshots[:-1] + redshifts_of_snapshots[1:]) / 2.0
|
|
269
|
+
bin_indices = np.digitize(redshift, bin_edges)
|
|
270
|
+
|
|
271
|
+
if args.param_dir is None:
|
|
272
|
+
max_sfr_file_list = [ None for snapshot_number in snapshot_dict ]
|
|
273
|
+
else:
|
|
274
|
+
max_sfr_file_list = ["{}/max_nbin20_{:d}.txt".format(args.param_dir, snapshot_number) for snapshot_number in snapshot_dict]
|
|
275
|
+
|
|
276
|
+
for i, snapshot_number in enumerate(snapshot_dict):
|
|
277
|
+
model_path, redshift_of_snapshot = snapshot_dict[snapshot_number]
|
|
278
|
+
print("# Snapshot number: {:d}, Redshift: {:.2f}".format(snapshot_number, redshift_of_snapshot))
|
|
279
|
+
|
|
280
|
+
### Skip if no haloes in this redshift bin
|
|
281
|
+
mask_z = (bin_indices == i)
|
|
282
|
+
if not np.any(mask_z):
|
|
283
|
+
print("# No haloes in redshift bin {:d} (snapshot number {:d}), skipping...".format(i, snapshot_number))
|
|
284
|
+
continue
|
|
285
|
+
|
|
286
|
+
x_now = x_in[mask_z] # (num_halos_in_bin, )
|
|
287
|
+
pos_now = pos[mask_z] # (num_halos_in_bin, 3)
|
|
288
|
+
redshift_now = redshift[mask_z] # (num_halos_in_bin, 1)
|
|
289
|
+
|
|
290
|
+
opt.model_dir = "{}/{}".format(args.model_dir, model_path)
|
|
291
|
+
opt.max_sfr_file = max_sfr_file_list[i]
|
|
292
|
+
|
|
293
|
+
if "Transformer_NF" in opt.model_dir:
|
|
294
|
+
generated, mask = generate_galaxy_TransNF(opt, x_now, global_params=global_params, verbose=False)
|
|
295
|
+
else:
|
|
296
|
+
generated, mask = generate_galaxy(opt, x_now, global_params=global_params, verbose=False)
|
|
297
|
+
|
|
298
|
+
seq_length = mask.shape[1]
|
|
299
|
+
num_features = generated.shape[-1]
|
|
300
|
+
|
|
301
|
+
# Define flag_central
|
|
302
|
+
flag_central = np.zeros_like(mask, dtype=bool)
|
|
303
|
+
flag_central[:, 0] = True
|
|
304
|
+
|
|
305
|
+
# Flatten the arrays
|
|
306
|
+
mask = mask.reshape(-1)
|
|
307
|
+
generated = generated.reshape(-1, num_features) # (num_halos * seq_length, num_features)
|
|
308
|
+
pos_central = np.repeat(pos_now[:,None,:], seq_length, axis=1).reshape(-1, 3) # (num_halos * seq_length, 3)
|
|
309
|
+
redshift_central = np.repeat(redshift_now[:,None], seq_length, axis=1).reshape(-1) # (num_halos * seq_length)
|
|
310
|
+
flag_central = flag_central.reshape(-1)
|
|
311
|
+
|
|
312
|
+
# Apply mask to arrays
|
|
313
|
+
generated = generated[mask] # (num_galaxies_valid, num_features)
|
|
314
|
+
pos_central = pos_central[mask] # (num_galaxies_valid, 3)
|
|
315
|
+
redshift_central = redshift_central[mask] # (num_galaxies_valid, 3)
|
|
316
|
+
flag_central = flag_central[mask] # (num_galaxies_valid, )
|
|
317
|
+
|
|
318
|
+
# Append
|
|
319
|
+
generated_all.append(generated)
|
|
320
|
+
pos_central_all.append(pos_central)
|
|
321
|
+
redshift_central_all.append(redshift_central)
|
|
322
|
+
flag_central_all.append(flag_central)
|
|
323
|
+
|
|
324
|
+
generated_all = np.concatenate(generated_all, axis=0) # (num_galaxies_valid, num_features)
|
|
325
|
+
pos_central_all = np.concatenate(pos_central_all, axis=0)
|
|
326
|
+
redshift_central_all = np.concatenate(redshift_central_all, axis=0) # (num_galaxies_valid,)
|
|
327
|
+
flag_central_all = np.concatenate(flag_central_all, axis=0) # (num_galaxies_valid,)
|
|
328
|
+
|
|
329
|
+
### Distribute galaxies in lightcone
|
|
330
|
+
sfr = generated_all[:,0]
|
|
331
|
+
distance = generated_all[:,1]
|
|
332
|
+
|
|
333
|
+
num_gal = len(sfr)
|
|
334
|
+
|
|
335
|
+
# Determine positions of galaxies
|
|
336
|
+
print("# Generate positions of galaxies")
|
|
337
|
+
_phi = np.random.uniform(0, 2 * np.pi, size=num_gal)
|
|
338
|
+
_cos_theta = np.random.uniform(-1, 1, size=num_gal)
|
|
339
|
+
_sin_theta = np.sqrt(1 - _cos_theta ** 2)
|
|
340
|
+
|
|
341
|
+
# Convert Mpc to deg
|
|
342
|
+
from .cosmology_utils import cMpc_to_arcsec, dcMpc_to_dz
|
|
343
|
+
distance_arcsec = cMpc_to_arcsec(distance, redshift_central_all, cosmo=cosmo, l_with_hlittle=True)
|
|
344
|
+
distance_z = dcMpc_to_dz(distance, redshift_central_all, cosmo=cosmo, l_with_hlittle=True)
|
|
345
|
+
|
|
346
|
+
pos_galaxies = pos_central_all
|
|
347
|
+
pos_galaxies[:,0] += distance_arcsec * _sin_theta * np.cos(_phi)
|
|
348
|
+
pos_galaxies[:,1] += distance_arcsec * _sin_theta * np.sin(_phi)
|
|
349
|
+
pos_galaxies[:,2] += distance_z * _cos_theta
|
|
350
|
+
|
|
351
|
+
# Add redshift-space distortion
|
|
352
|
+
if args.redshift_space:
|
|
353
|
+
|
|
354
|
+
relative_vel_rad = generated_all[:,2]
|
|
355
|
+
relative_vel_tan = generated_all[:,3]
|
|
356
|
+
relative_vel_rad[flag_central_all] = 0 # Set vr to 0 for central galaxies
|
|
357
|
+
alpha = np.random.uniform(0, 2 * np.pi, size=num_gal)
|
|
358
|
+
vz_gal = - relative_vel_rad * _cos_theta + relative_vel_tan * _sin_theta * np.cos(alpha)
|
|
359
|
+
|
|
360
|
+
beta = vz_gal / (cspeed * 100) # [(km/s) / (km/s)]
|
|
361
|
+
|
|
362
|
+
redshift_rest = pos_galaxies[:,2]
|
|
363
|
+
pos_galaxies[:,2] = ( 1. + redshift_rest ) * np.sqrt( (1. + beta) / (1. - beta) ) - 1.0
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
return sfr, pos_galaxies, redshift_central_all
|