ler 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ler might be problematic. Click here for more details.

Files changed (35) hide show
  1. ler/__init__.py +26 -26
  2. ler/gw_source_population/__init__.py +1 -0
  3. ler/gw_source_population/cbc_source_parameter_distribution.py +1076 -818
  4. ler/gw_source_population/cbc_source_redshift_distribution.py +619 -295
  5. ler/gw_source_population/jit_functions.py +484 -9
  6. ler/gw_source_population/sfr_with_time_delay.py +107 -0
  7. ler/image_properties/image_properties.py +44 -13
  8. ler/image_properties/multiprocessing_routine.py +5 -209
  9. ler/lens_galaxy_population/__init__.py +2 -0
  10. ler/lens_galaxy_population/epl_shear_cross_section.py +0 -0
  11. ler/lens_galaxy_population/jit_functions.py +101 -9
  12. ler/lens_galaxy_population/lens_galaxy_parameter_distribution.py +817 -885
  13. ler/lens_galaxy_population/lens_param_data/density_profile_slope_sl.txt +5000 -0
  14. ler/lens_galaxy_population/lens_param_data/external_shear_sl.txt +2 -0
  15. ler/lens_galaxy_population/lens_param_data/number_density_zl_zs.txt +48 -0
  16. ler/lens_galaxy_population/lens_param_data/optical_depth_epl_shear_vd_ewoud.txt +48 -0
  17. ler/lens_galaxy_population/mp copy.py +554 -0
  18. ler/lens_galaxy_population/mp.py +736 -138
  19. ler/lens_galaxy_population/optical_depth.py +2248 -616
  20. ler/rates/__init__.py +1 -2
  21. ler/rates/gwrates.py +129 -75
  22. ler/rates/ler.py +257 -116
  23. ler/utils/__init__.py +2 -0
  24. ler/utils/function_interpolation.py +322 -0
  25. ler/utils/gwsnr_training_data_generator.py +233 -0
  26. ler/utils/plots.py +1 -1
  27. ler/utils/test.py +1078 -0
  28. ler/utils/utils.py +553 -125
  29. {ler-0.4.1.dist-info → ler-0.4.3.dist-info}/METADATA +22 -9
  30. ler-0.4.3.dist-info/RECORD +34 -0
  31. {ler-0.4.1.dist-info → ler-0.4.3.dist-info}/WHEEL +1 -1
  32. ler/rates/ler copy.py +0 -2097
  33. ler-0.4.1.dist-info/RECORD +0 -25
  34. {ler-0.4.1.dist-info → ler-0.4.3.dist-info/licenses}/LICENSE +0 -0
  35. {ler-0.4.1.dist-info → ler-0.4.3.dist-info}/top_level.txt +0 -0
ler/rates/__init__.py CHANGED
@@ -1,3 +1,2 @@
1
1
  from .ler import *
2
- from .gwrates import *
3
-
2
+ from .gwrates import *
ler/rates/gwrates.py CHANGED
@@ -6,6 +6,8 @@ This module contains the main class for calculating the rates of detectable grav
6
6
  import os
7
7
  import warnings
8
8
  warnings.filterwarnings("ignore")
9
+ import logging
10
+ logging.getLogger('numexpr.utils').setLevel(logging.ERROR)
9
11
  import contextlib
10
12
  import numpy as np
11
13
  from scipy.stats import norm
@@ -51,7 +53,7 @@ class GWRATES(CBCSourceParameterDistribution):
51
53
  def snr_finder(gw_param_dict):
52
54
  ...
53
55
  return optimal_snr_dict
54
- where optimal_snr_dict.keys = ['optimal_snr_net']. Refer to `gwsnr` package's GWSNR.snr attribute for more details.
56
+ where optimal_snr_dict.keys = ['snr_net']. Refer to `gwsnr` package's GWSNR.snr attribute for more details.
55
57
  pdet_finder : `function`
56
58
  default pdet_finder = None.
