jaxspec 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxspec/model/abc.py CHANGED
@@ -1,37 +1,86 @@
1
1
  from __future__ import annotations
2
2
 
3
- from abc import ABC, abstractmethod
3
+ import operator
4
+
5
+ from abc import ABC
6
+ from functools import partial
4
7
  from uuid import uuid4
5
8
 
6
- import haiku as hk
9
+ import flax.nnx as nnx
7
10
  import jax
8
11
  import jax.numpy as jnp
12
+ import jax.scipy as jsp
9
13
  import networkx as nx
10
- import rich
11
14
 
12
- from haiku._src import base
13
- from jax.scipy.integrate import trapezoid
14
- from rich.table import Table
15
15
  from simpleeval import simple_eval
16
16
 
17
+ from jaxspec.util.typing import PriorDictType
18
+
19
+ from ._graph_util import compose, export_to_mermaid
20
+
21
+
22
+ def set_parameters(params: PriorDictType, state: nnx.State) -> nnx.State:
23
+ """
24
+ Set the parameters of a spectral model using `nnx'` routines.
25
+
26
+ Parameters:
27
+ params: Dictionary of parameters to set.
28
+ model: Spectral model.
29
+
30
+ Returns:
31
+ A spectral model with the newly set parameters.
32
+ """
33
+
34
+ state_dict = state.to_pure_dict() # haiku-like 2 level dictionary
35
+
36
+ for key, value in params.items():
37
+ # Split the key to extract the module name and parameter name
38
+ module_name, param_name = key.rsplit("_", 1)
39
+ state_dict["modules"][module_name][param_name] = value
40
+
41
+ state.replace_by_pure_dict(state_dict)
42
+
43
+ return state
17
44
 
18
- class SpectralModel:
45
+
46
+ class Composable(ABC):
19
47
  """
20
- This class is supposed to handle the composition of models through basic
21
- operations, and allows tracking of the operation graph and individual parameters.
48
+ Defines the set of operations between model components and spectral models
22
49
  """
23
50
 
24
- raw_graph: nx.DiGraph
25
- graph: nx.DiGraph
26
- labels: dict[str, str]
27
- n_parameters: int
51
+ def sanitize_inputs(self, other):
52
+ if isinstance(self, ModelComponent):
53
+ model_1 = SpectralModel.from_component(self)
54
+ else:
55
+ model_1 = self
56
+
57
+ if isinstance(other, ModelComponent):
58
+ model_2 = SpectralModel.from_component(other)
59
+ else:
60
+ model_2 = other
28
61
 
29
- def __init__(self, internal_graph, labels):
30
- self.raw_graph = internal_graph
31
- self.labels = labels
32
- self.graph = self.build_namespace()
62
+ return model_1, model_2
33
63
 
34
- self.n_parameters = hk.data_structures.tree_size(self.params)
64
+ def __add__(self, other):
65
+ model_1, model_2 = self.sanitize_inputs(other)
66
+ return model_1.compose(model_2, operation="add", operation_func=operator.add)
67
+
68
+ def __mul__(self, other):
69
+ model_1, model_2 = self.sanitize_inputs(other)
70
+ return model_1.compose(model_2, operation="mul", operation_func=operator.mul)
71
+
72
+
73
+ class SpectralModel(nnx.Module, Composable):
74
+ # graph: nx.DiGraph = eqx.field(static=True)
75
+ # modules: dict
76
+
77
+ def __init__(self, graph: nx.DiGraph):
78
+ self.graph = graph
79
+ self.modules = {}
80
+
81
+ for node, data in self.graph.nodes(data=True):
82
+ if "component" in data["type"]:
83
+ self.modules[data["name"]] = data["component"] # (**data['kwargs'])
35
84
 
36
85
  @classmethod
37
86
  def from_string(cls, string: str) -> SpectralModel:
@@ -67,10 +116,10 @@ class SpectralModel:
67
116
  >>> model.to_string()
68
117
  "Tbabs()*(Powerlaw() + Blackbody())"
69
118
  """
70
-
71
119
  return str(self)
72
120
 
73
- def __str__(self) -> SpectralModel:
121
+ """
122
+ def __str__(self) -> str:
74
123
  def build_expression(node_id):
75
124
  node = self.graph.nodes[node_id]
76
125
  if node["type"] == "component":
