ler 0.3.8__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ler might be problematic. Click here for more details.

ler/rates/gwrates.py CHANGED
@@ -3,10 +3,10 @@
3
3
  This module contains the main class for calculating the rates of detectable gravitational waves events. The class inherits the :class:`~ler.gw_source_population.CBCSourceParameterDistribution` class for source parameters sampling and uses `gwsnr` package for SNR calculation.
4
4
  """
5
5
 
6
- import contextlib
7
6
  import os
8
7
  import warnings
9
8
  warnings.filterwarnings("ignore")
9
+ import contextlib
10
10
  import numpy as np
11
11
  from scipy.stats import norm
12
12
  from astropy.cosmology import LambdaCDM
@@ -29,25 +29,51 @@ class GWRATES(CBCSourceParameterDistribution):
29
29
  z_max : `float`
30
30
  maximum redshift.
31
31
  default z_max = 10.
32
- for popI_II, popIII, primordial, BNS z_max = 10., 40., 40., 2. respectively.
32
+ for popI_II, popIII, primordial, BNS z_max = 10., 40., 40., 5. respectively.
33
+ event_type : `str`
34
+ type of event to generate.
35
+ default event_type = 'BBH'. Other options are 'BNS', 'NSBH'.
33
36
  size : `int`
34
37
  number of samples for sampling.
35
- default size = 100000.
38
+ default size = 100000. To get stable rates, size should be large (>=1e6).
36
39
  batch_size : `int`
37
40
  batch size for SNR calculation.
38
- default batch_size = 25000.
41
+ default batch_size = 50000.
39
42
  reduce the batch size if you are getting memory error.
40
- recommended batch_size = 50000, if size = 1000000.
41
- snr_finder : `str`
43
+ recommended batch_size = 200000, if size = 1000000.
44
+ cosmology : `astropy.cosmology`
45
+ cosmology to use for the calculation.
46
+ default cosmology = LambdaCDM(H0=70, Om0=0.3, Ode0=0.7).
47
+ snr_finder : `str` or `function`
42
48
  default snr_finder = 'gwsnr'.
43
- if 'gwsnr', the SNR will be calculated using the gwsnr package.
44
- if 'custom', the SNR will be calculated using a custom function.
45
- The custom function should have input and output as given in GWSNR.snr method.
49
+ if None, the SNR will be calculated using the gwsnr package.
50
+ if custom snr finder function is provided, the SNR will be calculated using a custom function. The custom function should follow the following signature:
51
+ def snr_finder(gw_param_dict):
52
+ ...
53
+ return optimal_snr_dict
54
+ where optimal_snr_dict.keys = ['optimal_snr_net']. Refer to `gwsnr` package's GWSNR.snr attribute for more details.
55
+ pdet_finder : `function`
56
+ default pdet_finder = None.
57
+ The rate calculation uses either the pdet_finder or the snr_finder to calculate the detectable events. The custom pdet finder function should follow the following signature:
58
+ def pdet_finder(gw_param_dict):
59
+ ...
60
+ return pdet_net_dict
61
+ where pdet_net_dict.keys = ['pdet_net']. For example uses, refer to [GRB pdet example](https://ler.readthedocs.io/en/latest/examples/rates/grb%20detection%20rate.html).
46
62
  json_file_names: `dict`
47
63
  names of the json files to strore the necessary parameters.
48
- default json_file_names = {'ler_param': 'LeR_params.json', 'gw_param': 'gw_param.json', 'gw_param_detectable': 'gw_param_detectable.json'}.\n
64
+ default json_file_names = {'gwrates_params':'gwrates_params.json', 'gw_param': 'gw_param.json', 'gw_param_detectable': 'gw_param_detectable.json'}.
65
+ interpolator_directory : `str`
66
+ directory to store the interpolators.
67
+ default interpolator_directory = './interpolator_pickle'. This is used for storing the various interpolators related to `ler` and `gwsnr` package.
68
+ ler_directory : `str`
69
+ directory to store the parameters.
70
+ default ler_directory = './ler_data'. This is used for storing the parameters of the simulated events.
71
+ verbose : `bool`
72
+ default verbose = True.
73
+ if True, the function will print all chosen parameters.
74
+ Choose False to prevent anything from printing.
49
75
  kwargs : `keyword arguments`
50
- Note : kwargs takes input for initializing the :class:`~ler.gw_source_population.CBCSourceParameterDistribution`, :meth:`~gwsnr_intialization`.
76
+ Note : kwargs takes input for initializing the :class:`~ler.gw_source_population.CBCSourceParameterDistribution` and :class:`~ler.gw_source_population.CBCSourceRedshiftDistribution` classes. If snr_finder='gwsnr', then kwargs also takes input for initializing the :class:`~gwsnr.GWSNR` class. Please refer to the respective classes for more details.
51
77
 
52
78
  Examples
53
79
  ----------
@@ -78,7 +104,11 @@ class GWRATES(CBCSourceParameterDistribution):
78
104
  +-------------------------------------+----------------------------------+
79
105
  |:attr:`~json_file_names` | `dict` |
80
106
  +-------------------------------------+----------------------------------+
81
- |:attr:`~directory` | `str` |
107
+ |:attr:`~interpolator_directory` | `str` |
108
+ +-------------------------------------+----------------------------------+
109
+ |:attr:`~ler_directory` | `str` |
110
+ +-------------------------------------+----------------------------------+
111
+ |:attr:`~gwsnr` | `bool` |
82
112
  +-------------------------------------+----------------------------------+
83
113
  |:attr:`~gw_param_sampler_dict` | `dict` |
84
114
  +-------------------------------------+----------------------------------+
@@ -161,14 +191,24 @@ class GWRATES(CBCSourceParameterDistribution):
161
191
 
162
192
  json_file_names = None
163
193
  """``dict`` \n
164
- Names of the json files to strore the necessary parameters.
194
+ Names of the json files to store the necessary parameters.
165
195
  """
166
196
 
167
- directory = None
197
+ interpolator_directory = None
168
198
  """``str`` \n
169
199
  Directory to store the interpolators.
170
200
  """
171
201
 
202
+ ler_directory = None
203
+ """``str`` \n
204
+ Directory to store the parameters.
205
+ """
206
+
207
+ gwsnr = None
208
+ """``bool`` \n
209
+ If True, the SNR will be calculated using the gwsnr package.
210
+ """
211
+
172
212
  gw_param_sampler_dict = None
173
213
  """``dict`` \n
174
214
  Dictionary of parameters to initialize the ``CBCSourceParameterDistribution`` class.
@@ -179,6 +219,55 @@ class GWRATES(CBCSourceParameterDistribution):
179
219
  Dictionary of parameters to initialize the ``GWSNR`` class.