57
59
  The rate calculation uses either the pdet_finder or the snr_finder to calculate the detectable events. The custom pdet finder function should follow the following signature:
@@ -84,7 +86,7 @@ class GWRATES(CBCSourceParameterDistribution):
84
86
 
85
87
  Instance Attributes
86
88
  ----------
87
- LeR class has the following attributes, \n
89
+ LeR class has the following attributes:\n
88
90
  +-------------------------------------+----------------------------------+
89
91
  | Atrributes | Type |
90
92
  +=====================================+==================================+
@@ -121,14 +123,14 @@ class GWRATES(CBCSourceParameterDistribution):
121
123
 
122
124
  Instance Methods
123
125
  ----------
124
- LeR class has the following methods, \n
126
+ LeR class has the following methods:\n
125
127
  +-------------------------------------+----------------------------------+
126
128
  | Methods | Description |
127
129
  +=====================================+==================================+
128
130
  |:meth:`~class_initialization` | Function to initialize the |
129
131
  | | parent classes |
130
132
  +-------------------------------------+----------------------------------+
131
- |:meth:`~gwsnr_intialization` | Function to initialize the |
133
+ |:meth:`~gwsnr_initialization` | Function to initialize the |
132
134
  | | gwsnr class |
133
135
  +-------------------------------------+----------------------------------+
134
136
  |:meth:`~snr` | Function to get the snr with the |
@@ -282,11 +284,13 @@ class GWRATES(CBCSourceParameterDistribution):
282
284
  list_of_detectors=None,
283
285
  json_file_names=None,
284
286
  interpolator_directory="./interpolator_pickle",
287
+ create_new_interpolator=False,
285
288
  ler_directory="./ler_data",
286
289
  verbose=True,
287
290
  **kwargs,
288
291
  ):
289
292
 
293
+ print("\nInitializing GWRATES class...\n")
290
294
  self.npool = npool
291
295
  self.z_min = z_min
292
296
  self.z_max = z_max
@@ -298,6 +302,7 @@ class GWRATES(CBCSourceParameterDistribution):
298
302
  if json_file_names:
299
303
  self.json_file_names.update(json_file_names)
300
304
  self.interpolator_directory = interpolator_directory
305
+ kwargs["create_new_interpolator"] = create_new_interpolator
301
306
  self.ler_directory = ler_directory
302
307
  # create directory if not exists
303
308
  if not os.path.exists(ler_directory):
@@ -308,7 +313,7 @@ class GWRATES(CBCSourceParameterDistribution):
308
313
  self.class_initialization(params=kwargs)
309
314
  # initialization self.snr and self.pdet from GWSNR class
310
315
  if not snr_finder and not pdet_finder:
311
- self.gwsnr_intialization(params=kwargs)
316
+ self.gwsnr_initialization(params=kwargs)
312
317
  self.gwsnr = True
313
318
  self.pdet = pdet_finder
314
319
  else:
@@ -320,20 +325,21 @@ class GWRATES(CBCSourceParameterDistribution):
320
325
  # store all the gwrates input parameters
321
326
  self.store_gwrates_params(output_jsonfile=self.json_file_names["gwrates_params"])
322
327
 
328
+ # if not verbose, prevent anything from printing
323
329
  if verbose:
324
330
  initialization()
325
- self.print_all_params()
331
+ self.print_all_params_ler()
326
332
  else:
327
333
  with contextlib.redirect_stdout(None):
328
334
  initialization()
329
335
 
330
- def print_all_params(self):
336
+ def print_all_params_ler(self):
331
337
  """
332
338
  Function to print all the parameters.
333
339
  """
334
340
 
335
341
  # print all relevant functions and sampler priors
336
- print("\n GWRATES set up params:")
342
+ print("\n # GWRATES set up params:")
337
343
  print(f'npool = {self.npool},')
338
344
  print(f'z_min = {self.z_min},')
339
345
  print(f'z_max = {self.z_max},')