@@ -92,485 +141,297 @@ class SpectralModel:
92
141
  predecessors = list(self.graph.predecessors(node_id))
93
142
  return build_expression(predecessors[0])
94
143
 
95
- return build_expression("out")[1:-1]
96
-
97
- @property
98
- def transformed_func_photon(self):
99
- def func_to_transform(e_low, e_high, n_points=2):
100
- return self.flux(e_low, e_high, n_points=n_points, energy_flux=False)
101
-
102
- return hk.without_apply_rng(hk.transform(func_to_transform))
103
-
104
- @property
105
- def transformed_func_energy(self):
106
- def func_to_transform(e_low, e_high, n_points=2):
107
- return self.flux(e_low, e_high, n_points=n_points, energy_flux=True)
108
-
109
- return hk.without_apply_rng(hk.transform(func_to_transform))
110
-
111
- @property
112
- def params(self):
113
- return self.transformed_func_photon.init(None, jnp.ones(10), jnp.ones(10))
114
-
115
- def __rich_repr__(self):
116
- table = Table(title=str(self))
117
-
118
- table.add_column("Component", justify="right", style="bold", no_wrap=True)
119
- table.add_column("Parameter")
120
-
121
- params = self.params
122
-
123
- for component in params.keys():
124
- once = True
125
-
126
- for parameters in params[component].keys():
127
- table.add_row(component if once else "", parameters)
128
- once = False
129
-
130
- return table
131
-
132
- def __repr_html_(self):
133
- return self.__rich_repr__()
134
-
135
- def __repr__(self):
136
- if not base.frame_stack:
137
- rich.print(self.__rich_repr__())
138
- return ""
139
-
140
- def photon_flux(self, params, e_low, e_high, n_points=2):
141
- r"""
142
- Compute the expected counts between $E_\min$ and $E_\max$ by integrating the model.
143
-
144
- $$ \Phi_{\text{photon}}\left(E_\min, ~E_\max\right) =
145
- \int _{E_\min}^{E_\max}\text{d}E ~ \mathcal{M}\left( E \right)
146
- \quad \left[\frac{\text{photons}}{\text{cm}^2\text{s}}\right]$$
147
-
148
- Parameters:
149
- params : The parameters of the model.
150
- e_low : The lower bound of the energy bins.
151
- e_high : The upper bound of the energy bins.
152
- n_points : The number of points used to integrate the model in each bin.
153
-
154
- !!! info
155
- This method is internally used in the inference process and should not be used directly. See
156
- [`photon_flux`][jaxspec.analysis.results.FitResult.photon_flux] to compute
157
- the photon flux associated with a set of fitted parameters in a
158
- [`FitResult`][jaxspec.analysis.results.FitResult]
159
- instead.
160
- """
161
-
162
- params = jax.tree_map(lambda x: jnp.asarray(x), params)
163
- e_low = jnp.asarray(e_low)
164
- e_high = jnp.asarray(e_high)
165
-
166
- return self.transformed_func_photon.apply(params, e_low, e_high, n_points=n_points)
167
-
168
- def energy_flux(self, params, e_low, e_high, n_points=2):
169
- r"""
170
- Compute the expected energy flux between $E_\min$ and $E_\max$ by integrating the model.
171
-
172
- $$ \Phi_{\text{energy}}\left(E_\min, ~E_\max\right) =
173
- \int _{E_\min}^{E_\max}\text{d}E ~ E ~ \mathcal{M}\left( E \right)
174
- \quad \left[\frac{\text{keV}}{\text{cm}^2\text{s}}\right]$$
175
-
176
- Parameters:
177
- params : The parameters of the model.
178
- e_low : The lower bound of the energy bins.
179
- e_high : The upper bound of the energy bins.
180
- n_points : The number of points used to integrate the model in each bin.
181
-
182
- !!! info
183
- This method is internally used in the inference process and should not be used directly. See
184
- [`energy_flux`](/references/results/#jaxspec.analysis.results.FitResult.energy_flux) to compute
185
- the energy flux associated with a set of fitted parameters in a
186
- [`FitResult`](/references/results/#jaxspec.analysis.results.FitResult)
187
- instead.
188
- """
189
-
190
- params = jax.tree_map(lambda x: jnp.asarray(x), params)
191
- e_low = jnp.asarray(e_low)
192
- e_high = jnp.asarray(e_high)
193
-
194
- return self.transformed_func_energy.apply(params, e_low, e_high, n_points=n_points)
195
-
196
- def build_namespace(self):
197
- """
198
- This method build a namespace for the model components, to avoid name collision
199
- """
200
-
201
- name_space = []
202
- new_graph = self.raw_graph.copy()
203
-
204
- for node_id in nx.dag.topological_sort(new_graph):
205
- node = new_graph.nodes[node_id]
206
-
207
- if node and node["type"] == "component":
208
- name_space.append(node["name"])
209
- n = name_space.count(node["name"])
210
- nx.set_node_attributes(new_graph, {node_id: name_space[-1] + f"_{n}"}, "name")
211
-
212
- return new_graph
144
+ return "This must be changed" # build_expression("out")[1:-1]
145
+ """
213
146
 
214
- def flux(self, e_low, e_high, energy_flux=False, n_points=2):
147
+ def compose(self, other, operation=None, operation_func=None):
215
148
  """
