Mesa 3.0.0rc0__py3-none-any.whl → 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of Mesa might be problematic. Click here for more details.

Files changed (33) hide show
  1. mesa/__init__.py +1 -1
  2. mesa/agent.py +15 -3
  3. mesa/examples/__init__.py +2 -2
  4. mesa/examples/advanced/pd_grid/app.py +1 -1
  5. mesa/examples/advanced/pd_grid/model.py +1 -1
  6. mesa/examples/advanced/sugarscape_g1mt/model.py +3 -1
  7. mesa/examples/advanced/wolf_sheep/agents.py +53 -39
  8. mesa/examples/advanced/wolf_sheep/app.py +17 -6
  9. mesa/examples/advanced/wolf_sheep/model.py +68 -74
  10. mesa/examples/basic/boid_flockers/agents.py +49 -18
  11. mesa/examples/basic/boid_flockers/model.py +55 -19
  12. mesa/examples/basic/boltzmann_wealth_model/agents.py +23 -5
  13. mesa/examples/basic/boltzmann_wealth_model/app.py +8 -3
  14. mesa/examples/basic/boltzmann_wealth_model/model.py +48 -13
  15. mesa/examples/basic/boltzmann_wealth_model/st_app.py +2 -2
  16. mesa/examples/basic/schelling/agents.py +9 -5
  17. mesa/examples/basic/schelling/model.py +48 -26
  18. mesa/experimental/cell_space/cell_collection.py +14 -2
  19. mesa/experimental/cell_space/discrete_space.py +16 -2
  20. mesa/experimental/devs/simulator.py +59 -14
  21. mesa/model.py +4 -4
  22. mesa/time.py +4 -4
  23. mesa/visualization/__init__.py +1 -1
  24. mesa/visualization/components/matplotlib_components.py +1 -2
  25. mesa/visualization/mpl_space_drawing.py +42 -7
  26. mesa/visualization/solara_viz.py +133 -54
  27. {mesa-3.0.0rc0.dist-info → mesa-3.0.1.dist-info}/METADATA +6 -8
  28. {mesa-3.0.0rc0.dist-info → mesa-3.0.1.dist-info}/RECORD +33 -33
  29. {mesa-3.0.0rc0.dist-info → mesa-3.0.1.dist-info}/WHEEL +1 -1
  30. /mesa/visualization/{UserParam.py → user_param.py} +0 -0
  31. {mesa-3.0.0rc0.dist-info → mesa-3.0.1.dist-info}/entry_points.txt +0 -0
  32. {mesa-3.0.0rc0.dist-info → mesa-3.0.1.dist-info}/licenses/LICENSE +0 -0
  33. {mesa-3.0.0rc0.dist-info → mesa-3.0.1.dist-info}/licenses/NOTICE +0 -0
@@ -57,8 +57,20 @@ class Simulator:
57
57
  Args:
58
58
  model (Model): The model to simulate
59
59
 
60
+ Raises:
61
+ Exception if simulator.time is not equal to simulator.starttime
62
+ Exception if event list is not empty
63
+
60
64
  """
61
- self.event_list.clear()
65
+ if self.time != self.start_time:
66
+ raise ValueError(
67
+ "trying to setup model, but current time is not equal to start_time, Has the simulator been reset or freshly initialized?"
68
+ )
69
+ if not self.event_list.is_empty():
70
+ raise ValueError(
71
+ "trying to setup model, but events have already been scheduled. Call simulator.setup before any scheduling"
72
+ )
73
+
62
74
  self.model = model
63
75
 
64
76
  def reset(self):
@@ -68,7 +80,20 @@ class Simulator:
68
80
  self.time = self.start_time
69
81
 
70
82
  def run_until(self, end_time: int | float) -> None:
71
- """Run the simulator until the end time."""
83
+ """Run the simulator until the end time.
84
+
85
+ Args:
86
+ end_time (int | float): The end time for stopping the simulator
87
+
88
+ Raises:
89
+ Exception if simulator.setup() has not yet been called
90
+
91
+ """
92
+ if self.model is None:
93
+ raise Exception(
94
+ "simulator has not been setup, call simulator.setup(model) first"
95
+ )
96
+
72
97
  while True:
73
98
  try:
74
99
  event = self.event_list.pop_event()
@@ -84,6 +109,26 @@ class Simulator:
84
109
  self._schedule_event(event) # reschedule event
85
110
  break
86
111
 
112
+ def run_next_event(self):
113
+ """Execute the next event.
114
+
115
+ Raises:
116
+ Exception if simulator.setup() has not yet been called
117
+
118
+ """
119
+ if self.model is None:
120
+ raise Exception(
121
+ "simulator has not been setup, call simulator.setup(model) first"
122
+ )
123
+
124
+ try:
125
+ event = self.event_list.pop_event()
126
+ except IndexError: # event list is empty
127
+ return
128
+ else:
129
+ self.time = event.time
130
+ event.execute()
131
+
87
132
  def run_for(self, time_delta: int | float):
88
133
  """Run the simulator for the specified time delta.