@@ -346,10 +352,10 @@ class GWRATES(CBCSourceParameterDistribution):
346
352
  if self.pdet:
347
353
  print(f'pdet_finder = {self.pdet},')
348
354
  print(f'json_file_names = {self.json_file_names},')
349
- print(f'interpolator_directory = {self.interpolator_directory},')
350
- print(f'ler_directory = {self.ler_directory},')
355
+ print(f"interpolator_directory = '{self.interpolator_directory}',")
356
+ print(f"ler_directory = '{self.ler_directory}',")
351
357
 
352
- print("\n GWRATES also takes CBCSourceParameterDistribution params as kwargs, as follows:")
358
+ print("\n # GWRATES also takes CBCSourceParameterDistribution params as kwargs, as follows:")
353
359
  print(f"source_priors = {self.gw_param_sampler_dict['source_priors']},")
354
360
  print(f"source_priors_params = {self.gw_param_sampler_dict['source_priors_params']},")
355
361
  print(f"spin_zero = {self.gw_param_sampler_dict['spin_zero']},")
@@ -357,7 +363,7 @@ class GWRATES(CBCSourceParameterDistribution):
357
363
  print(f"create_new_interpolator = {self.gw_param_sampler_dict['create_new_interpolator']},")
358
364
 
359
365
  if self.gwsnr:
360
- print("\n LeR also takes gwsnr.GWSNR params as kwargs, as follows:")
366
+ print("\n # LeR also takes gwsnr.GWSNR params as kwargs, as follows:")
361
367
  print(f"mtot_min = {self.snr_calculator_dict['mtot_min']},")
362
368
  print(f"mtot_max = {self.snr_calculator_dict['mtot_max']},")
363
369
  print(f"ratio_min = {self.snr_calculator_dict['ratio_min']},")
@@ -506,7 +512,7 @@ class GWRATES(CBCSourceParameterDistribution):
506
512
  dictionary of parameters to initialize the parent classes
507
513
  """
508
514
 
509
- # initialization of CompactBinaryPopulation class
515
+ # initialization of CBCSourceParameterDistribution class
510
516
  # it also initializes the CBCSourceRedshiftDistribution class
511
517
  input_params = dict(
512
518
  z_min=self.z_min,
@@ -530,10 +536,10 @@ class GWRATES(CBCSourceParameterDistribution):
530
536
  self,
531
537
  z_min=input_params["z_min"],
532
538
  z_max=input_params["z_max"],
539
+ cosmology=input_params["cosmology"],
533
540
  event_type=input_params["event_type"],
534
541
  source_priors=input_params["source_priors"],
535
542
  source_priors_params=input_params["source_priors_params"],
536
- cosmology=input_params["cosmology"],
537
543
  spin_zero=input_params["spin_zero"],
538
544
  spin_precession=input_params["spin_precession"],
539
545
  directory=input_params["directory"],
@@ -543,7 +549,7 @@ class GWRATES(CBCSourceParameterDistribution):
543
549
  self.gw_param_sampler_dict["source_priors"]=self.gw_param_samplers.copy()
544
550
  self.gw_param_sampler_dict["source_priors_params"]=self.gw_param_samplers_params.copy()
545
551
 
546
- def gwsnr_intialization(self, params=None):
552
+ def gwsnr_initialization(self, params=None):
547
553
  """
548
554
  Function to initialize the GWSNR class from the `gwsnr` package.
549
555
 
@@ -557,23 +563,35 @@ class GWRATES(CBCSourceParameterDistribution):
557
563
  # initialization of GWSNR class
