jaxspec 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxspec/_fit/__init__.py +0 -0
- jaxspec/_fit/_build_model.py +63 -0
- jaxspec/analysis/_plot.py +166 -7
- jaxspec/analysis/results.py +238 -336
- jaxspec/data/instrument.py +47 -12
- jaxspec/data/obsconf.py +12 -2
- jaxspec/data/observation.py +68 -11
- jaxspec/data/ogip.py +32 -13
- jaxspec/data/util.py +5 -75
- jaxspec/fit.py +101 -140
- jaxspec/model/_graph_util.py +151 -0
- jaxspec/model/abc.py +275 -414
- jaxspec/model/additive.py +276 -289
- jaxspec/model/background.py +94 -87
- jaxspec/model/multiplicative.py +101 -85
- jaxspec/scripts/debug.py +1 -1
- jaxspec/util/__init__.py +0 -45
- jaxspec/util/misc.py +25 -0
- jaxspec/util/typing.py +0 -63
- {jaxspec-0.1.3.dist-info → jaxspec-0.2.0.dist-info}/METADATA +36 -16
- jaxspec-0.2.0.dist-info/RECORD +34 -0
- {jaxspec-0.1.3.dist-info → jaxspec-0.2.0.dist-info}/WHEEL +1 -1
- jaxspec/data/grouping.py +0 -23
- jaxspec-0.1.3.dist-info/RECORD +0 -31
- {jaxspec-0.1.3.dist-info → jaxspec-0.2.0.dist-info}/LICENSE.md +0 -0
- {jaxspec-0.1.3.dist-info → jaxspec-0.2.0.dist-info}/entry_points.txt +0 -0
jaxspec/fit.py
CHANGED
|
@@ -7,126 +7,31 @@ from functools import cached_property
|
|
|
7
7
|
from typing import Literal
|
|
8
8
|
|
|
9
9
|
import arviz as az
|
|
10
|
-
import haiku as hk
|
|
11
10
|
import jax
|
|
12
11
|
import jax.numpy as jnp
|
|
13
12
|
import matplotlib.pyplot as plt
|
|
14
|
-
import numpy as np
|
|
15
13
|
import numpyro
|
|
16
14
|
|
|
17
15
|
from jax import random
|
|
18
|
-
from jax.experimental.sparse import BCOO
|
|
19
16
|
from jax.random import PRNGKey
|
|
20
|
-
from jax.tree_util import tree_map
|
|
21
|
-
from jax.typing import ArrayLike
|
|
22
17
|
from numpyro.contrib.nested_sampling import NestedSampler
|
|
23
|
-
from numpyro.distributions import
|
|
18
|
+
from numpyro.distributions import Poisson, TransformedDistribution
|
|
24
19
|
from numpyro.infer import AIES, ESS, MCMC, NUTS, Predictive
|
|
25
20
|
from numpyro.infer.inspect import get_model_relations
|
|
26
21
|
from numpyro.infer.reparam import TransformReparam
|
|
27
22
|
from numpyro.infer.util import log_density
|
|
28
23
|
|
|
29
|
-
from .
|
|
24
|
+
from ._fit._build_model import build_prior, forward_model
|
|
25
|
+
from .analysis._plot import (
|
|
26
|
+
_error_bars_for_observed_data,
|
|
27
|
+
_plot_binned_samples_with_error,
|
|
28
|
+
_plot_poisson_data_with_error,
|
|
29
|
+
)
|
|
30
30
|
from .analysis.results import FitResult
|
|
31
31
|
from .data import ObsConfiguration
|
|
32
32
|
from .model.abc import SpectralModel
|
|
33
33
|
from .model.background import BackgroundModel
|
|
34
|
-
from .util.typing import
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def build_prior(prior: PriorDictType, expand_shape: tuple = (), prefix=""):
|
|
38
|
-
"""
|
|
39
|
-
Transform a dictionary of prior distributions into a dictionary of parameters sampled from the prior.
|
|
40
|
-
Must be used within a numpyro model.
|
|
41
|
-
"""
|
|
42
|
-
parameters = dict(hk.data_structures.to_haiku_dict(prior))
|
|
43
|
-
|
|
44
|
-
for i, (m, n, sample) in enumerate(hk.data_structures.traverse(prior)):
|
|
45
|
-
if isinstance(sample, Distribution):
|
|
46
|
-
parameters[m][n] = jnp.ones(expand_shape) * numpyro.sample(f"{prefix}{m}_{n}", sample)
|
|
47
|
-
|
|
48
|
-
elif isinstance(sample, ArrayLike):
|
|
49
|
-
parameters[m][n] = jnp.ones(expand_shape) * sample
|
|
50
|
-
|
|
51
|
-
else:
|
|
52
|
-
raise ValueError(
|
|
53
|
-
f"Invalid prior type {type(sample)} for parameter {prefix}{m}_{n} : {sample}"
|
|
54
|
-
)
|
|
55
|
-
|
|
56
|
-
return parameters
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
def build_numpyro_model_for_single_obs(
|
|
60
|
-
obs: ObsConfiguration,
|
|
61
|
-
model: SpectralModel,
|
|
62
|
-
background_model: BackgroundModel,
|
|
63
|
-
name: str = "",
|
|
64
|
-
sparse: bool = False,
|
|
65
|
-
) -> Callable:
|
|
66
|
-
"""
|
|
67
|
-
Build a numpyro model for a given observation and spectral model.
|
|
68
|
-
"""
|
|
69
|
-
|
|
70
|
-
def numpyro_model(prior_params, observed=True):
|
|
71
|
-
# prior_params = build_prior(prior_distributions, name=name)
|
|
72
|
-
transformed_model = hk.without_apply_rng(
|
|
73
|
-
hk.transform(lambda par: CountForwardModel(model, obs, sparse=sparse)(par))
|
|
74
|
-
)
|
|
75
|
-
|
|
76
|
-
if (getattr(obs, "folded_background", None) is not None) and (background_model is not None):
|
|
77
|
-
bkg_countrate = background_model.numpyro_model(
|
|
78
|
-
obs, model, name="bkg_" + name, observed=observed
|
|
79
|
-
)
|
|
80
|
-
elif (getattr(obs, "folded_background", None) is None) and (background_model is not None):
|
|
81
|
-
raise ValueError(
|
|
82
|
-
"Trying to fit a background model but no background is linked to this observation"
|
|
83
|
-
)
|
|
84
|
-
|
|
85
|
-
else:
|
|
86
|
-
bkg_countrate = 0.0
|
|
87
|
-
|
|
88
|
-
obs_model = jax.jit(lambda p: transformed_model.apply(None, p))
|
|
89
|
-
countrate = obs_model(prior_params)
|
|
90
|
-
|
|
91
|
-
# This is the case where we fit a model to a TOTAL spectrum as defined in OGIP standard
|
|
92
|
-
with numpyro.plate("obs_plate_" + name, len(obs.folded_counts)):
|
|
93
|
-
numpyro.sample(
|
|
94
|
-
"obs_" + name,
|
|
95
|
-
Poisson(countrate + bkg_countrate / obs.folded_backratio.data),
|
|
96
|
-
obs=obs.folded_counts.data if observed else None,
|
|
97
|
-
)
|
|
98
|
-
|
|
99
|
-
return numpyro_model
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
class CountForwardModel(hk.Module):
|
|
103
|
-
"""
|
|
104
|
-
A haiku module which allows to build the function that simulates the measured counts
|
|
105
|
-
"""
|
|
106
|
-
|
|
107
|
-
def __init__(self, model: SpectralModel, folding: ObsConfiguration, sparse=False):
|
|
108
|
-
super().__init__()
|
|
109
|
-
self.model = model
|
|
110
|
-
self.energies = jnp.asarray(folding.in_energies)
|
|
111
|
-
|
|
112
|
-
if (
|
|
113
|
-
sparse
|
|
114
|
-
): # folding.transfer_matrix.data.density > 0.015 is a good criterion to consider sparsify
|
|
115
|
-
self.transfer_matrix = BCOO.from_scipy_sparse(
|
|
116
|
-
folding.transfer_matrix.data.to_scipy_sparse().tocsr()
|
|
117
|
-
)
|
|
118
|
-
|
|
119
|
-
else:
|
|
120
|
-
self.transfer_matrix = jnp.asarray(folding.transfer_matrix.data.todense())
|
|
121
|
-
|
|
122
|
-
def __call__(self, parameters):
|
|
123
|
-
"""
|
|
124
|
-
Compute the count functions for a given observation.
|
|
125
|
-
"""
|
|
126
|
-
|
|
127
|
-
expected_counts = self.transfer_matrix @ self.model.photon_flux(parameters, *self.energies)
|
|
128
|
-
|
|
129
|
-
return jnp.clip(expected_counts, a_min=1e-6)
|
|
34
|
+
from .util.typing import PriorDictType
|
|
130
35
|
|
|
131
36
|
|
|
132
37
|
class BayesianModel:
|
|
@@ -157,15 +62,16 @@ class BayesianModel:
|
|
|
157
62
|
self.model = model
|
|
158
63
|
self._observations = observations
|
|
159
64
|
self.background_model = background_model
|
|
160
|
-
self.pars = tree_map(lambda x: jnp.float64(x), self.model.params)
|
|
161
65
|
self.sparse = sparsify_matrix
|
|
162
66
|
|
|
163
67
|
if not callable(prior_distributions):
|
|
164
68
|
# Validate the entry with pydantic
|
|
165
|
-
prior = PriorDictModel.from_dict(prior_distributions).
|
|
69
|
+
# prior = PriorDictModel.from_dict(prior_distributions).
|
|
166
70
|
|
|
167
71
|
def prior_distributions_func():
|
|
168
|
-
return build_prior(
|
|
72
|
+
return build_prior(
|
|
73
|
+
prior_distributions, expand_shape=(len(self.observation_container),)
|
|
74
|
+
)
|
|
169
75
|
|
|
170
76
|
else:
|
|
171
77
|
prior_distributions_func = prior_distributions
|
|
@@ -173,6 +79,22 @@ class BayesianModel:
|
|
|
173
79
|
self.prior_distributions_func = prior_distributions_func
|
|
174
80
|
self.init_params = self.prior_samples()
|
|
175
81
|
|
|
82
|
+
# Check the priors are suited for the observations
|
|
83
|
+
split_parameters = [
|
|
84
|
+
(param, shape[-1])
|
|
85
|
+
for param, shape in jax.tree.map(lambda x: x.shape, self.init_params).items()
|
|
86
|
+
if (len(shape) > 1)
|
|
87
|
+
and not param.startswith("_")
|
|
88
|
+
and not param.startswith("bkg") # hardcoded for subtracted background
|
|
89
|
+
]
|
|
90
|
+
|
|
91
|
+
for parameter, proposed_number_of_obs in split_parameters:
|
|
92
|
+
if proposed_number_of_obs != len(self.observation_container):
|
|
93
|
+
raise ValueError(
|
|
94
|
+
f"Invalid splitting in the prior distribution. "
|
|
95
|
+
f"Expected {len(self.observation_container)} but got {proposed_number_of_obs} for {parameter}"
|
|
96
|
+
)
|
|
97
|
+
|
|
176
98
|
@cached_property
|
|
177
99
|
def observation_container(self) -> dict[str, ObsConfiguration]:
|
|
178
100
|
"""
|
|
@@ -197,22 +119,52 @@ class BayesianModel:
|
|
|
197
119
|
Build the numpyro model using the observed data, the prior distributions and the spectral model.
|
|
198
120
|
"""
|
|
199
121
|
|
|
200
|
-
def
|
|
122
|
+
def numpyro_model(observed=True):
|
|
123
|
+
# Instantiate and register the parameters of the spectral model and the background
|
|
201
124
|
prior_params = self.prior_distributions_func()
|
|
202
125
|
|
|
203
126
|
# Iterate over all the observations in our container and build a single numpyro model for each observation
|
|
204
|
-
for i, (
|
|
127
|
+
for i, (name, observation) in enumerate(self.observation_container.items()):
|
|
128
|
+
# Check that we can indeed fit a background
|
|
129
|
+
if (getattr(observation, "folded_background", None) is not None) and (
|
|
130
|
+
self.background_model is not None
|
|
131
|
+
):
|
|
132
|
+
# This call should register the parameter and observation of our background model
|
|
133
|
+
bkg_countrate = self.background_model.numpyro_model(
|
|
134
|
+
observation, name=name, observed=observed
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
elif (getattr(observation, "folded_background", None) is None) and (
|
|
138
|
+
self.background_model is not None
|
|
139
|
+
):
|
|
140
|
+
raise ValueError(
|
|
141
|
+
"Trying to fit a background model but no background is linked to this observation"
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
else:
|
|
145
|
+
bkg_countrate = 0.0
|
|
146
|
+
|
|
205
147
|
# We expect that prior_params contains an array of parameters for each observation
|
|
206
148
|
# They can be identical or different for each observation
|
|
207
|
-
params =
|
|
149
|
+
params = jax.tree.map(lambda x: x[i], prior_params)
|
|
208
150
|
|
|
209
|
-
|
|
210
|
-
|
|
151
|
+
# Forward model the observation and get the associated countrate
|
|
152
|
+
obs_model = jax.jit(
|
|
153
|
+
lambda par: forward_model(self.model, par, observation, sparse=self.sparse)
|
|
211
154
|
)
|
|
155
|
+
obs_countrate = obs_model(params)
|
|
212
156
|
|
|
213
|
-
|
|
157
|
+
# Register the observation as an observed site
|
|
158
|
+
with numpyro.plate("obs_plate_" + name, len(observation.folded_counts)):
|
|
159
|
+
numpyro.sample(
|
|
160
|
+
"obs_" + name,
|
|
161
|
+
Poisson(
|
|
162
|
+
obs_countrate + bkg_countrate
|
|
163
|
+
), # / observation.folded_backratio.data
|
|
164
|
+
obs=observation.folded_counts.data if observed else None,
|
|
165
|
+
)
|
|
214
166
|
|
|
215
|
-
return
|
|
167
|
+
return numpyro_model
|
|
216
168
|
|
|
217
169
|
@cached_property
|
|
218
170
|
def transformed_numpyro_model(self) -> Callable:
|
|
@@ -352,7 +304,9 @@ class BayesianModel:
|
|
|
352
304
|
return fakeit(key, parameters)
|
|
353
305
|
|
|
354
306
|
def prior_predictive_coverage(
|
|
355
|
-
self,
|
|
307
|
+
self,
|
|
308
|
+
key: PRNGKey = PRNGKey(0),
|
|
309
|
+
num_samples: int = 1000,
|
|
356
310
|
):
|
|
357
311
|
"""
|
|
358
312
|
Check if the prior distribution include the observed data.
|
|
@@ -362,24 +316,25 @@ class BayesianModel:
|
|
|
362
316
|
posterior_observations = self.mock_observations(prior_params, key=key_posterior)
|
|
363
317
|
|
|
364
318
|
for key, value in self.observation_container.items():
|
|
365
|
-
fig,
|
|
366
|
-
nrows=2, ncols=1, sharex=True, figsize=(
|
|
319
|
+
fig, ax = plt.subplots(
|
|
320
|
+
nrows=2, ncols=1, sharex=True, figsize=(5, 6), height_ratios=[3, 1]
|
|
367
321
|
)
|
|
368
322
|
|
|
369
|
-
|
|
370
|
-
|
|
323
|
+
y_observed, y_observed_low, y_observed_high = _error_bars_for_observed_data(
|
|
324
|
+
value.folded_counts.values, 1.0, "ct"
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
true_data_plot = _plot_poisson_data_with_error(
|
|
328
|
+
ax[0],
|
|
371
329
|
value.out_energies,
|
|
372
|
-
value
|
|
373
|
-
|
|
330
|
+
y_observed.value,
|
|
331
|
+
y_observed_low.value,
|
|
332
|
+
y_observed_high.value,
|
|
333
|
+
alpha=0.7,
|
|
374
334
|
)
|
|
375
335
|
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
edges=[*list(value.out_energies[0]), value.out_energies[1][-1]],
|
|
379
|
-
baseline=np.min(posterior_observations["obs_" + key], axis=0),
|
|
380
|
-
alpha=0.3,
|
|
381
|
-
fill=True,
|
|
382
|
-
color=(0.15, 0.25, 0.45),
|
|
336
|
+
prior_plot = _plot_binned_samples_with_error(
|
|
337
|
+
ax[0], value.out_energies, posterior_observations["obs_" + key], n_sigmas=3
|
|
383
338
|
)
|
|
384
339
|
|
|
385
340
|
# rank = np.vstack((posterior_observations["obs_" + key], value.folded_counts.values)).argsort(axis=0)[-1] / (num_samples) * 100
|
|
@@ -393,22 +348,24 @@ class BayesianModel:
|
|
|
393
348
|
|
|
394
349
|
rank = (less_than_obs + 0.5 * equal_to_obs) / num_samples * 100
|
|
395
350
|
|
|
396
|
-
|
|
351
|
+
ax[1].stairs(rank, edges=[*list(value.out_energies[0]), value.out_energies[1][-1]])
|
|
397
352
|
|
|
398
|
-
|
|
353
|
+
ax[1].plot(
|
|
399
354
|
(value.out_energies.min(), value.out_energies.max()),
|
|
400
355
|
(50, 50),
|
|
401
356
|
color="black",
|
|
402
357
|
linestyle="--",
|
|
403
358
|
)
|
|
404
359
|
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
360
|
+
ax[1].set_xlabel("Energy (keV)")
|
|
361
|
+
ax[0].set_ylabel("Counts")
|
|
362
|
+
ax[1].set_ylabel("Rank (%)")
|
|
363
|
+
ax[1].set_ylim(0, 100)
|
|
364
|
+
ax[0].set_xlim(value.out_energies.min(), value.out_energies.max())
|
|
365
|
+
ax[0].loglog()
|
|
366
|
+
ax[0].legend(loc="upper right")
|
|
411
367
|
plt.suptitle(f"Prior Predictive coverage for {key}")
|
|
368
|
+
plt.tight_layout()
|
|
412
369
|
plt.show()
|
|
413
370
|
|
|
414
371
|
|
|
@@ -513,7 +470,11 @@ class BayesianModelFitter(BayesianModel, ABC):
|
|
|
513
470
|
predictive_parameters
|
|
514
471
|
]
|
|
515
472
|
|
|
516
|
-
parameters = [
|
|
473
|
+
parameters = [
|
|
474
|
+
x
|
|
475
|
+
for x in inference_data.posterior.keys()
|
|
476
|
+
if not x.endswith("_base") or x.startswith("_")
|
|
477
|
+
]
|
|
517
478
|
inference_data.posterior = inference_data.posterior[parameters]
|
|
518
479
|
inference_data.prior = inference_data.prior[parameters]
|
|
519
480
|
|
|
@@ -595,7 +556,6 @@ class MCMCFitter(BayesianModelFitter):
|
|
|
595
556
|
return FitResult(
|
|
596
557
|
self,
|
|
597
558
|
inference_data,
|
|
598
|
-
self.model.params,
|
|
599
559
|
background_model=self.background_model,
|
|
600
560
|
)
|
|
601
561
|
|
|
@@ -641,11 +601,13 @@ class NSFitter(BayesianModelFitter):
|
|
|
641
601
|
ns = NestedSampler(
|
|
642
602
|
bayesian_model,
|
|
643
603
|
constructor_kwargs=dict(
|
|
644
|
-
num_parallel_workers=1,
|
|
645
604
|
verbose=verbose,
|
|
646
605
|
difficult_model=True,
|
|
647
|
-
max_samples=
|
|
606
|
+
max_samples=1e5,
|
|
648
607
|
parameter_estimation=True,
|
|
608
|
+
gradient_guided=True,
|
|
609
|
+
devices=jax.devices(),
|
|
610
|
+
# init_efficiency_threshold=0.01,
|
|
649
611
|
num_live_points=num_live_points,
|
|
650
612
|
),
|
|
651
613
|
termination_kwargs=termination_kwargs if termination_kwargs else dict(),
|
|
@@ -664,6 +626,5 @@ class NSFitter(BayesianModelFitter):
|
|
|
664
626
|
return FitResult(
|
|
665
627
|
self,
|
|
666
628
|
inference_data,
|
|
667
|
-
self.model.params,
|
|
668
629
|
background_model=self.background_model,
|
|
669
630
|
)
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
"""Helper functions to deal with the graph logic within model building"""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from uuid import uuid4
|
|
7
|
+
|
|
8
|
+
import networkx as nx
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_component_names(graph: nx.DiGraph):
|
|
12
|
+
"""
|
|
13
|
+
Get the set of component names from the nodes of a graph.
|
|
14
|
+
|
|
15
|
+
Parameters:
|
|
16
|
+
graph: The graph to get the component names from.
|
|
17
|
+
"""
|
|
18
|
+
return set(
|
|
19
|
+
data["name"] for _, data in graph.nodes(data=True) if "component" in data.get("type")
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def increment_name(name: str, used_names: set):
|
|
24
|
+
"""
|
|
25
|
+
Increment the suffix number in a name if it is formated as 'name_1'.
|
|
26
|
+
|
|
27
|
+
Parameters:
|
|
28
|
+
name: The name to increment.
|
|
29
|
+
used_names: The set of names that are already used.
|
|
30
|
+
"""
|
|
31
|
+
# Use regex to extract base name and suffix number
|
|
32
|
+
match = re.match(r"(.*?)(?:_(\d+))?$", name)
|
|
33
|
+
base_name = match.group(1)
|
|
34
|
+
suffix = match.group(2)
|
|
35
|
+
if suffix:
|
|
36
|
+
number = int(suffix)
|
|
37
|
+
else:
|
|
38
|
+
number = 1 # Start from 1 if there is no suffix
|
|
39
|
+
|
|
40
|
+
new_name = name
|
|
41
|
+
while new_name in used_names:
|
|
42
|
+
number += 1
|
|
43
|
+
new_name = f"{base_name}_{number}"
|
|
44
|
+
|
|
45
|
+
return new_name
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def compose_with_rename(graph_1: nx.DiGraph, graph_2: nx.DiGraph):
|
|
49
|
+
"""
|
|
50
|
+
Compose two graphs by updating the 'name' attributes of nodes in graph_2,
|
|
51
|
+
and return the graph joined on the 'out' node.
|
|
52
|
+
|
|
53
|
+
Parameters:
|
|
54
|
+
graph_1: The first graph to compose.
|
|
55
|
+
graph_2: The second graph to compose.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
# Initialize the set of used names with names from graph_1
|
|
59
|
+
used_names = get_component_names(graph_1)
|
|
60
|
+
|
|
61
|
+
# Update the 'name' attributes in graph_2 to make them unique
|
|
62
|
+
for node, data in graph_2.nodes(data=True):
|
|
63
|
+
if "component" in data.get("type"):
|
|
64
|
+
original_name = data["name"]
|
|
65
|
+
new_name = original_name
|
|
66
|
+
|
|
67
|
+
if new_name in used_names:
|
|
68
|
+
new_name = increment_name(original_name, used_names)
|
|
69
|
+
data["name"] = new_name
|
|
70
|
+
used_names.add(new_name)
|
|
71
|
+
|
|
72
|
+
else:
|
|
73
|
+
used_names.add(new_name)
|
|
74
|
+
|
|
75
|
+
# Now you can safely compose the graphs
|
|
76
|
+
composed_graph = nx.compose(graph_1, graph_2)
|
|
77
|
+
|
|
78
|
+
return composed_graph
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def compose(
|
|
82
|
+
graph_1: nx.DiGraph,
|
|
83
|
+
graph_2: nx.DiGraph,
|
|
84
|
+
operation: str = "",
|
|
85
|
+
operation_func: Callable = lambda x, y: None,
|
|
86
|
+
):
|
|
87
|
+
"""
|
|
88
|
+
Compose two graphs by joining the 'out' node of graph_1 and graph_2, and turning
|
|
89
|
+
it to an 'operation' node with the relevant operator and add a new 'out' node.
|
|
90
|
+
|
|
91
|
+
Parameters:
|
|
92
|
+
graph_1: The first graph to compose.
|
|
93
|
+
graph_2: The second graph to compose.
|
|
94
|
+
operation: The string describing the operation to perform.
|
|
95
|
+
operation_func: The callable that performs the operation.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
combined_graph = compose_with_rename(graph_1, graph_2)
|
|
99
|
+
node_id = str(uuid4())
|
|
100
|
+
graph = nx.relabel_nodes(combined_graph, {"out": node_id})
|
|
101
|
+
nx.set_node_attributes(graph, {node_id: f"{operation}_operation"}, "type")
|
|
102
|
+
nx.set_node_attributes(graph, {node_id: operation_func}, "operator")
|
|
103
|
+
|
|
104
|
+
# Now add the output node and link it to the operation node
|
|
105
|
+
graph.add_node("out", type="out")
|
|
106
|
+
graph.add_edge(node_id, "out")
|
|
107
|
+
|
|
108
|
+
# Compute the new depth of each node
|
|
109
|
+
longest_path = nx.dag_longest_path_length(graph)
|
|
110
|
+
|
|
111
|
+
for node in graph.nodes:
|
|
112
|
+
nx.set_node_attributes(
|
|
113
|
+
graph,
|
|
114
|
+
{node: longest_path - nx.shortest_path_length(graph, node, "out")},
|
|
115
|
+
"depth",
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
return graph
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def export_to_mermaid(graph, file=None):
|
|
122
|
+
mermaid_code = "graph LR\n" # LR = left to right
|
|
123
|
+
|
|
124
|
+
# Add nodes
|
|
125
|
+
for node, attributes in graph.nodes(data=True):
|
|
126
|
+
if attributes["type"] == "out":
|
|
127
|
+
mermaid_code += f' {node}("Output")\n'
|
|
128
|
+
|
|
129
|
+
else:
|
|
130
|
+
operation_type, node_type = attributes["type"].split("_")
|
|
131
|
+
|
|
132
|
+
if node_type == "component":
|
|
133
|
+
name, number = attributes["name"].split("_")
|
|
134
|
+
mermaid_code += f' {node}("{name.capitalize()} ({number})")\n'
|
|
135
|
+
|
|
136
|
+
elif node_type == "operation":
|
|
137
|
+
if operation_type == "add":
|
|
138
|
+
mermaid_code += f" {node}{{**+**}}\n"
|
|
139
|
+
|
|
140
|
+
elif operation_type == "mul":
|
|
141
|
+
mermaid_code += f" {node}{{**x**}}\n"
|
|
142
|
+
|
|
143
|
+
# Draw connexion between nodes
|
|
144
|
+
for source, target in graph.edges():
|
|
145
|
+
mermaid_code += f" {source} --> {target}\n"
|
|
146
|
+
|
|
147
|
+
if file is None:
|
|
148
|
+
return mermaid_code
|
|
149
|
+
else:
|
|
150
|
+
with open(file, "w") as f:
|
|
151
|
+
f.write(mermaid_code)
|