180
220
  """
181
221
 
222
+ gw_param = None
223
+ """``dict`` \n
224
+ Dictionary of GW source parameters. The included parameters and their units are as follows (for default settings):\n
225
+ +--------------------+--------------+--------------------------------------+
226
+ | Parameter | Units | Description |
227
+ +====================+==============+======================================+
228
+ | zs | | redshift of the source |
229
+ +--------------------+--------------+--------------------------------------+
230
+ | geocent_time | s | GPS time of coalescence |
231
+ +--------------------+--------------+--------------------------------------+
232
+ | ra | rad | right ascension |
233
+ +--------------------+--------------+--------------------------------------+
234
+ | dec | rad | declination |
235
+ +--------------------+--------------+--------------------------------------+
236
+ | phase | rad | phase of GW at reference frequency |
237
+ +--------------------+--------------+--------------------------------------+
238
+ | psi | rad | polarization angle |
239
+ +--------------------+--------------+--------------------------------------+
240
+ | theta_jn | rad | inclination angle |
241
+ +--------------------+--------------+--------------------------------------+
242
+ | luminosity_distance| Mpc | luminosity distance |
243
+ +--------------------+--------------+--------------------------------------+
244
+ | mass_1_source | Msun | mass_1 of the compact binary |
245
+ | | | (source frame) |
246
+ +--------------------+--------------+--------------------------------------+
247
+ | mass_2_source | Msun | mass_2 of the compact binary |
248
+ | | | (source frame) |
249
+ +--------------------+--------------+--------------------------------------+
250
+ | mass_1 | Msun | mass_1 of the compact binary |
251
+ | | | (detector frame) |
252
+ +--------------------+--------------+--------------------------------------+
253
+ | mass_2 | Msun | mass_2 of the compact binary |
254
+ | | | (detector frame) |
255
+ +--------------------+--------------+--------------------------------------+
256
+ | L1 | | optimal snr of L1 |
257
+ +--------------------+--------------+--------------------------------------+
258
+ | H1 | | optimal snr of H1 |
259
+ +--------------------+--------------+--------------------------------------+
260
+ | V1 | | optimal snr of V1 |
261
+ +--------------------+--------------+--------------------------------------+
262
+ | optimal_snr_net | | optimal snr of the network |
263
+ +--------------------+--------------+--------------------------------------+
264
+ """
265
+
266
+ gw_param_detectable = None
267
+ """``dict`` \n
268
+ Dictionary of detectable GW source parameters. It includes the same parameters as the :attr:`~gw_param` attribute.
269
+ """
270
+
182
271
  def __init__(
183
272
  self,
184
273
  npool=int(4),
@@ -205,7 +294,7 @@ class GWRATES(CBCSourceParameterDistribution):
205
294
  self.cosmo = cosmology if cosmology else LambdaCDM(H0=70, Om0=0.3, Ode0=0.7)
206
295
  self.size = size
207
296
  self.batch_size = batch_size
208
- self.json_file_names = dict(gwrates_param="gwrates_params.json", gw_param="gw_param.json", gw_param_detectable="gw_param_detectable.json",)
297
+ self.json_file_names = dict(gwrates_params="gwrates_params.json", gw_param="gw_param.json", gw_param_detectable="gw_param_detectable.json",)
209
298
  if json_file_names:
210
299
  self.json_file_names.update(json_file_names)
211
300
  self.interpolator_directory = interpolator_directory
@@ -229,7 +318,7 @@ class GWRATES(CBCSourceParameterDistribution):
229
318
  self.list_of_detectors = list_of_detectors
230
319
 
231
320
  # store all the gwrates input parameters
232
- self.store_gwrates_params(output_jsonfile=self.json_file_names["gwrates_param"])
321
+ self.store_gwrates_params(output_jsonfile=self.json_file_names["gwrates_params"])
233
322
 
234
323
  if verbose:
235
324
  initialization()
@@ -245,79 +334,79 @@ class GWRATES(CBCSourceParameterDistribution):
245
334
 
246
335
  # print all relevant functions and sampler priors
247
336
  print("\n GWRATES set up params:")
248
- print("npool = ", self.npool)
249
- print("z_min = ", self.z_min)
250
- print("z_max = ", self.z_max)
251
- print("event_type = ", self.event_type)
252
- print("size = ", self.size)
253
- print("batch_size = ", self.batch_size)
254
- print("cosmology = ", self.cosmo)
337
+ print(f'npool = {self.npool},')
338
+ print(f'z_min = {self.z_min},')
339
+ print(f'z_max = {self.z_max},')
340
+ print(f"event_type = '{self.event_type}',")
341
+ print(f'size = {self.size},')
342
+ print(f'batch_size = {self.batch_size},')
343
+ print(f'cosmology = {self.cosmo},')
255
344
  if self.snr:
256
- print("snr_finder = ", self.snr)
345
+ print(f'snr_finder = {self.snr},')
257
346
  if self.pdet:
258
- print("pdet_finder = ", self.pdet)
259
- print("json_file_names = ", self.json_file_names)
260
- print("interpolator_directory = ", self.interpolator_directory)
261
- print("ler_directory = ", self.ler_directory)
347
+ print(f'pdet_finder = {self.pdet},')
348
+ print(f'json_file_names = {self.json_file_names},')
349
+ print(f'interpolator_directory = {self.interpolator_directory},')
350
+ print(f'ler_directory = {self.ler_directory},')
262
351
 
263
352
  print("\n GWRATES also takes CBCSourceParameterDistribution params as kwargs, as follows:")
264
- print("source_priors=", self.gw_param_sampler_dict["source_priors"])
265
- print("source_priors_params=", self.gw_param_sampler_dict["source_priors_params"])
266
- print("spin_zero=", self.gw_param_sampler_dict["spin_zero"])
267
- print("spin_precession=", self.gw_param_sampler_dict["spin_precession"])
268
- print("create_new_interpolator=", self.gw_param_sampler_dict["create_new_interpolator"])
353
+ print(f"source_priors = {self.gw_param_sampler_dict['source_priors']},")
354
+ print(f"source_priors_params = {self.gw_param_sampler_dict['source_priors_params']},")
355
+ print(f"spin_zero = {self.gw_param_sampler_dict['spin_zero']},")
356
+ print(f"spin_precession = {self.gw_param_sampler_dict['spin_precession']},")
357
+ print(f"create_new_interpolator = {self.gw_param_sampler_dict['create_new_interpolator']},")
269
358
 
270
359
  if self.gwsnr:
271
- print("\n GWRATES also takes GWSNR params as kwargs, as follows:")
272
- print("mtot_min = ", self.snr_calculator_dict["mtot_min"])
273
- print("mtot_max = ", self.snr_calculator_dict["mtot_max"])
274
- print("ratio_min = ", self.snr_calculator_dict["ratio_min"])
275
- print("ratio_max = ", self.snr_calculator_dict["ratio_max"])
276
- print("mtot_resolution = ", self.snr_calculator_dict["mtot_resolution"])
277
- print("ratio_resolution = ", self.snr_calculator_dict["ratio_resolution"])
278
- print("sampling_frequency = ", self.snr_calculator_dict["sampling_frequency"])
279
- print("waveform_approximant = ", self.snr_calculator_dict["waveform_approximant"])
280
- print("minimum_frequency = ", self.snr_calculator_dict["minimum_frequency"])
281
- print("snr_type = ", self.snr_calculator_dict["snr_type"])
282
- print("psds = ", self.snr_calculator_dict["psds"])
283
- print("ifos = ", self.snr_calculator_dict["ifos"])
284
- print("interpolator_dir = ", self.snr_calculator_dict["interpolator_dir"])
285
- print("create_new_interpolator = ", self.snr_calculator_dict["create_new_interpolator"])
286
- print("gwsnr_verbose = ", self.snr_calculator_dict["gwsnr_verbose"])
287
- print("multiprocessing_verbose = ", self.snr_calculator_dict["multiprocessing_verbose"])
288
- print("mtot_cut = ", self.snr_calculator_dict["mtot_cut"])
289
- del self.gwsnr
360
+ print("\n LeR also takes gwsnr.GWSNR params as kwargs, as follows:")
361
+ print(f"mtot_min = {self.snr_calculator_dict['mtot_min']},")
362
+ print(f"mtot_max = {self.snr_calculator_dict['mtot_max']},")
363
+ print(f"ratio_min = {self.snr_calculator_dict['ratio_min']},")
364
+ print(f"ratio_max = {self.snr_calculator_dict['ratio_max']},")
365
+ print(f"mtot_resolution = {self.snr_calculator_dict['mtot_resolution']},")
366
+ print(f"ratio_resolution = {self.snr_calculator_dict['ratio_resolution']},")
367
+ print(f"sampling_frequency = {self.snr_calculator_dict['sampling_frequency']},")
368
+ print(f"waveform_approximant = '{self.snr_calculator_dict['waveform_approximant']}',")
369
+ print(f"minimum_frequency = {self.snr_calculator_dict['minimum_frequency']},")
370
+ print(f"snr_type = '{self.snr_calculator_dict['snr_type']}',")
371
+ print(f"psds = {self.snr_calculator_dict['psds']},")
372
+ print(f"ifos = {self.snr_calculator_dict['ifos']},")
373
+ print(f"interpolator_dir = '{self.snr_calculator_dict['interpolator_dir']}',")
374
+ print(f"create_new_interpolator = {self.snr_calculator_dict['create_new_interpolator']},")
375
+ print(f"gwsnr_verbose = {self.snr_calculator_dict['gwsnr_verbose']},")
376
+ print(f"multiprocessing_verbose = {self.snr_calculator_dict['multiprocessing_verbose']},")
377
+ print(f"mtot_cut = {self.snr_calculator_dict['mtot_cut']},")
378
+ # del self.gwsnr
290
379
 
291
380
  print("\n For reference, the chosen source parameters are listed below:")
292
- print("merger_rate_density = ", self.gw_param_samplers["merger_rate_density"])
381
+ print(f"merger_rate_density = '{self.gw_param_samplers['merger_rate_density']}'")
293
382
  print("merger_rate_density_params = ", self.gw_param_samplers_params["merger_rate_density"])
294
- print("source_frame_masses = ", self.gw_param_samplers["source_frame_masses"])
383
+ print(f"source_frame_masses = '{self.gw_param_samplers['source_frame_masses']}'")
295
384
  print("source_frame_masses_params = ", self.gw_param_samplers_params["source_frame_masses"])
296
- print("geocent_time = ", self.gw_param_samplers["geocent_time"])
385
+ print(f"geocent_time = '{self.gw_param_samplers['geocent_time']}'")
297
386
  print("geocent_time_params = ", self.gw_param_samplers_params["geocent_time"])
298
- print("ra = ", self.gw_param_samplers["ra"])
387
+ print(f"ra = '{self.gw_param_samplers['ra']}'")
299
388
  print("ra_params = ", self.gw_param_samplers_params["ra"])
300
- print("dec = ", self.gw_param_samplers["dec"])
389
+ print(f"dec = '{self.gw_param_samplers['dec']}'")
301
390
  print("dec_params = ", self.gw_param_samplers_params["dec"])
302
- print("phase = ", self.gw_param_samplers["phase"])
391
+ print(f"phase = '{self.gw_param_samplers['phase']}'")
303
392
  print("phase_params = ", self.gw_param_samplers_params["phase"])
304
- print("psi = ", self.gw_param_samplers["psi"])
393
+ print(f"psi = '{self.gw_param_samplers['psi']}'")
305
394
  print("psi_params = ", self.gw_param_samplers_params["psi"])
306
- print("theta_jn = ", self.gw_param_samplers["theta_jn"])
395
+ print(f"theta_jn = '{self.gw_param_samplers['theta_jn']}'")
307
396
  print("theta_jn_params = ", self.gw_param_samplers_params["theta_jn"])
308
- if self.spin_zero==False:
309
- print("a_1 = ", self.gw_param_samplers["a_1"])
397
+ if self.spin_zero is False:
398
+ print(f"a_1 = '{self.gw_param_samplers['a_1']}'")
310
399
  print("a_1_params = ", self.gw_param_samplers_params["a_1"])
311
- print("a_2 = ", self.gw_param_samplers["a_2"])
400
+ print(f"a_2 = '{self.gw_param_samplers['a_2']}'")
312
401
  print("a_2_params = ", self.gw_param_samplers_params["a_2"])
313
- if self.spin_precession==True:
314
- print("tilt_1 = ", self.gw_param_samplers["tilt_1"])
402
+ if self.spin_precession is True:
403
+ print(f"tilt_1 = '{self.gw_param_samplers['tilt_1']}'")
315
404
  print("tilt_1_params = ", self.gw_param_samplers_params["tilt_1"])
316
- print("tilt_2 = ", self.gw_param_samplers["tilt_2"])
405
+ print(f"tilt_2 = '{self.gw_param_samplers['tilt_2']}'")
317
406
  print("tilt_2_params = ", self.gw_param_samplers_params["tilt_2"])
318
- print("phi_12 = ", self.gw_param_samplers["phi_12"])
407
+ print(f"phi_12 = '{self.gw_param_samplers['phi_12']}'")
319
408
  print("phi_12_params = ", self.gw_param_samplers_params["phi_12"])
320
- print("phi_jl = ", self.gw_param_samplers["phi_jl"])
409
+ print(f"phi_jl = '{self.gw_param_samplers['phi_jl']}'")
321
410
  print("phi_jl_params = ", self.gw_param_samplers_params["phi_jl"])
322
411
 
323
412
 
@@ -409,10 +498,7 @@ class GWRATES(CBCSourceParameterDistribution):
409
498
 
410
499
  def class_initialization(self, params=None):
411
500
  """