558
564
  input_params = dict(
559
565
  npool=self.npool,
560
- mtot_min=2.0,
561
- mtot_max=200,
566
+ mtot_min=2*self.gw_param_samplers_params['source_frame_masses']['mminbh'],
567
+ mtot_max=2*self.gw_param_samplers_params['source_frame_masses']['mmaxbh']+10.0,
562
568
  ratio_min=0.1,
563
569
  ratio_max=1.0,
564
- mtot_resolution=500,
565
- ratio_resolution=50,
570
+ spin_max=0.99,
571
+ mtot_resolution=200,
572
+ ratio_resolution=20,
573
+ spin_resolution=10,
566
574
  sampling_frequency=2048.0,
567
575
  waveform_approximant="IMRPhenomD",
576
+ frequency_domain_source_model='lal_binary_black_hole',
568
577
  minimum_frequency=20.0,
578
+ duration_max=None,
579
+ duration_min=None,
569
580
  snr_type="interpolation",
570
581
  psds=None,
571
582
  ifos=None,
572
- interpolator_dir=self.interpolator_directory,
583
+ interpolator_dir="./interpolator_pickle",
573
584
  create_new_interpolator=False,
574
- gwsnr_verbose=False,
585
+ gwsnr_verbose=True,
575
586
  multiprocessing_verbose=True,
576
- mtot_cut=True,
587
+ mtot_cut=False,
588
+ pdet=False,
589
+ snr_th=8.0,
590
+ snr_th_net=8.0,
591
+ ann_path_dict=None,
592
+ snr_recalculation=False,
593
+ snr_recalculation_range=[4,12],
594
+ snr_recalculation_waveform_approximant="IMRPhenomXPHM",
577
595
  )
578
596
  # if self.event_type == "BNS":
579
597
  # input_params["mtot_max"]= 18.
@@ -582,25 +600,49 @@ class GWRATES(CBCSourceParameterDistribution):
582
600
  if key in input_params:
583
601
  input_params[key] = value
584
602
  self.snr_calculator_dict = input_params
603
+
604
+ # dealing with create_new_interpolator param
605
+ if isinstance(input_params["create_new_interpolator"], bool):
606
+ pass
607
+ elif isinstance(input_params["create_new_interpolator"], dict):
608
+ # check input_params["gwsnr"] exists
609
+ if "gwsnr" in input_params["create_new_interpolator"]:
610
+ if isinstance(input_params["create_new_interpolator"]["gwsnr"], bool):
611
+ input_params["create_new_interpolator"] = input_params["create_new_interpolator"]["gwsnr"]
612
+ else:
613
+ raise ValueError("create_new_interpolator['gwsnr'] should be a boolean.")
614
+ else:
615
+ input_params["create_new_interpolator"] = False
616
+
617
+ # initialization of GWSNR class
585
618
  gwsnr = GWSNR(
586
619
  npool=input_params["npool"],
587
620
  mtot_min=input_params["mtot_min"],
588
621
  mtot_max=input_params["mtot_max"],
589
622
  ratio_min=input_params["ratio_min"],
590
623
  ratio_max=input_params["ratio_max"],
624
+ spin_max=input_params["spin_max"],
591
625
  mtot_resolution=input_params["mtot_resolution"],
592
626
  ratio_resolution=input_params["ratio_resolution"],
627
+ spin_resolution=input_params["spin_resolution"],
593
628
  sampling_frequency=input_params["sampling_frequency"],
594
629
  waveform_approximant=input_params["waveform_approximant"],
630
+ frequency_domain_source_model=input_params["frequency_domain_source_model"],
595
631
  minimum_frequency=input_params["minimum_frequency"],
632
+ duration_max=input_params["duration_max"],
633
+ duration_min=input_params["duration_min"],
596
634
  snr_type=input_params["snr_type"],
597
635
  psds=input_params["psds"],
598
636
  ifos=input_params["ifos"],
599
637
  interpolator_dir=input_params["interpolator_dir"],
600
- # create_new_interpolator=input_params["create_new_interpolator"],
638
+ create_new_interpolator=input_params["create_new_interpolator"],
601
639
  gwsnr_verbose=input_params["gwsnr_verbose"],
602
640
  multiprocessing_verbose=input_params["multiprocessing_verbose"],
603
641
  mtot_cut=input_params["mtot_cut"],
642
+ pdet=input_params["pdet"],
643
+ snr_th=input_params["snr_th"],
644
+ snr_th_net=input_params["snr_th_net"],
645
+ ann_path_dict=input_params["ann_path_dict"],
604
646
  )