216
- This method return the expected counts between e_low and e_high by integrating the model.
217
- It contains most of the "usine à gaz" which makes jaxspec works.
218
- It evaluates the graph of operations and returns the result.
219
- It should be transformed using haiku.
149
+ This function operate a composition between the operation graph of two models
150
+ 1) It fuses the two graphs using which joins at the 'out' nodes and change components name to unique identifiers
151
+ 2) It relabels the 'out' node with a unique identifier and labels it with the operation
152
+ 3) It links the operation to a new 'out' node
220
153
  """
221
154
 
222
- # TODO : enable interpolation and integration with more than 2 points for the continuum
223
-
224
- if n_points == 2:
225
- energies = jnp.hstack((e_low, e_high[-1]))
226
- energies_to_integrate = jnp.stack((e_low, e_high))
227
-
228
- else:
229
- energies_to_integrate = jnp.linspace(e_low, e_high, n_points)
230
- energies = energies_to_integrate
231
-
232
- fine_structures_flux = jnp.zeros_like(e_low)
233
- runtime_modules = {}
234
- continuum = {}
235
-
236
- # Iterate through the graph in topological order and
237
- # compute the continuum contribution for each component
238
-
239
- for node_id in nx.dag.topological_sort(self.graph):
240
- node = self.graph.nodes[node_id]
241
-
242
- # Instantiate the haiku modules
243
- if node and node["type"] == "component":
244
- runtime_modules[node_id] = node["component"](name=node["name"], **node["kwargs"])
245
- continuum[node_id] = runtime_modules[node_id].continuum(energies)
246
-
247
- elif node and node["type"] == "operation":
248
- component_1 = list(self.graph.in_edges(node_id))[0][0] # noqa: RUF015
249
- component_2 = list(self.graph.in_edges(node_id))[1][0]
250
- continuum[node_id] = node["function"](
251
- continuum[component_1], continuum[component_2]
252
- )
253
-
254
- if n_points == 2:
255
- flux_1D = continuum[list(self.graph.in_edges("out"))[0][0]] # noqa: RUF015
256
- flux = jnp.stack((flux_1D[:-1], flux_1D[1:]))
257
-
258
- else:
259
- flux = continuum[list(self.graph.in_edges("out"))[0][0]] # noqa: RUF015
260
-
261
- if energy_flux:
262
- continuum_flux = trapezoid(
263
- flux * energies_to_integrate**2,
264
- x=jnp.log(energies_to_integrate),
265
- axis=0,
266
- )
267
-
268
- else:
269
- continuum_flux = trapezoid(
270
- flux * energies_to_integrate, x=jnp.log(energies_to_integrate), axis=0
271
- )
272
-
273
- # Iterate from the root nodes to the output node and
274
- # compute the fine structure contribution for each component
275
-
276
- root_nodes = [
277
- node_id
278
- for node_id, in_degree in self.graph.in_degree(self.graph.nodes)
279
- if in_degree == 0 and self.graph.nodes[node_id].get("component_type") == "additive"
280
- ]
281
-
282
- for root_node_id in root_nodes:
283
- path = nx.shortest_path(self.graph, source=root_node_id, target="out")
284
- nodes_id_in_path = [node_id for node_id in path]
285
-
286
- flux_from_component, mean_energy = runtime_modules[root_node_id].emission_lines(
287
- e_low, e_high
288
- )
289
-
290
- multiplicative_nodes = []
155
+ composed_graph = compose(
156
+ self.graph, other.graph, operation=operation, operation_func=operation_func
157
+ )
291
158
 
292
- # Search all multiplicative components connected to this node
293
- # and apply them at mean energy
294
- for node_id in nodes_id_in_path[::-1]:
295
- multiplicative_nodes.extend(
296
- [node_id for node_id in self.find_multiplicative_components(node_id)]
297
- )
159
+ return SpectralModel(composed_graph)
298
160
 
299
- for mul_node in multiplicative_nodes:
300
- flux_from_component *= runtime_modules[mul_node].continuum(mean_energy)
161
+ @classmethod
162
+ def from_component(cls, component):
163
+ node_id = str(uuid4())
164
+ graph = nx.DiGraph()
301
165
 
302
- if energy_flux:
303
- fine_structures_flux += trapezoid(
304
- flux_from_component * energies_to_integrate,
305
- x=jnp.log(energies_to_integrate),
306
- axis=0,
307
- )
166
+ node_properties = {
167
+ "type": f"{component.type}_component",
168
+ "name": f"{component.__class__.__name__}_1".lower(),
169
+ "component": component,
170
+ "depth": 0,
171
+ }
308
172
 
309
- else:
310
- fine_structures_flux += flux_from_component
173
+ graph.add_node(node_id, **node_properties)
174
+ graph.add_node("out", type="out", depth=1)
175
+ graph.add_edge(node_id, "out")
311
176
 
312
- return continuum_flux + fine_structures_flux
177
+ return cls(graph)
313
178
 
314
- def find_multiplicative_components(self, node_id):
179
+ def _find_multiplicative_components(self, node_id):
315
180
  """
