cosmoglint 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2 @@
1
+ from .generation_utils import generate_galaxy, generate_galaxy_TransNF, populate_galaxies_in_lightcone
2
+ from .io_utils import MyDataset, load_global_params, load_halo_data, load_lightcone_data, my_save_model
@@ -0,0 +1,194 @@
1
+ import os
2
+ import sys
3
+ import argparse
4
+ import json
5
+
6
+ from tqdm import tqdm
7
+ import numpy as np
8
+ import torch
9
+
10
+ import warnings, os
11
+
12
+ import copy
13
+
14
+ def short_formatwarning(msg, category, filename, lineno, line=None):
15
+ return f"{os.path.basename(filename)}:{lineno}: {category.__name__}: {msg}\n"
16
+
17
+ warnings.formatwarning = short_formatwarning
18
+ warnings.filterwarnings("always", category=RuntimeWarning)
19
+
20
+ import astropy.units as u
21
+ from astropy.constants import c as cspeed # [m/s]
22
+ from astropy.cosmology import FlatLambdaCDM
23
+ cosmo_default = FlatLambdaCDM(H0=67.74, Om0=0.3089)
24
+
25
+ def cMpc_to_arcsec(l_cMpc, z, cosmo=cosmo_default, l_with_hlittle=False):
26
+ if l_with_hlittle:
27
+ hlittle = cosmo.H0.value / 100
28
+ l_cMpc = l_cMpc / hlittle # [Mpc/h] -> [Mpc]
29
+ l_rad = l_cMpc * u.Mpc / cosmo.comoving_transverse_distance(z)
30
+ l_arcsec = (l_rad * u.radian).to(u.arcsec)
31
+ return l_arcsec.value
32
+
33
+ def arcsec_to_cMpc(l_arcsec, z, cosmo=cosmo_default, l_with_hlittle=False):
34
+ l_rad = l_arcsec * u.arcsec / u.radian
35
+ l_cMpc = ( cosmo.comoving_transverse_distance(z) * l_rad ).to(u.Mpc)
36
+ if l_with_hlittle:
37
+ hlittle = cosmo.H0.value / 100
38
+ l_cMpc = l_cMpc * hlittle # [Mpc] -> [Mpc/h]
39
+ return l_cMpc.value
40
+
41
+ def dcMpc_to_dz(l_cMpc, z, cosmo=cosmo_default, l_with_hlittle=False):
42
+ if l_with_hlittle:
43
+ hlittle = cosmo.H0.value / 100
44
+ l_cMpc = l_cMpc / hlittle # [Mpc/h] -> [Mpc]
45
+ dx_dz = (cspeed / cosmo.H(z)).to(u.Mpc)
46
+ d_z = l_cMpc / dx_dz.value
47
+ return d_z
48
+
49
+ def dz_to_dcMpc(dz, z, cosmo=cosmo_default, l_with_hlittle=False):
50
+ dx_dz = (cspeed / cosmo.H(z)).to(u.Mpc)
51
+ l_cMpc = dz * dx_dz.value
52
+ if l_with_hlittle:
53
+ hlittle = cosmo.H0.value / 100
54
+ l_cMpc = l_cMpc * hlittle # [Mpc] -> [Mpc/h]
55
+ return l_cMpc # [Mpc/h] if l_with_hlittle else [Mpc]
56
+
57
+ def freq_to_comdis(nu_obs, nu_rest, cosmo=cosmo_default, l_with_hlittle=False):
58
+ z = nu_rest / nu_obs - 1
59
+ if z < 0:
60
+ print("Error: z < 0")
61
+ sys.exit(1)
62
+ l_cMpc = cosmo.comoving_distance(z).to(u.Mpc).value
63
+ if l_with_hlittle:
64
+ hlittle = cosmo.H0.value / 100
65
+ l_cMpc = l_cMpc * hlittle # [Mpc] -> [Mpc/h]
66
+
67
+ return l_cMpc # [Mpc/h] if l_with_hlittle else [Mpc]
68
+
69
+ def z_to_log_lumi_dis(z, cosmo=cosmo_default):
70
+ return np.log10( cosmo.luminosity_distance(z).to(u.cm).value )
71
+
72
+ def populate_galaxies_in_lightcone(args, logm, pos, redshift, cosmo=cosmo_default):
73
+ """
74
+ args: args.gpu_id, args.model_dir, args.model_config_file, args.args.threshold, and args.param_dir are used
75
+ logm: (num_halos, )
76
+ pos: (num_halos, 3)
77
+ redshift: (num_halos, )
78
+ """
79
+
80
+ print("# Use Transformer to generate SFR")
81
+
82
+ opt = copy.deepcopy(args)
83
+
84
+ generated_all = []
85
+ pos_central_all = []
86
+ redshift_central_all = []
87
+ flag_central_all = []
88
+
89
+ with open(args.model_config_file, "r") as f:
90
+ snapshot_dict_str = json.load(f)
91
+ snapshot_dict = {int(k): v for k, v in snapshot_dict_str.items()}
92
+
93
+ print("# Model config:", snapshot_dict)
94
+ redshifts_of_snapshots = np.array([ v[1] for v in snapshot_dict.values() ])
95
+ bin_edges = (redshifts_of_snapshots[:-1] + redshifts_of_snapshots[1:]) / 2.0
96
+ bin_indices = np.digitize(redshift, bin_edges)
97
+
98
+ if args.param_dir is None:
99
+ max_sfr_file_list = [ None for snapshot_number in snapshot_dict ]
100
+ else:
101
+ max_sfr_file_list = ["{}/max_nbin20_{:d}.txt".format(args.param_dir, snapshot_number) for snapshot_number in snapshot_dict]
102
+
103
+ for i, snapshot_number in enumerate(snapshot_dict):
104
+ model_path, redshift_of_snapshot = snapshot_dict[snapshot_number]
105
+ print("# Snapshot number: {:d}, Redshift: {:.2f}".format(snapshot_number, redshift_of_snapshot))
106
+
107
+ ### Skip if no haloes in this redshift bin
108
+ mask_z = (bin_indices == i)
109
+ if not np.any(mask_z):
110
+ print("# No haloes in redshift bin {:d} (snapshot number {:d}), skipping...".format(i, snapshot_number))
111
+ continue
112
+
113
+ logm_now = logm[mask_z] # (num_halos_in_bin, )
114
+ pos_now = pos[mask_z] # (num_halos_in_bin, 3)
115
+ redshift_now = redshift[mask_z] # (num_halos_in_bin, 1)
116
+
117
+ opt.model_dir = "{}/{}".format(args.model_dir, model_path)
118
+ opt.max_sfr_file = max_sfr_file_list[i]
119
+ opt.prob_threshold = 1e-5
120
+
121
+ if "Transformer_NF" in opt.model_dir:
122
+ raise ValueError("Transformer_NF model is not supported yet. Please use a different model.")
123
+ else:
124
+ from .generation_utils import generate_galaxy
125
+ generated, mask = generate_galaxy(opt, logm_now)
126
+
127
+ seq_length = mask.shape[1]
128
+ num_features = generated.shape[-1]
129
+
130
+ # Define flag_central
131
+ flag_central = np.zeros_like(mask, dtype=bool)
132
+ flag_central[:, 0] = True
133
+
134
+ # Flatten the arrays
135
+ mask = mask.reshape(-1)
136
+ generated = generated.reshape(-1, num_features) # (num_halos * seq_length, num_features)
137
+ pos_central = np.repeat(pos_now[:,None,:], seq_length, axis=1).reshape(-1, 3) # (num_halos * seq_length, 3)
138
+ redshift_central = np.repeat(redshift_now[:,None], seq_length, axis=1).reshape(-1) # (num_halos * seq_length, 3)
139
+ flag_central = flag_central.reshape(-1)
140
+
141
+ # Apply mask to arrays
142
+ generated = generated[mask] # (num_galaxies_valid, num_features)
143
+ pos_central = pos_central[mask] # (num_galaxies_valid, 3)
144
+ redshift_central = redshift_central[mask] # (num_galaxies_valid, 3)
145
+ flag_central = flag_central[mask] # (num_galaxies_valid, )
146
+
147
+ # Append
148
+ generated_all.append(generated)
149
+ pos_central_all.append(pos_central)
150
+ redshift_central_all.append(redshift_central)
151
+ flag_central_all.append(flag_central)
152
+
153
+ generated_all = np.concatenate(generated_all, axis=0) # (num_galaxies_valid, num_features)
154
+ pos_central_all = np.concatenate(pos_central_all, axis=0)
155
+ redshift_central_all = np.concatenate(redshift_central_all, axis=0) # (num_galaxies_valid,)
156
+ flag_central_all = np.concatenate(flag_central_all, axis=0) # (num_galaxies_valid,)
157
+
158
+ ### Distributes galaxies in lightcone
159
+ sfr = generated_all[:,0]
160
+ distance = generated_all[:,1]
161
+
162
+ num_gal = len(sfr)
163
+
164
+ # Determine positions of galaxies
165
+ print("# Generate positions of galaxies")
166
+ _phi = np.random.uniform(0, 2 * np.pi, size=num_gal)
167
+ _cos_theta = np.random.uniform(-1, 1, size=num_gal)
168
+ _sin_theta = np.sqrt(1 - _cos_theta ** 2)
169
+
170
+ # Convert Mpc to deg
171
+ distance_arcsec = cMpc_to_arcsec(distance, redshift_central_all, cosmo=cosmo, l_with_hlittle=True)
172
+ distance_z = dcMpc_to_dz(distance, redshift_central_all, cosmo=cosmo, l_with_hlittle=True)
173
+
174
+ pos_galaxies = pos_central
175
+ pos_galaxies[:,0] += distance_arcsec * _sin_theta * np.cos(_phi)
176
+ pos_galaxies[:,1] += distance_arcsec * _sin_theta * np.sin(_phi)
177
+ pos_galaxies[:,2] += distance_z * _cos_theta
178
+
179
+ # Add redshift-space distortion
180
+ if args.redshift_space:
181
+
182
+ relative_vel_rad = generated[:,2]
183
+ relative_vel_tan = generated[:,3]
184
+ relative_vel_rad[flag_central] = 0 # Set vr to 0 for central galaxies
185
+ alpha = np.random.uniform(0, 2 * np.pi, size=num_gal)
186
+ vz_gal = - relative_vel_rad * _cos_theta + relative_vel_tan * _sin_theta * np.cos(alpha)
187
+
188
+ beta = vz_gal / (cspeed * 100) # [(km/s) / (km/s)]
189
+
190
+ redshift_rest = pos_galaxies[:,2]
191
+ pos_galaxies[:,2] = ( 1. + redshift_rest ) * np.sqrt( (1. + beta) / (1. - beta) ) - 1.0
192
+
193
+
194
+ return sfr, pos_galaxies, redshift_central_all
@@ -0,0 +1,366 @@
1
+ import os
2
+ import sys
3
+ import argparse
4
+ import json
5
+ import copy
6
+ import h5py
7
+ import re
8
+
9
+ from tqdm import tqdm
10
+
11
+ import numpy as np
12
+
13
+ import torch
14
+
15
+ cspeed = 3e10 # [cm/s]
16
+
17
+ from .io_utils import normalize, namespace_to_dict
18
+
19
+ def create_mask(array, threshold):
20
+ """
21
+ mask out galaxies satisfying either of the following:
22
+ - if it is central and the value is below threshold
23
+ - if it is satellite and any of the satellites before it is below threshold
24
+ """
25
+
26
+ _, seq_length = array.shape
27
+
28
+ mask_valid = array > threshold # (num_halos, max_length)
29
+ mask_below = array <= threshold # (num_halos, max_length)
30
+ mask_below[:, 0] = False
31
+
32
+ first_below = np.where(mask_below.any(axis=1), mask_below.argmax(axis=1), seq_length)
33
+ indices = np.arange(seq_length)[None, :] # (1, max_length)
34
+ mask = indices < first_below[:, None] # (num_halos, max_length)
35
+
36
+ mask = mask & mask_valid # (num_halos, max_length)
37
+
38
+ return mask
39
+
40
+ def generate_galaxy(args, x_in, global_params=None, verbose=True):
41
+ """
42
+ args: args.gpu_id, args.model_dir, args.threshold, and args.max_sfr_file are used
43
+ x_in: (num_halos, num_features_in); halo properties
44
+ """
45
+
46
+ print("# Use Transformer to generate SFR")
47
+
48
+ from cosmoglint.model.transformer import transformer_model
49
+ device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else "cpu")
50
+
51
+ ### load Transformer
52
+ with open("{}/args.json".format(args.model_dir), "r") as f:
53
+ opt = json.load(f, object_hook=lambda d: argparse.Namespace(**d))
54
+ opt.norm_param_dict = namespace_to_dict(opt.norm_param_dict)
55
+
56
+ model = transformer_model(opt)
57
+ model.load_state_dict(torch.load("{}/model.pth".format(args.model_dir), map_location="cpu"))
58
+ model.to(device)
59
+ model.eval()
60
+
61
+ if verbose:
62
+ print("opt: ", opt)
63
+ print(model)
64
+
65
+ ### generate galaxies
66
+ print("# Generate galaxies (batch size: {:d})".format(opt.batch_size))
67
+ for i, key in enumerate(opt.input_features):
68
+ x_in[...,i] = normalize(x_in[...,i], key, opt.norm_param_dict)
69
+ x_in = torch.from_numpy(x_in).float().to(device)
70
+
71
+ if global_params is not None:
72
+ global_params = global_params[opt.global_features].to_numpy(dtype=np.float32)
73
+ for i, key in enumerate(opt.global_features):
74
+ global_params[...,i] = normalize(global_params[...,i], key, opt.norm_param_dict)
75
+ global_params = torch.tensor(np.array(global_params), dtype=torch.float32).to(device)
76
+
77
+ if args.max_sfr_file is None:
78
+ print("# No max SFR file provided, using default max IDs")
79
+ max_ids = None
80
+ else:
81
+ max_ids = np.loadtxt(args.max_sfr_file)
82
+ max_ids = ( max_ids * opt.num_features_out ).astype(int)
83
+ max_ids = torch.tensor(max_ids).to(device) # (num_features, )
84
+
85
+ num_batch = (len(x_in) + opt.batch_size - 1) // opt.batch_size
86
+ stop_criterion = normalize(args.threshold, opt.output_features[0], opt.norm_param_dict) # stop criterion for SFR
87
+ generated = []
88
+ for batch_idx in tqdm(range(num_batch)):
89
+ start = batch_idx * opt.batch_size
90
+ x_batch = x_in[start: start + opt.batch_size] # (batch_size, num_features)
91
+ global_cond_batch = global_params.unsqueeze(0).repeat(len(x_batch), 1) if global_params is not None else None # (batch_size, num_global_features)
92
+ with torch.no_grad():
93
+ generated_batch, _ = model.generate(x_batch, global_cond=global_cond_batch, prob_threshold=1e-5, stop_criterion=stop_criterion, max_ids=max_ids) # (batch_size, seq_length, num_features)
94
+
95
+ generated.append(generated_batch.cpu().detach().numpy())
96
+
97
+ generated = np.concatenate(generated, axis=0) # (num_halos, seq_length, num_features) or (num_halos, seq_length * num_features, 1)
98
+
99
+ if opt.use_flat_representation:
100
+ generated = generated.squeeze(-1).reshape(len(generated), -1, opt.num_features_in) # (num_halos, max_length, num_features)
101
+ mask = mask.reshape(len(mask), -1, opt.num_features_in)
102
+
103
+ mask = create_mask(generated[:,:,0], stop_criterion) # (num_halos, seq_length)
104
+
105
+ # De-normalize
106
+ for i, key in enumerate(opt.output_features):
107
+ generated[...,i] = normalize(generated[...,i], key, opt.norm_param_dict, inverse=True)
108
+
109
+ print("# Number of valid galaxies: {:d}".format(len(generated)))
110
+
111
+ return generated, mask
112
+
113
+ def generate_galaxy_TransNF(args, x_in, global_params=None, verbose=True):
114
+ """
115
+ args: args.gpu_id, args.model_dir, and args.threshold are used
116
+ x_in: (num_halos, num_features_in), halo properties
117
+ """
118
+
119
+ print("# Use Transformer-NF to generate galaxies")
120
+
121
+ from cosmoglint.model.transformer_nf import transformer_nf_model, generate_with_transformer_nf
122
+ device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else "cpu")
123
+ print("Using device: ", device)
124
+
125
+ ### load Transformer
126
+ with open("{}/args.json".format(args.model_dir), "r") as f:
127
+ opt = json.load(f, object_hook=lambda d: argparse.Namespace(**d))
128
+ opt.norm_param_dict = namespace_to_dict(opt.norm_param_dict)
129
+
130
+ model, flow = transformer_nf_model(opt)
131
+
132
+ model.load_state_dict(torch.load("{}/model.pth".format(args.model_dir), map_location="cpu"))
133
+ model.to(device)
134
+ model.eval()
135
+
136
+ flow.load_state_dict(torch.load("{}/flow.pth".format(args.model_dir), map_location="cpu"))
137
+ flow.to(device)
138
+ flow.eval()
139
+
140
+ if verbose:
141
+ print("opt: ", opt)
142
+ print(model)
143
+ print(flow)
144
+
145
+ ### generate galaxies
146
+ print("# Generate galaxies (batch size: {:d})".format(opt.batch_size))
147
+
148
+ for i, key in enumerate(opt.input_features):
149
+ x_in[...,i] = normalize(x_in[...,i], key, opt.norm_param_dict)
150
+ x_in = torch.from_numpy(x_in).float().to(device)
151
+
152
+ if global_params is not None:
153
+ global_params = np.array([global_params[name] for name in opt.global_features], dtype=np.float32)
154
+ for i, key in enumerate(opt.global_features):
155
+ global_params[...,i] = normalize(global_params[...,i], key, opt.norm_param_dict)
156
+ global_params = torch.from_numpy(global_params).float().to(device)
157
+
158
+ num_batch = (len(x_in) + opt.batch_size - 1) // opt.batch_size
159
+ generated = []
160
+ def stop_criterion(sample):
161
+ # sample: (batch, num_features)
162
+ return (sample[:, 0] < 1).all()
163
+
164
+ for batch_idx in tqdm(range(num_batch)):
165
+ start = batch_idx * opt.batch_size
166
+ x_batch = x_in[start: start + opt.batch_size] # (batch_size, 1)
167
+ global_cond_batch = global_params.unsqueeze(0).repeat(len(x_batch), 1) if global_params is not None else None
168
+ generated_batch = generate_with_transformer_nf(model, flow, x_batch, global_cond=global_cond_batch, stop_criterion=stop_criterion) # (batch_size, max_length, num_features)
169
+ generated.append(generated_batch.cpu().detach().numpy())
170
+ generated = torch.cat(generated, dim=0) # (num_halos, max_length, num_features) or (num_halos, max_length * num_features, 1)
171
+
172
+ # De-normalize
173
+ for i, key in enumerate(opt.output_features):
174
+ generated[...,i] = normalize(generated[...,i], key, opt.norm_param_dict, inverse=True)
175
+
176
+ # Set mask for selection
177
+ sfr = generated[...,0]
178
+ mask = create_mask(sfr, args.threshold) # (num_halos, seq_length)
179
+
180
+ print("# Number of valid galaxies: {:d}".format(len(generated)))
181
+
182
+ return generated, mask
183
+
184
+ def populate_galaxies_in_cube(args, x_in, pos, vel, redshift, cosmo, global_params=None):
185
+
186
+ if "Transformer_NF" in args.model_dir:
187
+ generated, mask = generate_galaxy_TransNF(args, x_in, global_params=global_params)
188
+ else:
189
+ generated, mask = generate_galaxy(args, x_in, global_params=global_params)
190
+
191
+ seq_length = mask.shape[1]
192
+ num_features = generated.shape[-1]
193
+ num_gal = mask.sum()
194
+
195
+ # Define flag_central
196
+ flag_central = np.zeros_like(mask, dtype=bool) # (num_halos, seq_length)
197
+ flag_central[:, 0] = True
198
+
199
+ # Flatten the arrays
200
+ mask = mask.reshape(-1) # (num_halos * seq_length, )
201
+ generated = generated.reshape(-1, num_features) # (num_halos * seq_length, num_features)
202
+ pos_central = np.repeat(pos[:,None,:], seq_length, axis=1).reshape(-1, 3) # (num_halos * seq_length, 3)
203
+ vel_central = np.repeat(vel[:,None,:], seq_length, axis=1).reshape(-1, 3) # (num_halos * seq_length, 3)
204
+ flag_central = flag_central.reshape(-1) # (num_halos * seq_length, )
205
+
206
+ # Apply mask to arrays
207
+ generated = generated[mask] # (num_galaxies_valid, num_features)
208
+ pos_central = pos_central[mask] # (num_galaxies_valid, 3)
209
+ vel_central = vel_central[mask] # (num_galaxies_valid, 3)
210
+ flag_central = flag_central[mask] # (num_galaxies_valid, )
211
+
212
+ # Distribute galaxies in cube
213
+ print("# Generate positions of galaxies")
214
+
215
+ sfr = generated[:,0]
216
+ distance = generated[:,1]
217
+
218
+ phi = np.random.uniform(0, 2 * np.pi, size=num_gal)
219
+ cos_theta = np.random.uniform(-1, 1, size=num_gal)
220
+ sin_theta = np.sqrt(1 - cos_theta ** 2)
221
+
222
+ pos_galaxies = pos_central
223
+ pos_galaxies[:,0] += distance * sin_theta * np.cos(phi)
224
+ pos_galaxies[:,1] += distance * sin_theta * np.sin(phi)
225
+ pos_galaxies[:,2] += distance * cos_theta
226
+
227
+ pos_galaxies_real = copy.deepcopy(pos_galaxies)
228
+
229
+ # Add redshift-space distortion
230
+ if args.redshift_space:
231
+ import astropy.units as u
232
+ H = cosmo.H(redshift).to(u.km/u.s/u.Mpc).value #[km/s/Mpc]
233
+ hlittle = cosmo.H(0).to(u.km/u.s/u.Mpc).value / 100.0
234
+ scale_factor = 1 / (1 + redshift)
235
+
236
+ relative_vel_rad = generated[:,2]
237
+ relative_vel_tan = generated[:,3]
238
+ relative_vel_rad[flag_central] = 0 # Set vr to 0 for central galaxies
239
+ alpha = np.random.uniform(0, 2 * np.pi, size=num_gal)
240
+ vz_gal = - relative_vel_rad * cos_theta + relative_vel_tan * sin_theta * np.cos(alpha)
241
+ pos_galaxies[:,2] += ( vel_central[:,2] + vz_gal )/ scale_factor / H * hlittle
242
+
243
+ return sfr, pos_galaxies_real, pos_galaxies
244
+
245
+ def populate_galaxies_in_lightcone(args, x_in, pos, redshift, cosmo, global_params=None):
246
+ """
247
+ args: args.gpu_id, args.model_dir, args.model_config_file, args.args.threshold, and args.param_dir are used
248
+ x_in: (num_halos, )
249
+ pos: (num_halos, 3)
250
+ redshift: (num_halos, )
251
+ """
252
+
253
+ print("# Use Transformer to generate SFR")
254
+
255
+ opt = copy.deepcopy(args)
256
+
257
+ generated_all = []
258
+ pos_central_all = []
259
+ redshift_central_all = []
260
+ flag_central_all = []
261
+
262
+ with open(args.model_config_file, "r") as f:
263
+ snapshot_dict_str = json.load(f)
264
+ snapshot_dict = {int(k): v for k, v in snapshot_dict_str.items()}
265
+
266
+ print("# Model config:", snapshot_dict)
267
+ redshifts_of_snapshots = np.array([ v[1] for v in snapshot_dict.values() ])
268
+ bin_edges = (redshifts_of_snapshots[:-1] + redshifts_of_snapshots[1:]) / 2.0
269
+ bin_indices = np.digitize(redshift, bin_edges)
270
+
271
+ if args.param_dir is None:
272
+ max_sfr_file_list = [ None for snapshot_number in snapshot_dict ]
273
+ else:
274
+ max_sfr_file_list = ["{}/max_nbin20_{:d}.txt".format(args.param_dir, snapshot_number) for snapshot_number in snapshot_dict]
275
+
276
+ for i, snapshot_number in enumerate(snapshot_dict):
277
+ model_path, redshift_of_snapshot = snapshot_dict[snapshot_number]
278
+ print("# Snapshot number: {:d}, Redshift: {:.2f}".format(snapshot_number, redshift_of_snapshot))
279
+
280
+ ### Skip if no haloes in this redshift bin
281
+ mask_z = (bin_indices == i)
282
+ if not np.any(mask_z):
283
+ print("# No haloes in redshift bin {:d} (snapshot number {:d}), skipping...".format(i, snapshot_number))
284
+ continue
285
+
286
+ x_now = x_in[mask_z] # (num_halos_in_bin, )
287
+ pos_now = pos[mask_z] # (num_halos_in_bin, 3)
288
+ redshift_now = redshift[mask_z] # (num_halos_in_bin, 1)
289
+
290
+ opt.model_dir = "{}/{}".format(args.model_dir, model_path)
291
+ opt.max_sfr_file = max_sfr_file_list[i]
292
+
293
+ if "Transformer_NF" in opt.model_dir:
294
+ generated, mask = generate_galaxy_TransNF(opt, x_now, global_params=global_params, verbose=False)
295
+ else:
296
+ generated, mask = generate_galaxy(opt, x_now, global_params=global_params, verbose=False)
297
+
298
+ seq_length = mask.shape[1]
299
+ num_features = generated.shape[-1]
300
+
301
+ # Define flag_central
302
+ flag_central = np.zeros_like(mask, dtype=bool)
303
+ flag_central[:, 0] = True
304
+
305
+ # Flatten the arrays
306
+ mask = mask.reshape(-1)
307
+ generated = generated.reshape(-1, num_features) # (num_halos * seq_length, num_features)
308
+ pos_central = np.repeat(pos_now[:,None,:], seq_length, axis=1).reshape(-1, 3) # (num_halos * seq_length, 3)
309
+ redshift_central = np.repeat(redshift_now[:,None], seq_length, axis=1).reshape(-1) # (num_halos * seq_length)
310
+ flag_central = flag_central.reshape(-1)
311
+
312
+ # Apply mask to arrays
313
+ generated = generated[mask] # (num_galaxies_valid, num_features)
314
+ pos_central = pos_central[mask] # (num_galaxies_valid, 3)
315
+ redshift_central = redshift_central[mask] # (num_galaxies_valid, 3)
316
+ flag_central = flag_central[mask] # (num_galaxies_valid, )
317
+
318
+ # Append
319
+ generated_all.append(generated)
320
+ pos_central_all.append(pos_central)
321
+ redshift_central_all.append(redshift_central)
322
+ flag_central_all.append(flag_central)
323
+
324
+ generated_all = np.concatenate(generated_all, axis=0) # (num_galaxies_valid, num_features)
325
+ pos_central_all = np.concatenate(pos_central_all, axis=0)
326
+ redshift_central_all = np.concatenate(redshift_central_all, axis=0) # (num_galaxies_valid,)
327
+ flag_central_all = np.concatenate(flag_central_all, axis=0) # (num_galaxies_valid,)
328
+
329
+ ### Distribute galaxies in lightcone
330
+ sfr = generated_all[:,0]
331
+ distance = generated_all[:,1]
332
+
333
+ num_gal = len(sfr)
334
+
335
+ # Determine positions of galaxies
336
+ print("# Generate positions of galaxies")
337
+ _phi = np.random.uniform(0, 2 * np.pi, size=num_gal)
338
+ _cos_theta = np.random.uniform(-1, 1, size=num_gal)
339
+ _sin_theta = np.sqrt(1 - _cos_theta ** 2)
340
+
341
+ # Convert Mpc to deg
342
+ from .cosmology_utils import cMpc_to_arcsec, dcMpc_to_dz
343
+ distance_arcsec = cMpc_to_arcsec(distance, redshift_central_all, cosmo=cosmo, l_with_hlittle=True)
344
+ distance_z = dcMpc_to_dz(distance, redshift_central_all, cosmo=cosmo, l_with_hlittle=True)
345
+
346
+ pos_galaxies = pos_central_all
347
+ pos_galaxies[:,0] += distance_arcsec * _sin_theta * np.cos(_phi)
348
+ pos_galaxies[:,1] += distance_arcsec * _sin_theta * np.sin(_phi)
349
+ pos_galaxies[:,2] += distance_z * _cos_theta
350
+
351
+ # Add redshift-space distortion
352
+ if args.redshift_space:
353
+
354
+ relative_vel_rad = generated_all[:,2]
355
+ relative_vel_tan = generated_all[:,3]
356
+ relative_vel_rad[flag_central_all] = 0 # Set vr to 0 for central galaxies
357
+ alpha = np.random.uniform(0, 2 * np.pi, size=num_gal)
358
+ vz_gal = - relative_vel_rad * _cos_theta + relative_vel_tan * _sin_theta * np.cos(alpha)
359
+
360
+ beta = vz_gal / (cspeed * 100) # [(km/s) / (km/s)]
361
+
362
+ redshift_rest = pos_galaxies[:,2]
363
+ pos_galaxies[:,2] = ( 1. + redshift_rest ) * np.sqrt( (1. + beta) / (1. - beta) ) - 1.0
364
+
365
+
366
+ return sfr, pos_galaxies, redshift_central_all