ler 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ler might be problematic. Click here for more details.

Files changed (35) hide show
  1. ler/__init__.py +26 -26
  2. ler/gw_source_population/__init__.py +1 -0
  3. ler/gw_source_population/cbc_source_parameter_distribution.py +1076 -818
  4. ler/gw_source_population/cbc_source_redshift_distribution.py +619 -295
  5. ler/gw_source_population/jit_functions.py +484 -9
  6. ler/gw_source_population/sfr_with_time_delay.py +107 -0
  7. ler/image_properties/image_properties.py +44 -13
  8. ler/image_properties/multiprocessing_routine.py +5 -209
  9. ler/lens_galaxy_population/__init__.py +2 -0
  10. ler/lens_galaxy_population/epl_shear_cross_section.py +0 -0
  11. ler/lens_galaxy_population/jit_functions.py +101 -9
  12. ler/lens_galaxy_population/lens_galaxy_parameter_distribution.py +817 -885
  13. ler/lens_galaxy_population/lens_param_data/density_profile_slope_sl.txt +5000 -0
  14. ler/lens_galaxy_population/lens_param_data/external_shear_sl.txt +2 -0
  15. ler/lens_galaxy_population/lens_param_data/number_density_zl_zs.txt +48 -0
  16. ler/lens_galaxy_population/lens_param_data/optical_depth_epl_shear_vd_ewoud.txt +48 -0
  17. ler/lens_galaxy_population/mp copy.py +554 -0
  18. ler/lens_galaxy_population/mp.py +736 -138
  19. ler/lens_galaxy_population/optical_depth.py +2248 -616
  20. ler/rates/__init__.py +1 -2
  21. ler/rates/gwrates.py +129 -75
  22. ler/rates/ler.py +257 -116
  23. ler/utils/__init__.py +2 -0
  24. ler/utils/function_interpolation.py +322 -0
  25. ler/utils/gwsnr_training_data_generator.py +233 -0
  26. ler/utils/plots.py +1 -1
  27. ler/utils/test.py +1078 -0
  28. ler/utils/utils.py +553 -125
  29. {ler-0.4.1.dist-info → ler-0.4.3.dist-info}/METADATA +22 -9
  30. ler-0.4.3.dist-info/RECORD +34 -0
  31. {ler-0.4.1.dist-info → ler-0.4.3.dist-info}/WHEEL +1 -1
  32. ler/rates/ler copy.py +0 -2097
  33. ler-0.4.1.dist-info/RECORD +0 -25
  34. {ler-0.4.1.dist-info → ler-0.4.3.dist-info/licenses}/LICENSE +0 -0
  35. {ler-0.4.1.dist-info → ler-0.4.3.dist-info}/top_level.txt +0 -0
ler/utils/__init__.py CHANGED
@@ -1,2 +1,4 @@
1
1
  from .utils import *
2
2
  from .plots import *
