Mesa 3.2.0.dev0__py3-none-any.whl → 3.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of Mesa might be problematic. Click here for more details.

Files changed (58) hide show
  1. mesa/__init__.py +1 -1
  2. mesa/agent.py +9 -7
  3. mesa/datacollection.py +1 -1
  4. mesa/examples/README.md +1 -1
  5. mesa/examples/__init__.py +2 -0
  6. mesa/examples/advanced/alliance_formation/Readme.md +50 -0
  7. mesa/examples/advanced/alliance_formation/__init__ .py +0 -0
  8. mesa/examples/advanced/alliance_formation/agents.py +20 -0
  9. mesa/examples/advanced/alliance_formation/app.py +71 -0
  10. mesa/examples/advanced/alliance_formation/model.py +184 -0
  11. mesa/examples/advanced/epstein_civil_violence/app.py +11 -11
  12. mesa/examples/advanced/pd_grid/Readme.md +4 -6
  13. mesa/examples/advanced/pd_grid/app.py +10 -11
  14. mesa/examples/advanced/sugarscape_g1mt/Readme.md +4 -5
  15. mesa/examples/advanced/sugarscape_g1mt/app.py +34 -16
  16. mesa/examples/advanced/wolf_sheep/Readme.md +2 -17
  17. mesa/examples/advanced/wolf_sheep/app.py +21 -18
  18. mesa/examples/basic/boid_flockers/Readme.md +6 -1
  19. mesa/examples/basic/boid_flockers/app.py +15 -11
  20. mesa/examples/basic/boltzmann_wealth_model/Readme.md +2 -12
  21. mesa/examples/basic/boltzmann_wealth_model/app.py +39 -32
  22. mesa/examples/basic/conways_game_of_life/Readme.md +1 -9
  23. mesa/examples/basic/conways_game_of_life/app.py +13 -16
  24. mesa/examples/basic/schelling/Readme.md +2 -10
  25. mesa/examples/basic/schelling/agents.py +9 -3
  26. mesa/examples/basic/schelling/app.py +50 -3
  27. mesa/examples/basic/schelling/model.py +2 -0
  28. mesa/examples/basic/schelling/resources/blue_happy.png +0 -0
  29. mesa/examples/basic/schelling/resources/blue_unhappy.png +0 -0
  30. mesa/examples/basic/schelling/resources/orange_happy.png +0 -0
  31. mesa/examples/basic/schelling/resources/orange_unhappy.png +0 -0
  32. mesa/examples/basic/virus_on_network/Readme.md +0 -4
  33. mesa/examples/basic/virus_on_network/app.py +31 -14
  34. mesa/experimental/__init__.py +2 -2
  35. mesa/experimental/continuous_space/continuous_space.py +1 -1
  36. mesa/experimental/meta_agents/__init__.py +25 -0
  37. mesa/experimental/meta_agents/meta_agent.py +387 -0
  38. mesa/model.py +3 -3
  39. mesa/space.py +4 -1
  40. mesa/visualization/__init__.py +2 -0
  41. mesa/visualization/backends/__init__.py +23 -0
  42. mesa/visualization/backends/abstract_renderer.py +97 -0
  43. mesa/visualization/backends/altair_backend.py +440 -0
  44. mesa/visualization/backends/matplotlib_backend.py +419 -0
  45. mesa/visualization/components/__init__.py +28 -8
  46. mesa/visualization/components/altair_components.py +86 -0
  47. mesa/visualization/components/matplotlib_components.py +4 -2
  48. mesa/visualization/components/portrayal_components.py +120 -0
  49. mesa/visualization/mpl_space_drawing.py +292 -129
  50. mesa/visualization/solara_viz.py +274 -32
  51. mesa/visualization/space_drawers.py +797 -0
  52. mesa/visualization/space_renderer.py +399 -0
  53. {mesa-3.2.0.dev0.dist-info → mesa-3.3.0.dist-info}/METADATA +13 -4
  54. {mesa-3.2.0.dev0.dist-info → mesa-3.3.0.dist-info}/RECORD +57 -40
  55. mesa/examples/advanced/sugarscape_g1mt/tests.py +0 -69
  56. {mesa-3.2.0.dev0.dist-info → mesa-3.3.0.dist-info}/WHEEL +0 -0
  57. {mesa-3.2.0.dev0.dist-info → mesa-3.3.0.dist-info}/licenses/LICENSE +0 -0
  58. {mesa-3.2.0.dev0.dist-info → mesa-3.3.0.dist-info}/licenses/NOTICE +0 -0
@@ -6,10 +6,11 @@ for a paper.
6
6
 