605
647
 
606
648
  self.snr = gwsnr.snr
@@ -610,6 +652,16 @@ class GWRATES(CBCSourceParameterDistribution):
610
652
  self.snr_calculator_dict["psds"] = gwsnr.psds_list
611
653
  #self.pdet = gwsnr.pdet
612
654
 
655
+ # warm-up the snr calculator
656
+ if input_params["snr_type"]!="inner_product":
657
+ with contextlib.suppress(Exception):
658
+ # if snr_type is not inner_product, then we can warm up the snr calculator
659
+ # this is useful to avoid the first call to snr being slow
660
+ mass_1 = np.array([10.])
661
+ ratio = np.array([0.8])
662
+ dl = np.array([1000.0]) # in Mpc
663
+ snr = self.snr(mass_1=mass_1, mass_2=mass_1*ratio, luminosity_distance=dl)
664
+
613
665
  def store_gwrates_params(self, output_jsonfile="gwrates_params.json"):
614
666
  """
615
667
  Function to store the all the necessary parameters. This is useful for reproducing the results. All the parameters stored are in string format to make it json compatible.
@@ -646,13 +698,13 @@ class GWRATES(CBCSourceParameterDistribution):
646
698
  for key, value in snr_calculator_dict.items():
647
699
  snr_calculator_dict[key] = str(value)
648
700
  parameters_dict.update({"snr_calculator_dict": snr_calculator_dict})
649
-
650
- file_name = output_jsonfile
651
- append_json(self.ler_directory+"/"+file_name, parameters_dict, replace=True)
652
701
  except:
653
702
  # if snr_calculator is custom function
654
703
  pass
655
704
 
705
+ file_name = output_jsonfile
706
+ append_json(self.ler_directory+"/"+file_name, parameters_dict, replace=True)
707
+
656
708
  def gw_cbc_statistics(
657
709
  self, size=None, resume=False, save_batch=False, output_jsonfile=None,
658
710
  ):
@@ -691,36 +743,16 @@ class GWRATES(CBCSourceParameterDistribution):
691
743
  output_path = os.path.join(self.ler_directory, output_jsonfile)
692
744
  print(f"Simulated GW params will be stored in {output_path}")
693
745
 
694
- # sampling in batches
695
- if resume and os.path.exists(output_path):
696
- # get sample from json file
697
- self.dict_buffer = get_param_from_json(output_path)
698
- else:
699
- self.dict_buffer = None
700
-
701
- batch_handler(
746
+ gw_param = batch_handler(
702
747
  size=size,
703
748
  batch_size=self.batch_size,
704
749
  sampling_routine=self.gw_sampling_routine,
705
750
  output_jsonfile=output_path,
706
751
  save_batch=save_batch,
707
752
  resume=resume,
753
+ param_name="gw parameters",
708
754
  )
709
755
 
710
- if save_batch:
711
- gw_param = get_param_from_json(output_path)
712
- else:
713
- # this if condition is required if there is nothing to save
714
- if self.dict_buffer:
715
- gw_param = self.dict_buffer.copy()
716
- # store all params in json file
717
- print(f"saving all gw_params in {output_path}...")
718
- append_json(output_path, gw_param, replace=True)
719
- else:
720
- print("gw_params already sampled.")
721
- gw_param = get_param_from_json(output_path)
722
- self.dict_buffer = None # save memory
723
-
724
756
  return gw_param
725
757
 
726
758
  def gw_sampling_routine(self, size, output_jsonfile, resume=False, save_batch=True):
@@ -760,17 +792,6 @@ class GWRATES(CBCSourceParameterDistribution):
760
792
  pdet = self.pdet(gw_param_dict=gw_param)
761
793
  gw_param.update(pdet)
762
794
 
763
- # adding batches
764
- if not save_batch:
765
- if self.dict_buffer is None:
766
- self.dict_buffer = gw_param
767
- else:
768
- for key, value in gw_param.items():
769
- self.dict_buffer[key] = np.concatenate((self.dict_buffer[key], value))
770
- else:
771
- # store all params in json file
772
- self.dict_buffer = append_json(file_name=output_jsonfile, new_dictionary=gw_param, old_dictionary=self.dict_buffer, replace=not (resume))
773
-
774
795
  return gw_param
775
796
 
776
797
  def gw_rate(
@@ -897,7 +918,7 @@ class GWRATES(CBCSourceParameterDistribution):
897
918
  dictionary of GW source parameters.
898
919
  """