412
- Function to initialize the parent classes. List of relevant initialized instances, \n
413
- 1. self.sample_source_redshift
414
- 2. self.sample_gw_parameters
415
- 3. self.normalization_pdf_z
501
+ Function to initialize the parent classes.
416
502
 
417
503
  Parameters
418
504
  ----------
@@ -422,10 +508,6 @@ class GWRATES(CBCSourceParameterDistribution):
422
508
 
423
509
  # initialization of CompactBinaryPopulation class
424
510
  # it also initializes the CBCSourceRedshiftDistribution class
425
- # list of relevant initialized instances,
426
- # 1. self.sample_source_redshift
427
- # 2. self.sample_gw_parameters
428
- # 3. self.normalization_pdf_z
429
511
  input_params = dict(
430
512
  z_min=self.z_min,
431
513
  z_max=self.z_max,
@@ -433,10 +515,9 @@ class GWRATES(CBCSourceParameterDistribution):
433
515
  event_type=self.event_type,
434
516
  source_priors=None,
435
517
  source_priors_params=None,
436
-
437
518
  spin_zero=True,
438
519
  spin_precession=False,
439
- interpolator_directory=self.interpolator_directory,
520
+ directory=self.interpolator_directory,
440
521
  create_new_interpolator=False,
441
522
  )