7
7
  """
8
8
 
9
- import contextlib
10
9
  import itertools
10
+ import os
11
11
  import warnings
12
12
  from collections.abc import Callable
13
+ from dataclasses import fields
13
14
  from functools import lru_cache
14
15
  from itertools import pairwise
15
16
  from typing import Any
@@ -21,7 +22,9 @@ from matplotlib.axes import Axes
21
22
  from matplotlib.cm import ScalarMappable
22
23
  from matplotlib.collections import LineCollection, PatchCollection, PolyCollection
23
24
  from matplotlib.colors import LinearSegmentedColormap, Normalize, to_rgba
25
+ from matplotlib.offsetbox import AnnotationBbox, OffsetImage
24
26
  from matplotlib.patches import Polygon
27
+ from PIL import Image
25
28
 
26
29
  import mesa
27
30
  from mesa.discrete_space import (
@@ -35,10 +38,12 @@ from mesa.space import (
35
38
  HexSingleGrid,
36
39
  MultiGrid,
37
40
  NetworkGrid,
38
- PropertyLayer,
39
41
  SingleGrid,
40
42
  )
41
43
 
44
+ CORRECTION_FACTOR_MARKER_ZOOM = 0.6
45
+ DEFAULT_MARKER_SIZE = 50
46
+
42
47
  OrthogonalGrid = SingleGrid | MultiGrid | OrthogonalMooreGrid | OrthogonalVonNeumannGrid
43
48
  HexGrid = HexSingleGrid | HexMultiGrid | mesa.discrete_space.HexGrid
44
49
  Network = NetworkGrid | mesa.discrete_space.Network
@@ -47,59 +52,123 @@ Network = NetworkGrid | mesa.discrete_space.Network
47
52
  def collect_agent_data(
48
53
  space: OrthogonalGrid | HexGrid | Network | ContinuousSpace | VoronoiGrid,
49
54
  agent_portrayal: Callable,
50
- color="tab:blue",
51
- size=25,
52
- marker="o",
53
- zorder: int = 1,
54
- ):
55
+ default_size: float | None = None,
56
+ ) -> dict:
55
57
  """Collect the plotting data for all agents in the space.
56
58
 
57
59
  Args:
58
60
  space: The space containing the Agents.
59
- agent_portrayal: A callable that is called with the agent and returns a dict
60
- color: default color
61
- size: default size
62
- marker: default marker
63
- zorder: default zorder
64
-
65
- agent_portrayal should return a dict, limited to size (size of marker), color (color of marker), zorder (z-order),
66
- marker (marker style), alpha, linewidths, and edgecolors
61
+ agent_portrayal: A callable that is called with the agent and returns a AgentPortrayalStyle
62
+ default_size: default size
67
63
 
