mxlpy 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. mxlpy/__init__.py +165 -0
  2. mxlpy/distributions.py +339 -0
  3. mxlpy/experimental/__init__.py +12 -0
  4. mxlpy/experimental/diff.py +226 -0
  5. mxlpy/fit.py +291 -0
  6. mxlpy/fns.py +191 -0
  7. mxlpy/integrators/__init__.py +19 -0
  8. mxlpy/integrators/int_assimulo.py +146 -0
  9. mxlpy/integrators/int_scipy.py +146 -0
  10. mxlpy/label_map.py +610 -0
  11. mxlpy/linear_label_map.py +303 -0
  12. mxlpy/mc.py +548 -0
  13. mxlpy/mca.py +280 -0
  14. mxlpy/meta/__init__.py +11 -0
  15. mxlpy/meta/codegen_latex.py +516 -0
  16. mxlpy/meta/codegen_modebase.py +110 -0
  17. mxlpy/meta/codegen_py.py +107 -0
  18. mxlpy/meta/source_tools.py +320 -0
  19. mxlpy/model.py +1737 -0
  20. mxlpy/nn/__init__.py +10 -0
  21. mxlpy/nn/_tensorflow.py +0 -0
  22. mxlpy/nn/_torch.py +129 -0
  23. mxlpy/npe.py +277 -0
  24. mxlpy/parallel.py +171 -0
  25. mxlpy/parameterise.py +27 -0
  26. mxlpy/paths.py +36 -0
  27. mxlpy/plot.py +875 -0
  28. mxlpy/py.typed +0 -0
  29. mxlpy/sbml/__init__.py +14 -0
  30. mxlpy/sbml/_data.py +77 -0
  31. mxlpy/sbml/_export.py +644 -0
  32. mxlpy/sbml/_import.py +599 -0
  33. mxlpy/sbml/_mathml.py +691 -0
  34. mxlpy/sbml/_name_conversion.py +52 -0
  35. mxlpy/sbml/_unit_conversion.py +74 -0
  36. mxlpy/scan.py +629 -0
  37. mxlpy/simulator.py +655 -0
  38. mxlpy/surrogates/__init__.py +31 -0
  39. mxlpy/surrogates/_poly.py +97 -0
  40. mxlpy/surrogates/_torch.py +196 -0
  41. mxlpy/symbolic/__init__.py +10 -0
  42. mxlpy/symbolic/strikepy.py +582 -0
  43. mxlpy/symbolic/symbolic_model.py +75 -0
  44. mxlpy/types.py +474 -0
  45. mxlpy-0.8.0.dist-info/METADATA +106 -0
  46. mxlpy-0.8.0.dist-info/RECORD +48 -0
  47. mxlpy-0.8.0.dist-info/WHEEL +4 -0
  48. mxlpy-0.8.0.dist-info/licenses/LICENSE +674 -0