316
181
  Recursively finds all the multiplicative components connected to the node with the given ID.
317
182
  """
318
183
  node = self.graph.nodes[node_id]
319
184
  multiplicative_nodes = []
320
185
 
321
- if node.get("operation_type") == "mul":
186
+ if node.get("type") == "mul_operation":
322
187
  # Recursively find all the multiplicative components using the predecessors
323
188
  predecessors = self.graph.pred[node_id]
324
189
  for node_id in predecessors:
325
- if self.graph.nodes[node_id].get("component_type") == "multiplicative":
190
+ if "multiplicative_component" == self.graph.nodes[node_id].get("type"):
326
191
  multiplicative_nodes.append(node_id)
327
- elif self.graph.nodes[node_id].get("operation_type") == "mul":
328
- multiplicative_nodes.extend(self.find_multiplicative_components(node_id))
192
+ elif "mul_operation" == self.graph.nodes[node_id].get("type"):
193
+ multiplicative_nodes.extend(self._find_multiplicative_components(node_id))
329
194
 
330
195
  return multiplicative_nodes
331
196
 
332
- def __call__(self, pars, e_low, e_high, **kwargs):
333
- return self.photon_flux(pars, e_low, e_high, **kwargs)
197
+ @property
198
+ def root_nodes(self) -> list[str]:
199
+ return [
200
+ node_id
201
+ for node_id, in_degree in self.graph.in_degree(self.graph.nodes)
202
+ if in_degree == 0 and ("additive" in self.graph.nodes[node_id].get("type"))
203
+ ]
334
204
 
335
- @classmethod
336
- def from_component(cls, component, **kwargs) -> SpectralModel:
337
- """
338
- Build a model from a single component
339
- """
205
+ @property
206
+ def branches(self) -> list[str]:
207
+ branches = []
340
208
 
341
- graph = nx.DiGraph()
209
+ for root_node_id in self.root_nodes:
210
+ root_node_name = self.graph.nodes[root_node_id].get("name")
211
+ path = nx.shortest_path(self.graph, source=root_node_id, target="out")
212
+ multiplicative_components = []
342
213
 
343
- # Add the component node
344
- # Random static node id to keep it trackable in the graph
345
- node_id = str(uuid4())
214
+ # Search all multiplicative components connected to this node
215
+ # and apply them at mean energy
216
+ for node_id in path[::-1]:
217
+ multiplicative_components.extend(
218
+ [node_id for node_id in self._find_multiplicative_components(node_id)]
219
+ )
346
220
 
347
- if component.type == "additive":
221
+ branch = ""
348
222
 
349
- def lam_func(e):
350
- return (
351
- component(**kwargs).continuum(e)
352
- + component(**kwargs).emission_lines(e, e + 1)[0]
353
- )
223
+ for multiplicative_node_id in multiplicative_components:
224
+ multiplicative_node_name = self.graph.nodes[multiplicative_node_id].get("name")
225
+ branch += f"{multiplicative_node_name}*"
354
226
 
355
- elif component.type == "multiplicative":
227
+ branch += f"{root_node_name}"
228
+ branches.append(branch)
356
229
 
357
- def lam_func(e):
358
- return component().continuum(e)
230
+ return branches
359
231
 
360
- else:
232
+ def turbo_flux(self, e_low, e_high, energy_flux=False, n_points=2, return_branches=False):
233
+ continuum = {}
361
234
 
362
- def lam_func(e):
363
- return print("Some components are not working at this stage")
235
+ ## Evaluate the expected contribution for each component
236
+ for node_id in nx.dag.topological_sort(self.graph):
237
+ node = self.graph.nodes[node_id]
364
238
 
365
- node_properties = {
366
- "type": "component",
367
- "component_type": component.type,
368
- "name": component.__name__.lower(),
369
- "component": component,
370
- # "params": hk.transform(lam_func).init(None, jnp.ones(1)),
371
- "fine_structure": False,
372
- "kwargs": kwargs,
373
- "depth": 0,
374
- }
239
+ if node["type"] == "additive_component":
240
+ node_name = node["name"]
241
+ runtime_modules = self.modules[node_name]
375
242
 
376
- graph.add_node(node_id, **node_properties)
243
+ if not energy_flux:
244
+ continuum[node_name] = runtime_modules._photon_flux(
245
+ e_low, e_high, n_points=n_points
246
+ )
377
247
 
378
- # Add the output node
379
- labels = {node_id: component.__name__.lower(), "out": "out"}
248
+ else:
249
+ continuum[node_name] = runtime_modules._energy_flux(
250
+ e_low, e_high, n_points=n_points
251
+ )
380
252
 
381
- graph.add_node("out", type="out", depth=1)
382
- graph.add_edge(node_id, "out")
253
+ elif node["type"] == "multiplicative_component":
254
+ node_name = node["name"]
255
+ runtime_modules = self.modules[node_name]
256
+ continuum[node_name] = runtime_modules.factor((e_low + e_high) / 2)
383
257
 
384
- return cls(graph, labels)
258
+ else:
259
+ pass
385
260
 
386
- def compose(
387
- self, other: SpectralModel, operation=None, function=None, name=None
388
- ) -> SpectralModel:
389
- """
390
- This function operate a composition between the operation graph of two models
391
- 1) It fuses the two graphs using which joins at the 'out' nodes
392
- 2) It relabels the 'out' node with a unique identifier and labels it with the operation
393
- 3) It links the operation to a new 'out' node
394
- """
261
+ ## Propagate the absorption for each branch
262
+ root_nodes = self.root_nodes
395
263
 
396
- # Compose the two graphs with their output as common node
397
- # and add the operation node by overwriting the 'out' node
398
- node_id = str(uuid4())
399
- graph = nx.relabel_nodes(nx.compose(self.raw_graph, other.raw_graph), {"out": node_id})
400
- nx.set_node_attributes(graph, {node_id: "operation"}, "type")
401
- nx.set_node_attributes(graph, {node_id: operation}, "operation_type")
402
- nx.set_node_attributes(graph, {node_id: function}, "function")
403
- nx.set_node_attributes(graph, {node_id: name}, "operation_label")
404
-
405
- # Merge label dictionaries
406
- labels = self.labels | other.labels
407
- labels[node_id] = operation
408
-
409
- # Now add the output node and link it to the operation node
410
- graph.add_node("out", type="out")
411
- graph.add_edge(node_id, "out")
264
+ branches = {}
412
265
 
413
- # Compute the new depth of each node
414
- longest_path = nx.dag_longest_path_length(graph)
266
+ for root_node_id in root_nodes:
267
+ root_node_name = self.graph.nodes[root_node_id].get("name")
268
+ root_continuum = continuum[root_node_name]
415
269
 
416
- for node in graph.nodes:
417
- nx.set_node_attributes(
418
- graph,
419
- {node: longest_path - nx.shortest_path_length(graph, node, "out")},
420
- "depth",
421
- )
270
+ path = nx.shortest_path(self.graph, source=root_node_id, target="out")
271
+ multiplicative_components = []
422
272
 
423
- return SpectralModel(graph, labels)
273
+ # Search all multiplicative components connected to this node
274
+ # and apply them at mean energy
275
+ for node_id in path[::-1]:
276
+ multiplicative_components.extend(
277
+ [node_id for node_id in self._find_multiplicative_components(node_id)]
278
+ )
424
279
 
425
- def __add__(self, other: SpectralModel) -> SpectralModel:
426
- return self.compose(other, operation="add", function=lambda x, y: x + y, name="+")
280
+ branch = ""
427
281
 
428
- def __mul__(self, other: SpectralModel) -> SpectralModel:
429
- return self.compose(other, operation="mul", function=lambda x, y: x * y, name=r"*")
282
+ for multiplicative_node_id in multiplicative_components:
283
+ multiplicative_node_name = self.graph.nodes[multiplicative_node_id].get("name")
284
+ root_continuum *= continuum[multiplicative_node_name]
285
+ branch += f"{multiplicative_node_name}*"
430
286
 
431
- def export_to_mermaid(self, file=None):
432
- mermaid_code = "graph LR\n" # LR = left to right
287
+ branch += f"{root_node_name}"
288
+ branches[branch] = root_continuum
433
289
 
434
- # Add nodes
435
- for node, attributes in self.graph.nodes(data=True):
436
- if attributes["type"] == "component":
437
- name, number = attributes["name"].split("_")
438
- mermaid_code += f' {node}("{name.capitalize()} ({number})")\n'
290
+ if return_branches:
291
+ return branches
439
292
 
440
- if attributes["type"] == "operation":
441
- if attributes["operation_type"] == "add":
442
- mermaid_code += f" {node}{{+}}\n"
293
+ return sum(branches.values())
443
294
 
444
- if attributes["operation_type"] == "mul":
445
- mermaid_code += f" {node}{{x}}\n"
295
+ def to_mermaid(self, file: str | None = None):
296
+ """
297
+ This method returns the mermaid representation of the model.
446
298
 