64
+ agent_portrayal should return a AgentPortrayalStyle, limited to size (size of marker), color (color of marker), zorder (z-order),
65
+ marker (marker style), alpha, linewidths, and edgecolors.
68
66
  """
67
+
68
+ def get_agent_pos(agent, space):
69
+ """Helper function to get the agent position depending on the grid type."""
70
+ if isinstance(space, NetworkGrid):
71
+ agent_x, agent_y = agent.pos, agent.pos
72
+ elif isinstance(space, Network):
73
+ agent_x, agent_y = agent.cell.coordinate, agent.cell.coordinate
74
+ else:
75
+ agent_x = (
76
+ agent.pos[0] if agent.pos is not None else agent.cell.coordinate[0]
77
+ )
78
+ agent_y = (
79
+ agent.pos[1] if agent.pos is not None else agent.cell.coordinate[1]
80
+ )
81
+ return agent_x, agent_y
82
+
69
83
  arguments = {
84
+ "loc": [],
70
85
  "s": [],
71
86
  "c": [],
72
87
  "marker": [],
73
88
  "zorder": [],
74
- "loc": [],
75
89
  "alpha": [],
76
90
  "edgecolors": [],
77
91
  "linewidths": [],
78
92
  }
79
93
 
94
+ # Importing AgentPortrayalStyle inside the function to prevent circular imports
95
+ from mesa.visualization.components import AgentPortrayalStyle # noqa: PLC0415
96
+
97
+ # Get AgentPortrayalStyle defaults
98
+ style_fields = {f.name: f.default for f in fields(AgentPortrayalStyle)}
99
+ class_default_size = style_fields.get("size")
100
+
80
101
  for agent in space.agents:
81
- portray = agent_portrayal(agent)
82
- loc = agent.pos
83
- if loc is None:
84
- loc = agent.cell.coordinate
85
-
86
- arguments["loc"].append(loc)
87
- arguments["s"].append(portray.pop("size", size))
88
- arguments["c"].append(portray.pop("color", color))
89
- arguments["marker"].append(portray.pop("marker", marker))
90
- arguments["zorder"].append(portray.pop("zorder", zorder))
91
-
92
- for entry in ["alpha", "edgecolors", "linewidths"]:
93
- with contextlib.suppress(KeyError):
94
- arguments[entry].append(portray.pop(entry))
95
-
96
- if len(portray) > 0:
97
- ignored_fields = list(portray.keys())
98
- msg = ", ".join(ignored_fields)
102
+ portray_input = agent_portrayal(agent)
103
+ aps: AgentPortrayalStyle
104
+
105
+ if isinstance(portray_input, dict):
99
106
  warnings.warn(
100
- f"the following fields are not used in agent portrayal and thus ignored: {msg}.",
107
+ "Returning a dict from agent_portrayal is deprecated and will be removed "
108
+ "in a future version. Please return an AgentPortrayalStyle instance instead.",
109
+ PendingDeprecationWarning,
101
110
  stacklevel=2,
102
111
  )
112
+ dict_data = portray_input.copy()
113
+
114
+ agent_x, agent_y = get_agent_pos(agent, space)
115
+
116
+ # Extract values from the dict, using defaults if not provided
117
+ size_val = dict_data.pop("size", style_fields.get("size"))
118
+ color_val = dict_data.pop("color", style_fields.get("color"))
119
+ marker_val = dict_data.pop("marker", style_fields.get("marker"))
120
+ zorder_val = dict_data.pop("zorder", style_fields.get("zorder"))
121
+ alpha_val = dict_data.pop("alpha", style_fields.get("alpha"))
122
+ edgecolors_val = dict_data.pop("edgecolors", None)
123
+ linewidths_val = dict_data.pop("linewidths", style_fields.get("linewidths"))
124
+
125
+ aps = AgentPortrayalStyle(
126
+ x=agent_x,
127
+ y=agent_y,
128
+ size=size_val,
129
+ color=color_val,
130
+ marker=marker_val,
131
+ zorder=zorder_val,
132
+ alpha=alpha_val,
133
+ edgecolors=edgecolors_val,
134
+ linewidths=linewidths_val,
135
+ )
136
+
137
+ # Report list of unused data
138
+ if dict_data:
139
+ ignored_keys = list(dict_data.keys())
140
+ warnings.warn(
141
+ f"The following keys from the returned dict were ignored: {', '.join(ignored_keys)}",
142
+ UserWarning,
143
+ stacklevel=2,
144
+ )
145
+ else:
146
+ aps = portray_input
147
+ # default to agent's color if not provided
148
+ if aps.edgecolors is None:
149
+ aps.edgecolors = aps.color
150
+ # get position if not specified
151
+ if aps.x is None and aps.y is None:
152
+ aps.x, aps.y = get_agent_pos(agent, space)
153
+
154
+ # Collect common data from the AgentPortrayalStyle instance
155
+ arguments["loc"].append((aps.x, aps.y))
156
+
157
+ # Determine final size for collection
158
+ size_to_collect = aps.size
159
+ if size_to_collect is None:
160
+ size_to_collect = default_size
161
+ if size_to_collect is None:
162
+ size_to_collect = class_default_size
163
+
164
+ arguments["s"].append(size_to_collect)
165
+ arguments["c"].append(aps.color)
166
+ arguments["marker"].append(aps.marker)
167
+ arguments["zorder"].append(aps.zorder)
168
+ arguments["alpha"].append(aps.alpha)
169
+ if aps.edgecolors is not None:
170
+ arguments["edgecolors"].append(aps.edgecolors)
171
+ arguments["linewidths"].append(aps.linewidths)
103
172
 
104
173
  data = {
105
174
  k: (np.asarray(v, dtype=object) if k == "marker" else np.asarray(v))
@@ -115,7 +184,7 @@ def collect_agent_data(
115
184
  def draw_space(
116
185
  space,
117
186
  agent_portrayal: Callable,
118
- propertylayer_portrayal: dict | None = None,
187
+ propertylayer_portrayal: Callable | None = None,
119
188
  ax: Axes | None = None,
120
189
  **space_drawing_kwargs,
121
190
  ):
@@ -123,15 +192,15 @@ def draw_space(
123
192
 
124
193
  Args:
125
194
  space: the space of the mesa model
126
- agent_portrayal: A callable that returns a dict specifying how to show the agent
127
- propertylayer_portrayal: a dict specifying how to show propertylayer(s)
195
+ agent_portrayal: A callable that returns a AgnetPortrayalStyle specifying how to show the agent
196
+ propertylayer_portrayal: A callable that returns a PropertyLayerStyle specifying how to show the property layer
128
197
  ax: the axes upon which to draw the plot
129
198
  space_drawing_kwargs: any additional keyword arguments to be passed on to the underlying function for drawing the space.
130
199
 
131
200
  Returns:
132
201
  Returns the Axes object with the plot drawn onto it.
133
202
 
134
- ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
203
+ ``agent_portrayal`` is called with an agent and should return a AgentPortrayalStyle. Valid fields in this object are "color",
135
204
  "size", "marker", "zorder", alpha, linewidths, and edgecolors. Other field are ignored and will result in a user warning.
136
205
 
137
206
  """