89
134
 
@@ -92,6 +137,7 @@ class Simulator:
92
137
  plus the time delta
93
138
 
94
139
  """
140
+ # fixme, raise initialization error or something like it if model.setup has not been called
95
141
  end_time = self.time + time_delta
96
142
  self.run_until(end_time)
97
143
 
@@ -228,7 +274,7 @@ class ABMSimulator(Simulator):
228
274
 
229
275
  """
230
276
  super().setup(model)
231
- self.schedule_event_now(self.model.step, priority=Priority.HIGH)
277
+ self.schedule_event_next_tick(self.model.step, priority=Priority.HIGH)
232
278
 
233
279
  def check_time_unit(self, time) -> bool:
234
280
  """Check whether the time is of the correct unit.
@@ -277,7 +323,15 @@ class ABMSimulator(Simulator):
277
323
  Args:
278
324
  end_time (float| int): The end_time delta. The simulator is until the specified end time
279
325
 
326
+ Raises:
327
+ Exception if simulator.setup() has not yet been called
328
+
280
329
  """
330
+ if self.model is None:
331
+ raise Exception(
332
+ "simulator has not been setup, call simulator.setup(model) first"
333
+ )
334
+
281
335
  while True:
282
336
  try:
283
337
  event = self.event_list.pop_event()
@@ -285,6 +339,8 @@ class ABMSimulator(Simulator):
285
339
  self.time = end_time
286
340
  break
287
341
 
342
+ # fixme: the alternative would be to wrap model.step with an annotation which
343
+ # handles this scheduling.
288
344
  if event.time <= end_time:
289
345
  self.time = event.time
290
346
  if event.fn() == self.model.step:
@@ -298,17 +354,6 @@ class ABMSimulator(Simulator):
298
354
  self._schedule_event(event)
299
355
  break
300
356
 
301
- def run_for(self, time_delta: int):
302
- """Run the simulator for the specified time delta.
303
-
304
- Args:
305
- time_delta (float| int): The time delta. The simulator is run from the current time to the current time
306
- plus the time delta
307
-
308
- """
309
- end_time = self.time + time_delta - 1
310
- self.run_until(end_time)
311
-
312
357
 
313
358
  class DEVSimulator(Simulator):
314
359
  """A simulator where the unit of time is a float.
mesa/model.py CHANGED
@@ -114,7 +114,7 @@ class Model:
114
114
 
115
115
  def next_id(self) -> int: # noqa: D102
116
116
  warnings.warn(
117
- "using model.next_id() is deprecated. Agents track their unique ID automatically",
117
+ "using model.next_id() is deprecated and will be removed in Mesa 3.1. Agents track their unique ID automatically",
118
118
  DeprecationWarning,
119
119
  stacklevel=2,
120
120
  )
@@ -146,8 +146,8 @@ class Model:
146
146
  def get_agents_of_type(self, agenttype: type[Agent]) -> AgentSet:
147
147
  """Deprecated: Retrieves an AgentSet containing all agents of the specified type."""
148
148
  warnings.warn(
149
- f"Model.get_agents_of_type() is deprecated, please replace get_agents_of_type({agenttype})"
150
- f"with the property agents_by_type[{agenttype}].",
149
+ f"Model.get_agents_of_type() is deprecated and will be removed in Mesa 3.1."
150
+ f"Please replace get_agents_of_type({agenttype}) with the property agents_by_type[{agenttype}].",
151
151
  DeprecationWarning,
152
152
  stacklevel=2,
153
153
  )
@@ -262,7 +262,7 @@ class Model:
262
262
 
263
263
  """
