ler 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ler might be problematic. Click here for more details.

ler/__init__.py CHANGED
@@ -2,19 +2,56 @@
2
2
  LeR
3
3
  """
4
4
 
5
- #import pycbc
5
+ # mypackage/cli.py
6
+ import argparse
7
+ # import subprocess, os, sys, signal, warnings
8
+
9
+ ## import pycbc
10
+ import os
6
11
  import multiprocessing as mp
7
12
 
8
- try:
9
- mp.set_start_method('spawn', force=True)
10
- except RuntimeError:
11
- pass
12
- #from . import rates, gw_source_population, lens_galaxy_population, image_properties, utils
13
+ def set_multiprocessing_start_method():
14
+ if os.name == 'posix': # posix indicates the program is run on Unix/Linux/Mac
15
+ print("Setting multiprocessing start method to 'fork'")
16
+ try:
17
+ mp.set_start_method('fork', force=True)
18
+ except RuntimeError:
19
+ # The start method can only be set once and must be set before any process starts
20
+ pass
21
+ else:
22
+ print("Setting multiprocessing start method to 'spawn'")
23
+ # For Windows and other operating systems, use 'spawn'
24
+ try:
25
+ mp.set_start_method('spawn', force=True)
26
+ except RuntimeError:
27
+ pass
28
+
29
+ set_multiprocessing_start_method()
30
+
31
+ # try:
32
+ # mp.set_start_method('fork', force=True)
33
+ # except RuntimeError:
34
+ # pass
35
+
36
+ # if sys.platform == 'darwin':
37
+ # HAVE_OMP = False
38
+
39
+ # # MacosX after python3.7 switched to 'spawn', however, this does not
40
+ # # preserve common state information which we have relied on when using
41
+ # # multiprocessing based pools.
42
+ # import multiprocessing
43
+ # if multiprocessing.get_start_method(allow_none=True) is None:
44
+ # if hasattr(multiprocessing, 'set_start_method'):
45
+ # multiprocessing.set_start_method('fork')
46
+ # elif multiprocessing.get_start_method() != 'fork':
47
+ # warnings.warn("PyCBC requires the use of the 'fork' start method for multiprocessing, it is currently set to {}".format(multiprocessing.get_start_method()))
48
+ # else:
49
+ # HAVE_OMP = True
13
50
 
14
51
  __author__ = 'hemanta_ph <hemantaphurailatpam@gmail.com>'
15
52
 
16
53
  # The version as used in the setup.py
17
- __version__ = "0.4.0"
54
+ __version__ = "0.4.2"
18
55
 
19
56
  # add __file__
20
57
  import os
@@ -46,7 +46,7 @@ class CBCSourceParameterDistribution(CBCSourceRedshiftDistribution):
46
46
  Dictionary of prior sampler functions and its input parameters.
47
47
  Check for available priors and corresponding input parameters by running,
48
48
  >>> from ler.gw_source_population import CBCSourceParameterDistribution
49
- >>> cbc = CompactBinaryPopulation()
49
+ >>> cbc = CBCSourceParameterDistribution()
50
50
  >>> cbc.available_gw_prior_list_and_its_params()
51
51
  # To check the current chosen priors and its parameters, run,
52
52
  >>> print("default priors=",cbc.gw_param_samplers)
@@ -58,8 +58,8 @@ class CBCSourceParameterDistribution(CBCSourceRedshiftDistribution):
58
58
  If True, spin parameters are completely ignore in the sampling.
59
59
  default: True
60
60
  spin_precession : `bool`
61
- If spin_zero=True and spin_precession=True, spin parameters are sampled for precessing binaries.
62
- if spin_zero=True and spin_precession=False, spin parameters are sampled for aligned/anti-aligned spin binaries.
61
+ If spin_zero=False and spin_precession=True, spin parameters are sampled for precessing binaries.
62
+ if spin_zero=False and spin_precession=False, spin parameters are sampled for aligned/anti-aligned spin binaries.
63
63
  default: False
64
64
  directory : `str`
65
65
  Directory to store the interpolator pickle files
@@ -77,7 +77,7 @@ class CBCSourceParameterDistribution(CBCSourceRedshiftDistribution):
77
77
 
78
78
  Instance Attributes
79
79
  ----------
80
- CompactBinaryPopulation has the following instance attributes:\n
80
+ CBCSourceParameterDistribution has the following instance attributes:\n
81
81
  +-------------------------------------+----------------------------------+
82
82
  | Atrributes | Type |
83
83
  +=====================================+==================================+
@@ -114,7 +114,7 @@ class CBCSourceParameterDistribution(CBCSourceRedshiftDistribution):
114
114
 
115
115
  Instance Methods
116
116
  ----------
117
- CompactBinaryPopulation has the following instance methods:\n
117
+ CBCSourceParameterDistribution has the following instance methods:\n
118
118
  +-------------------------------------+----------------------------------+
119
119
  | Methods | Type |
120
120
  +=====================================+==================================+
@@ -1136,7 +1136,7 @@ class CBCSourceParameterDistribution(CBCSourceRedshiftDistribution):
1136
1136
  Examples
1137
1137
  ----------
1138
1138
  >>> from ler.gw_source_population import CBCSourceParameterDistribution
1139
- >>> cbc = CompactBinaryPopulation()
1139
+ >>> cbc = CBCSourceParameterDistribution()
1140
1140
  >>> priors = cbc.available_gw_prior_list_and_its_params
1141
1141
  >>> priors.keys() # type of priors
1142
1142
  dict_keys(['merger_rate_density', 'source_frame_masses', 'spin', 'geocent_time', 'ra', 'phase', 'psi', 'theta_jn'])
@@ -25,7 +25,7 @@ from .jit_functions import merger_rate_density_bbh_popI_II_oguri2018, star_forma
25
25
 
26
26
  class CBCSourceRedshiftDistribution(object):
27
27
  """Class to generate a population of source galaxies.
28
- This class is inherited by :class:`~ler.ler.CompactBinaryPopulation` and :class:`~ler.ler.LensGalaxyParameterDistribution` class.
28
+ This class is inherited by :class:`~ler.ler.CBCSourceParameterDistribution` and :class:`~ler.ler.LensGalaxyParameterDistribution` class.
29
29
 
30
30
  Parameters
31
31
  ----------
@@ -1,11 +1,13 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  """
3
3
  This module contains the LensGalaxyPopulation class, which is used to sample lens galaxy parameters, source parameters conditioned on the source being strongly lensed, image properties, and lensed SNRs. \n
4
- The class inherits from the CompactBinaryPopulation class, which is used to sample source parameters. \n
4
+ The class inherits from the CBCSourceParameterDistribution class, which is used to sample source parameters. \n
5
5
  """
6
6
 
7
7
  import warnings
8
8
  warnings.filterwarnings("ignore")
9
+ import logging
10
+ logging.getLogger('numexpr.utils').setLevel(logging.ERROR)
9
11
  # for multiprocessing
10
12
  from multiprocessing import Pool
11
13
  from tqdm import tqdm
@@ -339,6 +341,8 @@ class ImageProperties():
339
341
  ####################################
340
342
  lens_parameters.update(image_parameters)
341
343
  lens_parameters["n_images"] = n_images
344
+ lens_parameters["x_source"] = x_source
345
+ lens_parameters["y_source"] = y_source
342
346
 
343
347
  return lens_parameters
344
348
 
@@ -2,7 +2,7 @@
2
2
  """
3
3
  This module contains the LensGalaxyPopulation class, which is used to sample lens galaxy parameters, source parameters conditioned on the source being strongly lensed. \n
4
4
  The class inherits from the ImageProperties class, which is used calculate image properties (magnification, timedelays, source position, image position, morse phase). \n
5
- Either the class takes in initialized CompactBinaryPopulation class as input or inherits the CompactBinaryPopulation class with default params (if no input) \n
5
+ Either the class takes in initialized CBCSourceParameterDistribution class as input or inherits the CBCSourceParameterDistribution class with default params (if no input) \n
6
6
  """
7
7
 
8
8
  import warnings