447
- if attributes["type"] == "out":
448
- mermaid_code += f' {node}("Output")\n'
299
+ Parameters:
300
+ file : The file to write the mermaid representation to.
449
301
 
450
- # Draw connexion between nodes
451
- for source, target in self.graph.edges():
452
- mermaid_code += f" {source} --> {target}\n"
302
+ Returns:
303
+ A string containing the mermaid representation of the model.
304
+ """
305
+ return export_to_mermaid(self.graph, file)
453
306
 
454
- if file is None:
455
- return mermaid_code
456
- else:
457
- with open(file, "w") as f:
458
- f.write(mermaid_code)
307
+ @partial(jax.jit, static_argnums=0, static_argnames=("n_points", "split_branches"))
308
+ def photon_flux(self, params, e_low, e_high, n_points=2, split_branches=False):
309
+ r"""
310
+ Compute the expected counts between $E_\min$ and $E_\max$ by integrating the model.
459
311
 
460
- def plot(self, figsize=(8, 8)):
461
- import matplotlib.pyplot as plt
312
+ $$ \Phi_{\text{photon}}\left(E_\min, ~E_\max\right) =
313
+ \int _{E_\min}^{E_\max}\text{d}E ~ \mathcal{M}\left( E \right)
314
+ \quad \left[\frac{\text{photons}}{\text{cm}^2\text{s}}\right]$$
462
315
 
