ler 0.4.2__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ler might be problematic. Click here for more details.
- ler/__init__.py +26 -26
- ler/gw_source_population/__init__.py +1 -0
- ler/gw_source_population/cbc_source_parameter_distribution.py +1073 -815
- ler/gw_source_population/cbc_source_redshift_distribution.py +618 -294
- ler/gw_source_population/jit_functions.py +484 -9
- ler/gw_source_population/sfr_with_time_delay.py +107 -0
- ler/image_properties/image_properties.py +41 -12
- ler/image_properties/multiprocessing_routine.py +5 -209
- ler/lens_galaxy_population/__init__.py +2 -0
- ler/lens_galaxy_population/epl_shear_cross_section.py +0 -0
- ler/lens_galaxy_population/jit_functions.py +101 -9
- ler/lens_galaxy_population/lens_galaxy_parameter_distribution.py +813 -881
- ler/lens_galaxy_population/lens_param_data/density_profile_slope_sl.txt +5000 -0
- ler/lens_galaxy_population/lens_param_data/external_shear_sl.txt +2 -0
- ler/lens_galaxy_population/lens_param_data/number_density_zl_zs.txt +48 -0
- ler/lens_galaxy_population/lens_param_data/optical_depth_epl_shear_vd_ewoud.txt +48 -0
- ler/lens_galaxy_population/mp copy.py +554 -0
- ler/lens_galaxy_population/mp.py +736 -138
- ler/lens_galaxy_population/optical_depth.py +2248 -616
- ler/rates/__init__.py +1 -2
- ler/rates/gwrates.py +126 -72
- ler/rates/ler.py +218 -111
- ler/utils/__init__.py +2 -0
- ler/utils/function_interpolation.py +322 -0
- ler/utils/gwsnr_training_data_generator.py +233 -0
- ler/utils/plots.py +1 -1
- ler/utils/test.py +1078 -0
- ler/utils/utils.py +492 -125
- {ler-0.4.2.dist-info → ler-0.4.3.dist-info}/METADATA +30 -17
- ler-0.4.3.dist-info/RECORD +34 -0
- {ler-0.4.2.dist-info → ler-0.4.3.dist-info}/WHEEL +1 -1
- ler/rates/ler copy.py +0 -2097
- ler-0.4.2.dist-info/RECORD +0 -25
- {ler-0.4.2.dist-info → ler-0.4.3.dist-info/licenses}/LICENSE +0 -0
- {ler-0.4.2.dist-info → ler-0.4.3.dist-info}/top_level.txt +0 -0
ler/rates/__init__.py
CHANGED
ler/rates/gwrates.py
CHANGED
|
@@ -6,6 +6,8 @@ This module contains the main class for calculating the rates of detectable grav
|
|
|
6
6
|
import os
|
|
7
7
|
import warnings
|
|
8
8
|
warnings.filterwarnings("ignore")
|
|
9
|
+
import logging
|
|
10
|
+
logging.getLogger('numexpr.utils').setLevel(logging.ERROR)
|
|
9
11
|
import contextlib
|
|
10
12
|
import numpy as np
|
|
11
13
|
from scipy.stats import norm
|
|
@@ -51,7 +53,7 @@ class GWRATES(CBCSourceParameterDistribution):
|
|
|
51
53
|
def snr_finder(gw_param_dict):
|
|
52
54
|
...
|
|
53
55
|
return optimal_snr_dict
|
|
54
|
-
where optimal_snr_dict.keys = ['
|
|
56
|
+
where optimal_snr_dict.keys = ['snr_net']. Refer to `gwsnr` package's GWSNR.snr attribute for more details.
|
|
55
57
|
pdet_finder : `function`
|
|
56
58
|
default pdet_finder = None.
|
|
57
59
|
The rate calculation uses either the pdet_finder or the snr_finder to calculate the detectable events. The custom pdet finder function should follow the following signature:
|
|
@@ -128,7 +130,7 @@ class GWRATES(CBCSourceParameterDistribution):
|
|
|
128
130
|
|:meth:`~class_initialization` | Function to initialize the |
|
|
129
131
|
| | parent classes |
|
|
130
132
|
+-------------------------------------+----------------------------------+
|
|
131
|
-
|:meth:`~
|
|
133
|
+
|:meth:`~gwsnr_initialization` | Function to initialize the |
|
|
132
134
|
| | gwsnr class |
|
|
133
135
|
+-------------------------------------+----------------------------------+
|
|
134
136
|
|:meth:`~snr` | Function to get the snr with the |
|
|
@@ -282,11 +284,13 @@ class GWRATES(CBCSourceParameterDistribution):
|
|
|
282
284
|
list_of_detectors=None,
|
|
283
285
|
json_file_names=None,
|
|
284
286
|
interpolator_directory="./interpolator_pickle",
|
|
287
|
+
create_new_interpolator=False,
|
|
285
288
|
ler_directory="./ler_data",
|
|
286
289
|
verbose=True,
|
|
287
290
|
**kwargs,
|
|
288
291
|
):
|
|
289
292
|
|
|
293
|
+
print("\nInitializing GWRATES class...\n")
|
|
290
294
|
self.npool = npool
|
|
291
295
|
self.z_min = z_min
|
|
292
296
|
self.z_max = z_max
|
|
@@ -298,6 +302,7 @@ class GWRATES(CBCSourceParameterDistribution):
|
|
|
298
302
|
if json_file_names:
|
|
299
303
|
self.json_file_names.update(json_file_names)
|
|
300
304
|
self.interpolator_directory = interpolator_directory
|
|
305
|
+
kwargs["create_new_interpolator"] = create_new_interpolator
|
|
301
306
|
self.ler_directory = ler_directory
|
|
302
307
|
# create directory if not exists
|
|
303
308
|
if not os.path.exists(ler_directory):
|
|
@@ -308,7 +313,7 @@ class GWRATES(CBCSourceParameterDistribution):
|
|
|
308
313
|
self.class_initialization(params=kwargs)
|
|
309
314
|
# initialization self.snr and self.pdet from GWSNR class
|
|
310
315
|
if not snr_finder and not pdet_finder:
|
|
311
|
-
self.
|
|
316
|
+
self.gwsnr_initialization(params=kwargs)
|
|
312
317
|
self.gwsnr = True
|
|
313
318
|
self.pdet = pdet_finder
|
|
314
319
|
else:
|
|
@@ -320,20 +325,21 @@ class GWRATES(CBCSourceParameterDistribution):
|
|
|
320
325
|
# store all the gwrates input parameters
|
|
321
326
|
self.store_gwrates_params(output_jsonfile=self.json_file_names["gwrates_params"])
|
|
322
327
|
|
|
328
|
+
# if not verbose, prevent anything from printing
|
|
323
329
|
if verbose:
|
|
324
330
|
initialization()
|
|
325
|
-
self.
|
|
331
|
+
self.print_all_params_ler()
|
|
326
332
|
else:
|
|
327
333
|
with contextlib.redirect_stdout(None):
|
|
328
334
|
initialization()
|
|
329
335
|
|
|
330
|
-
def
|
|
336
|
+
def print_all_params_ler(self):
|
|
331
337
|
"""
|
|
332
338
|
Function to print all the parameters.
|
|
333
339
|
"""
|
|
334
340
|
|
|
335
341
|
# print all relevant functions and sampler priors
|
|
336
|
-
print("\n GWRATES set up params:")
|
|
342
|
+
print("\n # GWRATES set up params:")
|
|
337
343
|
print(f'npool = {self.npool},')
|
|
338
344
|
print(f'z_min = {self.z_min},')
|
|
339
345
|
print(f'z_max = {self.z_max},')
|
|
@@ -346,10 +352,10 @@ class GWRATES(CBCSourceParameterDistribution):
|
|
|
346
352
|
if self.pdet:
|
|
347
353
|
print(f'pdet_finder = {self.pdet},')
|
|
348
354
|
print(f'json_file_names = {self.json_file_names},')
|
|
349
|
-
print(f
|
|
350
|
-
print(f
|
|
355
|
+
print(f"interpolator_directory = '{self.interpolator_directory}',")
|
|
356
|
+
print(f"ler_directory = '{self.ler_directory}',")
|
|
351
357
|
|
|
352
|
-
print("\n GWRATES also takes CBCSourceParameterDistribution params as kwargs, as follows:")
|
|
358
|
+
print("\n # GWRATES also takes CBCSourceParameterDistribution params as kwargs, as follows:")
|
|
353
359
|
print(f"source_priors = {self.gw_param_sampler_dict['source_priors']},")
|
|
354
360
|
print(f"source_priors_params = {self.gw_param_sampler_dict['source_priors_params']},")
|
|
355
361
|
print(f"spin_zero = {self.gw_param_sampler_dict['spin_zero']},")
|
|
@@ -357,7 +363,7 @@ class GWRATES(CBCSourceParameterDistribution):
|
|
|
357
363
|
print(f"create_new_interpolator = {self.gw_param_sampler_dict['create_new_interpolator']},")
|
|
358
364
|
|
|
359
365
|
if self.gwsnr:
|
|
360
|
-
print("\n LeR also takes gwsnr.GWSNR params as kwargs, as follows:")
|
|
366
|
+
print("\n # LeR also takes gwsnr.GWSNR params as kwargs, as follows:")
|
|
361
367
|
print(f"mtot_min = {self.snr_calculator_dict['mtot_min']},")
|
|
362
368
|
print(f"mtot_max = {self.snr_calculator_dict['mtot_max']},")
|
|
363
369
|
print(f"ratio_min = {self.snr_calculator_dict['ratio_min']},")
|
|
@@ -530,10 +536,10 @@ class GWRATES(CBCSourceParameterDistribution):
|
|
|
530
536
|
self,
|
|
531
537
|
z_min=input_params["z_min"],
|
|
532
538
|
z_max=input_params["z_max"],
|
|
539
|
+
cosmology=input_params["cosmology"],
|
|
533
540
|
event_type=input_params["event_type"],
|
|
534
541
|
source_priors=input_params["source_priors"],
|
|
535
542
|
source_priors_params=input_params["source_priors_params"],
|
|
536
|
-
cosmology=input_params["cosmology"],
|
|
537
543
|
spin_zero=input_params["spin_zero"],
|
|
538
544
|
spin_precession=input_params["spin_precession"],
|
|
539
545
|
directory=input_params["directory"],
|
|
@@ -543,7 +549,7 @@ class GWRATES(CBCSourceParameterDistribution):
|
|
|
543
549
|
self.gw_param_sampler_dict["source_priors"]=self.gw_param_samplers.copy()
|
|
544
550
|
self.gw_param_sampler_dict["source_priors_params"]=self.gw_param_samplers_params.copy()
|
|
545
551
|
|
|
546
|
-
def
|
|
552
|
+
def gwsnr_initialization(self, params=None):
|
|
547
553
|
"""
|
|
548
554
|
Function to initialize the GWSNR class from the `gwsnr` package.
|
|
549
555
|
|
|
@@ -557,23 +563,35 @@ class GWRATES(CBCSourceParameterDistribution):
|
|
|
557
563
|
# initialization of GWSNR class
|
|
558
564
|
input_params = dict(
|
|
559
565
|
npool=self.npool,
|
|
560
|
-
mtot_min=2.
|
|
561
|
-
mtot_max=
|
|
566
|
+
mtot_min=2*self.gw_param_samplers_params['source_frame_masses']['mminbh'],
|
|
567
|
+
mtot_max=2*self.gw_param_samplers_params['source_frame_masses']['mmaxbh']+10.0,
|
|
562
568
|
ratio_min=0.1,
|
|
563
569
|
ratio_max=1.0,
|
|
564
|
-
|
|
565
|
-
|
|
570
|
+
spin_max=0.99,
|
|
571
|
+
mtot_resolution=200,
|
|
572
|
+
ratio_resolution=20,
|
|
573
|
+
spin_resolution=10,
|
|
566
574
|
sampling_frequency=2048.0,
|
|
567
575
|
waveform_approximant="IMRPhenomD",
|
|
576
|
+
frequency_domain_source_model='lal_binary_black_hole',
|
|
568
577
|
minimum_frequency=20.0,
|
|
578
|
+
duration_max=None,
|
|
579
|
+
duration_min=None,
|
|
569
580
|
snr_type="interpolation",
|
|
570
581
|
psds=None,
|
|
571
582
|
ifos=None,
|
|
572
|
-
interpolator_dir=
|
|
583
|
+
interpolator_dir="./interpolator_pickle",
|
|
573
584
|
create_new_interpolator=False,
|
|
574
|
-
gwsnr_verbose=
|
|
585
|
+
gwsnr_verbose=True,
|
|
575
586
|
multiprocessing_verbose=True,
|
|
576
|
-
mtot_cut=
|
|
587
|
+
mtot_cut=False,
|
|
588
|
+
pdet=False,
|
|
589
|
+
snr_th=8.0,
|
|
590
|
+
snr_th_net=8.0,
|
|
591
|
+
ann_path_dict=None,
|
|
592
|
+
snr_recalculation=False,
|
|
593
|
+
snr_recalculation_range=[4,12],
|
|
594
|
+
snr_recalculation_waveform_approximant="IMRPhenomXPHM",
|
|
577
595
|
)
|
|
578
596
|
# if self.event_type == "BNS":
|
|
579
597
|
# input_params["mtot_max"]= 18.
|
|
@@ -582,25 +600,49 @@ class GWRATES(CBCSourceParameterDistribution):
|
|
|
582
600
|
if key in input_params:
|
|
583
601
|
input_params[key] = value
|
|
584
602
|
self.snr_calculator_dict = input_params
|
|
603
|
+
|
|
604
|
+
# dealing with create_new_interpolator param
|
|
605
|
+
if isinstance(input_params["create_new_interpolator"], bool):
|
|
606
|
+
pass
|
|
607
|
+
elif isinstance(input_params["create_new_interpolator"], dict):
|
|
608
|
+
# check input_params["gwsnr"] exists
|
|
609
|
+
if "gwsnr" in input_params["create_new_interpolator"]:
|
|
610
|
+
if isinstance(input_params["create_new_interpolator"]["gwsnr"], bool):
|
|
611
|
+
input_params["create_new_interpolator"] = input_params["create_new_interpolator"]["gwsnr"]
|
|
612
|
+
else:
|
|
613
|
+
raise ValueError("create_new_interpolator['gwsnr'] should be a boolean.")
|
|
614
|
+
else:
|
|
615
|
+
input_params["create_new_interpolator"] = False
|
|
616
|
+
|
|
617
|
+
# initialization of GWSNR class
|
|
585
618
|
gwsnr = GWSNR(
|
|
586
619
|
npool=input_params["npool"],
|
|
587
620
|
mtot_min=input_params["mtot_min"],
|
|
588
621
|
mtot_max=input_params["mtot_max"],
|
|
589
622
|
ratio_min=input_params["ratio_min"],
|
|
590
623
|
ratio_max=input_params["ratio_max"],
|
|
624
|
+
spin_max=input_params["spin_max"],
|
|
591
625
|
mtot_resolution=input_params["mtot_resolution"],
|
|
592
626
|
ratio_resolution=input_params["ratio_resolution"],
|
|
627
|
+
spin_resolution=input_params["spin_resolution"],
|
|
593
628
|
sampling_frequency=input_params["sampling_frequency"],
|
|
594
629
|
waveform_approximant=input_params["waveform_approximant"],
|
|
630
|
+
frequency_domain_source_model=input_params["frequency_domain_source_model"],
|
|
595
631
|
minimum_frequency=input_params["minimum_frequency"],
|
|
632
|
+
duration_max=input_params["duration_max"],
|
|
633
|
+
duration_min=input_params["duration_min"],
|
|
596
634
|
snr_type=input_params["snr_type"],
|
|
597
635
|
psds=input_params["psds"],
|
|
598
636
|
ifos=input_params["ifos"],
|
|
599
637
|
interpolator_dir=input_params["interpolator_dir"],
|
|
600
|
-
|
|
638
|
+
create_new_interpolator=input_params["create_new_interpolator"],
|
|
601
639
|
gwsnr_verbose=input_params["gwsnr_verbose"],
|
|
602
640
|
multiprocessing_verbose=input_params["multiprocessing_verbose"],
|
|
603
641
|
mtot_cut=input_params["mtot_cut"],
|
|
642
|
+
pdet=input_params["pdet"],
|
|
643
|
+
snr_th=input_params["snr_th"],
|
|
644
|
+
snr_th_net=input_params["snr_th_net"],
|
|
645
|
+
ann_path_dict=input_params["ann_path_dict"],
|
|
604
646
|
)
|
|
605
647
|
|
|
606
648
|
self.snr = gwsnr.snr
|
|
@@ -610,6 +652,16 @@ class GWRATES(CBCSourceParameterDistribution):
|
|
|
610
652
|
self.snr_calculator_dict["psds"] = gwsnr.psds_list
|
|
611
653
|
#self.pdet = gwsnr.pdet
|
|
612
654
|
|
|
655
|
+
# warm-up the snr calculator
|
|
656
|
+
if input_params["snr_type"]!="inner_product":
|
|
657
|
+
with contextlib.suppress(Exception):
|
|
658
|
+
# if snr_type is not inner_product, then we can warm up the snr calculator
|
|
659
|
+
# this is useful to avoid the first call to snr being slow
|
|
660
|
+
mass_1 = np.array([10.])
|
|
661
|
+
ratio = np.array([0.8])
|
|
662
|
+
dl = np.array([1000.0]) # in Mpc
|
|
663
|
+
snr = self.snr(mass_1=mass_1, mass_2=mass_1*ratio, luminosity_distance=dl)
|
|
664
|
+
|
|
613
665
|
def store_gwrates_params(self, output_jsonfile="gwrates_params.json"):
|
|
614
666
|
"""
|
|
615
667
|
Function to store the all the necessary parameters. This is useful for reproducing the results. All the parameters stored are in string format to make it json compatible.
|
|
@@ -646,13 +698,13 @@ class GWRATES(CBCSourceParameterDistribution):
|
|
|
646
698
|
for key, value in snr_calculator_dict.items():
|
|
647
699
|
snr_calculator_dict[key] = str(value)
|
|
648
700
|
parameters_dict.update({"snr_calculator_dict": snr_calculator_dict})
|
|
649
|
-
|
|
650
|
-
file_name = output_jsonfile
|
|
651
|
-
append_json(self.ler_directory+"/"+file_name, parameters_dict, replace=True)
|
|
652
701
|
except:
|
|
653
702
|
# if snr_calculator is custom function
|
|
654
703
|
pass
|
|
655
704
|
|
|
705
|
+
file_name = output_jsonfile
|
|
706
|
+
append_json(self.ler_directory+"/"+file_name, parameters_dict, replace=True)
|
|
707
|
+
|
|
656
708
|
def gw_cbc_statistics(
|
|
657
709
|
self, size=None, resume=False, save_batch=False, output_jsonfile=None,
|
|
658
710
|
):
|
|
@@ -691,36 +743,16 @@ class GWRATES(CBCSourceParameterDistribution):
|
|
|
691
743
|
output_path = os.path.join(self.ler_directory, output_jsonfile)
|
|
692
744
|
print(f"Simulated GW params will be stored in {output_path}")
|
|
693
745
|
|
|
694
|
-
|
|
695
|
-
if resume and os.path.exists(output_path):
|
|
696
|
-
# get sample from json file
|
|
697
|
-
self.dict_buffer = get_param_from_json(output_path)
|
|
698
|
-
else:
|
|
699
|
-
self.dict_buffer = None
|
|
700
|
-
|
|
701
|
-
batch_handler(
|
|
746
|
+
gw_param = batch_handler(
|
|
702
747
|
size=size,
|
|
703
748
|
batch_size=self.batch_size,
|
|
704
749
|
sampling_routine=self.gw_sampling_routine,
|
|
705
750
|
output_jsonfile=output_path,
|
|
706
751
|
save_batch=save_batch,
|
|
707
752
|
resume=resume,
|
|
753
|
+
param_name="gw parameters",
|
|
708
754
|
)
|
|
709
755
|
|
|
710
|
-
if save_batch:
|
|
711
|
-
gw_param = get_param_from_json(output_path)
|
|
712
|
-
else:
|
|
713
|
-
# this if condition is required if there is nothing to save
|
|
714
|
-
if self.dict_buffer:
|
|
715
|
-
gw_param = self.dict_buffer.copy()
|
|
716
|
-
# store all params in json file
|
|
717
|
-
print(f"saving all gw_params in {output_path}...")
|
|
718
|
-
append_json(output_path, gw_param, replace=True)
|
|
719
|
-
else:
|
|
720
|
-
print("gw_params already sampled.")
|
|
721
|
-
gw_param = get_param_from_json(output_path)
|
|
722
|
-
self.dict_buffer = None # save memory
|
|
723
|
-
|
|
724
756
|
return gw_param
|
|
725
757
|
|
|
726
758
|
def gw_sampling_routine(self, size, output_jsonfile, resume=False, save_batch=True):
|
|
@@ -760,17 +792,6 @@ class GWRATES(CBCSourceParameterDistribution):
|
|
|
760
792
|
pdet = self.pdet(gw_param_dict=gw_param)
|
|
761
793
|
gw_param.update(pdet)
|
|
762
794
|
|
|
763
|
-
# adding batches
|
|
764
|
-
if not save_batch:
|
|
765
|
-
if self.dict_buffer is None:
|
|
766
|
-
self.dict_buffer = gw_param
|
|
767
|
-
else:
|
|
768
|
-
for key, value in gw_param.items():
|
|
769
|
-
self.dict_buffer[key] = np.concatenate((self.dict_buffer[key], value))
|
|
770
|
-
else:
|
|
771
|
-
# store all params in json file
|
|
772
|
-
self.dict_buffer = append_json(file_name=output_jsonfile, new_dictionary=gw_param, old_dictionary=self.dict_buffer, replace=not (resume))
|
|
773
|
-
|
|
774
795
|
return gw_param
|
|
775
796
|
|
|
776
797
|
def gw_rate(
|
|
@@ -897,7 +918,7 @@ class GWRATES(CBCSourceParameterDistribution):
|
|
|
897
918
|
dictionary of GW source parameters.
|
|
898
919
|
"""
|
|
899
920
|
|
|
900
|
-
snr_param = gw_param["
|
|
921
|
+
snr_param = gw_param["snr_net"]
|
|
901
922
|
idx_detectable = (snr_param > snr_threshold_recalculation[0]) & (snr_param < snr_threshold_recalculation[1])
|
|
902
923
|
# reduce the size of the dict
|
|
903
924
|
for key, value in gw_param.items():
|
|
@@ -931,15 +952,15 @@ class GWRATES(CBCSourceParameterDistribution):
|
|
|
931
952
|
"""
|
|
932
953
|
|
|
933
954
|
if self.snr:
|
|
934
|
-
if "
|
|
935
|
-
raise ValueError("'
|
|
955
|
+
if "snr_net" not in gw_param:
|
|
956
|
+
raise ValueError("'snr_net' not in gw param dict provided")
|
|
936
957
|
if detectability_condition == "step_function":
|
|
937
958
|
print("given detectability_condition == 'step_function'")
|
|
938
|
-
param = gw_param["
|
|
959
|
+
param = gw_param["snr_net"]
|
|
939
960
|
threshold = snr_threshold
|
|
940
961
|
elif detectability_condition == "pdet":
|
|
941
962
|
print("given detectability_condition == 'pdet'")
|
|
942
|
-
param = 1 - norm.cdf(snr_threshold - gw_param["
|
|
963
|
+
param = 1 - norm.cdf(snr_threshold - gw_param["snr_net"])
|
|
943
964
|
gw_param["pdet_net"] = param
|
|
944
965
|
threshold = pdet_threshold
|
|
945
966
|
elif self.pdet:
|
|
@@ -1061,13 +1082,17 @@ class GWRATES(CBCSourceParameterDistribution):
|
|
|
1061
1082
|
self,
|
|
1062
1083
|
size=100,
|
|
1063
1084
|
batch_size=None,
|
|
1085
|
+
stopping_criteria=dict(
|
|
1086
|
+
relative_diff_percentage=0.5,
|
|
1087
|
+
number_of_last_batches_to_check=4,
|
|
1088
|
+
),
|
|
1064
1089
|
snr_threshold=8.0,
|
|
1065
1090
|
pdet_threshold=0.5,
|
|
1066
1091
|
resume=False,
|
|
1067
1092
|
output_jsonfile="gw_params_n_detectable.json",
|
|
1068
1093
|
meta_data_file="meta_gw.json",
|
|
1069
1094
|
detectability_condition="step_function",
|
|
1070
|
-
trim_to_size=
|
|
1095
|
+
trim_to_size=False,
|
|
1071
1096
|
snr_recalculation=False,
|
|
1072
1097
|
snr_threshold_recalculation=[4, 12],
|
|
1073
1098
|
):
|
|
@@ -1124,10 +1149,11 @@ class GWRATES(CBCSourceParameterDistribution):
|
|
|
1124
1149
|
"""
|
|
1125
1150
|
|
|
1126
1151
|
# initial setup
|
|
1127
|
-
n, events_total, output_path, meta_data_path, buffer_file = self._initial_setup_for_n_event_selection(meta_data_file, output_jsonfile, resume, batch_size)
|
|
1152
|
+
n, events_total, output_path, meta_data_path, buffer_file, batch_size = self._initial_setup_for_n_event_selection(meta_data_file, output_jsonfile, resume, batch_size)
|
|
1128
1153
|
|
|
1129
1154
|
# loop until n samples are collected
|
|
1130
|
-
|
|
1155
|
+
continue_condition = True
|
|
1156
|
+
while continue_condition:
|
|
1131
1157
|
# disable print statements
|
|
1132
1158
|
with contextlib.redirect_stdout(None):
|
|
1133
1159
|
self.dict_buffer = None # this is used to store the sampled gw_param in batches when running the sampling_routine
|
|
@@ -1152,9 +1178,33 @@ class GWRATES(CBCSourceParameterDistribution):
|
|
|
1152
1178
|
total_rate = self.rate_function(n, events_total, verbose=False)
|
|
1153
1179
|
|
|
1154
1180
|
# bookmark
|
|
1155
|
-
self._append_meta_data(meta_data_path, n, events_total, total_rate)
|
|
1156
|
-
|
|
1157
|
-
|
|
1181
|
+
buffer_dict = self._append_meta_data(meta_data_path, n, events_total, total_rate)
|
|
1182
|
+
|
|
1183
|
+
if isinstance(stopping_criteria, dict):
|
|
1184
|
+
total_rates = np.array(buffer_dict['total_rate'])
|
|
1185
|
+
limit = stopping_criteria['relative_diff_percentage']
|
|
1186
|
+
num_a = stopping_criteria['number_of_last_batches_to_check']
|
|
1187
|
+
if len(total_rates)>num_a:
|
|
1188
|
+
num_a = int(-1*(num_a))
|
|
1189
|
+
# num_b = int(num_a)
|
|
1190
|
+
percentage_diff = np.abs((total_rates[num_a:]-total_rates[-1])/total_rates[-1])*100
|
|
1191
|
+
print(f"percentage difference for the last {abs(num_a)} batches = {percentage_diff}")
|
|
1192
|
+
if np.any(percentage_diff>limit):
|
|
1193
|
+
continue_condition &= True
|
|
1194
|
+
else:
|
|
1195
|
+
print(rf"stopping criteria of rate relative difference of {limit}% reached. If you want to collect more events, reduce stopping_criteria['relative_diff_percentage'] or put stopping_criteria=None.")
|
|
1196
|
+
continue_condition &= False
|
|
1197
|
+
|
|
1198
|
+
if isinstance(size, int):
|
|
1199
|
+
if n<size:
|
|
1200
|
+
continue_condition |= True
|
|
1201
|
+
else:
|
|
1202
|
+
print(rf"Given size={size} reached")
|
|
1203
|
+
continue_condition |= False
|
|
1204
|
+
if stopping_criteria is None:
|
|
1205
|
+
continue_condition &= False
|
|
1206
|
+
|
|
1207
|
+
print(f"stored detectable unlensed params in {output_path}")
|
|
1158
1208
|
print(f"stored meta data in {meta_data_path}")
|
|
1159
1209
|
|
|
1160
1210
|
if trim_to_size:
|
|
@@ -1221,6 +1271,8 @@ class GWRATES(CBCSourceParameterDistribution):
|
|
|
1221
1271
|
if not resume:
|
|
1222
1272
|
n = 0 # iterator
|
|
1223
1273
|
events_total = 0
|
|
1274
|
+
# the following file will be removed if it exists
|
|
1275
|
+
print(f"removing {output_path} and {meta_data_path} if they exist")
|
|
1224
1276
|
if os.path.exists(output_path):
|
|
1225
1277
|
os.remove(output_path)
|
|
1226
1278
|
if os.path.exists(meta_data_path):
|
|
@@ -1239,7 +1291,7 @@ class GWRATES(CBCSourceParameterDistribution):
|
|
|
1239
1291
|
buffer_file = "params_buffer.json"
|
|
1240
1292
|
print("collected number of detectable events = ", n)
|
|
1241
1293
|
|
|
1242
|
-
return n, events_total, output_path, meta_data_path, buffer_file
|
|
1294
|
+
return n, events_total, output_path, meta_data_path, buffer_file, batch_size
|
|
1243
1295
|
|
|
1244
1296
|
def _trim_results_to_size(self, size, output_path, meta_data_path):
|
|
1245
1297
|
"""
|
|
@@ -1316,12 +1368,14 @@ class GWRATES(CBCSourceParameterDistribution):
|
|
|
1316
1368
|
|
|
1317
1369
|
if os.path.exists(meta_data_path):
|
|
1318
1370
|
try:
|
|
1319
|
-
append_json(meta_data_path, meta_data, replace=False)
|
|
1371
|
+
dict_ = append_json(meta_data_path, meta_data, replace=False)
|
|
1320
1372
|
except:
|
|
1321
|
-
append_json(meta_data_path, meta_data, replace=True)
|
|
1373
|
+
dict_ = append_json(meta_data_path, meta_data, replace=True)
|
|
1322
1374
|
else:
|
|
1323
|
-
append_json(meta_data_path, meta_data, replace=True)
|
|
1375
|
+
dict_ = append_json(meta_data_path, meta_data, replace=True)
|
|
1324
1376
|
|
|
1325
1377
|
print("collected number of detectable events = ", n)
|
|
1326
1378
|
print("total number of events = ", events_total)
|
|
1327
|
-
print(f"total rate (yr^-1): {total_rate}")
|
|
1379
|
+
print(f"total rate (yr^-1): {total_rate}")
|
|
1380
|
+
|
|
1381
|
+
return dict_
|