3
+ from .gwsnr_training_data_generator import *
4
+ from .function_interpolation import *
@@ -0,0 +1,322 @@
1
+
2
+ from .utils import interpolator_pickle_path, cubic_spline_interpolator, cubic_spline_interpolator2d_array, inverse_transform_sampler, pdf_cubic_spline_interpolator2d_array, save_pickle, load_pickle, inverse_transform_sampler2d
3
+ import numpy as np
4
+ from scipy.integrate import quad, cumulative_trapezoid
5
+ from scipy.interpolate import CubicSpline
6
+ from scipy.stats import gaussian_kde
7
+ from numba import njit
8
+
9
+
10
+ class FunctionConditioning():
11
+
12
+ def __init__(self,
13
+ function=None, # can also be an array of function values
14
+ x_array=None,
15
+ conditioned_y_array=None, # if this is not none, 2D interpolation will be used
16
+ y_array=None,
17
+ gaussian_kde=False,
18
+ gaussian_kde_kwargs={},
19
+ param_dict_given={},
20
+ directory='./interpolator_pickle',
21
+ sub_directory='default',
22
+ name='default',
23
+ create_new=False,
24
+ create_function=False,
25
+ create_function_inverse=False,
26
+ create_pdf=False,
27
+ create_rvs=False,
28
+ multiprocessing_function=False,
29
+ callback=None,
30
+ ):
31
+ create = self.create_decision_function(create_function, create_function_inverse, create_pdf, create_rvs)
32
+ self.info = param_dict_given
33
+ self.callback = callback
34
+
35
+ if create:
36
+ # create_interpolator input list
37
+ input_list = [function, x_array, conditioned_y_array, create_function_inverse, create_pdf, create_rvs, multiprocessing_function]
38
+ input_list_kde = [x_array, y_array, gaussian_kde_kwargs]
39
+
40
+ # check first whether the directory, subdirectory and pickle exist
41
+ path_inv_cdf, it_exist = interpolator_pickle_path(
42
+ param_dict_given=param_dict_given,
43
+ directory=directory,
44
+ sub_directory=sub_directory,
45
+ interpolator_name=name,
46
+ )
47
+
48
+ # if the interpolator exists, load it
49
+ if create_new:
50
+ it_exist = False
51
+
52
+ if it_exist:
53
+ print(f"{name} interpolator will be loaded from {path_inv_cdf}")
54
+ # load the interpolator
55
+ interpolator = load_pickle(path_inv_cdf)
56
+ else:
57
+ print(f"{name} interpolator will be generated at {path_inv_cdf}")
58
+ if not gaussian_kde:
59
+ interpolator = self.create_interpolator(*input_list)
60
+ # save the interpolator
61
+ save_pickle(path_inv_cdf, interpolator)
62
+ else:
63
+ # gaussian KDE
64
+ interpolator = self.create_gaussian_kde(*input_list_kde)
65
+ save_pickle(path_inv_cdf, interpolator)
66
+
67
+ if not gaussian_kde:
68
+ x_array = interpolator['x_array']
69
+ z_array = interpolator['z_array']
70
+ conditioned_y_array = interpolator['conditioned_y_array']
71
+ y_array = None
72
+ function_spline = interpolator['function_spline']
73
+ function_inverse_spline = interpolator['function_inverse_spline']
74
+ pdf_norm_const = interpolator['pdf_norm_const']
75
+ cdf_values = interpolator['cdf_values']
76
+
77
+ if conditioned_y_array is None:
78
+ # function is 1D
79
+ self.function = njit(lambda x: cubic_spline_interpolator(x, function_spline, x_array)) if create_function else None
80
+ # inverse function is 1D
81
+ self.function_inverse = njit(lambda x: cubic_spline_interpolator(x, function_inverse_spline, z_array)) if create_function_inverse else None
82
+
83
+ # pdf is 1D
84
+ self.pdf = njit(lambda x: cubic_spline_interpolator(x, function_spline, x_array)/pdf_norm_const) if create_pdf else None
85
+ # sampler is 1D
86
+ self.rvs = njit(lambda size: inverse_transform_sampler(size, cdf_values, x_array)) if create_rvs else None
87
+
88
+ else:
89
+ self.function = njit(lambda x, y: cubic_spline_interpolator2d_array(x, y, function_spline, x_array, conditioned_y_array)) if create_function else None
90
+
91
+ self.function_inverse = njit(lambda x, y: cubic_spline_interpolator2d_array(x, y, function_inverse_spline, z_array, conditioned_y_array)) if create_function_inverse else None
92
+
93
+ self.pdf = njit(lambda x, y: pdf_cubic_spline_interpolator2d_array(x, y, pdf_norm_const, function_spline, x_array, conditioned_y_array)) if create_pdf else None
94
+
95
+ self.rvs = njit(lambda size, y: inverse_transform_sampler2d(size, y, cdf_values, x_array, conditioned_y_array)) if create_rvs else None
96
+
97
+ self.x_array = x_array
98
+ self.z_array = z_array
99
+ self.conditioned_y_array = conditioned_y_array
100
+ self.function_spline = function_spline
101
+ self.function_inverse_spline = function_inverse_spline
102
+ self.pdf_norm_const = pdf_norm_const
103
+ self.cdf_values = cdf_values
104
+ else:
105
+ x_array = interpolator['x_array']
106
+ y_array = interpolator['y_array']
107
+ kde_object = interpolator['kde_object']
108
+
109
+
110
+ self.pdf = lambda x: kde_object.pdf(x) if create_pdf else None
111
+ if y_array is None:
112
+ self.rvs = lambda size: kde_object.resample(size)[0] if create_rvs else None
113
+ else:
114
+ self.rvs = lambda size: kde_object.resample(size) if create_rvs else None
115
+
116
+ self.x_array = x_array
117
+ self.y_array = y_array
118
+ self.kde_object = kde_object
119
+
120
+ def __call__(self, *args):
121
+ args = [np.array(arg).reshape(-1) if isinstance(arg, float) else arg for arg in args]
122
+ return getattr(self, self.callback)(*args)
123
+
124
+
125
+ def create_decision_function(self, create_function, create_function_inverse, create_pdf, create_rvs):
126
+
127
+ decision_bool = True
128
+ if not isinstance(create_function, bool) and callable(create_function):
129
+ self.function = create_function
130
+ decision_bool = False
131
+ if not isinstance(create_function_inverse, bool) and callable(create_function_inverse):
132
+ self.function_inverse = create_function_inverse
133
+ decision_bool = False
134
+ if not isinstance(create_pdf, bool) and callable(create_pdf):
135
+ self.pdf = create_pdf
136
+ decision_bool = False
137
+ if not isinstance(create_rvs, bool) and callable(create_rvs):
138
+ self.rvs = create_rvs
139
+ decision_bool = False
140
+
141
+ return decision_bool
142
+
143
+
144
+ def create_gaussian_kde(self, x_array, y_array, gaussian_kde_kwargs):
145
+
146
+ # 1d KDE
147
+ if y_array is None:
148
+ kde = gaussian_kde(x_array, **gaussian_kde_kwargs)
149
+ else:
150
+ data = np.vstack([x_array, y_array])
151
+ kde = gaussian_kde(data, **gaussian_kde_kwargs)
152
+
153
+ return {
154
+ 'x_array': x_array,
155
+ 'y_array': y_array,
156
+ 'kde_object': kde,
157
+ }
158
+
159
+ def create_interpolator(self, function, x_array, conditioned_y_array, create_function_inverse, create_pdf, create_rvs, multiprocessing_function):
160
+
161
+ # function can be numpy array or callable
162
+ # x_array, z_array are 2D arrays if conditioned_y_array is not None
163
+ x_array, z_array, conditioned_y_array = self.create_z_array(x_array, function, conditioned_y_array, create_pdf, create_rvs, multiprocessing_function)
164
+ del function
165
+
166
+ function_spline = self.function_spline_generator(x_array, z_array, conditioned_y_array)
167
+
168
+ if create_function_inverse:
169
+ if conditioned_y_array is None:
170
+ idx_sort = np.argsort(z_array)
171
+ else:
172
+ idx_sort = np.argsort(z_array, axis=1)
173
+ x_array_ = x_array[idx_sort]
174
+ z_array_ = z_array[idx_sort]
175
+
176
+ # check z_array is strictly increasing
177
+ # if (not np.all(np.diff(z_array) > 0)) or (not np.all(np.diff(z_array) < 0)):
178
+ # raise ValueError("z_array must be strictly increasing")
179
+
180
+ function_inverse_spline = self.function_spline_generator(z_array_, x_array_, conditioned_y_array)
181
+ else:
182
+ function_inverse_spline = None
183
+
184
+ if create_pdf or create_rvs:
185
+ # cannot have -ve pdf
186
+ pdf_norm_const = self.pdf_norm_const_generator(x_array, function_spline, conditioned_y_array)
187
+
188
+ if create_rvs:
189
+ cdf_values = self.cdf_values_generator(x_array, z_array, conditioned_y_array)
190
+ else:
191
+ cdf_values = None
192
+ else:
193
+ pdf_norm_const = None
194
+ cdf_values = None
195
+
196
+ return {
197
+ 'x_array': x_array,
198
+ 'z_array': z_array,
199
+ 'conditioned_y_array': conditioned_y_array,
200
+ 'function_spline': function_spline,
201
+ 'function_inverse_spline': function_inverse_spline,
202
+ 'pdf_norm_const': pdf_norm_const,
203
+ 'cdf_values': cdf_values,
204
+ }
205
+
206
+ def create_z_array(self, x_array, function, conditioned_y_array, create_pdf, create_rvs, multiprocessing_function):
207
+
208
+ if callable(function):
209
+ # 1D
210
+ if conditioned_y_array is None:
211
+ z_array = function(x_array)
212
+ # remove nan values
213
+ idx = np.argwhere(np.isnan(z_array))
214
+ x_array = np.delete(x_array, idx)
215
+ z_array = np.delete(z_array, idx)
216
+ # 2D
217
+ else:
218
+ # check if x_array is 2D, if not, make it 2D of shape (len(conditioned_y_array), len(x_array))
219
+ if x_array.ndim == 1:
220
+ x_array = np.array([x_array]*len(conditioned_y_array))
221
+
222
+ idx = np.argsort(conditioned_y_array)
223
+ conditioned_y_array = conditioned_y_array[idx]
224
+ # x_array is 2D here, each row corresponds to a different y value
225
+ x_array = x_array[idx]
226
+ # sort each row of x_array
227
+ x_array = np.sort(x_array, axis=1)
228
+
229
+ if multiprocessing_function:
230
+ z_array = function(x_array, conditioned_y_array)
231
+ else:
232
+ z_list = []
233
+ for i, y in enumerate(conditioned_y_array):
234
+ try:
235
+ z_list.append(function(x_array[i], y*np.ones_like(x_array[i])))
236
+ except:
237
+ # print(x_array[i], y)
238
+ z_list.append(function(x_array[i], y))
239
+ z_array = np.array(z_list)
240
+ else:
241
+ if conditioned_y_array is None:
242
+ z_array = function
243
+ # remove nan values
244
+ idx = np.argwhere(np.isnan(z_array))
245
+ x_array = np.delete(x_array, idx)
246
+ z_array = np.delete(z_array, idx)
247
+ else:
248
+ if x_array.ndim == 1:
249
+ x_array = np.array([x_array]*len(conditioned_y_array))
250
+ if function.ndim == 1:
251
+ raise ValueError('function must be 2D array if conditioned_y_array is not None')
252
+ # row sort
253
+ idx = np.argsort(conditioned_y_array)
254
+ conditioned_y_array = conditioned_y_array[idx]
255
+ # x_array is 2D here, each row corresponds to a different y value
256
+ x_array = x_array[idx]
257
+ z_array = function[idx]
258
+
259
+ z_list = []
260
+ x_list = []
261
+ for i in range(len(conditioned_y_array)):
262
+ # column sort
263
+ idx = np.argsort(x_array[i])
264
+ x_list.append(x_array[i][idx])
265
+ z_list.append(z_array[i][idx])
266
+ x_array = np.array(x_list)
267
+ z_array = np.array(z_list)
268
+
269
+ # cannot have -ve pdf
270
+ if create_pdf or create_rvs:
271
+ z_array[z_array < 0.0] = 0.0
272
+
273
+ return x_array, z_array, conditioned_y_array
274
+
275
+ def cdf_values_generator(self, x_array, z_array, conditioned_y_array):
276
+ # 1D case
277
+ if conditioned_y_array is None:
278
+ # z_array[z_array<0.] = 0. # already done
279
+ cdf_values = cumulative_trapezoid(z_array, x_array, initial=0)
280
+ cdf_values = cdf_values/cdf_values[-1]
281
+ # 2D case
282
+ else:
283
+ cdf_values = []
284
+ for i, y in enumerate(conditioned_y_array):
285
+ z_array_ = z_array[i]
286
+ z_array_[z_array_<0.] = 0.
287
+ cdfs_ = cumulative_trapezoid(z_array_, x_array[i], initial=0)
288
+
289
+ cdf_values.append(cdfs_/cdfs_[-1])
290
+ # cdf_values.append(cdfs_)
291
+
292
+ return np.array(cdf_values)
293
+
294
+ def pdf_norm_const_generator(self, x_array, function_spline, conditioned_y_array):
295
+ # 1D case
296
+ if conditioned_y_array is None:
297
+ pdf_unorm = lambda x: cubic_spline_interpolator(np.array([x]), function_spline, x_array)
298
+
299
+ norm = quad(pdf_unorm, min(x_array), max(x_array))[0]
300
+ return norm
301
+ # 2D case
302
+ else:
303
+ norm = []
304
+ for i, y in enumerate(conditioned_y_array):
305
+ pdf_unorm = lambda x: cubic_spline_interpolator(np.array([x]), function_spline[i], x_array[i])
306
+
307
+ norm.append(quad(pdf_unorm, min(x_array[i]), max(x_array[i]))[0])
308
+
309
+ return np.array(norm)
310
+
311
+ def function_spline_generator(self, x_array, z_array, conditioned_y_array):
312
+ # 1D case
313
+ if conditioned_y_array is None:
314
+ function_spline = CubicSpline(x_array, z_array).c
315
+ # 2D case
316
+ else:
317
+ function_spline = []
318
+ for i, y in enumerate(conditioned_y_array):
319
+
320
+ function_spline.append(CubicSpline(x_array[i], z_array[i]).c)
321
+
322
+ return np.array(function_spline)
@@ -0,0 +1,233 @@
1
+ import os
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import contextlib
5
+ from gwsnr import GWSNR
6
+ from .utils import append_json, get_param_from_json
7
+
8
+ class TrainingDataGenerator():
9
+
10
+ def __init__(self,
11
+ npool=4,
12
+ z_min=0.0,
13
+ z_max=5.0,
14
+ verbose=True,
15
+ **kwargs, # ler and gwsnr arguments
16
+ ):
17
+
18
+ self.npool = npool
19
+ self.z_min = z_min
20
+ self.z_max = z_max
21
+ self.verbose = verbose
22
+
23
+ self.ler_init_args = dict(
24
+ event_type="BBH",
25
+ cosmology=None,
26
+ ler_directory="./ler_data",
27
+ spin_zero=False,
28
+ spin_precession=True,
29
+ # gwsnr args
30
+ mtot_min=2*4.98, # 4.98 Mo is the minimum component mass of BBH systems in GWTC-3
31
+ mtot_max=2*112.5+10.0, # 112.5 Mo is the maximum component mass of BBH systems in GWTC-3. 10.0 Mo is added to avoid edge effects.
32
+ ratio_min=0.1,
33
+ ratio_max=1.0,
34
+ spin_max=0.99,
35
+ mtot_resolution=200,
36
+ ratio_resolution=20,
37
+ spin_resolution=10,
38
+ sampling_frequency=2048.0,
39
+ waveform_approximant="IMRPhenomXPHM",
40
+ minimum_frequency=20.0,
41
+ snr_type="interpolation_aligned_spins",
42
+ psds=None,
43
+ ifos=None,
44
+ interpolator_dir="./interpolator_pickle",
45
+ create_new_interpolator=False,
46
+ gwsnr_verbose=True,
47
+ multiprocessing_verbose=True,
48
+ mtot_cut=False,
49
+ )
50
+ self.ler_init_args.update(kwargs)
51
+
52
+ def gw_parameters_generator(self,
53
+ size,
54
+ batch_size=100000,
55
+ snr_recalculation=True,
56
+ trim_to_size=False,
57
+ verbose=True,
58
+ replace=False,
59
+ data_distribution_range = [0, 2, 4, 6, 8, 10, 12, 14, 16, 100],
60
+ output_jsonfile="gw_parameters.json",
61
+ ):
62
+
63
+ args = self.ler_init_args.copy()
64
+ if snr_recalculation:
65
+ snr_type = 'interpolation_aligned_spins'
66
+ else:
67
+ snr_type = 'inner_product'
68
+
69
+ #
70
+ from ler.rates import GWRATES
71
+
72
+ # ler initialization
73
+ ler = GWRATES(
74
+ npool=self.npool,
75
+ z_min=self.z_min,
76
+ z_max=self.z_max, # becareful with this value
77
+ verbose=self.verbose,
78
+ # ler
79
+ event_type=args['event_type'],
80
+ cosmology=args['cosmology'],
81
+ ler_directory=args['ler_directory'],
82
+ spin_zero=args['spin_zero'],
83
+ spin_precession=args['spin_precession'],
84
+ # gwsnr args
85
+ mtot_min=args['mtot_min'],
86
+ mtot_max=args['mtot_max'],
87
+ ratio_min=args['ratio_min'],
88
+ ratio_max=args['ratio_max'],
89
+ mtot_resolution=args['mtot_resolution'],
90
+ ratio_resolution=args['ratio_resolution'],
91
+ sampling_frequency=args['sampling_frequency'],
92
+ waveform_approximant=args['waveform_approximant'],
93
+ minimum_frequency=args['minimum_frequency'],
94
+ snr_type=snr_type,
95
+ psds=args['psds'],
96
+ ifos=args['ifos'],
97
+ interpolator_dir=args['interpolator_dir'],
98
+ create_new_interpolator=args['create_new_interpolator'],
99
+ gwsnr_verbose=args['gwsnr_verbose'],
100
+ multiprocessing_verbose=args['multiprocessing_verbose'],
101
+ mtot_cut=args['mtot_cut'],
102
+ )
103
+ ler.batch_size = batch_size
104
+
105
+ # path to save parameters
106
+ json_path = f"{args['ler_directory']}/{output_jsonfile}"
107
+ if replace:
108
+ if os.path.exists(json_path):
109
+ os.remove(json_path)
110
+ len_final = 0
111
+ else:
112
+ if os.path.exists(json_path):
113
+ gw_param = get_param_from_json(json_path)
114
+ len_final = len(gw_param['snr_net'])
115
+ print(f'current size of the json file: {len_final}\n')
116
+ else:
117
+ len_final = 0
118
+
119
+ print(f'total event to collect: {size}\n')
120
+ while len_final<size:
121
+ with contextlib.redirect_stdout(None):
122
+ gw_param = ler.gw_cbc_statistics(size=ler.batch_size, resume=False)
123
+
124
+ if data_distribution_range is not None:
125
+ gw_param = self.helper_data_distribution(gw_param, data_distribution_range)
126
+
127
+ if gw_param is None:
128
+ continue
129
+
130
+ if snr_recalculation:
131
+ snrs = ler.snr_bilby(gw_param_dict=gw_param)
132
+ gw_param.update(snrs)
133
+
134
+ gw_param = self.helper_data_distribution(gw_param, data_distribution_range)
135
+
136
+ if gw_param is None:
137
+ print("No data in one of the given range")
138
+ continue
139
+ # save the parameters
140
+ append_json(json_path, gw_param, replace=False);
141
+
142
+ # print(f"Collected number of events: {len_}")
143
+ len_final += len(gw_param['snr_net'])
144
+ if verbose:
145
+ print(f"Collected number of events: {len_final}")
146
+
147
+ if trim_to_size:
148
+ gw_param = get_param_from_json(json_path)
149
+ for key, value in gw_param.items():
150
+ gw_param[key] = value[:size]
151
+ append_json(json_path, gw_param, replace=True);
152
+ len_final = len(gw_param['snr_net'])
153
+
154
+ print(f"final size: {len_final}\n")
155
+ print(f"json file saved at: {json_path}\n")
156
+
157
+ def helper_data_distribution(self, gw_param, data_distribution_range):
158
+ # optimal SNR
159
+ snr = np.array(gw_param['snr_net'])
160
+
161
+ idx_arr = []
162
+ snr_range = np.array(data_distribution_range)
163
+ len_ = len(snr_range)
164
+ len_arr = [] # size of len_arr is len_-1
165
+ for j in range(len_-1):
166
+ idx_ = np.argwhere((snr>=snr_range[j]) & (snr<snr_range[j+1])).flatten()
167
+ idx_arr.append(idx_)
168
+ len_arr.append(len(idx_))
169
+
170
+ idx_arr = np.array(idx_arr, dtype=object)
171
+ len_ref = min(len_arr)
172
+
173
+ if len_ref==0:
174
+ print("No data in one of the given range")
175
+ return None
176
+ else:
177
+ gw_param_final = {}
178
+ for j, len_ in enumerate(len_arr): # loop over snr range
179
+ idx_buffer = np.random.choice(idx_arr[j], len_ref, replace=False)
180
+
181
+ for key, value in gw_param.items():
182
+
183
+ try:
184
+ buffer_ = value[idx_buffer]
185
+ except IndexError:
186
+ print(f"IndexError")
187
+ print(f"key: {key}, len(value): {len(value)}, len(idx_buffer): {len(idx_buffer)}")
188
+ print(f"rerun the code again with: replace=False")
189
+ return None
190
+ if j==0:
191
+ gw_param_final[key] = buffer_
192
+ else:
193
+ gw_param_final[key] = np.concatenate([gw_param_final[key], buffer_])
194
+
195
+ return gw_param_final
196
+
197
+ def combine_dicts(self,
198
+ file_name_list=None,
199
+ path_list=None,
200
+ detector='L1',
201
+ parameter_list=['mass_1', 'mass_2', 'luminosity_distance', 'theta_jn', 'psi', 'geocent_time', 'ra', 'dec', 'a_1', 'a_2', 'tilt_1', 'tilt_2'],
202
+ output_jsonfile="combined_data.json",
203
+ ):
204
+
205
+ parameter_list += [detector]
206
+ combined_dict = {}
207
+
208
+ if file_name_list is not None:
209
+ path_list = [f"{self.ler_init_args['ler_directory']}/{file_name}" for file_name in file_name_list]
210
+ elif path_list is None:
211
+ print("Please provide either file_name_list or path_list")
212
+ return None
213
+
214
+ for path in path_list:
215
+ data = get_param_from_json(path)
216
+ for key, value in data.items():
217
+ if key in parameter_list:
218
+ if key in combined_dict:
219
+ combined_dict[key] = np.concatenate([combined_dict[key], value])
220
+ else:
221
+ combined_dict[key] = value
222
+
223
+ # if 'snr_net' is not in the combined_dict, we can calculate it
224
+ combined_dict['snr_net'] = combined_dict[detector]
225
+
226
+ json_path = f"{self.ler_init_args['ler_directory']}/{output_jsonfile}"
227
+ print(f"json file saved at: {json_path}\n")
228
+ append_json(json_path, combined_dict, replace=True);
229
+
230
+ def delete_json_file(self, path_list):
231
+ for path in path_list:
232
+ if os.path.exists(path):
233
+ os.remove(path)
ler/utils/plots.py CHANGED
@@ -174,7 +174,7 @@ def relative_mu_dt_lensed(
174
174
  # get magnifications, time_delays and snr
175
175
  mu = lensed_param["magnifications"]
176
176
  dt = lensed_param["time_delays"]
177
- snr = lensed_param["optimal_snr_net"]
177
+ snr = lensed_param["snr_net"]
178
178
  image_type = lensed_param["image_type"]
179
179
 
180
180
  # pair images wrt to image_type