@@ -203,21 +272,52 @@ def _get_hexmesh(
203
272
 
204
273
 
205
274
  def draw_property_layers(
206
- space, propertylayer_portrayal: dict[str, dict[str, Any]], ax: Axes
275
+ space, propertylayer_portrayal: dict[str, dict[str, Any]] | Callable, ax: Axes
207
276
  ):
208
277
  """Draw PropertyLayers on the given axes.
209
278
 
210
279
  Args:
211
280
  space (mesa.space._Grid): The space containing the PropertyLayers.
212
- propertylayer_portrayal (dict): the key is the name of the layer, the value is a dict with
213
- fields specifying how the layer is to be portrayed
281
+ propertylayer_portrayal (Callable): A function that accepts a property layer object
282
+ and returns either a `PropertyLayerStyle` object defining its visualization,
283
+ or `None` to skip drawing this particular layer.
214
284
  ax (matplotlib.axes.Axes): The axes to draw on.
215
285
 
216
- Notes:
217
- valid fields in in the inner dict of propertylayer_portrayal are "alpha", "vmin", "vmax", "color" or "colormap", and "colorbar"
218
- so you can do `{"some_layer":{"colormap":'viridis', 'alpha':.25, "colorbar":False}}`
219
-
220
286
  """
287
+ # Importing here to avoid circular import issues
288
+ from mesa.visualization.components import PropertyLayerStyle # noqa: PLC0415
289
+
290
+ def _propertylayer_portryal_dict_to_callable(
291
+ propertylayer_portrayal: dict[str, dict[str, Any]],
292
+ ):
293
+ """Helper function to convert a propertylayer_portrayal dict to a callable that return a PropertyLayerStyle."""
294
+
295
+ def style_callable(layer_object: Any):
296
+ layer_name = layer_object.name
297
+ params = propertylayer_portrayal.get(layer_name)
298
+
299
+ warnings.warn(
300
+ "The propertylayer_portrayal dict is deprecated. Use a callable that returns PropertyLayerStyle instead.",
301
+ PendingDeprecationWarning,
302
+ stacklevel=2,
303
+ )
304
+
305
+ if params is None:
306
+ return None # Layer not specified in the dict, so skip.
307
+
308
+ return PropertyLayerStyle(
309
+ color=params.get("color"),
310
+ colormap=params.get("colormap"),
311
+ alpha=params.get(
312
+ "alpha", PropertyLayerStyle.alpha
313
+ ), # Use defaults defined in the dataclass itself
314
+ vmin=params.get("vmin"),
315
+ vmax=params.get("vmax"),
316
+ colorbar=params.get("colorbar", PropertyLayerStyle.colorbar),
317
+ )
318
+
319
+ return style_callable
320
+
221
321
  try:
222
322
  # old style spaces
223
323
  property_layers = space.properties
@@ -225,12 +325,24 @@ def draw_property_layers(
225
325
  # new style spaces
226
326
  property_layers = space._mesa_property_layers
227
327
 
228
- for layer_name, portrayal in propertylayer_portrayal.items():
328
+ callable_portrayal: Callable[[Any], PropertyLayerStyle | None]
329
+ if isinstance(propertylayer_portrayal, dict):
330
+ callable_portrayal = _propertylayer_portryal_dict_to_callable(
331
+ propertylayer_portrayal
332
+ )
333
+ else:
334
+ callable_portrayal = propertylayer_portrayal
335
+
336
+ for layer_name in property_layers:
337
+ if layer_name == "empty":
338
+ # Skipping empty layer, automatically generated
339
+ continue
340
+
229
341
  layer = property_layers.get(layer_name, None)
230
- if not isinstance(
231
- layer,
232
- PropertyLayer | mesa.discrete_space.property_layer.PropertyLayer,
233
- ):
342
+ portrayal = callable_portrayal(layer)
343
+
344
+ if portrayal is None:
345
+ # Not visualizing layers that do not have a defined visual encoding.
234
346
  continue
235
347
 
236
348
  data = layer.data.astype(float) if layer.data.dtype == bool else layer.data
@@ -242,20 +354,19 @@ def draw_property_layers(
242
354
  stacklevel=2,
243
355
  )
244
356
 
245
- # Get portrayal properties, or use defaults
246
- alpha = portrayal.get("alpha", 1)
247
- vmin = portrayal.get("vmin", np.min(data))
248
- vmax = portrayal.get("vmax", np.max(data))
249
- colorbar = portrayal.get("colorbar", True)
357
+ color = portrayal.color
358
+ colormap = portrayal.colormap
359
+ alpha = portrayal.alpha
360
+ vmin = portrayal.vmin if portrayal.vmin else np.min(data)
361
+ vmax = portrayal.vmax if portrayal.vmax else np.max(data)
250
362
 
251
- # Prepare colormap
252
- if "color" in portrayal:
253
- rgba_color = to_rgba(portrayal["color"])
363
+ if color:
364
+ rgba_color = to_rgba(color)
254
365
  cmap = LinearSegmentedColormap.from_list(
255
366
  layer_name, [(0, 0, 0, 0), (*rgba_color[:3], alpha)]
256
367
  )
257
- elif "colormap" in portrayal:
258
- cmap = portrayal.get("colormap", "viridis")
368
+ elif colormap:
369
+ cmap = colormap
259
370
  if isinstance(cmap, list):
260
371
  cmap = LinearSegmentedColormap.from_list(layer_name, cmap)
261
372
  elif isinstance(cmap, str):
@@ -266,7 +377,7 @@ def draw_property_layers(
266
377
  )
267
378
 
268
379
  if isinstance(space, OrthogonalGrid):
269
- if "color" in portrayal:
380
+ if color:
270
381
  data = data.T
271
382
  normalized_data = (data - vmin) / (vmax - vmin)
272
383
  rgba_data = np.full((*data.shape, 4), rgba_color)
@@ -282,36 +393,26 @@ def draw_property_layers(
282
393
  vmax=vmax,
283
394
  origin="lower",
284
395
  )
285
-
286
396
  elif isinstance(space, HexGrid):
287
397
  width, height = data.shape
288
-
289
- # Generate hexagon mesh
290
398
  hexagons = _get_hexmesh(width, height)
291
-
292
- # Normalize colors
293
399
  norm = Normalize(vmin=vmin, vmax=vmax)
294
- colors = data.ravel() # flatten data to 1D array
400
+ colors = data.ravel()
295
401
 
296
- if "color" in portrayal:
402
+ if color:
297
403
  normalized_colors = np.clip(norm(colors), 0, 1)
298
404
  rgba_colors = np.full((len(colors), 4), rgba_color)
299
405
  rgba_colors[:, 3] = normalized_colors * alpha
300
406
  else:
301
407
  rgba_colors = cmap(norm(colors))
302
408
  rgba_colors[..., 3] *= alpha
303
-
304
- # Draw hexagons
305
409
  collection = PolyCollection(hexagons, facecolors=rgba_colors, zorder=-1)
306
410
  ax.add_collection(collection)
307
-
308
411
  else:
309
412
  raise NotImplementedError(
310
413
  f"PropertyLayer visualization not implemented for {type(space)}."
311
414
  )
312
-
313
- # Add colorbar if requested
314
- if colorbar:
415
+ if portrayal.colorbar:
315
416
  norm = Normalize(vmin=vmin, vmax=vmax)
316
417
  sm = ScalarMappable(norm=norm, cmap=cmap)
317
418
  sm.set_array([])
@@ -329,7 +430,7 @@ def draw_orthogonal_grid(
329
430
 
330
431
  Args:
331
432
  space: the space to visualize
332
- agent_portrayal: a callable that is called with the agent and returns a dict
433
+ agent_portrayal: a callable that is called with the agent and returns a AgentPortrayalStyle
333
434
  ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
334
435
  draw_grid: whether to draw the grid
335
436
  kwargs: additional keyword arguments passed to ax.scatter
@@ -337,8 +438,8 @@ def draw_orthogonal_grid(
337
438
  Returns:
338
439
  Returns the Axes object with the plot drawn onto it.
339
440
 
340
- ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
341
- "size", "marker", and "zorder". Other field are ignored and will result in a user warning.
441
+ ``agent_portrayal`` is called with an agent and should return a AgentPortrayalStyle. Valid fields in this object are "color",
442
+ "size", "marker", "zorder", alpha, linewidths, and edgecolors. Other field are ignored and will result in a user warning.
342
443
 
343
444
  """
344
445
  if ax is None:
@@ -346,15 +447,15 @@ def draw_orthogonal_grid(
346
447
 
347
448
  # gather agent data
348
449
  s_default = (180 / max(space.width, space.height)) ** 2
349
- arguments = collect_agent_data(space, agent_portrayal, size=s_default)
350
-
351
- # plot the agents
352
- _scatter(ax, arguments, **kwargs)
450
+ arguments = collect_agent_data(space, agent_portrayal, default_size=s_default)
353
451
 
354
452
  # further styling
355
453
  ax.set_xlim(-0.5, space.width - 0.5)
356
454
  ax.set_ylim(-0.5, space.height - 0.5)
357
455
 
456
+ # plot the agents
457
+ _scatter(ax, arguments, **kwargs)
458
+
358
459
  if draw_grid:
359
460
  # Draw grid lines
360
461
  for x in np.arange(-0.5, space.width - 0.5, 1):
@@ -376,33 +477,28 @@ def draw_hex_grid(
376
477
 
377
478
  Args:
378
479
  space: the space to visualize
379
- agent_portrayal: a callable that is called with the agent and returns a dict
480
+ agent_portrayal: a callable that is called with the agent and returns a AgentPortrayalStyle
380
481
  ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
381
482
  draw_grid: whether to draw the grid
382
483
  kwargs: additional keyword arguments passed to ax.scatter
484
+ Returns:
485
+ Returns the Axes object with the plot drawn onto it.
486
+
487
+ ``agent_portrayal`` is called with an agent and should return a AgentPortrayalStyle. Valid fields in this object are "color",
488
+ "size", "marker", "zorder", alpha, linewidths, and edgecolors. Other field are ignored and will result in a user warning.
383
489
  """
384
490
  if ax is None:
385
491
  fig, ax = plt.subplots()
386
492
 
387
493
  # gather data
388
494
  s_default = (180 / max(space.width, space.height)) ** 2
389
- arguments = collect_agent_data(space, agent_portrayal, size=s_default)
495
+ arguments = collect_agent_data(space, agent_portrayal, default_size=s_default)
390
496
 
391
497
  # Parameters for hexagon grid
392
498
  size = 1.0
393
499
  x_spacing = np.sqrt(3) * size
394
500
  y_spacing = 1.5 * size
395
501
 
396
- loc = arguments["loc"].astype(float)
397
- # Calculate hexagon centers for agents if agents are present and plot them.
398
- if loc.size > 0:
399
- loc[:, 0] = loc[:, 0] * x_spacing + ((loc[:, 1] - 1) % 2) * (x_spacing / 2)
400
- loc[:, 1] = loc[:, 1] * y_spacing
401
- arguments["loc"] = loc
402
-
403
- # plot the agents
404
- _scatter(ax, arguments, **kwargs)
405
-
406
502
  # Calculate proper bounds that account for the full hexagon width and height
407
503
  x_max = space.width * x_spacing + (space.height % 2) * (x_spacing / 2)
408
504
  y_max = space.height * y_spacing
@@ -418,6 +514,16 @@ def draw_hex_grid(
418
514
  ax.set_xlim(-2 * x_padding, x_max + x_padding)
419
515
  ax.set_ylim(-2 * y_padding, y_max + y_padding)
420
516
 
517
+ loc = arguments["loc"].astype(float)
518
+ # Calculate hexagon centers for agents if agents are present and plot them.
519
+ if loc.size > 0:
520
+ loc[:, 0] = loc[:, 0] * x_spacing + ((loc[:, 1] - 1) % 2) * (x_spacing / 2)
521
+ loc[:, 1] = loc[:, 1] * y_spacing
522
+ arguments["loc"] = loc
523
+
524
+ # plot the agents
525
+ _scatter(ax, arguments, **kwargs)
526
+
421
527
  def setup_hexmesh(width, height):
422
528
  """Helper function for creating the hexmesh with unique edges."""
423
529
  edges = set()
@@ -452,7 +558,7 @@ def draw_network(
452
558
 
453
559
  Args:
454
560
  space: the space to visualize
455
- agent_portrayal: a callable that is called with the agent and returns a dict
561
+ agent_portrayal: a callable that is called with the agent and returns a AgentPortrayalStyle
456
562
  ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
457
563
  draw_grid: whether to draw the grid
458
564
  layout_alg: a networkx layout algorithm or other callable with the same behavior
@@ -462,8 +568,8 @@ def draw_network(
462
568
  Returns:
463
569
  Returns the Axes object with the plot drawn onto it.
464
570
 
465
- ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
466
- "size", "marker", and "zorder". Other field are ignored and will result in a user warning.
571
+ ``agent_portrayal`` is called with an agent and should return a AgentPortrayalStyle. Valid fields in this object are "color",
572
+ "size", "marker", "zorder", alpha, linewidths, and edgecolors. Other field are ignored and will result in a user warning.
467
573
 
468
574
  """
469
575
  if ax is None:
@@ -485,21 +591,28 @@ def draw_network(
485
591
 
486
592
  # gather agent data
487
593
  s_default = (180 / max(width, height)) ** 2
488
- arguments = collect_agent_data(space, agent_portrayal, size=s_default)
594
+ arguments = collect_agent_data(space, agent_portrayal, default_size=s_default)
489
595
 
490
596
  # this assumes that nodes are identified by an integer
491
597
  # which is true for default nx graphs but might user changeable
492
598
  pos = np.asarray(list(pos.values()))
493
- arguments["loc"] = pos[arguments["loc"]]
599
+ loc = arguments["loc"]
494
600
 
495
- # plot the agents
496
- _scatter(ax, arguments, **kwargs)
601
+ # For network only one of x and y contains the correct coordinates
602
+ x = loc[:, 0]
603
+ if x is None:
604
+ x = loc[:, 1]
605
+
606
+ arguments["loc"] = pos[x]
497
607
 
498
608
  # further styling
499
609
  ax.set_axis_off()
500
610
  ax.set_xlim(xmin=xmin - x_padding, xmax=xmax + x_padding)
501
611
  ax.set_ylim(ymin=ymin - y_padding, ymax=ymax + y_padding)
502
612
 
613
+ # plot the agents
614
+ _scatter(ax, arguments, **kwargs)
615
+
503
616
  if draw_grid:
504
617
  # fixme we need to draw the empty nodes as well
505
618
  edge_collection = nx.draw_networkx_edges(
@@ -517,15 +630,15 @@ def draw_continuous_space(
517
630
 
518
631
  Args:
519
632
  space: the space to visualize
520
- agent_portrayal: a callable that is called with the agent and returns a dict
633
+ agent_portrayal: a callable that is called with the agent and returns a AgentPortrayalStyle
521
634
  ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
522
635
  kwargs: additional keyword arguments passed to ax.scatter
523
636
 
524
637
  Returns:
525
638
  Returns the Axes object with the plot drawn onto it.
526
639
 
527
- ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
528
- "size", "marker", and "zorder". Other field are ignored and will result in a user warning.
640
+ ``agent_portrayal`` is called with an agent and should return a AgentPortrayalStyle. Valid fields in this object are "color",
641
+ "size", "marker", "zorder", alpha, linewidths, and edgecolors. Other field are ignored and will result in a user warning.
529
642
 
530
643
  """
531
644
  if ax is None:
@@ -539,10 +652,7 @@ def draw_continuous_space(
539
652
 
540
653
  # gather agent data
541
654
  s_default = (180 / max(width, height)) ** 2
542
- arguments = collect_agent_data(space, agent_portrayal, size=s_default)
543
-
544
- # plot the agents
545
- _scatter(ax, arguments, **kwargs)
655
+ arguments = collect_agent_data(space, agent_portrayal, default_size=s_default)
546
656
 
547
657
  # further visual styling
548
658
  border_style = "solid" if not space.torus else (0, (5, 10))
@@ -554,6 +664,9 @@ def draw_continuous_space(
554
664
  ax.set_xlim(space.x_min - x_padding, space.x_max + x_padding)
555
665
  ax.set_ylim(space.y_min - y_padding, space.y_max + y_padding)
556
666
 
667
+ # plot the agents
668
+ _scatter(ax, arguments, **kwargs)
669
+
557
670
  return ax
558
671
 
559
672
 
@@ -568,7 +681,7 @@ def draw_voronoi_grid(
568
681
 
569
682
  Args:
570
683
  space: the space to visualize
571
- agent_portrayal: a callable that is called with the agent and returns a dict
684
+ agent_portrayal: a callable that is called with the agent and returns a AgentPortrayalStyle
572
685
  ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
573
686
  draw_grid: whether to draw the grid or not
574
687
  kwargs: additional keyword arguments passed to ax.scatter
@@ -576,8 +689,8 @@ def draw_voronoi_grid(
576
689
  Returns:
577
690
  Returns the Axes object with the plot drawn onto it.
578
691
 
579
- ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
580
- "size", "marker", and "zorder". Other field are ignored and will result in a user warning.
692
+ ``agent_portrayal`` is called with an agent and should return a AgentPortrayalStyle. Valid fields in this object are "color",
693
+ "size", "marker", "zorder", alpha, linewidths, and edgecolors. Other field are ignored and will result in a user warning.
581
694
 
582
695
  """
583
696
  if ax is None:
@@ -596,7 +709,7 @@ def draw_voronoi_grid(
596
709
  y_padding = height / 20
597
710
 
598
711
  s_default = (180 / max(width, height)) ** 2
599
- arguments = collect_agent_data(space, agent_portrayal, size=s_default)
712
+ arguments = collect_agent_data(space, agent_portrayal, default_size=s_default)
600
713
 
601
714
  ax.set_xlim(x_min - x_padding, x_max + x_padding)
602
715
  ax.set_ylim(y_min - y_padding, y_max + y_padding)
@@ -618,26 +731,50 @@ def draw_voronoi_grid(
618
731
  return ax
619
732
 
620
733
 
734
+ def _get_zoom_factor(ax, img):
735
+ ax.get_figure().canvas.draw()
736
+ bbox = ax.get_window_extent().transformed(
737
+ ax.get_figure().dpi_scale_trans.inverted()
738
+ ) # in inches
739
+ width, height = (
740
+ bbox.width * ax.get_figure().dpi,
741
+ bbox.height * ax.get_figure().dpi,
742
+ ) # in pixel
743
+
744
+ xr = ax.get_xlim()
745
+ yr = ax.get_ylim()
746
+
747
+ x_pixel_per_data = width / (xr[1] - xr[0])
748
+ y_pixel_per_data = height / (yr[1] - yr[0])
749
+
750
+ zoom_x = (x_pixel_per_data / img.width) * CORRECTION_FACTOR_MARKER_ZOOM
751
+ zoom_y = (y_pixel_per_data / img.height) * CORRECTION_FACTOR_MARKER_ZOOM
752
+
753
+ return min(zoom_x, zoom_y)
754
+
755
+
621
756
  def _scatter(ax: Axes, arguments, **kwargs):
622
757
  """Helper function for plotting the agents.
623
758
 
624
759
  Args:
625
760
  ax: a Matplotlib Axes instance
626
- arguments: the agents specific arguments for platting
761
+ arguments: the agents specific arguments for plotting
627
762
  kwargs: additional keyword arguments for ax.scatter
628
763
 
629
764
  """
630
765
  loc = arguments.pop("loc")
631
766
 
632
- x = loc[:, 0]
633
- y = loc[:, 1]
767
+ loc_x = loc[:, 0]
768
+ loc_y = loc[:, 1]
634
769
  marker = arguments.pop("marker")
635
770
  zorder = arguments.pop("zorder")
771
+ malpha = arguments.pop("alpha")
772
+ msize = arguments.pop("s")
636
773
 
637
774
  # we check if edgecolor, linewidth, and alpha are specified
638
775
  # at the agent level, if not, we remove them from the arguments dict
639
776
  # and fallback to the default value in ax.scatter / use what is passed via **kwargs
640
- for entry in ["edgecolors", "linewidths", "alpha"]:
777
+ for entry in ["edgecolors", "linewidths"]:
641
778
  if len(arguments[entry]) == 0:
642
779
  arguments.pop(entry)
643
780
  else:
@@ -646,17 +783,43 @@ def _scatter(ax: Axes, arguments, **kwargs):
646
783
  f"{entry} is specified in agent portrayal and via plotting kwargs, you can only use one or the other"
647
784
  )
648
785
 
786
+ ax.get_figure().canvas.draw()
649
787
  for mark in set(marker):
650
- mark_mask = [m == mark for m in list(marker)]
651
- for z_order in np.unique(zorder):
652
- zorder_mask = z_order == zorder
653
- logical = mark_mask & zorder_mask
654
-
655
- ax.scatter(
656
- x[logical],
657
- y[logical],
658
- marker=mark,
659
- zorder=z_order,
660
- **{k: v[logical] for k, v in arguments.items()},
661
- **kwargs,
662
- )
788
+ if isinstance(mark, (str | os.PathLike)) and os.path.isfile(mark):
789
+ # images
790
+ for m_size in np.unique(msize):
791
+ image = Image.open(mark)
792
+ im = OffsetImage(
793
+ image,
794
+ zoom=_get_zoom_factor(ax, image) * m_size / DEFAULT_MARKER_SIZE,
795
+ )
796
+ im.image.axes = ax
797
+
798
+ mask_marker = [m == mark for m in list(marker)] & (m_size == msize)
799
+ for z_order in np.unique(zorder[mask_marker]):
800
+ for m_alpha in np.unique(malpha[mask_marker]):
801
+ mask = (z_order == zorder) & (m_alpha == malpha) & mask_marker
802
+ for x, y in zip(loc_x[mask], loc_y[mask]):
803
+ ab = AnnotationBbox(
804
+ im,
805
+ (x, y),
806
+ frameon=False,
807
+ pad=0.0,
808
+ zorder=z_order,
809
+ **kwargs,
810
+ )
811
+ ax.add_artist(ab)
812
+
813
+ else:
814
+ # ordinary markers
815
+ mask_marker = [m == mark for m in list(marker)]
816
+ for z_order in np.unique(zorder[mask_marker]):
817
+ zorder_mask = z_order == zorder & mask_marker
818
+ ax.scatter(
819
+ loc_x[zorder_mask],
820
+ loc_y[zorder_mask],
821
+ marker=mark,
822
+ zorder=z_order,
823
+ **{k: v[zorder_mask] for k, v in arguments.items()},
824
+ **kwargs,
825
+ )