ler 0.4.2__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ler might be problematic. Click here for more details.

Files changed (35) hide show
  1. ler/__init__.py +26 -26
  2. ler/gw_source_population/__init__.py +1 -0
  3. ler/gw_source_population/cbc_source_parameter_distribution.py +1073 -815
  4. ler/gw_source_population/cbc_source_redshift_distribution.py +618 -294
  5. ler/gw_source_population/jit_functions.py +484 -9
  6. ler/gw_source_population/sfr_with_time_delay.py +107 -0
  7. ler/image_properties/image_properties.py +41 -12
  8. ler/image_properties/multiprocessing_routine.py +5 -209
  9. ler/lens_galaxy_population/__init__.py +2 -0
  10. ler/lens_galaxy_population/epl_shear_cross_section.py +0 -0
  11. ler/lens_galaxy_population/jit_functions.py +101 -9
  12. ler/lens_galaxy_population/lens_galaxy_parameter_distribution.py +813 -881
  13. ler/lens_galaxy_population/lens_param_data/density_profile_slope_sl.txt +5000 -0
  14. ler/lens_galaxy_population/lens_param_data/external_shear_sl.txt +2 -0
  15. ler/lens_galaxy_population/lens_param_data/number_density_zl_zs.txt +48 -0
  16. ler/lens_galaxy_population/lens_param_data/optical_depth_epl_shear_vd_ewoud.txt +48 -0
  17. ler/lens_galaxy_population/mp copy.py +554 -0
  18. ler/lens_galaxy_population/mp.py +736 -138
  19. ler/lens_galaxy_population/optical_depth.py +2248 -616
  20. ler/rates/__init__.py +1 -2
  21. ler/rates/gwrates.py +126 -72
  22. ler/rates/ler.py +218 -111
  23. ler/utils/__init__.py +2 -0
  24. ler/utils/function_interpolation.py +322 -0
  25. ler/utils/gwsnr_training_data_generator.py +233 -0
  26. ler/utils/plots.py +1 -1
  27. ler/utils/test.py +1078 -0
  28. ler/utils/utils.py +492 -125
  29. {ler-0.4.2.dist-info → ler-0.4.3.dist-info}/METADATA +30 -17
  30. ler-0.4.3.dist-info/RECORD +34 -0
  31. {ler-0.4.2.dist-info → ler-0.4.3.dist-info}/WHEEL +1 -1
  32. ler/rates/ler copy.py +0 -2097
  33. ler-0.4.2.dist-info/RECORD +0 -25
  34. {ler-0.4.2.dist-info → ler-0.4.3.dist-info/licenses}/LICENSE +0 -0
  35. {ler-0.4.2.dist-info → ler-0.4.3.dist-info}/top_level.txt +0 -0
ler/rates/ler.py CHANGED
@@ -4,6 +4,7 @@ This module contains the main class for calculating the rates of detectable grav
4
4
  """
5
5
 
6
6
  import os
7
+ # os.environ['OMP_NESTED'] = 'FALSE'
7
8
  import warnings
8
9
  warnings.filterwarnings("ignore")
9
10
  import logging
@@ -16,12 +17,6 @@ from ..lens_galaxy_population import LensGalaxyParameterDistribution
16
17
  from ..utils import load_json, append_json, get_param_from_json, batch_handler
17
18
 
18
19
 
19
- # # multiprocessing guard code
20
- # def main():
21
- # obj = LeR()
22
-
23
- # if __name__ == '__main__':
24
-
25
20
  class LeR(LensGalaxyParameterDistribution):
26
21
  """Class to sample of lensed and unlensed events and calculate it's rates. Please note that parameters of the simulated events are stored in json file but not as an attribute of the class. This saves RAM memory.
27
22
 
@@ -59,7 +54,7 @@ class LeR(LensGalaxyParameterDistribution):
59
54
  def snr_finder(gw_param_dict):
60
55
  ...
61
56
  return optimal_snr_dict
62
- where optimal_snr_dict.keys = ['optimal_snr_net']. Refer to `gwsnr` package's GWSNR.snr attribute for more details.
57
+ where optimal_snr_dict.keys = ['snr_net']. Refer to `gwsnr` package's GWSNR.snr attribute for more details.
63
58
  pdet_finder : `function`
64
59
  default pdet_finder = None.
65
60
  The rate calculation uses either the pdet_finder or the snr_finder to calculate the detectable events. The custom pdet finder function should follow the following signature:
@@ -165,7 +160,7 @@ class LeR(LensGalaxyParameterDistribution):
165
160
  |:meth:`~class_initialization` | Function to initialize the |
166
161
  | | parent classes |
167
162
  +-------------------------------------+----------------------------------+
168
- |:meth:`~gwsnr_intialization` | Function to initialize the |
163
+ |:meth:`~gwsnr_initialization` | Function to initialize the |
169
164
  | | gwsnr class |
