jaxspec 0.2.2.dev0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/analysis/_plot.py +5 -5
- jaxspec/analysis/results.py +38 -25
- jaxspec/data/obsconf.py +9 -3
- jaxspec/data/observation.py +3 -1
- jaxspec/data/ogip.py +9 -2
- jaxspec/data/util.py +17 -11
- jaxspec/experimental/interpolator.py +74 -0
- jaxspec/experimental/interpolator_jax.py +79 -0
- jaxspec/experimental/intrument_models.py +159 -0
- jaxspec/experimental/nested_sampler.py +78 -0
- jaxspec/experimental/tabulated.py +264 -0
- jaxspec/fit/__init__.py +3 -0
- jaxspec/{fit.py → fit/_bayesian_model.py} +86 -338
- jaxspec/{_fit → fit}/_build_model.py +42 -6
- jaxspec/fit/_fitter.py +255 -0
- jaxspec/model/abc.py +52 -80
- jaxspec/model/additive.py +14 -5
- jaxspec/model/background.py +17 -14
- jaxspec/model/instrument.py +81 -0
- jaxspec/model/list.py +4 -1
- jaxspec/model/multiplicative.py +32 -12
- jaxspec/util/integrate.py +17 -5
- {jaxspec-0.2.2.dev0.dist-info → jaxspec-0.3.0.dist-info}/METADATA +9 -9
- jaxspec-0.3.0.dist-info/RECORD +42 -0
- jaxspec-0.2.2.dev0.dist-info/RECORD +0 -34
- /jaxspec/{_fit → experimental}/__init__.py +0 -0
- {jaxspec-0.2.2.dev0.dist-info → jaxspec-0.3.0.dist-info}/WHEEL +0 -0
- {jaxspec-0.2.2.dev0.dist-info → jaxspec-0.3.0.dist-info}/entry_points.txt +0 -0
- {jaxspec-0.2.2.dev0.dist-info → jaxspec-0.3.0.dist-info}/licenses/LICENSE.md +0 -0
jaxspec/analysis/_plot.py
CHANGED
|
@@ -59,8 +59,8 @@ def _plot_poisson_data_with_error(
|
|
|
59
59
|
y,
|
|
60
60
|
xerr=np.abs(x_bins - np.sqrt(x_bins[0] * x_bins[1])),
|
|
61
61
|
yerr=[
|
|
62
|
-
y - y_low,
|
|
63
|
-
y_high - y,
|
|
62
|
+
np.maximum(y - y_low, 0),
|
|
63
|
+
np.maximum(y_high - y, 0),
|
|
64
64
|
],
|
|
65
65
|
color=color,
|
|
66
66
|
linestyle=linestyle,
|
|
@@ -149,13 +149,13 @@ def _compute_effective_area(
|
|
|
149
149
|
mid_bins_arf = obsconf.in_energies.mean(axis=0) * u.keV
|
|
150
150
|
mid_bins_arf = mid_bins_arf.to(x_unit, u.spectral())
|
|
151
151
|
e_grid = np.linspace(*xbins, 10)
|
|
152
|
-
interpolated_arf = np.interp(e_grid, mid_bins_arf, obsconf.area)
|
|
152
|
+
interpolated_arf = np.interp(e_grid.value, mid_bins_arf.value, obsconf.area)
|
|
153
153
|
integrated_arf = (
|
|
154
|
-
trapezoid(interpolated_arf, x=e_grid, axis=0)
|
|
154
|
+
trapezoid(interpolated_arf, x=e_grid.value, axis=0)
|
|
155
155
|
/ (
|
|
156
156
|
np.abs(
|
|
157
157
|
xbins[1] - xbins[0]
|
|
158
|
-
) # Must fold in abs because some units reverse the ordering of the bins
|
|
158
|
+
).value # Must fold in abs because some units reverse the ordering of the bins
|
|
159
159
|
)
|
|
160
160
|
* u.cm**2
|
|
161
161
|
)
|
jaxspec/analysis/results.py
CHANGED
|
@@ -42,6 +42,11 @@ V = TypeVar("V")
|
|
|
42
42
|
T = TypeVar("T")
|
|
43
43
|
|
|
44
44
|
|
|
45
|
+
def auto_in_axes(pytree, axis=0):
|
|
46
|
+
"""Return a pytree of 0/None depending on whether the leaf is batched."""
|
|
47
|
+
return jax.tree.map(lambda x: axis if (hasattr(x, "ndim") and x.ndim > 0) else None, pytree)
|
|
48
|
+
|
|
49
|
+
|
|
45
50
|
class FitResult:
|
|
46
51
|
"""
|
|
47
52
|
Container for the result of a fit using any ModelFitter class.
|
|
@@ -54,17 +59,17 @@ class FitResult:
|
|
|
54
59
|
inference_data: az.InferenceData,
|
|
55
60
|
background_model: BackgroundModel = None,
|
|
56
61
|
):
|
|
57
|
-
self.model = bayesian_fitter.
|
|
62
|
+
self.model = bayesian_fitter.spectral_model
|
|
58
63
|
self.bayesian_fitter = bayesian_fitter
|
|
59
64
|
self.inference_data = inference_data
|
|
60
|
-
self.obsconfs = bayesian_fitter.
|
|
65
|
+
self.obsconfs = bayesian_fitter._observation_container
|
|
61
66
|
self.background_model = background_model
|
|
62
67
|
|
|
63
68
|
# Add the model used in fit to the metadata
|
|
64
69
|
for group in self.inference_data.groups():
|
|
65
70
|
group_name = group.split("/")[-1]
|
|
66
71
|
metadata = getattr(self.inference_data, group_name).attrs
|
|
67
|
-
metadata["model"] = str(self.model)
|
|
72
|
+
# metadata["model"] = str(self.model)
|
|
68
73
|
# TODO : Store metadata about observations used in the fitting process
|
|
69
74
|
|
|
70
75
|
@property
|
|
@@ -78,6 +83,7 @@ class FitResult:
|
|
|
78
83
|
def _ppc_folded_branches(self, obs_id):
|
|
79
84
|
obs = self.obsconfs[obs_id]
|
|
80
85
|
|
|
86
|
+
# Slice the parameters corresponding to the current ObsID
|
|
81
87
|
if len(next(iter(self.input_parameters.values())).shape) > 2:
|
|
82
88
|
idx = list(self.obsconfs.keys()).index(obs_id)
|
|
83
89
|
obs_parameters = jax.tree.map(lambda x: x[..., idx], self.input_parameters)
|
|
@@ -85,7 +91,7 @@ class FitResult:
|
|
|
85
91
|
else:
|
|
86
92
|
obs_parameters = self.input_parameters
|
|
87
93
|
|
|
88
|
-
if self.bayesian_fitter.sparse:
|
|
94
|
+
if self.bayesian_fitter.settings.get("sparse", False):
|
|
89
95
|
transfer_matrix = BCOO.from_scipy_sparse(
|
|
90
96
|
obs.transfer_matrix.data.to_scipy_sparse().tocsr()
|
|
91
97
|
)
|
|
@@ -98,6 +104,7 @@ class FitResult:
|
|
|
98
104
|
flux_func = jax.jit(
|
|
99
105
|
jax.vmap(jax.vmap(lambda p: self.model.photon_flux(p, *energies, split_branches=True)))
|
|
100
106
|
)
|
|
107
|
+
|
|
101
108
|
convolve_func = jax.jit(
|
|
102
109
|
jax.vmap(jax.vmap(lambda flux: jnp.clip(transfer_matrix @ flux, a_min=1e-6)))
|
|
103
110
|
)
|
|
@@ -124,13 +131,14 @@ class FitResult:
|
|
|
124
131
|
|
|
125
132
|
for key, value in input_parameters.items():
|
|
126
133
|
module, parameter = key.rsplit("_", 1)
|
|
134
|
+
key_to_search = f"mod/~/{module}_{parameter}"
|
|
127
135
|
|
|
128
|
-
if
|
|
136
|
+
if key_to_search in posterior.keys():
|
|
129
137
|
# We add as extra dimension as there might be different values per observation
|
|
130
|
-
if posterior[
|
|
131
|
-
to_set = posterior[
|
|
138
|
+
if posterior[key_to_search].shape == samples_shape:
|
|
139
|
+
to_set = posterior[key_to_search][..., None]
|
|
132
140
|
else:
|
|
133
|
-
to_set = posterior[
|
|
141
|
+
to_set = posterior[key_to_search]
|
|
134
142
|
|
|
135
143
|
input_parameters[f"{module}_{parameter}"] = to_set
|
|
136
144
|
|
|
@@ -299,7 +307,7 @@ class FitResult:
|
|
|
299
307
|
|
|
300
308
|
return value
|
|
301
309
|
|
|
302
|
-
def to_chain(self, name: str) -> Chain:
|
|
310
|
+
def to_chain(self, name: str, parameter_kind="mod") -> Chain:
|
|
303
311
|
"""
|
|
304
312
|
Return a ChainConsumer Chain object from the posterior distribution of the parameters_type.
|
|
305
313
|
|
|
@@ -308,9 +316,7 @@ class FitResult:
|
|
|
308
316
|
"""
|
|
309
317
|
|
|
310
318
|
keys_to_drop = [
|
|
311
|
-
key
|
|
312
|
-
for key in self.inference_data.posterior.keys()
|
|
313
|
-
if (key.startswith("_") or key.startswith("bkg"))
|
|
319
|
+
key for key in self.inference_data.posterior.keys() if not key.startswith("mod")
|
|
314
320
|
]
|
|
315
321
|
|
|
316
322
|
reduced_id = az.extract(
|
|
@@ -338,6 +344,8 @@ class FitResult:
|
|
|
338
344
|
|
|
339
345
|
df = pd.concat(df_list, axis=1)
|
|
340
346
|
|
|
347
|
+
df = df.rename(columns=lambda x: x.split("/~/")[-1])
|
|
348
|
+
|
|
341
349
|
return Chain(samples=df, name=name)
|
|
342
350
|
|
|
343
351
|
@property
|
|
@@ -450,7 +458,7 @@ class FitResult:
|
|
|
450
458
|
legend_labels = []
|
|
451
459
|
|
|
452
460
|
count = az.extract(
|
|
453
|
-
self.inference_data, var_names=f"
|
|
461
|
+
self.inference_data, var_names=f"obs/~/{obs_id}", group="posterior_predictive"
|
|
454
462
|
).values.T
|
|
455
463
|
|
|
456
464
|
xbins, exposure, integrated_arf = _compute_effective_area(obsconf, x_unit)
|
|
@@ -465,7 +473,9 @@ class FitResult:
|
|
|
465
473
|
case "photon_flux_density":
|
|
466
474
|
denominator = (xbins[1] - xbins[0]) * integrated_arf * exposure
|
|
467
475
|
|
|
468
|
-
y_samples =
|
|
476
|
+
y_samples = count * u.ct / denominator
|
|
477
|
+
|
|
478
|
+
y_samples = y_samples.to(y_units)
|
|
469
479
|
|
|
470
480
|
y_observed, y_observed_low, y_observed_high = _error_bars_for_observed_data(
|
|
471
481
|
obsconf.folded_counts.data, denominator, y_units
|
|
@@ -491,8 +501,8 @@ class FitResult:
|
|
|
491
501
|
alpha=0.7,
|
|
492
502
|
)
|
|
493
503
|
|
|
494
|
-
lowest_y =
|
|
495
|
-
highest_y =
|
|
504
|
+
lowest_y = np.nanmin(y_observed)
|
|
505
|
+
highest_y = np.nanmax(y_observed)
|
|
496
506
|
|
|
497
507
|
legend_plots.append((true_data_plot,))
|
|
498
508
|
legend_labels.append("Observed")
|
|
@@ -522,7 +532,10 @@ class FitResult:
|
|
|
522
532
|
count.reshape((count.shape[0] * count.shape[1], -1))
|
|
523
533
|
* u.ct
|
|
524
534
|
/ denominator
|
|
525
|
-
)
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
y_samples = y_samples.to(y_units)
|
|
538
|
+
|
|
526
539
|
component_plot = _plot_binned_samples_with_error(
|
|
527
540
|
ax[0],
|
|
528
541
|
xbins.value,
|
|
@@ -545,7 +558,7 @@ class FitResult:
|
|
|
545
558
|
if self.background_model is None
|
|
546
559
|
else az.extract(
|
|
547
560
|
self.inference_data,
|
|
548
|
-
var_names=f"
|
|
561
|
+
var_names=f"bkg/~/{obs_id}",
|
|
549
562
|
group="posterior_predictive",
|
|
550
563
|
).values.T
|
|
551
564
|
)
|
|
@@ -577,18 +590,18 @@ class FitResult:
|
|
|
577
590
|
alpha=0.7,
|
|
578
591
|
)
|
|
579
592
|
|
|
580
|
-
lowest_y =
|
|
581
|
-
highest_y =
|
|
593
|
+
# lowest_y = np.nanmin(lowest_y.min, np.nanmin(y_observed_bkg.value).astype(float))
|
|
594
|
+
# highest_y = np.nanmax(highest_y.value.astype(float), np.nanmax(y_observed_bkg.value).astype(float))
|
|
582
595
|
|
|
583
596
|
legend_plots.append((true_bkg_plot,))
|
|
584
597
|
legend_labels.append("Observed (bkg)")
|
|
585
598
|
legend_plots += model_bkg_plot
|
|
586
599
|
legend_labels.append("Model (bkg)")
|
|
587
600
|
|
|
588
|
-
max_residuals = np.
|
|
601
|
+
max_residuals = min(3.5, np.nanmax(np.abs(residual_samples)))
|
|
589
602
|
|
|
590
603
|
ax[0].loglog()
|
|
591
|
-
ax[1].set_ylim(-
|
|
604
|
+
ax[1].set_ylim(-np.nanmax([3.5, max_residuals]), +np.nanmax([3.5, max_residuals]))
|
|
592
605
|
ax[0].set_ylabel(f"Folded spectrum\n [{y_units:latex_inline}]")
|
|
593
606
|
ax[1].set_ylabel("Residuals \n" + r"[$\sigma$]")
|
|
594
607
|
|
|
@@ -635,9 +648,9 @@ class FitResult:
|
|
|
635
648
|
|
|
636
649
|
fig.align_ylabels()
|
|
637
650
|
plt.subplots_adjust(hspace=0.0)
|
|
651
|
+
fig.suptitle(f"Posterior predictive - {obs_id}" if title is None else title)
|
|
638
652
|
fig.tight_layout()
|
|
639
653
|
figure_list.append(fig)
|
|
640
|
-
fig.suptitle(f"Posterior predictive - {obs_id}" if title is None else title)
|
|
641
654
|
# fig.show()
|
|
642
655
|
|
|
643
656
|
plt.tight_layout()
|
|
@@ -651,9 +664,9 @@ class FitResult:
|
|
|
651
664
|
"""
|
|
652
665
|
|
|
653
666
|
consumer = ChainConsumer()
|
|
654
|
-
consumer.add_chain(self.to_chain(
|
|
667
|
+
consumer.add_chain(self.to_chain("Model"))
|
|
655
668
|
|
|
656
|
-
return consumer.analysis.get_latex_table(caption="
|
|
669
|
+
return consumer.analysis.get_latex_table(caption="Fit result", label="tab:results")
|
|
657
670
|
|
|
658
671
|
def plot_corner(
|
|
659
672
|
self,
|
jaxspec/data/obsconf.py
CHANGED
|
@@ -85,13 +85,20 @@ class ObsConfiguration(xr.Dataset):
|
|
|
85
85
|
|
|
86
86
|
from .util import data_path_finder
|
|
87
87
|
|
|
88
|
-
arf_path_default, rmf_path_default, bkg_path_default = data_path_finder(
|
|
88
|
+
arf_path_default, rmf_path_default, bkg_path_default = data_path_finder(
|
|
89
|
+
pha_path,
|
|
90
|
+
require_arf=(arf_path is None) and (arf_path != ""),
|
|
91
|
+
require_rmf=rmf_path is None,
|
|
92
|
+
require_bkg=bkg_path is None,
|
|
93
|
+
)
|
|
89
94
|
|
|
90
95
|
arf_path = arf_path_default if arf_path is None else arf_path
|
|
91
96
|
rmf_path = rmf_path_default if rmf_path is None else rmf_path
|
|
92
97
|
bkg_path = bkg_path_default if bkg_path is None else bkg_path
|
|
93
98
|
|
|
94
|
-
instrument = Instrument.from_ogip_file(
|
|
99
|
+
instrument = Instrument.from_ogip_file(
|
|
100
|
+
rmf_path, arf_path=arf_path if arf_path != "" else None
|
|
101
|
+
)
|
|
95
102
|
observation = Observation.from_pha_file(pha_path, bkg_path=bkg_path)
|
|
96
103
|
|
|
97
104
|
return cls.from_instrument(
|
|
@@ -141,7 +148,6 @@ class ObsConfiguration(xr.Dataset):
|
|
|
141
148
|
transfer_matrix = grouping @ (redistribution * area * exposure)
|
|
142
149
|
|
|
143
150
|
# Exclude bins out of the considered energy range, and bins without contribution from the RMF
|
|
144
|
-
|
|
145
151
|
row_idx = (e_min > low_energy) & (e_max < high_energy) & (grouping.sum(axis=1) > 0)
|
|
146
152
|
col_idx = (e_min_unfolded > 0) & (redistribution.sum(axis=0) > 0)
|
|
147
153
|
|
jaxspec/data/observation.py
CHANGED
|
@@ -164,7 +164,9 @@ class Observation(xr.Dataset):
|
|
|
164
164
|
"""
|
|
165
165
|
from .util import data_path_finder
|
|
166
166
|
|
|
167
|
-
arf_path, rmf_path, bkg_path_default = data_path_finder(
|
|
167
|
+
arf_path, rmf_path, bkg_path_default = data_path_finder(
|
|
168
|
+
pha_path, require_arf=False, require_rmf=False, require_bkg=False
|
|
169
|
+
)
|
|
168
170
|
bkg_path = bkg_path_default if bkg_path is None else bkg_path
|
|
169
171
|
|
|
170
172
|
pha = DataPHA.from_file(pha_path)
|
jaxspec/data/ogip.py
CHANGED
|
@@ -109,7 +109,7 @@ class DataPHA:
|
|
|
109
109
|
raise ValueError("No QUALITY column found in the PHA file.")
|
|
110
110
|
|
|
111
111
|
if "BACKSCAL" in header:
|
|
112
|
-
backscal = header["BACKSCAL"] * np.ones_like(data["CHANNEL"])
|
|
112
|
+
backscal = header["BACKSCAL"] * np.ones_like(data["CHANNEL"], dtype=float)
|
|
113
113
|
elif "BACKSCAL" in data.colnames:
|
|
114
114
|
backscal = data["BACKSCAL"]
|
|
115
115
|
else:
|
|
@@ -138,7 +138,14 @@ class DataPHA:
|
|
|
138
138
|
"flags": flags,
|
|
139
139
|
}
|
|
140
140
|
|
|
141
|
-
|
|
141
|
+
if "COUNTS" in data.colnames:
|
|
142
|
+
counts = data["COUNTS"]
|
|
143
|
+
elif "RATE" in data.colnames:
|
|
144
|
+
counts = data["RATE"] * header["EXPOSURE"]
|
|
145
|
+
else:
|
|
146
|
+
raise ValueError("No COUNTS or RATE column found in the PHA file.")
|
|
147
|
+
|
|
148
|
+
return cls(data["CHANNEL"], counts, header["EXPOSURE"], **kwargs)
|
|
142
149
|
|
|
143
150
|
|
|
144
151
|
class DataARF:
|
jaxspec/data/util.py
CHANGED
|
@@ -228,12 +228,17 @@ def fakeit_for_multiple_parameters(
|
|
|
228
228
|
return fakeits[0] if len(fakeits) == 1 else fakeits
|
|
229
229
|
|
|
230
230
|
|
|
231
|
-
def data_path_finder(
|
|
231
|
+
def data_path_finder(
|
|
232
|
+
pha_path: str, require_arf: bool = True, require_rmf: bool = True, require_bkg: bool = False
|
|
233
|
+
) -> tuple[str | None, str | None, str | None]:
|
|
232
234
|
"""
|
|
233
235
|
Function which tries its best to find the ARF, RMF and BKG files associated with a given PHA file.
|
|
234
236
|
|
|
235
237
|
Parameters:
|
|
236
238
|
pha_path: The PHA file path.
|
|
239
|
+
require_arf: Whether to raise an error if the ARF file is not found.
|
|
240
|
+
require_rmf: Whether to raise an error if the RMF file is not found.
|
|
241
|
+
require_bkg: Whether to raise an error if the BKG file is not found.
|
|
237
242
|
|
|
238
243
|
Returns:
|
|
239
244
|
arf_path: The ARF file path.
|
|
@@ -241,23 +246,24 @@ def data_path_finder(pha_path: str) -> tuple[str | None, str | None, str | None]
|
|
|
241
246
|
bkg_path: The BKG file path.
|
|
242
247
|
"""
|
|
243
248
|
|
|
244
|
-
def find_path(file_name: str, directory: str) -> str | None:
|
|
245
|
-
if
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
+
def find_path(file_name: str, directory: str, raise_err: bool = True) -> str | None:
|
|
250
|
+
if raise_err:
|
|
251
|
+
if file_name.lower() != "none" and file_name != "":
|
|
252
|
+
return find_file_or_compressed_in_dir(file_name, directory, raise_err)
|
|
253
|
+
|
|
254
|
+
return None
|
|
249
255
|
|
|
250
256
|
header = fits.getheader(pha_path, "SPECTRUM")
|
|
251
257
|
directory = str(Path(pha_path).parent)
|
|
252
258
|
|
|
253
|
-
arf_path = find_path(header.get("ANCRFILE", "none"), directory)
|
|
254
|
-
rmf_path = find_path(header.get("RESPFILE", "none"), directory)
|
|
255
|
-
bkg_path = find_path(header.get("BACKFILE", "none"), directory)
|
|
259
|
+
arf_path = find_path(header.get("ANCRFILE", "none"), directory, require_arf)
|
|
260
|
+
rmf_path = find_path(header.get("RESPFILE", "none"), directory, require_rmf)
|
|
261
|
+
bkg_path = find_path(header.get("BACKFILE", "none"), directory, require_bkg)
|
|
256
262
|
|
|
257
263
|
return arf_path, rmf_path, bkg_path
|
|
258
264
|
|
|
259
265
|
|
|
260
|
-
def find_file_or_compressed_in_dir(path: str | Path, directory: str | Path) -> str:
|
|
266
|
+
def find_file_or_compressed_in_dir(path: str | Path, directory: str | Path, raise_err: bool) -> str:
|
|
261
267
|
"""
|
|
262
268
|
Try to find a file or its .gz compressed version in a given directory and return
|
|
263
269
|
the full path of the file.
|
|
@@ -275,5 +281,5 @@ def find_file_or_compressed_in_dir(path: str | Path, directory: str | Path) -> s
|
|
|
275
281
|
if file.suffix == ".gz":
|
|
276
282
|
return str(file)
|
|
277
283
|
|
|
278
|
-
|
|
284
|
+
elif raise_err:
|
|
279
285
|
raise FileNotFoundError(f"Can't find {path}(.gz) in {directory}.")
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from scipy.interpolate import RegularGridInterpolator
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class RegularGridInterpolatorWithGrad(RegularGridInterpolator):
|
|
7
|
+
"""
|
|
8
|
+
A subclass of SciPy's RegularGridInterpolator that also returns the gradient
|
|
9
|
+
of each interpolated output with respect to input coordinates.
|
|
10
|
+
|
|
11
|
+
Supports:
|
|
12
|
+
- Linear interpolation
|
|
13
|
+
- Out-of-bounds handling (fill_value=0)
|
|
14
|
+
- Multi-dimensional output (e.g., RGB or vector fields)
|
|
15
|
+
- Batched or single-point evaluation
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
- values: shape (..., output_dim)
|
|
19
|
+
- gradients: shape (..., input_dim, output_dim)
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, points, values, **kwargs):
|
|
23
|
+
kwargs.setdefault("method", "linear")
|
|
24
|
+
kwargs.setdefault("bounds_error", False)
|
|
25
|
+
kwargs.setdefault("fill_value", 0.0)
|
|
26
|
+
self.points = [np.asarray(p) for p in points]
|
|
27
|
+
self.input_dim = len(self.points)
|
|
28
|
+
|
|
29
|
+
self.output_shape = values.shape[self.input_dim :]
|
|
30
|
+
values_reshaped = values.reshape(*[len(p) for p in self.points], -1) # flatten output
|
|
31
|
+
|
|
32
|
+
super().__init__(self.points, values_reshaped, **kwargs)
|
|
33
|
+
|
|
34
|
+
def __call__(self, xi, return_gradient=True):
|
|
35
|
+
xi = np.atleast_2d(xi).astype(float)
|
|
36
|
+
n_points, n_dims = xi.shape
|
|
37
|
+
assert n_dims == self.input_dim, "Dim mismatch"
|
|
38
|
+
|
|
39
|
+
# Interpolate values
|
|
40
|
+
flat_vals = super().__call__(xi) # shape: (n_points, output_dim)
|
|
41
|
+
values = flat_vals.reshape(n_points, *self.output_shape)
|
|
42
|
+
|
|
43
|
+
if not return_gradient:
|
|
44
|
+
return values[0] if values.shape[0] == 1 else values
|
|
45
|
+
|
|
46
|
+
gradients = np.zeros((n_points, self.input_dim, np.prod(self.output_shape)), dtype=float)
|
|
47
|
+
|
|
48
|
+
for d, grid in enumerate(self.points):
|
|
49
|
+
xq = xi[:, d]
|
|
50
|
+
idx_upper = np.searchsorted(grid, xq, side="right")
|
|
51
|
+
idx_lower = idx_upper - 1
|
|
52
|
+
|
|
53
|
+
idx_lower = np.clip(idx_lower, 0, len(grid) - 2)
|
|
54
|
+
idx_upper = np.clip(idx_upper, 1, len(grid) - 1)
|
|
55
|
+
|
|
56
|
+
xi_low = xi.copy()
|
|
57
|
+
xi_high = xi.copy()
|
|
58
|
+
xi_low[:, d] = grid[idx_lower]
|
|
59
|
+
xi_high[:, d] = grid[idx_upper]
|
|
60
|
+
|
|
61
|
+
f_low = super().__call__(xi_low)
|
|
62
|
+
f_high = super().__call__(xi_high)
|
|
63
|
+
delta = (grid[idx_upper] - grid[idx_lower])[:, np.newaxis]
|
|
64
|
+
|
|
65
|
+
grad = np.where(delta != 0, (f_high - f_low) / delta, 0.0)
|
|
66
|
+
gradients[:, d, :] = grad
|
|
67
|
+
|
|
68
|
+
# Reshape output properly
|
|
69
|
+
gradients = gradients.reshape(n_points, self.input_dim, *self.output_shape)
|
|
70
|
+
|
|
71
|
+
if values.shape[0] == 1:
|
|
72
|
+
return values[0], gradients[0]
|
|
73
|
+
else:
|
|
74
|
+
return values, gradients
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from itertools import product
|
|
2
|
+
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
|
|
5
|
+
from jax.scipy.interpolate import RegularGridInterpolator
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class RegularGridInterpolatorWithGrad(RegularGridInterpolator):
|
|
9
|
+
"""
|
|
10
|
+
A subclass of SciPy's RegularGridInterpolator that also returns the gradient
|
|
11
|
+
of each interpolated output with respect to input coordinates.
|
|
12
|
+
|
|
13
|
+
Supports:
|
|
14
|
+
- Linear interpolation
|
|
15
|
+
- Out-of-bounds handling (fill_value=0)
|
|
16
|
+
- Multi-dimensional output (e.g., RGB or vector fields)
|
|
17
|
+
- Batched or single-point evaluation
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
- values: shape (..., output_dim)
|
|
21
|
+
- gradients: shape (..., input_dim, output_dim)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def _ndim_coords_from_arrays(self, points):
|
|
25
|
+
"""Convert a tuple of coordinate arrays to a (..., ndim)-shaped array."""
|
|
26
|
+
ndim = len(self.grid)
|
|
27
|
+
|
|
28
|
+
if isinstance(points, tuple) and len(points) == 1:
|
|
29
|
+
# handle argument tuple
|
|
30
|
+
points = points[0]
|
|
31
|
+
if isinstance(points, tuple):
|
|
32
|
+
p = jnp.broadcast_arrays(*points)
|
|
33
|
+
for p_other in p[1:]:
|
|
34
|
+
if p_other.shape != p[0].shape:
|
|
35
|
+
raise ValueError("coordinate arrays do not have the same shape")
|
|
36
|
+
points = jnp.empty((*p[0].shape, len(points)), dtype=float)
|
|
37
|
+
for j, item in enumerate(p):
|
|
38
|
+
points = points.at[..., j].set(item)
|
|
39
|
+
else:
|
|
40
|
+
points = jnp.asarray(points) # SciPy: asanyarray(points)
|
|
41
|
+
if points.ndim == 1:
|
|
42
|
+
if ndim is None:
|
|
43
|
+
points = points.reshape(-1, 1)
|
|
44
|
+
else:
|
|
45
|
+
points = points.reshape(-1, ndim)
|
|
46
|
+
return points
|
|
47
|
+
|
|
48
|
+
def __init__(self, points, values, **kwargs):
|
|
49
|
+
kwargs.setdefault("method", "linear")
|
|
50
|
+
kwargs.setdefault("bounds_error", False)
|
|
51
|
+
kwargs.setdefault("fill_value", 0.0)
|
|
52
|
+
|
|
53
|
+
super().__init__(points, values, **kwargs)
|
|
54
|
+
|
|
55
|
+
def value_and_grad(self, xi):
|
|
56
|
+
ndim = len(self.grid)
|
|
57
|
+
xi = self._ndim_coords_from_arrays(xi)
|
|
58
|
+
xi_shape = xi.shape
|
|
59
|
+
xi = xi.reshape(-1, xi_shape[-1])
|
|
60
|
+
|
|
61
|
+
indices, norm_distances, out_of_bounds = self._find_indices(xi.T)
|
|
62
|
+
|
|
63
|
+
vslice = (slice(None),) + (None,) * (self.values.ndim - len(indices))
|
|
64
|
+
|
|
65
|
+
# find relevant values
|
|
66
|
+
# each i and i+1 represents a edge
|
|
67
|
+
edges = product(*[[i, i + 1] for i in indices])
|
|
68
|
+
result = jnp.asarray(0.0)
|
|
69
|
+
for edge_indices in edges:
|
|
70
|
+
weight = jnp.asarray(1.0)
|
|
71
|
+
for ei, i, yi in zip(edge_indices, indices, norm_distances):
|
|
72
|
+
weight *= jnp.where(ei == 1, 1 - yi, yi)
|
|
73
|
+
result += self.values[edge_indices] * weight[vslice]
|
|
74
|
+
|
|
75
|
+
if not self.bounds_error and self.fill_value is not None:
|
|
76
|
+
bc_shp = result.shape[:1] + (1,) * (result.ndim - 1)
|
|
77
|
+
result = jnp.where(out_of_bounds.reshape(bc_shp), self.fill_value, result)
|
|
78
|
+
|
|
79
|
+
return result.reshape(xi_shape[:-1] + self.values.shape[ndim:])
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
import numpy as np
|
|
3
|
+
import numpyro
|
|
4
|
+
import numpyro.distributions as dist
|
|
5
|
+
|
|
6
|
+
from scipy.interpolate import BSpline
|
|
7
|
+
from tinygp import GaussianProcess, kernels
|
|
8
|
+
|
|
9
|
+
from ..model.instrument import GainModel, ShiftModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def bspline_basis(n_basis: int, degree: int = 3, interval=(0.0, 1.0)):
|
|
13
|
+
"""
|
|
14
|
+
Construct an open-uniform B-spline basis on a given interval.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
n_basis : int
|
|
19
|
+
Number of basis functions (X in the prompt). Must satisfy
|
|
20
|
+
n_basis >= degree + 1 for an open-uniform knot vector.
|
|
21
|
+
degree : int, optional
|
|
22
|
+
Polynomial degree of the splines (default 3 → cubic).
|
|
23
|
+
interval : tuple(float, float), optional
|
|
24
|
+
The (start, end) of the domain (default (0, 1)).
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
-------
|
|
28
|
+
basis : list[BSpline]
|
|
29
|
+
List of `n_basis` BSpline objects forming a basis over the interval.
|
|
30
|
+
knots : ndarray
|
|
31
|
+
The full knot vector, including endpoint multiplicities.
|
|
32
|
+
"""
|
|
33
|
+
a, b = interval
|
|
34
|
+
p = degree
|
|
35
|
+
if n_basis < p + 1:
|
|
36
|
+
raise ValueError(f"Need at least {p+1} basis functions (got {n_basis}).")
|
|
37
|
+
|
|
38
|
+
# How many *internal* knots (not counting the duplicated endpoints)?
|
|
39
|
+
n_internal = n_basis - p - 1 # open-uniform formula
|
|
40
|
+
|
|
41
|
+
# Equally spaced internal knots (could be user-supplied instead)
|
|
42
|
+
internal_knots = (
|
|
43
|
+
np.linspace(a, b, n_internal + 2)[1:-1] # drop the two ends
|
|
44
|
+
if n_internal > 0
|
|
45
|
+
else np.empty(0)
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
# Open-uniform knot vector: endpoints repeated p+1 times
|
|
49
|
+
knots = np.concatenate((np.full(p + 1, a), internal_knots, np.full(p + 1, b)))
|
|
50
|
+
|
|
51
|
+
# Coefficient matrix: each column of I generates one basis spline
|
|
52
|
+
coeffs = np.eye(n_basis)
|
|
53
|
+
|
|
54
|
+
# Build BSpline objects
|
|
55
|
+
basis = [BSpline(knots, coeffs[i], p, extrapolate=False) for i in range(n_basis)]
|
|
56
|
+
return basis, knots
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class GaussianProcessGain(GainModel):
|
|
60
|
+
def __init__(self, e_min, e_max, n_nodes=30):
|
|
61
|
+
# self.prior_distribution = prior_distribution
|
|
62
|
+
self.e_min = e_min
|
|
63
|
+
self.e_max = e_max
|
|
64
|
+
self.n_nodes = n_nodes
|
|
65
|
+
self.kernel = kernels.Matern52
|
|
66
|
+
|
|
67
|
+
def numpyro_model(self, observation_name: str):
|
|
68
|
+
mean = numpyro.sample(f"ins/~/_{observation_name}_mean", dist.Normal(1.0, 0.3))
|
|
69
|
+
|
|
70
|
+
sigma = numpyro.sample(f"ins/~/_{observation_name}_sigma", dist.HalfNormal(3.0))
|
|
71
|
+
rho = numpyro.sample(f"ins/~/_{observation_name}_rho", dist.HalfNormal(10.0))
|
|
72
|
+
|
|
73
|
+
# Set up the kernel and GP objects
|
|
74
|
+
kernel = sigma**2 * self.kernel(rho)
|
|
75
|
+
nodes = jnp.linspace(0, 1, self.n_nodes)
|
|
76
|
+
gp = GaussianProcess(kernel, nodes, diag=1e-5 * jnp.ones_like(nodes), mean=mean)
|
|
77
|
+
|
|
78
|
+
gain_sample = numpyro.sample(f"ins/~/{observation_name}_gain", gp.numpyro_dist())
|
|
79
|
+
|
|
80
|
+
def gain(energy):
|
|
81
|
+
return jnp.interp(
|
|
82
|
+
energy.mean(axis=0),
|
|
83
|
+
nodes * (self.e_max - self.e_min) + self.e_min,
|
|
84
|
+
gain_sample,
|
|
85
|
+
left=1.0,
|
|
86
|
+
right=1.0,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
return gain
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class BsplineGain(GainModel):
|
|
93
|
+
def __init__(self, e_min, e_max, n_nodes=6, grid_size=30):
|
|
94
|
+
self.e_min = e_min
|
|
95
|
+
self.e_max = e_max
|
|
96
|
+
self.n_nodes = n_nodes
|
|
97
|
+
self.egrid = jnp.linspace(e_min, e_max, grid_size)
|
|
98
|
+
|
|
99
|
+
basis, knots = bspline_basis(n_nodes, 3, (e_min, e_max))
|
|
100
|
+
|
|
101
|
+
self.gridded_basis = jnp.asarray([bi(self.egrid) for bi in basis])
|
|
102
|
+
|
|
103
|
+
def numpyro_model(self, observation_name: str):
|
|
104
|
+
coeff = numpyro.sample(
|
|
105
|
+
f"ins/~/_{observation_name}_coeff",
|
|
106
|
+
dist.Uniform(0 * jnp.ones(self.n_nodes), 2 * jnp.ones(self.n_nodes)),
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def gain(energy):
|
|
110
|
+
gridded_gain = jnp.dot(coeff, self.gridded_basis)
|
|
111
|
+
|
|
112
|
+
return jnp.interp(energy.mean(axis=0), self.egrid, gridded_gain, left=1.0, right=1.0)
|
|
113
|
+
|
|
114
|
+
return gain
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class PolynomialGain(GainModel):
|
|
118
|
+
def __init__(self, prior_distribution):
|
|
119
|
+
self.prior_distribution = prior_distribution
|
|
120
|
+
distribution_shape = prior_distribution.shape()
|
|
121
|
+
self.degree = distribution_shape[0] if len(distribution_shape) > 0 else 0
|
|
122
|
+
|
|
123
|
+
def numpyro_model(self, observation_name: str):
|
|
124
|
+
polynomial_coefficient = numpyro.sample(
|
|
125
|
+
f"ins/~/gain_{observation_name}", self.prior_distribution
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
if self.degree == 0:
|
|
129
|
+
|
|
130
|
+
def gain(energy):
|
|
131
|
+
return polynomial_coefficient
|
|
132
|
+
|
|
133
|
+
else:
|
|
134
|
+
|
|
135
|
+
def gain(energy):
|
|
136
|
+
return jnp.polyval(polynomial_coefficient, energy.mean(axis=0))
|
|
137
|
+
|
|
138
|
+
return gain
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class PolynomialShift(ShiftModel):
|
|
142
|
+
def __init__(self, prior_distribution):
|
|
143
|
+
self.prior_distribution = prior_distribution
|
|
144
|
+
distribution_shape = prior_distribution.shape()
|
|
145
|
+
self.degree = distribution_shape[0] if len(distribution_shape) > 0 else 0
|
|
146
|
+
|
|
147
|
+
def numpyro_model(self, observation_name: str):
|
|
148
|
+
polynomial_coefficient = numpyro.sample(
|
|
149
|
+
f"ins/~/shift_{observation_name}", self.prior_distribution
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
if self.degree == 0:
|
|
153
|
+
# ensure that new_energy = energy + constant
|
|
154
|
+
polynomial_coefficient = jnp.asarray([1.0, polynomial_coefficient])
|
|
155
|
+
|
|
156
|
+
def shift(energy):
|
|
157
|
+
return jnp.polyval(polynomial_coefficient, energy)
|
|
158
|
+
|
|
159
|
+
return shift
|