mxlpy 0.23.0__py3-none-any.whl → 0.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/mc.py CHANGED
@@ -24,9 +24,11 @@ import pandas as pd
24
24
  from mxlpy import mca, scan
25
25
  from mxlpy.parallel import Cache, parallelise
26
26
  from mxlpy.scan import (
27
+ ProtocolTimeCourseWorker,
27
28
  ProtocolWorker,
28
29
  SteadyStateWorker,
29
30
  TimeCourseWorker,
31
+ _protocol_time_course_worker,
30
32
  _protocol_worker,
31
33
  _steady_state_worker,
32
34
  _time_course_worker,
@@ -49,11 +51,12 @@ if TYPE_CHECKING:
49
51
  __all__ = [
50
52
  "ParameterScanWorker",
51
53
  "parameter_elasticities",
54
+ "protocol",
55
+ "protocol_time_course",
52
56
  "response_coefficients",
53
57
  "scan_steady_state",
54
58
  "steady_state",
55
59
  "time_course",
56
- "time_course_over_protocol",
57
60
  "variable_elasticities",
58
61
  ]
59
62
 
@@ -229,7 +232,7 @@ def time_course(
229
232
  )
230
233
 
231
234
 
232
- def time_course_over_protocol(
235
+ def protocol(
233
236
  model: Model,
234
237
  *,
235
238
  protocol: pd.DataFrame,
@@ -287,6 +290,64 @@ def time_course_over_protocol(
287
290
  )
288
291
 
289
292
 
293
+ def protocol_time_course(
294
+ model: Model,
295
+ *,
296
+ protocol: pd.DataFrame,
297
+ time_points: Array,
298
+ mc_to_scan: pd.DataFrame,
299
+ y0: dict[str, float] | None = None,
300
+ max_workers: int | None = None,
301
+ cache: Cache | None = None,
302
+ worker: ProtocolTimeCourseWorker = _protocol_time_course_worker,
303
+ integrator: IntegratorType | None = None,
304
+ ) -> ProtocolScan:
305
+ """MC time course.
306
+
307
+ Examples:
308
+ >>> protocol_time_course(model, protocol, time_points, mc_to_scan)
309
+ p t x y
310
+ 0 0.0 0.1 0.00
311
+ 1.0 0.2 0.01
312
+ 2.0 0.3 0.02
313
+ 3.0 0.4 0.03
314
+ ... ... ...
315
+ 1 0.0 0.1 0.00
316
+ 1.0 0.2 0.01
317
+ 2.0 0.3 0.02
318
+ 3.0 0.4 0.03
319
+
320
+ Returns:
321
+ tuple[concentrations, fluxes] using pandas multiindex
322
+ Both dataframes are of shape (#time_points * #mc_to_scan, #variables)
323
+
324
+ """
325
+ if y0 is not None:
326
+ model.update_variables(y0)
327
+
328
+ res = parallelise(
329
+ partial(
330
+ _update_parameters_and_initial_conditions,
331
+ fn=partial(
332
+ worker,
333
+ protocol=protocol,
334
+ time_points=time_points,
335
+ integrator=integrator,
336
+ y0=None,
337
+ ),
338
+ model=model,
339
+ ),
340
+ inputs=list(mc_to_scan.iterrows()),
341
+ max_workers=max_workers,
342
+ cache=cache,
343
+ )
344
+ return ProtocolScan(
345
+ to_scan=mc_to_scan,
346
+ protocol=protocol,
347
+ raw_results=dict(res),
348
+ )
349
+
350
+
290
351
  def scan_steady_state(
291
352
  model: Model,
292
353
  *,
mxlpy/meta/__init__.py CHANGED
@@ -3,13 +3,18 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  from .codegen_latex import generate_latex_code, to_tex_export
6
- from .codegen_model import generate_model_code_py, generate_model_code_rs
6
+ from .codegen_model import (
7
+ generate_model_code_py,
8
+ generate_model_code_rs,
9
+ generate_model_code_ts,
10
+ )
7
11
  from .codegen_mxlpy import generate_mxlpy_code
8
12
 
9
13
  __all__ = [
10
14
  "generate_latex_code",
11
15
  "generate_model_code_py",
12
16
  "generate_model_code_rs",
17
+ "generate_model_code_ts",
13
18
  "generate_mxlpy_code",
14
19
  "to_tex_export",
15
20
  ]
@@ -6,6 +6,7 @@ from dataclasses import dataclass
6
6
  from typing import TYPE_CHECKING
7
7
 
8
8
  import sympy
9
+ from wadler_lindig import pformat
9
10
 
10
11
  from mxlpy.meta.sympy_tools import fn_to_sympy, list_of_symbols
11
12
  from mxlpy.types import Derived, RateFn
@@ -358,6 +359,10 @@ class TexReaction:
358
359
  fn: RateFn
359
360
  args: list[str]
360
361
 
362
+ def __repr__(self) -> str:
363
+ """Return default representation."""
364
+ return pformat(self)
365
+
361
366
 
362
367
  @dataclass
363
368
  class TexExport:
@@ -397,6 +402,10 @@ class TexExport:
397
402
  reactions: dict[str, TexReaction]
398
403
  diff_eqs: dict[str, Mapping[str, float | Derived]]
399
404
 
405
+ def __repr__(self) -> str:
406
+ """Return default representation."""
407
+ return pformat(self)
408
+
400
409
  @staticmethod
401
410
  def _diff_parameters(
402
411
  p1: dict[str, float],
@@ -9,6 +9,7 @@ from mxlpy.meta.sympy_tools import (
9
9
  fn_to_sympy,
10
10
  list_of_symbols,
11
11
  stoichiometries_to_sympy,
12
+ sympy_to_inline_js,
12
13
  sympy_to_inline_py,
13
14
  sympy_to_inline_rust,
14
15
  )
@@ -23,6 +24,7 @@ if TYPE_CHECKING:
23
24
  __all__ = [
24
25
  "generate_model_code_py",
25
26
  "generate_model_code_rs",
27
+ "generate_model_code_ts",
26
28
  ]
27
29
 
28
30
  _LOGGER = logging.getLogger(__name__)
@@ -37,6 +39,7 @@ def _generate_model_code(
37
39
  assignment_template: str,
38
40
  sympy_inline_fn: Callable[[sympy.Expr], str],
39
41
  return_template: str,
42
+ custom_fns: dict[str, sympy.Expr],
40
43
  imports: list[str] | None = None,
41
44
  end: str | None = None,
42
45
  free_parameters: list[str] | None = None,
@@ -70,11 +73,13 @@ def _generate_model_code(
70
73
 
71
74
  # Derived
72
75
  for name, derived in model.get_raw_derived().items():
73
- expr = fn_to_sympy(
74
- derived.fn,
75
- origin=name,
76
- model_args=list_of_symbols(derived.args),
77
- )
76
+ expr = custom_fns.get(name)
77
+ if expr is None:
78
+ expr = fn_to_sympy(
79
+ derived.fn,
80
+ origin=name,
81
+ model_args=list_of_symbols(derived.args),
82
+ )
78
83
  if expr is None:
79
84
  msg = f"Unable to parse fn for derived value '{name}'"
80
85
  raise ValueError(msg)
@@ -82,11 +87,16 @@ def _generate_model_code(
82
87
 
83
88
  # Reactions
84
89
  for name, rxn in model.get_raw_reactions().items():
85
- expr = fn_to_sympy(
86
- rxn.fn,
87
- origin=name,
88
- model_args=list_of_symbols(rxn.args),
89
- )
90
+ expr = custom_fns.get(name)
91
+ if expr is None:
92
+ try:
93
+ expr = fn_to_sympy(
94
+ rxn.fn,
95
+ origin=name,
96
+ model_args=list_of_symbols(rxn.args),
97
+ )
98
+ except KeyError:
99
+ _LOGGER.warning("Failed to parse %s", name)
90
100
  if expr is None:
91
101
  msg = f"Unable to parse fn for reaction value '{name}'"
92
102
  raise ValueError(msg)
@@ -123,6 +133,7 @@ def _generate_model_code(
123
133
 
124
134
  def generate_model_code_py(
125
135
  model: Model,
136
+ custom_fns: dict[str, sympy.Expr] | None = None,
126
137
  free_parameters: list[str] | None = None,
127
138
  ) -> str:
128
139
  """Transform the model into a python function, inlining the function calls."""
@@ -137,21 +148,51 @@ def generate_model_code_py(
137
148
  return _generate_model_code(
138
149
  model,
139
150
  imports=[
151
+ "import math\n",
140
152
  "from collections.abc import Iterable\n",
141
153
  ],
142
154
  sized=False,
143
155
  model_fn=model_fn,
144
156
  variables_template=" {} = variables",
145
- assignment_template=" {k} = {v}",
157
+ assignment_template=" {k}: float = {v}",
146
158
  sympy_inline_fn=sympy_to_inline_py,
147
159
  return_template=" return {}",
148
160
  end=None,
149
161
  free_parameters=free_parameters,
162
+ custom_fns={} if custom_fns is None else custom_fns,
163
+ )
164
+
165
+
166
+ def generate_model_code_ts(
167
+ model: Model,
168
+ custom_fns: dict[str, sympy.Expr] | None = None,
169
+ free_parameters: list[str] | None = None,
170
+ ) -> str:
171
+ """Transform the model into a typescript function, inlining the function calls."""
172
+ if free_parameters is None:
173
+ model_fn = "function model(time: number, variables: number[]) {"
174
+ else:
175
+ args = ", ".join(f"{k}: number" for k in free_parameters)
176
+ model_fn = f"function model(time: number, variables: number[], {args}) {{"
177
+
178
+ return _generate_model_code(
179
+ model,
180
+ imports=[],
181
+ sized=False,
182
+ model_fn=model_fn,
183
+ variables_template=" let [{}] = variables;",
184
+ assignment_template=" let {k}: number = {v};",
185
+ sympy_inline_fn=sympy_to_inline_js,
186
+ return_template=" return [{}];",
187
+ end="};",
188
+ free_parameters=free_parameters,
189
+ custom_fns={} if custom_fns is None else custom_fns,
150
190
  )
151
191
 
152
192
 
153
193
  def generate_model_code_rs(
154
194
  model: Model,
195
+ custom_fns: dict[str, sympy.Expr] | None = None,
155
196
  free_parameters: list[str] | None = None,
156
197
  ) -> str:
157
198
  """Transform the model into a rust function, inlining the function calls."""
@@ -172,4 +213,5 @@ def generate_model_code_rs(
172
213
  return_template=" return [{}]",
173
214
  end="}",
174
215
  free_parameters=free_parameters,
216
+ custom_fns={} if custom_fns is None else custom_fns,
175
217
  )
@@ -7,6 +7,7 @@ from dataclasses import dataclass, field
7
7
  from typing import TYPE_CHECKING, cast
8
8
 
9
9
  import sympy
10
+ from wadler_lindig import pformat
10
11
 
11
12
  from mxlpy.meta.sympy_tools import (
12
13
  fn_to_sympy,
@@ -43,6 +44,10 @@ class SymbolicFn:
43
44
  expr: sympy.Expr
44
45
  args: list[str]
45
46
 
47
+ def __repr__(self) -> str:
48
+ """Return default representation."""
49
+ return pformat(self)
50
+
46
51
 
47
52
  @dataclass
48
53
  class SymbolicVariable:
@@ -51,6 +56,10 @@ class SymbolicVariable:
51
56
  value: sympy.Float | SymbolicFn # initial assignment
52
57
  unit: Quantity | None
53
58
 
59
+ def __repr__(self) -> str:
60
+ """Return default representation."""
61
+ return pformat(self)
62
+
54
63
 
55
64
  @dataclass
56
65
  class SymbolicParameter:
@@ -59,6 +68,10 @@ class SymbolicParameter:
59
68
  value: sympy.Float | SymbolicFn # initial assignment
60
69
  unit: Quantity | None
61
70
 
71
+ def __repr__(self) -> str:
72
+ """Return default representation."""
73
+ return pformat(self)
74
+
62
75
 
63
76
  @dataclass
64
77
  class SymbolicReaction:
@@ -67,6 +80,10 @@ class SymbolicReaction:
67
80
  fn: SymbolicFn
68
81
  stoichiometry: dict[str, sympy.Float | str | SymbolicFn]
69
82
 
83
+ def __repr__(self) -> str:
84
+ """Return default representation."""
85
+ return pformat(self)
86
+
70
87
 
71
88
  @dataclass
72
89
  class SymbolicRepr:
@@ -77,6 +94,10 @@ class SymbolicRepr:
77
94
  derived: dict[str, SymbolicFn] = field(default_factory=dict)
78
95
  reactions: dict[str, SymbolicReaction] = field(default_factory=dict)
79
96
 
97
+ def __repr__(self) -> str:
98
+ """Return default representation."""
99
+ return pformat(self)
100
+
80
101
 
81
102
  def _fn_to_symbolic_repr(k: str, fn: Callable, model_args: list[str]) -> SymbolicFn:
82
103
  fn_name = fn.__name__
@@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, Any, cast
15
15
  import dill
16
16
  import numpy as np
17
17
  import sympy
18
+ from wadler_lindig import pformat
18
19
 
19
20
  if TYPE_CHECKING:
20
21
  from collections.abc import Callable
@@ -174,6 +175,10 @@ class Context:
174
175
  modules: dict[str, ModuleType]
175
176
  fns: dict[str, Callable]
176
177
 
178
+ def __repr__(self) -> str:
179
+ """Return default representation."""
180
+ return pformat(self)
181
+
177
182
  def updated(
178
183
  self,
179
184
  symbols: dict[str, sympy.Symbol | sympy.Expr] | None = None,
mxlpy/meta/sympy_tools.py CHANGED
@@ -5,7 +5,7 @@ from __future__ import annotations
5
5
  from typing import TYPE_CHECKING, cast
6
6
 
7
7
  import sympy
8
- from sympy.printing import rust_code
8
+ from sympy.printing import jscode, rust_code
9
9
  from sympy.printing.pycode import pycode
10
10
 
11
11
  from mxlpy.meta.source_tools import fn_to_sympy
@@ -17,6 +17,7 @@ if TYPE_CHECKING:
17
17
  __all__ = [
18
18
  "list_of_symbols",
19
19
  "stoichiometries_to_sympy",
20
+ "sympy_to_inline_js",
20
21
  "sympy_to_inline_py",
21
22
  "sympy_to_inline_rust",
22
23
  "sympy_to_python_fn",
@@ -53,6 +54,11 @@ def sympy_to_inline_py(expr: sympy.Expr) -> str:
53
54
  return cast(str, pycode(expr, fully_qualified_modules=True, full_prec=False))
54
55
 
55
56
 
57
+ def sympy_to_inline_js(expr: sympy.Expr) -> str:
58
+ """Create rust code from sympy expression."""
59
+ return cast(str, jscode(expr, full_prec=False))
60
+
61
+
56
62
  def sympy_to_inline_rust(expr: sympy.Expr) -> str:
57
63
  """Create rust code from sympy expression."""
58
64
  return cast(str, rust_code(expr, full_prec=False))
mxlpy/model.py CHANGED
@@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Self, cast
18
18
  import numpy as np
19
19
  import pandas as pd
20
20
  import sympy
21
+ from wadler_lindig import pformat
21
22
 
22
23
  from mxlpy import fns
23
24
  from mxlpy.meta.source_tools import fn_to_sympy
@@ -91,6 +92,10 @@ class MdText:
91
92
 
92
93
  content: list[str]
93
94
 
95
+ def __repr__(self) -> str:
96
+ """Return default representation."""
97
+ return pformat(self)
98
+
94
99
  def _repr_markdown_(self) -> str:
95
100
  return "\n".join(self.content)
96
101
 
@@ -101,6 +106,10 @@ class UnitCheck:
101
106
 
102
107
  per_variable: dict[str, dict[str, bool | Failure | None]]
103
108
 
109
+ def __repr__(self) -> str:
110
+ """Return default representation."""
111
+ return pformat(self)
112
+
104
113
  @staticmethod
105
114
  def _fmt_success(s: str) -> str:
106
115
  return f"<span style='color: green'>{s}</span>"
@@ -171,6 +180,10 @@ class Dependency:
171
180
  required: set[str]
172
181
  provided: set[str]
173
182
 
183
+ def __repr__(self) -> str:
184
+ """Return default representation."""
185
+ return pformat(self)
186
+
174
187
 
175
188
  class MissingDependenciesError(Exception):
176
189
  """Raised when dependencies cannot be sorted topologically.
@@ -374,6 +387,10 @@ class ModelCache:
374
387
 
375
388
  """
376
389
 
390
+ def __repr__(self) -> str:
391
+ """Return default representation."""
392
+ return pformat(self)
393
+
377
394
  order: list[str] # mostly for debug purposes
378
395
  var_names: list[str]
379
396
  dyn_order: list[str]
@@ -402,15 +419,19 @@ class Model:
402
419
  """
403
420
 
404
421
  _ids: dict[str, str] = field(default_factory=dict, repr=False)
422
+ _cache: ModelCache | None = field(default=None, repr=False)
405
423
  _variables: dict[str, Variable] = field(default_factory=dict)
406
424
  _parameters: dict[str, Parameter] = field(default_factory=dict)
407
425
  _derived: dict[str, Derived] = field(default_factory=dict)
408
426
  _readouts: dict[str, Readout] = field(default_factory=dict)
409
427
  _reactions: dict[str, Reaction] = field(default_factory=dict)
410
428
  _surrogates: dict[str, AbstractSurrogate] = field(default_factory=dict)
411
- _cache: ModelCache | None = None
412
429
  _data: dict[str, pd.Series | pd.DataFrame] = field(default_factory=dict)
413
430
 
431
+ def __repr__(self) -> str:
432
+ """Return default representation."""
433
+ return pformat(self)
434
+
414
435
  ###########################################################################
415
436
  # Cache
416
437
  ###########################################################################
@@ -2250,6 +2271,24 @@ class Model:
2250
2271
  dxdt[k] += n * dependent[flux]
2251
2272
  return tuple(dxdt[i] for i in cache.var_names)
2252
2273
 
2274
+ def _get_right_hand_side(
2275
+ self,
2276
+ *,
2277
+ args: dict[str, float],
2278
+ var_names: list[str],
2279
+ cache: ModelCache,
2280
+ ) -> pd.Series:
2281
+ dxdt = pd.Series(np.zeros(len(var_names), dtype=float), index=var_names)
2282
+ for k, stoc in cache.stoich_by_cpds.items():
2283
+ for flux, n in stoc.items():
2284
+ dxdt[k] += n * args[flux]
2285
+
2286
+ for k, sd in cache.dyn_stoich_by_cpds.items():
2287
+ for flux, dv in sd.items():
2288
+ n = dv.fn(*(args[i] for i in dv.args))
2289
+ dxdt[k] += n * args[flux]
2290
+ return dxdt
2291
+
2253
2292
  def get_right_hand_side(
2254
2293
  self,
2255
2294
  variables: dict[str, float] | None = None,
@@ -2281,21 +2320,27 @@ class Model:
2281
2320
  if (cache := self._cache) is None:
2282
2321
  cache = self._create_cache()
2283
2322
  var_names = self.get_variable_names()
2284
- dependent = self._get_args(
2323
+ args = self._get_args(
2285
2324
  variables=self.get_initial_conditions() if variables is None else variables,
2286
2325
  time=time,
2287
2326
  cache=cache,
2288
2327
  )
2289
- dxdt = pd.Series(np.zeros(len(var_names), dtype=float), index=var_names)
2290
- for k, stoc in cache.stoich_by_cpds.items():
2291
- for flux, n in stoc.items():
2292
- dxdt[k] += n * dependent[flux]
2328
+ return self._get_right_hand_side(args=args, var_names=var_names, cache=cache)
2293
2329
 
2294
- for k, sd in cache.dyn_stoich_by_cpds.items():
2295
- for flux, dv in sd.items():
2296
- n = dv.fn(*(dependent[i] for i in dv.args))
2297
- dxdt[k] += n * dependent[flux]
2298
- return dxdt
2330
+ def get_right_hand_side_time_course(self, args: pd.DataFrame) -> pd.DataFrame:
2331
+ """Calculate the right-hand side of the differential equations for the model."""
2332
+ if (cache := self._cache) is None:
2333
+ cache = self._create_cache()
2334
+ var_names = self.get_variable_names()
2335
+
2336
+ rhs_by_time = {}
2337
+ for time, variables in args.iterrows():
2338
+ rhs_by_time[time] = self._get_right_hand_side(
2339
+ args=variables.to_dict(),
2340
+ var_names=var_names,
2341
+ cache=cache,
2342
+ )
2343
+ return pd.DataFrame(rhs_by_time).T
2299
2344
 
2300
2345
  ##########################################################################
2301
2346
  # Check units
mxlpy/plot.py CHANGED
@@ -41,6 +41,7 @@ from matplotlib.figure import Figure
41
41
  from matplotlib.legend import Legend
42
42
  from matplotlib.patches import Patch
43
43
  from mpl_toolkits.mplot3d import Axes3D
44
+ from wadler_lindig import pformat
44
45
 
45
46
  from mxlpy.label_map import LabelMapper
46
47
 
@@ -105,6 +106,10 @@ class Axs:
105
106
  """Length of axes."""
106
107
  return len(self.axs.flatten())
107
108
 
109
+ def __repr__(self) -> str:
110
+ """Return default representation."""
111
+ return pformat(self)
112
+
108
113
  @overload
109
114
  def __getitem__(self, row_col: int) -> Axes: ...
110
115
 
@@ -213,6 +218,28 @@ def _partition_by_order_of_magnitude(s: pd.Series) -> list[list[str]]:
213
218
  ]
214
219
 
215
220
 
221
+ def _combine_small_groups(
222
+ groups: list[list[str]], min_group_size: int
223
+ ) -> list[list[str]]:
224
+ """Combine smaller groups."""
225
+ result = []
226
+ current_group = groups[0]
227
+
228
+ for next_group in groups[1:]:
229
+ if len(current_group) < min_group_size:
230
+ current_group.extend(next_group)
231
+ else:
232
+ result.append(current_group)
233
+ current_group = next_group
234
+
235
+ # Last group
236
+ if len(current_group) < min_group_size:
237
+ result[-1].extend(current_group)
238
+ else:
239
+ result.append(current_group)
240
+ return result
241
+
242
+
216
243
  def _split_large_groups[T](groups: list[list[T]], max_size: int) -> list[list[T]]:
217
244
  """Split groups larger than the given size into smaller groups."""
218
245
  return list(
@@ -516,7 +543,7 @@ def grid_layout(
516
543
  n_rows = math.ceil(n_groups / n_cols)
517
544
  figsize = (n_cols * col_width, n_rows * row_height)
518
545
 
519
- return _default_fig_axs(
546
+ fig, axs = _default_fig_axs(
520
547
  ncols=n_cols,
521
548
  nrows=n_rows,
522
549
  figsize=figsize,
@@ -525,6 +552,12 @@ def grid_layout(
525
552
  grid=grid,
526
553
  )
527
554
 
555
+ # Disable unused plots by default
556
+ axsl = list(axs)
557
+ for i in range(n_groups, len(axs)):
558
+ axsl[i].set_visible(False)
559
+ return fig, axs
560
+
528
561
 
529
562
  ##########################################################################
530
563
  # Plots
@@ -586,10 +619,6 @@ def bars_grouped(
586
619
  ylabel=ylabel,
587
620
  )
588
621
 
589
- axsl = list(axs)
590
- for i in range(len(groups), len(axs)):
591
- axsl[i].set_visible(False)
592
-
593
622
  return fig, axs
594
623
 
595
624
 
@@ -599,18 +628,20 @@ def bars_autogrouped(
599
628
  n_cols: int = 2,
600
629
  col_width: float = 4,
601
630
  row_height: float = 3,
631
+ min_group_size: int = 1,
602
632
  max_group_size: int = 6,
603
633
  grid: bool = True,
604
634
  xlabel: str | None = None,
605
635
  ylabel: str | None = None,
606
636
  ) -> FigAxs:
607
637
  """Plot a series or dataframe with lines grouped by order of magnitude."""
608
- group_names = _split_large_groups(
638
+ group_names = (
609
639
  _partition_by_order_of_magnitude(s)
610
640
  if isinstance(s, pd.Series)
611
- else _partition_by_order_of_magnitude(s.max()),
612
- max_size=max_group_size,
641
+ else _partition_by_order_of_magnitude(s.max())
613
642
  )
643
+ group_names = _combine_small_groups(group_names, min_group_size=min_group_size)
644
+ group_names = _split_large_groups(group_names, max_size=max_group_size)
614
645
 
615
646
  groups: list[pd.Series] | list[pd.DataFrame] = (
616
647
  [s.loc[group] for group in group_names]
@@ -714,10 +745,6 @@ def lines_grouped(
714
745
  ylabel=ylabel,
715
746
  )
716
747
 
717
- axsl = list(axs)
718
- for i in range(len(groups), len(axs)):
719
- axsl[i].set_visible(False)
720
-
721
748
  return fig, axs
722
749
 
723
750
 
@@ -727,6 +754,7 @@ def line_autogrouped(
727
754
  n_cols: int = 2,
728
755
  col_width: float = 4,
729
756
  row_height: float = 3,
757
+ min_group_size: int = 1,
730
758
  max_group_size: int = 6,
731
759
  grid: bool = True,
732
760
  xlabel: str | None = None,
@@ -736,12 +764,13 @@ def line_autogrouped(
736
764
  linestyle: Linestyle | None = None,
737
765
  ) -> FigAxs:
738
766
  """Plot a series or dataframe with lines grouped by order of magnitude."""
739
- group_names = _split_large_groups(
767
+ group_names = (
740
768
  _partition_by_order_of_magnitude(s)
741
769
  if isinstance(s, pd.Series)
742
- else _partition_by_order_of_magnitude(s.max()),
743
- max_size=max_group_size,
770
+ else _partition_by_order_of_magnitude(s.max())
744
771
  )
772
+ group_names = _combine_small_groups(group_names, min_group_size=min_group_size)
773
+ group_names = _split_large_groups(group_names, max_size=max_group_size)
745
774
 
746
775
  groups: list[pd.Series] | list[pd.DataFrame] = (
747
776
  [s.loc[group] for group in group_names]