@@ -181,8 +181,8 @@ class LensGalaxyParameterDistribution(CBCSourceParameterDistribution, ImagePrope
181
181
 
182
182
  # Attributes
183
183
  cbc_pop = None
184
- """:class:`~CompactBinaryPopulation` class\n
185
- This is an already initialized class that contains a function (CompactBinaryPopulation.sample_gw_parameters) that actually samples the source parameters.
184
+ """:class:`~CBCSourceParameterDistribution` class\n
185
+ This is an already initialized class that contains a function (CBCSourceParameterDistribution.sample_gw_parameters) that actually samples the source parameters.
186
186
  """
187
187
 
188
188
  z_min = None
@@ -313,7 +313,7 @@ class LensGalaxyParameterDistribution(CBCSourceParameterDistribution, ImagePrope
313
313
  dictionary of parameters to initialize the parent classes
314
314
  """
315
315
 
316
- # initialization of CompactBinaryPopulation class
316
+ # initialization of CBCSourceParameterDistribution class
317
317
  # it also initializes the CBCSourceRedshiftDistribution class
318
318
  # list of relevant initialized instances,
319
319
  # 1. self.sample_source_redshift
@@ -374,6 +374,8 @@ class LensGalaxyParameterDistribution(CBCSourceParameterDistribution, ImagePrope
374
374
  lens_model_list=["EPL_NUMBA", "SHEAR"],
375
375
  )
376
376
  input_params_image.update(params)
377
+
378
+ print("input_params_image", input_params_image)
377
379
  ImageProperties.__init__(
378
380
  self,
379
381
  npool=self.npool,
@@ -637,7 +639,7 @@ class LensGalaxyParameterDistribution(CBCSourceParameterDistribution, ImagePrope
637
639
  def zs_function(zs_sl):
638
640
  # get zs
639
641
  # self.sample_source_redshifts from CBCSourceRedshiftDistribution class
640
- zs = self.sample_zs(size) # this function is from CompactBinaryPopulation class
642
+ zs = self.sample_zs(size) # this function is from CBCSourceParameterDistribution class
641
643
  # put strong lensing condition with optical depth
642
644
  tau = self.strong_lensing_optical_depth(zs)
643
645
  tau_max = self.strong_lensing_optical_depth(np.array([z_max]))[0] # tau increases with z
ler/rates/gwrates.py CHANGED
@@ -84,7 +84,7 @@ class GWRATES(CBCSourceParameterDistribution):
84
84
 
85
85
  Instance Attributes
86
86
  ----------
87
- LeR class has the following attributes, \n
87
+ LeR class has the following attributes:\n
88
88
  +-------------------------------------+----------------------------------+
89
89
  | Atrributes | Type |
90
90
  +=====================================+==================================+
@@ -121,7 +121,7 @@ class GWRATES(CBCSourceParameterDistribution):
121
121
 
122
122
  Instance Methods
123
123
  ----------
124
- LeR class has the following methods, \n
124
+ LeR class has the following methods:\n
125
125
  +-------------------------------------+----------------------------------+
126
126
  | Methods | Description |
127
127
  +=====================================+==================================+
@@ -506,7 +506,7 @@ class GWRATES(CBCSourceParameterDistribution):
506
506
  dictionary of parameters to initialize the parent classes
507
507
  """
508
508
 
509
- # initialization of CompactBinaryPopulation class
509
+ # initialization of CBCSourceParameterDistribution class
510
510
  # it also initializes the CBCSourceRedshiftDistribution class
511
511
  input_params = dict(
512
512
  z_min=self.z_min,
ler/rates/ler.py CHANGED
@@ -6,14 +6,22 @@ This module contains the main class for calculating the rates of detectable grav
6
6
  import os
7
7
  import warnings
8
8
  warnings.filterwarnings("ignore")
9
+ import logging
10
+ logging.getLogger('numexpr.utils').setLevel(logging.ERROR)
9
11
  import contextlib
10
12
  import numpy as np
11
13
  from scipy.stats import norm
12
14
  from astropy.cosmology import LambdaCDM
13
15
  from ..lens_galaxy_population import LensGalaxyParameterDistribution
14
- from ..utils import load_json, append_json, get_param_from_json, batch_handler, add_dict_values
16
+ from ..utils import load_json, append_json, get_param_from_json, batch_handler
15
17
 
16
18
 
19
+ # # multiprocessing guard code
20
+ # def main():
21
+ # obj = LeR()
22
+
23
+ # if __name__ == '__main__':
24
+
17
25
  class LeR(LensGalaxyParameterDistribution):
18
26
  """Class to sample of lensed and unlensed events and calculate it's rates. Please note that parameters of the simulated events are stored in json file but not as an attribute of the class. This saves RAM memory.
19
27
 
@@ -68,6 +76,23 @@ class LeR(LensGalaxyParameterDistribution):
68
76
  interpolator_directory : `str`
69
77
  directory to store the interpolators.
70
78
  default interpolator_directory = './interpolator_pickle'. This is used for storing the various interpolators related to `ler` and `gwsnr` package.
79
+ create_new_interpolator : `bool` or `dict`
80
+ default create_new_interpolator = False.
81
+ if True, the all interpolators (including `gwsnr`'s)will be created again.
82
+ if False, the interpolators will be loaded from the interpolator_directory if they exist.
83
+ if dict, you can specify which interpolators to create new. Complete example (change any of them to True), create_new_interpolator = create_new_interpolator = dict(
84
+ redshift_distribution=dict(create_new=False, resolution=1000),
85
+ z_to_luminosity_distance=dict(create_new=False, resolution=1000),
86
+ velocity_dispersion=dict(create_new=False, resolution=1000),
87
+ axis_ratio=dict(create_new=False, resolution=1000),
88
+ optical_depth=dict(create_new=False, resolution=200),
89
+ z_to_Dc=dict(create_new=False, resolution=1000),
90
+ Dc_to_z=dict(create_new=False, resolution=1000),
91
+ angular_diameter_distance=dict(create_new=False, resolution=1000),
92
+ differential_comoving_volume=dict(create_new=False, resolution=1000),
93
+ Dl_to_z=dict(create_new=False, resolution=1000),
94
+ gwsnr=False,
95
+ )
71
96
  ler_directory : `str`
72
97
  directory to store the parameters.
73
98
  default ler_directory = './ler_data'. This is used for storing the parameters of the simulated events.
@@ -90,7 +115,7 @@ class LeR(LensGalaxyParameterDistribution):
90
115
 
91
116
  Instance Attributes
92
117
  ----------
93
- LeR class has the following attributes, \n
118
+ LeR class has the following attributes: \n
94
119
  +-------------------------------------+----------------------------------+
95
120
  | Atrributes | Type |
96
121
  +=====================================+==================================+
@@ -133,7 +158,7 @@ class LeR(LensGalaxyParameterDistribution):
133
158
 
134
159
  Instance Methods
135
160
  ----------
136
- LeR class has the following methods, \n
161
+ LeR class has the following methods:\n
137
162
  +-------------------------------------+----------------------------------+
138
163
  | Methods | Description |
139
164
  +=====================================+==================================+
@@ -186,7 +211,7 @@ class LeR(LensGalaxyParameterDistribution):
186
211
  | | ratio between lensed and |
187
212
  | | unlensed events. |
188
213
  +-------------------------------------+----------------------------------+
189
- |:meth:`~rate_comparision_with_rate_calculation |
214
+ |:meth:`~rate_comparison_with_rate_calculation` |
190
215
  +-------------------------------------+----------------------------------+
191
216
  | | Function to calculate rates for |
192
217
  | | unleesed and lensed events and |
@@ -429,6 +454,7 @@ class LeR(LensGalaxyParameterDistribution):
429
454
  list_of_detectors=None,
430
455
  json_file_names=None,
431
456
  interpolator_directory="./interpolator_pickle",
457
+ create_new_interpolator=False,
432
458
  ler_directory="./ler_data",
433
459
  verbose=True,
434
460
  **kwargs,
@@ -445,6 +471,7 @@ class LeR(LensGalaxyParameterDistribution):
445
471
  if json_file_names:
446
472
  self.json_file_names.update(json_file_names)
447
473
  self.interpolator_directory = interpolator_directory
474
+ kwargs["create_new_interpolator"] = create_new_interpolator
448
475
  self.ler_directory = ler_directory
449
476
  # create directory if not exists
450
477
  if not os.path.exists(ler_directory):
@@ -480,7 +507,7 @@ class LeR(LensGalaxyParameterDistribution):
480
507
  Function to print all the parameters.
481
508
  """
482
509
  # print all relevant functions and sampler priors
483
- print("\n LeR set up params:")
510
+ print("\n # LeR set up params:")
484
511
  print(f'npool = {self.npool},')
485
512
  print(f'z_min = {self.z_min},')
486
513
  print(f'z_max = {self.z_max},')
@@ -493,23 +520,23 @@ class LeR(LensGalaxyParameterDistribution):
493
520
  if self.pdet:
494
521
  print(f'pdet_finder = {self.pdet},')
495
522
  print(f'json_file_names = {self.json_file_names},')
496
- print(f'interpolator_directory = {self.interpolator_directory},')
497
- print(f'ler_directory = {self.ler_directory},')
523
+ print(f"interpolator_directory = '{self.interpolator_directory}',")
524
+ print(f"ler_directory = '{self.ler_directory}',")
498
525
 
499
- print("\n LeR also takes CBCSourceParameterDistribution class params as kwargs, as follows:")
526
+ print("\n # LeR also takes CBCSourceParameterDistribution class params as kwargs, as follows:")
500
527
  print(f"source_priors = {self.gw_param_sampler_dict['source_priors']},")
501
528
  print(f"source_priors_params = {self.gw_param_sampler_dict['source_priors_params']},")
502
529
  print(f"spin_zero = {self.gw_param_sampler_dict['spin_zero']},")
503
530
  print(f"spin_precession = {self.gw_param_sampler_dict['spin_precession']},")
504
531
  print(f"create_new_interpolator = {self.gw_param_sampler_dict['create_new_interpolator']},")
505
532
 
506
- print("\n LeR also takes LensGalaxyParameterDistribution class params as kwargs, as follows:")
533
+ print("\n # LeR also takes LensGalaxyParameterDistribution class params as kwargs, as follows:")
507
534
  print(f"lens_type = '{self.gw_param_sampler_dict['lens_type']}',")
508
535
  print(f"lens_functions = {self.gw_param_sampler_dict['lens_functions']},")
509
536
  print(f"lens_priors = {self.gw_param_sampler_dict['lens_priors']},")
510
537
  print(f"lens_priors_params = {self.gw_param_sampler_dict['lens_priors_params']},")
511
538
 
512
- print("\n LeR also takes ImageProperties class params as kwargs, as follows:")
539
+ print("\n # LeR also takes ImageProperties class params as kwargs, as follows:")
513
540
  print(f"n_min_images = {self.n_min_images},")
514
541
  print(f"n_max_images = {self.n_max_images},")
515
542
  print(f"geocent_time_min = {self.geocent_time_min},")
@@ -517,7 +544,7 @@ class LeR(LensGalaxyParameterDistribution):
517
544
  print(f"lens_model_list = {self.lens_model_list},")
518
545
 
519
546
  if self.gwsnr:
520
- print("\n LeR also takes gwsnr.GWSNR params as kwargs, as follows:")
547
+ print("\n # LeR also takes gwsnr.GWSNR params as kwargs, as follows:")
521
548
  print(f"mtot_min = {self.snr_calculator_dict['mtot_min']},")
522
549
  print(f"mtot_max = {self.snr_calculator_dict['mtot_max']},")
523
550
  print(f"ratio_min = {self.snr_calculator_dict['ratio_min']},")
@@ -531,7 +558,6 @@ class LeR(LensGalaxyParameterDistribution):
531
558
  print(f"psds = {self.snr_calculator_dict['psds']},")
532
559
  print(f"ifos = {self.snr_calculator_dict['ifos']},")
533
560
  print(f"interpolator_dir = '{self.snr_calculator_dict['interpolator_dir']}',")
534
- print(f"create_new_interpolator = {self.snr_calculator_dict['create_new_interpolator']},")
535
561
  print(f"gwsnr_verbose = {self.snr_calculator_dict['gwsnr_verbose']},")
536
562
  print(f"multiprocessing_verbose = {self.snr_calculator_dict['multiprocessing_verbose']},")
537
563
  print(f"mtot_cut = {self.snr_calculator_dict['mtot_cut']},")
@@ -712,6 +738,7 @@ class LeR(LensGalaxyParameterDistribution):
712
738
  # initialization of LensGalaxyParameterDistribution class
713
739
  # it also initializes the CBCSourceParameterDistribution and ImageProperties classes
714
740
  input_params = dict(
741
+ # LensGalaxyParameterDistribution class params
715
742
  z_min=self.z_min,
716
743
  z_max=self.z_max,
717
744
  cosmology=self.cosmo,
@@ -720,8 +747,13 @@ class LeR(LensGalaxyParameterDistribution):
720
747
  lens_functions= None,
721
748
  lens_priors=None,
722
749
  lens_priors_params=None,
750
+ # ImageProperties class params
751
+ n_min_images=2,
752
+ n_max_images=4,
723
753
  geocent_time_min=1126259462.4,
724
754
  geocent_time_max=1126259462.4+365*24*3600*20,
755
+ lens_model_list=['EPL_NUMBA', 'SHEAR'],
756
+ # CBCSourceParameterDistribution class params
725
757
  source_priors=None,
726
758
  source_priors_params=None,
727
759
  spin_zero=True,
@@ -745,8 +777,11 @@ class LeR(LensGalaxyParameterDistribution):
745
777
  lens_functions=input_params["lens_functions"],
746
778
  lens_priors=input_params["lens_priors"],
747
779
  lens_priors_params=input_params["lens_priors_params"],
780
+ n_min_images=input_params["n_min_images"],
781
+ n_max_images=input_params["n_max_images"],
748
782
  geocent_time_min=input_params["geocent_time_min"],
749
783
  geocent_time_max=input_params["geocent_time_max"],
784
+ lens_model_list=input_params["lens_model_list"],
750
785
  source_priors=input_params["source_priors"],
751
786
  source_priors_params=input_params["source_priors_params"],
752
787
  spin_zero=input_params["spin_zero"],
@@ -792,6 +827,10 @@ class LeR(LensGalaxyParameterDistribution):
792
827
  gwsnr_verbose=False,
793
828
  multiprocessing_verbose=True,
794
829
  mtot_cut=True,
830
+ pdet=False,
831
+ snr_th=8.0,
832
+ snr_th_net=8.0,
833
+ ann_path_dict=None,
795
834
  )
796
835
  # if self.event_type == "BNS":
797
836
  # input_params["mtot_max"]= 18.
@@ -800,6 +839,21 @@ class LeR(LensGalaxyParameterDistribution):
800
839
  if key in input_params:
801
840
  input_params[key] = value
802
841
  self.snr_calculator_dict = input_params
842
+
843
+ # dealing with create_new_interpolator param
844
+ if isinstance(input_params["create_new_interpolator"], bool):
845
+ pass
846
+ elif isinstance(input_params["create_new_interpolator"], dict):
847
+ # check input_params["gwsnr"] exists
848
+ if "gwsnr" in input_params["create_new_interpolator"]:
849
+ if isinstance(input_params["create_new_interpolator"]["gwsnr"], bool):
850
+ input_params["create_new_interpolator"] = input_params["create_new_interpolator"]["gwsnr"]
851
+ else:
852
+ raise ValueError("create_new_interpolator['gwsnr'] should be a boolean.")
853
+ else:
854
+ raise ValueError("create_new_interpolator should be a boolean or a dictionary with 'gwsnr' key.")
855
+
856
+ # initialization of GWSNR class
803
857
  gwsnr = GWSNR(
804
858
  npool=input_params["npool"],
805
859
  mtot_min=input_params["mtot_min"],
@@ -815,10 +869,14 @@ class LeR(LensGalaxyParameterDistribution):
815
869
  psds=input_params["psds"],
816
870
  ifos=input_params["ifos"],
817
871
  interpolator_dir=input_params["interpolator_dir"],
818
- # create_new_interpolator=input_params["create_new_interpolator"],
872
+ create_new_interpolator=input_params["create_new_interpolator"],
819
873
  gwsnr_verbose=input_params["gwsnr_verbose"],
820
874
  multiprocessing_verbose=input_params["multiprocessing_verbose"],
821
875
  mtot_cut=input_params["mtot_cut"],
876
+ pdet=input_params["pdet"],
877
+ snr_th=input_params["snr_th"],
878
+ snr_th_net=input_params["snr_th_net"],
879
+ ann_path_dict=input_params["ann_path_dict"],
822
880
  )
823
881
 
824
882
  self.snr = gwsnr.snr
@@ -886,7 +944,7 @@ class LeR(LensGalaxyParameterDistribution):
886
944
  resume = False (default) or True.
887
945
  if True, the function will resume from the last batch.
888
946
  save_batch : `bool`
889
- if True, the function will save the parameters in batches. if False, the function will save all the parameters at the end of sampling. save_batch=False is faster.
947
+ if True, the function will save the parameters in batches. if False (default), the function will save all the parameters at the end of sampling. save_batch=False is faster.
890
948
  output_jsonfile : `str`
891
949
  json file name for storing the parameters.
892
950
  default output_jsonfile = 'unlensed_params.json'. Note that this file will be stored in the self.ler_directory.
@@ -909,36 +967,16 @@ class LeR(LensGalaxyParameterDistribution):
909
967
  output_path = os.path.join(self.ler_directory, output_jsonfile)
910
968
  print(f"unlensed params will be store in {output_path}")
911
969
 
912
- # sampling in batches
913
- if resume and os.path.exists(output_path):
914
- # get sample from json file
915
- self.dict_buffer = get_param_from_json(output_path)
916
- else:
917
- self.dict_buffer = None
918
-
919
- batch_handler(
970
+ unlensed_param = batch_handler(
920
971
  size=size,
921
972
  batch_size=self.batch_size,
922
973
  sampling_routine=self.unlensed_sampling_routine,
923
974
  output_jsonfile=output_path,
924
975
  save_batch=save_batch,
925
976
  resume=resume,
977
+ param_name="unlensed parameters",
926
978
  )
927
979
 
928
- if save_batch:
929
- unlensed_param = get_param_from_json(output_path)
930
- else:
931
- # this if condition is required if there is nothing to save
932
- if self.dict_buffer:
933
- unlensed_param = self.dict_buffer.copy()
934
- # store all params in json file
935
- print(f"saving all unlensed_params in {output_path} ")
936
- append_json(output_path, unlensed_param, replace=True)
937
- else:
938
- print("unlensed_params already sampled.")
939
- unlensed_param = get_param_from_json(output_path)
940
- self.dict_buffer = None # save memory
941
-
942
980
  return unlensed_param
943
981
 
944
982
  def unlensed_sampling_routine(self, size, output_jsonfile, resume=False, save_batch=True):
@@ -978,17 +1016,6 @@ class LeR(LensGalaxyParameterDistribution):
978
1016
  pdet = self.pdet(gw_param_dict=unlensed_param)
979
1017
  unlensed_param.update(pdet)
980
1018
 
981
- # adding batches
982
- if not save_batch:
983
- if self.dict_buffer is None:
984
- self.dict_buffer = unlensed_param
985
- else:
986
- for key, value in unlensed_param.items():
987
- self.dict_buffer[key] = np.concatenate((self.dict_buffer[key], value))
988
- else:
989
- # store all params in json file
990
- self.dict_buffer = append_json(file_name=output_jsonfile, new_dictionary=unlensed_param, old_dictionary=self.dict_buffer, replace=not (resume))
991
-
992
1019
  return unlensed_param
993
1020
 
994
1021
  def unlensed_rate(
@@ -1323,36 +1350,16 @@ class LeR(LensGalaxyParameterDistribution):
1323
1350
  output_path = os.path.join(self.ler_directory, output_jsonfile)
1324
1351
  print(f"lensed params will be store in {output_path}")
1325
1352
 
1326
- # sampling in batches
1327
- if resume and os.path.exists(output_path):
1328
- # get sample from json file
1329
- self.dict_buffer = get_param_from_json(output_path)
1330
- else:
1331
- self.dict_buffer = None
1332
-
1333
- batch_handler(
1353
+ lensed_param = batch_handler(
1334
1354
  size=size,
1335
1355
  batch_size=self.batch_size,
1336
1356
  sampling_routine=self.lensed_sampling_routine,
1337
1357
  output_jsonfile=output_path,
1338
1358
  save_batch=save_batch,
1339
1359
  resume=resume,
1360
+ param_name="lensed parameters",
1340
1361
  )
1341
1362
 
1342
- if save_batch:
1343
- lensed_param = get_param_from_json(output_path)
1344
- else:
1345
- # this if condition is required if there is nothing to save
1346
- if self.dict_buffer:
1347
- lensed_param = self.dict_buffer.copy()
1348
- # store all params in json file
1349
- print(f"saving all lensed_params in {output_path} ")
1350
- append_json(output_path, lensed_param, replace=True)
1351
- else:
1352
- print("lensed_params already sampled.")
1353
- lensed_param = get_param_from_json(output_path)
1354
- self.dict_buffer = None # save memory
1355
-
1356
1363
  return lensed_param
1357
1364
 
1358
1365
  def lensed_sampling_routine(self, size, output_jsonfile, save_batch=True, resume=False):
@@ -1382,6 +1389,8 @@ class LeR(LensGalaxyParameterDistribution):
1382
1389
  print("sampling lensed params...")
1383
1390
  lensed_param = {}
1384
1391
 
1392
+ # Some of the sample lensed events may not satisfy the strong lensing condition
1393
+ # In that case, we will resample those events and replace the values with the corresponding indices
1385
1394
  while True:
1386
1395
  # get lensed params
1387
1396
  lensed_param_ = self.sample_lens_parameters(size=size)
@@ -1425,17 +1434,6 @@ class LeR(LensGalaxyParameterDistribution):
1425
1434
  )
1426
1435
  lensed_param.update(pdet)
1427
1436
 
1428
- # adding batches
1429
- if not save_batch:
1430
- if self.dict_buffer is None:
1431
- self.dict_buffer = lensed_param
1432
- else:
1433
- for key, value in lensed_param.items():
1434
- self.dict_buffer[key] = np.concatenate((self.dict_buffer[key], value))
1435
- else:
1436
- # store all params in json file
1437
- self.dict_buffer = append_json(file_name=output_jsonfile, new_dictionary=lensed_param, old_dictionary=self.dict_buffer, replace=not (resume))
1438
-
1439
1437
  return lensed_param
1440
1438
 
1441
1439
  def lensed_rate(
@@ -1447,6 +1445,8 @@ class LeR(LensGalaxyParameterDistribution):
1447
1445
  output_jsonfile=None,
1448
1446
  nan_to_num=True,
1449
1447
  detectability_condition="step_function",
1448
+ combine_image_snr=False,
1449
+ snr_cut_for_combine_image_snr=4.0,
1450
1450
  snr_recalculation=False,
1451
1451
  snr_threshold_recalculation=[[4,4], [20,20]],
1452
1452
  ):
@@ -1519,7 +1519,8 @@ class LeR(LensGalaxyParameterDistribution):
1519
1519
  if snr_recalculation:
1520
1520
  lensed_param = self._recalculate_snr_lensed(lensed_param, snr_threshold_recalculation, num_img, total_events)
1521
1521
 
1522
- snr_hit = self._find_detectable_index_lensed(lensed_param, snr_threshold, pdet_threshold, num_img, detectability_condition)
1522
+ # find index of detectable events
1523
+ snr_hit = self._find_detectable_index_lensed(lensed_param, snr_threshold, pdet_threshold, num_img, detectability_condition, combine_image_snr=combine_image_snr, snr_cut_for_combine_image_snr=snr_cut_for_combine_image_snr)
1523
1524
 
1524
1525
  # montecarlo integration
1525
1526
  total_rate = self.rate_function(np.sum(snr_hit), total_events, param_type="lensed")
@@ -1619,7 +1620,7 @@ class LeR(LensGalaxyParameterDistribution):
1619
1620
 
1620
1621
  return lensed_param
1621
1622
 
1622
- def _find_detectable_index_lensed(self, lensed_param, snr_threshold, pdet_threshold, num_img, detectability_condition):
1623
+ def _find_detectable_index_lensed(self, lensed_param, snr_threshold, pdet_threshold, num_img, detectability_condition, combine_image_snr=False, snr_cut_for_combine_image_snr=4.0):
1623
1624
  """
1624
1625
  Helper function to find the index of detectable events based on SNR or p_det.
1625
1626
 
@@ -1655,18 +1656,24 @@ class LeR(LensGalaxyParameterDistribution):
1655
1656
  snr_param = -np.sort(-snr_param, axis=1) # sort snr in descending order
1656
1657
  snr_hit = np.full(len(snr_param), True) # boolean array to store the result of the threshold condition
1657
1658
 
1658
- # for each row: choose a threshold and check if the number of images above threshold. Sum over the images. If sum is greater than num_img, then snr_hit = True
1659
- # algorithm:
1660
- # i) consider snr_threshold=[8,6] and num_img=[2,1] and first row of snr_param[0]=[12,8,6,1]. Note that the snr_param is sorted in descending order.
1661
- # ii) for loop runs wrt snr_threshold. idx_max = idx_max + num_img[i]
1662
- # iii) First iteration: snr_threshold=8 and num_img=2. In snr_param, column index 0 and 1 (i.e. 0:num_img[0]) are considered. The sum of snr_param[0, 0:2] > 8 is checked. If True, then snr_hit = True.
1663
- # v) Second iteration: snr_threshold=6 and num_img=1. In snr_param, column index 2 (i.e. num_img[0]:num_img[1]) is considered. The sum of snr_param[0, 0:1] > 6 is checked. If True, then snr_hit = True.
1664
- j = 0
1665
- idx_max = 0
1666
- for i, snr_th in enumerate(snr_threshold):
1667
- idx_max = idx_max + num_img[i]
1668
- snr_hit = snr_hit & (np.sum((snr_param[:,j:idx_max] > snr_th), axis=1) >= num_img[i])
1669
- j = idx_max
1659
+ if not combine_image_snr:
1660
+ # for each row: choose a threshold and check if the number of images above threshold. Sum over the images. If sum is greater than num_img, then snr_hit = True
1661
+ # algorithm:
1662
+ # i) consider snr_threshold=[8,6] and num_img=[2,1] and first row of snr_param[0]=[12,8,6,1]. Note that the snr_param is sorted in descending order.
1663
+ # ii) for loop runs wrt snr_threshold. idx_max = idx_max + num_img[i]
1664
+ # iii) First iteration: snr_threshold=8 and num_img=2. In snr_param, column index 0 and 1 (i.e. 0:num_img[0]) are considered. The sum of snr_param[0, 0:2] > 8 is checked. If True, then snr_hit = True.
1665
+ # v) Second iteration: snr_threshold=6 and num_img=1. In snr_param, column index 2 (i.e. num_img[0]:num_img[1]) is considered. The sum of snr_param[0, 0:1] > 6 is checked. If True, then snr_hit = True.
1666
+ j = 0
1667
+ idx_max = 0
1668
+ for i, snr_th in enumerate(snr_threshold):
1669
+ idx_max = idx_max + num_img[i]
1670
+ snr_hit = snr_hit & (np.sum((snr_param[:,j:idx_max] > snr_th), axis=1) >= num_img[i])
1671
+ j = idx_max
1672
+ else:
1673
+ # sqrt of the the sum of the squares of the snr of the images
1674
+ snr_param[snr_param<snr_cut_for_combine_image_snr] = 0.0 # images with snr below snr_cut_for_combine_image_snr are not considered
1675
+ snr_param = np.sqrt(np.sum(snr_param[:,:np.sum(num_img)]**2, axis=1))
1676
+ snr_hit = snr_param >= snr_threshold[0]
1670
1677
 
1671
1678
  elif detectability_condition == "pdet":
1672
1679
  if "pdet_net" not in lensed_param:
@@ -1697,7 +1704,7 @@ class LeR(LensGalaxyParameterDistribution):
1697
1704
 
1698
1705
  return snr_hit
1699
1706
 
1700
- def rate_comparision_with_rate_calculation(
1707
+ def rate_comparison_with_rate_calculation(
1701
1708
  self,
1702
1709
  unlensed_param=None,
1703
1710
  snr_threshold_unlensed=8.0,
@@ -1705,6 +1712,8 @@ class LeR(LensGalaxyParameterDistribution):
1705
1712
  lensed_param=None,
1706
1713
  snr_threshold_lensed=[8.0,8.0],
1707
1714
  num_img=[1,1],
1715
+ combine_image_snr=False,
1716
+ snr_cut_for_combine_image_snr=4.0,
1708
1717
  output_jsonfile_lensed=None,
1709
1718
  nan_to_num=True,
1710
1719
  detectability_condition="step_function",
@@ -1758,7 +1767,7 @@ class LeR(LensGalaxyParameterDistribution):
1758
1767
  >>> ler = LeR()
1759
1768
  >>> ler.unlensed_cbc_statistics();
1760
1769
  >>> ler.lensed_cbc_statistics();
1761
- >>> rate_ratio, unlensed_param, lensed_param = ler.rate_comparision_with_rate_calculation()
1770
+ >>> rate_ratio, unlensed_param, lensed_param = ler.rate_comparison_with_rate_calculation()
1762
1771
  """
1763
1772
 
1764
1773
  # call json_file_ler_param and add the results
@@ -1779,6 +1788,8 @@ class LeR(LensGalaxyParameterDistribution):
1779
1788
  output_jsonfile=output_jsonfile_lensed,
1780
1789
  nan_to_num=nan_to_num,
1781
1790
  detectability_condition=detectability_condition,
1791
+ combine_image_snr=combine_image_snr,
1792
+ snr_cut_for_combine_image_snr=snr_cut_for_combine_image_snr,
1782
1793
  )
1783
1794
  # calculate rate ratio
1784
1795
  rate_ratio = self.rate_ratio()
@@ -1893,7 +1904,7 @@ class LeR(LensGalaxyParameterDistribution):
1893
1904
  """
1894
1905
 
1895
1906
  # initial setup
1896
- n, events_total, output_path, meta_data_path, buffer_file = self._initial_setup_for_n_event_selection(meta_data_file, output_jsonfile, resume, batch_size)
1907
+ n, events_total, output_path, meta_data_path, buffer_file, batch_size = self._initial_setup_for_n_event_selection(meta_data_file, output_jsonfile, resume, batch_size)
1897
1908
 
1898
1909
  # loop until n samples are collected
1899
1910
  while n < size:
@@ -1953,6 +1964,8 @@ class LeR(LensGalaxyParameterDistribution):
1953
1964
  snr_threshold=[8.0,8.0],
1954
1965
  pdet_threshold=0.5,
1955
1966
  num_img=[1,1],
1967
+ combine_image_snr=False,
1968
+ snr_cut_for_combine_image_snr=4.0,
1956
1969
  resume=False,
1957
1970
  detectability_condition="step_function",
1958
1971
  output_jsonfile="n_lensed_params_detectable.json",
@@ -2021,7 +2034,7 @@ class LeR(LensGalaxyParameterDistribution):
2021
2034
  """
2022
2035
 
2023
2036
  # initial setup
2024
- n, events_total, output_path, meta_data_path, buffer_file = self._initial_setup_for_n_event_selection(meta_data_file, output_jsonfile, resume, batch_size)
2037
+ n, events_total, output_path, meta_data_path, buffer_file, batch_size = self._initial_setup_for_n_event_selection(meta_data_file, output_jsonfile, resume, batch_size)
2025
2038
 
2026
2039
  # re-analyse the provided snr_threshold and num_img
2027
2040
  snr_threshold, num_img = self._check_snr_threshold_lensed(snr_threshold, num_img)
@@ -2041,7 +2054,7 @@ class LeR(LensGalaxyParameterDistribution):
2041
2054
  if snr_recalculation:
2042
2055
  lensed_param = self._recalculate_snr_lensed(lensed_param, snr_threshold_recalculation, num_img, total_events_in_this_iteration)
2043
2056
 
2044
- snr_hit = self._find_detectable_index_lensed(lensed_param, snr_threshold, pdet_threshold, num_img, detectability_condition)
2057
+ snr_hit = self._find_detectable_index_lensed(lensed_param, snr_threshold, pdet_threshold, num_img, detectability_condition, combine_image_snr=combine_image_snr, snr_cut_for_combine_image_snr=snr_cut_for_combine_image_snr)
2045
2058
 
2046
2059
  # store all params in json file
2047
2060
  self._save_detectable_params(output_jsonfile, lensed_param, snr_hit, key_file_name="n_lensed_detectable_events", nan_to_num=nan_to_num, verbose=False, replace_jsonfile=False)
@@ -2118,6 +2131,8 @@ class LeR(LensGalaxyParameterDistribution):
2118
2131
  if not resume:
2119
2132
  n = 0 # iterator
2120
2133
  events_total = 0
2134
+ # the following file will be removed if it exists
2135
+ print(f"removing {output_path} and {meta_data_path} if they exist")
2121
2136
  if os.path.exists(output_path):
2122
2137
  os.remove(output_path)
2123
2138
  if os.path.exists(meta_data_path):
@@ -2136,7 +2151,7 @@ class LeR(LensGalaxyParameterDistribution):
2136
2151
  buffer_file = "params_buffer.json"
2137
2152
  print("collected number of detectable events = ", n)
2138
2153
 
2139
- return n, events_total, output_path, meta_data_path, buffer_file
2154
+ return n, events_total, output_path, meta_data_path, buffer_file, batch_size
2140
2155
 
2141
2156
  def _trim_results_to_size(self, size, output_path, meta_data_path, param_type="unlensed"):
2142
2157
  """
ler/utils/utils.py CHANGED
@@ -5,6 +5,7 @@ This module contains helper routines for other modules in the ler package.
5
5
 
6
6
  import os
7
7
  import pickle
8
+ import h5py
8
9
  import numpy as np
9
10
  import json
10
11
  from scipy.interpolate import interp1d
@@ -52,6 +53,66 @@ class NumpyEncoder(json.JSONEncoder):
52
53
  return obj.tolist()
53
54
  return json.JSONEncoder.default(self, obj)
54
55
 
56
+ def load_pickle(file_name):
57
+ """Load a pickle file.
58
+
59
+ Parameters
60
+ ----------
61
+ file_name : `str`
62
+ pickle file name for storing the parameters.
63
+
64
+ Returns
65
+ ----------
66
+ param : `dict`
67
+ """
68
+ with open(file_name, "rb") as handle:
69
+ param = pickle.load(handle)
70
+
71
+ return param
72
+
73
+ def save_pickle(file_name, param):
74
+ """Save a dictionary as a pickle file.
75
+
76
+ Parameters
77
+ ----------
78
+ file_name : `str`
79
+ pickle file name for storing the parameters.
80
+ param : `dict`
81
+ dictionary to be saved as a pickle file.
82
+ """
83
+ with open(file_name, "wb") as handle:
84
+ pickle.dump(param, handle, protocol=pickle.HIGHEST_PROTOCOL)
85
+
86
+ # hdf5
87
+ def load_hdf5(file_name):
88
+ """Load a hdf5 file.
89
+
90
+ Parameters
91
+ ----------
92
+ file_name : `str`
93
+ hdf5 file name for storing the parameters.
94
+
95
+ Returns
96
+ ----------
97
+ param : `dict`
98
+ """
99
+
100
+ return h5py.File(file_name, 'r')
101
+
102
+ def save_hdf5(file_name, param):
103
+ """Save a dictionary as a hdf5 file.
104
+
105
+ Parameters
106
+ ----------
107
+ file_name : `str`
108
+ hdf5 file name for storing the parameters.
109
+ param : `dict`
110
+ dictionary to be saved as a hdf5 file.
111
+ """
112
+ with h5py.File(file_name, 'w') as f:
113
+ for key, value in param.items():
114
+ f.create_dataset(key, data=value)
115
+
55
116
  def load_json(file_name):
56
117
  """Load a json file.
57
118
 
@@ -125,7 +186,7 @@ def append_json(file_name, new_dictionary, old_dictionary=None, replace=False):
125
186
 
126
187
  # start = datetime.datetime.now()
127
188
  if not replace:
128
- data = add_dict_values(data, new_dictionary)
189
+ data = add_dictionaries_together(data, new_dictionary)
129
190
  # data_key = data.keys()
130
191
  # for key, value in new_dictionary.items():
131
192
  # if key in data_key:
@@ -143,25 +204,25 @@ def append_json(file_name, new_dictionary, old_dictionary=None, replace=False):
143
204
 
144
205
  return data
145
206
 
146
- def add_dict_values(dict1, dict2):
147
- """Adds the values of two dictionaries together.
207
+ # def add_dict_values(dict1, dict2):
208
+ # """Adds the values of two dictionaries together.
148
209
 
149
- Parameters
150
- ----------
151
- dict1 : `dict`
152
- dictionary to be added.
153
- dict2 : `dict`
154
- dictionary to be added.
210
+ # Parameters
211
+ # ----------
212
+ # dict1 : `dict`
213
+ # dictionary to be added.
214
+ # dict2 : `dict`
215
+ # dictionary to be added.
155
216
 
156
- Returns
157
- ----------
158
- dict1 : `dict`
159
- dictionary with added values.
160
- """
161
- data_key = dict1.keys()
162
- for key, value in dict2.items():
163
- if key in data_key:
164
- dict1[key] = np.concatenate((dict1[key], value))
217
+ # Returns
218
+ # ----------
219
+ # dict1 : `dict`
220
+ # dictionary with added values.
221
+ # """
222
+ # data_key = dict1.keys()
223
+ # for key, value in dict2.items():
224
+ # if key in data_key:
225
+ # dict1[key] = np.concatenate((dict1[key], value))
165
226
 
166
227
  return dict1
167
228
 
@@ -315,13 +376,33 @@ def add_dictionaries_together(dictionary1, dictionary2):
315
376
  if dictionary1.keys() != dictionary2.keys():
316
377
  raise ValueError("The dictionaries have different keys.")
317
378
  for key in dictionary1.keys():
318
- # Check if the item is an ndarray
319
- if isinstance(dictionary1[key], np.ndarray):
320
- dictionary[key] = np.concatenate((dictionary1[key], dictionary2[key]))
321
- elif isinstance(dictionary1[key], list):
322
- dictionary[key] = dictionary1[key] + dictionary2[key]
323
- # Check if the item is a nested dictionary
324
- elif isinstance(dictionary1[key], dict):
379
+ value1 = dictionary1[key]
380
+ value2 = dictionary2[key]
381
+
382
+ # check if the value is empty
383
+ bool0 = len(value1) == 0 or len(value2) == 0
384
+ # check if the value is an ndarray or a list
385
+ bool1 = isinstance(value1, np.ndarray) and isinstance(value2, np.ndarray)
386
+ bool2 = isinstance(value1, list) and isinstance(value2, list)
387
+ bool3 = isinstance(value1, np.ndarray) and isinstance(value2, list)
388
+ bool4 = isinstance(value1, list) and isinstance(value2, np.ndarray)
389
+ bool4 = bool4 or bool3
390
+ bool5 = isinstance(value1, dict) and isinstance(value2, dict)
391
+
392
+ if bool0:
393
+ if len(value1) == 0 and len(value2) == 0:
394
+ dictionary[key] = np.array([])
395
+ elif len(value1) != 0 and len(value2) == 0:
396
+ dictionary[key] = np.array(value1)
397
+ elif len(value1) == 0 and len(value2) != 0:
398
+ dictionary[key] = np.array(value2)
399
+ elif bool1:
400
+ dictionary[key] = np.concatenate((value1, value2))
401
+ elif bool2:
402
+ dictionary[key] = value1 + value2
403
+ elif bool4:
404
+ dictionary[key] = np.concatenate((np.array(value1), np.array(value2)))
405
+ elif bool5:
325
406
  dictionary[key] = add_dictionaries_together(
326
407
  dictionary1[key], dictionary2[key]
327
408
  )
@@ -370,9 +451,9 @@ def create_func_pdf_invcdf(x, y, category="function"):
370
451
  Parameters
371
452
  ----------
372
453
  x : `numpy.ndarray`
373
- x values.
454
+ x values. This has to sorted in ascending order.
374
455
  y : `numpy.ndarray`
375
- y values.
456
+ y values. Corresponding to the x values.
376
457
  category : `str`, optional
377
458
  category of the function. Default is "function". Other options are "function_inverse", "pdf" and "inv_cdf".
378
459
 
@@ -853,8 +934,7 @@ def inverse_transform_sampler(size, cdf, x):
853
934
  samples = y0 + (y1 - y0) * (u - x0) / (x1 - x0)
854
935
  return samples
855
936
 
856
- def batch_handler(size, batch_size, sampling_routine, output_jsonfile, save_batch=True, resume=False,
857
- ):
937
+ def batch_handler(size, batch_size, sampling_routine, output_jsonfile, save_batch=True, resume=False, param_name='parameters'):
858
938
  """
859
939
  Function to run the sampling in batches.
860
940
 
@@ -865,14 +945,28 @@ def batch_handler(size, batch_size, sampling_routine, output_jsonfile, save_batc
865
945
  batch_size : `int`
866
946
  batch size.
867
947
  sampling_routine : `function`
868
- function to sample the parameters.
869
- e.g. unlensed_sampling_routine() or lensed_sampling_routine()
948
+ sampling function. It should have 'size' as input and return a dictionary.
870
949
  output_jsonfile : `str`
871
- name of the json file to store the parameters.
872
- resume : `bool`
873
- if True, it will resume the sampling from the last batch.
874
- default resume = False.
950
+ json file name for storing the parameters.
951
+ save_batch : `bool`, optional
952
+ if True, save sampled parameters in each iteration. Default is True.
953
+ resume : `bool`, optional
954
+ if True, resume sampling from the last batch. Default is False.
955
+ param_name : `str`, optional
956
+ name of the parameter. Default is 'parameters'.
957
+
958
+ Returns
959
+ ----------
960
+ dict_buffer : `dict`
961
+ dictionary of parameters.
875
962
  """
963
+
964
+ # sampling in batches
965
+ if resume and os.path.exists(output_jsonfile):
966
+ # get sample from json file
967
+ dict_buffer = get_param_from_json(output_jsonfile)
968
+ else:
969
+ dict_buffer = None
876
970
 
877
971
  # if size is multiple of batch_size
878
972
  if size % batch_size == 0:
@@ -895,69 +989,90 @@ def batch_handler(size, batch_size, sampling_routine, output_jsonfile, save_batc
895
989
  track_batches = 0 # to track the number of batches
896
990
 
897
991
  if not resume:
898
- track_batches = track_batches + 1
899
- print(f"Batch no. {track_batches}")
900
- # new first batch with the frac_batches
901
- sampling_routine(size=frac_batches, save_batch=save_batch, output_jsonfile=output_jsonfile);
992
+ # create new first batch with the frac_batches
993
+ track_batches, dict_buffer = create_batch_params(sampling_routine, frac_batches, dict_buffer, save_batch, output_jsonfile, track_batches=track_batches)
902
994
  else:
903
995
  # check where to resume from
996
+ # identify the last batch and assign current batch number
997
+ # try-except is added to avoid the error when the file does not exist or if the file is empty or corrupted or does not have the required key.
904
998
  try:
905
999
  print(f"resuming from {output_jsonfile}")
906
- with open(output_jsonfile, "r", encoding="utf-8") as f:
907
- data = json.load(f)
908
- track_batches = (len(data["zs"]) - frac_batches) // batch_size + 1
1000
+ len_ = len(list(dict_buffer.values())[0])
1001
+ track_batches = (len_ - frac_batches) // batch_size + 1
909
1002
  except:
910
- track_batches = track_batches + 1
911
- print(f"Batch no. {track_batches}")
912
- # new first batch with the frac_batches
913
- sampling_routine(size=frac_batches, save_batch=save_batch, output_jsonfile=output_jsonfile);
1003
+ # create new first batch with the frac_batches
1004
+ track_batches, dict_buffer = create_batch_params(sampling_routine, frac_batches, dict_buffer, save_batch, output_jsonfile, track_batches=track_batches)
914
1005
 
915
- # ---------------------------------------------------#
1006
+ # loop over the remaining batches
916
1007
  min_, max_ = track_batches, num_batches
917
- for i in range(min_, max_):
918
- track_batches = track_batches + 1
919
- print(f"Batch no. {track_batches}")
920
- sampling_routine(size=batch_size, save_batch=save_batch, output_jsonfile=output_jsonfile, resume=True);
921
- # ---------------------------------------------------#
1008
+ # print(f"min_ = {min_}, max_ = {max_}")
1009
+ save_param = False
1010
+ if min_ == max_:
1011
+ print(f"{param_name} already sampled.")
1012
+ elif min_ > max_:
1013
+ len_ = len(list(dict_buffer.values())[0])
1014
+ print(f"existing {param_name} size is {len_} is more than the required size={size}. It will be trimmed.")
1015
+ dict_buffer = trim_dictionary(dict_buffer, size)
1016
+ save_param = True
1017
+ else:
1018
+ for i in range(min_, max_):
1019
+ _, dict_buffer = create_batch_params(sampling_routine, batch_size, dict_buffer, save_batch, output_jsonfile, track_batches=i, resume=True)
1020
+
1021
+ if save_batch:
1022
+ # if save_batch=True, then dict_buffer is only the last batch
1023
+ dict_buffer = get_param_from_json(output_jsonfile)
1024
+ else: # dont save in batches
1025
+ # this if condition is required if there is nothing to save
1026
+ save_param = True
1027
+
1028
+ if save_param:
1029
+ # store all params in json file
1030
+ print(f"saving all {param_name} in {output_jsonfile} ")
1031
+ append_json(output_jsonfile, dict_buffer, replace=True)
922
1032
 
923
- return None
1033
+ return dict_buffer
924
1034
 
925
- # def batch_handler(size, batch_size, sampling_routine, output_jsonfile, save_batch=True, resume=False):
926
- # """
927
- # Function to run the sampling in batches.
928
-
929
- # Parameters
930
- # ----------
931
- # size : `int`
932
- # number of samples.
933
- # batch_size : `int`
934
- # batch size.
935
- # sampling_routine : `function`
936
- # function to sample the parameters.
937
- # e.g. unlensed_sampling_routine() or lensed_sampling_routine()
938
- # output_jsonfile : `str`
939
- # name of the json file to store the parameters.
940
- # resume : `bool`
941
- # if True, it will resume the sampling from the last batch.
942
- # default resume = False.
943
- # """
1035
+ def create_batch_params(sampling_routine, frac_batches, dict_buffer, save_batch, output_jsonfile, track_batches, resume=False):
1036
+ """
1037
+ Helper function to batch_handler. It create batch parameters and store in a dictionary.
944
1038
 
945
- # num_batches = (size + batch_size - 1) // batch_size
946
- # print(f"Chosen batch size = {batch_size} with total size = {size}")
947
- # print(f"There will be {num_batches} batch(es)")
1039
+ Parameters
1040
+ ----------
1041
+ sampling_routine : `function`
1042
+ sampling function. It should have 'size' as input and return a dictionary.
1043
+ frac_batches : `int`
1044
+ batch size.
1045
+ dict_buffer : `dict`
1046
+ dictionary of parameters.
1047
+ save_batch : `bool`
1048
+ if True, save sampled parameters in each iteration.
1049
+ output_jsonfile : `str`
1050
+ json file name for storing the parameters.
1051
+ track_batches : `int`
1052
+ track the number of batches.
1053
+ resume : `bool`, optional
1054
+ if True, resume sampling from the last batch. Default is False.
948
1055
 
949
- # if not resume:
950
- # first_batch_size = size % batch_size or batch_size
951
- # else:
952
- # with open(output_jsonfile, "r", encoding="utf-8") as f:
953
- # data = json.load(f)
954
- # first_batch_size = len(data["zs"]) % batch_size or batch_size
1056
+ Returns
1057
+ ----------
1058
+ track_batches : `int`
1059
+ track the number of batches.
1060
+ """
955
1061
 
956
- # for batch_num in range(1, num_batches + 1):
957
- # print(f"Batch no. {batch_num}")
958
- # current_batch_size = first_batch_size if batch_num == 1 else batch_size
959
- # sampling_routine(size=current_batch_size, output_jsonfile=output_jsonfile, resume=resume, save_batch=save_batch)
960
- # resume = True # Resume for subsequent batches
1062
+ track_batches = track_batches + 1
1063
+ print(f"Batch no. {track_batches}")
1064
+ param = sampling_routine(size=frac_batches, save_batch=save_batch, output_jsonfile=output_jsonfile, resume=resume)
961
1065
 
962
- # return None
1066
+ # adding batches and hold it in the buffer attribute.
1067
+ if not save_batch:
1068
+ # in the new batch (new sampling run), dict_buffer will be None
1069
+ if dict_buffer is None:
1070
+ dict_buffer = param
1071
+ else:
1072
+ for key, value in param.items():
1073
+ dict_buffer[key] = np.concatenate((dict_buffer[key], value))
1074
+ else:
1075
+ # store all params in json file
1076
+ dict_buffer = append_json(file_name=output_jsonfile, new_dictionary=param, old_dictionary=dict_buffer, replace=not (resume))
963
1077
 
1078
+ return track_batches, dict_buffer
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ler
3
- Version: 0.4.0
4
- Summary: Gravitational waves Lensing Rates
3
+ Version: 0.4.2
4
+ Summary: LVK (LIGO-Virgo-KAGRA collaboration) Event (compact-binary mergers) Rate calculator and simulator
5
5
  Home-page: https://github.com/hemantaph/ler
6
6
  Author: Hemantakumar
7
7
  Author-email: hemantaphurailatpam@gmail.com
@@ -22,7 +22,7 @@ Requires-Dist: astropy >=5.1
22
22
  Requires-Dist: tqdm >=4.64.1
23
23
  Requires-Dist: pointpats >=2.3
24
24
  Requires-Dist: shapely >=2.0.1
25
- Requires-Dist: gwcosmo
25
+ Requires-Dist: gwcosmo ==2.1.0
26
26
 
27
27
  # LeR
28
28
  [![DOI](https://zenodo.org/badge/626733473.svg)](https://zenodo.org/badge/latestdoi/626733473) [![PyPI version](https://badge.fury.io/py/ler.svg)](https://badge.fury.io/py/ler) [![DOCS](https://readthedocs.org/projects/ler/badge/?version=latest)](https://ler.readthedocs.io/en/latest/)
@@ -1,25 +1,25 @@
1
- ler/__init__.py,sha256=Ip4IJByGm_ESuIsV_ICkM50WwUIjnk_G1DE-S1TjDPA,804
1
+ ler/__init__.py,sha256=KUIsDaNuHiBdV2gljkp4QbqUCwyMxx5gdZk3r7mVrK8,2174
2
2
  ler/gw_source_population/__init__.py,sha256=HG0ve5wTpBDN2fNMxHLnoOqTz-S0jXM_DsWEJ5PEHAw,126
3
- ler/gw_source_population/cbc_source_parameter_distribution.py,sha256=e0-Sqcx7WblWkqX1WQg6sCzrpMbkQ5tUcMu9sbBcp9Y,67455
4
- ler/gw_source_population/cbc_source_redshift_distribution.py,sha256=o2qAM_-9SeLxxfGwqXrdVWTCeEAaXVan_OPDd4jrplg,28559
3
+ ler/gw_source_population/cbc_source_parameter_distribution.py,sha256=A6tEvCMB7sqEQPlEBR6g7nqQMN8uPkiWiqT7LQLocfw,67485
4
+ ler/gw_source_population/cbc_source_redshift_distribution.py,sha256=F2MBdYyXxh0kr5No778wNc4esBs7MES3ej2FrLaZD0M,28566
5
5
  ler/gw_source_population/jit_functions.py,sha256=aQV9mv3IY5b3OLiPeXmoLWJ_TbFUS9M1OgnIyIY3eX4,8668
6
6
  ler/image_properties/__init__.py,sha256=XfJFlyZuOrKODT-z9WxjR9mI8eT399YJV-jzcJKTqGo,71
7
- ler/image_properties/image_properties.py,sha256=QmZ27y4CFR-DvzBxJewgaH3kEAXW6UDPxbyI7zwjdP4,25302
7
+ ler/image_properties/image_properties.py,sha256=CI-0_Jj4u5vf9KhZcbA9Nmc1naxEfbFYCMuRjUWi-oA,25477
8
8
  ler/image_properties/multiprocessing_routine.py,sha256=hYnQTM7PSIj3X-5YNDqMxH9UgeXHUPPdLG70h_r6sEY,18333
9
9
  ler/lens_galaxy_population/__init__.py,sha256=TXk1nwiYy0tvTpKs35aYK0-ZK63g2JLPyGG_yfxD0YU,126
10
10
  ler/lens_galaxy_population/jit_functions.py,sha256=tCTcr4FWyQXH7SQlHsUWeZBpv4jnG00DsBIljdWFs5M,8472
11
- ler/lens_galaxy_population/lens_galaxy_parameter_distribution.py,sha256=_o_kaIIqps9i-5RYQ8PIaofblAwTBSIokjYPEbo1Rh4,48278
11
+ ler/lens_galaxy_population/lens_galaxy_parameter_distribution.py,sha256=tRNy4DtSUPGfm5AdmQ4tbD-hHDwQ_HeOgq_mkDN5qe8,48377
12
12
  ler/lens_galaxy_population/mp.py,sha256=TPnFDEzojEqJzE3b0g39emZasHeeaeXN2q7JtMcgihk,6387
13
13
  ler/lens_galaxy_population/optical_depth.py,sha256=rZ_Inpw7ChpFdDLp3kJrCmA0PL3RxN6T_W_NTFhj_ko,42542
14
14
  ler/rates/__init__.py,sha256=N4li9NouSVjZl5HIhyuiKKRyrpUgQkBZaUeDgL1m4ic,43
15
- ler/rates/gwrates.py,sha256=akw6rKAkETr_ERmymVJx3APRXs0XqqFccZ-LIzAV4jM,58465
15
+ ler/rates/gwrates.py,sha256=Ek_wqAEFV1Uffu5TtiihdLJuSaEqw2KJhNpNFUWdjSw,58470
16
16
  ler/rates/ler copy.py,sha256=BlnGlRISUwiWUhUNwp32_lvh7tHdT-d1VDhFelwKO_c,101873
17
- ler/rates/ler.py,sha256=pBFTeYlranu5A81yENC3lUC5ebalUmr-g8neyWJUhgk,106112
17
+ ler/rates/ler.py,sha256=38z4EGRs33-B6AUaIVAbcWXYRBZW0l0EixyZ1ngrsH4,107884
18
18
  ler/utils/__init__.py,sha256=JWF9SKoqj1BThpV_ynfoyUeU06NQQ45DHCUGaaMSp_8,42
19
19
  ler/utils/plots.py,sha256=uq-usKRnEymtOSAPeHFOfMQW1XX76_WP2aBkT40RvLo,15664
20
- ler/utils/utils.py,sha256=HzRgpDjxXqaZ0jUjYU79IRzaFFK66rAhNAoqXdUHJJo,28976
21
- ler-0.4.0.dist-info/LICENSE,sha256=9LeXXC3WaHBpiUGhLVgOVnz0F12olPma1RX5zgpfp8Q,1081
22
- ler-0.4.0.dist-info/METADATA,sha256=T9JDeo2k-a0XmLpHG5U1gKy3Wlhake6ebe-Xy6EYGO0,6520
23
- ler-0.4.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
24
- ler-0.4.0.dist-info/top_level.txt,sha256=VWeWLF_gNMjzquGmqrLXqp2J5WegY86apTUimMTh68I,4
25
- ler-0.4.0.dist-info/RECORD,,
20
+ ler/utils/utils.py,sha256=wxcHWHNdjdU9fi4RftDez4PuIeu5m7OCL-lS-L6j5b8,32798
21
+ ler-0.4.2.dist-info/LICENSE,sha256=9LeXXC3WaHBpiUGhLVgOVnz0F12olPma1RX5zgpfp8Q,1081
22
+ ler-0.4.2.dist-info/METADATA,sha256=FL3i3pflIZDyQ4ZpXT945zepKway86BfCgh8DUogG8g,6592
23
+ ler-0.4.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
24
+ ler-0.4.2.dist-info/top_level.txt,sha256=VWeWLF_gNMjzquGmqrLXqp2J5WegY86apTUimMTh68I,4
25
+ ler-0.4.2.dist-info/RECORD,,
File without changes
File without changes