463
- plt.figure(figsize=figsize)
316
+ Parameters:
317
+ params : The parameters of the model.
318
+ e_low : The lower bound of the energy bins.
319
+ e_high : The upper bound of the energy bins.
320
+ n_points : The number of points used to integrate the model in each bin.
464
321
 
465
- pos = nx.multipartite_layout(self.graph, subset_key="depth", scale=1)
322
+ !!! info
323
+ This method is internally used in the inference process and should not be used directly. See
324
+ [`photon_flux`][jaxspec.analysis.results.FitResult.photon_flux] to compute
325
+ the photon flux associated with a set of fitted parameters in a
326
+ [`FitResult`][jaxspec.analysis.results.FitResult]
327
+ instead.
328
+ """
466
329
 
467
- nodes_out = [x for x, y in self.graph.nodes(data=True) if y["type"] == "out"]
468
- nx.draw_networkx_nodes(self.graph, pos, nodelist=nodes_out, node_color="tab:green")
469
- nx.draw_networkx_edges(self.graph, pos, width=1.0)
330
+ graphdef, state = nnx.split(self)
331
+ state = set_parameters(params, state)
470
332
 
471
- nx.draw_networkx_labels(
472
- self.graph,
473
- pos,
474
- labels=nx.get_node_attributes(self.graph, "name"),
475
- font_size=12,
476
- font_color="black",
477
- bbox={"fc": "tab:red", "boxstyle": "round", "pad": 0.3},
478
- )
479
- nx.draw_networkx_labels(
480
- self.graph,
481
- pos,
482
- labels=nx.get_node_attributes(self.graph, "operation_label"),
483
- font_size=12,
484
- font_color="black",
485
- bbox={"fc": "tab:blue", "boxstyle": "circle", "pad": 0.3},
486
- )
333
+ return nnx.call((graphdef, state)).turbo_flux(
334
+ e_low, e_high, n_points=n_points, return_branches=split_branches
335
+ )[0]
487
336
 
488
- plt.axis("equal")
489
- plt.axis("off")
490
- plt.tight_layout()
491
- plt.show()
337
+ @partial(jax.jit, static_argnums=0, static_argnames="n_points")
338
+ def energy_flux(self, params, e_low, e_high, n_points=2):
339
+ r"""
340
+ Compute the expected energy flux between $E_\min$ and $E_\max$ by integrating the model.
492
341
 
