modelbase2 0.4.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -211,7 +211,7 @@ def generate_model_code_py(model: Model) -> str:
211
211
  stoich_source = []
212
212
  for variable, stoich in stoichiometries.items():
213
213
  stoich_source.append(
214
- f" d{variable}dt = {conditional_join(stoich, lambda x: x.startswith("-"), " ", " + ")}"
214
+ f" d{variable}dt = {conditional_join(stoich, lambda x: x.startswith('-'), ' ', ' + ')}"
215
215
  )
216
216
 
217
217
  # Surrogates
@@ -329,7 +329,9 @@ class TexExport:
329
329
  parameters={gls.get(k, k): v for k, v in self.parameters.items()},
330
330
  variables={gls.get(k, k): v for k, v in self.variables.items()},
331
331
  derived={
332
- gls.get(k, k): Derived(fn=v.fn, args=[gls.get(i, i) for i in v.args])
332
+ gls.get(k, k): Derived(
333
+ name=k, fn=v.fn, args=[gls.get(i, i) for i in v.args]
334
+ )
333
335
  for k, v in self.derived.items()
334
336
  },
335
337
  reactions={
@@ -508,14 +510,10 @@ def get_model_tex_diff(
508
510
  gls = default_init(gls)
509
511
  section_label = "sec:model-diff"
510
512
 
511
- return f"""{' start autogenerated ':%^60}
513
+ return f"""{" start autogenerated ":%^60}
512
514
  {_clearpage()}
513
- {_subsubsection('Model changes')}{_label(section_label)}
514
- {(
515
- (_to_tex_export(m1) - _to_tex_export(m2))
516
- .rename_with_glossary(gls)
517
- .export_all()
518
- )}
515
+ {_subsubsection("Model changes")}{_label(section_label)}
516
+ {((_to_tex_export(m1) - _to_tex_export(m2)).rename_with_glossary(gls).export_all())}
519
517
  {_clearpage()}
520
- {' end autogenerated ':%^60}
518
+ {" end autogenerated ":%^60}
521
519
  """
@@ -287,10 +287,12 @@ class LinearLabelMapper:
287
287
  stoichiometry = {}
288
288
  if substrate != "EXT":
289
289
  stoichiometry[substrate] = Derived(
290
- _neg_one_div, [substrate.split("__")[0]]
290
+ name=substrate, fn=_neg_one_div, args=[substrate.split("__")[0]]
291
291
  )
292
292
  if product != "EXT":
293
- stoichiometry[product] = Derived(_one_div, [product.split("__")[0]])
293
+ stoichiometry[product] = Derived(
294
+ name=product, fn=_one_div, args=[product.split("__")[0]]
295
+ )
294
296
 
295
297
  m.add_reaction(
296
298
  name=f"{rxn_name}__{i}",
modelbase2/model.py CHANGED
@@ -20,12 +20,17 @@ from modelbase2 import fns
20
20
  from modelbase2.types import (
21
21
  Array,
22
22
  Derived,
23
- Float,
24
23
  Reaction,
25
24
  Readout,
26
25
  )
27
26
 
28
- __all__ = ["ArityMismatchError", "Model", "ModelCache", "SortError"]
27
+ __all__ = [
28
+ "ArityMismatchError",
29
+ "CircularDependencyError",
30
+ "MissingDependenciesError",
31
+ "Model",
32
+ "ModelCache",
33
+ ]
29
34
 
30
35
  if TYPE_CHECKING:
31
36
  from collections.abc import Iterable, Mapping
@@ -34,19 +39,38 @@ if TYPE_CHECKING:
34
39
  from modelbase2.types import AbstractSurrogate, Callable, Param, RateFn, RetType
35
40
 
36
41
 
37
- class SortError(Exception):
42
+ class MissingDependenciesError(Exception):
38
43
  """Raised when dependencies cannot be sorted topologically.
39
44
 
40
45
  This typically indicates circular dependencies in model components.
41
46
  """
42
47
 
43
- def __init__(self, unsorted: list[str], order: list[str]) -> None:
48
+ def __init__(self, not_solvable: dict[str, list[str]]) -> None:
44
49
  """Initialise exception."""
50
+ missing_by_module = "\n".join(f"\t{k}: {v}" for k, v in not_solvable.items())
45
51
  msg = (
46
- f"Exceeded max iterations on sorting derived. "
47
- "Check if there are circular references.\n"
48
- f"Unsorted: {unsorted}\n"
49
- f"Order: {order}"
52
+ f"Dependencies cannot be solved. Missing dependencies:\n{missing_by_module}"
53
+ )
54
+ super().__init__(msg)
55
+
56
+
57
+ class CircularDependencyError(Exception):
58
+ """Raised when dependencies cannot be sorted topologically.
59
+
60
+ This typically indicates circular dependencies in model components.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ missing: dict[str, set[str]],
66
+ ) -> None:
67
+ """Initialise exception."""
68
+ missing_by_module = "\n".join(f"\t{k}: {v}" for k, v in missing.items())
69
+ msg = (
70
+ f"Exceeded max iterations on sorting dependencies.\n"
71
+ "Check if there are circular references. "
72
+ "Missing dependencies:\n"
73
+ f"{missing_by_module}"
50
74
  )
51
75
  super().__init__(msg)
52
76
 
@@ -119,6 +143,24 @@ def _invalidate_cache(method: Callable[Param, RetType]) -> Callable[Param, RetTy
119
143
  return wrapper # type: ignore
120
144
 
121
145
 
146
+ def _check_if_is_sortable(
147
+ available: set[str],
148
+ elements: list[tuple[str, set[str]]],
149
+ ) -> None:
150
+ all_available = available.copy()
151
+ for name, _ in elements:
152
+ all_available.add(name)
153
+
154
+ # Check if it can be sorted in the first place
155
+ not_solvable = {}
156
+ for name, args in elements:
157
+ if not args.issubset(all_available):
158
+ not_solvable[name] = sorted(args.difference(all_available))
159
+
160
+ if not_solvable:
161
+ raise MissingDependenciesError(not_solvable=not_solvable)
162
+
163
+
122
164
  def _sort_dependencies(
123
165
  available: set[str], elements: list[tuple[str, set[str]]]
124
166
  ) -> list[str]:
@@ -137,6 +179,8 @@ def _sort_dependencies(
137
179
  """
138
180
  from queue import Empty, SimpleQueue
139
181
 
182
+ _check_if_is_sortable(available, elements)
183
+
140
184
  order = []
141
185
  # FIXME: what is the worst case here?
142
186
  max_iterations = len(elements) ** 2
@@ -170,7 +214,10 @@ def _sort_dependencies(
170
214
  unsorted.append(queue.get_nowait()[0])
171
215
  except Empty:
172
216
  break
173
- raise SortError(unsorted=unsorted, order=order)
217
+
218
+ mod_to_args: dict[str, set[str]] = dict(elements)
219
+ missing = {k: mod_to_args[k].difference(available) for k in unsorted}
220
+ raise CircularDependencyError(missing=missing)
174
221
  return order
175
222
 
176
223
 
@@ -190,6 +237,7 @@ class ModelCache:
190
237
  """
191
238
 
192
239
  var_names: list[str]
240
+ order: list[str]
193
241
  all_parameter_values: dict[str, float]
194
242
  derived_parameter_names: list[str]
195
243
  derived_variable_names: list[str]
@@ -245,22 +293,26 @@ class Model:
245
293
  # Sanity checks
246
294
  for name, el in it.chain(
247
295
  self._derived.items(),
248
- self._readouts.items(),
249
296
  self._reactions.items(),
297
+ self._readouts.items(),
250
298
  ):
251
299
  if not _check_function_arity(el.fn, len(el.args)):
252
300
  raise ArityMismatchError(name, el.fn, el.args)
253
301
 
254
- # Sort derived
255
- derived_order = _sort_dependencies(
302
+ # Sort derived & reactions
303
+ to_sort = self._derived | self._reactions | self._surrogates
304
+ order = _sort_dependencies(
256
305
  available=set(self._parameters) | set(self._variables) | {"time"},
257
- elements=[(k, set(v.args)) for k, v in self._derived.items()],
306
+ elements=[(k, set(v.args)) for k, v in to_sort.items()],
258
307
  )
259
308
 
260
309
  # Split derived into parameters and variables
310
+ # for user convenience
261
311
  derived_variable_names: list[str] = []
262
312
  derived_parameter_names: list[str] = []
263
- for name in derived_order:
313
+ for name in order:
314
+ if name in self._reactions or name in self._surrogates:
315
+ continue
264
316
  derived = self._derived[name]
265
317
  if all(i in all_parameter_names for i in derived.args):
266
318
  all_parameter_names.add(name)
@@ -300,6 +352,7 @@ class Model:
300
352
 
301
353
  self._cache = ModelCache(
302
354
  var_names=var_names,
355
+ order=order,
303
356
  all_parameter_values=all_parameter_values,
304
357
  stoich_by_cpds=stoich_by_compounds,
305
358
  dyn_stoich_by_cpds=dyn_stoich_by_compounds,
@@ -838,7 +891,7 @@ class Model:
838
891
 
839
892
  """
840
893
  self._insert_id(name=name, ctx="derived")
841
- self._derived[name] = Derived(fn, args)
894
+ self._derived[name] = Derived(name=name, fn=fn, args=args)
842
895
  return self
843
896
 
844
897
  def get_derived_parameter_names(self) -> list[str]:
@@ -947,7 +1000,7 @@ class Model:
947
1000
  """
948
1001
  if (cache := self._cache) is None:
949
1002
  cache = self._create_cache()
950
- args = self.get_args(concs=concs, time=time)
1003
+ args = self.get_dependent(concs=concs, time=time)
951
1004
 
952
1005
  stoich_by_cpds = copy.deepcopy(cache.stoich_by_cpds)
953
1006
  for cpd, stoich in cache.dyn_stoich_by_cpds.items():
@@ -988,10 +1041,12 @@ class Model:
988
1041
  self._insert_id(name=name, ctx="reaction")
989
1042
 
990
1043
  stoich: dict[str, Derived | float] = {
991
- k: Derived(fns.constant, [v]) if isinstance(v, str) else v
1044
+ k: Derived(name=k, fn=fns.constant, args=[v]) if isinstance(v, str) else v
992
1045
  for k, v in stoichiometry.items()
993
1046
  }
994
- self._reactions[name] = Reaction(fn=fn, stoichiometry=stoich, args=args)
1047
+ self._reactions[name] = Reaction(
1048
+ name=name, fn=fn, stoichiometry=stoich, args=args
1049
+ )
995
1050
  return self
996
1051
 
997
1052
  def get_reaction_names(self) -> list[str]:
@@ -1040,7 +1095,9 @@ class Model:
1040
1095
 
1041
1096
  if stoichiometry is not None:
1042
1097
  stoich = {
1043
- k: Derived(fns.constant, [v]) if isinstance(v, str) else v
1098
+ k: Derived(name=k, fn=fns.constant, args=[v])
1099
+ if isinstance(v, str)
1100
+ else v
1044
1101
  for k, v in stoichiometry.items()
1045
1102
  }
1046
1103
  rxn.stoichiometry = stoich
@@ -1114,7 +1171,7 @@ class Model:
1114
1171
 
1115
1172
  """
1116
1173
  self._insert_id(name=name, ctx="readout")
1117
- self._readouts[name] = Readout(fn, args)
1174
+ self._readouts[name] = Readout(name=name, fn=fn, args=args)
1118
1175
  return self
1119
1176
 
1120
1177
  def get_readout_names(self) -> list[str]:
@@ -1235,26 +1292,31 @@ class Model:
1235
1292
  return self
1236
1293
 
1237
1294
  ##########################################################################
1238
- # Get args
1295
+ # Get dependent values. This includes
1296
+ # - derived parameters
1297
+ # - derived variables
1298
+ # - fluxes
1299
+ # - readouts
1239
1300
  ##########################################################################
1240
1301
 
1241
- def _get_args(
1302
+ def _get_dependent(
1242
1303
  self,
1243
1304
  concs: dict[str, float],
1244
1305
  time: float = 0.0,
1245
1306
  *,
1246
- include_readouts: bool,
1307
+ cache: ModelCache,
1247
1308
  ) -> dict[str, float]:
1248
- """Generate a dictionary of arguments for model calculations.
1309
+ """Generate a dictionary of model components dependent on other components.
1249
1310
 
1250
1311
  Examples:
1251
- >>> model._get_args({"x1": 1.0, "x2": 2.0}, time=0.0)
1312
+ >>> model._get_dependent({"x1": 1.0, "x2": 2.0}, time=0.0)
1252
1313
  {"x1": 1.0, "x2": 2.0, "k1": 0.1, "time": 0.0}
1253
1314
 
1254
1315
  Args:
1255
1316
  concs: A dictionary of concentrations with keys as the names of the substances
1256
1317
  and values as their respective concentrations.
1257
1318
  time: The time point for the calculation
1319
+ cache: A ModelCache object containing precomputed values and dependencies.
1258
1320
  include_readouts: A flag indicating whether to include readout values in the returned dictionary.
1259
1321
 
1260
1322
  Returns:
@@ -1263,23 +1325,16 @@ class Model:
1263
1325
  with their respective names as keys and their calculated values as values.
1264
1326
 
1265
1327
  """
1266
- if (cache := self._cache) is None:
1267
- cache = self._create_cache()
1268
-
1269
1328
  args: dict[str, float] = cache.all_parameter_values | concs
1270
1329
  args["time"] = time
1271
1330
 
1272
- derived = self._derived
1273
- for name in cache.derived_variable_names:
1274
- dv = derived[name]
1275
- args[name] = cast(float, dv.fn(*(args[arg] for arg in dv.args)))
1331
+ containers = self._derived | self._reactions | self._surrogates
1332
+ for name in cache.order:
1333
+ containers[name].calculate_inpl(args)
1276
1334
 
1277
- if include_readouts:
1278
- for name, ro in self._readouts.items():
1279
- args[name] = cast(float, ro.fn(*(args[arg] for arg in ro.args)))
1280
1335
  return args
1281
1336
 
1282
- def get_args(
1337
+ def get_dependent(
1283
1338
  self,
1284
1339
  concs: dict[str, float] | None = None,
1285
1340
  time: float = 0.0,
@@ -1310,16 +1365,22 @@ class Model:
1310
1365
  A pandas Series containing the generated arguments with float dtype.
1311
1366
 
1312
1367
  """
1313
- return pd.Series(
1314
- self._get_args(
1315
- concs=self.get_initial_conditions() if concs is None else concs,
1316
- time=time,
1317
- include_readouts=include_readouts,
1318
- ),
1319
- dtype=float,
1368
+ if (cache := self._cache) is None:
1369
+ cache = self._create_cache()
1370
+
1371
+ args = self._get_dependent(
1372
+ concs=self.get_initial_conditions() if concs is None else concs,
1373
+ time=time,
1374
+ cache=cache,
1320
1375
  )
1321
1376
 
1322
- def get_args_time_course(
1377
+ if include_readouts:
1378
+ for ro in self._readouts.values(): # FIXME: order?
1379
+ ro.calculate_inpl(args)
1380
+
1381
+ return pd.Series(args, dtype=float)
1382
+
1383
+ def get_dependent_time_course(
1323
1384
  self,
1324
1385
  concs: pd.DataFrame,
1325
1386
  *,
@@ -1362,10 +1423,9 @@ class Model:
1362
1423
  args = pd.concat((concs, pars_df), axis=1)
1363
1424
  args["time"] = args.index
1364
1425
 
1365
- derived = self._derived
1366
- for name in cache.derived_variable_names:
1367
- dv = derived[name]
1368
- args[name] = dv.fn(*args.loc[:, dv.args].to_numpy().T)
1426
+ containers = self._derived | self._reactions | self._surrogates
1427
+ for name in cache.order:
1428
+ containers[name].calculate_inpl_time_course(args)
1369
1429
 
1370
1430
  if include_readouts:
1371
1431
  for name, ro in self._readouts.items():
@@ -1373,48 +1433,91 @@ class Model:
1373
1433
  return args
1374
1434
 
1375
1435
  ##########################################################################
1376
- # Get full concs
1436
+ # Get args
1377
1437
  ##########################################################################
1378
1438
 
1379
- def get_full_concs(
1439
+ def get_args(
1380
1440
  self,
1381
1441
  concs: dict[str, float] | None = None,
1382
1442
  time: float = 0.0,
1383
1443
  *,
1384
- include_readouts: bool = True,
1444
+ include_derived: bool = True,
1445
+ include_readouts: bool = False,
1385
1446
  ) -> pd.Series:
1386
- """Get the full concentrations as a pandas Series.
1447
+ """Generate a pandas Series of arguments for the model.
1387
1448
 
1388
1449
  Examples:
1389
- >>> model.get_full_concs({"x1": 1.0, "x2": 2.0}, time=0.0)
1390
- pd.Series({
1391
- "x1": 1.0,
1392
- "x2": 2.0,
1393
- "d1": 3.0,
1394
- "d2": 4.0,
1395
- "r1": 0.1,
1396
- "r2": 0.2,
1397
- "energy_state": 0.5,
1398
- })
1450
+ # Using initial conditions
1451
+ >>> model.get_args()
1452
+ {"x1": 1.0, "x2": 2.0, "k1": 0.1, "time": 0.0}
1453
+
1454
+ # With custom concentrations
1455
+ >>> model.get_args({"x1": 1.0, "x2": 2.0})
1456
+ {"x1": 1.0, "x2": 2.0, "k1": 0.1, "time": 0.0}
1457
+
1458
+ # With custom concentrations and time
1459
+ >>> model.get_args({"x1": 1.0, "x2": 2.0}, time=1.0)
1460
+ {"x1": 1.0, "x2": 2.0, "k1": 0.1, "time": 1.0}
1399
1461
 
1400
1462
  Args:
1401
- concs (dict[str, float]): A dictionary of concentrations with variable names as keys and their corresponding values as floats.
1402
- time (float, optional): The time point at which to get the concentrations. Default is 0.0.
1403
- include_readouts (bool, optional): Whether to include readout variables in the result. Default is True.
1463
+ concs: A dictionary where keys are the names of the concentrations and values are their respective float values.
1464
+ time: The time point at which the arguments are generated.
1465
+ include_derived: Whether to include derived variables in the arguments.
1466
+ include_readouts: Whether to include readouts in the arguments.
1404
1467
 
1405
1468
  Returns:
1406
- pd.Series: A pandas Series containing the full concentrations for the specified variables.
1469
+ A pandas Series containing the generated arguments with float dtype.
1407
1470
 
1408
1471
  """
1409
- names = self.get_variable_names() + self.get_derived_variable_names()
1472
+ names = self.get_variable_names()
1473
+ if include_derived:
1474
+ names.extend(self.get_derived_variable_names())
1410
1475
  if include_readouts:
1411
- names.extend(self.get_readout_names())
1476
+ names.extend(self._readouts)
1412
1477
 
1413
- return self.get_args(
1414
- concs=concs,
1415
- time=time,
1416
- include_readouts=include_readouts,
1417
- ).loc[names]
1478
+ args = self.get_dependent(
1479
+ concs=concs, time=time, include_readouts=include_readouts
1480
+ )
1481
+ return args.loc[names]
1482
+
1483
+ def get_args_time_course(
1484
+ self,
1485
+ concs: pd.DataFrame,
1486
+ *,
1487
+ include_derived: bool = True,
1488
+ include_readouts: bool = False,
1489
+ ) -> pd.DataFrame:
1490
+ """Generate a DataFrame containing time course arguments for model evaluation.
1491
+
1492
+ Examples:
1493
+ >>> model.get_args_time_course(
1494
+ ... pd.DataFrame({"x1": [1.0, 2.0], "x2": [2.0, 3.0]}
1495
+ ... )
1496
+ pd.DataFrame({
1497
+ "x1": [1.0, 2.0],
1498
+ "x2": [2.0, 3.0],
1499
+ "k1": [0.1, 0.1],
1500
+ "time": [0.0, 1.0]},
1501
+ )
1502
+
1503
+ Args:
1504
+ concs: A DataFrame containing concentration data with time as the index.
1505
+ include_derived: Whether to include derived variables in the arguments.
1506
+ include_readouts: If True, include readout variables in the resulting DataFrame.
1507
+
1508
+ Returns:
1509
+ A DataFrame containing the combined concentration data, parameter values,
1510
+ derived variables, and optionally readout variables, with time as an additional column.
1511
+
1512
+ """
1513
+ names = self.get_variable_names()
1514
+ if include_derived:
1515
+ names.extend(self.get_derived_variable_names())
1516
+
1517
+ args = self.get_dependent_time_course(
1518
+ concs=concs, include_readouts=include_readouts
1519
+ )
1520
+ return args.loc[:, names]
1418
1521
 
1419
1522
  ##########################################################################
1420
1523
  # Get fluxes
@@ -1470,19 +1573,16 @@ class Model:
1470
1573
  Fluxes: A pandas Series containing the fluxes for each reaction.
1471
1574
 
1472
1575
  """
1473
- args = self.get_args(
1576
+ names = self.get_reaction_names()
1577
+ for surrogate in self._surrogates.values():
1578
+ names.extend(surrogate.stoichiometries)
1579
+
1580
+ args = self.get_dependent(
1474
1581
  concs=concs,
1475
1582
  time=time,
1476
1583
  include_readouts=False,
1477
1584
  )
1478
-
1479
- fluxes: dict[str, float] = {}
1480
- for name, rxn in self._reactions.items():
1481
- fluxes[name] = cast(float, rxn.fn(*args.loc[rxn.args]))
1482
-
1483
- for surrogate in self._surrogates.values():
1484
- fluxes |= surrogate.predict(args.loc[surrogate.args].to_numpy())
1485
- return pd.Series(fluxes, dtype=float)
1585
+ return args.loc[names]
1486
1586
 
1487
1587
  def get_fluxes_time_course(self, args: pd.DataFrame) -> pd.DataFrame:
1488
1588
  """Generate a time course of fluxes for the given reactions and surrogates.
@@ -1506,20 +1606,15 @@ class Model:
1506
1606
  the index of the input arguments.
1507
1607
 
1508
1608
  """
1509
- fluxes: dict[str, Float] = {}
1510
- for name, rate in self._reactions.items():
1511
- fluxes[name] = rate.fn(*args.loc[:, rate.args].to_numpy().T)
1512
-
1513
- # Create df here already to avoid having to play around with
1514
- # shape of surrogate outputs
1515
- flux_df = pd.DataFrame(fluxes, index=args.index)
1609
+ names = self.get_reaction_names()
1516
1610
  for surrogate in self._surrogates.values():
1517
- outputs = pd.DataFrame(
1518
- [surrogate.predict(y) for y in args.loc[:, surrogate.args].to_numpy()],
1519
- index=args.index,
1520
- )
1521
- flux_df = pd.concat((flux_df, outputs), axis=1)
1522
- return flux_df
1611
+ names.extend(surrogate.stoichiometries)
1612
+
1613
+ args = self.get_dependent_time_course(
1614
+ concs=args,
1615
+ include_readouts=False,
1616
+ )
1617
+ return args.loc[:, names]
1523
1618
 
1524
1619
  ##########################################################################
1525
1620
  # Get rhs
@@ -1553,22 +1648,21 @@ class Model:
1553
1648
  strict=True,
1554
1649
  )
1555
1650
  )
1556
- args: dict[str, float] = self._get_args(
1651
+ dependent: dict[str, float] = self._get_dependent(
1557
1652
  concs=concsd,
1558
1653
  time=time,
1559
- include_readouts=False,
1654
+ cache=cache,
1560
1655
  )
1561
- fluxes: dict[str, float] = self._get_fluxes(args)
1562
1656
 
1563
1657
  dxdt = cache.dxdt
1564
1658
  dxdt[:] = 0
1565
1659
  for k, stoc in cache.stoich_by_cpds.items():
1566
1660
  for flux, n in stoc.items():
1567
- dxdt[k] += n * fluxes[flux]
1661
+ dxdt[k] += n * dependent[flux]
1568
1662
  for k, sd in cache.dyn_stoich_by_cpds.items():
1569
1663
  for flux, dv in sd.items():
1570
- n = dv.fn(*(args[i] for i in dv.args))
1571
- dxdt[k] += n * fluxes[flux]
1664
+ n = dv.calculate(dependent)
1665
+ dxdt[k] += n * dependent[flux]
1572
1666
  return cast(Array, dxdt.to_numpy())
1573
1667
 
1574
1668
  def get_right_hand_side(
@@ -1602,19 +1696,18 @@ class Model:
1602
1696
  if (cache := self._cache) is None:
1603
1697
  cache = self._create_cache()
1604
1698
  var_names = self.get_variable_names()
1605
- args = self._get_args(
1699
+ dependent = self._get_dependent(
1606
1700
  concs=self.get_initial_conditions() if concs is None else concs,
1607
1701
  time=time,
1608
- include_readouts=False,
1702
+ cache=cache,
1609
1703
  )
1610
- fluxes = self._get_fluxes(args)
1611
1704
  dxdt = pd.Series(np.zeros(len(var_names), dtype=float), index=var_names)
1612
1705
  for k, stoc in cache.stoich_by_cpds.items():
1613
1706
  for flux, n in stoc.items():
1614
- dxdt[k] += n * fluxes[flux]
1707
+ dxdt[k] += n * dependent[flux]
1615
1708
 
1616
1709
  for k, sd in cache.dyn_stoich_by_cpds.items():
1617
1710
  for flux, dv in sd.items():
1618
- n = dv.fn(*(args[i] for i in dv.args))
1619
- dxdt[k] += n * fluxes[flux]
1711
+ n = dv.fn(*(dependent[i] for i in dv.args))
1712
+ dxdt[k] += n * dependent[flux]
1620
1713
  return dxdt
modelbase2/plot.py CHANGED
@@ -818,7 +818,7 @@ def relative_label_distribution(
818
818
  isos = mapper.get_isotopomers_of_at_position(name, i)
819
819
  labels = cast(pd.DataFrame, concs.loc[:, isos])
820
820
  total = concs.loc[:, f"{name}__total"]
821
- ax.plot(labels.index, (labels.sum(axis=1) / total), label=f"C{i+1}")
821
+ ax.plot(labels.index, (labels.sum(axis=1) / total), label=f"C{i + 1}")
822
822
  ax.set_title(name)
823
823
  ax.legend()
824
824
  else:
@@ -827,6 +827,6 @@ def relative_label_distribution(
827
827
  ):
828
828
  ax.plot(concs.index, concs.loc[:, isos])
829
829
  ax.set_title(name)
830
- ax.legend([f"C{i+1}" for i in range(len(isos))])
830
+ ax.legend([f"C{i + 1}" for i in range(len(isos))])
831
831
 
832
832
  return fig, axs
@@ -492,6 +492,16 @@ def _codgen(name: str, sbml: Parser) -> Path:
492
492
  else:
493
493
  variables[k] = v.size
494
494
 
495
+ # Ensure non-zero value for initial assignments
496
+ # EXPLAIN: we need to do this for the first round of get_dependent to work
497
+ # otherwise we run into a ton of DivisionByZero errors.
498
+ # Since the values are overwritte afterwards, it doesn't really matter anyways
499
+ for k in sbml.initial_assignment:
500
+ if k in parameters and parameters[k] == 0:
501
+ parameters[k] = 1
502
+ if k in variables and variables[k] == 0:
503
+ variables[k] = 1
504
+
495
505
  derived_str = "\n ".join(
496
506
  f"m.add_derived('{k}', fn={k}, args={v.args})" for k, v in sbml.derived.items()
497
507
  )
@@ -507,7 +517,11 @@ def _codgen(name: str, sbml: Parser) -> Path:
507
517
 
508
518
  # Initial assignments
509
519
  initial_assignment_order = _sort_dependencies(
510
- available=set(sbml.initial_assignment) ^ set(parameters) ^ set(variables),
520
+ available=set(sbml.initial_assignment)
521
+ ^ set(parameters)
522
+ ^ set(variables)
523
+ ^ set(sbml.derived)
524
+ | {"time"},
511
525
  elements=[(k, set(v.args)) for k, v in sbml.initial_assignment.items()],
512
526
  )
513
527
 
@@ -535,7 +549,7 @@ def get_model() -> Model:
535
549
  {variables_str}
536
550
  {derived_str}
537
551
  {rxn_str}
538
- args = m.get_args()
552
+ args = m.get_dependent()
539
553
  {initial_assignment_source}
540
554
  return m
541
555
  """
modelbase2/scan.py CHANGED
@@ -467,7 +467,7 @@ def steady_state(
467
467
  )
468
468
  concs.index = idx
469
469
  fluxes.index = idx
470
- return SteadyStates(concs, fluxes, parameters=parameters)
470
+ return SteadyStates(concs=concs, fluxes=fluxes, parameters=parameters)
471
471
 
472
472
 
473
473
  def time_course(