899
920
 
900
- snr_param = gw_param["optimal_snr_net"]
921
+ snr_param = gw_param["snr_net"]
901
922
  idx_detectable = (snr_param > snr_threshold_recalculation[0]) & (snr_param < snr_threshold_recalculation[1])
902
923
  # reduce the size of the dict
903
924
  for key, value in gw_param.items():
@@ -931,15 +952,15 @@ class GWRATES(CBCSourceParameterDistribution):
931
952
  """
932
953
 
933
954
  if self.snr:
934
- if "optimal_snr_net" not in gw_param:
935
- raise ValueError("'optimal_snr_net' not in gw param dict provided")
955
+ if "snr_net" not in gw_param:
956
+ raise ValueError("'snr_net' not in gw param dict provided")
936
957
  if detectability_condition == "step_function":
937
958
  print("given detectability_condition == 'step_function'")
938
- param = gw_param["optimal_snr_net"]
959
+ param = gw_param["snr_net"]
939
960
  threshold = snr_threshold
940
961
  elif detectability_condition == "pdet":
941
962
  print("given detectability_condition == 'pdet'")
942
- param = 1 - norm.cdf(snr_threshold - gw_param["optimal_snr_net"])
963
+ param = 1 - norm.cdf(snr_threshold - gw_param["snr_net"])
943
964
  gw_param["pdet_net"] = param
944
965
  threshold = pdet_threshold
945
966
  elif self.pdet:
@@ -1061,13 +1082,17 @@ class GWRATES(CBCSourceParameterDistribution):
1061
1082
  self,
1062
1083
  size=100,
1063
1084
  batch_size=None,
1085
+ stopping_criteria=dict(
1086
+ relative_diff_percentage=0.5,
1087
+ number_of_last_batches_to_check=4,
1088
+ ),
1064
1089
  snr_threshold=8.0,
1065
1090
  pdet_threshold=0.5,
1066
1091
  resume=False,
1067
1092
  output_jsonfile="gw_params_n_detectable.json",
1068
1093
  meta_data_file="meta_gw.json",
1069
1094
  detectability_condition="step_function",
1070
- trim_to_size=True,
1095
+ trim_to_size=False,
1071
1096
  snr_recalculation=False,
1072
1097
  snr_threshold_recalculation=[4, 12],
1073
1098
  ):
@@ -1124,10 +1149,11 @@ class GWRATES(CBCSourceParameterDistribution):
1124
1149
  """
1125
1150
 
1126
1151
  # initial setup
1127
- n, events_total, output_path, meta_data_path, buffer_file = self._initial_setup_for_n_event_selection(meta_data_file, output_jsonfile, resume, batch_size)
1152
+ n, events_total, output_path, meta_data_path, buffer_file, batch_size = self._initial_setup_for_n_event_selection(meta_data_file, output_jsonfile, resume, batch_size)
1128
1153
 
1129
1154
  # loop until n samples are collected
1130
- while n < size:
1155
+ continue_condition = True
1156
+ while continue_condition:
1131
1157
  # disable print statements
1132
1158
  with contextlib.redirect_stdout(None):
1133
1159
  self.dict_buffer = None # this is used to store the sampled gw_param in batches when running the sampling_routine
