ler 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ler might be problematic. Click here for more details.
- ler/__init__.py +26 -26
- ler/gw_source_population/__init__.py +1 -0
- ler/gw_source_population/cbc_source_parameter_distribution.py +1076 -818
- ler/gw_source_population/cbc_source_redshift_distribution.py +619 -295
- ler/gw_source_population/jit_functions.py +484 -9
- ler/gw_source_population/sfr_with_time_delay.py +107 -0
- ler/image_properties/image_properties.py +44 -13
- ler/image_properties/multiprocessing_routine.py +5 -209
- ler/lens_galaxy_population/__init__.py +2 -0
- ler/lens_galaxy_population/epl_shear_cross_section.py +0 -0
- ler/lens_galaxy_population/jit_functions.py +101 -9
- ler/lens_galaxy_population/lens_galaxy_parameter_distribution.py +817 -885
- ler/lens_galaxy_population/lens_param_data/density_profile_slope_sl.txt +5000 -0
- ler/lens_galaxy_population/lens_param_data/external_shear_sl.txt +2 -0
- ler/lens_galaxy_population/lens_param_data/number_density_zl_zs.txt +48 -0
- ler/lens_galaxy_population/lens_param_data/optical_depth_epl_shear_vd_ewoud.txt +48 -0
- ler/lens_galaxy_population/mp copy.py +554 -0
- ler/lens_galaxy_population/mp.py +736 -138
- ler/lens_galaxy_population/optical_depth.py +2248 -616
- ler/rates/__init__.py +1 -2
- ler/rates/gwrates.py +129 -75
- ler/rates/ler.py +257 -116
- ler/utils/__init__.py +2 -0
- ler/utils/function_interpolation.py +322 -0
- ler/utils/gwsnr_training_data_generator.py +233 -0
- ler/utils/plots.py +1 -1
- ler/utils/test.py +1078 -0
- ler/utils/utils.py +553 -125
- {ler-0.4.1.dist-info → ler-0.4.3.dist-info}/METADATA +22 -9
- ler-0.4.3.dist-info/RECORD +34 -0
- {ler-0.4.1.dist-info → ler-0.4.3.dist-info}/WHEEL +1 -1
- ler/rates/ler copy.py +0 -2097
- ler-0.4.1.dist-info/RECORD +0 -25
- {ler-0.4.1.dist-info → ler-0.4.3.dist-info/licenses}/LICENSE +0 -0
- {ler-0.4.1.dist-info → ler-0.4.3.dist-info}/top_level.txt +0 -0
ler/utils/test.py
ADDED
|
@@ -0,0 +1,1078 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
This module contains helper routines for other modules in the ler package.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
import pickle
|
|
8
|
+
import h5py
|
|
9
|
+
import numpy as np
|
|
10
|
+
import json
|
|
11
|
+
from scipy.interpolate import interp1d
|
|
12
|
+
from scipy.interpolate import CubicSpline
|
|
13
|
+
from scipy.integrate import quad, cumtrapz
|
|
14
|
+
from numba import njit
|
|
15
|
+
# import datetime
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class NumpyEncoder(json.JSONEncoder):
|
|
19
|
+
"""
|
|
20
|
+
Class for storing a numpy.ndarray or any nested-list composition as JSON file. This is required for dealing np.nan and np.inf.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
json.JSONEncoder : `class`
|
|
25
|
+
class for encoding JSON file
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
----------
|
|
29
|
+
json.JSONEncoder.default : `function`
|
|
30
|
+
function for encoding JSON file
|
|
31
|
+
|
|
32
|
+
Example
|
|
33
|
+
----------
|
|
34
|
+
>>> import numpy as np
|
|
35
|
+
>>> import json
|
|
36
|
+
>>> from ler import helperroutines as hr
|
|
37
|
+
>>> # create a dictionary
|
|
38
|
+
>>> param = {'a': np.array([1,2,3]), 'b': np.array([4,5,6])}
|
|
39
|
+
>>> # save the dictionary as json file
|
|
40
|
+
>>> with open('param.json', 'w') as f:
|
|
41
|
+
>>> json.dump(param, f, cls=hr.NumpyEncoder)
|
|
42
|
+
>>> # load the dictionary from json file
|
|
43
|
+
>>> with open('param.json', 'r') as f:
|
|
44
|
+
>>> param = json.load(f)
|
|
45
|
+
>>> # print the dictionary
|
|
46
|
+
>>> print(param)
|
|
47
|
+
{'a': [1, 2, 3], 'b': [4, 5, 6]}
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def default(self, obj):
|
|
51
|
+
"""function for encoding JSON file"""
|
|
52
|
+
if isinstance(obj, np.ndarray):
|
|
53
|
+
return obj.tolist()
|
|
54
|
+
return json.JSONEncoder.default(self, obj)
|
|
55
|
+
|
|
56
|
+
def load_pickle(file_name):
|
|
57
|
+
"""Load a pickle file.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
file_name : `str`
|
|
62
|
+
pickle file name for storing the parameters.
|
|
63
|
+
|
|
64
|
+
Returns
|
|
65
|
+
----------
|
|
66
|
+
param : `dict`
|
|
67
|
+
"""
|
|
68
|
+
with open(file_name, "rb") as handle:
|
|
69
|
+
param = pickle.load(handle)
|
|
70
|
+
|
|
71
|
+
return param
|
|
72
|
+
|
|
73
|
+
def save_pickle(file_name, param):
|
|
74
|
+
"""Save a dictionary as a pickle file.
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
file_name : `str`
|
|
79
|
+
pickle file name for storing the parameters.
|
|
80
|
+
param : `dict`
|
|
81
|
+
dictionary to be saved as a pickle file.
|
|
82
|
+
"""
|
|
83
|
+
with open(file_name, "wb") as handle:
|
|
84
|
+
pickle.dump(param, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
|
85
|
+
|
|
86
|
+
# hdf5
|
|
87
|
+
def load_hdf5(file_name):
|
|
88
|
+
"""Load a hdf5 file.
|
|
89
|
+
|
|
90
|
+
Parameters
|
|
91
|
+
----------
|
|
92
|
+
file_name : `str`
|
|
93
|
+
hdf5 file name for storing the parameters.
|
|
94
|
+
|
|
95
|
+
Returns
|
|
96
|
+
----------
|
|
97
|
+
param : `dict`
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
return h5py.File(file_name, 'r')
|
|
101
|
+
|
|
102
|
+
def save_hdf5(file_name, param):
|
|
103
|
+
"""Save a dictionary as a hdf5 file.
|
|
104
|
+
|
|
105
|
+
Parameters
|
|
106
|
+
----------
|
|
107
|
+
file_name : `str`
|
|
108
|
+
hdf5 file name for storing the parameters.
|
|
109
|
+
param : `dict`
|
|
110
|
+
dictionary to be saved as a hdf5 file.
|
|
111
|
+
"""
|
|
112
|
+
with h5py.File(file_name, 'w') as f:
|
|
113
|
+
for key, value in param.items():
|
|
114
|
+
f.create_dataset(key, data=value)
|
|
115
|
+
|
|
116
|
+
def load_json(file_name):
|
|
117
|
+
"""Load a json file.
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
file_name : `str`
|
|
122
|
+
json file name for storing the parameters.
|
|
123
|
+
|
|
124
|
+
Returns
|
|
125
|
+
----------
|
|
126
|
+
param : `dict`
|
|
127
|
+
"""
|
|
128
|
+
with open(file_name, "r", encoding="utf-8") as f:
|
|
129
|
+
param = json.load(f)
|
|
130
|
+
|
|
131
|
+
return param
|
|
132
|
+
|
|
133
|
+
def save_json(file_name, param):
|
|
134
|
+
"""Save a dictionary as a json file.
|
|
135
|
+
|
|
136
|
+
Parameters
|
|
137
|
+
----------
|
|
138
|
+
file_name : `str`
|
|
139
|
+
json file name for storing the parameters.
|
|
140
|
+
param : `dict`
|
|
141
|
+
dictionary to be saved as a json file.
|
|
142
|
+
"""
|
|
143
|
+
with open(file_name, "w", encoding="utf-8") as write_file:
|
|
144
|
+
json.dump(param, write_file)
|
|
145
|
+
|
|
146
|
+
def append_json(file_name, new_dictionary, old_dictionary=None, replace=False):
|
|
147
|
+
"""
|
|
148
|
+
Append (values with corresponding keys) and update a json file with a dictionary. There are four options:
|
|
149
|
+
|
|
150
|
+
1. If old_dictionary is provided, the values of the new dictionary will be appended to the old dictionary and save in the 'file_name' json file.
|
|
151
|
+
2. If replace is True, replace the json file (with the 'file_name') content with the new_dictionary.
|
|
152
|
+
3. If the file does not exist, create a new one with the new_dictionary.
|
|
153
|
+
4. If none of the above, append the new dictionary to the content of the json file.
|
|
154
|
+
|
|
155
|
+
Parameters
|
|
156
|
+
----------
|
|
157
|
+
file_name : `str`
|
|
158
|
+
json file name for storing the parameters.
|
|
159
|
+
new_dictionary : `dict`
|
|
160
|
+
dictionary to be appended to the json file.
|
|
161
|
+
old_dictionary : `dict`, optional
|
|
162
|
+
If provided the values of the new dictionary will be appended to the old dictionary and save in the 'file_name' json file.
|
|
163
|
+
Default is None.
|
|
164
|
+
replace : `bool`, optional
|
|
165
|
+
If True, replace the json file with the dictionary. Default is False.
|
|
166
|
+
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
# check if the file exists
|
|
170
|
+
# time
|
|
171
|
+
# start = datetime.datetime.now()
|
|
172
|
+
if old_dictionary:
|
|
173
|
+
data = old_dictionary
|
|
174
|
+
elif replace:
|
|
175
|
+
data = new_dictionary
|
|
176
|
+
elif not os.path.exists(file_name):
|
|
177
|
+
#print(f" {file_name} file does not exist. Creating a new one...")
|
|
178
|
+
replace = True
|
|
179
|
+
data = new_dictionary
|
|
180
|
+
else:
|
|
181
|
+
#print("getting data from file")
|
|
182
|
+
with open(file_name, "r", encoding="utf-8") as f:
|
|
183
|
+
data = json.load(f)
|
|
184
|
+
# end = datetime.datetime.now()
|
|
185
|
+
# print(f"Time taken to load the json file: {end-start}")
|
|
186
|
+
|
|
187
|
+
# start = datetime.datetime.now()
|
|
188
|
+
if not replace:
|
|
189
|
+
data = add_dictionaries_together(data, new_dictionary)
|
|
190
|
+
# data_key = data.keys()
|
|
191
|
+
# for key, value in new_dictionary.items():
|
|
192
|
+
# if key in data_key:
|
|
193
|
+
# data[key] = np.concatenate((data[key], value)).tolist()
|
|
194
|
+
# end = datetime.datetime.now()
|
|
195
|
+
# print(f"Time taken to append the dictionary: {end-start}")
|
|
196
|
+
|
|
197
|
+
# save the dictionary
|
|
198
|
+
# start = datetime.datetime.now()
|
|
199
|
+
#print(data)
|
|
200
|
+
with open(file_name, "w", encoding="utf-8") as write_file:
|
|
201
|
+
json.dump(data, write_file, indent=4, cls=NumpyEncoder)
|
|
202
|
+
# end = datetime.datetime.now()
|
|
203
|
+
# print(f"Time taken to save the json file: {end-start}")
|
|
204
|
+
|
|
205
|
+
return data
|
|
206
|
+
|
|
207
|
+
# def add_dict_values(dict1, dict2):
|
|
208
|
+
# """Adds the values of two dictionaries together.
|
|
209
|
+
|
|
210
|
+
# Parameters
|
|
211
|
+
# ----------
|
|
212
|
+
# dict1 : `dict`
|
|
213
|
+
# dictionary to be added.
|
|
214
|
+
# dict2 : `dict`
|
|
215
|
+
# dictionary to be added.
|
|
216
|
+
|
|
217
|
+
# Returns
|
|
218
|
+
# ----------
|
|
219
|
+
# dict1 : `dict`
|
|
220
|
+
# dictionary with added values.
|
|
221
|
+
# """
|
|
222
|
+
# data_key = dict1.keys()
|
|
223
|
+
# for key, value in dict2.items():
|
|
224
|
+
# if key in data_key:
|
|
225
|
+
# dict1[key] = np.concatenate((dict1[key], value))
|
|
226
|
+
|
|
227
|
+
return dict1
|
|
228
|
+
|
|
229
|
+
def get_param_from_json(json_file):
|
|
230
|
+
"""
|
|
231
|
+
Function to get the parameters from json file.
|
|
232
|
+
|
|
233
|
+
Parameters
|
|
234
|
+
----------
|
|
235
|
+
json_file : `str`
|
|
236
|
+
json file name for storing the parameters.
|
|
237
|
+
|
|
238
|
+
Returns
|
|
239
|
+
----------
|
|
240
|
+
param : `dict`
|
|
241
|
+
"""
|
|
242
|
+
with open(json_file, "r", encoding="utf-8") as f:
|
|
243
|
+
param = json.load(f)
|
|
244
|
+
|
|
245
|
+
for key, value in param.items():
|
|
246
|
+
param[key] = np.array(value)
|
|
247
|
+
return param
|
|
248
|
+
|
|
249
|
+
def rejection_sample(pdf, xmin, xmax, size=100, chunk_size=10000):
|
|
250
|
+
"""
|
|
251
|
+
Helper function for rejection sampling from a pdf with maximum and minimum arguments.
|
|
252
|
+
|
|
253
|
+
Parameters
|
|
254
|
+
----------
|
|
255
|
+
pdf : `function`
|
|
256
|
+
pdf function.
|
|
257
|
+
xmin : `float`
|
|
258
|
+
minimum value of the pdf.
|
|
259
|
+
xmax : `float`
|
|
260
|
+
maximum value of the pdf.
|
|
261
|
+
size : `int`, optional
|
|
262
|
+
number of samples. Default is 100.
|
|
263
|
+
chunk_size : `int`, optional
|
|
264
|
+
chunk size for sampling. Default is 10000.
|
|
265
|
+
|
|
266
|
+
Returns
|
|
267
|
+
----------
|
|
268
|
+
x_sample : `numpy.ndarray`
|
|
269
|
+
samples from the pdf.
|
|
270
|
+
"""
|
|
271
|
+
x = np.linspace(xmin, xmax, chunk_size)
|
|
272
|
+
y = pdf(x)
|
|
273
|
+
# Maximum value of the pdf
|
|
274
|
+
ymax = np.max(y)
|
|
275
|
+
|
|
276
|
+
# Rejection sample in chunks
|
|
277
|
+
x_sample = []
|
|
278
|
+
while len(x_sample) < size:
|
|
279
|
+
x_try = np.random.uniform(xmin, xmax, size=chunk_size)
|
|
280
|
+
pdf_x_try = pdf(x_try) # Calculate the pdf at the random x values
|
|
281
|
+
# this is for comparing with the pdf value at x_try, will be used to accept or reject the sample
|
|
282
|
+
y_try = np.random.uniform(0, ymax, size=chunk_size)
|
|
283
|
+
|
|
284
|
+
# Update the maximum value of the pdf
|
|
285
|
+
ymax = max(ymax, np.max(pdf_x_try))
|
|
286
|
+
|
|
287
|
+
# applying condition to accept the sample
|
|
288
|
+
# Add while retaining 1D shape of the list
|
|
289
|
+
x_sample += list(x_try[y_try < pdf_x_try])
|
|
290
|
+
|
|
291
|
+
# Transform the samples to a 1D numpy array
|
|
292
|
+
x_sample = np.array(x_sample).flatten()
|
|
293
|
+
# Return the correct number of samples
|
|
294
|
+
return x_sample[:size]
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def rejection_sample2d(pdf, xmin, xmax, ymin, ymax, size=100, chunk_size=10000):
|
|
298
|
+
"""
|
|
299
|
+
Helper function for rejection sampling from a 2D pdf with maximum and minimum arguments.
|
|
300
|
+
|
|
301
|
+
Parameters
|
|
302
|
+
----------
|
|
303
|
+
pdf : `function`
|
|
304
|
+
2D pdf function.
|
|
305
|
+
xmin : `float`
|
|
306
|
+
minimum value of the pdf in the x-axis.
|
|
307
|
+
xmax : `float`
|
|
308
|
+
maximum value of the pdf in the x-axis.
|
|
309
|
+
ymin : `float`
|
|
310
|
+
minimum value of the pdf in the y-axis.
|
|
311
|
+
ymax : `float`
|
|
312
|
+
maximum value of the pdf in the y-axis.
|
|
313
|
+
size : `int`, optional
|
|
314
|
+
number of samples. Default is 100.
|
|
315
|
+
chunk_size : `int`, optional
|
|
316
|
+
chunk size for sampling. Default is 10000.
|
|
317
|
+
|
|
318
|
+
Returns
|
|
319
|
+
----------
|
|
320
|
+
x_sample : `numpy.ndarray`
|
|
321
|
+
samples from the pdf in the x-axis.
|
|
322
|
+
"""
|
|
323
|
+
|
|
324
|
+
x = np.random.uniform(xmin, xmax, chunk_size)
|
|
325
|
+
y = np.random.uniform(ymin, ymax, chunk_size)
|
|
326
|
+
z = pdf(x, y)
|
|
327
|
+
# Maximum value of the pdf
|
|
328
|
+
zmax = np.max(z)
|
|
329
|
+
|
|
330
|
+
# Rejection sample in chunks
|
|
331
|
+
x_sample = []
|
|
332
|
+
y_sample = []
|
|
333
|
+
while len(x_sample) < size:
|
|
334
|
+
x_try = np.random.uniform(xmin, xmax, size=chunk_size)
|
|
335
|
+
y_try = np.random.uniform(ymin, ymax, size=chunk_size)
|
|
336
|
+
pdf_xy_try = pdf(x_try, y_try)
|
|
337
|
+
# this is for comparing with the pdf value at x_try, will be used to accept or reject the sample
|
|
338
|
+
z_try = np.random.uniform(0, zmax, size=chunk_size)
|
|
339
|
+
|
|
340
|
+
# Update the maximum value of the pdf
|
|
341
|
+
zmax = max(zmax, np.max(pdf_xy_try))
|
|
342
|
+
|
|
343
|
+
x_sample += list(x_try[z_try < pdf_xy_try])
|
|
344
|
+
y_sample += list(y_try[z_try < pdf_xy_try])
|
|
345
|
+
|
|
346
|
+
# Transform the samples to a 1D numpy array
|
|
347
|
+
x_sample = np.array(x_sample).flatten()
|
|
348
|
+
y_sample = np.array(y_sample).flatten()
|
|
349
|
+
# Return the correct number of samples
|
|
350
|
+
return x_sample[:size], y_sample[:size]
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def add_dictionaries_together(dictionary1, dictionary2):
|
|
354
|
+
"""
|
|
355
|
+
Adds two dictionaries with the same keys together.
|
|
356
|
+
|
|
357
|
+
Parameters
|
|
358
|
+
----------
|
|
359
|
+
dictionary1 : `dict`
|
|
360
|
+
dictionary to be added.
|
|
361
|
+
dictionary2 : `dict`
|
|
362
|
+
dictionary to be added.
|
|
363
|
+
|
|
364
|
+
Returns
|
|
365
|
+
----------
|
|
366
|
+
dictionary : `dict`
|
|
367
|
+
dictionary with added values.
|
|
368
|
+
"""
|
|
369
|
+
dictionary = {}
|
|
370
|
+
# Check if either dictionary empty, in which case only return the dictionary with values
|
|
371
|
+
if len(dictionary1) == 0:
|
|
372
|
+
return dictionary2
|
|
373
|
+
elif len(dictionary2) == 0:
|
|
374
|
+
return dictionary1
|
|
375
|
+
# Check if the keys are the same
|
|
376
|
+
if dictionary1.keys() != dictionary2.keys():
|
|
377
|
+
raise ValueError("The dictionaries have different keys.")
|
|
378
|
+
for key in dictionary1.keys():
|
|
379
|
+
value1 = dictionary1[key]
|
|
380
|
+
value2 = dictionary2[key]
|
|
381
|
+
|
|
382
|
+
# check if the value is empty
|
|
383
|
+
bool0 = len(value1) == 0 or len(value2) == 0
|
|
384
|
+
# check if the value is an ndarray or a list
|
|
385
|
+
bool1 = isinstance(value1, np.ndarray) and isinstance(value2, np.ndarray)
|
|
386
|
+
bool2 = isinstance(value1, list) and isinstance(value2, list)
|
|
387
|
+
bool3 = isinstance(value1, np.ndarray) and isinstance(value2, list)
|
|
388
|
+
bool4 = isinstance(value1, list) and isinstance(value2, np.ndarray)
|
|
389
|
+
bool4 = bool4 or bool3
|
|
390
|
+
bool5 = isinstance(value1, dict) and isinstance(value2, dict)
|
|
391
|
+
|
|
392
|
+
if bool0:
|
|
393
|
+
if len(value1) == 0 and len(value2) == 0:
|
|
394
|
+
dictionary[key] = np.array([])
|
|
395
|
+
elif len(value1) != 0 and len(value2) == 0:
|
|
396
|
+
dictionary[key] = np.array(value1)
|
|
397
|
+
elif len(value1) == 0 and len(value2) != 0:
|
|
398
|
+
dictionary[key] = np.array(value2)
|
|
399
|
+
elif bool1:
|
|
400
|
+
dictionary[key] = np.concatenate((value1, value2))
|
|
401
|
+
elif bool2:
|
|
402
|
+
dictionary[key] = value1 + value2
|
|
403
|
+
elif bool4:
|
|
404
|
+
dictionary[key] = np.concatenate((np.array(value1), np.array(value2)))
|
|
405
|
+
elif bool5:
|
|
406
|
+
dictionary[key] = add_dictionaries_together(
|
|
407
|
+
dictionary1[key], dictionary2[key]
|
|
408
|
+
)
|
|
409
|
+
else:
|
|
410
|
+
raise ValueError(
|
|
411
|
+
"The dictionary contains an item which is neither an ndarray nor a dictionary."
|
|
412
|
+
)
|
|
413
|
+
return dictionary
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
def trim_dictionary(dictionary, size):
|
|
417
|
+
"""
|
|
418
|
+
Filters an event dictionary to only contain the size.
|
|
419
|
+
|
|
420
|
+
Parameters
|
|
421
|
+
----------
|
|
422
|
+
dictionary : `dict`
|
|
423
|
+
dictionary to be trimmed.
|
|
424
|
+
size : `int`
|
|
425
|
+
size to trim the dictionary to.
|
|
426
|
+
|
|
427
|
+
Returns
|
|
428
|
+
----------
|
|
429
|
+
dictionary : `dict`
|
|
430
|
+
trimmed dictionary.
|
|
431
|
+
"""
|
|
432
|
+
for key in dictionary.keys():
|
|
433
|
+
# Check if the item is an ndarray
|
|
434
|
+
if isinstance(dictionary[key], np.ndarray):
|
|
435
|
+
dictionary[key] = dictionary[key][:size] # Trim the array
|
|
436
|
+
# Check if the item is a nested dictionary
|
|
437
|
+
elif isinstance(dictionary[key], dict):
|
|
438
|
+
dictionary[key] = trim_dictionary(
|
|
439
|
+
dictionary[key], size
|
|
440
|
+
) # Trim the nested dictionary
|
|
441
|
+
else:
|
|
442
|
+
raise ValueError(
|
|
443
|
+
"The dictionary contains an item which is neither an ndarray nor a dictionary."
|
|
444
|
+
)
|
|
445
|
+
return dictionary
|
|
446
|
+
|
|
447
|
+
def create_func_pdf_invcdf(x, y, category="function"):
|
|
448
|
+
"""
|
|
449
|
+
Function to create a interpolated function, inverse function or inverse cdf from the input x and y.
|
|
450
|
+
|
|
451
|
+
Parameters
|
|
452
|
+
----------
|
|
453
|
+
x : `numpy.ndarray`
|
|
454
|
+
x values. This has to sorted in ascending order.
|
|
455
|
+
y : `numpy.ndarray`
|
|
456
|
+
y values. Corresponding to the x values.
|
|
457
|
+
category : `str`, optional
|
|
458
|
+
category of the function. Default is "function". Other options are "function_inverse", "pdf" and "inv_cdf".
|
|
459
|
+
|
|
460
|
+
Returns
|
|
461
|
+
----------
|
|
462
|
+
pdf : `pdf function`
|
|
463
|
+
interpolated pdf function.
|
|
464
|
+
inv_pdf : `function inverse`
|
|
465
|
+
interpolated inverse pdf function.
|
|
466
|
+
inv_cdf : `function`
|
|
467
|
+
interpolated inverse cdf.
|
|
468
|
+
"""
|
|
469
|
+
|
|
470
|
+
idx = np.argwhere(np.isnan(y))
|
|
471
|
+
x = np.delete(x, idx)
|
|
472
|
+
y = np.delete(y, idx)
|
|
473
|
+
|
|
474
|
+
# create pdf with interpolation
|
|
475
|
+
pdf_unorm = interp1d(x, y, kind="cubic", fill_value="extrapolate")
|
|
476
|
+
if category == "function":
|
|
477
|
+
return pdf_unorm
|
|
478
|
+
if category == "function_inverse":
|
|
479
|
+
# create inverse function
|
|
480
|
+
return interp1d(y, x, kind="cubic", fill_value="extrapolate")
|
|
481
|
+
|
|
482
|
+
min_, max_ = min(x), max(x)
|
|
483
|
+
norm = quad(pdf_unorm, min_, max_)[0]
|
|
484
|
+
y = y / norm
|
|
485
|
+
if category == "pdf" or category is None:
|
|
486
|
+
# normalize the pdf
|
|
487
|
+
pdf = interp1d(x, y, kind="cubic", fill_value="extrapolate")
|
|
488
|
+
return pdf
|
|
489
|
+
# cdf
|
|
490
|
+
cdf_values = cumtrapz(y, x, initial=0)
|
|
491
|
+
idx = np.argwhere(cdf_values > 0)[0][0]
|
|
492
|
+
cdf_values = cdf_values[idx:]
|
|
493
|
+
x = x[idx:]
|
|
494
|
+
inv_cdf = interp1d(cdf_values, x, kind="cubic", fill_value="extrapolate")
|
|
495
|
+
if category == "inv_cdf":
|
|
496
|
+
return inv_cdf
|
|
497
|
+
if category == "all":
|
|
498
|
+
return([pdf, inv_cdf])
|
|
499
|
+
|
|
500
|
+
def create_conditioned_pdf_invcdf(x, conditioned_y, pdf_func, category):
|
|
501
|
+
"""
|
|
502
|
+
pdf_func is the function to calculate the pdf of x given y
|
|
503
|
+
x is an array and the output of pdf_func is an array
|
|
504
|
+
y is the condition
|
|
505
|
+
we consider parameter plane of x and y
|
|
506
|
+
|
|
507
|
+
Parameters
|
|
508
|
+
----------
|
|
509
|
+
x : `numpy.ndarray`
|
|
510
|
+
x values.
|
|
511
|
+
conditioned_y : `numpy.ndarray`
|
|
512
|
+
conditioned y values.
|
|
513
|
+
pdf_func : `function`
|
|
514
|
+
function to calculate the pdf of x given y.
|
|
515
|
+
category : `str`, optional
|
|
516
|
+
category of the function. Default is "function". Other options are "function_inverse", "pdf" and "inv_cdf".
|
|
517
|
+
"""
|
|
518
|
+
|
|
519
|
+
list_ = []
|
|
520
|
+
for y in conditioned_y:
|
|
521
|
+
phi = pdf_func(x,y)
|
|
522
|
+
# append pdf for each y along the x-axis
|
|
523
|
+
list_.append(create_func_pdf_invcdf(x, phi, category=category))
|
|
524
|
+
|
|
525
|
+
return list_
|
|
526
|
+
|
|
527
|
+
def create_func(x, y):
|
|
528
|
+
"""
|
|
529
|
+
Function to create a spline interpolated function from the input x and y.
|
|
530
|
+
|
|
531
|
+
Parameters
|
|
532
|
+
----------
|
|
533
|
+
x : `numpy.ndarray`
|
|
534
|
+
x values.
|
|
535
|
+
y : `numpy.ndarray`
|
|
536
|
+
y values.
|
|
537
|
+
|
|
538
|
+
Returns
|
|
539
|
+
----------
|
|
540
|
+
c : `numpy.ndarray`
|
|
541
|
+
spline coefficients.
|
|
542
|
+
"""
|
|
543
|
+
|
|
544
|
+
idx = np.argwhere(np.isnan(y))
|
|
545
|
+
x = np.delete(x, idx)
|
|
546
|
+
y = np.delete(y, idx)
|
|
547
|
+
return CubicSpline(x, y).c, x
|
|
548
|
+
|
|
549
|
+
def create_func_inv(x, y):
|
|
550
|
+
"""
|
|
551
|
+
Function to create a spline interpolated inverse function from the input x and y.
|
|
552
|
+
|
|
553
|
+
Parameters
|
|
554
|
+
----------
|
|
555
|
+
x : `numpy.ndarray`
|
|
556
|
+
x values.
|
|
557
|
+
y : `numpy.ndarray`
|
|
558
|
+
y values.
|
|
559
|
+
|
|
560
|
+
Returns
|
|
561
|
+
----------
|
|
562
|
+
c : `numpy.ndarray`
|
|
563
|
+
spline coefficients.
|
|
564
|
+
"""
|
|
565
|
+
|
|
566
|
+
idx = np.argwhere(np.isnan(y))
|
|
567
|
+
x = np.delete(x, idx)
|
|
568
|
+
y = np.delete(y, idx)
|
|
569
|
+
return CubicSpline(y, x).c, y
|
|
570
|
+
|
|
571
|
+
def create_pdf(x, y):
|
|
572
|
+
"""
|
|
573
|
+
Function to create a spline interpolated normalized pdf from the input x and y.
|
|
574
|
+
|
|
575
|
+
Parameters
|
|
576
|
+
----------
|
|
577
|
+
x : `numpy.ndarray`
|
|
578
|
+
x values.
|
|
579
|
+
y : `numpy.ndarray`
|
|
580
|
+
y values.
|
|
581
|
+
|
|
582
|
+
Returns
|
|
583
|
+
----------
|
|
584
|
+
c : `numpy.ndarray`
|
|
585
|
+
spline coefficients.
|
|
586
|
+
"""
|
|
587
|
+
idx = np.argwhere(np.isnan(y))
|
|
588
|
+
x = np.delete(x, idx)
|
|
589
|
+
y = np.delete(y, idx)
|
|
590
|
+
pdf_unorm = interp1d(x, y, kind="cubic", fill_value="extrapolate")
|
|
591
|
+
min_, max_ = min(x), max(x)
|
|
592
|
+
norm = quad(pdf_unorm, min_, max_)[0]
|
|
593
|
+
y = y / norm
|
|
594
|
+
return CubicSpline(x, y).c, x
|
|
595
|
+
|
|
596
|
+
def create_inv_cdf_array(x, y):
|
|
597
|
+
"""
|
|
598
|
+
Function to create a spline interpolated inverse cdf from the input x and y.
|
|
599
|
+
|
|
600
|
+
Parameters
|
|
601
|
+
----------
|
|
602
|
+
x : `numpy.ndarray`
|
|
603
|
+
x values.
|
|
604
|
+
y : `numpy.ndarray`
|
|
605
|
+
y values.
|
|
606
|
+
|
|
607
|
+
Returns
|
|
608
|
+
----------
|
|
609
|
+
c : `numpy.ndarray`
|
|
610
|
+
spline coefficients.
|
|
611
|
+
"""
|
|
612
|
+
|
|
613
|
+
idx = np.argwhere(np.isnan(y))
|
|
614
|
+
x = np.delete(x, idx)
|
|
615
|
+
y = np.delete(y, idx)
|
|
616
|
+
cdf_values = cumtrapz(y, x, initial=0)
|
|
617
|
+
cdf_values = cdf_values / cdf_values[-1]
|
|
618
|
+
# to remove duplicate values on x-axis before interpolation
|
|
619
|
+
# idx = np.argwhere(cdf_values > 0)[0][0]
|
|
620
|
+
# cdf_values = cdf_values[idx:]
|
|
621
|
+
# x = x[idx:]
|
|
622
|
+
# cdf_values = np.insert(cdf_values, 0, 0)
|
|
623
|
+
# x = np.insert(x, 0, x[idx-1])
|
|
624
|
+
return np.array([cdf_values, x])
|
|
625
|
+
|
|
626
|
+
def create_conditioned_pdf(x, conditioned_y, pdf_func):
|
|
627
|
+
"""
|
|
628
|
+
Function to create a conditioned pdf from the input x and y.
|
|
629
|
+
|
|
630
|
+
Parameters
|
|
631
|
+
----------
|
|
632
|
+
x : `numpy.ndarray`
|
|
633
|
+
x values.
|
|
634
|
+
conditioned_y : `numpy.ndarray`
|
|
635
|
+
conditioned y values.
|
|
636
|
+
pdf_func : `function`
|
|
637
|
+
function to calculate the pdf of x given y.
|
|
638
|
+
|
|
639
|
+
Returns
|
|
640
|
+
----------
|
|
641
|
+
list_ : `list`
|
|
642
|
+
list of pdfs.
|
|
643
|
+
"""
|
|
644
|
+
list_ = []
|
|
645
|
+
for y in conditioned_y:
|
|
646
|
+
phi = pdf_func(x,y)
|
|
647
|
+
list_.append(create_pdf(x, phi))
|
|
648
|
+
|
|
649
|
+
return np.array(list_)
|
|
650
|
+
|
|
651
|
+
def create_conditioned_inv_cdf_array(x, conditioned_y, pdf_func):
|
|
652
|
+
"""
|
|
653
|
+
Function to create a conditioned inv_cdf from the input x and y.
|
|
654
|
+
|
|
655
|
+
Parameters
|
|
656
|
+
----------
|
|
657
|
+
x : `numpy.ndarray`
|
|
658
|
+
x values.
|
|
659
|
+
conditioned_y : `numpy.ndarray`
|
|
660
|
+
conditioned y values.
|
|
661
|
+
pdf_func : `function`
|
|
662
|
+
function to calculate the pdf of x given y.
|
|
663
|
+
|
|
664
|
+
Returns
|
|
665
|
+
----------
|
|
666
|
+
list_ : `list`
|
|
667
|
+
list of inv_cdfs.
|
|
668
|
+
"""
|
|
669
|
+
|
|
670
|
+
list_ = []
|
|
671
|
+
for y in conditioned_y:
|
|
672
|
+
phi = pdf_func(x,y)
|
|
673
|
+
list_.append(create_inv_cdf_array(x, phi))
|
|
674
|
+
|
|
675
|
+
return np.array(list_)
|
|
676
|
+
|
|
677
|
+
def interpolator_from_pickle(
|
|
678
|
+
param_dict_given, directory, sub_directory, name, x, pdf_func=None, y=None, conditioned_y=None, dimension=1,category="pdf", create_new=False
|
|
679
|
+
):
|
|
680
|
+
"""
|
|
681
|
+
Function to decide which interpolator to use.
|
|
682
|
+
|
|
683
|
+
Parameters
|
|
684
|
+
----------
|
|
685
|
+
param_dict_given : `dict`
|
|
686
|
+
dictionary of parameters.
|
|
687
|
+
directory : `str`
|
|
688
|
+
directory to store the interpolator.
|
|
689
|
+
sub_directory : `str`
|
|
690
|
+
sub-directory to store the interpolator.
|
|
691
|
+
name : `str`
|
|
692
|
+
name of the interpolator.
|
|
693
|
+
x : `numpy.ndarray`
|
|
694
|
+
x values.
|
|
695
|
+
pdf_func : `function`
|
|
696
|
+
function to calculate the pdf of x given y.
|
|
697
|
+
y : `numpy.ndarray`
|
|
698
|
+
y values.
|
|
699
|
+
conditioned_y : `numpy.ndarray`
|
|
700
|
+
conditioned y values.
|
|
701
|
+
dimension : `int`
|
|
702
|
+
dimension of the interpolator. Default is 1.
|
|
703
|
+
category : `str`
|
|
704
|
+
category of the function. Default is "pdf".
|
|
705
|
+
create_new : `bool`
|
|
706
|
+
if True, create a new interpolator. Default is False.
|
|
707
|
+
|
|
708
|
+
Returns
|
|
709
|
+
----------
|
|
710
|
+
interpolator : `function`
|
|
711
|
+
interpolator function.
|
|
712
|
+
"""
|
|
713
|
+
|
|
714
|
+
# check first whether the directory, subdirectory and pickle exist
|
|
715
|
+
path_inv_cdf, it_exist = interpolator_pickle_path(
|
|
716
|
+
param_dict_given=param_dict_given,
|
|
717
|
+
directory=directory,
|
|
718
|
+
sub_directory=sub_directory,
|
|
719
|
+
interpolator_name=name,
|
|
720
|
+
)
|
|
721
|
+
if create_new:
|
|
722
|
+
it_exist = False
|
|
723
|
+
if it_exist:
|
|
724
|
+
print(f"{name} interpolator will be loaded from {path_inv_cdf}")
|
|
725
|
+
# load the interpolator
|
|
726
|
+
with open(path_inv_cdf, "rb") as handle:
|
|
727
|
+
interpolator = pickle.load(handle)
|
|
728
|
+
return interpolator
|
|
729
|
+
else:
|
|
730
|
+
print(f"{name} interpolator will be generated at {path_inv_cdf}")
|
|
731
|
+
|
|
732
|
+
# create the interpolator
|
|
733
|
+
if dimension==1:
|
|
734
|
+
if y is None:
|
|
735
|
+
y = pdf_func(x)
|
|
736
|
+
if category=="function":
|
|
737
|
+
interpolator = create_func(x, y)
|
|
738
|
+
elif category=="function_inverse":
|
|
739
|
+
interpolator = create_func_inv(x, y)
|
|
740
|
+
elif category=="pdf":
|
|
741
|
+
interpolator = create_pdf(x, y)
|
|
742
|
+
elif category=="inv_cdf":
|
|
743
|
+
interpolator = create_inv_cdf_array(x, y)
|
|
744
|
+
else:
|
|
745
|
+
raise ValueError("The category given should be function, function_inverse, pdf or inv_cdf.")
|
|
746
|
+
elif dimension==2:
|
|
747
|
+
if category=="pdf":
|
|
748
|
+
interpolator = create_conditioned_pdf(x, conditioned_y, pdf_func)
|
|
749
|
+
elif category=="inv_cdf":
|
|
750
|
+
interpolator = create_conditioned_inv_cdf_array(x, conditioned_y, pdf_func)
|
|
751
|
+
else:
|
|
752
|
+
raise ValueError("The dimension is not supported.")
|
|
753
|
+
# save the interpolator
|
|
754
|
+
with open(path_inv_cdf, "wb") as handle:
|
|
755
|
+
pickle.dump(interpolator, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
|
756
|
+
return interpolator
|
|
757
|
+
|
|
758
|
+
def interpolator_pickle_path(
|
|
759
|
+
param_dict_given,
|
|
760
|
+
directory,
|
|
761
|
+
sub_directory,
|
|
762
|
+
interpolator_name,
|
|
763
|
+
):
|
|
764
|
+
"""
|
|
765
|
+
Function to create the interpolator pickle file path.
|
|
766
|
+
|
|
767
|
+
Parameters
|
|
768
|
+
----------
|
|
769
|
+
param_dict_given : `dict`
|
|
770
|
+
dictionary of parameters.
|
|
771
|
+
directory : `str`
|
|
772
|
+
directory to store the interpolator.
|
|
773
|
+
sub_directory : `str`
|
|
774
|
+
sub-directory to store the interpolator.
|
|
775
|
+
interpolator_name : `str`
|
|
776
|
+
name of the interpolator.
|
|
777
|
+
|
|
778
|
+
Returns
|
|
779
|
+
----------
|
|
780
|
+
path_inv_cdf : `str`
|
|
781
|
+
path of the interpolator pickle file.
|
|
782
|
+
it_exist : `bool`
|
|
783
|
+
if True, the interpolator exists.
|
|
784
|
+
"""
|
|
785
|
+
|
|
786
|
+
# check the dir 'interpolator' exist
|
|
787
|
+
full_dir = directory + "/" + sub_directory
|
|
788
|
+
if not os.path.exists(directory):
|
|
789
|
+
os.makedirs(directory)
|
|
790
|
+
os.makedirs(full_dir)
|
|
791
|
+
else:
|
|
792
|
+
if not os.path.exists(full_dir):
|
|
793
|
+
os.makedirs(full_dir)
|
|
794
|
+
|
|
795
|
+
# check if param_dict_list.pickle exists
|
|
796
|
+
path1 = full_dir + "/init_dict.pickle"
|
|
797
|
+
if not os.path.exists(path1):
|
|
798
|
+
dict_list = []
|
|
799
|
+
with open(path1, "wb") as handle:
|
|
800
|
+
pickle.dump(dict_list, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
|
801
|
+
|
|
802
|
+
# check if the input dict is the same as one of the dict inside the pickle file
|
|
803
|
+
param_dict_stored = pickle.load(open(path1, "rb"))
|
|
804
|
+
|
|
805
|
+
path2 = full_dir
|
|
806
|
+
len_ = len(param_dict_stored)
|
|
807
|
+
if param_dict_given in param_dict_stored:
|
|
808
|
+
idx = param_dict_stored.index(param_dict_given)
|
|
809
|
+
# load the interpolator
|
|
810
|
+
path_inv_cdf = path2 + "/" + interpolator_name + "_" + str(idx) + ".pickle"
|
|
811
|
+
# there will be exception if the file is deleted by mistake
|
|
812
|
+
if os.path.exists(path_inv_cdf):
|
|
813
|
+
it_exist = True
|
|
814
|
+
else:
|
|
815
|
+
it_exist = False
|
|
816
|
+
else:
|
|
817
|
+
it_exist = False
|
|
818
|
+
path_inv_cdf = path2 + "/" + interpolator_name + "_" + str(len_) + ".pickle"
|
|
819
|
+
param_dict_stored.append(param_dict_given)
|
|
820
|
+
with open(path1, "wb") as handle:
|
|
821
|
+
pickle.dump(param_dict_stored, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
|
822
|
+
|
|
823
|
+
return path_inv_cdf, it_exist
|
|
824
|
+
|
|
825
|
+
def interpolator_pdf_conditioned(x, conditioned_y, y_array, interpolator_list):
|
|
826
|
+
"""
|
|
827
|
+
Function to find the pdf interpolator coefficients from the conditioned y.
|
|
828
|
+
|
|
829
|
+
Parameters
|
|
830
|
+
----------
|
|
831
|
+
x : `numpy.ndarray`
|
|
832
|
+
x values.
|
|
833
|
+
conditioned_y : `float`
|
|
834
|
+
conditioned y value.
|
|
835
|
+
y_array : `numpy.ndarray`
|
|
836
|
+
y values.
|
|
837
|
+
interpolator_list : `list`
|
|
838
|
+
list of interpolators.
|
|
839
|
+
|
|
840
|
+
Returns
|
|
841
|
+
----------
|
|
842
|
+
interpolator_list[idx](x) : `numpy.ndarray`
|
|
843
|
+
samples from the interpolator.
|
|
844
|
+
"""
|
|
845
|
+
# find the index of z in zlist
|
|
846
|
+
idx = np.searchsorted(y_array, conditioned_y)
|
|
847
|
+
|
|
848
|
+
return interpolator_list[idx](x)
|
|
849
|
+
|
|
850
|
+
def interpolator_sampler_conditioned(conditioned_y, y_array, interpolator_list, size=1000):
|
|
851
|
+
"""
|
|
852
|
+
Function to find sampler interpolator coefficients from the conditioned y.
|
|
853
|
+
|
|
854
|
+
Parameters
|
|
855
|
+
----------
|
|
856
|
+
conditioned_y : `float`
|
|
857
|
+
conditioned y value.
|
|
858
|
+
y_array : `numpy.ndarray`
|
|
859
|
+
y values.
|
|
860
|
+
interpolator_list : `list`
|
|
861
|
+
list of interpolators.
|
|
862
|
+
size : `int`
|
|
863
|
+
number of samples.
|
|
864
|
+
|
|
865
|
+
Returns
|
|
866
|
+
----------
|
|
867
|
+
"""
|
|
868
|
+
|
|
869
|
+
# find the index of z in zlist
|
|
870
|
+
idx = np.searchsorted(y_array, conditioned_y)
|
|
871
|
+
u = np.random.uniform(0, 1, size=size)
|
|
872
|
+
return interpolator_list[idx](u)
|
|
873
|
+
|
|
874
|
+
@njit
|
|
875
|
+
def cubic_spline_interpolator(xnew, coefficients, x):
|
|
876
|
+
"""
|
|
877
|
+
Function to interpolate using cubic spline.
|
|
878
|
+
|
|
879
|
+
Parameters
|
|
880
|
+
----------
|
|
881
|
+
xnew : `numpy.ndarray`
|
|
882
|
+
new x values.
|
|
883
|
+
coefficients : `numpy.ndarray`
|
|
884
|
+
coefficients of the cubic spline.
|
|
885
|
+
x : `numpy.ndarray`
|
|
886
|
+
x values.
|
|
887
|
+
|
|
888
|
+
Returns
|
|
889
|
+
----------
|
|
890
|
+
result : `numpy.ndarray`
|
|
891
|
+
interpolated values.
|
|
892
|
+
"""
|
|
893
|
+
|
|
894
|
+
# Handling extrapolation
|
|
895
|
+
i = np.searchsorted(x, xnew) - 1
|
|
896
|
+
idx1 = xnew <= x[0]
|
|
897
|
+
idx2 = xnew > x[-1]
|
|
898
|
+
i[idx1] = 0
|
|
899
|
+
i[idx2] = len(x) - 2
|
|
900
|
+
|
|
901
|
+
# Calculate the relative position within the interval
|
|
902
|
+
dx = xnew - x[i]
|
|
903
|
+
|
|
904
|
+
# Calculate the interpolated value
|
|
905
|
+
# Cubic polynomial: a + b*dx + c*dx^2 + d*dx^3
|
|
906
|
+
a, b, c, d = coefficients[:, i]
|
|
907
|
+
#result = a + b*dx + c*dx**2 + d*dx**3
|
|
908
|
+
result = d + c*dx + b*dx**2 + a*dx**3
|
|
909
|
+
return result
|
|
910
|
+
|
|
911
|
+
@njit
|
|
912
|
+
def inverse_transform_sampler(size, cdf, x):
|
|
913
|
+
"""
|
|
914
|
+
Function to sample from the inverse transform method.
|
|
915
|
+
|
|
916
|
+
Parameters
|
|
917
|
+
----------
|
|
918
|
+
size : `int`
|
|
919
|
+
number of samples.
|
|
920
|
+
cdf : `numpy.ndarray`
|
|
921
|
+
cdf values.
|
|
922
|
+
x : `numpy.ndarray`
|
|
923
|
+
x values.
|
|
924
|
+
|
|
925
|
+
Returns
|
|
926
|
+
----------
|
|
927
|
+
samples : `numpy.ndarray`
|
|
928
|
+
samples from the cdf.
|
|
929
|
+
"""
|
|
930
|
+
|
|
931
|
+
u = np.random.uniform(0, 1, size)
|
|
932
|
+
idx = np.searchsorted(cdf, u)
|
|
933
|
+
x1, x0, y1, y0 = cdf[idx], cdf[idx-1], x[idx], x[idx-1]
|
|
934
|
+
samples = y0 + (y1 - y0) * (u - x0) / (x1 - x0)
|
|
935
|
+
return samples
|
|
936
|
+
|
|
937
|
+
def batch_handler(size, batch_size, sampling_routine, output_jsonfile, save_batch=True, resume=False, param_name='parameters'):
|
|
938
|
+
"""
|
|
939
|
+
Function to run the sampling in batches.
|
|
940
|
+
|
|
941
|
+
Parameters
|
|
942
|
+
----------
|
|
943
|
+
size : `int`
|
|
944
|
+
number of samples.
|
|
945
|
+
batch_size : `int`
|
|
946
|
+
batch size.
|
|
947
|
+
sampling_routine : `function`
|
|
948
|
+
sampling function. It should have 'size' as input and return a dictionary.
|
|
949
|
+
output_jsonfile : `str`
|
|
950
|
+
json file name for storing the parameters.
|
|
951
|
+
save_batch : `bool`, optional
|
|
952
|
+
if True, save sampled parameters in each iteration. Default is True.
|
|
953
|
+
resume : `bool`, optional
|
|
954
|
+
if True, resume sampling from the last batch. Default is False.
|
|
955
|
+
param_name : `str`, optional
|
|
956
|
+
name of the parameter. Default is 'parameters'.
|
|
957
|
+
|
|
958
|
+
Returns
|
|
959
|
+
----------
|
|
960
|
+
dict_buffer : `dict`
|
|
961
|
+
dictionary of parameters.
|
|
962
|
+
"""
|
|
963
|
+
|
|
964
|
+
# sampling in batches
|
|
965
|
+
if resume and os.path.exists(output_jsonfile):
|
|
966
|
+
# get sample from json file
|
|
967
|
+
dict_buffer = get_param_from_json(output_jsonfile)
|
|
968
|
+
else:
|
|
969
|
+
dict_buffer = None
|
|
970
|
+
|
|
971
|
+
# if size is multiple of batch_size
|
|
972
|
+
if size % batch_size == 0:
|
|
973
|
+
num_batches = size // batch_size
|
|
974
|
+
# if size is not multiple of batch_size
|
|
975
|
+
else:
|
|
976
|
+
num_batches = size // batch_size + 1
|
|
977
|
+
|
|
978
|
+
print(
|
|
979
|
+
f"chosen batch size = {batch_size} with total size = {size}"
|
|
980
|
+
)
|
|
981
|
+
print(f"There will be {num_batches} batche(s)")
|
|
982
|
+
|
|
983
|
+
# note frac_batches+(num_batches-1)*batch_size = size
|
|
984
|
+
if size > batch_size:
|
|
985
|
+
frac_batches = size - (num_batches - 1) * batch_size
|
|
986
|
+
# if size is less than batch_size
|
|
987
|
+
else:
|
|
988
|
+
frac_batches = size
|
|
989
|
+
track_batches = 0 # to track the number of batches
|
|
990
|
+
|
|
991
|
+
if not resume:
|
|
992
|
+
# create new first batch with the frac_batches
|
|
993
|
+
track_batches, dict_buffer = create_batch_params(sampling_routine, frac_batches, dict_buffer, save_batch, output_jsonfile, track_batches=track_batches)
|
|
994
|
+
else:
|
|
995
|
+
# check where to resume from
|
|
996
|
+
# identify the last batch and assign current batch number
|
|
997
|
+
# try-except is added to avoid the error when the file does not exist or if the file is empty or corrupted or does not have the required key.
|
|
998
|
+
try:
|
|
999
|
+
print(f"resuming from {output_jsonfile}")
|
|
1000
|
+
len_ = len(list(dict_buffer.values())[0])
|
|
1001
|
+
track_batches = (len_ - frac_batches) // batch_size + 1
|
|
1002
|
+
except:
|
|
1003
|
+
# create new first batch with the frac_batches
|
|
1004
|
+
track_batches, dict_buffer = create_batch_params(sampling_routine, frac_batches, dict_buffer, save_batch, output_jsonfile, track_batches=track_batches)
|
|
1005
|
+
|
|
1006
|
+
# loop over the remaining batches
|
|
1007
|
+
min_, max_ = track_batches, num_batches
|
|
1008
|
+
# print(f"min_ = {min_}, max_ = {max_}")
|
|
1009
|
+
save_param = False
|
|
1010
|
+
if min_ == max_:
|
|
1011
|
+
print(f"{param_name} already sampled.")
|
|
1012
|
+
elif min_ > max_:
|
|
1013
|
+
len_ = len(list(dict_buffer.values())[0])
|
|
1014
|
+
print(f"existing {param_name} size is {len_} is more than the required size={size}. It will be trimmed.")
|
|
1015
|
+
dict_buffer = trim_dictionary(dict_buffer, size)
|
|
1016
|
+
save_param = True
|
|
1017
|
+
else:
|
|
1018
|
+
for i in range(min_, max_):
|
|
1019
|
+
_, dict_buffer = create_batch_params(sampling_routine, batch_size, dict_buffer, save_batch, output_jsonfile, track_batches=i, resume=True)
|
|
1020
|
+
|
|
1021
|
+
if save_batch:
|
|
1022
|
+
# if save_batch=True, then dict_buffer is only the last batch
|
|
1023
|
+
dict_buffer = get_param_from_json(output_jsonfile)
|
|
1024
|
+
else: # dont save in batches
|
|
1025
|
+
# this if condition is required if there is nothing to save
|
|
1026
|
+
save_param = True
|
|
1027
|
+
|
|
1028
|
+
if save_param:
|
|
1029
|
+
# store all params in json file
|
|
1030
|
+
print(f"saving all {param_name} in {output_jsonfile} ")
|
|
1031
|
+
append_json(output_jsonfile, dict_buffer, replace=True)
|
|
1032
|
+
|
|
1033
|
+
return dict_buffer
|
|
1034
|
+
|
|
1035
|
+
def create_batch_params(sampling_routine, frac_batches, dict_buffer, save_batch, output_jsonfile, track_batches, resume=False):
|
|
1036
|
+
"""
|
|
1037
|
+
Helper function to batch_handler. It create batch parameters and store in a dictionary.
|
|
1038
|
+
|
|
1039
|
+
Parameters
|
|
1040
|
+
----------
|
|
1041
|
+
sampling_routine : `function`
|
|
1042
|
+
sampling function. It should have 'size' as input and return a dictionary.
|
|
1043
|
+
frac_batches : `int`
|
|
1044
|
+
batch size.
|
|
1045
|
+
dict_buffer : `dict`
|
|
1046
|
+
dictionary of parameters.
|
|
1047
|
+
save_batch : `bool`
|
|
1048
|
+
if True, save sampled parameters in each iteration.
|
|
1049
|
+
output_jsonfile : `str`
|
|
1050
|
+
json file name for storing the parameters.
|
|
1051
|
+
track_batches : `int`
|
|
1052
|
+
track the number of batches.
|
|
1053
|
+
resume : `bool`, optional
|
|
1054
|
+
if True, resume sampling from the last batch. Default is False.
|
|
1055
|
+
|
|
1056
|
+
Returns
|
|
1057
|
+
----------
|
|
1058
|
+
track_batches : `int`
|
|
1059
|
+
track the number of batches.
|
|
1060
|
+
"""
|
|
1061
|
+
|
|
1062
|
+
track_batches = track_batches + 1
|
|
1063
|
+
print(f"Batch no. {track_batches}")
|
|
1064
|
+
param = sampling_routine(size=frac_batches, save_batch=save_batch, output_jsonfile=output_jsonfile, resume=resume)
|
|
1065
|
+
|
|
1066
|
+
# adding batches and hold it in the buffer attribute.
|
|
1067
|
+
if not save_batch:
|
|
1068
|
+
# in the new batch (new sampling run), dict_buffer will be None
|
|
1069
|
+
if dict_buffer is None:
|
|
1070
|
+
dict_buffer = param
|
|
1071
|
+
else:
|
|
1072
|
+
for key, value in param.items():
|
|
1073
|
+
dict_buffer[key] = np.concatenate((dict_buffer[key], value))
|
|
1074
|
+
else:
|
|
1075
|
+
# store all params in json file
|
|
1076
|
+
dict_buffer = append_json(file_name=output_jsonfile, new_dictionary=param, old_dictionary=dict_buffer, replace=not (resume))
|
|
1077
|
+
|
|
1078
|
+
return track_batches, dict_buffer
|