ler 0.4.2__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ler might be problematic. Click here for more details.

Files changed (35) hide show
  1. ler/__init__.py +26 -26
  2. ler/gw_source_population/__init__.py +1 -0
  3. ler/gw_source_population/cbc_source_parameter_distribution.py +1073 -815
  4. ler/gw_source_population/cbc_source_redshift_distribution.py +618 -294
  5. ler/gw_source_population/jit_functions.py +484 -9
  6. ler/gw_source_population/sfr_with_time_delay.py +107 -0
  7. ler/image_properties/image_properties.py +41 -12
  8. ler/image_properties/multiprocessing_routine.py +5 -209
  9. ler/lens_galaxy_population/__init__.py +2 -0
  10. ler/lens_galaxy_population/epl_shear_cross_section.py +0 -0
  11. ler/lens_galaxy_population/jit_functions.py +101 -9
  12. ler/lens_galaxy_population/lens_galaxy_parameter_distribution.py +813 -881
  13. ler/lens_galaxy_population/lens_param_data/density_profile_slope_sl.txt +5000 -0
  14. ler/lens_galaxy_population/lens_param_data/external_shear_sl.txt +2 -0
  15. ler/lens_galaxy_population/lens_param_data/number_density_zl_zs.txt +48 -0
  16. ler/lens_galaxy_population/lens_param_data/optical_depth_epl_shear_vd_ewoud.txt +48 -0
  17. ler/lens_galaxy_population/mp copy.py +554 -0
  18. ler/lens_galaxy_population/mp.py +736 -138
  19. ler/lens_galaxy_population/optical_depth.py +2248 -616
  20. ler/rates/__init__.py +1 -2
  21. ler/rates/gwrates.py +126 -72
  22. ler/rates/ler.py +218 -111
  23. ler/utils/__init__.py +2 -0
  24. ler/utils/function_interpolation.py +322 -0
  25. ler/utils/gwsnr_training_data_generator.py +233 -0
  26. ler/utils/plots.py +1 -1
  27. ler/utils/test.py +1078 -0
  28. ler/utils/utils.py +492 -125
  29. {ler-0.4.2.dist-info → ler-0.4.3.dist-info}/METADATA +30 -17
  30. ler-0.4.3.dist-info/RECORD +34 -0
  31. {ler-0.4.2.dist-info → ler-0.4.3.dist-info}/WHEEL +1 -1
  32. ler/rates/ler copy.py +0 -2097
  33. ler-0.4.2.dist-info/RECORD +0 -25
  34. {ler-0.4.2.dist-info → ler-0.4.3.dist-info/licenses}/LICENSE +0 -0
  35. {ler-0.4.2.dist-info → ler-0.4.3.dist-info}/top_level.txt +0 -0
