Mesa 3.2.0.dev0__py3-none-any.whl → 3.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of Mesa might be problematic. Click here for more details.
- mesa/__init__.py +1 -1
- mesa/agent.py +9 -7
- mesa/datacollection.py +1 -1
- mesa/examples/README.md +1 -1
- mesa/examples/__init__.py +2 -0
- mesa/examples/advanced/alliance_formation/Readme.md +50 -0
- mesa/examples/advanced/alliance_formation/__init__ .py +0 -0
- mesa/examples/advanced/alliance_formation/agents.py +20 -0
- mesa/examples/advanced/alliance_formation/app.py +71 -0
- mesa/examples/advanced/alliance_formation/model.py +184 -0
- mesa/examples/advanced/epstein_civil_violence/app.py +11 -11
- mesa/examples/advanced/pd_grid/Readme.md +4 -6
- mesa/examples/advanced/pd_grid/app.py +10 -11
- mesa/examples/advanced/sugarscape_g1mt/Readme.md +4 -5
- mesa/examples/advanced/sugarscape_g1mt/app.py +34 -16
- mesa/examples/advanced/wolf_sheep/Readme.md +2 -17
- mesa/examples/advanced/wolf_sheep/app.py +21 -18
- mesa/examples/basic/boid_flockers/Readme.md +6 -1
- mesa/examples/basic/boid_flockers/app.py +15 -11
- mesa/examples/basic/boltzmann_wealth_model/Readme.md +2 -12
- mesa/examples/basic/boltzmann_wealth_model/app.py +39 -32
- mesa/examples/basic/conways_game_of_life/Readme.md +1 -9
- mesa/examples/basic/conways_game_of_life/app.py +13 -16
- mesa/examples/basic/schelling/Readme.md +2 -10
- mesa/examples/basic/schelling/agents.py +9 -3
- mesa/examples/basic/schelling/app.py +50 -3
- mesa/examples/basic/schelling/model.py +2 -0
- mesa/examples/basic/schelling/resources/blue_happy.png +0 -0
- mesa/examples/basic/schelling/resources/blue_unhappy.png +0 -0
- mesa/examples/basic/schelling/resources/orange_happy.png +0 -0
- mesa/examples/basic/schelling/resources/orange_unhappy.png +0 -0
- mesa/examples/basic/virus_on_network/Readme.md +0 -4
- mesa/examples/basic/virus_on_network/app.py +31 -14
- mesa/experimental/__init__.py +2 -2
- mesa/experimental/continuous_space/continuous_space.py +1 -1
- mesa/experimental/meta_agents/__init__.py +25 -0
- mesa/experimental/meta_agents/meta_agent.py +387 -0
- mesa/model.py +3 -3
- mesa/space.py +4 -1
- mesa/visualization/__init__.py +2 -0
- mesa/visualization/backends/__init__.py +23 -0
- mesa/visualization/backends/abstract_renderer.py +97 -0
- mesa/visualization/backends/altair_backend.py +440 -0
- mesa/visualization/backends/matplotlib_backend.py +419 -0
- mesa/visualization/components/__init__.py +28 -8
- mesa/visualization/components/altair_components.py +86 -0
- mesa/visualization/components/matplotlib_components.py +4 -2
- mesa/visualization/components/portrayal_components.py +120 -0
- mesa/visualization/mpl_space_drawing.py +292 -129
- mesa/visualization/solara_viz.py +274 -32
- mesa/visualization/space_drawers.py +797 -0
- mesa/visualization/space_renderer.py +399 -0
- {mesa-3.2.0.dev0.dist-info → mesa-3.3.0.dist-info}/METADATA +13 -4
- {mesa-3.2.0.dev0.dist-info → mesa-3.3.0.dist-info}/RECORD +57 -40
- mesa/examples/advanced/sugarscape_g1mt/tests.py +0 -69
- {mesa-3.2.0.dev0.dist-info → mesa-3.3.0.dist-info}/WHEEL +0 -0
- {mesa-3.2.0.dev0.dist-info → mesa-3.3.0.dist-info}/licenses/LICENSE +0 -0
- {mesa-3.2.0.dev0.dist-info → mesa-3.3.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,419 @@
|
|
|
1
|
+
# noqa: D100
|
|
2
|
+
import os
|
|
3
|
+
import warnings
|
|
4
|
+
from dataclasses import fields
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from matplotlib import pyplot as plt
|
|
8
|
+
from matplotlib.cm import ScalarMappable
|
|
9
|
+
from matplotlib.collections import PolyCollection
|
|
10
|
+
from matplotlib.colors import LinearSegmentedColormap, Normalize, to_rgba
|
|
11
|
+
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
|
|
12
|
+
from PIL import Image
|
|
13
|
+
|
|
14
|
+
import mesa
|
|
15
|
+
from mesa.discrete_space import (
|
|
16
|
+
OrthogonalMooreGrid,
|
|
17
|
+
OrthogonalVonNeumannGrid,
|
|
18
|
+
)
|
|
19
|
+
from mesa.space import (
|
|
20
|
+
HexMultiGrid,
|
|
21
|
+
HexSingleGrid,
|
|
22
|
+
MultiGrid,
|
|
23
|
+
SingleGrid,
|
|
24
|
+
)
|
|
25
|
+
from mesa.visualization.backends.abstract_renderer import AbstractRenderer
|
|
26
|
+
|
|
27
|
+
OrthogonalGrid = SingleGrid | MultiGrid | OrthogonalMooreGrid | OrthogonalVonNeumannGrid
|
|
28
|
+
HexGrid = HexSingleGrid | HexMultiGrid | mesa.discrete_space.HexGrid
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
CORRECTION_FACTOR_MARKER_ZOOM = 0.01
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class MatplotlibBackend(AbstractRenderer):
|
|
35
|
+
"""Matplotlib-based renderer for Mesa spaces.
|
|
36
|
+
|
|
37
|
+
Provides visualization capabilities using Matplotlib for rendering
|
|
38
|
+
space structures, agents, and property layers.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, space_drawer):
|
|
42
|
+
"""Initialize the Matplotlib backend.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
space_drawer: An instance of a SpaceDrawer class that handles
|
|
46
|
+
the drawing of the space structure.
|
|
47
|
+
"""
|
|
48
|
+
super().__init__(space_drawer)
|
|
49
|
+
|
|
50
|
+
self._active_colorbars = []
|
|
51
|
+
|
|
52
|
+
def initialize_canvas(self, ax=None):
|
|
53
|
+
"""Initialize the matplotlib canvas.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
ax (matplotlib.axes.Axes, optional): Existing axes to use.
|
|
57
|
+
If None, creates new figure and axes.
|
|
58
|
+
"""
|
|
59
|
+
if ax is None:
|
|
60
|
+
fig, ax = plt.subplots(constrained_layout=True)
|
|
61
|
+
self.fig = fig
|
|
62
|
+
self.ax = ax
|
|
63
|
+
|
|
64
|
+
def draw_structure(self, **kwargs):
|
|
65
|
+
"""Draw the space structure using matplotlib.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
**kwargs: Additional arguments passed to the space drawer.
|
|
69
|
+
Checkout respective `SpaceDrawer` class on details how to pass **kwargs.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
The matplotlib axes with the drawn structure.
|
|
73
|
+
"""
|
|
74
|
+
return self.space_drawer.draw_matplotlib(self.ax, **kwargs)
|
|
75
|
+
|
|
76
|
+
def collect_agent_data(self, space, agent_portrayal, default_size=None):
|
|
77
|
+
"""Collect plotting data for all agents in the space.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
space: The Mesa space containing agents.
|
|
81
|
+
agent_portrayal (Callable): Function that returns AgentPortrayalStyle for each agent.
|
|
82
|
+
default_size (float, optional): Default marker size if not specified in portrayal.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
dict: Dictionary containing agent plotting data arrays.
|
|
86
|
+
"""
|
|
87
|
+
# Initialize data collection arrays
|
|
88
|
+
arguments = {
|
|
89
|
+
"loc": [],
|
|
90
|
+
"s": [],
|
|
91
|
+
"c": [],
|
|
92
|
+
"marker": [],
|
|
93
|
+
"zorder": [],
|
|
94
|
+
"alpha": [],
|
|
95
|
+
"edgecolors": [],
|
|
96
|
+
"linewidths": [],
|
|
97
|
+
}
|
|
98
|
+
# Import here to prevent circular imports
|
|
99
|
+
from mesa.visualization.components import AgentPortrayalStyle # noqa: PLC0415
|
|
100
|
+
|
|
101
|
+
# Get default values from AgentPortrayalStyle
|
|
102
|
+
style_fields = {f.name: f.default for f in fields(AgentPortrayalStyle)}
|
|
103
|
+
class_default_size = style_fields.get("size")
|
|
104
|
+
|
|
105
|
+
for agent in space.agents:
|
|
106
|
+
portray_input = agent_portrayal(agent)
|
|
107
|
+
|
|
108
|
+
if isinstance(portray_input, dict):
|
|
109
|
+
warnings.warn(
|
|
110
|
+
"Returning a dict from agent_portrayal is deprecated. "
|
|
111
|
+
"Please return an AgentPortrayalStyle instance instead.",
|
|
112
|
+
PendingDeprecationWarning,
|
|
113
|
+
stacklevel=2,
|
|
114
|
+
)
|
|
115
|
+
# Handle legacy dict input
|
|
116
|
+
dict_data = portray_input.copy()
|
|
117
|
+
agent_x, agent_y = self._get_agent_pos(agent, space)
|
|
118
|
+
|
|
119
|
+
# Extract values with defaults
|
|
120
|
+
aps = AgentPortrayalStyle(
|
|
121
|
+
x=agent_x,
|
|
122
|
+
y=agent_y,
|
|
123
|
+
size=dict_data.pop("size", style_fields.get("size")),
|
|
124
|
+
color=dict_data.pop("color", style_fields.get("color")),
|
|
125
|
+
marker=dict_data.pop("marker", style_fields.get("marker")),
|
|
126
|
+
zorder=dict_data.pop("zorder", style_fields.get("zorder")),
|
|
127
|
+
alpha=dict_data.pop("alpha", style_fields.get("alpha")),
|
|
128
|
+
edgecolors=dict_data.pop("edgecolors", None),
|
|
129
|
+
linewidths=dict_data.pop(
|
|
130
|
+
"linewidths", style_fields.get("linewidths")
|
|
131
|
+
),
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Warn about unused keys
|
|
135
|
+
if dict_data:
|
|
136
|
+
ignored_keys = list(dict_data.keys())
|
|
137
|
+
warnings.warn(
|
|
138
|
+
f"The following keys were ignored: {', '.join(ignored_keys)}",
|
|
139
|
+
UserWarning,
|
|
140
|
+
stacklevel=2,
|
|
141
|
+
)
|
|
142
|
+
else:
|
|
143
|
+
aps = portray_input
|
|
144
|
+
# Set defaults if not provided
|
|
145
|
+
if aps.x is None and aps.y is None:
|
|
146
|
+
aps.x, aps.y = self._get_agent_pos(agent, space)
|
|
147
|
+
|
|
148
|
+
# Collect agent data
|
|
149
|
+
arguments["loc"].append((aps.x, aps.y))
|
|
150
|
+
|
|
151
|
+
# Determine final size
|
|
152
|
+
size_to_collect = aps.size or default_size or class_default_size
|
|
153
|
+
arguments["s"].append(size_to_collect)
|
|
154
|
+
arguments["c"].append(aps.color)
|
|
155
|
+
arguments["marker"].append(aps.marker)
|
|
156
|
+
arguments["zorder"].append(aps.zorder)
|
|
157
|
+
arguments["alpha"].append(aps.alpha)
|
|
158
|
+
if aps.edgecolors is not None:
|
|
159
|
+
arguments["edgecolors"].append(aps.edgecolors)
|
|
160
|
+
arguments["linewidths"].append(aps.linewidths)
|
|
161
|
+
|
|
162
|
+
# Convert to numpy arrays
|
|
163
|
+
data = {
|
|
164
|
+
k: (np.asarray(v, dtype=object) if k == "marker" else np.asarray(v))
|
|
165
|
+
for k, v in arguments.items()
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
# Handle marker array specially to preserve tuples
|
|
169
|
+
arr = np.empty(len(arguments["marker"]), dtype=object)
|
|
170
|
+
arr[:] = arguments["marker"]
|
|
171
|
+
data["marker"] = arr
|
|
172
|
+
|
|
173
|
+
return data
|
|
174
|
+
|
|
175
|
+
def _get_zoom_factor(self, ax, img):
|
|
176
|
+
"""Calculate zoom factor only once and cache the result."""
|
|
177
|
+
ax.get_figure().canvas.draw()
|
|
178
|
+
|
|
179
|
+
bbox = ax.get_window_extent().transformed(
|
|
180
|
+
ax.get_figure().dpi_scale_trans.inverted()
|
|
181
|
+
)
|
|
182
|
+
width, height = (
|
|
183
|
+
bbox.width * ax.get_figure().dpi,
|
|
184
|
+
bbox.height * ax.get_figure().dpi,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
xr = ax.get_xlim()
|
|
188
|
+
yr = ax.get_ylim()
|
|
189
|
+
|
|
190
|
+
x_pixel_per_data = width / (xr[1] - xr[0])
|
|
191
|
+
y_pixel_per_data = height / (yr[1] - yr[0])
|
|
192
|
+
|
|
193
|
+
zoom_x = (x_pixel_per_data / img.width) * CORRECTION_FACTOR_MARKER_ZOOM
|
|
194
|
+
zoom_y = (y_pixel_per_data / img.height) * CORRECTION_FACTOR_MARKER_ZOOM
|
|
195
|
+
|
|
196
|
+
return min(zoom_x, zoom_y)
|
|
197
|
+
|
|
198
|
+
def draw_agents(self, arguments, **kwargs):
|
|
199
|
+
"""Draw agents on the backend's axes - optimized version.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
arguments: Dictionary containing agent data arrays.
|
|
203
|
+
**kwargs: Additional keyword arguments for customization.
|
|
204
|
+
Checkout respective `SpaceDrawer` class on details how to pass **kwargs.
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
matplotlib.axes.Axes: The Matplotlib Axes with the agents drawn upon it.
|
|
208
|
+
"""
|
|
209
|
+
if arguments["loc"].size == 0:
|
|
210
|
+
return None
|
|
211
|
+
|
|
212
|
+
loc = arguments.pop("loc")
|
|
213
|
+
loc_x, loc_y = loc[:, 0], loc[:, 1]
|
|
214
|
+
marker = arguments.pop("marker")
|
|
215
|
+
zorder = arguments.pop("zorder")
|
|
216
|
+
malpha = arguments["alpha"]
|
|
217
|
+
msize = arguments["s"]
|
|
218
|
+
|
|
219
|
+
# Validate edge arguments
|
|
220
|
+
for entry in ["edgecolors", "linewidths"]:
|
|
221
|
+
if len(arguments[entry]) == 0:
|
|
222
|
+
arguments.pop(entry)
|
|
223
|
+
elif entry in kwargs:
|
|
224
|
+
raise ValueError(
|
|
225
|
+
f"{entry} is specified in agent portrayal and via plotting kwargs, "
|
|
226
|
+
"you can only use one or the other"
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# Cache for loaded images and their zoom factors
|
|
230
|
+
image_cache = {}
|
|
231
|
+
|
|
232
|
+
# Separate image and non-image markers
|
|
233
|
+
unique_markers = set(marker)
|
|
234
|
+
image_markers = set()
|
|
235
|
+
regular_markers = set()
|
|
236
|
+
|
|
237
|
+
for mark in unique_markers:
|
|
238
|
+
if isinstance(mark, str | os.PathLike) and os.path.isfile(mark):
|
|
239
|
+
image_markers.add(mark)
|
|
240
|
+
else:
|
|
241
|
+
regular_markers.add(mark)
|
|
242
|
+
|
|
243
|
+
self.ax.get_figure().canvas.draw()
|
|
244
|
+
|
|
245
|
+
for mark in image_markers:
|
|
246
|
+
if mark not in image_cache:
|
|
247
|
+
image = Image.open(mark)
|
|
248
|
+
base_zoom = self._get_zoom_factor(self.ax, image)
|
|
249
|
+
image_cache[mark] = {
|
|
250
|
+
"image": image,
|
|
251
|
+
"base_zoom": base_zoom,
|
|
252
|
+
"offset_images": {},
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
cache_entry = image_cache[mark]
|
|
256
|
+
mask_marker = marker == mark
|
|
257
|
+
|
|
258
|
+
unique_sizes = np.unique(msize[mask_marker])
|
|
259
|
+
|
|
260
|
+
for m_size in unique_sizes:
|
|
261
|
+
if m_size not in cache_entry["offset_images"]:
|
|
262
|
+
im = OffsetImage(
|
|
263
|
+
cache_entry["image"],
|
|
264
|
+
zoom=cache_entry["base_zoom"] * m_size,
|
|
265
|
+
)
|
|
266
|
+
im.image.axes = self.ax
|
|
267
|
+
cache_entry["offset_images"][m_size] = im
|
|
268
|
+
|
|
269
|
+
im = cache_entry["offset_images"][m_size]
|
|
270
|
+
|
|
271
|
+
size_marker_mask = mask_marker & (msize == m_size)
|
|
272
|
+
|
|
273
|
+
unique_zorders = np.unique(zorder[size_marker_mask])
|
|
274
|
+
unique_alphas = np.unique(malpha[size_marker_mask])
|
|
275
|
+
|
|
276
|
+
for z_order in unique_zorders:
|
|
277
|
+
for m_alpha in unique_alphas:
|
|
278
|
+
final_mask = (
|
|
279
|
+
(zorder == z_order) & (malpha == m_alpha) & size_marker_mask
|
|
280
|
+
)
|
|
281
|
+
positions = loc[final_mask]
|
|
282
|
+
|
|
283
|
+
for x, y in positions:
|
|
284
|
+
ab = AnnotationBbox(
|
|
285
|
+
im,
|
|
286
|
+
(x, y),
|
|
287
|
+
frameon=False,
|
|
288
|
+
pad=0.0,
|
|
289
|
+
zorder=z_order,
|
|
290
|
+
**kwargs,
|
|
291
|
+
)
|
|
292
|
+
self.ax.add_artist(ab)
|
|
293
|
+
|
|
294
|
+
for mark in regular_markers:
|
|
295
|
+
mask_marker = marker == mark
|
|
296
|
+
|
|
297
|
+
unique_zorders = np.unique(zorder[mask_marker])
|
|
298
|
+
|
|
299
|
+
for z_order in unique_zorders:
|
|
300
|
+
zorder_mask = (zorder == z_order) & mask_marker
|
|
301
|
+
|
|
302
|
+
scatter_args = {k: v[zorder_mask] for k, v in arguments.items()}
|
|
303
|
+
|
|
304
|
+
self.ax.scatter(
|
|
305
|
+
loc_x[zorder_mask],
|
|
306
|
+
loc_y[zorder_mask],
|
|
307
|
+
marker=mark,
|
|
308
|
+
zorder=z_order,
|
|
309
|
+
**scatter_args,
|
|
310
|
+
**kwargs,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
return self.ax
|
|
314
|
+
|
|
315
|
+
def draw_propertylayer(self, space, property_layers, propertylayer_portrayal):
|
|
316
|
+
"""Draw property layers using matplotlib backend.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
space: The Mesa space object.
|
|
320
|
+
property_layers (dict): Dictionary of property layers to visualize.
|
|
321
|
+
propertylayer_portrayal (Callable): Function that returns PropertyLayerStyle.
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
tuple: (matplotlib.axes.Axes, colorbar) - The matplotlib axes and colorbar objects.
|
|
325
|
+
"""
|
|
326
|
+
# Draw each layer
|
|
327
|
+
for layer_name in property_layers:
|
|
328
|
+
if layer_name == "empty":
|
|
329
|
+
continue
|
|
330
|
+
|
|
331
|
+
layer = property_layers.get(layer_name)
|
|
332
|
+
portrayal = propertylayer_portrayal(layer)
|
|
333
|
+
|
|
334
|
+
if portrayal is None:
|
|
335
|
+
continue
|
|
336
|
+
|
|
337
|
+
data = layer.data.astype(float) if layer.data.dtype == bool else layer.data
|
|
338
|
+
|
|
339
|
+
# Check dimensions
|
|
340
|
+
if (space.width, space.height) != data.shape:
|
|
341
|
+
warnings.warn(
|
|
342
|
+
f"Layer {layer_name} dimensions ({data.shape}) "
|
|
343
|
+
f"don't match space dimensions ({space.width}, {space.height})",
|
|
344
|
+
UserWarning,
|
|
345
|
+
stacklevel=2,
|
|
346
|
+
)
|
|
347
|
+
continue
|
|
348
|
+
|
|
349
|
+
# Get portrayal parameters
|
|
350
|
+
color = portrayal.color
|
|
351
|
+
colormap = portrayal.colormap
|
|
352
|
+
alpha = portrayal.alpha
|
|
353
|
+
vmin = portrayal.vmin if portrayal.vmin is not None else np.min(data)
|
|
354
|
+
vmax = portrayal.vmax if portrayal.vmax is not None else np.max(data)
|
|
355
|
+
|
|
356
|
+
# Set up colormap
|
|
357
|
+
if color:
|
|
358
|
+
rgba_color = to_rgba(color)
|
|
359
|
+
cmap = LinearSegmentedColormap.from_list(
|
|
360
|
+
layer_name, [(0, 0, 0, 0), (*rgba_color[:3], alpha)]
|
|
361
|
+
)
|
|
362
|
+
elif colormap:
|
|
363
|
+
cmap = colormap
|
|
364
|
+
if isinstance(cmap, list):
|
|
365
|
+
cmap = LinearSegmentedColormap.from_list(layer_name, cmap)
|
|
366
|
+
elif isinstance(cmap, str):
|
|
367
|
+
cmap = plt.get_cmap(cmap)
|
|
368
|
+
else:
|
|
369
|
+
raise ValueError(
|
|
370
|
+
f"PropertyLayer {layer_name} must include 'color' or 'colormap'"
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
# Draw based on space type
|
|
374
|
+
if isinstance(space, OrthogonalGrid):
|
|
375
|
+
if color:
|
|
376
|
+
data = data.T
|
|
377
|
+
normalized_data = (data - vmin) / (vmax - vmin)
|
|
378
|
+
rgba_data = np.full((*data.shape, 4), rgba_color)
|
|
379
|
+
rgba_data[..., 3] *= normalized_data * alpha
|
|
380
|
+
rgba_data = np.clip(rgba_data, 0, 1)
|
|
381
|
+
self.ax.imshow(rgba_data, origin="lower")
|
|
382
|
+
else:
|
|
383
|
+
self.ax.imshow(
|
|
384
|
+
data.T,
|
|
385
|
+
cmap=cmap,
|
|
386
|
+
alpha=alpha,
|
|
387
|
+
vmin=vmin,
|
|
388
|
+
vmax=vmax,
|
|
389
|
+
origin="lower",
|
|
390
|
+
)
|
|
391
|
+
elif isinstance(space, HexGrid):
|
|
392
|
+
hexagons = self.space_drawer.hexagons
|
|
393
|
+
norm = Normalize(vmin=vmin, vmax=vmax)
|
|
394
|
+
colors = data.ravel()
|
|
395
|
+
|
|
396
|
+
if color:
|
|
397
|
+
normalized_colors = np.clip(norm(colors), 0, 1)
|
|
398
|
+
rgba_colors = np.full((len(colors), 4), rgba_color)
|
|
399
|
+
rgba_colors[:, 3] = normalized_colors * alpha
|
|
400
|
+
else:
|
|
401
|
+
rgba_colors = cmap(norm(colors))
|
|
402
|
+
rgba_colors[..., 3] *= alpha
|
|
403
|
+
|
|
404
|
+
collection = PolyCollection(hexagons, facecolors=rgba_colors, zorder=-1)
|
|
405
|
+
self.ax.add_collection(collection)
|
|
406
|
+
else:
|
|
407
|
+
raise NotImplementedError(
|
|
408
|
+
f"PropertyLayer visualization not implemented for {type(space)}"
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
# Add colorbar if requested
|
|
412
|
+
cbar = None
|
|
413
|
+
if portrayal.colorbar:
|
|
414
|
+
norm = Normalize(vmin=vmin, vmax=vmax)
|
|
415
|
+
sm = ScalarMappable(norm=norm, cmap=cmap)
|
|
416
|
+
sm.set_array([])
|
|
417
|
+
cbar = plt.colorbar(sm, ax=self.ax, label=layer_name)
|
|
418
|
+
self._active_colorbars.append(cbar)
|
|
419
|
+
return self.ax, cbar
|
|
@@ -1,15 +1,32 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Custom visualization components."""
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from collections.abc import Callable
|
|
6
6
|
|
|
7
|
-
from .altair_components import
|
|
7
|
+
from .altair_components import (
|
|
8
|
+
SpaceAltair,
|
|
9
|
+
make_altair_plot_component,
|
|
10
|
+
make_altair_space,
|
|
11
|
+
)
|
|
8
12
|
from .matplotlib_components import (
|
|
9
13
|
SpaceMatplotlib,
|
|
10
14
|
make_mpl_plot_component,
|
|
11
15
|
make_mpl_space_component,
|
|
12
16
|
)
|
|
17
|
+
from .portrayal_components import AgentPortrayalStyle, PropertyLayerStyle
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"AgentPortrayalStyle",
|
|
21
|
+
"PropertyLayerStyle",
|
|
22
|
+
"SpaceAltair",
|
|
23
|
+
"SpaceMatplotlib",
|
|
24
|
+
"make_altair_space",
|
|
25
|
+
"make_mpl_plot_component",
|
|
26
|
+
"make_mpl_space_component",
|
|
27
|
+
"make_plot_component",
|
|
28
|
+
"make_space_component",
|
|
29
|
+
]
|
|
13
30
|
|
|
14
31
|
|
|
15
32
|
def make_space_component(
|
|
@@ -57,6 +74,7 @@ def make_plot_component(
|
|
|
57
74
|
measure: str | dict[str, str] | list[str] | tuple[str],
|
|
58
75
|
post_process: Callable | None = None,
|
|
59
76
|
backend: str = "matplotlib",
|
|
77
|
+
page: int = 0,
|
|
60
78
|
**plot_drawing_kwargs,
|
|
61
79
|
):
|
|
62
80
|
"""Create a plotting function for a specified measure using the specified backend.
|
|
@@ -65,18 +83,20 @@ def make_plot_component(
|
|
|
65
83
|
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
|
|
66
84
|
post_process: a user-specified callable to do post-processing called with the Axes instance.
|
|
67
85
|
backend: the backend to use {"matplotlib", "altair"}
|
|
86
|
+
page: Page number where the plot should be displayed (default 0).
|
|
68
87
|
plot_drawing_kwargs: additional keyword arguments to pass onto the backend specific function for making a plotting component
|
|
69
88
|
|
|
70
|
-
Notes:
|
|
71
|
-
altair plotting backend is not yet implemented and planned for mesa 3.1.
|
|
72
|
-
|
|
73
89
|
Returns:
|
|
74
|
-
function: A function that creates a plot component
|
|
90
|
+
(function, page): A tuple of a function and page number that creates a plot component on that specific page.
|
|
75
91
|
"""
|
|
76
92
|
if backend == "matplotlib":
|
|
77
|
-
return make_mpl_plot_component(
|
|
93
|
+
return make_mpl_plot_component(
|
|
94
|
+
measure, post_process, page, **plot_drawing_kwargs
|
|
95
|
+
)
|
|
78
96
|
elif backend == "altair":
|
|
79
|
-
|
|
97
|
+
return make_altair_plot_component(
|
|
98
|
+
measure, post_process, page, **plot_drawing_kwargs
|
|
99
|
+
)
|
|
80
100
|
else:
|
|
81
101
|
raise ValueError(
|
|
82
102
|
f"unknown backend {backend}, must be one of matplotlib, altair"
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Altair based solara components for visualization mesa spaces."""
|
|
2
2
|
|
|
3
3
|
import warnings
|
|
4
|
+
from collections.abc import Callable
|
|
4
5
|
|
|
5
6
|
import altair as alt
|
|
6
7
|
import numpy as np
|
|
@@ -448,3 +449,88 @@ def chart_property_layers(space, propertylayer_portrayal, chart_width, chart_hei
|
|
|
448
449
|
f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'."
|
|
449
450
|
)
|
|
450
451
|
return base, bar_chart
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
def make_altair_plot_component(
|
|
455
|
+
measure: str | dict[str, str] | list[str] | tuple[str],
|
|
456
|
+
post_process: Callable | None = None,
|
|
457
|
+
page: int = 0,
|
|
458
|
+
grid=False,
|
|
459
|
+
):
|
|
460
|
+
"""Create a plotting function for a specified measure.
|
|
461
|
+
|
|
462
|
+
Args:
|
|
463
|
+
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
|
|
464
|
+
post_process: a user-specified callable to do post-processing called with the Axes instance.
|
|
465
|
+
page: Page number where the plot should be displayed.
|
|
466
|
+
grid: Bool to draw grid or not.
|
|
467
|
+
|
|
468
|
+
Returns:
|
|
469
|
+
(function, page): A tuple of a function that creates a PlotAltair component and a page number.
|
|
470
|
+
"""
|
|
471
|
+
|
|
472
|
+
def MakePlotAltair(model):
|
|
473
|
+
return PlotAltair(model, measure, post_process=post_process, grid=grid)
|
|
474
|
+
|
|
475
|
+
return (MakePlotAltair, page)
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
@solara.component
|
|
479
|
+
def PlotAltair(model, measure, post_process: Callable | None = None, grid=False):
|
|
480
|
+
"""Create an Altair-based plot for a measure or measures.
|
|
481
|
+
|
|
482
|
+
Args:
|
|
483
|
+
model (mesa.Model): The model instance.
|
|
484
|
+
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
|
|
485
|
+
If a dict is given, keys are measure names and values are colors.
|
|
486
|
+
post_process: A user-specified callable for post-processing, called
|
|
487
|
+
with the Altair Chart instance.
|
|
488
|
+
grid: Bool to draw grid or not.
|
|
489
|
+
|
|
490
|
+
Returns:
|
|
491
|
+
solara.FigureAltair: A component for rendering the plot.
|
|
492
|
+
"""
|
|
493
|
+
update_counter.get()
|
|
494
|
+
df = model.datacollector.get_model_vars_dataframe().reset_index()
|
|
495
|
+
df = df.rename(columns={"index": "Step"})
|
|
496
|
+
|
|
497
|
+
y_title = "Value"
|
|
498
|
+
if isinstance(measure, str):
|
|
499
|
+
measures_to_plot = [measure]
|
|
500
|
+
y_title = measure
|
|
501
|
+
elif isinstance(measure, list | tuple):
|
|
502
|
+
measures_to_plot = list(measure)
|
|
503
|
+
elif isinstance(measure, dict):
|
|
504
|
+
measures_to_plot = list(measure.keys())
|
|
505
|
+
|
|
506
|
+
df_long = df.melt(
|
|
507
|
+
id_vars=["Step"],
|
|
508
|
+
value_vars=measures_to_plot,
|
|
509
|
+
var_name="Measure",
|
|
510
|
+
value_name="Value",
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
chart = (
|
|
514
|
+
alt.Chart(df_long)
|
|
515
|
+
.mark_line()
|
|
516
|
+
.encode(
|
|
517
|
+
x=alt.X("Step:Q", axis=alt.Axis(tickMinStep=1, title="Step", grid=grid)),
|
|
518
|
+
y=alt.Y("Value:Q", axis=alt.Axis(title=y_title, grid=grid)),
|
|
519
|
+
tooltip=["Step", "Measure", "Value"],
|
|
520
|
+
)
|
|
521
|
+
.properties(width=450, height=350)
|
|
522
|
+
.interactive()
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
if len(measures_to_plot) > 0:
|
|
526
|
+
color_args = {}
|
|
527
|
+
if isinstance(measure, dict):
|
|
528
|
+
color_args["scale"] = alt.Scale(
|
|
529
|
+
domain=list(measure.keys()), range=list(measure.values())
|
|
530
|
+
)
|
|
531
|
+
chart = chart.encode(color=alt.Color("Measure:N", **color_args))
|
|
532
|
+
|
|
533
|
+
if post_process is not None:
|
|
534
|
+
chart = post_process(chart)
|
|
535
|
+
|
|
536
|
+
return solara.FigureAltair(chart)
|
|
@@ -107,6 +107,7 @@ def make_plot_measure(*args, **kwargs): # noqa: D103
|
|
|
107
107
|
def make_mpl_plot_component(
|
|
108
108
|
measure: str | dict[str, str] | list[str] | tuple[str],
|
|
109
109
|
post_process: Callable | None = None,
|
|
110
|
+
page: int = 0,
|
|
110
111
|
save_format="png",
|
|
111
112
|
):
|
|
112
113
|
"""Create a plotting function for a specified measure.
|
|
@@ -114,10 +115,11 @@ def make_mpl_plot_component(
|
|
|
114
115
|
Args:
|
|
115
116
|
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
|
|
116
117
|
post_process: a user-specified callable to do post-processing called with the Axes instance.
|
|
118
|
+
page: Page number where the plot should be displayed.
|
|
117
119
|
save_format: save format of figure in solara backend
|
|
118
120
|
|
|
119
121
|
Returns:
|
|
120
|
-
function: A function that creates a PlotMatplotlib component.
|
|
122
|
+
(function, page): A tuple of a function that creates a PlotMatplotlib component and a page number.
|
|
121
123
|
"""
|
|
122
124
|
|
|
123
125
|
def MakePlotMatplotlib(model):
|
|
@@ -125,7 +127,7 @@ def make_mpl_plot_component(
|
|
|
125
127
|
model, measure, post_process=post_process, save_format=save_format
|
|
126
128
|
)
|
|
127
129
|
|
|
128
|
-
return MakePlotMatplotlib
|
|
130
|
+
return (MakePlotMatplotlib, page)
|
|
129
131
|
|
|
130
132
|
|
|
131
133
|
@solara.component
|