442
523
  if params:
@@ -455,7 +536,7 @@ class GWRATES(CBCSourceParameterDistribution):
455
536
  cosmology=input_params["cosmology"],
456
537
  spin_zero=input_params["spin_zero"],
457
538
  spin_precession=input_params["spin_precession"],
458
- directory=input_params["interpolator_directory"],
539
+ directory=input_params["directory"],
459
540
  create_new_interpolator=input_params["create_new_interpolator"],
460
541
  )
461
542
 
@@ -464,7 +545,7 @@ class GWRATES(CBCSourceParameterDistribution):
464
545
 
465
546
  def gwsnr_intialization(self, params=None):
466
547
  """
467
- Function to initialize the gwsnr class
548
+ Function to initialize the GWSNR class from the `gwsnr` package.
468
549
 
469
550
  Parameters
470
551
  ----------
@@ -490,10 +571,12 @@ class GWRATES(CBCSourceParameterDistribution):
490
571
  ifos=None,
491
572
  interpolator_dir=self.interpolator_directory,
492
573
  create_new_interpolator=False,
493
- gwsnr_verbose=True,
574
+ gwsnr_verbose=False,
494
575
  multiprocessing_verbose=True,
495
576
  mtot_cut=True,
496
577
  )
578
+ # if self.event_type == "BNS":
579
+ # input_params["mtot_max"]= 18.
497
580
  if params:
498
581
  for key, value in params.items():
499
582
  if key in input_params:
@@ -574,7 +657,7 @@ class GWRATES(CBCSourceParameterDistribution):
574
657
  self, size=None, resume=False, save_batch=False, output_jsonfile=None,
575
658
  ):
576
659
  """
577
- Function to generate gw GW source parameters. This function also stores the parameters in json file.
660
+ Function to generate gw GW source parameters. This function calls the gw_sampling_routine function to generate the parameters in batches. The generated parameters are stored in a json file; and if save_batch=True, it keeps updating the file in batches.
578
661
 
579
662
  Parameters
580
663
  ----------
@@ -588,13 +671,12 @@ class GWRATES(CBCSourceParameterDistribution):
588
671
  if True, the function will save the parameters in batches. if False, the function will save all the parameters at the end of sampling. save_batch=False is faster.
589
672
  output_jsonfile : `str`
590
673
  json file name for storing the parameters.
591
- default output_jsonfile = 'gw_params.json'.
674
+ default output_jsonfile = 'gw_params.json'. Note that this file will be stored in the self.ler_directory.
592
675
 
593
676
  Returns
594
677
  ----------
595
678
  gw_param : `dict`
596
- dictionary of gw GW source parameters.
597
- gw_param.keys() = ['zs', 'geocent_time', 'ra', 'dec', 'phase', 'psi', 'theta_jn', 'luminosity_distance', 'mass_1_source', 'mass_2_source', 'mass_1', 'mass_2', 'optimal_snr_net', 'L1', 'H1', 'V1']
679
+ dictionary of gw GW source parameters. Refer to :attr:`~gw_param` for details.
598
680
 
599
681
  Examples
600
682
  ----------
@@ -643,25 +725,26 @@ class GWRATES(CBCSourceParameterDistribution):
643
725
 
644
726
  def gw_sampling_routine(self, size, output_jsonfile, resume=False, save_batch=True):
645
727
  """
646
- Function to generate gw GW source parameters. This function also stores the parameters in json file.
728
+ Function to generate GW source parameters. This function also stores the parameters in json file in the current batch if save_batch=True.
647
729
 
648
730
  Parameters
649
731
  ----------
650
732
  size : `int`
651
733
  number of samples.
652
734
  default size = 100000.
653
- resume : `bool`
654
- resume = False (default) or True.
655
- if True, the function will resume from the last batch.
656
735
  output_jsonfile : `str`
657
736
  json file name for storing the parameters.
658
- default output_jsonfile = 'gw_params.json'.
737
+ default output_jsonfile = 'gw_params.json'. Note that this file will be stored in the self.ler_directory.
738
+ resume : `bool`
739
+ resume = False (default) or True.
740
+ if True, it appends the new samples to the existing json file.
741
+ save_batch : `bool`
742
+ if True, the function will save the parameters in batches. if False, the function will save all the parameters at the end of sampling. save_batch=False is faster.
659
743
 
660
744
  Returns
661
745
  ----------
662
746
  gw_param : `dict`
663
- dictionary of gw GW source parameters.
664
- gw_param.keys() = ['zs', 'geocent_time', 'ra', 'dec', 'phase', 'psi', 'theta_jn', 'luminosity_distance', 'mass_1_source', 'mass_2_source', 'mass_1', 'mass_2', 'optimal_snr_net', 'L1', 'H1', 'V1']
747
+ dictionary of gw GW source parameters. Refer to :attr:`~gw_param` for details.
665
748
  """
666
749
 
667
750
  # get gw params
@@ -694,13 +777,19 @@ class GWRATES(CBCSourceParameterDistribution):
694
777
  self,
695
778
  gw_param=None,
696
779
  snr_threshold=8.0,
780
+ pdet_threshold=0.5,
697
781
  output_jsonfile=None,
698
782
  detectability_condition="step_function",
699
783
  snr_recalculation=False,
700
- threshold_snr_recalculation=6.0,
784
+ snr_threshold_recalculation=[4, 20],
701
785
  ):