264
264
  warnings.warn(
265
- "initialize_data_collector() is deprecated. Please use the DataCollector class directly. "
265
+ "initialize_data_collector() is deprecated and will be removed in Mesa 3.1. Please use the DataCollector class directly. "
266
266
  "by using `self.datacollector = DataCollector(...)`.",
267
267
  DeprecationWarning,
268
268
  stacklevel=2,
mesa/time.py CHANGED
@@ -1,7 +1,7 @@
1
1
  """Mesa Time Module.
2
2
 
3
3
  .. warning::
4
- The time module and all its Schedulers are deprecated and will be removed in a future version.
4
+ The time module and all its Schedulers are deprecated and will be removed in Mesa 3.1.
5
5
  They can be replaced with AgentSet functionality. See the migration guide for details:
6
6
  https://mesa.readthedocs.io/latest/migration_guide.html#time-and-schedulers
7
7
 
@@ -63,7 +63,7 @@ class BaseScheduler:
63
63
 
64
64
  """
65
65
  warnings.warn(
66
- "The time module and all its Schedulers are deprecated and will be removed in a future version. "
66
+ "The time module and all its Schedulers are deprecated and will be removed in Mesa 3.1. "
67
67
  "They can be replaced with AgentSet functionality. See the migration guide for details. "
68
68
  "https://mesa.readthedocs.io/latest/migration_guide.html#time-and-schedulers",
69
69
  DeprecationWarning,
@@ -375,7 +375,7 @@ class RandomActivationByType(BaseScheduler):
375
375
 
376
376
 
377
377
  class DiscreteEventScheduler(BaseScheduler):
378
- """This class has been deprecated and replaced by the functionality provided by experimental.devs."""
378
+ """This class has been removed and replaced by the functionality provided by experimental.devs."""
379
379
 
380
380
  def __init__(self, model: Model, time_step: TimeT = 1) -> None:
381
381
  """Initialize DiscreteEventScheduler.
@@ -387,5 +387,5 @@ class DiscreteEventScheduler(BaseScheduler):
387
387
  """
388
388
  super().__init__(model)
389
389
  raise Exception(
390
- "DiscreteEventScheduler is deprecated in favor of the functionality provided by experimental.devs"
390
+ "DiscreteEventScheduler is removed in favor of the functionality provided by experimental.devs"
391
391
  )
@@ -13,7 +13,7 @@ from mesa.visualization.mpl_space_drawing import (
13
13
  from .components import make_plot_component, make_space_component
14
14
  from .components.altair_components import make_space_altair
15
15
  from .solara_viz import JupyterViz, SolaraViz
16
- from .UserParam import Slider
16
+ from .user_param import Slider
17
17
 
18
18
  __all__ = [
19
19
  "JupyterViz",
@@ -38,8 +38,7 @@ def make_mpl_space_component(
38
38
  the functions for drawing the various spaces for further details.
39
39
 
40
40
  ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
41
- "size", "marker", and "zorder". Other field are ignored and will result in a user warning.
42
-
41
+ "size", "marker", "zorder", alpha, linewidths, and edgecolors. Other field are ignored and will result in a user warning.
43
42
 
44
43
  Returns:
45
44
  function: A function that creates a SpaceMatplotlib component
@@ -6,6 +6,7 @@ for a paper.
6
6
 
7
7
  """
8
8
 
9
+ import contextlib
9
10
  import itertools
10
11
  import math
11
12
  import warnings
@@ -61,10 +62,19 @@ def collect_agent_data(
61
62
  zorder: default zorder
62
63
 
63
64
  agent_portrayal should return a dict, limited to size (size of marker), color (color of marker), zorder (z-order),
64
- and marker (marker style)
65
+ marker (marker style), alpha, linewidths, and edgecolors
65
66
 
66
67
  """
67
- arguments = {"s": [], "c": [], "marker": [], "zorder": [], "loc": []}
68
+ arguments = {
69
+ "s": [],
70
+ "c": [],
71
+ "marker": [],
72
+ "zorder": [],
73
+ "loc": [],
74
+ "alpha": [],
75
+ "edgecolors": [],
76
+ "linewidths": [],
77
+ }
68
78
 
69
79
  for agent in space.agents:
70
80
  portray = agent_portrayal(agent)
@@ -78,6 +88,10 @@ def collect_agent_data(
78
88
  arguments["marker"].append(portray.pop("marker", marker))
79
89
  arguments["zorder"].append(portray.pop("zorder", zorder))
80
90
 
91
+ for entry in ["alpha", "edgecolors", "linewidths"]:
92
+ with contextlib.suppress(KeyError):
93
+ arguments[entry].append(portray.pop(entry))
94
+
81
95
  if len(portray) > 0:
82
96
  ignored_fields = list(portray.keys())
83
97
  msg = ", ".join(ignored_fields)
@@ -110,7 +124,7 @@ def draw_space(
110
124
  Returns the Axes object with the plot drawn onto it.
111
125
 
112
126
  ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
113
- "size", "marker", and "zorder". Other field are ignored and will result in a user warning.
127
+ "size", "marker", "zorder", alpha, linewidths, and edgecolors. Other field are ignored and will result in a user warning.
114
128
 
115
129
  """
116
130
  if ax is None:
@@ -118,16 +132,24 @@ def draw_space(
118
132
 
119
133
  # https://stackoverflow.com/questions/67524641/convert-multiple-isinstance-checks-to-structural-pattern-matching
120
134
  match space:
121
- case mesa.space._Grid() | OrthogonalMooreGrid() | OrthogonalVonNeumannGrid():
122
- draw_orthogonal_grid(space, agent_portrayal, ax=ax, **space_drawing_kwargs)
135
+ # order matters here given the class structure of old-style grid spaces
123
136
  case HexSingleGrid() | HexMultiGrid() | mesa.experimental.cell_space.HexGrid():
124
137
  draw_hex_grid(space, agent_portrayal, ax=ax, **space_drawing_kwargs)
138
+ case (
139
+ mesa.space.SingleGrid()
140
+ | OrthogonalMooreGrid()
141
+ | OrthogonalVonNeumannGrid()
142
+ | mesa.space.MultiGrid()
143
+ ):
144
+ draw_orthogonal_grid(space, agent_portrayal, ax=ax, **space_drawing_kwargs)
125
145
  case mesa.space.NetworkGrid() | mesa.experimental.cell_space.Network():
126
146
  draw_network(space, agent_portrayal, ax=ax, **space_drawing_kwargs)
127
147
  case mesa.space.ContinuousSpace():
128
148
  draw_continuous_space(space, agent_portrayal, ax=ax)
129
149
  case VoronoiGrid():
130
- draw_voroinoi_grid(space, agent_portrayal, ax=ax)
150
+ draw_voronoi_grid(space, agent_portrayal, ax=ax)
151
+ case _:
152
+ raise ValueError(f"Unknown space type: {type(space)}")
131
153
 
132
154
  if propertylayer_portrayal:
133
155
  draw_property_layers(space, propertylayer_portrayal, ax=ax)
@@ -473,7 +495,7 @@ def draw_continuous_space(
473
495
  return ax
474
496
 
475
497
 
476
- def draw_voroinoi_grid(
498
+ def draw_voronoi_grid(
477
499
  space: VoronoiGrid, agent_portrayal: Callable, ax: Axes | None = None, **kwargs
478
500
  ):
479
501
  """Visualize a voronoi grid.
@@ -543,11 +565,24 @@ def _scatter(ax: Axes, arguments, **kwargs):
543
565
  marker = arguments.pop("marker")
544
566
  zorder = arguments.pop("zorder")
545
567
 
568
+ # we check if edgecolor, linewidth, and alpha are specified
569
+ # at the agent level, if not, we remove them from the arguments dict
570
+ # and fallback to the default value in ax.scatter / use what is passed via **kwargs
571
+ for entry in ["edgecolors", "linewidths", "alpha"]:
572
+ if len(arguments[entry]) == 0:
573
+ arguments.pop(entry)
574
+ else:
575
+ if entry in kwargs:
576
+ raise ValueError(
577
+ f"{entry} is specified in agent portrayal and via plotting kwargs, you can only use one or the other"
578
+ )
579
+
546
580
  for mark in np.unique(marker):
547
581
  mark_mask = marker == mark
548
582
  for z_order in np.unique(zorder):
549
583
  zorder_mask = z_order == zorder
550
584
  logical = mark_mask & zorder_mask
585
+
551
586
  ax.scatter(
552
587
  x[logical],
553
588
  y[logical],
@@ -24,7 +24,6 @@ See the Visualization Tutorial and example models for more details.
24
24
  from __future__ import annotations
25
25
 
26
26
  import asyncio
27
- import copy
28
27
  import inspect
29
28
  from collections.abc import Callable
30
29
  from typing import TYPE_CHECKING, Literal
@@ -33,7 +32,8 @@ import reacton.core
33
32
  import solara
34
33
 
35
34
  import mesa.visualization.components.altair_components as components_altair
36
- from mesa.visualization.UserParam import Slider
35
+ from mesa.experimental.devs.simulator import Simulator
36
+ from mesa.visualization.user_param import Slider
37
37
  from mesa.visualization.utils import force_update, update_counter
38
38
 
39
39
  if TYPE_CHECKING:
@@ -43,12 +43,13 @@ if TYPE_CHECKING:
43
43
  @solara.component
44
44
  def SolaraViz(
45
45
  model: Model | solara.Reactive[Model],
46
+ *,
46
47
  components: list[reacton.core.Component]
47
48
  | list[Callable[[Model], reacton.core.Component]]
48
49
  | Literal["default"] = "default",
49
50
  play_interval: int = 100,
51
+ simulator: Simulator | None = None,
50
52
  model_params=None,
51
- seed: float = 0,
52
53
  name: str | None = None,
53
54
  ):
54
55
  """Solara visualization component.
@@ -67,10 +68,9 @@ def SolaraViz(
67
68
  Defaults to "default", which uses the default Altair space visualization.
68
69
  play_interval (int, optional): Interval for playing the model steps in milliseconds.
69
70
  This controls the speed of the model's automatic stepping. Defaults to 100 ms.
71
+ simulator: A simulator that controls the model (optional)
70
72
  model_params (dict, optional): Parameters for (re-)instantiating a model.
71
73
  Can include user-adjustable parameters and fixed parameters. Defaults to None.
72
- seed (int, optional): Seed for the random number generator. This ensures reproducibility
73
- of the model's behavior. Defaults to 0.
74
74
  name (str | None, optional): Name of the visualization. Defaults to the models class name.
75
75
 
76
76
  Returns:
@@ -88,41 +88,39 @@ def SolaraViz(
88
88
  value results in faster stepping, while a higher value results in slower stepping.
89
89
  """
90
90
  if components == "default":
91
- components = [components_altair.make_space_altair()]
91
+ components = [components_altair.make_altair_space()]
92
+ if model_params is None:
93
+ model_params = {}
92
94
 
93
95
  # Convert model to reactive
94
96
  if not isinstance(model, solara.Reactive):
95
97
  model = solara.use_reactive(model) # noqa: SH102, RUF100
96
98
 
97
- def connect_to_model():
98
- # Patch the step function to force updates
99
- original_step = model.value.step
100
-
101
- def step():
102
- original_step()
103
- force_update()
104
-
105
- model.value.step = step
106
- # Add a trigger to model itself
107
- model.value.force_update = force_update
108
- force_update()
109
-
110
- solara.use_effect(connect_to_model, [model.value])
99
+ # set up reactive model_parameters shared by ModelCreator and ModelController
100
+ reactive_model_parameters = solara.use_reactive({})
111
101
 
112
102
  with solara.AppBar():
113
103
  solara.AppBarTitle(name if name else model.value.__class__.__name__)
114
104
 
115
105
  with solara.Sidebar(), solara.Column():
116
106
  with solara.Card("Controls"):
117
- ModelController(model, play_interval)
118
-
119
- if model_params is not None:
120
- with solara.Card("Model Parameters"):
121
- ModelCreator(
107
+ if not isinstance(simulator, Simulator):
108
+ ModelController(
122
109
  model,
123
- model_params,
124
- seed=seed,
110
+ model_parameters=reactive_model_parameters,
111
+ play_interval=play_interval,
125
112
  )
113
+ else:
114
+ SimulatorController(
115
+ model,
116
+ simulator,
117
+ model_parameters=reactive_model_parameters,
118
+ play_interval=play_interval,
119
+ )
120
+ with solara.Card("Model Parameters"):
121
+ ModelCreator(
122
+ model, model_params, model_parameters=reactive_model_parameters
123
+ )
126
124
  with solara.Card("Information"):
127
125
  ShowSteps(model.value)
128
126
 
@@ -173,24 +171,89 @@ JupyterViz = SolaraViz
173
171
 
174
172
 
175
173
  @solara.component
176
- def ModelController(model: solara.Reactive[Model], play_interval=100):
174
+ def ModelController(
175
+ model: solara.Reactive[Model],
176
+ *,
177
+ model_parameters: dict | solara.Reactive[dict] = None,
178
+ play_interval: int = 100,
179
+ ):
177
180
  """Create controls for model execution (step, play, pause, reset).
178
181
 
179
182
  Args:
180
- model (solara.Reactive[Model]): Reactive model instance
181
- play_interval (int, optional): Interval for playing the model steps in milliseconds.
183
+ model: Reactive model instance
184
+ model_parameters: Reactive parameters for (re-)instantiating a model.
185
+ play_interval: Interval for playing the model steps in milliseconds.
186
+
182
187
  """
183
188
  playing = solara.use_reactive(False)
184
189
  running = solara.use_reactive(True)
185
- original_model = solara.use_reactive(None)
190
+ if model_parameters is None:
191
+ model_parameters = {}
192
+ model_parameters = solara.use_reactive(model_parameters)
186
193
 
187
- def save_initial_model():
188
- """Save the initial model for comparison."""
189
- original_model.set(copy.deepcopy(model.value))
190
- playing.value = False
194
+ async def step():
195
+ while playing.value and running.value:
196
+ await asyncio.sleep(play_interval / 1000)
197
+ do_step()
198
+
199
+ solara.lab.use_task(
200
+ step, dependencies=[playing.value, running.value], prefer_threaded=False
201
+ )
202
+
203
+ def do_step():
204
+ """Advance the model by one step."""
205
+ model.value.step()
206
+ running.value = model.value.running
191
207
  force_update()
192
208
 
193
- solara.use_effect(save_initial_model, [model.value])
209
+ def do_reset():
210
+ """Reset the model to its initial state."""
211
+ playing.value = False
212
+ running.value = True
213
+ model.value = model.value = model.value.__class__(**model_parameters.value)
214
+
215
+ def do_play_pause():
216
+ """Toggle play/pause."""
217
+ playing.value = not playing.value
218
+
219
+ with solara.Row(justify="space-between"):
220
+ solara.Button(label="Reset", color="primary", on_click=do_reset)
221
+ solara.Button(
222
+ label="▶" if not playing.value else "❚❚",
223
+ color="primary",
224
+ on_click=do_play_pause,
225
+ disabled=not running.value,
226
+ )
227
+ solara.Button(
228
+ label="Step",
229
+ color="primary",
230
+ on_click=do_step,
231
+ disabled=playing.value or not running.value,
232
+ )
233
+
234
+
235
+ @solara.component
236
+ def SimulatorController(
237
+ model: solara.Reactive[Model],
238
+ simulator,
239
+ *,
240
+ model_parameters: dict | solara.Reactive[dict] = None,
241
+ play_interval: int = 100,
242
+ ):
243
+ """Create controls for model execution (step, play, pause, reset).
244
+
245
+ Args:
246
+ model: Reactive model instance
247
+ simulator: Simulator instance
248
+ model_parameters: Reactive parameters for (re-)instantiating a model.
249
+ play_interval: Interval for playing the model steps in milliseconds.
250
+
251
+ """
252
+ playing = solara.use_reactive(False)
253
+ running = solara.use_reactive(True)
254
+ if model_parameters is None:
255
+ model_parameters = {}
256
+ model_parameters = solara.use_reactive(model_parameters)
194
257
 
195
258
  async def step():
196
259
  while playing.value and running.value:
@@ -203,14 +266,18 @@ def ModelController(model: solara.Reactive[Model], play_interval=100):
203
266
 
204
267
  def do_step():
205
268
  """Advance the model by one step."""
206
- model.value.step()
269
+ simulator.run_for(1)
207
270
  running.value = model.value.running
271
+ force_update()
208
272
 
209
273
  def do_reset():
210
274
  """Reset the model to its initial state."""
211
275
  playing.value = False
212
276
  running.value = True
213
- model.value = copy.deepcopy(original_model.value)
277
+ simulator.reset()
278
+ model.value = model.value = model.value.__class__(
279
+ simulator, **model_parameters.value
280
+ )
214
281
 
215
282
  def do_play_pause():
216
283
  """Toggle play/pause."""
@@ -269,7 +336,12 @@ def check_param_is_fixed(param):
269
336
 
270
337
 
271
338
  @solara.component
272
- def ModelCreator(model, model_params, seed=1):
339
+ def ModelCreator(
340
+ model: solara.Reactive[Model],
341
+ user_params: dict,
342
+ *,
343
+ model_parameters: dict | solara.Reactive[dict] = None,
344
+ ):
273
345
  """Solara component for creating and managing a model instance with user-defined parameters.
274
346
 
275
347
  This component allows users to create a model instance with specified parameters and seed.
@@ -277,9 +349,9 @@ def ModelCreator(model, model_params, seed=1):
277
349
  number generator.
278
350
 
279
351
  Args:
280
- model (solara.Reactive[Model]): A reactive model instance. This is the main model to be created and managed.
281
- model_params (dict): Dictionary of model parameters. This includes both user-adjustable parameters and fixed parameters.
282
- seed (int, optional): Initial seed for the random number generator. Defaults to 1.
352
+ model: A reactive model instance. This is the main model to be created and managed.
353
+ user_params: Parameters for (re-)instantiating a model. Can include user-adjustable parameters and fixed parameters. Defaults to None.
354
+ model_parameters: reactive parameters for reinitializing the model
283
355
 
284
356
  Returns:
285
357
  solara.component: A Solara component that renders the model creation and management interface.
@@ -300,24 +372,24 @@ def ModelCreator(model, model_params, seed=1):
300
372
  - The component provides an interface for adjusting user-defined parameters and reseeding the model.
301
373
 
302
374
  """
375
+ if model_parameters is None:
376
+ model_parameters = {}
377
+ model_parameters = solara.use_reactive(model_parameters)
378
+
303
379
  solara.use_effect(
304
380
  lambda: _check_model_params(model.value.__class__.__init__, fixed_params),
305
381
  [model.value],
306
382
  )
383
+ user_params, fixed_params = split_model_params(user_params)
307
384
 
308
- user_params, fixed_params = split_model_params(model_params)
309
-
310
- model_parameters, set_model_parameters = solara.use_state(
311
- {
312
- **fixed_params,
313
- **{k: v.get("value") for k, v in user_params.items()},
314
- }
315
- )
385
+ # set model_parameters to the default values for all parameters
386
+ model_parameters.value = {
387
+ **fixed_params,
388
+ **{k: v.get("value") for k, v in user_params.items()},
389
+ }
316
390
 
317
391
  def on_change(name, value):
318
- new_model_parameters = {**model_parameters, name: value}
319
- model.value = model.value.__class__(**new_model_parameters)
320
- set_model_parameters(new_model_parameters)
392
+ model_parameters.value = {**model_parameters.value, name: value}
321
393
 
322
394
  UserInputs(user_params, on_change=on_change)
323
395
 
@@ -338,10 +410,11 @@ def _check_model_params(init_func, model_params):
338
410
  model_parameters[name].default == inspect.Parameter.empty
339
411
  and name not in model_params
340
412
  and name != "self"
413
+ and name != "kwargs"
341
414
  ):
342
415
  raise ValueError(f"Missing required model parameter: {name}")
343
416
  for name in model_params:
344
- if name not in model_parameters:
417
+ if name not in model_parameters and "kwargs" not in model_parameters:
345
418
  raise ValueError(f"Invalid model parameter: {name}")
346
419
 
347
420
 
@@ -409,6 +482,12 @@ def UserInputs(user_params, on_change=None):
409
482
  on_value=change_handler,
410
483
  value=options.get("value"),
411
484
  )
485
+ elif input_type == "InputText":
486
+ solara.InputText(
487
+ label=label,
488
+ on_value=change_handler,
489
+ value=options.get("value"),
490
+ )
412
491
  else:
413
492
  raise ValueError(f"{input_type} is not a supported input type")
414
493