ler/utils/test.py ADDED
@@ -0,0 +1,1078 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ This module contains helper routines for other modules in the ler package.
4
+ """
5
+
6
+ import os
7
+ import pickle
8
+ import h5py
9
+ import numpy as np
10
+ import json
11
+ from scipy.interpolate import interp1d
12
+ from scipy.interpolate import CubicSpline
13
+ from scipy.integrate import quad, cumtrapz
14
+ from numba import njit
15
+ # import datetime
16
+
17
+
18
+ class NumpyEncoder(json.JSONEncoder):
19
+ """
20
+ Class for storing a numpy.ndarray or any nested-list composition as JSON file. This is required for dealing np.nan and np.inf.
21
+
22
+ Parameters
23
+ ----------
24
+ json.JSONEncoder : `class`
25
+ class for encoding JSON file
26
+
27
+ Returns
28
+ ----------
29
+ json.JSONEncoder.default : `function`
30
+ function for encoding JSON file
31
+
32
+ Example
33
+ ----------
34
+ >>> import numpy as np
35
+ >>> import json
36
+ >>> from ler import helperroutines as hr
37
+ >>> # create a dictionary
38
+ >>> param = {'a': np.array([1,2,3]), 'b': np.array([4,5,6])}
39
+ >>> # save the dictionary as json file
40
+ >>> with open('param.json', 'w') as f:
41
+ >>> json.dump(param, f, cls=hr.NumpyEncoder)
42
+ >>> # load the dictionary from json file
43
+ >>> with open('param.json', 'r') as f:
44
+ >>> param = json.load(f)
45
+ >>> # print the dictionary
46
+ >>> print(param)
47
+ {'a': [1, 2, 3], 'b': [4, 5, 6]}
48
+ """
49
+
50
+ def default(self, obj):
51
+ """function for encoding JSON file"""
52
+ if isinstance(obj, np.ndarray):
53
+ return obj.tolist()
54
+ return json.JSONEncoder.default(self, obj)
55
+
56
+ def load_pickle(file_name):
57
+ """Load a pickle file.
58
+
59
+ Parameters
60
+ ----------
61
+ file_name : `str`
62
+ pickle file name for storing the parameters.
63
+
64
+ Returns
65
+ ----------
66
+ param : `dict`
67
+ """
68
+ with open(file_name, "rb") as handle:
69
+ param = pickle.load(handle)
70
+
71
+ return param
72
+
73
+ def save_pickle(file_name, param):
74
+ """Save a dictionary as a pickle file.
75
+
76
+ Parameters
77
+ ----------
78
+ file_name : `str`
79
+ pickle file name for storing the parameters.
80
+ param : `dict`
81
+ dictionary to be saved as a pickle file.
82
+ """
83
+ with open(file_name, "wb") as handle:
84
+ pickle.dump(param, handle, protocol=pickle.HIGHEST_PROTOCOL)
85
+
86
+ # hdf5
87
+ def load_hdf5(file_name):
88
+ """Load a hdf5 file.
89
+
90
+ Parameters
91
+ ----------
92
+ file_name : `str`
93
+ hdf5 file name for storing the parameters.
94
+
95
+ Returns
96
+ ----------
97
+ param : `dict`
98
+ """
99
+
100
+ return h5py.File(file_name, 'r')
101
+
102
+ def save_hdf5(file_name, param):
103
+ """Save a dictionary as a hdf5 file.
104
+
105
+ Parameters
106
+ ----------
107
+ file_name : `str`
108
+ hdf5 file name for storing the parameters.
109
+ param : `dict`
110
+ dictionary to be saved as a hdf5 file.
111
+ """
112
+ with h5py.File(file_name, 'w') as f:
113
+ for key, value in param.items():
114
+ f.create_dataset(key, data=value)
115
+
116
+ def load_json(file_name):
117
+ """Load a json file.
118
+
119
+ Parameters
120
+ ----------
121
+ file_name : `str`
122
+ json file name for storing the parameters.
123
+
124
+ Returns
125
+ ----------
126
+ param : `dict`
127
+ """
128
+ with open(file_name, "r", encoding="utf-8") as f:
129
+ param = json.load(f)
130
+
131
+ return param
132
+
133
+ def save_json(file_name, param):
134
+ """Save a dictionary as a json file.
135
+
136
+ Parameters
137
+ ----------
138
+ file_name : `str`
139
+ json file name for storing the parameters.
140
+ param : `dict`
141
+ dictionary to be saved as a json file.
142
+ """
143
+ with open(file_name, "w", encoding="utf-8") as write_file:
144
+ json.dump(param, write_file)
145
+
146
+ def append_json(file_name, new_dictionary, old_dictionary=None, replace=False):
147
+ """
148
+ Append (values with corresponding keys) and update a json file with a dictionary. There are four options:
149
+
150
+ 1. If old_dictionary is provided, the values of the new dictionary will be appended to the old dictionary and save in the 'file_name' json file.
151
+ 2. If replace is True, replace the json file (with the 'file_name') content with the new_dictionary.
152
+ 3. If the file does not exist, create a new one with the new_dictionary.
153
+ 4. If none of the above, append the new dictionary to the content of the json file.
154
+
155
+ Parameters
156
+ ----------
157
+ file_name : `str`
158
+ json file name for storing the parameters.
159
+ new_dictionary : `dict`
160
+ dictionary to be appended to the json file.
161
+ old_dictionary : `dict`, optional
162
+ If provided the values of the new dictionary will be appended to the old dictionary and save in the 'file_name' json file.
163
+ Default is None.
164
+ replace : `bool`, optional
165
+ If True, replace the json file with the dictionary. Default is False.
166
+
167
+ """
168
+
169
+ # check if the file exists
170
+ # time
171
+ # start = datetime.datetime.now()
172
+ if old_dictionary:
173
+ data = old_dictionary
174
+ elif replace:
175
+ data = new_dictionary
176
+ elif not os.path.exists(file_name):
177
+ #print(f" {file_name} file does not exist. Creating a new one...")
178
+ replace = True
179
+ data = new_dictionary
180
+ else:
181
+ #print("getting data from file")
182
+ with open(file_name, "r", encoding="utf-8") as f:
183
+ data = json.load(f)
184
+ # end = datetime.datetime.now()
185
+ # print(f"Time taken to load the json file: {end-start}")
186
+
187
+ # start = datetime.datetime.now()
188
+ if not replace:
189
+ data = add_dictionaries_together(data, new_dictionary)
190
+ # data_key = data.keys()
191
+ # for key, value in new_dictionary.items():
192
+ # if key in data_key:
193
+ # data[key] = np.concatenate((data[key], value)).tolist()
194
+ # end = datetime.datetime.now()
195
+ # print(f"Time taken to append the dictionary: {end-start}")
196
+
197
+ # save the dictionary
198
+ # start = datetime.datetime.now()
199
+ #print(data)
200
+ with open(file_name, "w", encoding="utf-8") as write_file:
201
+ json.dump(data, write_file, indent=4, cls=NumpyEncoder)
202
+ # end = datetime.datetime.now()
203
+ # print(f"Time taken to save the json file: {end-start}")
204
+
205
+ return data
206
+
207
+ # def add_dict_values(dict1, dict2):
208
+ # """Adds the values of two dictionaries together.
209
+
210
+ # Parameters
211
+ # ----------
212
+ # dict1 : `dict`
213
+ # dictionary to be added.
214
+ # dict2 : `dict`
215
+ # dictionary to be added.
216
+
217
+ # Returns
218
+ # ----------
219
+ # dict1 : `dict`
220
+ # dictionary with added values.
221
+ # """
222
+ # data_key = dict1.keys()
223
+ # for key, value in dict2.items():
224
+ # if key in data_key:
225
+ # dict1[key] = np.concatenate((dict1[key], value))
226
+
227
+ return dict1
228
+
229
+ def get_param_from_json(json_file):
230
+ """
231
+ Function to get the parameters from json file.
232
+
233
+ Parameters
234
+ ----------
235
+ json_file : `str`
236
+ json file name for storing the parameters.
237
+
238
+ Returns
239
+ ----------
240
+ param : `dict`
241
+ """
242
+ with open(json_file, "r", encoding="utf-8") as f:
243
+ param = json.load(f)
244
+
245
+ for key, value in param.items():
246
+ param[key] = np.array(value)
247
+ return param
248
+
249
+ def rejection_sample(pdf, xmin, xmax, size=100, chunk_size=10000):
250
+ """
251
+ Helper function for rejection sampling from a pdf with maximum and minimum arguments.
252
+
253
+ Parameters
254
+ ----------
255
+ pdf : `function`
256
+ pdf function.
257
+ xmin : `float`
258
+ minimum value of the pdf.
259
+ xmax : `float`
260
+ maximum value of the pdf.
261
+ size : `int`, optional
262
+ number of samples. Default is 100.
263
+ chunk_size : `int`, optional
264
+ chunk size for sampling. Default is 10000.
265
+
266
+ Returns
267
+ ----------
268
+ x_sample : `numpy.ndarray`
269
+ samples from the pdf.
270
+ """
271
+ x = np.linspace(xmin, xmax, chunk_size)
272
+ y = pdf(x)
273
+ # Maximum value of the pdf
274
+ ymax = np.max(y)
275
+
276
+ # Rejection sample in chunks
277
+ x_sample = []
278
+ while len(x_sample) < size:
279
+ x_try = np.random.uniform(xmin, xmax, size=chunk_size)
280
+ pdf_x_try = pdf(x_try) # Calculate the pdf at the random x values
281
+ # this is for comparing with the pdf value at x_try, will be used to accept or reject the sample
282
+ y_try = np.random.uniform(0, ymax, size=chunk_size)
283
+
284
+ # Update the maximum value of the pdf
285
+ ymax = max(ymax, np.max(pdf_x_try))
286
+
287
+ # applying condition to accept the sample
288
+ # Add while retaining 1D shape of the list
289
+ x_sample += list(x_try[y_try < pdf_x_try])
290
+
291
+ # Transform the samples to a 1D numpy array
292
+ x_sample = np.array(x_sample).flatten()
293
+ # Return the correct number of samples
294
+ return x_sample[:size]
295
+
296
+
297
+ def rejection_sample2d(pdf, xmin, xmax, ymin, ymax, size=100, chunk_size=10000):
298
+ """
299
+ Helper function for rejection sampling from a 2D pdf with maximum and minimum arguments.
300
+
301
+ Parameters
302
+ ----------
303
+ pdf : `function`
304
+ 2D pdf function.
305
+ xmin : `float`
306
+ minimum value of the pdf in the x-axis.
307
+ xmax : `float`
308
+ maximum value of the pdf in the x-axis.
309
+ ymin : `float`
310
+ minimum value of the pdf in the y-axis.
311
+ ymax : `float`
312
+ maximum value of the pdf in the y-axis.
313
+ size : `int`, optional
314
+ number of samples. Default is 100.
315
+ chunk_size : `int`, optional
316
+ chunk size for sampling. Default is 10000.
317
+
318
+ Returns
319
+ ----------
320
+ x_sample : `numpy.ndarray`
321
+ samples from the pdf in the x-axis.
322
+ """
323
+
324
+ x = np.random.uniform(xmin, xmax, chunk_size)
325
+ y = np.random.uniform(ymin, ymax, chunk_size)
326
+ z = pdf(x, y)
327
+ # Maximum value of the pdf
328
+ zmax = np.max(z)
329
+
330
+ # Rejection sample in chunks
331
+ x_sample = []
332
+ y_sample = []
333
+ while len(x_sample) < size:
334
+ x_try = np.random.uniform(xmin, xmax, size=chunk_size)
335
+ y_try = np.random.uniform(ymin, ymax, size=chunk_size)
336
+ pdf_xy_try = pdf(x_try, y_try)
337
+ # this is for comparing with the pdf value at x_try, will be used to accept or reject the sample
338
+ z_try = np.random.uniform(0, zmax, size=chunk_size)
339
+
340
+ # Update the maximum value of the pdf
341
+ zmax = max(zmax, np.max(pdf_xy_try))
342
+
343
+ x_sample += list(x_try[z_try < pdf_xy_try])
344
+ y_sample += list(y_try[z_try < pdf_xy_try])
345
+
346
+ # Transform the samples to a 1D numpy array
347
+ x_sample = np.array(x_sample).flatten()
348
+ y_sample = np.array(y_sample).flatten()
349
+ # Return the correct number of samples
350
+ return x_sample[:size], y_sample[:size]
351
+
352
+
353
+ def add_dictionaries_together(dictionary1, dictionary2):
354
+ """
355
+ Adds two dictionaries with the same keys together.
356
+
357
+ Parameters
358
+ ----------
359
+ dictionary1 : `dict`
360
+ dictionary to be added.
361
+ dictionary2 : `dict`
362
+ dictionary to be added.
363
+
364
+ Returns
365
+ ----------
366
+ dictionary : `dict`
367
+ dictionary with added values.
368
+ """
369
+ dictionary = {}
370
+ # Check if either dictionary empty, in which case only return the dictionary with values
371
+ if len(dictionary1) == 0:
372
+ return dictionary2
373
+ elif len(dictionary2) == 0:
374
+ return dictionary1
375
+ # Check if the keys are the same
376
+ if dictionary1.keys() != dictionary2.keys():
377
+ raise ValueError("The dictionaries have different keys.")
378
+ for key in dictionary1.keys():
379
+ value1 = dictionary1[key]
380
+ value2 = dictionary2[key]
381
+
382
+ # check if the value is empty
383
+ bool0 = len(value1) == 0 or len(value2) == 0
384
+ # check if the value is an ndarray or a list
385
+ bool1 = isinstance(value1, np.ndarray) and isinstance(value2, np.ndarray)
386
+ bool2 = isinstance(value1, list) and isinstance(value2, list)
387
+ bool3 = isinstance(value1, np.ndarray) and isinstance(value2, list)
388
+ bool4 = isinstance(value1, list) and isinstance(value2, np.ndarray)
389
+ bool4 = bool4 or bool3
390
+ bool5 = isinstance(value1, dict) and isinstance(value2, dict)
391
+
392
+ if bool0:
393
+ if len(value1) == 0 and len(value2) == 0:
394
+ dictionary[key] = np.array([])
395
+ elif len(value1) != 0 and len(value2) == 0:
396
+ dictionary[key] = np.array(value1)
397
+ elif len(value1) == 0 and len(value2) != 0:
398
+ dictionary[key] = np.array(value2)
399
+ elif bool1:
400
+ dictionary[key] = np.concatenate((value1, value2))
401
+ elif bool2:
402
+ dictionary[key] = value1 + value2
403
+ elif bool4:
404
+ dictionary[key] = np.concatenate((np.array(value1), np.array(value2)))
405
+ elif bool5:
406
+ dictionary[key] = add_dictionaries_together(
407
+ dictionary1[key], dictionary2[key]
408
+ )
409
+ else:
410
+ raise ValueError(
411
+ "The dictionary contains an item which is neither an ndarray nor a dictionary."
412
+ )
413
+ return dictionary
414
+
415
+
416
+ def trim_dictionary(dictionary, size):
417
+ """
418
+ Filters an event dictionary to only contain the size.
419
+
420
+ Parameters
421
+ ----------
422
+ dictionary : `dict`
423
+ dictionary to be trimmed.
424
+ size : `int`
425
+ size to trim the dictionary to.
426
+
427
+ Returns
428
+ ----------
429
+ dictionary : `dict`
430
+ trimmed dictionary.
431
+ """
432
+ for key in dictionary.keys():
433
+ # Check if the item is an ndarray
434
+ if isinstance(dictionary[key], np.ndarray):
435
+ dictionary[key] = dictionary[key][:size] # Trim the array
436
+ # Check if the item is a nested dictionary
437
+ elif isinstance(dictionary[key], dict):
438
+ dictionary[key] = trim_dictionary(
439
+ dictionary[key], size
440
+ ) # Trim the nested dictionary
441
+ else:
442
+ raise ValueError(
443
+ "The dictionary contains an item which is neither an ndarray nor a dictionary."
444
+ )
445
+ return dictionary
446
+
447
+ def create_func_pdf_invcdf(x, y, category="function"):
448
+ """
449
+ Function to create a interpolated function, inverse function or inverse cdf from the input x and y.
450
+
451
+ Parameters
452
+ ----------
453
+ x : `numpy.ndarray`
454
+ x values. This has to sorted in ascending order.
455
+ y : `numpy.ndarray`
456
+ y values. Corresponding to the x values.
457
+ category : `str`, optional
458
+ category of the function. Default is "function". Other options are "function_inverse", "pdf" and "inv_cdf".
459
+
460
+ Returns
461
+ ----------
462
+ pdf : `pdf function`
463
+ interpolated pdf function.
464
+ inv_pdf : `function inverse`
465
+ interpolated inverse pdf function.
466
+ inv_cdf : `function`
467
+ interpolated inverse cdf.
468
+ """
469
+
470
+ idx = np.argwhere(np.isnan(y))
471
+ x = np.delete(x, idx)
472
+ y = np.delete(y, idx)
473
+
474
+ # create pdf with interpolation
475
+ pdf_unorm = interp1d(x, y, kind="cubic", fill_value="extrapolate")
476
+ if category == "function":
477
+ return pdf_unorm
478
+ if category == "function_inverse":
479
+ # create inverse function
480
+ return interp1d(y, x, kind="cubic", fill_value="extrapolate")
481
+
482
+ min_, max_ = min(x), max(x)
483
+ norm = quad(pdf_unorm, min_, max_)[0]
484
+ y = y / norm
485
+ if category == "pdf" or category is None:
486
+ # normalize the pdf
487
+ pdf = interp1d(x, y, kind="cubic", fill_value="extrapolate")
488
+ return pdf
489
+ # cdf
490
+ cdf_values = cumtrapz(y, x, initial=0)
491
+ idx = np.argwhere(cdf_values > 0)[0][0]
492
+ cdf_values = cdf_values[idx:]
493
+ x = x[idx:]
494
+ inv_cdf = interp1d(cdf_values, x, kind="cubic", fill_value="extrapolate")
495
+ if category == "inv_cdf":
496
+ return inv_cdf
497
+ if category == "all":
498
+ return([pdf, inv_cdf])
499
+
500
+ def create_conditioned_pdf_invcdf(x, conditioned_y, pdf_func, category):
501
+ """
502
+ pdf_func is the function to calculate the pdf of x given y
503
+ x is an array and the output of pdf_func is an array
504
+ y is the condition
505
+ we consider parameter plane of x and y
506
+
507
+ Parameters
508
+ ----------
509
+ x : `numpy.ndarray`
510
+ x values.
511
+ conditioned_y : `numpy.ndarray`
512
+ conditioned y values.
513
+ pdf_func : `function`
514
+ function to calculate the pdf of x given y.
515
+ category : `str`, optional
516
+ category of the function. Default is "function". Other options are "function_inverse", "pdf" and "inv_cdf".
517
+ """
518
+
519
+ list_ = []
520
+ for y in conditioned_y:
521
+ phi = pdf_func(x,y)
522
+ # append pdf for each y along the x-axis
523
+ list_.append(create_func_pdf_invcdf(x, phi, category=category))
524
+
525
+ return list_
526
+
527
+ def create_func(x, y):
528
+ """
529
+ Function to create a spline interpolated function from the input x and y.
530
+
531
+ Parameters
532
+ ----------
533
+ x : `numpy.ndarray`
534
+ x values.
535
+ y : `numpy.ndarray`
536
+ y values.
537
+
538
+ Returns
539
+ ----------
540
+ c : `numpy.ndarray`
541
+ spline coefficients.
542
+ """
543
+
544
+ idx = np.argwhere(np.isnan(y))
545
+ x = np.delete(x, idx)
546
+ y = np.delete(y, idx)
547
+ return CubicSpline(x, y).c, x
548
+
549
+ def create_func_inv(x, y):
550
+ """
551
+ Function to create a spline interpolated inverse function from the input x and y.
552
+
553
+ Parameters
554
+ ----------
555
+ x : `numpy.ndarray`
556
+ x values.
557
+ y : `numpy.ndarray`
558
+ y values.
559
+
560
+ Returns
561
+ ----------
562
+ c : `numpy.ndarray`
563
+ spline coefficients.
564
+ """
565
+
566
+ idx = np.argwhere(np.isnan(y))
567
+ x = np.delete(x, idx)
568
+ y = np.delete(y, idx)
569
+ return CubicSpline(y, x).c, y
570
+
571
+ def create_pdf(x, y):
572
+ """
573
+ Function to create a spline interpolated normalized pdf from the input x and y.
574
+
575
+ Parameters
576
+ ----------
577
+ x : `numpy.ndarray`
578
+ x values.
579
+ y : `numpy.ndarray`
580
+ y values.
581
+
582
+ Returns
583
+ ----------
584
+ c : `numpy.ndarray`
585
+ spline coefficients.
586
+ """
587
+ idx = np.argwhere(np.isnan(y))
588
+ x = np.delete(x, idx)
589
+ y = np.delete(y, idx)
590
+ pdf_unorm = interp1d(x, y, kind="cubic", fill_value="extrapolate")
591
+ min_, max_ = min(x), max(x)
592
+ norm = quad(pdf_unorm, min_, max_)[0]
593
+ y = y / norm
594
+ return CubicSpline(x, y).c, x
595
+
596
+ def create_inv_cdf_array(x, y):
597
+ """
598
+ Function to create a spline interpolated inverse cdf from the input x and y.
599
+
600
+ Parameters
601
+ ----------
602
+ x : `numpy.ndarray`
603
+ x values.
604
+ y : `numpy.ndarray`
605
+ y values.
606
+
607
+ Returns
608
+ ----------
609
+ c : `numpy.ndarray`
610
+ spline coefficients.
611
+ """
612
+
613
+ idx = np.argwhere(np.isnan(y))
614
+ x = np.delete(x, idx)
615
+ y = np.delete(y, idx)
616
+ cdf_values = cumtrapz(y, x, initial=0)
617
+ cdf_values = cdf_values / cdf_values[-1]
618
+ # to remove duplicate values on x-axis before interpolation
619
+ # idx = np.argwhere(cdf_values > 0)[0][0]
620
+ # cdf_values = cdf_values[idx:]
621
+ # x = x[idx:]
622
+ # cdf_values = np.insert(cdf_values, 0, 0)
623
+ # x = np.insert(x, 0, x[idx-1])
624
+ return np.array([cdf_values, x])
625
+
626
+ def create_conditioned_pdf(x, conditioned_y, pdf_func):
627
+ """
628
+ Function to create a conditioned pdf from the input x and y.
629
+
630
+ Parameters
631
+ ----------
632
+ x : `numpy.ndarray`
633
+ x values.
634
+ conditioned_y : `numpy.ndarray`
635
+ conditioned y values.
636
+ pdf_func : `function`
637
+ function to calculate the pdf of x given y.
638
+
639
+ Returns
640
+ ----------
641
+ list_ : `list`
642
+ list of pdfs.
643
+ """
644
+ list_ = []
645
+ for y in conditioned_y:
646
+ phi = pdf_func(x,y)
647
+ list_.append(create_pdf(x, phi))
648
+
649
+ return np.array(list_)
650
+
651
+ def create_conditioned_inv_cdf_array(x, conditioned_y, pdf_func):
652
+ """
653
+ Function to create a conditioned inv_cdf from the input x and y.
654
+
655
+ Parameters
656
+ ----------
657
+ x : `numpy.ndarray`
658
+ x values.
659
+ conditioned_y : `numpy.ndarray`
660
+ conditioned y values.
661
+ pdf_func : `function`
662
+ function to calculate the pdf of x given y.
663
+
664
+ Returns
665
+ ----------
666
+ list_ : `list`
667
+ list of inv_cdfs.
668
+ """
669
+
670
+ list_ = []
671
+ for y in conditioned_y:
672
+ phi = pdf_func(x,y)
673
+ list_.append(create_inv_cdf_array(x, phi))
674
+
675
+ return np.array(list_)
676
+
677
+ def interpolator_from_pickle(
678
+ param_dict_given, directory, sub_directory, name, x, pdf_func=None, y=None, conditioned_y=None, dimension=1,category="pdf", create_new=False
679
+ ):
680
+ """
681
+ Function to decide which interpolator to use.
682
+
683
+ Parameters
684
+ ----------
685
+ param_dict_given : `dict`
686
+ dictionary of parameters.
687
+ directory : `str`
688
+ directory to store the interpolator.
689
+ sub_directory : `str`
690
+ sub-directory to store the interpolator.
691
+ name : `str`
692
+ name of the interpolator.
693
+ x : `numpy.ndarray`
694
+ x values.
695
+ pdf_func : `function`
696
+ function to calculate the pdf of x given y.
697
+ y : `numpy.ndarray`
698
+ y values.
699
+ conditioned_y : `numpy.ndarray`
700
+ conditioned y values.
701
+ dimension : `int`
702
+ dimension of the interpolator. Default is 1.
703
+ category : `str`
704
+ category of the function. Default is "pdf".
705
+ create_new : `bool`
706
+ if True, create a new interpolator. Default is False.
707
+
708
+ Returns
709
+ ----------
710
+ interpolator : `function`
711
+ interpolator function.
712
+ """
713
+
714
+ # check first whether the directory, subdirectory and pickle exist
715
+ path_inv_cdf, it_exist = interpolator_pickle_path(
716
+ param_dict_given=param_dict_given,
717
+ directory=directory,
718
+ sub_directory=sub_directory,
719
+ interpolator_name=name,
720
+ )
721
+ if create_new:
722
+ it_exist = False
723
+ if it_exist:
724
+ print(f"{name} interpolator will be loaded from {path_inv_cdf}")
725
+ # load the interpolator
726
+ with open(path_inv_cdf, "rb") as handle:
727
+ interpolator = pickle.load(handle)
728
+ return interpolator
729
+ else:
730
+ print(f"{name} interpolator will be generated at {path_inv_cdf}")
731
+
732
+ # create the interpolator
733
+ if dimension==1:
734
+ if y is None:
735
+ y = pdf_func(x)
736
+ if category=="function":
737
+ interpolator = create_func(x, y)
738
+ elif category=="function_inverse":
739
+ interpolator = create_func_inv(x, y)
740
+ elif category=="pdf":
741
+ interpolator = create_pdf(x, y)
742
+ elif category=="inv_cdf":
743
+ interpolator = create_inv_cdf_array(x, y)
744
+ else:
745
+ raise ValueError("The category given should be function, function_inverse, pdf or inv_cdf.")
746
+ elif dimension==2:
747
+ if category=="pdf":
748
+ interpolator = create_conditioned_pdf(x, conditioned_y, pdf_func)
749
+ elif category=="inv_cdf":
750
+ interpolator = create_conditioned_inv_cdf_array(x, conditioned_y, pdf_func)
751
+ else:
752
+ raise ValueError("The dimension is not supported.")
753
+ # save the interpolator
754
+ with open(path_inv_cdf, "wb") as handle:
755
+ pickle.dump(interpolator, handle, protocol=pickle.HIGHEST_PROTOCOL)
756
+ return interpolator
757
+
758
+ def interpolator_pickle_path(
759
+ param_dict_given,
760
+ directory,
761
+ sub_directory,
762
+ interpolator_name,
763
+ ):
764
+ """
765
+ Function to create the interpolator pickle file path.
766
+
767
+ Parameters
768
+ ----------
769
+ param_dict_given : `dict`
770
+ dictionary of parameters.
771
+ directory : `str`
772
+ directory to store the interpolator.
773
+ sub_directory : `str`
774
+ sub-directory to store the interpolator.
775
+ interpolator_name : `str`
776
+ name of the interpolator.
777
+
778
+ Returns
779
+ ----------
780
+ path_inv_cdf : `str`
781
+ path of the interpolator pickle file.
782
+ it_exist : `bool`
783
+ if True, the interpolator exists.
784
+ """
785
+
786
+ # check the dir 'interpolator' exist
787
+ full_dir = directory + "/" + sub_directory
788
+ if not os.path.exists(directory):
789
+ os.makedirs(directory)
790
+ os.makedirs(full_dir)
791
+ else:
792
+ if not os.path.exists(full_dir):
793
+ os.makedirs(full_dir)
794
+
795
+ # check if param_dict_list.pickle exists
796
+ path1 = full_dir + "/init_dict.pickle"
797
+ if not os.path.exists(path1):
798
+ dict_list = []
799
+ with open(path1, "wb") as handle:
800
+ pickle.dump(dict_list, handle, protocol=pickle.HIGHEST_PROTOCOL)
801
+
802
+ # check if the input dict is the same as one of the dict inside the pickle file
803
+ param_dict_stored = pickle.load(open(path1, "rb"))
804
+
805
+ path2 = full_dir
806
+ len_ = len(param_dict_stored)
807
+ if param_dict_given in param_dict_stored:
808
+ idx = param_dict_stored.index(param_dict_given)
809
+ # load the interpolator
810
+ path_inv_cdf = path2 + "/" + interpolator_name + "_" + str(idx) + ".pickle"
811
+ # there will be exception if the file is deleted by mistake
812
+ if os.path.exists(path_inv_cdf):
813
+ it_exist = True
814
+ else:
815
+ it_exist = False
816
+ else:
817
+ it_exist = False
818
+ path_inv_cdf = path2 + "/" + interpolator_name + "_" + str(len_) + ".pickle"
819
+ param_dict_stored.append(param_dict_given)
820
+ with open(path1, "wb") as handle:
821
+ pickle.dump(param_dict_stored, handle, protocol=pickle.HIGHEST_PROTOCOL)
822
+
823
+ return path_inv_cdf, it_exist
824
+
825
+ def interpolator_pdf_conditioned(x, conditioned_y, y_array, interpolator_list):
826
+ """
827
+ Function to find the pdf interpolator coefficients from the conditioned y.
828
+
829
+ Parameters
830
+ ----------
831
+ x : `numpy.ndarray`
832
+ x values.
833
+ conditioned_y : `float`
834
+ conditioned y value.
835
+ y_array : `numpy.ndarray`
836
+ y values.
837
+ interpolator_list : `list`
838
+ list of interpolators.
839
+
840
+ Returns
841
+ ----------
842
+ interpolator_list[idx](x) : `numpy.ndarray`
843
+ samples from the interpolator.
844
+ """
845
+ # find the index of z in zlist
846
+ idx = np.searchsorted(y_array, conditioned_y)
847
+
848
+ return interpolator_list[idx](x)
849
+
850
+ def interpolator_sampler_conditioned(conditioned_y, y_array, interpolator_list, size=1000):
851
+ """
852
+ Function to find sampler interpolator coefficients from the conditioned y.
853
+
854
+ Parameters
855
+ ----------
856
+ conditioned_y : `float`
857
+ conditioned y value.
858
+ y_array : `numpy.ndarray`
859
+ y values.
860
+ interpolator_list : `list`
861
+ list of interpolators.
862
+ size : `int`
863
+ number of samples.
864
+
865
+ Returns
866
+ ----------
867
+ """
868
+
869
+ # find the index of z in zlist
870
+ idx = np.searchsorted(y_array, conditioned_y)
871
+ u = np.random.uniform(0, 1, size=size)
872
+ return interpolator_list[idx](u)
873
+
874
+ @njit
875
+ def cubic_spline_interpolator(xnew, coefficients, x):
876
+ """
877
+ Function to interpolate using cubic spline.
878
+
879
+ Parameters
880
+ ----------
881
+ xnew : `numpy.ndarray`
882
+ new x values.
883
+ coefficients : `numpy.ndarray`
884
+ coefficients of the cubic spline.
885
+ x : `numpy.ndarray`
886
+ x values.
887
+
888
+ Returns
889
+ ----------
890
+ result : `numpy.ndarray`
891
+ interpolated values.
892
+ """
893
+
894
+ # Handling extrapolation
895
+ i = np.searchsorted(x, xnew) - 1
896
+ idx1 = xnew <= x[0]
897
+ idx2 = xnew > x[-1]
898
+ i[idx1] = 0
899
+ i[idx2] = len(x) - 2
900
+
901
+ # Calculate the relative position within the interval
902
+ dx = xnew - x[i]
903
+
904
+ # Calculate the interpolated value
905
+ # Cubic polynomial: a + b*dx + c*dx^2 + d*dx^3
906
+ a, b, c, d = coefficients[:, i]
907
+ #result = a + b*dx + c*dx**2 + d*dx**3
908
+ result = d + c*dx + b*dx**2 + a*dx**3
909
+ return result
910
+
911
+ @njit
912
+ def inverse_transform_sampler(size, cdf, x):
913
+ """
914
+ Function to sample from the inverse transform method.
915
+
916
+ Parameters
917
+ ----------
918
+ size : `int`
919
+ number of samples.
920
+ cdf : `numpy.ndarray`
921
+ cdf values.
922
+ x : `numpy.ndarray`
923
+ x values.
924
+
925
+ Returns
926
+ ----------
927
+ samples : `numpy.ndarray`
928
+ samples from the cdf.
929
+ """
930
+
931
+ u = np.random.uniform(0, 1, size)
932
+ idx = np.searchsorted(cdf, u)
933
+ x1, x0, y1, y0 = cdf[idx], cdf[idx-1], x[idx], x[idx-1]
934
+ samples = y0 + (y1 - y0) * (u - x0) / (x1 - x0)
935
+ return samples
936
+
937
+ def batch_handler(size, batch_size, sampling_routine, output_jsonfile, save_batch=True, resume=False, param_name='parameters'):
938
+ """
939
+ Function to run the sampling in batches.
940
+
941
+ Parameters
942
+ ----------
943
+ size : `int`
944
+ number of samples.
945
+ batch_size : `int`
946
+ batch size.
947
+ sampling_routine : `function`
948
+ sampling function. It should have 'size' as input and return a dictionary.
949
+ output_jsonfile : `str`
950
+ json file name for storing the parameters.
951
+ save_batch : `bool`, optional
952
+ if True, save sampled parameters in each iteration. Default is True.
953
+ resume : `bool`, optional
954
+ if True, resume sampling from the last batch. Default is False.
955
+ param_name : `str`, optional
956
+ name of the parameter. Default is 'parameters'.
957
+
958
+ Returns
959
+ ----------
960
+ dict_buffer : `dict`
961
+ dictionary of parameters.
962
+ """
963
+
964
+ # sampling in batches
965
+ if resume and os.path.exists(output_jsonfile):
966
+ # get sample from json file
967
+ dict_buffer = get_param_from_json(output_jsonfile)
968
+ else:
969
+ dict_buffer = None
970
+
971
+ # if size is multiple of batch_size
972
+ if size % batch_size == 0:
973
+ num_batches = size // batch_size
974
+ # if size is not multiple of batch_size
975
+ else:
976
+ num_batches = size // batch_size + 1
977
+
978
+ print(
979
+ f"chosen batch size = {batch_size} with total size = {size}"
980
+ )
981
+ print(f"There will be {num_batches} batche(s)")
982
+
983
+ # note frac_batches+(num_batches-1)*batch_size = size
984
+ if size > batch_size:
985
+ frac_batches = size - (num_batches - 1) * batch_size
986
+ # if size is less than batch_size
987
+ else:
988
+ frac_batches = size
989
+ track_batches = 0 # to track the number of batches
990
+
991
+ if not resume:
992
+ # create new first batch with the frac_batches
993
+ track_batches, dict_buffer = create_batch_params(sampling_routine, frac_batches, dict_buffer, save_batch, output_jsonfile, track_batches=track_batches)
994
+ else:
995
+ # check where to resume from
996
+ # identify the last batch and assign current batch number
997
+ # try-except is added to avoid the error when the file does not exist or if the file is empty or corrupted or does not have the required key.
998
+ try:
999
+ print(f"resuming from {output_jsonfile}")
1000
+ len_ = len(list(dict_buffer.values())[0])
1001
+ track_batches = (len_ - frac_batches) // batch_size + 1
1002
+ except:
1003
+ # create new first batch with the frac_batches
1004
+ track_batches, dict_buffer = create_batch_params(sampling_routine, frac_batches, dict_buffer, save_batch, output_jsonfile, track_batches=track_batches)
1005
+
1006
+ # loop over the remaining batches
1007
+ min_, max_ = track_batches, num_batches
1008
+ # print(f"min_ = {min_}, max_ = {max_}")
1009
+ save_param = False
1010
+ if min_ == max_:
1011
+ print(f"{param_name} already sampled.")
1012
+ elif min_ > max_:
1013
+ len_ = len(list(dict_buffer.values())[0])
1014
+ print(f"existing {param_name} size is {len_} is more than the required size={size}. It will be trimmed.")
1015
+ dict_buffer = trim_dictionary(dict_buffer, size)
1016
+ save_param = True
1017
+ else:
1018
+ for i in range(min_, max_):
1019
+ _, dict_buffer = create_batch_params(sampling_routine, batch_size, dict_buffer, save_batch, output_jsonfile, track_batches=i, resume=True)
1020
+
1021
+ if save_batch:
1022
+ # if save_batch=True, then dict_buffer is only the last batch
1023
+ dict_buffer = get_param_from_json(output_jsonfile)
1024
+ else: # dont save in batches
1025
+ # this if condition is required if there is nothing to save
1026
+ save_param = True
1027
+
1028
+ if save_param:
1029
+ # store all params in json file
1030
+ print(f"saving all {param_name} in {output_jsonfile} ")
1031
+ append_json(output_jsonfile, dict_buffer, replace=True)
1032
+
1033
+ return dict_buffer
1034
+
1035
+ def create_batch_params(sampling_routine, frac_batches, dict_buffer, save_batch, output_jsonfile, track_batches, resume=False):
1036
+ """
1037
+ Helper function to batch_handler. It create batch parameters and store in a dictionary.
1038
+
1039
+ Parameters
1040
+ ----------
1041
+ sampling_routine : `function`
1042
+ sampling function. It should have 'size' as input and return a dictionary.
1043
+ frac_batches : `int`
1044
+ batch size.
1045
+ dict_buffer : `dict`
1046
+ dictionary of parameters.
1047
+ save_batch : `bool`
1048
+ if True, save sampled parameters in each iteration.
1049
+ output_jsonfile : `str`
1050
+ json file name for storing the parameters.
1051
+ track_batches : `int`
1052
+ track the number of batches.
1053
+ resume : `bool`, optional
1054
+ if True, resume sampling from the last batch. Default is False.
1055
+
1056
+ Returns
1057
+ ----------
1058
+ track_batches : `int`
1059
+ track the number of batches.
1060
+ """
1061
+
1062
+ track_batches = track_batches + 1
1063
+ print(f"Batch no. {track_batches}")
1064
+ param = sampling_routine(size=frac_batches, save_batch=save_batch, output_jsonfile=output_jsonfile, resume=resume)
1065
+
1066
+ # adding batches and hold it in the buffer attribute.
1067
+ if not save_batch:
1068
+ # in the new batch (new sampling run), dict_buffer will be None
1069
+ if dict_buffer is None:
1070
+ dict_buffer = param
1071
+ else:
1072
+ for key, value in param.items():
1073
+ dict_buffer[key] = np.concatenate((dict_buffer[key], value))
1074
+ else:
1075
+ # store all params in json file
1076
+ dict_buffer = append_json(file_name=output_jsonfile, new_dictionary=param, old_dictionary=dict_buffer, replace=not (resume))
1077
+
1078
+ return track_batches, dict_buffer