@@ -0,0 +1,582 @@
1
+ # ruff: noqa: D100, D101, D102, D103, D104, D105, D106, D107, D200, D203, D400, D401
2
+
3
+ """Reimplementation of strikepy from.
4
+
5
+ StrikePy: https://github.com/afvillaverde/StrikePy
6
+ STRIKE-GOLDD: https://github.com/afvillaverde/strike-goldd
7
+
8
+ FIXME:
9
+ - no handling of derived variables
10
+ - performance issues of generic_rank
11
+ """
12
+
13
+ import textwrap
14
+ from concurrent.futures import ProcessPoolExecutor
15
+ from dataclasses import dataclass, field
16
+ from functools import partial
17
+ from math import ceil, inf
18
+ from time import time
19
+ from typing import cast
20
+
21
+ import numpy as np
22
+ import numpy.typing as npt
23
+ import symbtools as st
24
+ import sympy
25
+ import sympy as sym
26
+ import tqdm
27
+ from sympy import Matrix
28
+ from sympy.matrices import zeros
29
+
30
+ from mxlpy.model import Model
31
+
32
+ from .symbolic_model import SymbolicModel, to_symbolic_model
33
+
34
+ __all__ = [
35
+ "Options",
36
+ "Result",
37
+ "StrikepyModel",
38
+ "check_identifiability",
39
+ "strike_goldd",
40
+ ]
41
+
42
+
43
+ @dataclass
44
+ class Options:
45
+ check_observability: bool = True
46
+ max_lie_time: float = inf
47
+ non_zero_known_input_derivatives: list[int] = field(default_factory=lambda: [100])
48
+ non_zero_unknown_input_derivatives: list[int] = field(default_factory=lambda: [100])
49
+ prev_ident_pars: set[sympy.Symbol] = field(default_factory=set)
50
+
51
+
52
+ @dataclass
53
+ class StrikepyModel:
54
+ states: list[sym.Symbol]
55
+ pars: list[sym.Symbol]
56
+ eqs: list[sym.Expr]
57
+ outputs: list[sym.Symbol]
58
+ known_inputs: list[sym.Symbol] = field(default_factory=list)
59
+ unknown_inputs: list[sym.Symbol] = field(default_factory=list)
60
+
61
+
62
+ @dataclass
63
+ class Result:
64
+ rank: int
65
+ model: StrikepyModel
66
+ is_fispo: bool
67
+ par_ident: list
68
+ par_unident: list
69
+ state_obs: list
70
+ state_unobs: list
71
+ input_obs: list
72
+ input_unobs: list
73
+
74
+ def all_inputs_observable(self) -> bool:
75
+ return bool(
76
+ len(self.par_ident) == len(self.model.pars)
77
+ and len(self.model.unknown_inputs) > 0
78
+ )
79
+
80
+ def summary(self) -> str:
81
+ return textwrap.dedent(f"""\
82
+ Summary
83
+ =======
84
+ The model {"is" if self.is_fispo else "is not"} FISPO.
85
+ Identifiable parameters: {self.par_ident}
86
+ Unidentifiable parameters: {self.par_unident}
87
+ Identifiable variables: {self.state_obs}
88
+ Unidentifiable variables: {self.state_unobs}
89
+ Identifiable inputs: {self.input_obs}
90
+ Unidentifiable inputs: {self.input_unobs}
91
+ """)
92
+
93
+
94
+ def _rationalize_all_numbers(expr: sym.Matrix) -> sym.Matrix:
95
+ """Convert all numbers in expr to sympy.Rational-objects."""
96
+ numbers_atoms = list(expr.atoms(sym.Number))
97
+ rationalized_number_tpls = [(n, sym.Rational(n)) for n in numbers_atoms]
98
+ return expr.subs(rationalized_number_tpls)
99
+
100
+
101
+ def _calculate_num_rank(inp: tuple[int, list[int]], onx: Matrix) -> tuple[int, int]:
102
+ idx, indices = inp
103
+ return idx, st.generic_rank(onx.col(indices))
104
+
105
+
106
+ def _elim_and_recalc(
107
+ *,
108
+ model: StrikepyModel,
109
+ res: Result,
110
+ options: Options,
111
+ unmeas_xred_indices: list[int],
112
+ onx: sym.Matrix,
113
+ unidflag: bool,
114
+ w1: list[sym.Symbol],
115
+ ) -> None:
116
+ onx = _rationalize_all_numbers(onx)
117
+ par_ident = res.par_ident
118
+ state_obs = res.state_obs
119
+ input_obs = res.input_obs
120
+
121
+ r = sym.shape(onx)[1]
122
+ new_ident_pars = par_ident
123
+ new_nonid_pars = []
124
+ new_obs_states = state_obs
125
+ new_unobs_states = []
126
+ new_obs_in = input_obs
127
+ new_unobs_in = []
128
+
129
+ all_indices: list[tuple[int, list[int]]] = []
130
+ for idx in range(len(model.pars)):
131
+ if model.pars[idx] not in par_ident:
132
+ indices = list(range(r))
133
+ indices.pop(len(model.states) + idx)
134
+ all_indices.append((idx, indices))
135
+ with ProcessPoolExecutor() as ppe:
136
+ num_ranks = list(ppe.map(partial(_calculate_num_rank, onx=onx), all_indices))
137
+ for idx, num_rank in num_ranks:
138
+ if num_rank == res.rank:
139
+ if unidflag:
140
+ new_nonid_pars.append(model.pars[idx])
141
+ else:
142
+ new_ident_pars.append(model.pars[idx])
143
+
144
+ # At each iteration we try removing a different state from 'xred':
145
+ if options.check_observability:
146
+ all_indices = []
147
+ for idx in range(len(unmeas_xred_indices)):
148
+ orig_idx = unmeas_xred_indices[idx]
149
+ if model.states[orig_idx] not in state_obs:
150
+ indices = list(range(r))
151
+ indices.pop(orig_idx)
152
+ all_indices.append((orig_idx, indices))
153
+ with ProcessPoolExecutor() as ppe:
154
+ num_ranks = list(
155
+ ppe.map(partial(_calculate_num_rank, onx=onx), all_indices)
156
+ )
157
+ for orig_idx, num_rank in num_ranks:
158
+ if num_rank == res.rank:
159
+ if unidflag:
160
+ new_unobs_states.append(model.states[orig_idx])
161
+ else:
162
+ new_obs_states.append(model.states[orig_idx])
163
+
164
+ # At each iteration we try removing a different column from onx:
165
+ all_indices = []
166
+ for idx in range(len(w1)):
167
+ if w1[idx] not in input_obs:
168
+ indices = list(range(r))
169
+ indices.pop(len(model.states) + len(model.pars) + idx)
170
+ all_indices.append((idx, indices))
171
+ with ProcessPoolExecutor() as ppe:
172
+ num_ranks = list(ppe.map(partial(_calculate_num_rank, onx=onx), all_indices))
173
+ for idx, num_rank in num_ranks:
174
+ if num_rank == res.rank:
175
+ if unidflag:
176
+ new_unobs_in.append(w1[idx])
177
+ else:
178
+ new_obs_in.append(w1[idx])
179
+
180
+ res.par_ident = new_ident_pars
181
+ res.par_unident = new_nonid_pars
182
+ res.state_obs = new_obs_states
183
+ res.state_unobs = new_unobs_states
184
+ res.input_obs = new_obs_in
185
+ res.input_unobs = new_unobs_in
186
+
187
+
188
+ def _remove_identified_parameters(model: StrikepyModel, options: Options) -> None:
189
+ if len(options.prev_ident_pars) != 0:
190
+ model.pars = [i for i in model.pars if i not in options.prev_ident_pars]
191
+
192
+
193
+ def _get_measured_states(model: StrikepyModel) -> tuple[list[sym.Symbol], list[int]]:
194
+ # Check which states are directly measured, if any.
195
+ # Basically it is checked if any state is directly on the output,
196
+ # then that state is directly measurable.
197
+ is_measured: list[bool] = [False for i in range(len(model.states))]
198
+ for i, state in enumerate(model.states):
199
+ if state in model.outputs:
200
+ is_measured[i] = True
201
+
202
+ measured_state_idxs: list[int] = [i for i, j in enumerate(is_measured) if j]
203
+ unmeasured_state_idxs = [i for i, j in enumerate(is_measured) if not j]
204
+ measured_state_names = [model.states[i] for i in measured_state_idxs]
205
+ return measured_state_names, unmeasured_state_idxs
206
+
207
+
208
+ def _create_derivatives(
209
+ elements: list[sym.Symbol], n_min_lie_derivatives: int, n_derivatives: list[int]
210
+ ) -> list[list[sym.Symbol]]:
211
+ derivatives: list[list[float | sym.Symbol]] = []
212
+ for ind_u, element in enumerate(elements):
213
+ auxiliar: list[float | sym.Symbol] = [sym.Symbol(f"{element}")]
214
+ for k in range(n_min_lie_derivatives):
215
+ auxiliar.append(sym.Symbol(f"{element}_d{k + 1}")) # noqa: PERF401
216
+ derivatives.append(auxiliar)
217
+
218
+ if len(derivatives[0]) >= n_derivatives[ind_u] + 1:
219
+ for i in range(len(derivatives[0][(n_derivatives[ind_u] + 1) :])):
220
+ derivatives[ind_u][(n_derivatives[ind_u] + 1) + i] = 0
221
+ return derivatives # type: ignore
222
+
223
+
224
+ def _create_w1_vector(
225
+ model: StrikepyModel, w_der: list[list[sym.Symbol]]
226
+ ) -> tuple[list[sym.Symbol], list[sym.Symbol]]:
227
+ w1vector = []
228
+ w1vector_dot = []
229
+
230
+ if len(model.unknown_inputs) == 0:
231
+ return w1vector, w1vector_dot
232
+
233
+ w1vector.extend(w_der[:-1])
234
+ w1vector_dot.extend(w_der[1:])
235
+
236
+ # -- Include as states only nonzero inputs / derivatives:
237
+ nzi = []
238
+ nzj = []
239
+ nz_w1vec = []
240
+ for fila in range(len(w1vector)):
241
+ if w1vector[fila][0] != 0:
242
+ nzi.append([fila])
243
+ nzj.append([1])
244
+ nz_w1vec.append(w1vector[fila])
245
+
246
+ w1vector = nz_w1vec
247
+ w1vector_dot = w1vector_dot[0 : len(nzi)]
248
+ return w1vector, w1vector_dot
249
+
250
+
251
+ def _create_xaug_faug(
252
+ model: StrikepyModel,
253
+ w1vector: list[sym.Symbol],
254
+ w1vector_dot: list[sym.Symbol],
255
+ ) -> tuple[npt.NDArray, npt.NDArray]:
256
+ xaug = np.array(model.states)
257
+ xaug = np.append(xaug, model.pars, axis=0) # type: ignore
258
+ if len(w1vector) != 0:
259
+ xaug = np.append(xaug, w1vector, axis=0) # type: ignore
260
+
261
+ faug = np.atleast_2d(np.array(model.eqs, dtype=object)).T
262
+ faug = np.append(faug, zeros(len(model.pars), 1), axis=0)
263
+ if len(w1vector) != 0:
264
+ faug = np.append(faug, w1vector_dot, axis=0) # type: ignore
265
+ return xaug, faug
266
+
267
+
268
+ def _compute_extra_term(
269
+ extra_term: npt.NDArray,
270
+ ind: int,
271
+ past_lie: sym.Matrix,
272
+ input_der: list[list[sym.Symbol]],
273
+ zero_input_der_dummy_name: sym.Symbol,
274
+ ) -> npt.NDArray:
275
+ for i in range(ind):
276
+ column = len(input_der) - 1
277
+ if i < column:
278
+ lo_u_der = input_der[i]
279
+ if lo_u_der == 0:
280
+ lo_u_der = zero_input_der_dummy_name
281
+ lo_u_der = np.array([lo_u_der])
282
+ hi_u_der = input_der[i + 1]
283
+ hi_u_der = Matrix([hi_u_der])
284
+ intermedio = past_lie.jacobian(lo_u_der) * hi_u_der
285
+ extra_term = extra_term + intermedio if extra_term else intermedio
286
+ return extra_term
287
+
288
+
289
+ def _compute_n_min_lie_derivatives(model: StrikepyModel) -> int:
290
+ n_outputs = len(model.outputs)
291
+ n_states = len(model.states)
292
+ n_unknown_pars = len(model.pars)
293
+ n_unknown_inp = len(model.unknown_inputs)
294
+ n_vars_to_observe = n_states + n_unknown_pars + n_unknown_inp
295
+ return ceil((n_vars_to_observe - n_outputs) / n_outputs)
296
+
297
+
298
+ def _test_fispo(
299
+ model: StrikepyModel,
300
+ res: Result,
301
+ measured_state_names: list[sym.Symbol],
302
+ ) -> bool:
303
+ if len(res.par_ident) == len(model.pars) and (
304
+ len(res.state_obs) + len(measured_state_names)
305
+ ) == len(model.states):
306
+ res.is_fispo = True
307
+ res.state_obs = model.states
308
+ res.input_obs = model.unknown_inputs
309
+ res.par_ident = model.pars
310
+
311
+ return res.is_fispo
312
+
313
+
314
+ def _create_onx(
315
+ model: StrikepyModel,
316
+ n_min_lie_derivatives: int,
317
+ options: Options,
318
+ w1vector: list[sym.Symbol],
319
+ xaug: npt.ArrayLike,
320
+ faug: npt.ArrayLike,
321
+ input_der: list[list[sym.Symbol]],
322
+ zero_input_der_dummy_name: sym.Symbol,
323
+ ) -> tuple[npt.NDArray, sym.Matrix]:
324
+ onx = np.array(
325
+ zeros(
326
+ len(model.outputs) * (1 + n_min_lie_derivatives),
327
+ len(model.states) + len(model.pars) + len(w1vector),
328
+ )
329
+ )
330
+ jacobian = sym.Matrix(model.outputs).jacobian(xaug)
331
+
332
+ # first row(s) of onx (derivative of the output with respect to the vector states+unknown parameters).
333
+ onx[0 : len(model.outputs)] = np.array(jacobian)
334
+
335
+ past_Lie = sym.Matrix(model.outputs)
336
+ extra_term = np.array(0)
337
+
338
+ # loop as long as I don't complete the preset Lie derivatives or go over the maximum time
339
+ t_start = time()
340
+
341
+ onx[(len(model.outputs)) : 2 * len(model.outputs)] = past_Lie.jacobian(xaug)
342
+ for ind in range(1, n_min_lie_derivatives):
343
+ if (time() - t_start) > options.max_lie_time:
344
+ msg = "More Lie derivatives would be needed to analyse the model."
345
+ raise TimeoutError(msg)
346
+
347
+ lie_derivatives = Matrix(
348
+ (onx[(ind * len(model.outputs)) : (ind + 1) * len(model.outputs)][:]).dot(
349
+ faug
350
+ )
351
+ )
352
+ extra_term = _compute_extra_term(
353
+ extra_term,
354
+ ind=ind,
355
+ past_lie=past_Lie,
356
+ input_der=input_der,
357
+ zero_input_der_dummy_name=zero_input_der_dummy_name,
358
+ )
359
+
360
+ ext_Lie = lie_derivatives + extra_term if extra_term else lie_derivatives
361
+ past_Lie = ext_Lie
362
+ onx[((ind + 1) * len(model.outputs)) : (ind + 2) * len(model.outputs)] = (
363
+ sym.Matrix(ext_Lie).jacobian(xaug)
364
+ )
365
+ return onx, cast(sym.Matrix, past_Lie)
366
+
367
+
368
+ def strike_goldd(model: StrikepyModel, options: Options | None = None) -> Result:
369
+ options = Options() if options is None else options
370
+
371
+ # Check if the size of nnzDerU and nnzDerW are appropriate
372
+ if len(model.known_inputs) > len(options.non_zero_known_input_derivatives):
373
+ msg = (
374
+ "The number of known inputs is higher than the size of nnzDerU "
375
+ "and must have the same size."
376
+ )
377
+ raise ValueError(msg)
378
+ if len(model.unknown_inputs) > len(options.non_zero_unknown_input_derivatives):
379
+ msg = (
380
+ "The number of unknown inputs is higher than the size of nnzDerW "
381
+ "and must have the same size."
382
+ )
383
+ raise ValueError(msg)
384
+
385
+ _remove_identified_parameters(model, options)
386
+
387
+ res = Result(
388
+ rank=0,
389
+ is_fispo=False,
390
+ model=model,
391
+ par_ident=[],
392
+ par_unident=[],
393
+ state_obs=[],
394
+ state_unobs=[],
395
+ input_obs=[],
396
+ input_unobs=[],
397
+ )
398
+
399
+ lastrank = None
400
+ unidflag = False
401
+ skip_elim: bool = False
402
+
403
+ n_min_lie_derivatives = _compute_n_min_lie_derivatives(model)
404
+ measured_state_names, unmeasured_state_idxs = _get_measured_states(model)
405
+
406
+ input_der = _create_derivatives(
407
+ model.known_inputs,
408
+ n_min_lie_derivatives=n_min_lie_derivatives,
409
+ n_derivatives=options.non_zero_known_input_derivatives,
410
+ )
411
+ zero_input_der_dummy_name = sym.Symbol("zero_input_der_dummy_name")
412
+
413
+ w_der: list[list[sym.Symbol]] = _create_derivatives(
414
+ model.unknown_inputs,
415
+ n_min_lie_derivatives=n_min_lie_derivatives,
416
+ n_derivatives=options.non_zero_unknown_input_derivatives,
417
+ )
418
+
419
+ w1vector, w1vector_dot = _create_w1_vector(
420
+ model,
421
+ w_der=w_der,
422
+ )
423
+
424
+ xaug, faug = _create_xaug_faug(
425
+ model,
426
+ w1vector=w1vector,
427
+ w1vector_dot=w1vector_dot,
428
+ )
429
+
430
+ onx, past_Lie = _create_onx(
431
+ model,
432
+ n_min_lie_derivatives=n_min_lie_derivatives,
433
+ options=options,
434
+ w1vector=w1vector,
435
+ xaug=xaug,
436
+ faug=faug,
437
+ input_der=input_der,
438
+ zero_input_der_dummy_name=zero_input_der_dummy_name,
439
+ )
440
+
441
+ t_start = time()
442
+
443
+ pbar = tqdm.tqdm(desc="Main loop")
444
+ while True:
445
+ pbar.update(1)
446
+ if time() - t_start > options.max_lie_time:
447
+ msg = "More Lie derivatives would be needed to see if the model is structurally unidentifiable as a whole."
448
+ raise TimeoutError(msg)
449
+
450
+ # FIXME: For some problems this starts to be really slow
451
+ # can't directly be fixed by using numpy.linalg.matrix_rank because
452
+ # that can't handle the symbolic stuff
453
+ res.rank = st.generic_rank(_rationalize_all_numbers(Matrix(onx)))
454
+
455
+ # If the onx matrix already has full rank... all is observable and identifiable
456
+ if res.rank == len(xaug):
457
+ res.state_obs = model.states
458
+ res.input_obs = model.unknown_inputs
459
+ res.par_ident = model.pars
460
+ break
461
+
462
+ # If there are unknown inputs, we may want to check id/obs of (x,p,w) and not of dw/dt:
463
+ if len(model.unknown_inputs) > 0:
464
+ _elim_and_recalc(
465
+ model=model,
466
+ res=res,
467
+ options=options,
468
+ unmeas_xred_indices=unmeasured_state_idxs,
469
+ onx=Matrix(onx),
470
+ unidflag=unidflag,
471
+ w1=w1vector,
472
+ )
473
+
474
+ if _test_fispo(
475
+ model=model, res=res, measured_state_names=measured_state_names
476
+ ):
477
+ break
478
+
479
+ # If possible (& necessary), calculate one more Lie derivative and retry:
480
+ if n_min_lie_derivatives < len(xaug) and res.rank != lastrank:
481
+ ind = n_min_lie_derivatives
482
+ n_min_lie_derivatives = (
483
+ n_min_lie_derivatives + 1
484
+ ) # One is added to the number of derivatives already made
485
+ extra_term = np.array(0) # reset for each new Lie derivative
486
+ # - Known input derivatives: ----------------------------------
487
+ # Extra terms of extended Lie derivatives
488
+ # may have to add extra input derivatives (note that 'nd' has grown):
489
+ if len(model.known_inputs) > 0:
490
+ input_der = _create_derivatives(
491
+ model.known_inputs,
492
+ n_min_lie_derivatives=n_min_lie_derivatives,
493
+ n_derivatives=options.non_zero_known_input_derivatives,
494
+ )
495
+ extra_term = _compute_extra_term(
496
+ extra_term,
497
+ ind=ind,
498
+ past_lie=past_Lie,
499
+ input_der=input_der,
500
+ zero_input_der_dummy_name=zero_input_der_dummy_name,
501
+ )
502
+
503
+ # add new derivatives, if they are not zero
504
+ if len(model.unknown_inputs) > 0:
505
+ prev_size = len(w1vector)
506
+ w_der = _create_derivatives(
507
+ model.unknown_inputs,
508
+ n_min_lie_derivatives=n_min_lie_derivatives + 1,
509
+ n_derivatives=options.non_zero_unknown_input_derivatives,
510
+ )
511
+ w1vector, w1vector_dot = _create_w1_vector(
512
+ model,
513
+ w_der=w_der,
514
+ )
515
+ xaug, faug = _create_xaug_faug(
516
+ model, w1vector=w1vector, w1vector_dot=w1vector_dot
517
+ )
518
+
519
+ # Augment size of the Obs-Id matrix if needed
520
+ new_size = len(w1vector)
521
+ onx = np.append(
522
+ onx,
523
+ zeros((ind + 1) * len(model.outputs), new_size - prev_size),
524
+ axis=1,
525
+ )
526
+
527
+ newLie = Matrix(
528
+ (
529
+ onx[(ind * len(model.outputs)) : (ind + 1) * len(model.outputs)][:] # type: ignore
530
+ ).dot(faug) # type: ignore
531
+ )
532
+ past_Lie = newLie + extra_term if extra_term else newLie
533
+ newOnx = sym.Matrix(past_Lie).jacobian(xaug)
534
+ onx = np.append(onx, newOnx, axis=0)
535
+ lastrank = res.rank
536
+
537
+ # If that is not possible, there are several possible causes:
538
+ # This is the case when you have onx with all possible derivatives done
539
+ # and it is not full rank, the maximum time for the next derivative has passed
540
+ # or the matrix no longer increases in rank as derivatives are increased.
541
+ else:
542
+ # The maximum number of Lie derivatives has been reached
543
+ if n_min_lie_derivatives >= len(xaug):
544
+ unidflag = True
545
+ elif res.rank == lastrank:
546
+ onx = onx[0 : (-1 - (len(model.outputs) - 1))] # type: ignore
547
+ # It is indicated that the number of derivatives needed was
548
+ # one less than the number of derivatives made
549
+ n_min_lie_derivatives = n_min_lie_derivatives - 1
550
+ unidflag = True
551
+
552
+ if not skip_elim and not res.is_fispo:
553
+ # Eliminate columns one by one to check identifiability
554
+ # of the associated parameters
555
+ _elim_and_recalc(
556
+ model=model,
557
+ res=res,
558
+ options=options,
559
+ unmeas_xred_indices=unmeasured_state_idxs,
560
+ onx=Matrix(onx),
561
+ unidflag=unidflag,
562
+ w1=w1vector,
563
+ )
564
+
565
+ if _test_fispo(
566
+ model=model, res=res, measured_state_names=measured_state_names
567
+ ):
568
+ break
569
+
570
+ break
571
+ pbar.close()
572
+ return res
573
+
574
+
575
+ def check_identifiability(model: SymbolicModel, outputs: list[sympy.Symbol]) -> Result:
576
+ strike_model = StrikepyModel(
577
+ states=list(model.variables.values()),
578
+ pars=list(model.parameters.values()),
579
+ eqs=model.eqs,
580
+ outputs=outputs,
581
+ )
582
+ return strike_goldd(strike_model)
@@ -0,0 +1,75 @@
1
+ # ruff: noqa: D100, D101, D102, D103, D104, D105, D106, D107, D200, D203, D400, D401
2
+
3
+
4
+ from collections.abc import Iterable
5
+ from dataclasses import dataclass
6
+
7
+ import sympy
8
+
9
+ from mxlpy.meta.source_tools import fn_to_sympy
10
+ from mxlpy.model import Model
11
+
12
+ __all__ = ["SymbolicModel", "to_symbolic_model"]
13
+
14
+
15
+ @dataclass
16
+ class SymbolicModel:
17
+ variables: dict[str, sympy.Symbol]
18
+ parameters: dict[str, sympy.Symbol]
19
+ eqs: list[sympy.Expr]
20
+ initial_conditions: dict[str, float]
21
+ parameter_values: dict[str, float]
22
+
23
+ def jacobian(self) -> sympy.Matrix:
24
+ # FIXME: don't rely on ordering of variables
25
+ return sympy.Matrix(self.eqs).jacobian(
26
+ sympy.Matrix(list(self.variables.values()))
27
+ )
28
+
29
+
30
+ def _list_of_symbols(args: Iterable[str]) -> list[sympy.Symbol]:
31
+ return [sympy.Symbol(arg) for arg in args]
32
+
33
+
34
+ def to_symbolic_model(model: Model) -> SymbolicModel:
35
+ cache = model._create_cache() # noqa: SLF001
36
+
37
+ variables: dict[str, sympy.Symbol] = dict(
38
+ zip(model.variables, _list_of_symbols(model.variables), strict=True)
39
+ )
40
+ parameters: dict[str, sympy.Symbol] = dict(
41
+ zip(model.parameters, _list_of_symbols(model.parameters), strict=True)
42
+ )
43
+ symbols: dict[str, sympy.Symbol | sympy.Expr] = variables | parameters # type: ignore
44
+
45
+ # Insert derived into symbols
46
+ for k, v in model.derived.items():
47
+ symbols[k] = fn_to_sympy(v.fn, [symbols[i] for i in v.args])
48
+
49
+ # Insert derived into reaction via args
50
+ rxns = {
51
+ k: fn_to_sympy(v.fn, [symbols[i] for i in v.args])
52
+ for k, v in model.reactions.items()
53
+ }
54
+
55
+ eqs: dict[str, sympy.Expr] = {}
56
+ for cpd, stoich in cache.stoich_by_cpds.items():
57
+ for rxn, stoich_value in stoich.items():
58
+ eqs[cpd] = (
59
+ eqs.get(cpd, sympy.Float(0.0)) + sympy.Float(stoich_value) * rxns[rxn] # type: ignore
60
+ )
61
+
62
+ for cpd, dstoich in cache.dyn_stoich_by_cpds.items():
63
+ for rxn, der in dstoich.items():
64
+ eqs[cpd] = eqs.get(cpd, sympy.Float(0.0)) + fn_to_sympy(
65
+ der.fn,
66
+ [symbols[i] for i in der.args] * rxns[rxn], # type: ignore
67
+ ) # type: ignore
68
+
69
+ return SymbolicModel(
70
+ variables=variables,
71
+ parameters=parameters,
72
+ eqs=[eqs[i] for i in cache.var_names],
73
+ initial_conditions=model.variables.copy(),
74
+ parameter_values=model.parameters.copy(),
75
+ )