@@ -1152,9 +1178,33 @@ class GWRATES(CBCSourceParameterDistribution):
1152
1178
  total_rate = self.rate_function(n, events_total, verbose=False)
1153
1179
 
1154
1180
  # bookmark
1155
- self._append_meta_data(meta_data_path, n, events_total, total_rate)
1156
-
1157
- print(f"stored detectable gw params in {output_path}")
1181
+ buffer_dict = self._append_meta_data(meta_data_path, n, events_total, total_rate)
1182
+
1183
+ if isinstance(stopping_criteria, dict):
1184
+ total_rates = np.array(buffer_dict['total_rate'])
1185
+ limit = stopping_criteria['relative_diff_percentage']
1186
+ num_a = stopping_criteria['number_of_last_batches_to_check']
1187
+ if len(total_rates)>num_a:
1188
+ num_a = int(-1*(num_a))
1189
+ # num_b = int(num_a)
1190
+ percentage_diff = np.abs((total_rates[num_a:]-total_rates[-1])/total_rates[-1])*100
1191
+ print(f"percentage difference for the last {abs(num_a)} batches = {percentage_diff}")
1192
+ if np.any(percentage_diff>limit):
1193
+ continue_condition &= True
1194
+ else:
1195
+ print(rf"stopping criteria of rate relative difference of {limit}% reached. If you want to collect more events, reduce stopping_criteria['relative_diff_percentage'] or put stopping_criteria=None.")
1196
+ continue_condition &= False
1197
+
1198
+ if isinstance(size, int):
1199
+ if n<size:
1200
+ continue_condition |= True
1201
+ else:
1202
+ print(rf"Given size={size} reached")
1203
+ continue_condition |= False
1204
+ if stopping_criteria is None:
1205
+ continue_condition &= False
1206
+
1207
+ print(f"stored detectable unlensed params in {output_path}")
1158
1208
  print(f"stored meta data in {meta_data_path}")
1159
1209
 
1160
1210
  if trim_to_size:
@@ -1221,6 +1271,8 @@ class GWRATES(CBCSourceParameterDistribution):
1221
1271
  if not resume:
1222
1272
  n = 0 # iterator
1223
1273
  events_total = 0
1274
+ # the following file will be removed if it exists
1275
+ print(f"removing {output_path} and {meta_data_path} if they exist")
1224
1276
  if os.path.exists(output_path):
1225
1277
  os.remove(output_path)
1226
1278
  if os.path.exists(meta_data_path):
@@ -1239,7 +1291,7 @@ class GWRATES(CBCSourceParameterDistribution):
1239
1291
  buffer_file = "params_buffer.json"
1240
1292
  print("collected number of detectable events = ", n)
1241
1293
 
1242
- return n, events_total, output_path, meta_data_path, buffer_file
1294
+ return n, events_total, output_path, meta_data_path, buffer_file, batch_size
1243
1295
 
1244
1296
  def _trim_results_to_size(self, size, output_path, meta_data_path):
1245
1297
  """
@@ -1316,12 +1368,14 @@ class GWRATES(CBCSourceParameterDistribution):
1316
1368
 
1317
1369
  if os.path.exists(meta_data_path):
1318
1370
  try:
1319
- append_json(meta_data_path, meta_data, replace=False)
1371
+ dict_ = append_json(meta_data_path, meta_data, replace=False)
1320
1372
  except:
1321
- append_json(meta_data_path, meta_data, replace=True)
1373
+ dict_ = append_json(meta_data_path, meta_data, replace=True)
1322
1374
  else:
1323
- append_json(meta_data_path, meta_data, replace=True)
1375
+ dict_ = append_json(meta_data_path, meta_data, replace=True)
1324
1376
 
1325
1377
  print("collected number of detectable events = ", n)
1326
1378
  print("total number of events = ", events_total)
1327
- print(f"total rate (yr^-1): {total_rate}")
1379
+ print(f"total rate (yr^-1): {total_rate}")
1380
+
1381
+ return dict_