170
165
  +-------------------------------------+----------------------------------+
171
166
  |:meth:`~snr` | Function to get the snr with the |
@@ -460,6 +455,7 @@ class LeR(LensGalaxyParameterDistribution):
460
455
  **kwargs,
461
456
  ):
462
457
 
458
+ print("\nInitializing LeR class...\n")
463
459
  self.npool = npool
464
460
  self.z_min = z_min
465
461
  self.z_max = z_max
@@ -482,7 +478,7 @@ class LeR(LensGalaxyParameterDistribution):
482
478
  self.class_initialization(params=kwargs)
483
479
  # initialization self.snr and self.pdet from GWSNR class
484
480
  if not snr_finder and not pdet_finder:
485
- self.gwsnr_intialization(params=kwargs)
481
+ self.gwsnr_initialization(params=kwargs)
486
482
  self.gwsnr = True
487
483
  self.pdet = pdet_finder
488
484
  else:
@@ -494,15 +490,15 @@ class LeR(LensGalaxyParameterDistribution):
494
490
  # store all the ler input parameters
495
491
  self.store_ler_params(output_jsonfile=self.json_file_names["ler_params"])
496
492
 
497
- # if verbose, prevent anything from printing
493
+ # if not verbose, prevent anything from printing
498
494
  if verbose:
499
495
  initialization()
500
- self.print_all_params()
496
+ self.print_all_params_ler()
501
497
  else:
502
498
  with contextlib.redirect_stdout(None):
503
499
  initialization()
504
500
 
505
- def print_all_params(self):
501
+ def print_all_params_ler(self):
506
502
  """
507
503
  Function to print all the parameters.
508
504
  """
@@ -533,14 +529,15 @@ class LeR(LensGalaxyParameterDistribution):
533
529
  print("\n # LeR also takes LensGalaxyParameterDistribution class params as kwargs, as follows:")
534
530
  print(f"lens_type = '{self.gw_param_sampler_dict['lens_type']}',")
535
531
  print(f"lens_functions = {self.gw_param_sampler_dict['lens_functions']},")
536
- print(f"lens_priors = {self.gw_param_sampler_dict['lens_priors']},")
537
- print(f"lens_priors_params = {self.gw_param_sampler_dict['lens_priors_params']},")
532
+ print(f"lens_param_samplers = {self.gw_param_sampler_dict['lens_param_samplers']},")
533
+ print(f"lens_param_samplers_params = {self.gw_param_sampler_dict['lens_param_samplers_params']},")
538
534
 
539
535
  print("\n # LeR also takes ImageProperties class params as kwargs, as follows:")
540
536
  print(f"n_min_images = {self.n_min_images},")
541
537
  print(f"n_max_images = {self.n_max_images},")
542
- print(f"geocent_time_min = {self.geocent_time_min},")
543
- print(f"geocent_time_max = {self.geocent_time_max},")
538
+ print(f"time_window = {self.time_window},")
539
+ # print(f"geocent_time_min = {self.geocent_time_min},")
540
+ # print(f"geocent_time_max = {self.geocent_time_max},")
544
541
  print(f"lens_model_list = {self.lens_model_list},")
545
542
 
546
543
  if self.gwsnr:
@@ -604,14 +601,16 @@ class LeR(LensGalaxyParameterDistribution):
604
601
  print("axis_ratio_params = ", self.lens_param_samplers_params["axis_ratio"])
605
602
  print(f"axis_rotation_angle = '{self.lens_param_samplers['axis_rotation_angle']}'")
606
603
  print("axis_rotation_angle_params = ", self.lens_param_samplers_params["axis_rotation_angle"])
607
- print(f"shear = '{self.lens_param_samplers['shear']}'")
608
- print("shear_params = ", self.lens_param_samplers_params["shear"])
609
- print(f"mass_density_spectral_index = '{self.lens_param_samplers['mass_density_spectral_index']}'")
610
- print("mass_density_spectral_index_params = ", self.lens_param_samplers_params["mass_density_spectral_index"])
604
+ print(f"shear = '{self.lens_param_samplers['external_shear']}'")
605
+ print("shear_params = ", self.lens_param_samplers_params['external_shear'])
606
+ print(f"density_profile_slope = '{self.lens_param_samplers['density_profile_slope']}'")
607
+ print("density_profile_slope_params = ", self.lens_param_samplers_params["density_profile_slope"])
611
608
  # lens functions
612
609
  print("Lens functions:")
613
610
  print(f"strong_lensing_condition = '{self.lens_functions['strong_lensing_condition']}'")
614
611
  print(f"optical_depth = '{self.lens_functions['optical_depth']}'")
612
+ print(f"optical_depth_params = '{self.lens_functions_params['optical_depth']}'")
613
+ print(f"param_sampler_type = '{self.lens_functions['param_sampler_type']}'")
615
614
 
616
615
  @property
617
616
  def snr(self):