342
+ $$ \Phi_{\text{energy}}\left(E_\min, ~E_\max\right) =
343
+ \int _{E_\min}^{E_\max}\text{d}E ~ E ~ \mathcal{M}\left( E \right)
344
+ \quad \left[\frac{\text{keV}}{\text{cm}^2\text{s}}\right]$$
493
345
 
494
- class ComponentMetaClass(type(hk.Module)):
495
- """
496
- This metaclass enable the construction of model from components with a simple
497
- syntax while style enabling the components to be used as haiku modules.
498
- """
346
+ Parameters:
347
+ params : The parameters of the model.
348
+ e_low : The lower bound of the energy bins.
349
+ e_high : The upper bound of the energy bins.
350
+ n_points : The number of points used to integrate the model in each bin.
499
351
 
500
- def __call__(self, **kwargs) -> SpectralModel:
501
- """
502
- This method enable to use model components as haiku modules when folded in a haiku transform
503
- function and also to instantiate them as SpectralModel when out of a haiku transform
352
+ !!! info
353
+ This method is internally used in the inference process and should not be used directly. See
354
+ [`energy_flux`](/references/results/#jaxspec.analysis.results.FitResult.energy_flux) to compute
355
+ the energy flux associated with a set of fitted parameters in a
356
+ [`FitResult`](/references/results/#jaxspec.analysis.results.FitResult)
357
+ instead.
504
358
  """
505
359
 
506
- if not base.frame_stack:
507
- return SpectralModel.from_component(self, **kwargs)
360
+ graphdef, state = nnx.split(self)
361
+ state = set_parameters(params, state)
508
362
 
509
- else:
510
- return super().__call__(**kwargs)
363
+ return nnx.call((graphdef, state)).turbo_flux(
364
+ e_low, e_high, n_points=n_points, energy_flux=True
365
+ )[0]
511
366
 
512
367
 
513
- class ModelComponent(hk.Module, ABC, metaclass=ComponentMetaClass):
368
+ class ModelComponent(nnx.Module, Composable, ABC):
514
369
  """
515
370
  Abstract class for model components
516
371
  """
517
372
 
518
- type: str
373
+ ...
519
374
 
520
- def __init__(self, *args, **kwargs):
521
- super().__init__(*args, **kwargs)
522
375
 
523
-
524
- class AdditiveComponent(ModelComponent, ABC):
376
+ class AdditiveComponent(ModelComponent):
525
377
  type = "additive"
526
378
 
527
379
  def continuum(self, energy):
528
- """
529
- Method for computing the continuum associated to the model.
530
- By default, this is set to 0, which means that the model has no continuum.
531
- This should be overloaded by the user if the model has a continuum.
532
- """
380
+ r"""
381
+ Compute the continuum of the component.
533
382
 
383
+ Parameters:
384
+ energy : The energy at which to compute the continuum.
385
+ """
534
386
  return jnp.zeros_like(energy)
