Mesa 3.2.0.dev0__py3-none-any.whl → 3.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of Mesa might be problematic. Click here for more details.

Files changed (58) hide show
  1. mesa/__init__.py +1 -1
  2. mesa/agent.py +9 -7
  3. mesa/datacollection.py +1 -1
  4. mesa/examples/README.md +1 -1
  5. mesa/examples/__init__.py +2 -0
  6. mesa/examples/advanced/alliance_formation/Readme.md +50 -0
  7. mesa/examples/advanced/alliance_formation/__init__ .py +0 -0
  8. mesa/examples/advanced/alliance_formation/agents.py +20 -0
  9. mesa/examples/advanced/alliance_formation/app.py +71 -0
  10. mesa/examples/advanced/alliance_formation/model.py +184 -0
  11. mesa/examples/advanced/epstein_civil_violence/app.py +11 -11
  12. mesa/examples/advanced/pd_grid/Readme.md +4 -6
  13. mesa/examples/advanced/pd_grid/app.py +10 -11
  14. mesa/examples/advanced/sugarscape_g1mt/Readme.md +4 -5
  15. mesa/examples/advanced/sugarscape_g1mt/app.py +34 -16
  16. mesa/examples/advanced/wolf_sheep/Readme.md +2 -17
  17. mesa/examples/advanced/wolf_sheep/app.py +21 -18
  18. mesa/examples/basic/boid_flockers/Readme.md +6 -1
  19. mesa/examples/basic/boid_flockers/app.py +15 -11
  20. mesa/examples/basic/boltzmann_wealth_model/Readme.md +2 -12
  21. mesa/examples/basic/boltzmann_wealth_model/app.py +39 -32
  22. mesa/examples/basic/conways_game_of_life/Readme.md +1 -9
  23. mesa/examples/basic/conways_game_of_life/app.py +13 -16
  24. mesa/examples/basic/schelling/Readme.md +2 -10
  25. mesa/examples/basic/schelling/agents.py +9 -3
  26. mesa/examples/basic/schelling/app.py +50 -3
  27. mesa/examples/basic/schelling/model.py +2 -0
  28. mesa/examples/basic/schelling/resources/blue_happy.png +0 -0
  29. mesa/examples/basic/schelling/resources/blue_unhappy.png +0 -0
  30. mesa/examples/basic/schelling/resources/orange_happy.png +0 -0
  31. mesa/examples/basic/schelling/resources/orange_unhappy.png +0 -0
  32. mesa/examples/basic/virus_on_network/Readme.md +0 -4
  33. mesa/examples/basic/virus_on_network/app.py +31 -14
  34. mesa/experimental/__init__.py +2 -2
  35. mesa/experimental/continuous_space/continuous_space.py +1 -1
  36. mesa/experimental/meta_agents/__init__.py +25 -0
  37. mesa/experimental/meta_agents/meta_agent.py +387 -0
  38. mesa/model.py +3 -3
  39. mesa/space.py +4 -1
  40. mesa/visualization/__init__.py +2 -0
  41. mesa/visualization/backends/__init__.py +23 -0
  42. mesa/visualization/backends/abstract_renderer.py +97 -0
  43. mesa/visualization/backends/altair_backend.py +440 -0
  44. mesa/visualization/backends/matplotlib_backend.py +419 -0
  45. mesa/visualization/components/__init__.py +28 -8
  46. mesa/visualization/components/altair_components.py +86 -0
  47. mesa/visualization/components/matplotlib_components.py +4 -2
  48. mesa/visualization/components/portrayal_components.py +120 -0
  49. mesa/visualization/mpl_space_drawing.py +292 -129
  50. mesa/visualization/solara_viz.py +274 -32
  51. mesa/visualization/space_drawers.py +797 -0
  52. mesa/visualization/space_renderer.py +399 -0
  53. {mesa-3.2.0.dev0.dist-info → mesa-3.3.0.dist-info}/METADATA +13 -4
  54. {mesa-3.2.0.dev0.dist-info → mesa-3.3.0.dist-info}/RECORD +57 -40
  55. mesa/examples/advanced/sugarscape_g1mt/tests.py +0 -69
  56. {mesa-3.2.0.dev0.dist-info → mesa-3.3.0.dist-info}/WHEEL +0 -0
  57. {mesa-3.2.0.dev0.dist-info → mesa-3.3.0.dist-info}/licenses/LICENSE +0 -0
  58. {mesa-3.2.0.dev0.dist-info → mesa-3.3.0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,419 @@
1
+ # noqa: D100
2
+ import os
3
+ import warnings
4
+ from dataclasses import fields
5
+
6
+ import numpy as np
7
+ from matplotlib import pyplot as plt
8
+ from matplotlib.cm import ScalarMappable
9
+ from matplotlib.collections import PolyCollection
10
+ from matplotlib.colors import LinearSegmentedColormap, Normalize, to_rgba
11
+ from matplotlib.offsetbox import AnnotationBbox, OffsetImage
12
+ from PIL import Image
13
+
14
+ import mesa
15
+ from mesa.discrete_space import (
16
+ OrthogonalMooreGrid,
17
+ OrthogonalVonNeumannGrid,
18
+ )
19
+ from mesa.space import (
20
+ HexMultiGrid,
21
+ HexSingleGrid,
22
+ MultiGrid,
23
+ SingleGrid,
24
+ )
25
+ from mesa.visualization.backends.abstract_renderer import AbstractRenderer
26
+
27
+ OrthogonalGrid = SingleGrid | MultiGrid | OrthogonalMooreGrid | OrthogonalVonNeumannGrid
28
+ HexGrid = HexSingleGrid | HexMultiGrid | mesa.discrete_space.HexGrid
29
+
30
+
31
+ CORRECTION_FACTOR_MARKER_ZOOM = 0.01
32
+
33
+
34
+ class MatplotlibBackend(AbstractRenderer):
35
+ """Matplotlib-based renderer for Mesa spaces.
36
+
37
+ Provides visualization capabilities using Matplotlib for rendering
38
+ space structures, agents, and property layers.
39
+ """
40
+
41
+ def __init__(self, space_drawer):
42
+ """Initialize the Matplotlib backend.
43
+
44
+ Args:
45
+ space_drawer: An instance of a SpaceDrawer class that handles
46
+ the drawing of the space structure.
47
+ """
48
+ super().__init__(space_drawer)
49
+
50
+ self._active_colorbars = []
51
+
52
+ def initialize_canvas(self, ax=None):
53
+ """Initialize the matplotlib canvas.
54
+
55
+ Args:
56
+ ax (matplotlib.axes.Axes, optional): Existing axes to use.
57
+ If None, creates new figure and axes.
58
+ """
59
+ if ax is None:
60
+ fig, ax = plt.subplots(constrained_layout=True)
61
+ self.fig = fig
62
+ self.ax = ax
63
+
64
+ def draw_structure(self, **kwargs):
65
+ """Draw the space structure using matplotlib.
66
+
67
+ Args:
68
+ **kwargs: Additional arguments passed to the space drawer.
69
+ Checkout respective `SpaceDrawer` class on details how to pass **kwargs.
70
+
71
+ Returns:
72
+ The matplotlib axes with the drawn structure.
73
+ """
74
+ return self.space_drawer.draw_matplotlib(self.ax, **kwargs)
75
+
76
+ def collect_agent_data(self, space, agent_portrayal, default_size=None):
77
+ """Collect plotting data for all agents in the space.
78
+
79
+ Args:
80
+ space: The Mesa space containing agents.
81
+ agent_portrayal (Callable): Function that returns AgentPortrayalStyle for each agent.
82
+ default_size (float, optional): Default marker size if not specified in portrayal.
83
+
84
+ Returns:
85
+ dict: Dictionary containing agent plotting data arrays.
86
+ """
87
+ # Initialize data collection arrays
88
+ arguments = {
89
+ "loc": [],
90
+ "s": [],
91
+ "c": [],
92
+ "marker": [],
93
+ "zorder": [],
94
+ "alpha": [],
95
+ "edgecolors": [],
96
+ "linewidths": [],
97
+ }
98
+ # Import here to prevent circular imports
99
+ from mesa.visualization.components import AgentPortrayalStyle # noqa: PLC0415
100
+
101
+ # Get default values from AgentPortrayalStyle
102
+ style_fields = {f.name: f.default for f in fields(AgentPortrayalStyle)}
103
+ class_default_size = style_fields.get("size")
104
+
105
+ for agent in space.agents:
106
+ portray_input = agent_portrayal(agent)
107
+
108
+ if isinstance(portray_input, dict):
109
+ warnings.warn(
110
+ "Returning a dict from agent_portrayal is deprecated. "
111
+ "Please return an AgentPortrayalStyle instance instead.",
112
+ PendingDeprecationWarning,
113
+ stacklevel=2,
114
+ )
115
+ # Handle legacy dict input
116
+ dict_data = portray_input.copy()
117
+ agent_x, agent_y = self._get_agent_pos(agent, space)
118
+
119
+ # Extract values with defaults
120
+ aps = AgentPortrayalStyle(
121
+ x=agent_x,
122
+ y=agent_y,
123
+ size=dict_data.pop("size", style_fields.get("size")),
124
+ color=dict_data.pop("color", style_fields.get("color")),
125
+ marker=dict_data.pop("marker", style_fields.get("marker")),
126
+ zorder=dict_data.pop("zorder", style_fields.get("zorder")),
127
+ alpha=dict_data.pop("alpha", style_fields.get("alpha")),
128
+ edgecolors=dict_data.pop("edgecolors", None),
129
+ linewidths=dict_data.pop(
130
+ "linewidths", style_fields.get("linewidths")
131
+ ),
132
+ )
133
+
134
+ # Warn about unused keys
135
+ if dict_data:
136
+ ignored_keys = list(dict_data.keys())
137
+ warnings.warn(
138
+ f"The following keys were ignored: {', '.join(ignored_keys)}",
139
+ UserWarning,
140
+ stacklevel=2,
141
+ )
142
+ else:
143
+ aps = portray_input
144
+ # Set defaults if not provided
145
+ if aps.x is None and aps.y is None:
146
+ aps.x, aps.y = self._get_agent_pos(agent, space)
147
+
148
+ # Collect agent data
149
+ arguments["loc"].append((aps.x, aps.y))
150
+
151
+ # Determine final size
152
+ size_to_collect = aps.size or default_size or class_default_size
153
+ arguments["s"].append(size_to_collect)
154
+ arguments["c"].append(aps.color)
155
+ arguments["marker"].append(aps.marker)
156
+ arguments["zorder"].append(aps.zorder)
157
+ arguments["alpha"].append(aps.alpha)
158
+ if aps.edgecolors is not None:
159
+ arguments["edgecolors"].append(aps.edgecolors)
160
+ arguments["linewidths"].append(aps.linewidths)
161
+
162
+ # Convert to numpy arrays
163
+ data = {
164
+ k: (np.asarray(v, dtype=object) if k == "marker" else np.asarray(v))
165
+ for k, v in arguments.items()
166
+ }
167
+
168
+ # Handle marker array specially to preserve tuples
169
+ arr = np.empty(len(arguments["marker"]), dtype=object)
170
+ arr[:] = arguments["marker"]
171
+ data["marker"] = arr
172
+
173
+ return data
174
+
175
+ def _get_zoom_factor(self, ax, img):
176
+ """Calculate zoom factor only once and cache the result."""
177
+ ax.get_figure().canvas.draw()
178
+
179
+ bbox = ax.get_window_extent().transformed(
180
+ ax.get_figure().dpi_scale_trans.inverted()
181
+ )
182
+ width, height = (
183
+ bbox.width * ax.get_figure().dpi,
184
+ bbox.height * ax.get_figure().dpi,
185
+ )
186
+
187
+ xr = ax.get_xlim()
188
+ yr = ax.get_ylim()
189
+
190
+ x_pixel_per_data = width / (xr[1] - xr[0])
191
+ y_pixel_per_data = height / (yr[1] - yr[0])
192
+
193
+ zoom_x = (x_pixel_per_data / img.width) * CORRECTION_FACTOR_MARKER_ZOOM
194
+ zoom_y = (y_pixel_per_data / img.height) * CORRECTION_FACTOR_MARKER_ZOOM
195
+
196
+ return min(zoom_x, zoom_y)
197
+
198
+ def draw_agents(self, arguments, **kwargs):
199
+ """Draw agents on the backend's axes - optimized version.
200
+
201
+ Args:
202
+ arguments: Dictionary containing agent data arrays.
203
+ **kwargs: Additional keyword arguments for customization.
204
+ Checkout respective `SpaceDrawer` class on details how to pass **kwargs.
205
+
206
+ Returns:
207
+ matplotlib.axes.Axes: The Matplotlib Axes with the agents drawn upon it.
208
+ """
209
+ if arguments["loc"].size == 0:
210
+ return None
211
+
212
+ loc = arguments.pop("loc")
213
+ loc_x, loc_y = loc[:, 0], loc[:, 1]
214
+ marker = arguments.pop("marker")
215
+ zorder = arguments.pop("zorder")
216
+ malpha = arguments["alpha"]
217
+ msize = arguments["s"]
218
+
219
+ # Validate edge arguments
220
+ for entry in ["edgecolors", "linewidths"]:
221
+ if len(arguments[entry]) == 0:
222
+ arguments.pop(entry)
223
+ elif entry in kwargs:
224
+ raise ValueError(
225
+ f"{entry} is specified in agent portrayal and via plotting kwargs, "
226
+ "you can only use one or the other"
227
+ )
228
+
229
+ # Cache for loaded images and their zoom factors
230
+ image_cache = {}
231
+
232
+ # Separate image and non-image markers
233
+ unique_markers = set(marker)
234
+ image_markers = set()
235
+ regular_markers = set()
236
+
237
+ for mark in unique_markers:
238
+ if isinstance(mark, str | os.PathLike) and os.path.isfile(mark):
239
+ image_markers.add(mark)
240
+ else:
241
+ regular_markers.add(mark)
242
+
243
+ self.ax.get_figure().canvas.draw()
244
+
245
+ for mark in image_markers:
246
+ if mark not in image_cache:
247
+ image = Image.open(mark)
248
+ base_zoom = self._get_zoom_factor(self.ax, image)
249
+ image_cache[mark] = {
250
+ "image": image,
251
+ "base_zoom": base_zoom,
252
+ "offset_images": {},
253
+ }
254
+
255
+ cache_entry = image_cache[mark]
256
+ mask_marker = marker == mark
257
+
258
+ unique_sizes = np.unique(msize[mask_marker])
259
+
260
+ for m_size in unique_sizes:
261
+ if m_size not in cache_entry["offset_images"]:
262
+ im = OffsetImage(
263
+ cache_entry["image"],
264
+ zoom=cache_entry["base_zoom"] * m_size,
265
+ )
266
+ im.image.axes = self.ax
267
+ cache_entry["offset_images"][m_size] = im
268
+
269
+ im = cache_entry["offset_images"][m_size]
270
+
271
+ size_marker_mask = mask_marker & (msize == m_size)
272
+
273
+ unique_zorders = np.unique(zorder[size_marker_mask])
274
+ unique_alphas = np.unique(malpha[size_marker_mask])
275
+
276
+ for z_order in unique_zorders:
277
+ for m_alpha in unique_alphas:
278
+ final_mask = (
279
+ (zorder == z_order) & (malpha == m_alpha) & size_marker_mask
280
+ )
281
+ positions = loc[final_mask]
282
+
283
+ for x, y in positions:
284
+ ab = AnnotationBbox(
285
+ im,
286
+ (x, y),
287
+ frameon=False,
288
+ pad=0.0,
289
+ zorder=z_order,
290
+ **kwargs,
291
+ )
292
+ self.ax.add_artist(ab)
293
+
294
+ for mark in regular_markers:
295
+ mask_marker = marker == mark
296
+
297
+ unique_zorders = np.unique(zorder[mask_marker])
298
+
299
+ for z_order in unique_zorders:
300
+ zorder_mask = (zorder == z_order) & mask_marker
301
+
302
+ scatter_args = {k: v[zorder_mask] for k, v in arguments.items()}
303
+
304
+ self.ax.scatter(
305
+ loc_x[zorder_mask],
306
+ loc_y[zorder_mask],
307
+ marker=mark,
308
+ zorder=z_order,
309
+ **scatter_args,
310
+ **kwargs,
311
+ )
312
+
313
+ return self.ax
314
+
315
+ def draw_propertylayer(self, space, property_layers, propertylayer_portrayal):
316
+ """Draw property layers using matplotlib backend.
317
+
318
+ Args:
319
+ space: The Mesa space object.
320
+ property_layers (dict): Dictionary of property layers to visualize.
321
+ propertylayer_portrayal (Callable): Function that returns PropertyLayerStyle.
322
+
323
+ Returns:
324
+ tuple: (matplotlib.axes.Axes, colorbar) - The matplotlib axes and colorbar objects.
325
+ """
326
+ # Draw each layer
327
+ for layer_name in property_layers:
328
+ if layer_name == "empty":
329
+ continue
330
+
331
+ layer = property_layers.get(layer_name)
332
+ portrayal = propertylayer_portrayal(layer)
333
+
334
+ if portrayal is None:
335
+ continue
336
+
337
+ data = layer.data.astype(float) if layer.data.dtype == bool else layer.data
338
+
339
+ # Check dimensions
340
+ if (space.width, space.height) != data.shape:
341
+ warnings.warn(
342
+ f"Layer {layer_name} dimensions ({data.shape}) "
343
+ f"don't match space dimensions ({space.width}, {space.height})",
344
+ UserWarning,
345
+ stacklevel=2,
346
+ )
347
+ continue
348
+
349
+ # Get portrayal parameters
350
+ color = portrayal.color
351
+ colormap = portrayal.colormap
352
+ alpha = portrayal.alpha
353
+ vmin = portrayal.vmin if portrayal.vmin is not None else np.min(data)
354
+ vmax = portrayal.vmax if portrayal.vmax is not None else np.max(data)
355
+
356
+ # Set up colormap
357
+ if color:
358
+ rgba_color = to_rgba(color)
359
+ cmap = LinearSegmentedColormap.from_list(
360
+ layer_name, [(0, 0, 0, 0), (*rgba_color[:3], alpha)]
361
+ )
362
+ elif colormap:
363
+ cmap = colormap
364
+ if isinstance(cmap, list):
365
+ cmap = LinearSegmentedColormap.from_list(layer_name, cmap)
366
+ elif isinstance(cmap, str):
367
+ cmap = plt.get_cmap(cmap)
368
+ else:
369
+ raise ValueError(
370
+ f"PropertyLayer {layer_name} must include 'color' or 'colormap'"
371
+ )
372
+
373
+ # Draw based on space type
374
+ if isinstance(space, OrthogonalGrid):
375
+ if color:
376
+ data = data.T
377
+ normalized_data = (data - vmin) / (vmax - vmin)
378
+ rgba_data = np.full((*data.shape, 4), rgba_color)
379
+ rgba_data[..., 3] *= normalized_data * alpha
380
+ rgba_data = np.clip(rgba_data, 0, 1)
381
+ self.ax.imshow(rgba_data, origin="lower")
382
+ else:
383
+ self.ax.imshow(
384
+ data.T,
385
+ cmap=cmap,
386
+ alpha=alpha,
387
+ vmin=vmin,
388
+ vmax=vmax,
389
+ origin="lower",
390
+ )
391
+ elif isinstance(space, HexGrid):
392
+ hexagons = self.space_drawer.hexagons
393
+ norm = Normalize(vmin=vmin, vmax=vmax)
394
+ colors = data.ravel()
395
+
396
+ if color:
397
+ normalized_colors = np.clip(norm(colors), 0, 1)
398
+ rgba_colors = np.full((len(colors), 4), rgba_color)
399
+ rgba_colors[:, 3] = normalized_colors * alpha
400
+ else:
401
+ rgba_colors = cmap(norm(colors))
402
+ rgba_colors[..., 3] *= alpha
403
+
404
+ collection = PolyCollection(hexagons, facecolors=rgba_colors, zorder=-1)
405
+ self.ax.add_collection(collection)
406
+ else:
407
+ raise NotImplementedError(
408
+ f"PropertyLayer visualization not implemented for {type(space)}"
409
+ )
410
+
411
+ # Add colorbar if requested
412
+ cbar = None
413
+ if portrayal.colorbar:
414
+ norm = Normalize(vmin=vmin, vmax=vmax)
415
+ sm = ScalarMappable(norm=norm, cmap=cmap)
416
+ sm.set_array([])
417
+ cbar = plt.colorbar(sm, ax=self.ax, label=layer_name)
418
+ self._active_colorbars.append(cbar)
419
+ return self.ax, cbar
@@ -1,15 +1,32 @@
1
- """custom solara components."""
1
+ """Custom visualization components."""
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
5
  from collections.abc import Callable
6
6
 
7
- from .altair_components import SpaceAltair, make_altair_space
7
+ from .altair_components import (
8
+ SpaceAltair,
9
+ make_altair_plot_component,
10
+ make_altair_space,
11
+ )
8
12
  from .matplotlib_components import (
9
13
  SpaceMatplotlib,
10
14
  make_mpl_plot_component,
11
15
  make_mpl_space_component,
12
16
  )
17
+ from .portrayal_components import AgentPortrayalStyle, PropertyLayerStyle
18
+
19
+ __all__ = [
20
+ "AgentPortrayalStyle",
21
+ "PropertyLayerStyle",
22
+ "SpaceAltair",
23
+ "SpaceMatplotlib",
24
+ "make_altair_space",
25
+ "make_mpl_plot_component",
26
+ "make_mpl_space_component",
27
+ "make_plot_component",
28
+ "make_space_component",
29
+ ]
13
30
 
14
31
 
15
32
  def make_space_component(
@@ -57,6 +74,7 @@ def make_plot_component(
57
74
  measure: str | dict[str, str] | list[str] | tuple[str],
58
75
  post_process: Callable | None = None,
59
76
  backend: str = "matplotlib",
77
+ page: int = 0,
60
78
  **plot_drawing_kwargs,
61
79
  ):
62
80
  """Create a plotting function for a specified measure using the specified backend.
@@ -65,18 +83,20 @@ def make_plot_component(
65
83
  measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
66
84
  post_process: a user-specified callable to do post-processing called with the Axes instance.
67
85
  backend: the backend to use {"matplotlib", "altair"}
86
+ page: Page number where the plot should be displayed (default 0).
68
87
  plot_drawing_kwargs: additional keyword arguments to pass onto the backend specific function for making a plotting component
69
88
 
70
- Notes:
71
- altair plotting backend is not yet implemented and planned for mesa 3.1.
72
-
73
89
  Returns:
74
- function: A function that creates a plot component
90
+ (function, page): A tuple of a function and page number that creates a plot component on that specific page.
75
91
  """
76
92
  if backend == "matplotlib":
77
- return make_mpl_plot_component(measure, post_process, **plot_drawing_kwargs)
93
+ return make_mpl_plot_component(
94
+ measure, post_process, page, **plot_drawing_kwargs
95
+ )
78
96
  elif backend == "altair":
79
- raise NotImplementedError("altair line plots are not yet implemented")
97
+ return make_altair_plot_component(
98
+ measure, post_process, page, **plot_drawing_kwargs
99
+ )
80
100
  else:
81
101
  raise ValueError(
82
102
  f"unknown backend {backend}, must be one of matplotlib, altair"
@@ -1,6 +1,7 @@
1
1
  """Altair based solara components for visualization mesa spaces."""
2
2
 
3
3
  import warnings
4
+ from collections.abc import Callable
4
5
 
5
6
  import altair as alt
6
7
  import numpy as np
@@ -448,3 +449,88 @@ def chart_property_layers(space, propertylayer_portrayal, chart_width, chart_hei
448
449
  f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'."
449
450
  )
450
451
  return base, bar_chart
452
+
453
+
454
+ def make_altair_plot_component(
455
+ measure: str | dict[str, str] | list[str] | tuple[str],
456
+ post_process: Callable | None = None,
457
+ page: int = 0,
458
+ grid=False,
459
+ ):
460
+ """Create a plotting function for a specified measure.
461
+
462
+ Args:
463
+ measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
464
+ post_process: a user-specified callable to do post-processing called with the Axes instance.
465
+ page: Page number where the plot should be displayed.
466
+ grid: Bool to draw grid or not.
467
+
468
+ Returns:
469
+ (function, page): A tuple of a function that creates a PlotAltair component and a page number.
470
+ """
471
+
472
+ def MakePlotAltair(model):
473
+ return PlotAltair(model, measure, post_process=post_process, grid=grid)
474
+
475
+ return (MakePlotAltair, page)
476
+
477
+
478
+ @solara.component
479
+ def PlotAltair(model, measure, post_process: Callable | None = None, grid=False):
480
+ """Create an Altair-based plot for a measure or measures.
481
+
482
+ Args:
483
+ model (mesa.Model): The model instance.
484
+ measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
485
+ If a dict is given, keys are measure names and values are colors.
486
+ post_process: A user-specified callable for post-processing, called
487
+ with the Altair Chart instance.
488
+ grid: Bool to draw grid or not.
489
+
490
+ Returns:
491
+ solara.FigureAltair: A component for rendering the plot.
492
+ """
493
+ update_counter.get()
494
+ df = model.datacollector.get_model_vars_dataframe().reset_index()
495
+ df = df.rename(columns={"index": "Step"})
496
+
497
+ y_title = "Value"
498
+ if isinstance(measure, str):
499
+ measures_to_plot = [measure]
500
+ y_title = measure
501
+ elif isinstance(measure, list | tuple):
502
+ measures_to_plot = list(measure)
503
+ elif isinstance(measure, dict):
504
+ measures_to_plot = list(measure.keys())
505
+
506
+ df_long = df.melt(
507
+ id_vars=["Step"],
508
+ value_vars=measures_to_plot,
509
+ var_name="Measure",
510
+ value_name="Value",
511
+ )
512
+
513
+ chart = (
514
+ alt.Chart(df_long)
515
+ .mark_line()
516
+ .encode(
517
+ x=alt.X("Step:Q", axis=alt.Axis(tickMinStep=1, title="Step", grid=grid)),
518
+ y=alt.Y("Value:Q", axis=alt.Axis(title=y_title, grid=grid)),
519
+ tooltip=["Step", "Measure", "Value"],
520
+ )
521
+ .properties(width=450, height=350)
522
+ .interactive()
523
+ )
524
+
525
+ if len(measures_to_plot) > 0:
526
+ color_args = {}
527
+ if isinstance(measure, dict):
528
+ color_args["scale"] = alt.Scale(
529
+ domain=list(measure.keys()), range=list(measure.values())
530
+ )
531
+ chart = chart.encode(color=alt.Color("Measure:N", **color_args))
532
+
533
+ if post_process is not None:
534
+ chart = post_process(chart)
535
+
536
+ return solara.FigureAltair(chart)
@@ -107,6 +107,7 @@ def make_plot_measure(*args, **kwargs): # noqa: D103
107
107
  def make_mpl_plot_component(
108
108
  measure: str | dict[str, str] | list[str] | tuple[str],
109
109
  post_process: Callable | None = None,
110
+ page: int = 0,
110
111
  save_format="png",
111
112
  ):
112
113
  """Create a plotting function for a specified measure.
@@ -114,10 +115,11 @@ def make_mpl_plot_component(
114
115
  Args:
115
116
  measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
116
117
  post_process: a user-specified callable to do post-processing called with the Axes instance.
118
+ page: Page number where the plot should be displayed.
117
119
  save_format: save format of figure in solara backend
118
120
 
119
121
  Returns:
120
- function: A function that creates a PlotMatplotlib component.
122
+ (function, page): A tuple of a function that creates a PlotMatplotlib component and a page number.
121
123
  """
122
124
 
123
125
  def MakePlotMatplotlib(model):
@@ -125,7 +127,7 @@ def make_mpl_plot_component(
125
127
  model, measure, post_process=post_process, save_format=save_format
126
128
  )
127
129
 
128
- return MakePlotMatplotlib
130
+ return (MakePlotMatplotlib, page)
129
131
 
130
132
 
131
133
  @solara.component