jaxspec 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/_fit/_build_model.py +26 -103
- jaxspec/analysis/_plot.py +166 -7
- jaxspec/analysis/results.py +219 -330
- jaxspec/data/instrument.py +47 -12
- jaxspec/data/obsconf.py +12 -2
- jaxspec/data/observation.py +17 -4
- jaxspec/data/ogip.py +32 -13
- jaxspec/data/util.py +5 -75
- jaxspec/fit.py +56 -44
- jaxspec/model/_graph_util.py +151 -0
- jaxspec/model/abc.py +275 -414
- jaxspec/model/additive.py +276 -289
- jaxspec/model/background.py +3 -4
- jaxspec/model/multiplicative.py +101 -85
- jaxspec/scripts/debug.py +1 -1
- jaxspec/util/__init__.py +0 -45
- jaxspec/util/misc.py +25 -0
- jaxspec/util/typing.py +0 -63
- {jaxspec-0.1.4.dist-info → jaxspec-0.2.0.dist-info}/METADATA +12 -13
- jaxspec-0.2.0.dist-info/RECORD +34 -0
- {jaxspec-0.1.4.dist-info → jaxspec-0.2.0.dist-info}/WHEEL +1 -1
- jaxspec/data/grouping.py +0 -23
- jaxspec-0.1.4.dist-info/RECORD +0 -33
- {jaxspec-0.1.4.dist-info → jaxspec-0.2.0.dist-info}/LICENSE.md +0 -0
- {jaxspec-0.1.4.dist-info → jaxspec-0.2.0.dist-info}/entry_points.txt +0 -0
jaxspec/model/abc.py
CHANGED
|
@@ -1,37 +1,86 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
import operator
|
|
4
|
+
|
|
5
|
+
from abc import ABC
|
|
6
|
+
from functools import partial
|
|
4
7
|
from uuid import uuid4
|
|
5
8
|
|
|
6
|
-
import
|
|
9
|
+
import flax.nnx as nnx
|
|
7
10
|
import jax
|
|
8
11
|
import jax.numpy as jnp
|
|
12
|
+
import jax.scipy as jsp
|
|
9
13
|
import networkx as nx
|
|
10
|
-
import rich
|
|
11
14
|
|
|
12
|
-
from haiku._src import base
|
|
13
|
-
from jax.scipy.integrate import trapezoid
|
|
14
|
-
from rich.table import Table
|
|
15
15
|
from simpleeval import simple_eval
|
|
16
16
|
|
|
17
|
+
from jaxspec.util.typing import PriorDictType
|
|
18
|
+
|
|
19
|
+
from ._graph_util import compose, export_to_mermaid
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def set_parameters(params: PriorDictType, state: nnx.State) -> nnx.State:
|
|
23
|
+
"""
|
|
24
|
+
Set the parameters of a spectral model using `nnx'` routines.
|
|
25
|
+
|
|
26
|
+
Parameters:
|
|
27
|
+
params: Dictionary of parameters to set.
|
|
28
|
+
model: Spectral model.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
A spectral model with the newly set parameters.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
state_dict = state.to_pure_dict() # haiku-like 2 level dictionary
|
|
35
|
+
|
|
36
|
+
for key, value in params.items():
|
|
37
|
+
# Split the key to extract the module name and parameter name
|
|
38
|
+
module_name, param_name = key.rsplit("_", 1)
|
|
39
|
+
state_dict["modules"][module_name][param_name] = value
|
|
40
|
+
|
|
41
|
+
state.replace_by_pure_dict(state_dict)
|
|
42
|
+
|
|
43
|
+
return state
|
|
17
44
|
|
|
18
|
-
|
|
45
|
+
|
|
46
|
+
class Composable(ABC):
|
|
19
47
|
"""
|
|
20
|
-
|
|
21
|
-
operations, and allows tracking of the operation graph and individual parameters.
|
|
48
|
+
Defines the set of operations between model components and spectral models
|
|
22
49
|
"""
|
|
23
50
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
51
|
+
def sanitize_inputs(self, other):
|
|
52
|
+
if isinstance(self, ModelComponent):
|
|
53
|
+
model_1 = SpectralModel.from_component(self)
|
|
54
|
+
else:
|
|
55
|
+
model_1 = self
|
|
56
|
+
|
|
57
|
+
if isinstance(other, ModelComponent):
|
|
58
|
+
model_2 = SpectralModel.from_component(other)
|
|
59
|
+
else:
|
|
60
|
+
model_2 = other
|
|
28
61
|
|
|
29
|
-
|
|
30
|
-
self.raw_graph = internal_graph
|
|
31
|
-
self.labels = labels
|
|
32
|
-
self.graph = self.build_namespace()
|
|
62
|
+
return model_1, model_2
|
|
33
63
|
|
|
34
|
-
|
|
64
|
+
def __add__(self, other):
|
|
65
|
+
model_1, model_2 = self.sanitize_inputs(other)
|
|
66
|
+
return model_1.compose(model_2, operation="add", operation_func=operator.add)
|
|
67
|
+
|
|
68
|
+
def __mul__(self, other):
|
|
69
|
+
model_1, model_2 = self.sanitize_inputs(other)
|
|
70
|
+
return model_1.compose(model_2, operation="mul", operation_func=operator.mul)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class SpectralModel(nnx.Module, Composable):
|
|
74
|
+
# graph: nx.DiGraph = eqx.field(static=True)
|
|
75
|
+
# modules: dict
|
|
76
|
+
|
|
77
|
+
def __init__(self, graph: nx.DiGraph):
|
|
78
|
+
self.graph = graph
|
|
79
|
+
self.modules = {}
|
|
80
|
+
|
|
81
|
+
for node, data in self.graph.nodes(data=True):
|
|
82
|
+
if "component" in data["type"]:
|
|
83
|
+
self.modules[data["name"]] = data["component"] # (**data['kwargs'])
|
|
35
84
|
|
|
36
85
|
@classmethod
|
|
37
86
|
def from_string(cls, string: str) -> SpectralModel:
|
|
@@ -67,10 +116,10 @@ class SpectralModel:
|
|
|
67
116
|
>>> model.to_string()
|
|
68
117
|
"Tbabs()*(Powerlaw() + Blackbody())"
|
|
69
118
|
"""
|
|
70
|
-
|
|
71
119
|
return str(self)
|
|
72
120
|
|
|
73
|
-
|
|
121
|
+
"""
|
|
122
|
+
def __str__(self) -> str:
|
|
74
123
|
def build_expression(node_id):
|
|
75
124
|
node = self.graph.nodes[node_id]
|
|
76
125
|
if node["type"] == "component":
|
|
@@ -92,485 +141,297 @@ class SpectralModel:
|
|
|
92
141
|
predecessors = list(self.graph.predecessors(node_id))
|
|
93
142
|
return build_expression(predecessors[0])
|
|
94
143
|
|
|
95
|
-
return build_expression("out")[1:-1]
|
|
96
|
-
|
|
97
|
-
@property
|
|
98
|
-
def transformed_func_photon(self):
|
|
99
|
-
def func_to_transform(e_low, e_high, n_points=2):
|
|
100
|
-
return self.flux(e_low, e_high, n_points=n_points, energy_flux=False)
|
|
101
|
-
|
|
102
|
-
return hk.without_apply_rng(hk.transform(func_to_transform))
|
|
103
|
-
|
|
104
|
-
@property
|
|
105
|
-
def transformed_func_energy(self):
|
|
106
|
-
def func_to_transform(e_low, e_high, n_points=2):
|
|
107
|
-
return self.flux(e_low, e_high, n_points=n_points, energy_flux=True)
|
|
108
|
-
|
|
109
|
-
return hk.without_apply_rng(hk.transform(func_to_transform))
|
|
110
|
-
|
|
111
|
-
@property
|
|
112
|
-
def params(self):
|
|
113
|
-
return self.transformed_func_photon.init(None, jnp.ones(10), jnp.ones(10))
|
|
114
|
-
|
|
115
|
-
def __rich_repr__(self):
|
|
116
|
-
table = Table(title=str(self))
|
|
117
|
-
|
|
118
|
-
table.add_column("Component", justify="right", style="bold", no_wrap=True)
|
|
119
|
-
table.add_column("Parameter")
|
|
120
|
-
|
|
121
|
-
params = self.params
|
|
122
|
-
|
|
123
|
-
for component in params.keys():
|
|
124
|
-
once = True
|
|
125
|
-
|
|
126
|
-
for parameters in params[component].keys():
|
|
127
|
-
table.add_row(component if once else "", parameters)
|
|
128
|
-
once = False
|
|
129
|
-
|
|
130
|
-
return table
|
|
131
|
-
|
|
132
|
-
def __repr_html_(self):
|
|
133
|
-
return self.__rich_repr__()
|
|
134
|
-
|
|
135
|
-
def __repr__(self):
|
|
136
|
-
if not base.frame_stack:
|
|
137
|
-
rich.print(self.__rich_repr__())
|
|
138
|
-
return ""
|
|
139
|
-
|
|
140
|
-
def photon_flux(self, params, e_low, e_high, n_points=2):
|
|
141
|
-
r"""
|
|
142
|
-
Compute the expected counts between $E_\min$ and $E_\max$ by integrating the model.
|
|
143
|
-
|
|
144
|
-
$$ \Phi_{\text{photon}}\left(E_\min, ~E_\max\right) =
|
|
145
|
-
\int _{E_\min}^{E_\max}\text{d}E ~ \mathcal{M}\left( E \right)
|
|
146
|
-
\quad \left[\frac{\text{photons}}{\text{cm}^2\text{s}}\right]$$
|
|
147
|
-
|
|
148
|
-
Parameters:
|
|
149
|
-
params : The parameters of the model.
|
|
150
|
-
e_low : The lower bound of the energy bins.
|
|
151
|
-
e_high : The upper bound of the energy bins.
|
|
152
|
-
n_points : The number of points used to integrate the model in each bin.
|
|
153
|
-
|
|
154
|
-
!!! info
|
|
155
|
-
This method is internally used in the inference process and should not be used directly. See
|
|
156
|
-
[`photon_flux`][jaxspec.analysis.results.FitResult.photon_flux] to compute
|
|
157
|
-
the photon flux associated with a set of fitted parameters in a
|
|
158
|
-
[`FitResult`][jaxspec.analysis.results.FitResult]
|
|
159
|
-
instead.
|
|
160
|
-
"""
|
|
161
|
-
|
|
162
|
-
params = jax.tree_map(lambda x: jnp.asarray(x), params)
|
|
163
|
-
e_low = jnp.asarray(e_low)
|
|
164
|
-
e_high = jnp.asarray(e_high)
|
|
165
|
-
|
|
166
|
-
return self.transformed_func_photon.apply(params, e_low, e_high, n_points=n_points)
|
|
167
|
-
|
|
168
|
-
def energy_flux(self, params, e_low, e_high, n_points=2):
|
|
169
|
-
r"""
|
|
170
|
-
Compute the expected energy flux between $E_\min$ and $E_\max$ by integrating the model.
|
|
171
|
-
|
|
172
|
-
$$ \Phi_{\text{energy}}\left(E_\min, ~E_\max\right) =
|
|
173
|
-
\int _{E_\min}^{E_\max}\text{d}E ~ E ~ \mathcal{M}\left( E \right)
|
|
174
|
-
\quad \left[\frac{\text{keV}}{\text{cm}^2\text{s}}\right]$$
|
|
175
|
-
|
|
176
|
-
Parameters:
|
|
177
|
-
params : The parameters of the model.
|
|
178
|
-
e_low : The lower bound of the energy bins.
|
|
179
|
-
e_high : The upper bound of the energy bins.
|
|
180
|
-
n_points : The number of points used to integrate the model in each bin.
|
|
181
|
-
|
|
182
|
-
!!! info
|
|
183
|
-
This method is internally used in the inference process and should not be used directly. See
|
|
184
|
-
[`energy_flux`](/references/results/#jaxspec.analysis.results.FitResult.energy_flux) to compute
|
|
185
|
-
the energy flux associated with a set of fitted parameters in a
|
|
186
|
-
[`FitResult`](/references/results/#jaxspec.analysis.results.FitResult)
|
|
187
|
-
instead.
|
|
188
|
-
"""
|
|
189
|
-
|
|
190
|
-
params = jax.tree_map(lambda x: jnp.asarray(x), params)
|
|
191
|
-
e_low = jnp.asarray(e_low)
|
|
192
|
-
e_high = jnp.asarray(e_high)
|
|
193
|
-
|
|
194
|
-
return self.transformed_func_energy.apply(params, e_low, e_high, n_points=n_points)
|
|
195
|
-
|
|
196
|
-
def build_namespace(self):
|
|
197
|
-
"""
|
|
198
|
-
This method build a namespace for the model components, to avoid name collision
|
|
199
|
-
"""
|
|
200
|
-
|
|
201
|
-
name_space = []
|
|
202
|
-
new_graph = self.raw_graph.copy()
|
|
203
|
-
|
|
204
|
-
for node_id in nx.dag.topological_sort(new_graph):
|
|
205
|
-
node = new_graph.nodes[node_id]
|
|
206
|
-
|
|
207
|
-
if node and node["type"] == "component":
|
|
208
|
-
name_space.append(node["name"])
|
|
209
|
-
n = name_space.count(node["name"])
|
|
210
|
-
nx.set_node_attributes(new_graph, {node_id: name_space[-1] + f"_{n}"}, "name")
|
|
211
|
-
|
|
212
|
-
return new_graph
|
|
144
|
+
return "This must be changed" # build_expression("out")[1:-1]
|
|
145
|
+
"""
|
|
213
146
|
|
|
214
|
-
def
|
|
147
|
+
def compose(self, other, operation=None, operation_func=None):
|
|
215
148
|
"""
|
|
216
|
-
This
|
|
217
|
-
It
|
|
218
|
-
It
|
|
219
|
-
It
|
|
149
|
+
This function operate a composition between the operation graph of two models
|
|
150
|
+
1) It fuses the two graphs using which joins at the 'out' nodes and change components name to unique identifiers
|
|
151
|
+
2) It relabels the 'out' node with a unique identifier and labels it with the operation
|
|
152
|
+
3) It links the operation to a new 'out' node
|
|
220
153
|
"""
|
|
221
154
|
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
energies = jnp.hstack((e_low, e_high[-1]))
|
|
226
|
-
energies_to_integrate = jnp.stack((e_low, e_high))
|
|
227
|
-
|
|
228
|
-
else:
|
|
229
|
-
energies_to_integrate = jnp.linspace(e_low, e_high, n_points)
|
|
230
|
-
energies = energies_to_integrate
|
|
231
|
-
|
|
232
|
-
fine_structures_flux = jnp.zeros_like(e_low)
|
|
233
|
-
runtime_modules = {}
|
|
234
|
-
continuum = {}
|
|
235
|
-
|
|
236
|
-
# Iterate through the graph in topological order and
|
|
237
|
-
# compute the continuum contribution for each component
|
|
238
|
-
|
|
239
|
-
for node_id in nx.dag.topological_sort(self.graph):
|
|
240
|
-
node = self.graph.nodes[node_id]
|
|
241
|
-
|
|
242
|
-
# Instantiate the haiku modules
|
|
243
|
-
if node and node["type"] == "component":
|
|
244
|
-
runtime_modules[node_id] = node["component"](name=node["name"], **node["kwargs"])
|
|
245
|
-
continuum[node_id] = runtime_modules[node_id].continuum(energies)
|
|
246
|
-
|
|
247
|
-
elif node and node["type"] == "operation":
|
|
248
|
-
component_1 = list(self.graph.in_edges(node_id))[0][0] # noqa: RUF015
|
|
249
|
-
component_2 = list(self.graph.in_edges(node_id))[1][0]
|
|
250
|
-
continuum[node_id] = node["function"](
|
|
251
|
-
continuum[component_1], continuum[component_2]
|
|
252
|
-
)
|
|
253
|
-
|
|
254
|
-
if n_points == 2:
|
|
255
|
-
flux_1D = continuum[list(self.graph.in_edges("out"))[0][0]] # noqa: RUF015
|
|
256
|
-
flux = jnp.stack((flux_1D[:-1], flux_1D[1:]))
|
|
257
|
-
|
|
258
|
-
else:
|
|
259
|
-
flux = continuum[list(self.graph.in_edges("out"))[0][0]] # noqa: RUF015
|
|
260
|
-
|
|
261
|
-
if energy_flux:
|
|
262
|
-
continuum_flux = trapezoid(
|
|
263
|
-
flux * energies_to_integrate**2,
|
|
264
|
-
x=jnp.log(energies_to_integrate),
|
|
265
|
-
axis=0,
|
|
266
|
-
)
|
|
267
|
-
|
|
268
|
-
else:
|
|
269
|
-
continuum_flux = trapezoid(
|
|
270
|
-
flux * energies_to_integrate, x=jnp.log(energies_to_integrate), axis=0
|
|
271
|
-
)
|
|
272
|
-
|
|
273
|
-
# Iterate from the root nodes to the output node and
|
|
274
|
-
# compute the fine structure contribution for each component
|
|
275
|
-
|
|
276
|
-
root_nodes = [
|
|
277
|
-
node_id
|
|
278
|
-
for node_id, in_degree in self.graph.in_degree(self.graph.nodes)
|
|
279
|
-
if in_degree == 0 and self.graph.nodes[node_id].get("component_type") == "additive"
|
|
280
|
-
]
|
|
281
|
-
|
|
282
|
-
for root_node_id in root_nodes:
|
|
283
|
-
path = nx.shortest_path(self.graph, source=root_node_id, target="out")
|
|
284
|
-
nodes_id_in_path = [node_id for node_id in path]
|
|
285
|
-
|
|
286
|
-
flux_from_component, mean_energy = runtime_modules[root_node_id].emission_lines(
|
|
287
|
-
e_low, e_high
|
|
288
|
-
)
|
|
289
|
-
|
|
290
|
-
multiplicative_nodes = []
|
|
155
|
+
composed_graph = compose(
|
|
156
|
+
self.graph, other.graph, operation=operation, operation_func=operation_func
|
|
157
|
+
)
|
|
291
158
|
|
|
292
|
-
|
|
293
|
-
# and apply them at mean energy
|
|
294
|
-
for node_id in nodes_id_in_path[::-1]:
|
|
295
|
-
multiplicative_nodes.extend(
|
|
296
|
-
[node_id for node_id in self.find_multiplicative_components(node_id)]
|
|
297
|
-
)
|
|
159
|
+
return SpectralModel(composed_graph)
|
|
298
160
|
|
|
299
|
-
|
|
300
|
-
|
|
161
|
+
@classmethod
|
|
162
|
+
def from_component(cls, component):
|
|
163
|
+
node_id = str(uuid4())
|
|
164
|
+
graph = nx.DiGraph()
|
|
301
165
|
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
166
|
+
node_properties = {
|
|
167
|
+
"type": f"{component.type}_component",
|
|
168
|
+
"name": f"{component.__class__.__name__}_1".lower(),
|
|
169
|
+
"component": component,
|
|
170
|
+
"depth": 0,
|
|
171
|
+
}
|
|
308
172
|
|
|
309
|
-
|
|
310
|
-
|
|
173
|
+
graph.add_node(node_id, **node_properties)
|
|
174
|
+
graph.add_node("out", type="out", depth=1)
|
|
175
|
+
graph.add_edge(node_id, "out")
|
|
311
176
|
|
|
312
|
-
return
|
|
177
|
+
return cls(graph)
|
|
313
178
|
|
|
314
|
-
def
|
|
179
|
+
def _find_multiplicative_components(self, node_id):
|
|
315
180
|
"""
|
|
316
181
|
Recursively finds all the multiplicative components connected to the node with the given ID.
|
|
317
182
|
"""
|
|
318
183
|
node = self.graph.nodes[node_id]
|
|
319
184
|
multiplicative_nodes = []
|
|
320
185
|
|
|
321
|
-
if node.get("
|
|
186
|
+
if node.get("type") == "mul_operation":
|
|
322
187
|
# Recursively find all the multiplicative components using the predecessors
|
|
323
188
|
predecessors = self.graph.pred[node_id]
|
|
324
189
|
for node_id in predecessors:
|
|
325
|
-
if self.graph.nodes[node_id].get("
|
|
190
|
+
if "multiplicative_component" == self.graph.nodes[node_id].get("type"):
|
|
326
191
|
multiplicative_nodes.append(node_id)
|
|
327
|
-
elif self.graph.nodes[node_id].get("
|
|
328
|
-
multiplicative_nodes.extend(self.
|
|
192
|
+
elif "mul_operation" == self.graph.nodes[node_id].get("type"):
|
|
193
|
+
multiplicative_nodes.extend(self._find_multiplicative_components(node_id))
|
|
329
194
|
|
|
330
195
|
return multiplicative_nodes
|
|
331
196
|
|
|
332
|
-
|
|
333
|
-
|
|
197
|
+
@property
|
|
198
|
+
def root_nodes(self) -> list[str]:
|
|
199
|
+
return [
|
|
200
|
+
node_id
|
|
201
|
+
for node_id, in_degree in self.graph.in_degree(self.graph.nodes)
|
|
202
|
+
if in_degree == 0 and ("additive" in self.graph.nodes[node_id].get("type"))
|
|
203
|
+
]
|
|
334
204
|
|
|
335
|
-
@
|
|
336
|
-
def
|
|
337
|
-
|
|
338
|
-
Build a model from a single component
|
|
339
|
-
"""
|
|
205
|
+
@property
|
|
206
|
+
def branches(self) -> list[str]:
|
|
207
|
+
branches = []
|
|
340
208
|
|
|
341
|
-
|
|
209
|
+
for root_node_id in self.root_nodes:
|
|
210
|
+
root_node_name = self.graph.nodes[root_node_id].get("name")
|
|
211
|
+
path = nx.shortest_path(self.graph, source=root_node_id, target="out")
|
|
212
|
+
multiplicative_components = []
|
|
342
213
|
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
214
|
+
# Search all multiplicative components connected to this node
|
|
215
|
+
# and apply them at mean energy
|
|
216
|
+
for node_id in path[::-1]:
|
|
217
|
+
multiplicative_components.extend(
|
|
218
|
+
[node_id for node_id in self._find_multiplicative_components(node_id)]
|
|
219
|
+
)
|
|
346
220
|
|
|
347
|
-
|
|
221
|
+
branch = ""
|
|
348
222
|
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
+ component(**kwargs).emission_lines(e, e + 1)[0]
|
|
353
|
-
)
|
|
223
|
+
for multiplicative_node_id in multiplicative_components:
|
|
224
|
+
multiplicative_node_name = self.graph.nodes[multiplicative_node_id].get("name")
|
|
225
|
+
branch += f"{multiplicative_node_name}*"
|
|
354
226
|
|
|
355
|
-
|
|
227
|
+
branch += f"{root_node_name}"
|
|
228
|
+
branches.append(branch)
|
|
356
229
|
|
|
357
|
-
|
|
358
|
-
return component().continuum(e)
|
|
230
|
+
return branches
|
|
359
231
|
|
|
360
|
-
|
|
232
|
+
def turbo_flux(self, e_low, e_high, energy_flux=False, n_points=2, return_branches=False):
|
|
233
|
+
continuum = {}
|
|
361
234
|
|
|
362
|
-
|
|
363
|
-
|
|
235
|
+
## Evaluate the expected contribution for each component
|
|
236
|
+
for node_id in nx.dag.topological_sort(self.graph):
|
|
237
|
+
node = self.graph.nodes[node_id]
|
|
364
238
|
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
"name": component.__name__.lower(),
|
|
369
|
-
"component": component,
|
|
370
|
-
# "params": hk.transform(lam_func).init(None, jnp.ones(1)),
|
|
371
|
-
"fine_structure": False,
|
|
372
|
-
"kwargs": kwargs,
|
|
373
|
-
"depth": 0,
|
|
374
|
-
}
|
|
239
|
+
if node["type"] == "additive_component":
|
|
240
|
+
node_name = node["name"]
|
|
241
|
+
runtime_modules = self.modules[node_name]
|
|
375
242
|
|
|
376
|
-
|
|
243
|
+
if not energy_flux:
|
|
244
|
+
continuum[node_name] = runtime_modules._photon_flux(
|
|
245
|
+
e_low, e_high, n_points=n_points
|
|
246
|
+
)
|
|
377
247
|
|
|
378
|
-
|
|
379
|
-
|
|
248
|
+
else:
|
|
249
|
+
continuum[node_name] = runtime_modules._energy_flux(
|
|
250
|
+
e_low, e_high, n_points=n_points
|
|
251
|
+
)
|
|
380
252
|
|
|
381
|
-
|
|
382
|
-
|
|
253
|
+
elif node["type"] == "multiplicative_component":
|
|
254
|
+
node_name = node["name"]
|
|
255
|
+
runtime_modules = self.modules[node_name]
|
|
256
|
+
continuum[node_name] = runtime_modules.factor((e_low + e_high) / 2)
|
|
383
257
|
|
|
384
|
-
|
|
258
|
+
else:
|
|
259
|
+
pass
|
|
385
260
|
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
) -> SpectralModel:
|
|
389
|
-
"""
|
|
390
|
-
This function operate a composition between the operation graph of two models
|
|
391
|
-
1) It fuses the two graphs using which joins at the 'out' nodes
|
|
392
|
-
2) It relabels the 'out' node with a unique identifier and labels it with the operation
|
|
393
|
-
3) It links the operation to a new 'out' node
|
|
394
|
-
"""
|
|
261
|
+
## Propagate the absorption for each branch
|
|
262
|
+
root_nodes = self.root_nodes
|
|
395
263
|
|
|
396
|
-
|
|
397
|
-
# and add the operation node by overwriting the 'out' node
|
|
398
|
-
node_id = str(uuid4())
|
|
399
|
-
graph = nx.relabel_nodes(nx.compose(self.raw_graph, other.raw_graph), {"out": node_id})
|
|
400
|
-
nx.set_node_attributes(graph, {node_id: "operation"}, "type")
|
|
401
|
-
nx.set_node_attributes(graph, {node_id: operation}, "operation_type")
|
|
402
|
-
nx.set_node_attributes(graph, {node_id: function}, "function")
|
|
403
|
-
nx.set_node_attributes(graph, {node_id: name}, "operation_label")
|
|
404
|
-
|
|
405
|
-
# Merge label dictionaries
|
|
406
|
-
labels = self.labels | other.labels
|
|
407
|
-
labels[node_id] = operation
|
|
408
|
-
|
|
409
|
-
# Now add the output node and link it to the operation node
|
|
410
|
-
graph.add_node("out", type="out")
|
|
411
|
-
graph.add_edge(node_id, "out")
|
|
264
|
+
branches = {}
|
|
412
265
|
|
|
413
|
-
|
|
414
|
-
|
|
266
|
+
for root_node_id in root_nodes:
|
|
267
|
+
root_node_name = self.graph.nodes[root_node_id].get("name")
|
|
268
|
+
root_continuum = continuum[root_node_name]
|
|
415
269
|
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
graph,
|
|
419
|
-
{node: longest_path - nx.shortest_path_length(graph, node, "out")},
|
|
420
|
-
"depth",
|
|
421
|
-
)
|
|
270
|
+
path = nx.shortest_path(self.graph, source=root_node_id, target="out")
|
|
271
|
+
multiplicative_components = []
|
|
422
272
|
|
|
423
|
-
|
|
273
|
+
# Search all multiplicative components connected to this node
|
|
274
|
+
# and apply them at mean energy
|
|
275
|
+
for node_id in path[::-1]:
|
|
276
|
+
multiplicative_components.extend(
|
|
277
|
+
[node_id for node_id in self._find_multiplicative_components(node_id)]
|
|
278
|
+
)
|
|
424
279
|
|
|
425
|
-
|
|
426
|
-
return self.compose(other, operation="add", function=lambda x, y: x + y, name="+")
|
|
280
|
+
branch = ""
|
|
427
281
|
|
|
428
|
-
|
|
429
|
-
|
|
282
|
+
for multiplicative_node_id in multiplicative_components:
|
|
283
|
+
multiplicative_node_name = self.graph.nodes[multiplicative_node_id].get("name")
|
|
284
|
+
root_continuum *= continuum[multiplicative_node_name]
|
|
285
|
+
branch += f"{multiplicative_node_name}*"
|
|
430
286
|
|
|
431
|
-
|
|
432
|
-
|
|
287
|
+
branch += f"{root_node_name}"
|
|
288
|
+
branches[branch] = root_continuum
|
|
433
289
|
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
if attributes["type"] == "component":
|
|
437
|
-
name, number = attributes["name"].split("_")
|
|
438
|
-
mermaid_code += f' {node}("{name.capitalize()} ({number})")\n'
|
|
290
|
+
if return_branches:
|
|
291
|
+
return branches
|
|
439
292
|
|
|
440
|
-
|
|
441
|
-
if attributes["operation_type"] == "add":
|
|
442
|
-
mermaid_code += f" {node}{{+}}\n"
|
|
293
|
+
return sum(branches.values())
|
|
443
294
|
|
|
444
|
-
|
|
445
|
-
|
|
295
|
+
def to_mermaid(self, file: str | None = None):
|
|
296
|
+
"""
|
|
297
|
+
This method returns the mermaid representation of the model.
|
|
446
298
|
|
|
447
|
-
|
|
448
|
-
|
|
299
|
+
Parameters:
|
|
300
|
+
file : The file to write the mermaid representation to.
|
|
449
301
|
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
302
|
+
Returns:
|
|
303
|
+
A string containing the mermaid representation of the model.
|
|
304
|
+
"""
|
|
305
|
+
return export_to_mermaid(self.graph, file)
|
|
453
306
|
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
f.write(mermaid_code)
|
|
307
|
+
@partial(jax.jit, static_argnums=0, static_argnames=("n_points", "split_branches"))
|
|
308
|
+
def photon_flux(self, params, e_low, e_high, n_points=2, split_branches=False):
|
|
309
|
+
r"""
|
|
310
|
+
Compute the expected counts between $E_\min$ and $E_\max$ by integrating the model.
|
|
459
311
|
|
|
460
|
-
|
|
461
|
-
|
|
312
|
+
$$ \Phi_{\text{photon}}\left(E_\min, ~E_\max\right) =
|
|
313
|
+
\int _{E_\min}^{E_\max}\text{d}E ~ \mathcal{M}\left( E \right)
|
|
314
|
+
\quad \left[\frac{\text{photons}}{\text{cm}^2\text{s}}\right]$$
|
|
462
315
|
|
|
463
|
-
|
|
316
|
+
Parameters:
|
|
317
|
+
params : The parameters of the model.
|
|
318
|
+
e_low : The lower bound of the energy bins.
|
|
319
|
+
e_high : The upper bound of the energy bins.
|
|
320
|
+
n_points : The number of points used to integrate the model in each bin.
|
|
464
321
|
|
|
465
|
-
|
|
322
|
+
!!! info
|
|
323
|
+
This method is internally used in the inference process and should not be used directly. See
|
|
324
|
+
[`photon_flux`][jaxspec.analysis.results.FitResult.photon_flux] to compute
|
|
325
|
+
the photon flux associated with a set of fitted parameters in a
|
|
326
|
+
[`FitResult`][jaxspec.analysis.results.FitResult]
|
|
327
|
+
instead.
|
|
328
|
+
"""
|
|
466
329
|
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
nx.draw_networkx_edges(self.graph, pos, width=1.0)
|
|
330
|
+
graphdef, state = nnx.split(self)
|
|
331
|
+
state = set_parameters(params, state)
|
|
470
332
|
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
labels=nx.get_node_attributes(self.graph, "name"),
|
|
475
|
-
font_size=12,
|
|
476
|
-
font_color="black",
|
|
477
|
-
bbox={"fc": "tab:red", "boxstyle": "round", "pad": 0.3},
|
|
478
|
-
)
|
|
479
|
-
nx.draw_networkx_labels(
|
|
480
|
-
self.graph,
|
|
481
|
-
pos,
|
|
482
|
-
labels=nx.get_node_attributes(self.graph, "operation_label"),
|
|
483
|
-
font_size=12,
|
|
484
|
-
font_color="black",
|
|
485
|
-
bbox={"fc": "tab:blue", "boxstyle": "circle", "pad": 0.3},
|
|
486
|
-
)
|
|
333
|
+
return nnx.call((graphdef, state)).turbo_flux(
|
|
334
|
+
e_low, e_high, n_points=n_points, return_branches=split_branches
|
|
335
|
+
)[0]
|
|
487
336
|
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
337
|
+
@partial(jax.jit, static_argnums=0, static_argnames="n_points")
|
|
338
|
+
def energy_flux(self, params, e_low, e_high, n_points=2):
|
|
339
|
+
r"""
|
|
340
|
+
Compute the expected energy flux between $E_\min$ and $E_\max$ by integrating the model.
|
|
492
341
|
|
|
342
|
+
$$ \Phi_{\text{energy}}\left(E_\min, ~E_\max\right) =
|
|
343
|
+
\int _{E_\min}^{E_\max}\text{d}E ~ E ~ \mathcal{M}\left( E \right)
|
|
344
|
+
\quad \left[\frac{\text{keV}}{\text{cm}^2\text{s}}\right]$$
|
|
493
345
|
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
346
|
+
Parameters:
|
|
347
|
+
params : The parameters of the model.
|
|
348
|
+
e_low : The lower bound of the energy bins.
|
|
349
|
+
e_high : The upper bound of the energy bins.
|
|
350
|
+
n_points : The number of points used to integrate the model in each bin.
|
|
499
351
|
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
352
|
+
!!! info
|
|
353
|
+
This method is internally used in the inference process and should not be used directly. See
|
|
354
|
+
[`energy_flux`](/references/results/#jaxspec.analysis.results.FitResult.energy_flux) to compute
|
|
355
|
+
the energy flux associated with a set of fitted parameters in a
|
|
356
|
+
[`FitResult`](/references/results/#jaxspec.analysis.results.FitResult)
|
|
357
|
+
instead.
|
|
504
358
|
"""
|
|
505
359
|
|
|
506
|
-
|
|
507
|
-
|
|
360
|
+
graphdef, state = nnx.split(self)
|
|
361
|
+
state = set_parameters(params, state)
|
|
508
362
|
|
|
509
|
-
|
|
510
|
-
|
|
363
|
+
return nnx.call((graphdef, state)).turbo_flux(
|
|
364
|
+
e_low, e_high, n_points=n_points, energy_flux=True
|
|
365
|
+
)[0]
|
|
511
366
|
|
|
512
367
|
|
|
513
|
-
class ModelComponent(
|
|
368
|
+
class ModelComponent(nnx.Module, Composable, ABC):
|
|
514
369
|
"""
|
|
515
370
|
Abstract class for model components
|
|
516
371
|
"""
|
|
517
372
|
|
|
518
|
-
|
|
373
|
+
...
|
|
519
374
|
|
|
520
|
-
def __init__(self, *args, **kwargs):
|
|
521
|
-
super().__init__(*args, **kwargs)
|
|
522
375
|
|
|
523
|
-
|
|
524
|
-
class AdditiveComponent(ModelComponent, ABC):
|
|
376
|
+
class AdditiveComponent(ModelComponent):
|
|
525
377
|
type = "additive"
|
|
526
378
|
|
|
527
379
|
def continuum(self, energy):
|
|
528
|
-
"""
|
|
529
|
-
|
|
530
|
-
By default, this is set to 0, which means that the model has no continuum.
|
|
531
|
-
This should be overloaded by the user if the model has a continuum.
|
|
532
|
-
"""
|
|
380
|
+
r"""
|
|
381
|
+
Compute the continuum of the component.
|
|
533
382
|
|
|
383
|
+
Parameters:
|
|
384
|
+
energy : The energy at which to compute the continuum.
|
|
385
|
+
"""
|
|
534
386
|
return jnp.zeros_like(energy)
|
|
535
387
|
|
|
536
|
-
def
|
|
537
|
-
"""
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
388
|
+
def integrated_continuum(self, e_low, e_high):
|
|
389
|
+
r"""
|
|
390
|
+
Compute the integrated continuum between $E_\min$ and $E_\max$.
|
|
391
|
+
|
|
392
|
+
Parameters:
|
|
393
|
+
e_low: Lower bound of the energy bin.
|
|
394
|
+
e_high: Upper bound of the energy bin.
|
|
541
395
|
"""
|
|
396
|
+
return jnp.zeros_like((e_low + e_high) / 2)
|
|
542
397
|
|
|
543
|
-
|
|
398
|
+
def _photon_flux(self, e_low, e_high, n_points=2):
|
|
399
|
+
energy = jnp.linspace(e_low, e_high, n_points, axis=-1)
|
|
400
|
+
continuum = self.continuum(energy)
|
|
401
|
+
integrated_continuum = self.integrated_continuum(e_low, e_high)
|
|
544
402
|
|
|
545
|
-
|
|
546
|
-
def integral(self, e_min, e_max):
|
|
547
|
-
r"""
|
|
548
|
-
Method for integrating an additive model between two energies. It relies on
|
|
549
|
-
double exponential quadrature for finite intervals to compute an approximation
|
|
550
|
-
of the integral of a model.
|
|
403
|
+
return jsp.integrate.trapezoid(continuum, energy, axis=-1) + integrated_continuum
|
|
551
404
|
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
* $Tanh-sinh quadrature <https://en.wikipedia.org/wiki/Tanh-sinh_quadrature>$_ from Wikipedia
|
|
405
|
+
def _energy_flux(self, e_low, e_high, n_points=2):
|
|
406
|
+
energy = jnp.linspace(e_low, e_high, n_points, axis=-1)
|
|
407
|
+
continuum = self.continuum(energy)
|
|
408
|
+
integrated_continuum = self.integrated_continuum(e_low, e_high)
|
|
557
409
|
|
|
558
|
-
|
|
410
|
+
return jsp.integrate.trapezoid(
|
|
411
|
+
continuum * energy**2, jnp.log(energy), axis=-1
|
|
412
|
+
) + integrated_continuum * (e_high - e_low)
|
|
559
413
|
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
x = (e_max - e_min) / 2 * phi + (e_max + e_min) / 2
|
|
566
|
-
dx = (e_max - e_min) / 2 * dphi
|
|
414
|
+
@partial(jax.jit, static_argnums=0, static_argnames="n_points")
|
|
415
|
+
def photon_flux(self, params, e_low, e_high, n_points=2):
|
|
416
|
+
return SpectralModel.from_component(self).photon_flux(
|
|
417
|
+
params, e_low, e_high, n_points=n_points
|
|
418
|
+
)
|
|
567
419
|
|
|
568
|
-
|
|
569
|
-
|
|
420
|
+
@partial(jax.jit, static_argnums=0, static_argnames="n_points")
|
|
421
|
+
def energy_flux(self, params, e_low, e_high, n_points=2):
|
|
422
|
+
return SpectralModel.from_component(self).energy_flux(
|
|
423
|
+
params, e_low, e_high, n_points=n_points
|
|
424
|
+
)
|
|
570
425
|
|
|
571
426
|
|
|
572
|
-
class MultiplicativeComponent(ModelComponent
|
|
427
|
+
class MultiplicativeComponent(ModelComponent):
|
|
573
428
|
type = "multiplicative"
|
|
574
429
|
|
|
575
|
-
|
|
576
|
-
|
|
430
|
+
def factor(self, energy):
|
|
431
|
+
"""
|
|
432
|
+
Absorption factor applied for a given energy
|
|
433
|
+
|
|
434
|
+
Parameters:
|
|
435
|
+
energy : The energy at which to compute the factor.
|
|
436
|
+
"""
|
|
437
|
+
return jnp.ones_like(energy)
|