@@ -743,15 +742,18 @@ class LeR(LensGalaxyParameterDistribution):
743
742
  z_max=self.z_max,
744
743
  cosmology=self.cosmo,
745
744
  event_type=self.event_type,
746
- lens_type="epl_galaxy",
745
+ lens_type="epl_shear_galaxy",
747
746
  lens_functions= None,
748
- lens_priors=None,
749
- lens_priors_params=None,
747
+ lens_functions_params=None,
748
+ lens_param_samplers=None,
749
+ lens_param_samplers_params=None,
750
+ buffer_size=1000,
750
751
  # ImageProperties class params
751
752
  n_min_images=2,
752
753
  n_max_images=4,
753
- geocent_time_min=1126259462.4,
754
- geocent_time_max=1126259462.4+365*24*3600*20,
754
+ time_window=365*24*3600*20,
755
+ # geocent_time_min=1126259462.4,
756
+ # geocent_time_max=1126259462.4+365*24*3600*20,
755
757
  lens_model_list=['EPL_NUMBA', 'SHEAR'],
756
758
  # CBCSourceParameterDistribution class params
757
759
  source_priors=None,
@@ -775,13 +777,16 @@ class LeR(LensGalaxyParameterDistribution):
775
777
  event_type=input_params["event_type"],
776
778
  lens_type=input_params["lens_type"],
777
779
  lens_functions=input_params["lens_functions"],
778
- lens_priors=input_params["lens_priors"],
779
- lens_priors_params=input_params["lens_priors_params"],
780
+ lens_functions_params=input_params["lens_functions_params"],
781
+ lens_param_samplers=input_params["lens_param_samplers"],
782
+ lens_param_samplers_params=input_params["lens_param_samplers_params"],
780
783
  n_min_images=input_params["n_min_images"],
781
784
  n_max_images=input_params["n_max_images"],
782
- geocent_time_min=input_params["geocent_time_min"],
783
- geocent_time_max=input_params["geocent_time_max"],
785
+ time_window=input_params["time_window"],
786
+ # geocent_time_min=input_params["geocent_time_min"],
787
+ # geocent_time_max=input_params["geocent_time_max"],
784
788
  lens_model_list=input_params["lens_model_list"],
789
+ buffer_size=input_params["buffer_size"],
785
790
  source_priors=input_params["source_priors"],
786
791
  source_priors_params=input_params["source_priors_params"],
787
792
  spin_zero=input_params["spin_zero"],
@@ -792,11 +797,12 @@ class LeR(LensGalaxyParameterDistribution):
792
797
 
793
798
  self.gw_param_sampler_dict["source_priors"]=self.gw_param_samplers.copy()
794
799
  self.gw_param_sampler_dict["source_priors_params"]=self.gw_param_samplers_params.copy()
795
- self.gw_param_sampler_dict["lens_priors"]=self.lens_param_samplers.copy()
796
- self.gw_param_sampler_dict["lens_priors_params"]=self.lens_param_samplers_params.copy()
800
+ self.gw_param_sampler_dict["lens_param_samplers"]=self.lens_param_samplers.copy()
801
+ self.gw_param_sampler_dict["lens_param_samplers_params"]=self.lens_param_samplers_params.copy()
797
802
  self.gw_param_sampler_dict["lens_functions"]=self.lens_functions.copy()
803
+ self.gw_param_sampler_dict["lens_functions_params"]=self.lens_functions_params.copy()
798
804
 