702
786
  """
703
- Function to calculate the gw rate. This function also stores the parameters of the detectable events in json file.
787
+ Function to calculate the GW rate. This function also stores the parameters of the detectable events in json file. There are two conditions for detectability: 'step_function' and 'pdet'.
788
+
789
+ 1. 'step_function': If two images have SNR>8.0, then the event is detectable. This is a step function. This is with the assumption that SNR function is provided and not None.
790
+ 2. 'pdet':
791
+ i) If self.pdet is None and self.snr is not None, then it will calculate the pdet from the snr. There is no hard cut for this pdet and can have value ranging from 0 to 1 near the threshold.
792
+ ii) If self.pdet is not None, then it will use the generated pdet.
704
793
 
705
794
  Parameters
706
795
  ----------
@@ -710,6 +799,9 @@ class GWRATES(CBCSourceParameterDistribution):
710
799
  snr_threshold : `float`
711
800
  threshold for detection signal to noise ratio.
712
801
  e.g. snr_threshold = 8.
802
+ pdet_threshold : `float`
803
+ threshold for detection probability.
804
+ e.g. pdet_threshold = 0.5.
713
805
  output_jsonfile : `str`
714
806
  json file name for storing the parameters of the detectable events.
715
807
  default output_jsonfile = 'gw_params_detectable.json'.
@@ -718,152 +810,409 @@ class GWRATES(CBCSourceParameterDistribution):
718
810
  default detectability_condition = 'step_function'.
719
811
  other options are 'pdet'.
720
812
  snr_recalculation : `bool`
721
- if True, the SNR of centain events (snr>threshold_snr_recalculation)will be recalculate with 'inner product'. This is useful when the snr is calculated with 'ann' method.
813
+ if True, the SNR of centain events (snr>snr_threshold_recalculation)will be recalculate with 'inner-product' method. This is useful when the snr is calculated with 'ann' method.
722
814
  default snr_recalculation = False.
723
- threshold_snr_recalculation : `float`
724
- threshold for recalculation of detection signal to noise ratio.
815
+ snr_threshold_recalculation : `list`
816
+ lower and upper threshold for recalculation of detection signal to noise ratio.
817
+ default snr_threshold_recalculation = [4, 20].
725
818
 
726
819
  Returns
727
820
  ----------
728
821
  total_rate : `float`
729
822
  total gw rate (Mpc^-3 yr^-1).
730
823
  gw_param : `dict`
731
- dictionary of gw GW source parameters of the detectable events.
732
- gw_param.keys() = ['zs', 'geocent_time', 'ra', 'dec', 'phase', 'psi', 'theta_jn', 'luminosity_distance', 'mass_1_source', 'mass_2_source', 'mass_1', 'mass_2', 'optimal_snr_net', 'L1', 'H1', 'V1']
824
+ dictionary of gw GW source parameters of the detectable events. Refer to :attr:`~gw_param` for details.
733
825
 
734
826
  Examples
735
827
  ----------
736
828
  >>> from ler.rates import GWRATES
737
829
  >>> ler = GWRATES()
830
+ >>> ler.gw_cbc_statistics();
738
831
  >>> total_rate, gw_param = ler.gw_rate()
739
832
  """
740
833
 
741
- # call self.json_file_names["gwrates_param"] and for adding the final results
742
- data = load_json(self.ler_directory+"/"+self.json_file_names["gwrates_param"])
834
+ gw_param = self._load_param(gw_param)
835
+ total_events = len(gw_param["zs"])
836
+
837
+ # below is use when the snr is calculated with 'ann' method of `gwsnr`
838
+ if snr_recalculation:
839
+ gw_param = self._recalculate_snr(gw_param, snr_threshold_recalculation)
840
+
841
+ # find index of detectable events
842
+ idx_detectable = self._find_detectable_index(gw_param, snr_threshold, pdet_threshold, detectability_condition)
843
+
844
+ detectable_events = np.sum(idx_detectable)
845
+ # montecarlo integration
846
+ # The total rate R = norm <Theta(rho-rhoc)>
847
+ total_rate = self.rate_function(detectable_events, total_events)
848
+
849
+ # store all detectable params in json file
850
+ self._save_detectable_params(output_jsonfile, gw_param, idx_detectable, key_file_name="gw_param_detectable", nan_to_num=False, verbose=True, replace_jsonfile=True)
851
+
852
+ # append ler_param and save it
853
+ self._append_ler_param(total_rate, detectability_condition)
743
854
 
744
- # get gw params from json file if not provided
745
- if gw_param is None:
746
- gw_param = self.json_file_names["gw_param"]
747
- if type(gw_param) == str:
748
- self.json_file_names["gw_param"] = gw_param
749
- path_ = self.ler_directory+"/"+gw_param
750
- print(f"getting gw_params from json file {path_}...")
751
- gw_param = get_param_from_json(self.ler_directory+"/"+gw_param)
855
+ return total_rate, gw_param
856
+
857
+ def _load_param(self, param):
858
+ """
859
+ Helper function to load or copy GW parameters.
860
+
861
+ Parameters
862
+ ----------
863
+ param : `dict` or `str`
864
+ dictionary of GW parameters or json file name.
865
+
866
+ Returns
867
+ ----------
868
+ param : `dict`
869
+ dictionary of GW parameters.
870
+ """
871
+
872
+ if param is None:
873
+ param = self.json_file_names["gw_param"]
874
+ if isinstance(param, str):
875
+ path_ = self.ler_directory + "/" + param
876
+ print(f"Getting GW parameters from json file {path_}...")
877
+ return get_param_from_json(path_)
752
878
  else:
753
- print("using provided gw_param dict...")
754
- # store all params in json file self.json_file_names["gw_param"]
755
- gw_param = gw_param.copy()
879
+ print("Using provided {param_type} dict...")
880
+ return param.copy()
756
881
 
757
- # recalculate snr if required
758
- # this ensures that the snr is recalculated for the detectable events
759
- # with inner product
760
- total_events = len(gw_param["zs"])
761
- if snr_recalculation:
762
- # select only above centain snr threshold
763
- param = gw_param["optimal_snr_net"]
764
- idx_detectable = param > threshold_snr_recalculation
765
- # reduce the size of the dict
766
- for key, value in gw_param.items():
767
- gw_param[key] = value[idx_detectable]
768
- # recalculate more accurate snrs
769
- snrs = self.snr_bilby(gw_param_dict=gw_param)
770
- gw_param.update(snrs)
882
+ def _recalculate_snr(self, gw_param, snr_threshold_recalculation):
883
+ """
884
+ Recalculates SNR for events where the initial SNR is above a given threshold.
771
885
 
886
+ Parameters
887
+ ----------
888
+ gw_param : `dict`
889
+ dictionary of GW source parameters.
890
+ snr_threshold_recalculation : `list`
891
+ lower and upper threshold for recalculation of detection signal to noise ratio.
892
+ default snr_threshold_recalculation = [4, 20].
893
+
894
+ Returns
895
+ ----------
896
+ gw_param : `dict`
897
+ dictionary of GW source parameters.
898
+ """
899
+
900
+ snr_param = gw_param["optimal_snr_net"]
901
+ idx_detectable = (snr_param > snr_threshold_recalculation[0]) & (snr_param < snr_threshold_recalculation[1])
902
+ # reduce the size of the dict
903
+ for key, value in gw_param.items():
904
+ gw_param[key] = value[idx_detectable]
905
+ # recalculate more accurate snrs
906
+ snrs = self.snr_bilby(gw_param_dict=gw_param)
907
+ gw_param.update(snrs)
908
+ return gw_param
909
+
910
+ def _find_detectable_index(self, gw_param, snr_threshold, pdet_threshold, detectability_condition):
911
+ """
912
+ Find the index of detectable events based on SNR or p_det.
913
+
914
+ Parameters
915
+ ----------
916
+ gw_param : `dict`
917
+ dictionary of GW source parameters.
918
+ snr_threshold : `float`
919
+ threshold for detection signal to noise ratio.
920
+ pdet_threshold : `float`
921
+ threshold for detection probability.
922
+ detectability_condition : `str`
923
+ detectability condition.
924
+ default detectability_condition = 'step_function'.
925
+ other options are 'pdet'.
926
+
927
+ Returns
928
+ ----------
929
+ idx_detectable : `numpy.ndarray`
930
+ index of detectable events.
931
+ """
932
+
772
933
  if self.snr:
