Mesa 3.0.0b2__py3-none-any.whl → 3.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of Mesa might be problematic. Click here for more details.

Files changed (28) hide show
  1. mesa/__init__.py +1 -1
  2. mesa/batchrunner.py +26 -1
  3. mesa/examples/README.md +11 -11
  4. mesa/examples/advanced/epstein_civil_violence/agents.py +44 -38
  5. mesa/examples/advanced/epstein_civil_violence/app.py +29 -28
  6. mesa/examples/advanced/epstein_civil_violence/model.py +33 -65
  7. mesa/examples/advanced/pd_grid/app.py +8 -4
  8. mesa/examples/advanced/sugarscape_g1mt/app.py +5 -13
  9. mesa/examples/advanced/wolf_sheep/app.py +25 -18
  10. mesa/examples/basic/boid_flockers/app.py +2 -2
  11. mesa/examples/basic/boltzmann_wealth_model/app.py +14 -10
  12. mesa/examples/basic/conways_game_of_life/app.py +15 -3
  13. mesa/examples/basic/schelling/app.py +5 -5
  14. mesa/examples/basic/virus_on_network/app.py +25 -47
  15. mesa/space.py +0 -30
  16. mesa/visualization/__init__.py +16 -5
  17. mesa/visualization/components/__init__.py +83 -0
  18. mesa/visualization/components/{altair.py → altair_components.py} +34 -2
  19. mesa/visualization/components/matplotlib_components.py +176 -0
  20. mesa/visualization/mpl_space_drawing.py +558 -0
  21. mesa/visualization/solara_viz.py +30 -20
  22. {mesa-3.0.0b2.dist-info → mesa-3.0.0rc0.dist-info}/METADATA +1 -1
  23. {mesa-3.0.0b2.dist-info → mesa-3.0.0rc0.dist-info}/RECORD +27 -25
  24. mesa/visualization/components/matplotlib.py +0 -386
  25. {mesa-3.0.0b2.dist-info → mesa-3.0.0rc0.dist-info}/WHEEL +0 -0
  26. {mesa-3.0.0b2.dist-info → mesa-3.0.0rc0.dist-info}/entry_points.txt +0 -0
  27. {mesa-3.0.0b2.dist-info → mesa-3.0.0rc0.dist-info}/licenses/LICENSE +0 -0
  28. {mesa-3.0.0b2.dist-info → mesa-3.0.0rc0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,558 @@
1
+ """Helper functions for drawing mesa spaces with matplotlib.
2
+
3
+ These functions are used by the provided matplotlib components, but can also be used to quickly visualize
4
+ a space with matplotlib for example when creating a mp4 of a movie run or when needing a figure
5
+ for a paper.
6
+
7
+ """
8
+
9
+ import itertools
10
+ import math
11
+ import warnings
12
+ from collections.abc import Callable
13
+ from typing import Any
14
+
15
+ import networkx as nx
16
+ import numpy as np
17
+ from matplotlib import pyplot as plt
18
+ from matplotlib.axes import Axes
19
+ from matplotlib.cm import ScalarMappable
20
+ from matplotlib.collections import PatchCollection
21
+ from matplotlib.colors import LinearSegmentedColormap, Normalize, to_rgba
22
+ from matplotlib.patches import RegularPolygon
23
+
24
+ import mesa
25
+ from mesa.experimental.cell_space import (
26
+ OrthogonalMooreGrid,
27
+ OrthogonalVonNeumannGrid,
28
+ VoronoiGrid,
29
+ )
30
+ from mesa.space import (
31
+ ContinuousSpace,
32
+ HexMultiGrid,
33
+ HexSingleGrid,
34
+ MultiGrid,
35
+ NetworkGrid,
36
+ PropertyLayer,
37
+ SingleGrid,
38
+ )
39
+
40
+ OrthogonalGrid = SingleGrid | MultiGrid | OrthogonalMooreGrid | OrthogonalVonNeumannGrid
41
+ HexGrid = HexSingleGrid | HexMultiGrid | mesa.experimental.cell_space.HexGrid
42
+ Network = NetworkGrid | mesa.experimental.cell_space.Network
43
+
44
+
45
+ def collect_agent_data(
46
+ space: OrthogonalGrid | HexGrid | Network | ContinuousSpace | VoronoiGrid,
47
+ agent_portrayal: Callable,
48
+ color="tab:blue",
49
+ size=25,
50
+ marker="o",
51
+ zorder: int = 1,
52
+ ):
53
+ """Collect the plotting data for all agents in the space.
54
+
55
+ Args:
56
+ space: The space containing the Agents.
57
+ agent_portrayal: A callable that is called with the agent and returns a dict
58
+ color: default color
59
+ size: default size
60
+ marker: default marker
61
+ zorder: default zorder
62
+
63
+ agent_portrayal should return a dict, limited to size (size of marker), color (color of marker), zorder (z-order),
64
+ and marker (marker style)
65
+
66
+ """
67
+ arguments = {"s": [], "c": [], "marker": [], "zorder": [], "loc": []}
68
+
69
+ for agent in space.agents:
70
+ portray = agent_portrayal(agent)
71
+ loc = agent.pos
72
+ if loc is None:
73
+ loc = agent.cell.coordinate
74
+
75
+ arguments["loc"].append(loc)
76
+ arguments["s"].append(portray.pop("size", size))
77
+ arguments["c"].append(portray.pop("color", color))
78
+ arguments["marker"].append(portray.pop("marker", marker))
79
+ arguments["zorder"].append(portray.pop("zorder", zorder))
80
+
81
+ if len(portray) > 0:
82
+ ignored_fields = list(portray.keys())
83
+ msg = ", ".join(ignored_fields)
84
+ warnings.warn(
85
+ f"the following fields are not used in agent portrayal and thus ignored: {msg}.",
86
+ stacklevel=2,
87
+ )
88
+
89
+ return {k: np.asarray(v) for k, v in arguments.items()}
90
+
91
+
92
+ def draw_space(
93
+ space,
94
+ agent_portrayal: Callable,
95
+ propertylayer_portrayal: dict | None = None,
96
+ ax: Axes | None = None,
97
+ **space_drawing_kwargs,
98
+ ):
99
+ """Draw a Matplotlib-based visualization of the space.
100
+
101
+ Args:
102
+ space: the space of the mesa model
103
+ agent_portrayal: A callable that returns a dict specifying how to show the agent
104
+ propertylayer_portrayal: a dict specifying how to show propertylayer(s)
105
+ ax: the axes upon which to draw the plot
106
+ post_process: a callable called with the Axes instance
107
+ space_drawing_kwargs: any additional keyword arguments to be passed on to the underlying function for drawing the space.
108
+
109
+ Returns:
110
+ Returns the Axes object with the plot drawn onto it.
111
+
112
+ ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
113
+ "size", "marker", and "zorder". Other field are ignored and will result in a user warning.
114
+
115
+ """
116
+ if ax is None:
117
+ fig, ax = plt.subplots()
118
+
119
+ # https://stackoverflow.com/questions/67524641/convert-multiple-isinstance-checks-to-structural-pattern-matching
120
+ match space:
121
+ case mesa.space._Grid() | OrthogonalMooreGrid() | OrthogonalVonNeumannGrid():
122
+ draw_orthogonal_grid(space, agent_portrayal, ax=ax, **space_drawing_kwargs)
123
+ case HexSingleGrid() | HexMultiGrid() | mesa.experimental.cell_space.HexGrid():
124
+ draw_hex_grid(space, agent_portrayal, ax=ax, **space_drawing_kwargs)
125
+ case mesa.space.NetworkGrid() | mesa.experimental.cell_space.Network():
126
+ draw_network(space, agent_portrayal, ax=ax, **space_drawing_kwargs)
127
+ case mesa.space.ContinuousSpace():
128
+ draw_continuous_space(space, agent_portrayal, ax=ax)
129
+ case VoronoiGrid():
130
+ draw_voroinoi_grid(space, agent_portrayal, ax=ax)
131
+
132
+ if propertylayer_portrayal:
133
+ draw_property_layers(space, propertylayer_portrayal, ax=ax)
134
+
135
+ return ax
136
+
137
+
138
+ def draw_property_layers(
139
+ space, propertylayer_portrayal: dict[str, dict[str, Any]], ax: Axes
140
+ ):
141
+ """Draw PropertyLayers on the given axes.
142
+
143
+ Args:
144
+ space (mesa.space._Grid): The space containing the PropertyLayers.
145
+ propertylayer_portrayal (dict): the key is the name of the layer, the value is a dict with
146
+ fields specifying how the layer is to be portrayed
147
+ ax (matplotlib.axes.Axes): The axes to draw on.
148
+
149
+ Notes:
150
+ valid fields in in the inner dict of propertylayer_portrayal are "alpha", "vmin", "vmax", "color" or "colormap", and "colorbar"
151
+ so you can do `{"some_layer":{"colormap":'viridis', 'alpha':.25, "colorbar":False}}`
152
+
153
+ """
154
+ try:
155
+ # old style spaces
156
+ property_layers = space.properties
157
+ except AttributeError:
158
+ # new style spaces
159
+ property_layers = space.property_layers
160
+
161
+ for layer_name, portrayal in propertylayer_portrayal.items():
162
+ layer = property_layers.get(layer_name, None)
163
+ if not isinstance(layer, PropertyLayer):
164
+ continue
165
+
166
+ data = layer.data.astype(float) if layer.data.dtype == bool else layer.data
167
+ width, height = data.shape if space is None else (space.width, space.height)
168
+
169
+ if space and data.shape != (width, height):
170
+ warnings.warn(
171
+ f"Layer {layer_name} dimensions ({data.shape}) do not match space dimensions ({width}, {height}).",
172
+ UserWarning,
173
+ stacklevel=2,
174
+ )
175
+
176
+ # Get portrayal properties, or use defaults
177
+ alpha = portrayal.get("alpha", 1)
178
+ vmin = portrayal.get("vmin", np.min(data))
179
+ vmax = portrayal.get("vmax", np.max(data))
180
+ colorbar = portrayal.get("colorbar", True)
181
+
182
+ # Draw the layer
183
+ if "color" in portrayal:
184
+ rgba_color = to_rgba(portrayal["color"])
185
+ normalized_data = (data - vmin) / (vmax - vmin)
186
+ rgba_data = np.full((*data.shape, 4), rgba_color)
187
+ rgba_data[..., 3] *= normalized_data * alpha
188
+ rgba_data = np.clip(rgba_data, 0, 1)
189
+ cmap = LinearSegmentedColormap.from_list(
190
+ layer_name, [(0, 0, 0, 0), (*rgba_color[:3], alpha)]
191
+ )
192
+ im = ax.imshow(
193
+ rgba_data.transpose(1, 0, 2),
194
+ origin="lower",
195
+ )
196
+ if colorbar:
197
+ norm = Normalize(vmin=vmin, vmax=vmax)
198
+ sm = ScalarMappable(norm=norm, cmap=cmap)
199
+ sm.set_array([])
200
+ ax.figure.colorbar(sm, ax=ax, orientation="vertical")
201
+
202
+ elif "colormap" in portrayal:
203
+ cmap = portrayal.get("colormap", "viridis")
204
+ if isinstance(cmap, list):
205
+ cmap = LinearSegmentedColormap.from_list(layer_name, cmap)
206
+ im = ax.imshow(
207
+ data.T,
208
+ cmap=cmap,
209
+ alpha=alpha,
210
+ vmin=vmin,
211
+ vmax=vmax,
212
+ origin="lower",
213
+ )
214
+ if colorbar:
215
+ plt.colorbar(im, ax=ax, label=layer_name)
216
+ else:
217
+ raise ValueError(
218
+ f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'."
219
+ )
220
+
221
+
222
+ def draw_orthogonal_grid(
223
+ space: OrthogonalGrid,
224
+ agent_portrayal: Callable,
225
+ ax: Axes | None = None,
226
+ draw_grid: bool = True,
227
+ **kwargs,
228
+ ):
229
+ """Visualize a orthogonal grid.
230
+
231
+ Args:
232
+ space: the space to visualize
233
+ agent_portrayal: a callable that is called with the agent and returns a dict
234
+ ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
235
+ draw_grid: whether to draw the grid
236
+ kwargs: additional keyword arguments passed to ax.scatter
237
+
238
+ Returns:
239
+ Returns the Axes object with the plot drawn onto it.
240
+
241
+ ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
242
+ "size", "marker", and "zorder". Other field are ignored and will result in a user warning.
243
+
244
+ """
245
+ if ax is None:
246
+ fig, ax = plt.subplots()
247
+
248
+ # gather agent data
249
+ s_default = (180 / max(space.width, space.height)) ** 2
250
+ arguments = collect_agent_data(space, agent_portrayal, size=s_default)
251
+
252
+ # plot the agents
253
+ _scatter(ax, arguments, **kwargs)
254
+
255
+ # further styling
256
+ ax.set_xlim(-0.5, space.width - 0.5)
257
+ ax.set_ylim(-0.5, space.height - 0.5)
258
+
259
+ if draw_grid:
260
+ # Draw grid lines
261
+ for x in np.arange(-0.5, space.width - 0.5, 1):
262
+ ax.axvline(x, color="gray", linestyle=":")
263
+ for y in np.arange(-0.5, space.height - 0.5, 1):
264
+ ax.axhline(y, color="gray", linestyle=":")
265
+
266
+ return ax
267
+
268
+
269
+ def draw_hex_grid(
270
+ space: HexGrid,
271
+ agent_portrayal: Callable,
272
+ ax: Axes | None = None,
273
+ draw_grid: bool = True,
274
+ **kwargs,
275
+ ):
276
+ """Visualize a hex grid.
277
+
278
+ Args:
279
+ space: the space to visualize
280
+ agent_portrayal: a callable that is called with the agent and returns a dict
281
+ ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
282
+ draw_grid: whether to draw the grid
283
+ kwargs: additional keyword arguments passed to ax.scatter
284
+
285
+ Returns:
286
+ Returns the Axes object with the plot drawn onto it.
287
+
288
+ ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
289
+ "size", "marker", and "zorder". Other field are ignored and will result in a user warning.
290
+
291
+ """
292
+ if ax is None:
293
+ fig, ax = plt.subplots()
294
+
295
+ # gather data
296
+ s_default = (180 / max(space.width, space.height)) ** 2
297
+ arguments = collect_agent_data(space, agent_portrayal, size=s_default)
298
+
299
+ # for hexgrids we have to go from logical coordinates to visual coordinates
300
+ # this is a bit messy.
301
+
302
+ # give all even rows an offset in the x direction
303
+ # give all rows an offset in the y direction
304
+
305
+ # numbers here are based on a distance of 1 between centers of hexes
306
+ offset = math.sqrt(0.75)
307
+
308
+ loc = arguments["loc"].astype(float)
309
+
310
+ logical = np.mod(loc[:, 1], 2) == 0
311
+ loc[:, 0][logical] += 0.5
312
+ loc[:, 1] *= offset
313
+ arguments["loc"] = loc
314
+
315
+ # plot the agents
316
+ _scatter(ax, arguments, **kwargs)
317
+
318
+ # further styling and adding of grid
319
+ ax.set_xlim(-1, space.width + 0.5)
320
+ ax.set_ylim(-offset, space.height * offset)
321
+
322
+ def setup_hexmesh(
323
+ width,
324
+ height,
325
+ ):
326
+ """Helper function for creating the hexmaesh."""
327
+ # fixme: this should be done once, rather than in each update
328
+ # fixme check coordinate system in hexgrid (see https://www.redblobgames.com/grids/hexagons/#coordinates-offset)
329
+
330
+ patches = []
331
+ for x, y in itertools.product(range(width), range(height)):
332
+ if y % 2 == 0:
333
+ x += 0.5 # noqa: PLW2901
334
+ y *= offset # noqa: PLW2901
335
+ hex = RegularPolygon(
336
+ (x, y),
337
+ numVertices=6,
338
+ radius=math.sqrt(1 / 3),
339
+ orientation=np.radians(120),
340
+ )
341
+ patches.append(hex)
342
+ mesh = PatchCollection(
343
+ patches, edgecolor="k", facecolor=(1, 1, 1, 0), linestyle="dotted", lw=1
344
+ )
345
+ return mesh
346
+
347
+ if draw_grid:
348
+ # add grid
349
+ ax.add_collection(
350
+ setup_hexmesh(
351
+ space.width,
352
+ space.height,
353
+ )
354
+ )
355
+ return ax
356
+
357
+
358
+ def draw_network(
359
+ space: Network,
360
+ agent_portrayal: Callable,
361
+ ax: Axes | None = None,
362
+ draw_grid: bool = True,
363
+ layout_alg=nx.spring_layout,
364
+ layout_kwargs=None,
365
+ **kwargs,
366
+ ):
367
+ """Visualize a network space.
368
+
369
+ Args:
370
+ space: the space to visualize
371
+ agent_portrayal: a callable that is called with the agent and returns a dict
372
+ ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
373
+ draw_grid: whether to draw the grid
374
+ layout_alg: a networkx layout algorithm or other callable with the same behavior
375
+ layout_kwargs: a dictionary of keyword arguments for the layout algorithm
376
+ kwargs: additional keyword arguments passed to ax.scatter
377
+
378
+ Returns:
379
+ Returns the Axes object with the plot drawn onto it.
380
+
381
+ ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
382
+ "size", "marker", and "zorder". Other field are ignored and will result in a user warning.
383
+
384
+ """
385
+ if ax is None:
386
+ fig, ax = plt.subplots()
387
+ if layout_kwargs is None:
388
+ layout_kwargs = {"seed": 0}
389
+
390
+ # gather locations for nodes in network
391
+ graph = space.G
392
+ pos = layout_alg(graph, **layout_kwargs)
393
+ x, y = list(zip(*pos.values()))
394
+ xmin, xmax = min(x), max(x)
395
+ ymin, ymax = min(y), max(y)
396
+
397
+ width = xmax - xmin
398
+ height = ymax - ymin
399
+ x_padding = width / 20
400
+ y_padding = height / 20
401
+
402
+ # gather agent data
403
+ s_default = (180 / max(width, height)) ** 2
404
+ arguments = collect_agent_data(space, agent_portrayal, size=s_default)
405
+
406
+ # this assumes that nodes are identified by an integer
407
+ # which is true for default nx graphs but might user changeable
408
+ pos = np.asarray(list(pos.values()))
409
+ arguments["loc"] = pos[arguments["loc"]]
410
+
411
+ # plot the agents
412
+ _scatter(ax, arguments, **kwargs)
413
+
414
+ # further styling
415
+ ax.set_axis_off()
416
+ ax.set_xlim(xmin=xmin - x_padding, xmax=xmax + x_padding)
417
+ ax.set_ylim(ymin=ymin - y_padding, ymax=ymax + y_padding)
418
+
419
+ if draw_grid:
420
+ # fixme we need to draw the empty nodes as well
421
+ edge_collection = nx.draw_networkx_edges(
422
+ graph, pos, ax=ax, alpha=0.5, style="--"
423
+ )
424
+ edge_collection.set_zorder(0)
425
+
426
+ return ax
427
+
428
+
429
+ def draw_continuous_space(
430
+ space: ContinuousSpace, agent_portrayal: Callable, ax: Axes | None = None, **kwargs
431
+ ):
432
+ """Visualize a continuous space.
433
+
434
+ Args:
435
+ space: the space to visualize
436
+ agent_portrayal: a callable that is called with the agent and returns a dict
437
+ ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
438
+ kwargs: additional keyword arguments passed to ax.scatter
439
+
440
+ Returns:
441
+ Returns the Axes object with the plot drawn onto it.
442
+
443
+ ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
444
+ "size", "marker", and "zorder". Other field are ignored and will result in a user warning.
445
+
446
+ """
447
+ if ax is None:
448
+ fig, ax = plt.subplots()
449
+
450
+ # space related setup
451
+ width = space.x_max - space.x_min
452
+ x_padding = width / 20
453
+ height = space.y_max - space.y_min
454
+ y_padding = height / 20
455
+
456
+ # gather agent data
457
+ s_default = (180 / max(width, height)) ** 2
458
+ arguments = collect_agent_data(space, agent_portrayal, size=s_default)
459
+
460
+ # plot the agents
461
+ _scatter(ax, arguments, **kwargs)
462
+
463
+ # further visual styling
464
+ border_style = "solid" if not space.torus else (0, (5, 10))
465
+ for spine in ax.spines.values():
466
+ spine.set_linewidth(1.5)
467
+ spine.set_color("black")
468
+ spine.set_linestyle(border_style)
469
+
470
+ ax.set_xlim(space.x_min - x_padding, space.x_max + x_padding)
471
+ ax.set_ylim(space.y_min - y_padding, space.y_max + y_padding)
472
+
473
+ return ax
474
+
475
+
476
+ def draw_voroinoi_grid(
477
+ space: VoronoiGrid, agent_portrayal: Callable, ax: Axes | None = None, **kwargs
478
+ ):
479
+ """Visualize a voronoi grid.
480
+
481
+ Args:
482
+ space: the space to visualize
483
+ agent_portrayal: a callable that is called with the agent and returns a dict
484
+ ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
485
+ kwargs: additional keyword arguments passed to ax.scatter
486
+
487
+ Returns:
488
+ Returns the Axes object with the plot drawn onto it.
489
+
490
+ ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
491
+ "size", "marker", and "zorder". Other field are ignored and will result in a user warning.
492
+
493
+ """
494
+ if ax is None:
495
+ fig, ax = plt.subplots()
496
+
497
+ x_list = [i[0] for i in space.centroids_coordinates]
498
+ y_list = [i[1] for i in space.centroids_coordinates]
499
+ x_max = max(x_list)
500
+ x_min = min(x_list)
501
+ y_max = max(y_list)
502
+ y_min = min(y_list)
503
+
504
+ width = x_max - x_min
505
+ x_padding = width / 20
506
+ height = y_max - y_min
507
+ y_padding = height / 20
508
+
509
+ s_default = (180 / max(width, height)) ** 2
510
+ arguments = collect_agent_data(space, agent_portrayal, size=s_default)
511
+
512
+ ax.set_xlim(x_min - x_padding, x_max + x_padding)
513
+ ax.set_ylim(y_min - y_padding, y_max + y_padding)
514
+
515
+ _scatter(ax, arguments, **kwargs)
516
+
517
+ for cell in space.all_cells:
518
+ polygon = cell.properties["polygon"]
519
+ ax.fill(
520
+ *zip(*polygon),
521
+ alpha=min(1, cell.properties[space.cell_coloring_property]),
522
+ c="red",
523
+ zorder=0,
524
+ ) # Plot filled polygon
525
+ ax.plot(*zip(*polygon), color="black") # Plot polygon edges in black
526
+
527
+ return ax
528
+
529
+
530
+ def _scatter(ax: Axes, arguments, **kwargs):
531
+ """Helper function for plotting the agents.
532
+
533
+ Args:
534
+ ax: a Matplotlib Axes instance
535
+ arguments: the agents specific arguments for platting
536
+ kwargs: additional keyword arguments for ax.scatter
537
+
538
+ """
539
+ loc = arguments.pop("loc")
540
+
541
+ x = loc[:, 0]
542
+ y = loc[:, 1]
543
+ marker = arguments.pop("marker")
544
+ zorder = arguments.pop("zorder")
545
+
546
+ for mark in np.unique(marker):
547
+ mark_mask = marker == mark
548
+ for z_order in np.unique(zorder):
549
+ zorder_mask = z_order == zorder
550
+ logical = mark_mask & zorder_mask
551
+ ax.scatter(
552
+ x[logical],
553
+ y[logical],
554
+ marker=mark,
555
+ zorder=z_order,
556
+ **{k: v[logical] for k, v in arguments.items()},
557
+ **kwargs,
558
+ )
@@ -25,13 +25,14 @@ from __future__ import annotations
25
25
 
26
26
  import asyncio
27
27
  import copy
28
+ import inspect
28
29
  from collections.abc import Callable
29
30
  from typing import TYPE_CHECKING, Literal
30
31
 
31
32
  import reacton.core
32
33
  import solara
33
34
 
34
- import mesa.visualization.components.altair as components_altair
35
+ import mesa.visualization.components.altair_components as components_altair
35
36
  from mesa.visualization.UserParam import Slider
36
37
  from mesa.visualization.utils import force_update, update_counter
37
38
 
@@ -299,9 +300,12 @@ def ModelCreator(model, model_params, seed=1):
299
300
  - The component provides an interface for adjusting user-defined parameters and reseeding the model.
300
301
 
301
302
  """
302
- user_params, fixed_params = split_model_params(model_params)
303
+ solara.use_effect(
304
+ lambda: _check_model_params(model.value.__class__.__init__, fixed_params),
305
+ [model.value],
306
+ )
303
307
 
304
- reactive_seed = solara.use_reactive(seed)
308
+ user_params, fixed_params = split_model_params(model_params)
305
309
 
306
310
  model_parameters, set_model_parameters = solara.use_state(
307
311
  {
@@ -310,29 +314,35 @@ def ModelCreator(model, model_params, seed=1):
310
314
  }
311
315
  )
312
316
 
313
- def do_reseed():
314
- """Update the random seed for the model."""
315
- reactive_seed.value = model.value.random.random()
316
-
317
317
  def on_change(name, value):
318
- set_model_parameters({**model_parameters, name: value})
318
+ new_model_parameters = {**model_parameters, name: value}
319
+ model.value = model.value.__class__(**new_model_parameters)
320
+ set_model_parameters(new_model_parameters)
319
321
 
320
- def create_model():
321
- model.value = model.value.__class__(**model_parameters)
322
- model.value._seed = reactive_seed.value
322
+ UserInputs(user_params, on_change=on_change)
323
323
 
324
- solara.use_effect(create_model, [model_parameters, reactive_seed.value])
325
324
 
326
- with solara.Row(justify="space-between"):
327
- solara.InputText(
328
- label="Seed",
329
- value=reactive_seed,
330
- continuous_update=True,
331
- )
325
+ def _check_model_params(init_func, model_params):
326
+ """Check if model parameters are valid for the model's initialization function.
332
327
 
333
- solara.Button(label="Reseed", color="primary", on_click=do_reseed)
328
+ Args:
329
+ init_func: Model initialization function
330
+ model_params: Dictionary of model parameters
334
331
 
335
- UserInputs(user_params, on_change=on_change)
332
+ Raises:
333
+ ValueError: If a parameter is not valid for the model's initialization function
334
+ """
335
+ model_parameters = inspect.signature(init_func).parameters
336
+ for name in model_parameters:
337
+ if (
338
+ model_parameters[name].default == inspect.Parameter.empty
339
+ and name not in model_params
340
+ and name != "self"
341
+ ):
342
+ raise ValueError(f"Missing required model parameter: {name}")
343
+ for name in model_params:
344
+ if name not in model_parameters:
345
+ raise ValueError(f"Invalid model parameter: {name}")
336
346
 
337
347
 
338
348
  @solara.component
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: Mesa
3
- Version: 3.0.0b2
3
+ Version: 3.0.0rc0
4
4
  Summary: Agent-based modeling (ABM) in Python
5
5
  Project-URL: homepage, https://github.com/projectmesa/mesa
6
6
  Project-URL: repository, https://github.com/projectmesa/mesa