799
- def gwsnr_intialization(self, params=None):
805
+ def gwsnr_initialization(self, params=None):
800
806
  """
801
807
  Function to initialize the GWSNR class from the `gwsnr` package.
802
808
 
@@ -808,29 +814,55 @@ class LeR(LensGalaxyParameterDistribution):
808
814
  from gwsnr import GWSNR
809
815
 
810
816
  # initialization of GWSNR class
817
+ if 'mminbh' in self.gw_param_samplers_params['source_frame_masses']:
818
+ min_bh_mass = self.gw_param_samplers_params['source_frame_masses']['mminbh']
819
+ else:
820
+ min_bh_mass = 2.0
821
+
822
+ if 'mmaxbh' in self.gw_param_samplers_params['source_frame_masses']:
823
+ max_bh_mass = self.gw_param_samplers_params['source_frame_masses']['mmaxbh']
824
+ else:
825
+ max_bh_mass = 200.0
811
826
  input_params = dict(
827
+ # General settings
812
828
  npool=self.npool,
813
- mtot_min=2.0,
814
- mtot_max=200,
829
+ snr_method="interpolation_aligned_spins",
830
+ snr_type="optimal_snr",
831
+ gwsnr_verbose=True,
832
+ multiprocessing_verbose=True,
833
+ pdet_kwargs=None,
834
+ # Settings for interpolation grid
835
+ mtot_min=min_bh_mass*2,
836
+ mtot_max=max_bh_mass*2*(1+self.z_max) if max_bh_mass*2*(1+self.z_max)<500.0 else 500.0,
815
837
  ratio_min=0.1,
816
838
  ratio_max=1.0,
817
- mtot_resolution=500,
818
- ratio_resolution=50,
839
+ spin_max=0.99,
840
+ mtot_resolution=200,
841
+ ratio_resolution=20,
842
+ spin_resolution=10,
843
+ batch_size_interpolation=1000000,
844
+ interpolator_dir="./interpolator_pickle",
845
+ create_new_interpolator=False,
846
+ # GW signal settings
819
847
  sampling_frequency=2048.0,
820
848
  waveform_approximant="IMRPhenomD",
849
+ frequency_domain_source_model='lal_binary_black_hole',
821
850
  minimum_frequency=20.0,
822
- snr_type="interpolation",
851
+ reference_frequency=None,
852
+ duration_max=None,
853
+ duration_min=None,
854
+ fixed_duration=None,
855
+ mtot_cut=False,
856
+ # Detector settings
823
857
  psds=None,
824
858
  ifos=None,
825
- interpolator_dir=self.interpolator_directory,
826
- create_new_interpolator=False,
827
- gwsnr_verbose=False,
828
- multiprocessing_verbose=True,
829
- mtot_cut=True,
830
- pdet=False,
831
- snr_th=8.0,
832
- snr_th_net=8.0,
859
+ noise_realization=None, # not implemented yet
860
+ # ANN settings
833
861
  ann_path_dict=None,
862
+ # Hybrid SNR recalculation settings
863
+ snr_recalculation=False,
864
+ snr_recalculation_range=[6,14],
865
+ snr_recalculation_waveform_approximant="IMRPhenomXPHM",
834
866
  )
835
867
  # if self.event_type == "BNS":
836
868
  # input_params["mtot_max"]= 18.
@@ -851,40 +883,51 @@ class LeR(LensGalaxyParameterDistribution):
851
883
  else:
852
884
  raise ValueError("create_new_interpolator['gwsnr'] should be a boolean.")
853
885
  else:
854
- raise ValueError("create_new_interpolator should be a boolean or a dictionary with 'gwsnr' key.")
886
+ input_params["create_new_interpolator"] = False
855
887
 
856
888
  # initialization of GWSNR class
857
889
  gwsnr = GWSNR(
858
- npool=input_params["npool"],
859
- mtot_min=input_params["mtot_min"],
860
- mtot_max=input_params["mtot_max"],
861
- ratio_min=input_params["ratio_min"],
862
- ratio_max=input_params["ratio_max"],
863
- mtot_resolution=input_params["mtot_resolution"],
864
- ratio_resolution=input_params["ratio_resolution"],
865
- sampling_frequency=input_params["sampling_frequency"],
866
- waveform_approximant=input_params["waveform_approximant"],
867
- minimum_frequency=input_params["minimum_frequency"],
868
- snr_type=input_params["snr_type"],
869
- psds=input_params["psds"],
870
- ifos=input_params["ifos"],
871
- interpolator_dir=input_params["interpolator_dir"],
872
- create_new_interpolator=input_params["create_new_interpolator"],
873
- gwsnr_verbose=input_params["gwsnr_verbose"],
874
- multiprocessing_verbose=input_params["multiprocessing_verbose"],
875
- mtot_cut=input_params["mtot_cut"],
876
- pdet=input_params["pdet"],
877
- snr_th=input_params["snr_th"],
878
- snr_th_net=input_params["snr_th_net"],
879
- ann_path_dict=input_params["ann_path_dict"],
880
- )
890
+ npool=input_params["npool"],
891
+ snr_method=input_params["snr_method"],
892
+ snr_type=input_params["snr_type"],
893
+ gwsnr_verbose=input_params["gwsnr_verbose"],
894
+ multiprocessing_verbose=input_params["multiprocessing_verbose"],
895
+ pdet_kwargs=input_params["pdet_kwargs"],
896
+ mtot_min=input_params["mtot_min"],
897
+ mtot_max=input_params["mtot_max"],
898
+ ratio_min=input_params["ratio_min"],
899
+ ratio_max=input_params["ratio_max"],
900
+ spin_max=input_params["spin_max"],
901
+ mtot_resolution=input_params["mtot_resolution"],
902
+ ratio_resolution=input_params["ratio_resolution"],
903
+ spin_resolution=input_params["spin_resolution"],
904
+ batch_size_interpolation=input_params["batch_size_interpolation"],
905
+ interpolator_dir=input_params["interpolator_dir"],
906
+ create_new_interpolator=input_params["create_new_interpolator"],
907
+ sampling_frequency=input_params["sampling_frequency"],
908
+ waveform_approximant=input_params["waveform_approximant"],
909
+ frequency_domain_source_model=input_params["frequency_domain_source_model"],
910
+ minimum_frequency=input_params["minimum_frequency"],
911
+ reference_frequency=input_params["reference_frequency"],
912
+ duration_max=input_params["duration_max"],
913
+ duration_min=input_params["duration_min"],
914
+ fixed_duration=input_params["fixed_duration"],
915
+ mtot_cut=input_params["mtot_cut"],
916
+ psds=input_params["psds"],
917
+ ifos=input_params["ifos"],
918
+ noise_realization=input_params["noise_realization"],
919
+ ann_path_dict=input_params["ann_path_dict"],
920
+ snr_recalculation=input_params["snr_recalculation"],
921
+ snr_recalculation_range=input_params["snr_recalculation_range"],
922
+ snr_recalculation_waveform_approximant=input_params["snr_recalculation_waveform_approximant"],
923
+ )
881
924
 
882
- self.snr = gwsnr.snr
925
+ self.snr = gwsnr.optimal_snr
883
926
  self.list_of_detectors = gwsnr.detector_list
884
- self.snr_bilby = gwsnr.compute_bilby_snr
927
+ self.snr_bilby = gwsnr.optimal_snr_with_inner_product
885
928
  self.snr_calculator_dict["mtot_max"] = gwsnr.mtot_max
886
929
  self.snr_calculator_dict["psds"] = gwsnr.psds_list
887
- #self.pdet = gwsnr.pdet
930
+ self.pdet = gwsnr.pdet
888
931
 
889
932
  def store_ler_params(self, output_jsonfile="ler_params.json"):
890
933
  """