773
934
  if "optimal_snr_net" not in gw_param:
774
- raise ValueError("'optimal_snr_net' not in gw parm dict provided")
935
+ raise ValueError("'optimal_snr_net' not in gw param dict provided")
775
936
  if detectability_condition == "step_function":
776
937
  print("given detectability_condition == 'step_function'")
777
938
  param = gw_param["optimal_snr_net"]
778
939
  threshold = snr_threshold
779
-
780
-
781
940
  elif detectability_condition == "pdet":
782
941
  print("given detectability_condition == 'pdet'")
783
942
  param = 1 - norm.cdf(snr_threshold - gw_param["optimal_snr_net"])
784
943
  gw_param["pdet_net"] = param
785
- threshold = 0.5
944
+ threshold = pdet_threshold
786
945
  elif self.pdet:
787
946
  if "pdet_net" in gw_param:
788
947
  print("given detectability_condition == 'pdet'")
789
948
  param = gw_param["pdet_net"]
790
- threshold = 0.5
949
+ threshold = pdet_threshold
791
950
  else:
792
- raise ValueError("'pdet_net' not in gw parm dict provided")
951
+ raise ValueError("'pdet_net' not in gw param dict provided")
793
952
 
794
953
  idx_detectable = param > threshold
795
- detectable_events = np.sum(idx_detectable)
796
- # montecarlo integration
797
- # The total rate R = norm <Theta(rho-rhoc)>
798
- total_rate = self.normalization_pdf_z * detectable_events / total_events
799
- print(f"total gw rate (yr^-1) (with step function): {total_rate}")
800
- print(f"number of simulated gw detectable events: {detectable_events}")
801
- print(f"number of all simulated gw events: {total_events}")
802
-
954
+ return idx_detectable
955
+
956
+ def rate_function(self, detectable_size, total_size, verbose=True):
957
+ """
958
+ General helper function to calculate the rate for GW events.
959
+
960
+ Parameters
961
+ ----------
962
+ detectable_size : `int`
963
+ number of detectable events.
964
+ total_size : `int`
965
+ total number of events.
966
+ param_type : `str`
967
+ type of parameters.
968
+
969
+ Returns
970
+ ----------
971
+ rate : `float`
972
+ rate of the events.
973
+
974
+ Examples
975
+ ----------
976
+ >>> from ler.rates import LeR
977
+ >>> ler = LeR()
978
+ >>> rate = ler.rate_function(detectable_size=100, total_size=1000)
979
+ """
980
+
981
+ normalization = self.normalization_pdf_z
982
+ rate = normalization * detectable_size / total_size
983
+
984
+ if verbose:
985
+ print(f"total GW event rate (yr^-1): {rate}")
986
+ print(f"number of simulated GW detectable events: {detectable_size}")
987
+ print(f"number of simulated all GW events: {total_size}")
988
+
989
+ return rate
990
+
991
+ def _save_detectable_params(self,
992
+ output_jsonfile,
993
+ param,
994
+ idx_detectable,
995
+ key_file_name="gw_param_detectable",
996
+ nan_to_num=False,
997
+ verbose=True,
998
+ replace_jsonfile=True,
999
+ ):
1000
+ """
1001
+ Helper function to save the detectable parameters in json file.
1002
+
1003
+ Parameters
1004
+ ----------
1005
+ output_jsonfile : `str`
1006
+ json file name for storing the parameters of the detectable events. This is stored in the self.ler_directory.
1007
+ param : `dict`
1008
+ dictionary of GW source parameters.
1009
+ idx_detectable : `numpy.ndarray`
1010
+ index of detectable events.
1011
+ key_file_name : `str`
1012
+ key name for the json file to be added in self.json_file_names.
1013
+ nan_to_num : `bool`
1014
+ if True, it will replace nan with 0.
1015
+ default nan_to_num = False.
1016
+ verbose : `bool`
1017
+ if True, it will print the path of the json file.
1018
+ default verbose = True.
1019
+ replace_jsonfile : `bool`
1020
+ if True, it will replace the json file. If False, it will append the json file.
1021
+ """
1022
+
803
1023
  # store all detectable params in json file
804
- for key, value in gw_param.items():
805
- gw_param[key] = value[idx_detectable]
1024
+ if nan_to_num:
1025
+ for key, value in param.items():
1026
+ param[key] = np.nan_to_num(value[idx_detectable])
1027
+ else:
1028
+ for key, value in param.items():
1029
+ param[key] = value[idx_detectable]
806
1030
 
807
1031
  # store all detectable params in json file
808
1032
  if output_jsonfile is None:
809
- output_jsonfile = self.json_file_names["gw_param_detectable"]
1033
+ output_jsonfile = self.json_file_names[key_file_name]
810
1034
  else:
811
- self.json_file_names["gw_param_detectable"] = output_jsonfile
812
- path_ = self.ler_directory+"/"+output_jsonfile
813
- print(f"storing detectable gw params in {path_}")
814
- append_json(self.ler_directory+"/"+output_jsonfile, gw_param, replace=True)
1035
+ self.json_file_names[key_file_name] = output_jsonfile
1036
+
1037
+ output_path = self.ler_directory+"/"+output_jsonfile
1038
+ if verbose:
1039
+ print(f"storing detectable params in {output_path}")
1040
+ append_json(output_path, param, replace=replace_jsonfile)
1041
+
1042
+ def _append_ler_param(self, total_rate, detectability_condition):
1043
+ """
1044
+ Helper function to append the final results, total_rate, in the json file.
815
1045
 
1046
+ Parameters
1047
+ ----------
1048
+ total_rate : `float`
1049
+ total rate.
1050
+ detectability_condition : `str`
1051
+ detectability condition.
1052
+ """
1053
+
1054
+ data = load_json(self.ler_directory+"/"+self.json_file_names["gwrates_params"])
816
1055
  # write the results
817
- data['detectable_gw_rate_per_year'] = total_rate
1056
+ data["detectable_gw_rate_per_year"] = total_rate
818
1057
  data["detectability_condition"] = detectability_condition
819
- append_json(self.ler_directory+"/"+self.json_file_names["gwrates_param"], data, replace=True)
820
-
821
- return total_rate, gw_param
1058
+ append_json(self.ler_directory+"/"+self.json_file_names["gwrates_params"], data, replace=True)
822
1059
 
823
1060
  def selecting_n_gw_detectable_events(
824
1061
  self,
825
1062
  size=100,
826
1063
  batch_size=None,
827
1064
  snr_threshold=8.0,
1065
+ pdet_threshold=0.5,
828
1066
  resume=False,
829
1067
  output_jsonfile="gw_params_n_detectable.json",
830
1068
  meta_data_file="meta_gw.json",
1069
+ detectability_condition="step_function",
831
1070
  trim_to_size=True,
1071
+ snr_recalculation=False,
1072
+ snr_threshold_recalculation=[4, 12],
832
1073
  ):
