ler 0.3.9__tar.gz → 0.4.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ler might be problematic. Click here for more details.

Files changed (31) hide show
  1. {ler-0.3.9 → ler-0.4.1}/PKG-INFO +1 -1
  2. ler-0.4.1/ler/__init__.py +68 -0
  3. {ler-0.3.9 → ler-0.4.1}/ler/gw_source_population/cbc_source_parameter_distribution.py +2 -2
  4. {ler-0.3.9 → ler-0.4.1}/ler/image_properties/image_properties.py +2 -0
  5. {ler-0.3.9 → ler-0.4.1}/ler/lens_galaxy_population/lens_galaxy_parameter_distribution.py +2 -0
  6. {ler-0.3.9 → ler-0.4.1}/ler/rates/ler.py +74 -93
  7. {ler-0.3.9 → ler-0.4.1}/ler/utils/utils.py +141 -87
  8. {ler-0.3.9 → ler-0.4.1}/ler.egg-info/PKG-INFO +1 -1
  9. {ler-0.3.9 → ler-0.4.1}/ler.egg-info/SOURCES.txt +0 -1
  10. {ler-0.3.9 → ler-0.4.1}/ler.egg-info/requires.txt +3 -3
  11. {ler-0.3.9 → ler-0.4.1}/setup.py +4 -4
  12. ler-0.3.9/ler/__init__.py +0 -31
  13. ler-0.3.9/ler/rates/ler copy.py +0 -2097
  14. {ler-0.3.9 → ler-0.4.1}/LICENSE +0 -0
  15. {ler-0.3.9 → ler-0.4.1}/README.md +0 -0
  16. {ler-0.3.9 → ler-0.4.1}/ler/gw_source_population/__init__.py +0 -0
  17. {ler-0.3.9 → ler-0.4.1}/ler/gw_source_population/cbc_source_redshift_distribution.py +0 -0
  18. {ler-0.3.9 → ler-0.4.1}/ler/gw_source_population/jit_functions.py +0 -0
  19. {ler-0.3.9 → ler-0.4.1}/ler/image_properties/__init__.py +0 -0
  20. {ler-0.3.9 → ler-0.4.1}/ler/image_properties/multiprocessing_routine.py +0 -0
  21. {ler-0.3.9 → ler-0.4.1}/ler/lens_galaxy_population/__init__.py +0 -0
  22. {ler-0.3.9 → ler-0.4.1}/ler/lens_galaxy_population/jit_functions.py +0 -0
  23. {ler-0.3.9 → ler-0.4.1}/ler/lens_galaxy_population/mp.py +0 -0
  24. {ler-0.3.9 → ler-0.4.1}/ler/lens_galaxy_population/optical_depth.py +0 -0
  25. {ler-0.3.9 → ler-0.4.1}/ler/rates/__init__.py +0 -0
  26. {ler-0.3.9 → ler-0.4.1}/ler/rates/gwrates.py +0 -0
  27. {ler-0.3.9 → ler-0.4.1}/ler/utils/__init__.py +0 -0
  28. {ler-0.3.9 → ler-0.4.1}/ler/utils/plots.py +0 -0
  29. {ler-0.3.9 → ler-0.4.1}/ler.egg-info/dependency_links.txt +0 -0
  30. {ler-0.3.9 → ler-0.4.1}/ler.egg-info/top_level.txt +0 -0
  31. {ler-0.3.9 → ler-0.4.1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ler
3
- Version: 0.3.9
3
+ Version: 0.4.1
4
4
  Summary: Gravitational waves Lensing Rates
5
5
  Home-page: https://github.com/hemantaph/ler
6
6
  Author: Hemantakumar
@@ -0,0 +1,68 @@
1
+ """
2
+ LeR
3
+ """
4
+
5
+ # mypackage/cli.py
6
+ import argparse
7
+ # import subprocess, os, sys, signal, warnings
8
+
9
+ ## import pycbc
10
+ import os
11
+ import multiprocessing as mp
12
+
13
+ def set_multiprocessing_start_method():
14
+ if os.name == 'posix': # posix indicates the program is run on Unix/Linux/Mac
15
+ print("Setting multiprocessing start method to 'fork'")
16
+ try:
17
+ mp.set_start_method('fork', force=True)
18
+ except RuntimeError:
19
+ # The start method can only be set once and must be set before any process starts
20
+ pass
21
+ else:
22
+ print("Setting multiprocessing start method to 'spawn'")
23
+ # For Windows and other operating systems, use 'spawn'
24
+ try:
25
+ mp.set_start_method('spawn', force=True)
26
+ except RuntimeError:
27
+ pass
28
+
29
+ set_multiprocessing_start_method()
30
+
31
+ # try:
32
+ # mp.set_start_method('fork', force=True)
33
+ # except RuntimeError:
34
+ # pass
35
+
36
+ # if sys.platform == 'darwin':
37
+ # HAVE_OMP = False
38
+
39
+ # # MacosX after python3.7 switched to 'spawn', however, this does not
40
+ # # preserve common state information which we have relied on when using
41
+ # # multiprocessing based pools.
42
+ # import multiprocessing
43
+ # if multiprocessing.get_start_method(allow_none=True) is None:
44
+ # if hasattr(multiprocessing, 'set_start_method'):
45
+ # multiprocessing.set_start_method('fork')
46
+ # elif multiprocessing.get_start_method() != 'fork':
47
+ # warnings.warn("PyCBC requires the use of the 'fork' start method for multiprocessing, it is currently set to {}".format(multiprocessing.get_start_method()))
48
+ # else:
49
+ # HAVE_OMP = True
50
+
51
+ __author__ = 'hemanta_ph <hemantaphurailatpam@gmail.com>'
52
+
53
+ # The version as used in the setup.py
54
+ __version__ = "0.4.1"
55
+
56
+ # add __file__
57
+ import os
58
+ __file__ = os.path.abspath(__file__)
59
+
60
+ from . import rates, gw_source_population, lens_galaxy_population, image_properties, utils
61
+
62
+ from .rates import ler, gwrates
63
+ from .gw_source_population import cbc_source_parameter_distribution, cbc_source_redshift_distribution
64
+ from .lens_galaxy_population import lens_galaxy_parameter_distribution, optical_depth
65
+ from .image_properties import image_properties
66
+ from .utils import utils, plots
67
+
68
+
@@ -58,8 +58,8 @@ class CBCSourceParameterDistribution(CBCSourceRedshiftDistribution):
58
58
  If True, spin parameters are completely ignore in the sampling.
59
59
  default: True
60
60
  spin_precession : `bool`
61
- If spin_zero=True and spin_precession=True, spin parameters are sampled for precessing binaries.
62
- if spin_zero=True and spin_precession=False, spin parameters are sampled for aligned/anti-aligned spin binaries.
61
+ If spin_zero=False and spin_precession=True, spin parameters are sampled for precessing binaries.
62
+ if spin_zero=False and spin_precession=False, spin parameters are sampled for aligned/anti-aligned spin binaries.
63
63
  default: False
64
64
  directory : `str`
65
65
  Directory to store the interpolator pickle files
@@ -6,6 +6,8 @@ The class inherits from the CompactBinaryPopulation class, which is used to samp
6
6
 
7
7
  import warnings
8
8
  warnings.filterwarnings("ignore")
9
+ import logging
10
+ logging.getLogger('numexpr.utils').setLevel(logging.ERROR)
9
11
  # for multiprocessing
10
12
  from multiprocessing import Pool
11
13
  from tqdm import tqdm
@@ -374,6 +374,8 @@ class LensGalaxyParameterDistribution(CBCSourceParameterDistribution, ImagePrope
374
374
  lens_model_list=["EPL_NUMBA", "SHEAR"],
375
375
  )
376
376
  input_params_image.update(params)
377
+
378
+ print("input_params_image", input_params_image)
377
379
  ImageProperties.__init__(
378
380
  self,
379
381
  npool=self.npool,
@@ -6,14 +6,22 @@ This module contains the main class for calculating the rates of detectable grav
6
6
  import os
7
7
  import warnings
8
8
  warnings.filterwarnings("ignore")
9
+ import logging
10
+ logging.getLogger('numexpr.utils').setLevel(logging.ERROR)
9
11
  import contextlib
10
12
  import numpy as np
11
13
  from scipy.stats import norm
12
14
  from astropy.cosmology import LambdaCDM
13
15
  from ..lens_galaxy_population import LensGalaxyParameterDistribution
14
- from ..utils import load_json, append_json, get_param_from_json, batch_handler, add_dict_values
16
+ from ..utils import load_json, append_json, get_param_from_json, batch_handler
15
17
 
16
18
 
19
+ # # multiprocessing guard code
20
+ # def main():
21
+ # obj = LeR()
22
+
23
+ # if __name__ == '__main__':
24
+
17
25
  class LeR(LensGalaxyParameterDistribution):
18
26
  """Class to sample of lensed and unlensed events and calculate it's rates. Please note that parameters of the simulated events are stored in json file but not as an attribute of the class. This saves RAM memory.
19
27
 
@@ -480,7 +488,7 @@ class LeR(LensGalaxyParameterDistribution):
480
488
  Function to print all the parameters.
481
489
  """
482
490
  # print all relevant functions and sampler priors
483
- print("\n LeR set up params:")
491
+ print("\n # LeR set up params:")
484
492
  print(f'npool = {self.npool},')
485
493
  print(f'z_min = {self.z_min},')
486
494
  print(f'z_max = {self.z_max},')
@@ -493,23 +501,23 @@ class LeR(LensGalaxyParameterDistribution):
493
501
  if self.pdet:
494
502
  print(f'pdet_finder = {self.pdet},')
495
503
  print(f'json_file_names = {self.json_file_names},')
496
- print(f'interpolator_directory = {self.interpolator_directory},')
497
- print(f'ler_directory = {self.ler_directory},')
504
+ print(f"interpolator_directory = '{self.interpolator_directory}',")
505
+ print(f"ler_directory = '{self.ler_directory}',")
498
506
 
499
- print("\n LeR also takes CBCSourceParameterDistribution class params as kwargs, as follows:")
507
+ print("\n # LeR also takes CBCSourceParameterDistribution class params as kwargs, as follows:")
500
508
  print(f"source_priors = {self.gw_param_sampler_dict['source_priors']},")
501
509
  print(f"source_priors_params = {self.gw_param_sampler_dict['source_priors_params']},")
502
510
  print(f"spin_zero = {self.gw_param_sampler_dict['spin_zero']},")
503
511
  print(f"spin_precession = {self.gw_param_sampler_dict['spin_precession']},")
504
512
  print(f"create_new_interpolator = {self.gw_param_sampler_dict['create_new_interpolator']},")
505
513
 
506
- print("\n LeR also takes LensGalaxyParameterDistribution class params as kwargs, as follows:")
514
+ print("\n # LeR also takes LensGalaxyParameterDistribution class params as kwargs, as follows:")
507
515
  print(f"lens_type = '{self.gw_param_sampler_dict['lens_type']}',")
508
516
  print(f"lens_functions = {self.gw_param_sampler_dict['lens_functions']},")
509
517
  print(f"lens_priors = {self.gw_param_sampler_dict['lens_priors']},")
510
518
  print(f"lens_priors_params = {self.gw_param_sampler_dict['lens_priors_params']},")
511
519
 
512
- print("\n LeR also takes ImageProperties class params as kwargs, as follows:")
520
+ print("\n # LeR also takes ImageProperties class params as kwargs, as follows:")
513
521
  print(f"n_min_images = {self.n_min_images},")
514
522
  print(f"n_max_images = {self.n_max_images},")
515
523
  print(f"geocent_time_min = {self.geocent_time_min},")
@@ -517,7 +525,7 @@ class LeR(LensGalaxyParameterDistribution):
517
525
  print(f"lens_model_list = {self.lens_model_list},")
518
526
 
519
527
  if self.gwsnr:
520
- print("\n LeR also takes gwsnr.GWSNR params as kwargs, as follows:")
528
+ print("\n # LeR also takes gwsnr.GWSNR params as kwargs, as follows:")
521
529
  print(f"mtot_min = {self.snr_calculator_dict['mtot_min']},")
522
530
  print(f"mtot_max = {self.snr_calculator_dict['mtot_max']},")
523
531
  print(f"ratio_min = {self.snr_calculator_dict['ratio_min']},")
@@ -531,7 +539,6 @@ class LeR(LensGalaxyParameterDistribution):
531
539
  print(f"psds = {self.snr_calculator_dict['psds']},")
532
540
  print(f"ifos = {self.snr_calculator_dict['ifos']},")
533
541
  print(f"interpolator_dir = '{self.snr_calculator_dict['interpolator_dir']}',")
534
- print(f"create_new_interpolator = {self.snr_calculator_dict['create_new_interpolator']},")
535
542
  print(f"gwsnr_verbose = {self.snr_calculator_dict['gwsnr_verbose']},")
536
543
  print(f"multiprocessing_verbose = {self.snr_calculator_dict['multiprocessing_verbose']},")
537
544
  print(f"mtot_cut = {self.snr_calculator_dict['mtot_cut']},")
@@ -712,6 +719,7 @@ class LeR(LensGalaxyParameterDistribution):
712
719
  # initialization of LensGalaxyParameterDistribution class
713
720
  # it also initializes the CBCSourceParameterDistribution and ImageProperties classes
714
721
  input_params = dict(
722
+ # LensGalaxyParameterDistribution class params
715
723
  z_min=self.z_min,
716
724
  z_max=self.z_max,
717
725
  cosmology=self.cosmo,
@@ -720,8 +728,13 @@ class LeR(LensGalaxyParameterDistribution):
720
728
  lens_functions= None,
721
729
  lens_priors=None,
722
730
  lens_priors_params=None,
731
+ # ImageProperties class params
732
+ n_min_images=2,
733
+ n_max_images=4,
723
734
  geocent_time_min=1126259462.4,
724
735
  geocent_time_max=1126259462.4+365*24*3600*20,
736
+ lens_model_list=['EPL_NUMBA', 'SHEAR'],
737
+ # CBCSourceParameterDistribution class params
725
738
  source_priors=None,
726
739
  source_priors_params=None,
727
740
  spin_zero=True,
@@ -745,8 +758,11 @@ class LeR(LensGalaxyParameterDistribution):
745
758
  lens_functions=input_params["lens_functions"],
746
759
  lens_priors=input_params["lens_priors"],
747
760
  lens_priors_params=input_params["lens_priors_params"],
761
+ n_min_images=input_params["n_min_images"],
762
+ n_max_images=input_params["n_max_images"],
748
763
  geocent_time_min=input_params["geocent_time_min"],
749
764
  geocent_time_max=input_params["geocent_time_max"],
765
+ lens_model_list=input_params["lens_model_list"],
750
766
  source_priors=input_params["source_priors"],
751
767
  source_priors_params=input_params["source_priors_params"],
752
768
  spin_zero=input_params["spin_zero"],
@@ -792,6 +808,10 @@ class LeR(LensGalaxyParameterDistribution):
792
808
  gwsnr_verbose=False,
793
809
  multiprocessing_verbose=True,
794
810
  mtot_cut=True,
811
+ pdet=False,
812
+ snr_th=8.0,
813
+ snr_th_net=8.0,
814
+ ann_path_dict=None,
795
815
  )
796
816
  # if self.event_type == "BNS":
797
817
  # input_params["mtot_max"]= 18.
@@ -819,6 +839,10 @@ class LeR(LensGalaxyParameterDistribution):
819
839
  gwsnr_verbose=input_params["gwsnr_verbose"],
820
840
  multiprocessing_verbose=input_params["multiprocessing_verbose"],
821
841
  mtot_cut=input_params["mtot_cut"],
842
+ pdet=input_params["pdet"],
843
+ snr_th=input_params["snr_th"],
844
+ snr_th_net=input_params["snr_th_net"],
845
+ ann_path_dict=input_params["ann_path_dict"],
822
846
  )
823
847
 
824
848
  self.snr = gwsnr.snr
@@ -909,36 +933,16 @@ class LeR(LensGalaxyParameterDistribution):
909
933
  output_path = os.path.join(self.ler_directory, output_jsonfile)
910
934
  print(f"unlensed params will be store in {output_path}")
911
935
 
912
- # sampling in batches
913
- if resume and os.path.exists(output_path):
914
- # get sample from json file
915
- self.dict_buffer = get_param_from_json(output_path)
916
- else:
917
- self.dict_buffer = None
918
-
919
- batch_handler(
936
+ unlensed_param = batch_handler(
920
937
  size=size,
921
938
  batch_size=self.batch_size,
922
939
  sampling_routine=self.unlensed_sampling_routine,
923
940
  output_jsonfile=output_path,
924
941
  save_batch=save_batch,
925
942
  resume=resume,
943
+ param_name="unlensed parameters",
926
944
  )
927
945
 
928
- if save_batch:
929
- unlensed_param = get_param_from_json(output_path)
930
- else:
931
- # this if condition is required if there is nothing to save
932
- if self.dict_buffer:
933
- unlensed_param = self.dict_buffer.copy()
934
- # store all params in json file
935
- print(f"saving all unlensed_params in {output_path} ")
936
- append_json(output_path, unlensed_param, replace=True)
937
- else:
938
- print("unlensed_params already sampled.")
939
- unlensed_param = get_param_from_json(output_path)
940
- self.dict_buffer = None # save memory
941
-
942
946
  return unlensed_param
943
947
 
944
948
  def unlensed_sampling_routine(self, size, output_jsonfile, resume=False, save_batch=True):
@@ -978,17 +982,6 @@ class LeR(LensGalaxyParameterDistribution):
978
982
  pdet = self.pdet(gw_param_dict=unlensed_param)
979
983
  unlensed_param.update(pdet)
980
984
 
981
- # adding batches
982
- if not save_batch:
983
- if self.dict_buffer is None:
984
- self.dict_buffer = unlensed_param
985
- else:
986
- for key, value in unlensed_param.items():
987
- self.dict_buffer[key] = np.concatenate((self.dict_buffer[key], value))
988
- else:
989
- # store all params in json file
990
- self.dict_buffer = append_json(file_name=output_jsonfile, new_dictionary=unlensed_param, old_dictionary=self.dict_buffer, replace=not (resume))
991
-
992
985
  return unlensed_param
993
986
 
994
987
  def unlensed_rate(
@@ -1323,36 +1316,16 @@ class LeR(LensGalaxyParameterDistribution):
1323
1316
  output_path = os.path.join(self.ler_directory, output_jsonfile)
1324
1317
  print(f"lensed params will be store in {output_path}")
1325
1318
 
1326
- # sampling in batches
1327
- if resume and os.path.exists(output_path):
1328
- # get sample from json file
1329
- self.dict_buffer = get_param_from_json(output_path)
1330
- else:
1331
- self.dict_buffer = None
1332
-
1333
- batch_handler(
1319
+ lensed_param = batch_handler(
1334
1320
  size=size,
1335
1321
  batch_size=self.batch_size,
1336
1322
  sampling_routine=self.lensed_sampling_routine,
1337
1323
  output_jsonfile=output_path,
1338
1324
  save_batch=save_batch,
1339
1325
  resume=resume,
1326
+ param_name="lensed parameters",
1340
1327
  )
1341
1328
 
1342
- if save_batch:
1343
- lensed_param = get_param_from_json(output_path)
1344
- else:
1345
- # this if condition is required if there is nothing to save
1346
- if self.dict_buffer:
1347
- lensed_param = self.dict_buffer.copy()
1348
- # store all params in json file
1349
- print(f"saving all lensed_params in {output_path} ")
1350
- append_json(output_path, lensed_param, replace=True)
1351
- else:
1352
- print("lensed_params already sampled.")
1353
- lensed_param = get_param_from_json(output_path)
1354
- self.dict_buffer = None # save memory
1355
-
1356
1329
  return lensed_param
1357
1330
 
1358
1331
  def lensed_sampling_routine(self, size, output_jsonfile, save_batch=True, resume=False):
@@ -1382,6 +1355,8 @@ class LeR(LensGalaxyParameterDistribution):
1382
1355
  print("sampling lensed params...")
1383
1356
  lensed_param = {}
1384
1357
 
1358
+ # Some of the sample lensed events may not satisfy the strong lensing condition
1359
+ # In that case, we will resample those events and replace the values with the corresponding indices
1385
1360
  while True:
1386
1361
  # get lensed params
1387
1362
  lensed_param_ = self.sample_lens_parameters(size=size)
@@ -1425,17 +1400,6 @@ class LeR(LensGalaxyParameterDistribution):
1425
1400
  )
1426
1401
  lensed_param.update(pdet)
1427
1402
 
1428
- # adding batches
1429
- if not save_batch:
1430
- if self.dict_buffer is None:
1431
- self.dict_buffer = lensed_param
1432
- else:
1433
- for key, value in lensed_param.items():
1434
- self.dict_buffer[key] = np.concatenate((self.dict_buffer[key], value))
1435
- else:
1436
- # store all params in json file
1437
- self.dict_buffer = append_json(file_name=output_jsonfile, new_dictionary=lensed_param, old_dictionary=self.dict_buffer, replace=not (resume))
1438
-
1439
1403
  return lensed_param
1440
1404
 
1441
1405
  def lensed_rate(
@@ -1447,6 +1411,8 @@ class LeR(LensGalaxyParameterDistribution):
1447
1411
  output_jsonfile=None,
1448
1412
  nan_to_num=True,
1449
1413
  detectability_condition="step_function",
1414
+ combine_image_snr=False,
1415
+ snr_cut_for_combine_image_snr=4.0,
1450
1416
  snr_recalculation=False,
1451
1417
  snr_threshold_recalculation=[[4,4], [20,20]],
1452
1418
  ):
@@ -1519,7 +1485,8 @@ class LeR(LensGalaxyParameterDistribution):
1519
1485
  if snr_recalculation:
1520
1486
  lensed_param = self._recalculate_snr_lensed(lensed_param, snr_threshold_recalculation, num_img, total_events)
1521
1487
 
1522
- snr_hit = self._find_detectable_index_lensed(lensed_param, snr_threshold, pdet_threshold, num_img, detectability_condition)
1488
+ # find index of detectable events
1489
+ snr_hit = self._find_detectable_index_lensed(lensed_param, snr_threshold, pdet_threshold, num_img, detectability_condition, combine_image_snr=combine_image_snr, snr_cut_for_combine_image_snr=snr_cut_for_combine_image_snr)
1523
1490
 
1524
1491
  # montecarlo integration
1525
1492
  total_rate = self.rate_function(np.sum(snr_hit), total_events, param_type="lensed")
@@ -1619,7 +1586,7 @@ class LeR(LensGalaxyParameterDistribution):
1619
1586
 
1620
1587
  return lensed_param
1621
1588
 
1622
- def _find_detectable_index_lensed(self, lensed_param, snr_threshold, pdet_threshold, num_img, detectability_condition):
1589
+ def _find_detectable_index_lensed(self, lensed_param, snr_threshold, pdet_threshold, num_img, detectability_condition, combine_image_snr=False, snr_cut_for_combine_image_snr=4.0):
1623
1590
  """
1624
1591
  Helper function to find the index of detectable events based on SNR or p_det.
1625
1592
 
@@ -1655,18 +1622,24 @@ class LeR(LensGalaxyParameterDistribution):
1655
1622
  snr_param = -np.sort(-snr_param, axis=1) # sort snr in descending order
1656
1623
  snr_hit = np.full(len(snr_param), True) # boolean array to store the result of the threshold condition
1657
1624
 
1658
- # for each row: choose a threshold and check if the number of images above threshold. Sum over the images. If sum is greater than num_img, then snr_hit = True
1659
- # algorithm:
1660
- # i) consider snr_threshold=[8,6] and num_img=[2,1] and first row of snr_param[0]=[12,8,6,1]. Note that the snr_param is sorted in descending order.
1661
- # ii) for loop runs wrt snr_threshold. idx_max = idx_max + num_img[i]
1662
- # iii) First iteration: snr_threshold=8 and num_img=2. In snr_param, column index 0 and 1 (i.e. 0:num_img[0]) are considered. The sum of snr_param[0, 0:2] > 8 is checked. If True, then snr_hit = True.
1663
- # v) Second iteration: snr_threshold=6 and num_img=1. In snr_param, column index 2 (i.e. num_img[0]:num_img[1]) is considered. The sum of snr_param[0, 0:1] > 6 is checked. If True, then snr_hit = True.
1664
- j = 0
1665
- idx_max = 0
1666
- for i, snr_th in enumerate(snr_threshold):
1667
- idx_max = idx_max + num_img[i]
1668
- snr_hit = snr_hit & (np.sum((snr_param[:,j:idx_max] > snr_th), axis=1) >= num_img[i])
1669
- j = idx_max
1625
+ if not combine_image_snr:
1626
+ # for each row: choose a threshold and check if the number of images above threshold. Sum over the images. If sum is greater than num_img, then snr_hit = True
1627
+ # algorithm:
1628
+ # i) consider snr_threshold=[8,6] and num_img=[2,1] and first row of snr_param[0]=[12,8,6,1]. Note that the snr_param is sorted in descending order.
1629
+ # ii) for loop runs wrt snr_threshold. idx_max = idx_max + num_img[i]
1630
+ # iii) First iteration: snr_threshold=8 and num_img=2. In snr_param, column index 0 and 1 (i.e. 0:num_img[0]) are considered. The sum of snr_param[0, 0:2] > 8 is checked. If True, then snr_hit = True.
1631
+ # v) Second iteration: snr_threshold=6 and num_img=1. In snr_param, column index 2 (i.e. num_img[0]:num_img[1]) is considered. The sum of snr_param[0, 0:1] > 6 is checked. If True, then snr_hit = True.
1632
+ j = 0
1633
+ idx_max = 0
1634
+ for i, snr_th in enumerate(snr_threshold):
1635
+ idx_max = idx_max + num_img[i]
1636
+ snr_hit = snr_hit & (np.sum((snr_param[:,j:idx_max] > snr_th), axis=1) >= num_img[i])
1637
+ j = idx_max
1638
+ else:
1639
+ # sqrt of the the sum of the squares of the snr of the images
1640
+ snr_param[snr_param<snr_cut_for_combine_image_snr] = 0.0 # images with snr below snr_cut_for_combine_image_snr are not considered
1641
+ snr_param = np.sqrt(np.sum(snr_param[:,:np.sum(num_img)]**2, axis=1))
1642
+ snr_hit = snr_param >= snr_threshold[0]
1670
1643
 
1671
1644
  elif detectability_condition == "pdet":
1672
1645
  if "pdet_net" not in lensed_param:
@@ -1705,6 +1678,8 @@ class LeR(LensGalaxyParameterDistribution):
1705
1678
  lensed_param=None,
1706
1679
  snr_threshold_lensed=[8.0,8.0],
1707
1680
  num_img=[1,1],
1681
+ combine_image_snr=False,
1682
+ snr_cut_for_combine_image_snr=4.0,
1708
1683
  output_jsonfile_lensed=None,
1709
1684
  nan_to_num=True,
1710
1685
  detectability_condition="step_function",
@@ -1779,6 +1754,8 @@ class LeR(LensGalaxyParameterDistribution):
1779
1754
  output_jsonfile=output_jsonfile_lensed,
1780
1755
  nan_to_num=nan_to_num,
1781
1756
  detectability_condition=detectability_condition,
1757
+ combine_image_snr=combine_image_snr,
1758
+ snr_cut_for_combine_image_snr=snr_cut_for_combine_image_snr,
1782
1759
  )
1783
1760
  # calculate rate ratio
1784
1761
  rate_ratio = self.rate_ratio()
@@ -1893,7 +1870,7 @@ class LeR(LensGalaxyParameterDistribution):
1893
1870
  """
1894
1871
 
1895
1872
  # initial setup
1896
- n, events_total, output_path, meta_data_path, buffer_file = self._initial_setup_for_n_event_selection(meta_data_file, output_jsonfile, resume, batch_size)
1873
+ n, events_total, output_path, meta_data_path, buffer_file, batch_size = self._initial_setup_for_n_event_selection(meta_data_file, output_jsonfile, resume, batch_size)
1897
1874
 
1898
1875
  # loop until n samples are collected
1899
1876
  while n < size:
@@ -1953,6 +1930,8 @@ class LeR(LensGalaxyParameterDistribution):
1953
1930
  snr_threshold=[8.0,8.0],
1954
1931
  pdet_threshold=0.5,
1955
1932
  num_img=[1,1],
1933
+ combine_image_snr=False,
1934
+ snr_cut_for_combine_image_snr=4.0,
1956
1935
  resume=False,
1957
1936
  detectability_condition="step_function",
1958
1937
  output_jsonfile="n_lensed_params_detectable.json",
@@ -2021,7 +2000,7 @@ class LeR(LensGalaxyParameterDistribution):
2021
2000
  """
2022
2001
 
2023
2002
  # initial setup
2024
- n, events_total, output_path, meta_data_path, buffer_file = self._initial_setup_for_n_event_selection(meta_data_file, output_jsonfile, resume, batch_size)
2003
+ n, events_total, output_path, meta_data_path, buffer_file, batch_size = self._initial_setup_for_n_event_selection(meta_data_file, output_jsonfile, resume, batch_size)
2025
2004
 
2026
2005
  # re-analyse the provided snr_threshold and num_img
2027
2006
  snr_threshold, num_img = self._check_snr_threshold_lensed(snr_threshold, num_img)
@@ -2041,7 +2020,7 @@ class LeR(LensGalaxyParameterDistribution):
2041
2020
  if snr_recalculation:
2042
2021
  lensed_param = self._recalculate_snr_lensed(lensed_param, snr_threshold_recalculation, num_img, total_events_in_this_iteration)
2043
2022
 
2044
- snr_hit = self._find_detectable_index_lensed(lensed_param, snr_threshold, pdet_threshold, num_img, detectability_condition)
2023
+ snr_hit = self._find_detectable_index_lensed(lensed_param, snr_threshold, pdet_threshold, num_img, detectability_condition, combine_image_snr=combine_image_snr, snr_cut_for_combine_image_snr=snr_cut_for_combine_image_snr)
2045
2024
 
2046
2025
  # store all params in json file
2047
2026
  self._save_detectable_params(output_jsonfile, lensed_param, snr_hit, key_file_name="n_lensed_detectable_events", nan_to_num=nan_to_num, verbose=False, replace_jsonfile=False)
@@ -2118,6 +2097,8 @@ class LeR(LensGalaxyParameterDistribution):
2118
2097
  if not resume:
2119
2098
  n = 0 # iterator
2120
2099
  events_total = 0
2100
+ # the following file will be removed if it exists
2101
+ print(f"removing {output_path} and {meta_data_path} if they exist")
2121
2102
  if os.path.exists(output_path):
2122
2103
  os.remove(output_path)
2123
2104
  if os.path.exists(meta_data_path):
@@ -2136,7 +2117,7 @@ class LeR(LensGalaxyParameterDistribution):
2136
2117
  buffer_file = "params_buffer.json"
2137
2118
  print("collected number of detectable events = ", n)
2138
2119
 
2139
- return n, events_total, output_path, meta_data_path, buffer_file
2120
+ return n, events_total, output_path, meta_data_path, buffer_file, batch_size
2140
2121
 
2141
2122
  def _trim_results_to_size(self, size, output_path, meta_data_path, param_type="unlensed"):
2142
2123
  """