@@ -922,12 +965,12 @@ class LeR(LensGalaxyParameterDistribution):
922
965
  for key, value in snr_calculator_dict.items():
923
966
  snr_calculator_dict[key] = str(value)
924
967
  parameters_dict.update({"snr_calculator_dict": snr_calculator_dict})
925
-
926
- file_name = output_jsonfile
927
- append_json(self.ler_directory+"/"+file_name, parameters_dict, replace=True)
928
968
  except:
929
969
  # if snr_calculator is custom function
930
970
  pass
971
+
972
+ file_name = output_jsonfile
973
+ append_json(self.ler_directory+"/"+file_name, parameters_dict, replace=True)
931
974
 
932
975
  def unlensed_cbc_statistics(
933
976
  self, size=None, resume=False, save_batch=False, output_jsonfile=None,
@@ -965,7 +1008,7 @@ class LeR(LensGalaxyParameterDistribution):
965
1008
  output_jsonfile = output_jsonfile or self.json_file_names["unlensed_param"]
966
1009
  self.json_file_names["unlensed_param"] = output_jsonfile
967
1010
  output_path = os.path.join(self.ler_directory, output_jsonfile)
968
- print(f"unlensed params will be store in {output_path}")
1011
+ print(f"unlensed params will be stored in {output_path}")
969
1012
 
970
1013
  unlensed_param = batch_handler(
971
1014
  size=size,
@@ -1006,10 +1049,11 @@ class LeR(LensGalaxyParameterDistribution):
1006
1049
  # get gw params
1007
1050
  print("sampling gw source params...")
1008
1051
  unlensed_param = self.sample_gw_parameters(size=size)
1052
+
1009
1053
  # Get all of the signal to noise ratios
1010
1054
  if self.snr:
1011
1055
  print("calculating snrs...")
1012
- snrs = self.snr(gw_param_dict=unlensed_param)
1056
+ snrs = self.snr(gw_param_dict=unlensed_param.copy())
1013
1057
  unlensed_param.update(snrs)
1014
1058
  elif self.pdet:
1015
1059
  print("calculating pdet...")
@@ -1146,7 +1190,7 @@ class LeR(LensGalaxyParameterDistribution):
1146
1190
  dictionary of unlensed GW source parameters.
1147
1191
  """
1148
1192
 
1149
- snr_param = unlensed_param["optimal_snr_net"]
1193
+ snr_param = unlensed_param["snr_net"]
1150
1194
  idx_detectable = (snr_param > snr_threshold_recalculation[0]) & (snr_param < snr_threshold_recalculation[1])
1151
1195
  # reduce the size of the dict
1152
1196
  for key, value in unlensed_param.items():
@@ -1180,20 +1224,20 @@ class LeR(LensGalaxyParameterDistribution):
1180
1224
  """
1181
1225
 
1182
1226
  if self.snr:
1183
- if "optimal_snr_net" not in unlensed_param:
1184
- raise ValueError("'optimal_snr_net' not in unlensed param dict provided")
1227
+ if "snr_net" not in unlensed_param:
1228
+ raise ValueError("'snr_net' not in unlensed param dict provided")
1185
1229
  if detectability_condition == "step_function":
1186
- print("given detectability_condition == 'step_function'")
1187
- param = unlensed_param["optimal_snr_net"]
1230
+ #print("given detectability_condition == 'step_function'")
1231
+ param = unlensed_param["snr_net"]
1188
1232
  threshold = snr_threshold
1189
1233
  elif detectability_condition == "pdet":
1190
- print("given detectability_condition == 'pdet'")
1191
- param = 1 - norm.cdf(snr_threshold - unlensed_param["optimal_snr_net"])
1234
+ #print("given detectability_condition == 'pdet'")
1235
+ param = 1 - norm.cdf(snr_threshold - unlensed_param["snr_net"])
1192
1236
  unlensed_param["pdet_net"] = param
1193
1237
  threshold = pdet_threshold
1194
1238
  elif self.pdet:
1195
1239
  if "pdet_net" in unlensed_param:
1196
- print("given detectability_condition == 'pdet'")
1240
+ #print("given detectability_condition == 'pdet'")
1197
1241
  param = unlensed_param["pdet_net"]
1198
1242
  threshold = pdet_threshold
1199
1243
  else:
@@ -1409,11 +1453,12 @@ class LeR(LensGalaxyParameterDistribution):
1409
1453
  # check for invalid samples
1410
1454
  idx = lensed_param["n_images"] < 2
1411
1455
 
1412
- if np.sum(idx) == 0:
1413
- break
1414
- else:
1415
- print(f"Invalid sample found. Resampling {np.sum(idx)} lensed events...")
1416
- size = np.sum(idx)
1456
+ # if np.sum(idx) == 0:
1457
+ # break
1458
+ # else:
1459
+ # print(f"Invalid sample found. Resampling {np.sum(idx)} lensed events...")
1460
+ # size = np.sum(idx)
1461
+ break
1417
1462
 
1418
1463
  # Get all of the signal to noise ratios
1419
1464
  if self.snr:
@@ -1522,6 +1567,8 @@ class LeR(LensGalaxyParameterDistribution):
1522
1567
  # find index of detectable events
1523
1568
  snr_hit = self._find_detectable_index_lensed(lensed_param, snr_threshold, pdet_threshold, num_img, detectability_condition, combine_image_snr=combine_image_snr, snr_cut_for_combine_image_snr=snr_cut_for_combine_image_snr)
1524
1569
 
1570
+ # select according to time delay
1571
+
1525
1572
  # montecarlo integration
1526
1573
  total_rate = self.rate_function(np.sum(snr_hit), total_events, param_type="lensed")
1527
1574
 
@@ -1590,10 +1637,10 @@ class LeR(LensGalaxyParameterDistribution):
1590
1637
  snr_threshold_recalculation_max, _ = self._check_snr_threshold_lensed(snr_threshold_recalculation[1], num_img)
1591
1638
 
1592
1639
  # check optimal_snr_net is provided in dict
1593
- if "optimal_snr_net" not in lensed_param:
1640
+ if "snr_net" not in lensed_param:
1594
1641
  raise ValueError("optimal_snr_net not provided in lensed_param dict. Exiting...")
1595
1642
 
1596
- snr_param = lensed_param["optimal_snr_net"]
1643
+ snr_param = lensed_param["snr_net"]
1597
1644
  snr_param = -np.sort(-snr_param, axis=1) # sort snr in descending order
1598
1645
 
1599
1646
  # for each row: choose a threshold and check if the number of images above threshold. Sum over the images. If sum is greater than num_img, then snr_hit = True
@@ -1648,11 +1695,11 @@ class LeR(LensGalaxyParameterDistribution):
1648
1695
  boolean array to store the result of the threshold condition.
1649
1696
  """
1650
1697
 
1651
- print(f"given detectability_condition == {detectability_condition}")
1698
+ #print(f"given detectability_condition == {detectability_condition}")
1652
1699
  if detectability_condition == "step_function":
1653
- if "optimal_snr_net" not in lensed_param:
1654
- raise ValueError("'optimal_snr_net' not in lensed parm dict provided")
1655
- snr_param = lensed_param["optimal_snr_net"]
1700
+ if "snr_net" not in lensed_param:
1701
+ raise ValueError("'snr_net' not in lensed parm dict provided")
1702
+ snr_param = lensed_param["snr_net"]
1656
1703
  snr_param = -np.sort(-snr_param, axis=1) # sort snr in descending order
1657
1704
  snr_hit = np.full(len(snr_param), True) # boolean array to store the result of the threshold condition
1658
1705
 
@@ -1668,6 +1715,7 @@ class LeR(LensGalaxyParameterDistribution):
1668
1715
  for i, snr_th in enumerate(snr_threshold):
1669
1716
  idx_max = idx_max + num_img[i]
1670
1717
  snr_hit = snr_hit & (np.sum((snr_param[:,j:idx_max] > snr_th), axis=1) >= num_img[i])
1718
+ # select according to time delays
1671
1719
  j = idx_max
1672
1720
  else:
1673
1721
  # sqrt of the the sum of the squares of the snr of the images
@@ -1677,12 +1725,12 @@ class LeR(LensGalaxyParameterDistribution):
1677
1725
 
1678
1726
  elif detectability_condition == "pdet":
1679
1727
  if "pdet_net" not in lensed_param:
1680
- if "optimal_snr_net" not in lensed_param:
1681
- raise ValueError("'optimal_snr_net' or 'pdet_net' not in lensed parm dict provided")
1728
+ if "snr_net" not in lensed_param:
1729
+ raise ValueError("'snr_net' or 'pdet_net' not in lensed parm dict provided")
1682
1730
  else:
1683
- print("calculating pdet using 'optimal_snr_net'...")
1731
+ print("calculating pdet using 'snr_net'...")
1684
1732
  # pdet dimension is (size, n_max_images)
1685
- snr_param = lensed_param["optimal_snr_net"]
1733
+ snr_param = lensed_param["snr_net"]
1686
1734
  snr_param = -np.sort(-snr_param, axis=1) # sort snr in descending order
1687
1735
 
1688
1736
  # column index beyong np.sum(num_img)-1 are not considered
@@ -1841,13 +1889,17 @@ class LeR(LensGalaxyParameterDistribution):
1841
1889
  self,
1842
1890
  size=100,
1843
1891
  batch_size=None,
1892
+ stopping_criteria=dict(
1893
+ relative_diff_percentage=0.5,
1894
+ number_of_last_batches_to_check=4,
1895
+ ),
1844
1896
  snr_threshold=8.0,
1845
1897
  pdet_threshold=0.5,
1846
1898
  resume=False,
1847
1899
  output_jsonfile="n_unlensed_param_detectable.json",
1848
1900
  meta_data_file="meta_unlensed.json",
1849
1901
  detectability_condition="step_function",
1850
- trim_to_size=True,
1902
+ trim_to_size=False,
1851
1903
  snr_recalculation=False,
1852
1904
  snr_threshold_recalculation=[4, 12],
1853
1905
  ):
@@ -1907,7 +1959,8 @@ class LeR(LensGalaxyParameterDistribution):
1907
1959
  n, events_total, output_path, meta_data_path, buffer_file, batch_size = self._initial_setup_for_n_event_selection(meta_data_file, output_jsonfile, resume, batch_size)
1908
1960
 
1909
1961
  # loop until n samples are collected
1910
- while n < size:
1962
+ continue_condition = True
1963
+ while continue_condition:
1911
1964
  # disable print statements
1912
1965
  with contextlib.redirect_stdout(None):
1913
1966
  self.dict_buffer = None # this is used to store the sampled unlensed_param in batches when running the sampling_routine
@@ -1932,7 +1985,31 @@ class LeR(LensGalaxyParameterDistribution):
1932
1985
  total_rate = self.rate_function(n, events_total, param_type="unlensed", verbose=False)
1933
1986
 
1934
1987
  # bookmark
1935
- self._append_meta_data(meta_data_path, n, events_total, total_rate)
1988
+ buffer_dict = self._append_meta_data(meta_data_path, n, events_total, total_rate)
1989
+
1990
+ if isinstance(stopping_criteria, dict):
1991
+ total_rates = np.array(buffer_dict['total_rate'])
1992
+ limit = stopping_criteria['relative_diff_percentage']
1993
+ num_a = stopping_criteria['number_of_last_batches_to_check']
1994
+ if len(total_rates)>num_a:
1995
+ num_a = int(-1*(num_a))
1996
+ # num_b = int(num_a)
1997
+ percentage_diff = np.abs((total_rates[num_a:]-total_rates[-1])/total_rates[-1])*100
1998
+ print(f"percentage difference for the last {abs(num_a)} batches = {percentage_diff}")
1999
+ if np.any(percentage_diff>limit):
2000
+ continue_condition &= True
2001
+ else:
2002
+ print(rf"stopping criteria of rate relative difference of {limit}% reached. If you want to collect more events, reduce stopping_criteria['relative_diff_percentage'] or put stopping_criteria=None.")
2003
+ continue_condition &= False
2004
+
2005
+ if isinstance(size, int):
2006
+ if n<size:
2007
+ continue_condition |= True
2008
+ else:
2009
+ print(rf"Given size={size} reached")
2010
+ continue_condition |= False
2011
+ if stopping_criteria is None:
2012
+ continue_condition &= False
1936
2013
 
1937
2014
  print(f"stored detectable unlensed params in {output_path}")
1938
2015
  print(f"stored meta data in {meta_data_path}")
@@ -1960,6 +2037,10 @@ class LeR(LensGalaxyParameterDistribution):
1960
2037
  def selecting_n_lensed_detectable_events(
1961
2038
  self,
1962
2039
  size=100,
2040
+ stopping_criteria=dict(
2041
+ relative_diff_percentage=0.5,
2042
+ number_of_last_batches_to_check=4,
2043
+ ),
1963
2044
  batch_size=None,
1964
2045
  snr_threshold=[8.0,8.0],
1965
2046
  pdet_threshold=0.5,
@@ -1970,7 +2051,7 @@ class LeR(LensGalaxyParameterDistribution):
1970
2051
  detectability_condition="step_function",
1971
2052
  output_jsonfile="n_lensed_params_detectable.json",
1972
2053
  meta_data_file="meta_lensed.json",
1973
- trim_to_size=True,
2054
+ trim_to_size=False,
1974
2055
  nan_to_num=False,
1975
2056
  snr_recalculation=False,
1976
2057
  snr_threshold_recalculation=[[4,4],[12,12]],
@@ -2039,7 +2120,8 @@ class LeR(LensGalaxyParameterDistribution):
2039
2120
  # re-analyse the provided snr_threshold and num_img
2040
2121
  snr_threshold, num_img = self._check_snr_threshold_lensed(snr_threshold, num_img)
2041
2122
 
2042
- while n < size:
2123
+ continue_condition = True
2124
+ while continue_condition:
2043
2125
  # disable print statements
2044
2126
  with contextlib.redirect_stdout(None):
2045
2127
  self.dict_buffer = None # this is used to store the sampled lensed_param in batches when running the sampling_routine
@@ -2064,7 +2146,30 @@ class LeR(LensGalaxyParameterDistribution):
2064
2146
  total_rate = self.rate_function(n, events_total, param_type="lensed", verbose=False)
2065
2147
 
2066
2148
  # save meta data
2067
- self._append_meta_data(meta_data_path, n, events_total, total_rate)
2149
+ buffer_dict = self._append_meta_data(meta_data_path, n, events_total, total_rate)
2150
+
2151
+ if isinstance(stopping_criteria, dict):
2152
+ total_rates = np.array(buffer_dict['total_rate'])
2153
+ limit = stopping_criteria['relative_diff_percentage']
2154
+ num_a = stopping_criteria['number_of_last_batches_to_check']
2155
+
2156
+ if len(total_rates)>num_a:
2157
+ num_a = int(-1*(num_a))
2158
+ # num_b = int(num_a)
2159
+ percentage_diff = np.abs((total_rates[num_a:]-total_rates[-1])/total_rates[-1])*100
2160
+ print(f"percentage difference for the last {abs(num_a)} batches = {percentage_diff}")
2161
+ if np.any(percentage_diff>limit):
2162
+ continue_condition &= True
2163
+ else:
2164
+ print(rf"stopping criteria of rate relative difference of {limit}% reached. If you want to collect more events, reduce stopping_criteria['relative_diff_percentage']")
2165
+ continue_condition &= False
2166
+
2167
+ if isinstance(size, int):
2168
+ if n<size:
2169
+ continue_condition |= True
2170
+ else:
2171
+ print(rf"Given size={size} reached")
2172
+ continue_condition |= False
2068
2173
 
2069
2174
  print(f"storing detectable lensed params in {output_path}")
2070
2175
  print(f"storing meta data in {meta_data_path}")
@@ -2084,7 +2189,7 @@ class LeR(LensGalaxyParameterDistribution):
2084
2189
  meta = get_param_from_json(meta_data_path)
2085
2190
  data["detectable_lensed_rate_per_year"] = meta["total_rate"][-1]
2086
2191
  data["detectability_condition_lensed"] = detectability_condition
2087
- append_json(self.ler_directory+"/"+self.json_file_names["ler_params"], data, replace=True)
2192
+ buffer_dict = append_json(self.ler_directory+"/"+self.json_file_names["ler_params"], data, replace=True)
2088
2193
 
2089
2194
  return param_final
2090
2195
 
@@ -2232,12 +2337,14 @@ class LeR(LensGalaxyParameterDistribution):
2232
2337
 
2233
2338
  if os.path.exists(meta_data_path):
2234
2339
  try:
2235
- append_json(meta_data_path, meta_data, replace=False)
2340
+ dict_ = append_json(meta_data_path, meta_data, replace=False)
2236
2341
  except:
2237
- append_json(meta_data_path, meta_data, replace=True)
2342
+ dict_ = append_json(meta_data_path, meta_data, replace=True)
2238
2343
  else:
2239
- append_json(meta_data_path, meta_data, replace=True)
2344
+ dict_ = append_json(meta_data_path, meta_data, replace=True)
2240
2345
 
2241
2346
  print("collected number of detectable events = ", n)
2242
2347
  print("total number of events = ", events_total)
2243
- print(f"total rate (yr^-1): {total_rate}")
2348
+ print(f"total rate (yr^-1): {total_rate}")
2349
+
2350
+ return dict_
ler/utils/__init__.py CHANGED
@@ -1,2 +1,4 @@
1
1
  from .utils import *
2
2
  from .plots import *
3
+ from .gwsnr_training_data_generator import *
4
+ from .function_interpolation import *