535
387
 
536
- def emission_lines(self, e_min, e_max) -> (jax.Array, jax.Array):
537
- """
538
- Method for computing the fine structure of an additive model between two energies.
539
- By default, this is set to 0, which means that the model has no emission lines.
540
- This should be overloaded by the user if the model has a fine structure.
388
+ def integrated_continuum(self, e_low, e_high):
389
+ r"""
390
+ Compute the integrated continuum between $E_\min$ and $E_\max$.
391
+
392
+ Parameters:
393
+ e_low: Lower bound of the energy bin.
394
+ e_high: Upper bound of the energy bin.
541
395
  """
396
+ return jnp.zeros_like((e_low + e_high) / 2)
542
397
 
543
- return jnp.zeros_like(e_min), (e_min + e_max) / 2
398
+ def _photon_flux(self, e_low, e_high, n_points=2):
399
+ energy = jnp.linspace(e_low, e_high, n_points, axis=-1)
400
+ continuum = self.continuum(energy)
401
+ integrated_continuum = self.integrated_continuum(e_low, e_high)
544
402
 
545
- '''
546
- def integral(self, e_min, e_max):
547
- r"""
548
- Method for integrating an additive model between two energies. It relies on
549
- double exponential quadrature for finite intervals to compute an approximation
550
- of the integral of a model.
403
+ return jsp.integrate.trapezoid(continuum, energy, axis=-1) + integrated_continuum
551
404
 
552
- references
553
- ----------
554
- * $Takahasi and Mori (1974) <https://ems.press/journals/prims/articles/2686>$_
555
- * $Mori and Sugihara (2001) <https://doi.org/10.1016/S0377-0427(00)00501-X>$_
556
- * $Tanh-sinh quadrature <https://en.wikipedia.org/wiki/Tanh-sinh_quadrature>$_ from Wikipedia
405
+ def _energy_flux(self, e_low, e_high, n_points=2):
406
+ energy = jnp.linspace(e_low, e_high, n_points, axis=-1)
407
+ continuum = self.continuum(energy)
408
+ integrated_continuum = self.integrated_continuum(e_low, e_high)
557
409
 
558
- """
410
+ return jsp.integrate.trapezoid(
411
+ continuum * energy**2, jnp.log(energy), axis=-1
412
+ ) + integrated_continuum * (e_high - e_low)
559
413
 
560
- t = jnp.linspace(-4, 4, 71) # The number of points used is hardcoded and this is not ideal
561
- # Quadrature nodes as defined in reference
562
- phi = jnp.tanh(jnp.pi / 2 * jnp.sinh(t))
563
- dphi = jnp.pi / 2 * jnp.cosh(t) * (1 / jnp.cosh(jnp.pi / 2 * jnp.sinh(t)) ** 2)
564
- # Change of variable to turn the integral from E_min to E_max into an integral from -1 to 1
565
- x = (e_max - e_min) / 2 * phi + (e_max + e_min) / 2
566
- dx = (e_max - e_min) / 2 * dphi
414
+ @partial(jax.jit, static_argnums=0, static_argnames="n_points")
415
+ def photon_flux(self, params, e_low, e_high, n_points=2):
416
+ return SpectralModel.from_component(self).photon_flux(
417
+ params, e_low, e_high, n_points=n_points
418
+ )
567
419
 
568
- return jnp.trapz(self(x) * dx, x=t)
569
- '''
420
+ @partial(jax.jit, static_argnums=0, static_argnames="n_points")
421
+ def energy_flux(self, params, e_low, e_high, n_points=2):
422
+ return SpectralModel.from_component(self).energy_flux(
423
+ params, e_low, e_high, n_points=n_points
424
+ )
570
425
 
571
426
 
572
- class MultiplicativeComponent(ModelComponent, ABC):
427
+ class MultiplicativeComponent(ModelComponent):
573
428
  type = "multiplicative"
574
429
 
575
- @abstractmethod
576
- def continuum(self, energy): ...
430
+ def factor(self, energy):
431
+ """
432
+ Absorption factor applied for a given energy
433
+
434
+ Parameters:
435
+ energy : The energy at which to compute the factor.
436
+ """
437
+ return jnp.ones_like(energy)