833
1074
  """
834
- Function to select n gw detectable events.
1075
+ Function to generate n GW detectable events. This fuction samples the GW parameters and save only the detectable events in json file. It also records metadata in the JSON file, which includes the total number of events and the cumulative rate of events. This functionality is particularly useful for generating a fixed or large number of detectable events until the event rates stabilize.
835
1076
 
836
1077
  Parameters
837
1078
  ----------
838
1079
  size : `int`
839
1080
  number of samples to be selected.
840
1081
  default size = 100.
1082
+ batch_size : `int`
1083
+ batch size for sampling.
1084
+ default batch_size = 50000.
841
1085
  snr_threshold : `float`
842
1086
  threshold for detection signal to noise ratio.
843
1087
  e.g. snr_threshold = 8.
1088
+ pdet_threshold : `float`
1089
+ threshold for detection probability.
1090
+ default pdet_threshold = 0.5.
844
1091
  resume : `bool`
845
- if True, it will resume the sampling from the last batch.
846
- default resume = False.
1092
+ resume = False (default) or True.
1093
+ if True, the function will resume from the last batch.
847
1094
  output_jsonfile : `str`
848
- json file name for storing the parameters.
849
- default output_jsonfile = 'gw_params_detectable.json'.
1095
+ json file name for storing the parameters of the detectable events.
1096
+ default output_jsonfile = 'n_gw_param_detectable.json'.
1097
+ meta_data_file : `str`
1098
+ json file name for storing the metadata.
1099
+ default meta_data_file = 'meta_gw.json'.
1100
+ detectability_condition : `str`
1101
+ detectability condition.
1102
+ default detectability_condition = 'step_function'.
1103
+ other options are 'pdet'.
1104
+ trim_to_size : `bool`
1105
+ if True, the final result will be trimmed to size.
1106
+ default trim_to_size = True.
1107
+ snr_recalculation : `bool`
1108
+ if True, the SNR of centain events (snr>snr_threshold_recalculation)will be recalculate with 'inner-product' method. This is useful when the snr is calculated with 'ann' method of `gwsnr`.
1109
+ default snr_recalculation = False.
1110
+ snr_threshold_recalculation : `list`
1111
+ lower and upper threshold for recalculation of detection signal to noise ratio.
1112
+ default snr_threshold_recalculation = [4, 12].
850
1113
 
851
1114
  Returns
852
1115
  ----------
853
1116
  param_final : `dict`
854
- dictionary of gw GW source parameters of the detectable events.
855
- param_final.keys() = ['zs', 'geocent_time', 'ra', 'dec', 'phase', 'psi', 'theta_jn', 'luminosity_distance', 'mass_1_source', 'mass_2_source', 'mass_1', 'mass_2', 'optimal_snr_net', 'L1', 'H1', 'V1']
1117
+ dictionary of gw GW source parameters of the detectable events. Refer to :attr:`~gw_param` for details.
856
1118
 
857
1119
  Examples
858
1120
  ----------
859
- >>> from ler.rates import GWRATES
860
- >>> ler = GWRATES()
861
- >>> param_final = ler.selecting_n_gw_detectable_events(size=500)
1121
+ >>> from ler.rates import LeR
1122
+ >>> ler = LeR()
1123
+ >>> gw_param = ler.selecting_n_gw_detectable_events(size=100)
1124
+ """
1125
+
1126
+ # initial setup
1127
+ n, events_total, output_path, meta_data_path, buffer_file = self._initial_setup_for_n_event_selection(meta_data_file, output_jsonfile, resume, batch_size)
1128
+
1129
+ # loop until n samples are collected
1130
+ while n < size:
1131
+ # disable print statements
1132
+ with contextlib.redirect_stdout(None):
1133
+ self.dict_buffer = None # this is used to store the sampled gw_param in batches when running the sampling_routine
1134
+ gw_param = self.gw_sampling_routine(
1135
+ size=batch_size, output_jsonfile=buffer_file, save_batch=False,resume=False
1136
+ )
1137
+
1138
+ total_events_in_this_iteration = len(gw_param["zs"])
1139
+ # below is use when the snr is calculated with 'ann' method of `gwsnr`
1140
+ if snr_recalculation:
1141
+ # select only above centain snr threshold
1142
+ gw_param = self._recalculate_snr(gw_param, snr_threshold_recalculation)
1143
+
1144
+ # find index of detectable events
1145
+ idx_detectable = self._find_detectable_index(gw_param, snr_threshold, pdet_threshold, detectability_condition)
1146
+
1147
+ # store all params in json file
1148
+ self._save_detectable_params(output_jsonfile, gw_param, idx_detectable, key_file_name="n_gw_detectable_events", nan_to_num=False, verbose=False, replace_jsonfile=False)
1149
+
1150
+ n += np.sum(idx_detectable)
1151
+ events_total += total_events_in_this_iteration
1152
+ total_rate = self.rate_function(n, events_total, verbose=False)
1153
+
1154
+ # bookmark
1155
+ self._append_meta_data(meta_data_path, n, events_total, total_rate)
1156
+
1157
+ print(f"stored detectable gw params in {output_path}")
1158
+ print(f"stored meta data in {meta_data_path}")
1159
+
1160
+ if trim_to_size:
1161
+ param_final, total_rate = self._trim_results_to_size(size, output_path, meta_data_path)
1162
+ else:
1163
+ param_final = get_param_from_json(output_path)
1164
+
1165
+ # call self.json_file_names["ler_param"] and for adding the final results
1166
+ data = load_json(self.ler_directory+"/"+self.json_file_names["gwrates_params"])
1167
+ # write the results
1168
+ try:
1169
+ data["detectable_gw_rate_per_year"] = total_rate
1170
+ data["detectability_condition"] = detectability_condition
1171
+ except:
1172
+ meta = get_param_from_json(meta_data_path)
1173
+ data["detectable_gw_rate_per_year"] = meta["total_rate"][-1]
1174
+ data["detectability_condition"] = detectability_condition
1175
+
1176
+ append_json(self.ler_directory+"/"+self.json_file_names["gwrates_params"], data, replace=True)
1177
+
1178
+ return param_final
1179
+
1180
+
1181
+ def _initial_setup_for_n_event_selection(self, meta_data_file, output_jsonfile, resume, batch_size):
1182
+ """Helper function for selecting_n_gw_detectable_events and selecting_n_lensed_detectable_events functions.
1183
+
1184
+ Parameters
1185
+ ----------
1186
+ meta_data_file : `str`
1187
+ json file name for storing the metadata.
1188
+ output_jsonfile : `str`
1189
+ json file name for storing the parameters of the detectable events.
1190
+ resume : `bool`
1191
+ resume = False (default) or True.
1192
+ if True, the function will resume from the last batch.
1193
+ batch_size : `int`
1194
+ batch size for sampling.
1195
+ default batch_size = 50000.
1196
+
1197
+ Returns
1198
+ ----------
1199
+ n : `int`
1200
+ iterator.
1201
+ events_total : `int`
1202
+ total number of events.
1203
+ output_path : `str`
1204
+ path to the output json file.
1205
+ meta_data_path : `str`
1206
+ path to the metadata json file.
1207
+ buffer_file : `str`
1208
+ path to the buffer json file.
862
1209
  """
863
1210
 
864
1211
  meta_data_path = self.ler_directory+"/"+meta_data_file
865
1212
  output_path = self.ler_directory+"/"+output_jsonfile
866
-
1213
+ if meta_data_path==output_path:
1214
+ raise ValueError("meta_data_file and output_jsonfile cannot be same.")
1215
+
867
1216
  if batch_size is None:
868
1217
  batch_size = self.batch_size
869
1218
  else:
@@ -877,66 +1226,102 @@ class GWRATES(CBCSourceParameterDistribution):
877
1226
  if os.path.exists(meta_data_path):
878
1227
  os.remove(meta_data_path)
879
1228
  else:
1229
+ # get sample size as size from json file
880
1230
  if os.path.exists(output_path):
881
- # get sample size as nsamples from json file
882
1231
  param_final = get_param_from_json(output_path)
883
- n = load_json(meta_data_path)["detectable_events"][-1]
1232
+ n = len(param_final["zs"])
884
1233
  events_total = load_json(meta_data_path)["events_total"][-1]
885
1234
  del param_final
886
1235
  else:
887
1236
  n = 0
888
1237
  events_total = 0
889
1238
 
890
- buffer_file = self.ler_directory+"/"+"gw_params_buffer.json"
1239
+ buffer_file = "params_buffer.json"
891
1240
  print("collected number of detectable events = ", n)
892
- # loop until n samples are collected
893
- while n < size:
894
- # disable print statements
895
- with contextlib.redirect_stdout(None):
896
- self.dict_buffer = None
897
- gw_param = self.gw_sampling_routine(
898
- size=batch_size, output_jsonfile=buffer_file, save_batch=False, resume=False
899
- )
900
-
901
- # get snr
902
- snr = gw_param["optimal_snr_net"]
903
- # index of detectable events
904
- idx = snr > snr_threshold
905
1241
 
906
- # store all params in json file
907
- for key, value in gw_param.items():
908
- gw_param[key] = value[idx]
909
- append_json(file_name=output_path, new_dictionary=gw_param, replace=False)
910
-
911
- n += np.sum(idx)
912
- events_total += len(idx)
913
- total_rate = self.normalization_pdf_z * n / events_total
1242
+ return n, events_total, output_path, meta_data_path, buffer_file
914
1243
 
915
- # save meta data
916
- meta_data = dict(events_total=[events_total], detectable_events=[float(n)], total_rate=[total_rate])
917
- if os.path.exists(meta_data_path):
918
- try:
919
- append_json(file_name=meta_data_path, new_dictionary=meta_data, replace=False)
920
- except:
921
- append_json(file_name=meta_data_path, new_dictionary=meta_data, replace=True)
922
- else:
923
- append_json(file_name=meta_data_path, new_dictionary=meta_data, replace=True)
924
-
925
- print("collected number of detectable events = ", n)
926
- print("total number of events = ", events_total)
927
- print(f"total gw rate (yr^-1): {total_rate}")
1244
+ def _trim_results_to_size(self, size, output_path, meta_data_path):
1245
+ """
1246
+ Helper function of 'selecting_n_gw_detectable_events' and 'selecting_n_lensed_detectable_events' functions. Trims the data in the output file to the specified size and updates the metadata accordingly.
928
1247
 
929
- print(f"storing detectable gw params in {output_path}")
1248
+ Parameters
1249
+ ----------
1250
+ size : `int`
1251
+ number of samples to be selected.
1252
+ output_path : `str`
1253
+ path to the output json file.
1254
+ meta_data_path : `str`
1255
+ path to the metadata json file.
930
1256
 
931
- if trim_to_size:
932
- # trim the final param dictionary
933
- print(f"\n trmming final result to size={size}")
934
- param_final = get_param_from_json(output_path)
935
- # trim the final param dictionary, randomly, without repeating
936
- for key, value in param_final.items():
937
- param_final[key] = param_final[key][:size]
1257
+ Returns
1258
+ ----------
1259
+ param_final : `dict`
1260
+ dictionary of GW source parameters of the detectable events. Refer to :attr:`~gw_param`
1261
+ new_total_rate : `float`
1262
+ total rate (Mpc^-3 yr^-1).
1263
+ """
938
1264
 
1265
+ print(f"\n trmming final result to size={size}")
1266
+ param_final = get_param_from_json(output_path)
1267
+ # randomly select size number of samples
1268
+ len_ = len(list(param_final.values())[0])
1269
+ idx = np.random.choice(len_, size, replace=False)
1270
+ # trim the final param dictionary, randomly, without repeating
1271
+ for key, value in param_final.items():
1272
+ param_final[key] = value[idx]
1273
+
1274
+ # change meta data
1275
+ meta_data = load_json(meta_data_path)
1276
+ old_events_total = meta_data["events_total"][-1]
1277
+ old_detectable_events = meta_data["detectable_events"][-1]
1278
+
1279
+ # adjust the meta data
1280
+ # following is to keep rate the same
1281
+ new_events_total = np.round(size*old_events_total/old_detectable_events)
1282
+ new_total_rate = self.rate_function(size, new_events_total, verbose=False)
1283
+ meta_data["events_total"][-1] = new_events_total
1284
+ meta_data["detectable_events"][-1] = size
1285
+ meta_data["total_rate"][-1] = new_total_rate
1286
+
1287
+ print("collected number of detectable events = ", size)
1288
+ print("total number of events = ", new_events_total)
1289
+ print(f"total GW event rate (yr^-1): {new_total_rate}")
1290
+
1291
+ # save the meta data
1292
+ append_json(meta_data_path, meta_data, replace=True)
939
1293
  # save the final param dictionary
940
1294
  append_json(output_path, param_final, replace=True)
941
1295
 
942
- return param_final
1296
+ return param_final, new_total_rate
1297
+
1298
+ def _append_meta_data(self, meta_data_path, n, events_total, total_rate):
1299
+ """
1300
+ Helper function for appending meta data json file.
1301
+
1302
+ Parameters
1303
+ ----------
1304
+ meta_data_path : `str`
1305
+ path to the metadata json file.
1306
+ n : `int`
1307
+ iterator.
1308
+ events_total : `int`
1309
+ total number of events.
1310
+ total_rate : `float`
1311
+ total rate (Mpc^-3 yr^-1).
1312
+ """
1313
+
1314
+ # save meta data
1315
+ meta_data = dict(events_total=[events_total], detectable_events=[float(n)], total_rate=[total_rate])
1316
+
1317
+ if os.path.exists(meta_data_path):
1318
+ try:
1319
+ append_json(meta_data_path, meta_data, replace=False)
1320
+ except:
1321
+ append_json(meta_data_path, meta_data, replace=True)
1322
+ else:
1323
+ append_json(meta_data_path, meta_data, replace=True)
1324
+
1325
+ print("collected number of detectable events = ", n)
1326
+ print("total number of events = ", events_total)
1327
+ print(f"total rate (yr^-1): {total_rate}")