Mesa 3.2.0__py3-none-any.whl → 3.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of Mesa might be problematic. Click here for more details.
- mesa/__init__.py +1 -1
- mesa/agent.py +3 -3
- mesa/datacollection.py +1 -1
- mesa/examples/advanced/epstein_civil_violence/app.py +11 -11
- mesa/examples/advanced/pd_grid/app.py +10 -11
- mesa/examples/advanced/sugarscape_g1mt/app.py +34 -16
- mesa/examples/advanced/wolf_sheep/app.py +21 -18
- mesa/examples/basic/boid_flockers/app.py +15 -11
- mesa/examples/basic/boltzmann_wealth_model/app.py +39 -32
- mesa/examples/basic/conways_game_of_life/app.py +13 -16
- mesa/examples/basic/schelling/Readme.md +2 -2
- mesa/examples/basic/schelling/agents.py +9 -3
- mesa/examples/basic/schelling/app.py +50 -3
- mesa/examples/basic/schelling/model.py +2 -0
- mesa/examples/basic/schelling/resources/blue_happy.png +0 -0
- mesa/examples/basic/schelling/resources/blue_unhappy.png +0 -0
- mesa/examples/basic/schelling/resources/orange_happy.png +0 -0
- mesa/examples/basic/schelling/resources/orange_unhappy.png +0 -0
- mesa/examples/basic/virus_on_network/app.py +31 -14
- mesa/experimental/continuous_space/continuous_space.py +1 -1
- mesa/space.py +4 -1
- mesa/visualization/__init__.py +2 -0
- mesa/visualization/backends/__init__.py +23 -0
- mesa/visualization/backends/abstract_renderer.py +97 -0
- mesa/visualization/backends/altair_backend.py +440 -0
- mesa/visualization/backends/matplotlib_backend.py +419 -0
- mesa/visualization/components/__init__.py +28 -8
- mesa/visualization/components/altair_components.py +86 -0
- mesa/visualization/components/matplotlib_components.py +4 -2
- mesa/visualization/components/portrayal_components.py +120 -0
- mesa/visualization/mpl_space_drawing.py +292 -129
- mesa/visualization/solara_viz.py +274 -32
- mesa/visualization/space_drawers.py +797 -0
- mesa/visualization/space_renderer.py +399 -0
- {mesa-3.2.0.dist-info → mesa-3.3.0.dist-info}/METADATA +13 -4
- {mesa-3.2.0.dist-info → mesa-3.3.0.dist-info}/RECORD +39 -29
- mesa/examples/advanced/sugarscape_g1mt/tests.py +0 -69
- {mesa-3.2.0.dist-info → mesa-3.3.0.dist-info}/WHEEL +0 -0
- {mesa-3.2.0.dist-info → mesa-3.3.0.dist-info}/licenses/LICENSE +0 -0
- {mesa-3.2.0.dist-info → mesa-3.3.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -10,18 +10,19 @@ from mesa.examples.basic.virus_on_network.model import (
|
|
|
10
10
|
from mesa.visualization import (
|
|
11
11
|
Slider,
|
|
12
12
|
SolaraViz,
|
|
13
|
+
SpaceRenderer,
|
|
13
14
|
make_plot_component,
|
|
14
|
-
make_space_component,
|
|
15
15
|
)
|
|
16
|
+
from mesa.visualization.components import AgentPortrayalStyle
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
def agent_portrayal(agent):
|
|
19
20
|
node_color_dict = {
|
|
20
|
-
State.INFECTED: "
|
|
21
|
-
State.SUSCEPTIBLE: "
|
|
22
|
-
State.RESISTANT: "
|
|
21
|
+
State.INFECTED: "red",
|
|
22
|
+
State.SUSCEPTIBLE: "green",
|
|
23
|
+
State.RESISTANT: "gray",
|
|
23
24
|
}
|
|
24
|
-
return
|
|
25
|
+
return AgentPortrayalStyle(color=node_color_dict[agent.state], size=20)
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
def get_resistant_susceptible_ratio(model):
|
|
@@ -92,24 +93,40 @@ model_params = {
|
|
|
92
93
|
}
|
|
93
94
|
|
|
94
95
|
|
|
95
|
-
def post_process_lineplot(
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
96
|
+
def post_process_lineplot(chart):
|
|
97
|
+
chart = chart.properties(
|
|
98
|
+
width=400,
|
|
99
|
+
height=400,
|
|
100
|
+
).configure_legend(
|
|
101
|
+
strokeColor="black",
|
|
102
|
+
fillColor="#ECE9E9",
|
|
103
|
+
orient="right",
|
|
104
|
+
cornerRadius=5,
|
|
105
|
+
padding=10,
|
|
106
|
+
strokeWidth=1,
|
|
107
|
+
)
|
|
108
|
+
return chart
|
|
109
|
+
|
|
99
110
|
|
|
111
|
+
model1 = VirusOnNetwork()
|
|
112
|
+
renderer = SpaceRenderer(model1, backend="altair")
|
|
113
|
+
renderer.draw_structure(
|
|
114
|
+
node_kwargs={"color": "black", "filled": False, "strokeWidth": 5},
|
|
115
|
+
edge_kwargs={"strokeDash": [6, 1]},
|
|
116
|
+
) # Do this to draw the underlying network and customize it
|
|
117
|
+
renderer.draw_agents(agent_portrayal)
|
|
100
118
|
|
|
101
|
-
|
|
119
|
+
# Plot components can also be in altair and support post_process
|
|
102
120
|
StatePlot = make_plot_component(
|
|
103
|
-
{"Infected": "
|
|
121
|
+
{"Infected": "red", "Susceptible": "green", "Resistant": "gray"},
|
|
122
|
+
backend="altair",
|
|
104
123
|
post_process=post_process_lineplot,
|
|
105
124
|
)
|
|
106
125
|
|
|
107
|
-
model1 = VirusOnNetwork()
|
|
108
|
-
|
|
109
126
|
page = SolaraViz(
|
|
110
127
|
model1,
|
|
128
|
+
renderer,
|
|
111
129
|
components=[
|
|
112
|
-
SpacePlot,
|
|
113
130
|
StatePlot,
|
|
114
131
|
get_resistant_susceptible_ratio,
|
|
115
132
|
],
|
|
@@ -117,7 +117,7 @@ class ContinuousSpace:
|
|
|
117
117
|
if self._agent_positions.shape[0] <= index:
|
|
118
118
|
# we are out of space
|
|
119
119
|
fraction = 0.2 # we add 20% Fixme
|
|
120
|
-
n =
|
|
120
|
+
n = round(fraction * self._n_agents, None)
|
|
121
121
|
self._agent_positions = np.vstack(
|
|
122
122
|
[
|
|
123
123
|
self._agent_positions,
|
mesa/space.py
CHANGED
|
@@ -1571,7 +1571,10 @@ class NetworkGrid:
|
|
|
1571
1571
|
)
|
|
1572
1572
|
if not include_center:
|
|
1573
1573
|
del neighbors_with_distance[node_id]
|
|
1574
|
-
|
|
1574
|
+
neighbors_with_distance = sorted(
|
|
1575
|
+
neighbors_with_distance.items(), key=lambda item: item[1]
|
|
1576
|
+
)
|
|
1577
|
+
neighborhood = [node_id for node_id, _ in neighbors_with_distance]
|
|
1575
1578
|
return neighborhood
|
|
1576
1579
|
|
|
1577
1580
|
def get_neighbors(
|
mesa/visualization/__init__.py
CHANGED
|
@@ -13,6 +13,7 @@ from .command_console import CommandConsole
|
|
|
13
13
|
from .components import make_plot_component, make_space_component
|
|
14
14
|
from .components.altair_components import make_space_altair
|
|
15
15
|
from .solara_viz import JupyterViz, SolaraViz
|
|
16
|
+
from .space_renderer import SpaceRenderer
|
|
16
17
|
from .user_param import Slider
|
|
17
18
|
|
|
18
19
|
__all__ = [
|
|
@@ -20,6 +21,7 @@ __all__ = [
|
|
|
20
21
|
"JupyterViz",
|
|
21
22
|
"Slider",
|
|
22
23
|
"SolaraViz",
|
|
24
|
+
"SpaceRenderer",
|
|
23
25
|
"draw_space",
|
|
24
26
|
"make_plot_component",
|
|
25
27
|
"make_space_altair",
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Visualization backends for Mesa space rendering.
|
|
2
|
+
|
|
3
|
+
This module provides different backend implementations for visualizing
|
|
4
|
+
Mesa agent-based model spaces and components.
|
|
5
|
+
|
|
6
|
+
Note:
|
|
7
|
+
These backends are used internally by the space renderer and are not intended for
|
|
8
|
+
direct use by end users. See `SpaceRenderer` for actual usage and setting up
|
|
9
|
+
visualizations.
|
|
10
|
+
|
|
11
|
+
Available Backends:
|
|
12
|
+
1. AltairBackend
|
|
13
|
+
2. MatplotlibBackend
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from .altair_backend import AltairBackend
|
|
18
|
+
from .matplotlib_backend import MatplotlibBackend
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"AltairBackend",
|
|
22
|
+
"MatplotlibBackend",
|
|
23
|
+
]
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""Abstract base class for visualization backends in Mesa.
|
|
2
|
+
|
|
3
|
+
This module provides the foundational interface for implementing various
|
|
4
|
+
visualization backends for Mesa agent-based models.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
|
|
9
|
+
import mesa
|
|
10
|
+
from mesa.discrete_space import (
|
|
11
|
+
OrthogonalMooreGrid,
|
|
12
|
+
OrthogonalVonNeumannGrid,
|
|
13
|
+
)
|
|
14
|
+
from mesa.space import (
|
|
15
|
+
HexMultiGrid,
|
|
16
|
+
HexSingleGrid,
|
|
17
|
+
MultiGrid,
|
|
18
|
+
NetworkGrid,
|
|
19
|
+
SingleGrid,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
OrthogonalGrid = SingleGrid | MultiGrid | OrthogonalMooreGrid | OrthogonalVonNeumannGrid
|
|
23
|
+
HexGrid = HexSingleGrid | HexMultiGrid | mesa.discrete_space.HexGrid
|
|
24
|
+
Network = NetworkGrid | mesa.discrete_space.Network
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class AbstractRenderer(ABC):
|
|
28
|
+
"""Abstract base class for visualization backends.
|
|
29
|
+
|
|
30
|
+
This class defines the interface for rendering Mesa spaces and agents.
|
|
31
|
+
For details on the methods checkout specific backend implementations.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, space_drawer):
|
|
35
|
+
"""Initialize the renderer.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
space_drawer: Object responsible for drawing space elements. Checkout `SpaceDrawer`
|
|
39
|
+
for more details on the detailed implementations of the drawing functions.
|
|
40
|
+
"""
|
|
41
|
+
self.space_drawer = space_drawer
|
|
42
|
+
self._canvas = None
|
|
43
|
+
|
|
44
|
+
def _get_agent_pos(self, agent, space):
|
|
45
|
+
"""Get agent position based on space type."""
|
|
46
|
+
if isinstance(space, NetworkGrid):
|
|
47
|
+
return agent.pos, agent.pos
|
|
48
|
+
elif isinstance(space, Network):
|
|
49
|
+
return agent.cell.coordinate, agent.cell.coordinate
|
|
50
|
+
else:
|
|
51
|
+
x = agent.pos[0] if agent.pos is not None else agent.cell.coordinate[0]
|
|
52
|
+
y = agent.pos[1] if agent.pos is not None else agent.cell.coordinate[1]
|
|
53
|
+
return x, y
|
|
54
|
+
|
|
55
|
+
@abstractmethod
|
|
56
|
+
def initialize_canvas(self):
|
|
57
|
+
"""Set up the drawing canvas."""
|
|
58
|
+
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def draw_structure(self, **kwargs):
|
|
61
|
+
"""Draw the space structure.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
**kwargs: Structure drawing configuration options.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
@abstractmethod
|
|
68
|
+
def collect_agent_data(self, space, agent_portrayal, default_size=None):
|
|
69
|
+
"""Collect plotting data for all agents in the space.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
space: The Mesa space containing agents.
|
|
73
|
+
agent_portrayal (Callable): Function that returns AgentPortrayalStyle for each agent.
|
|
74
|
+
default_size (float, optional): Default marker size if not specified in portrayal.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
dict: Dictionary containing agent plotting data arrays with keys:
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
@abstractmethod
|
|
81
|
+
def draw_agents(self, arguments, **kwargs):
|
|
82
|
+
"""Drawing agents on space.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
arguments (dict): Dictionary containing agent data.
|
|
86
|
+
**kwargs: Additional drawing configuration options.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
@abstractmethod
|
|
90
|
+
def draw_propertylayer(self, space, property_layers, propertylayer_portrayal):
|
|
91
|
+
"""Draw property layers on the visualization.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
space: The model's space object.
|
|
95
|
+
property_layers (dict): Dictionary of property layers to visualize.
|
|
96
|
+
propertylayer_portrayal (Callable): Function that returns PropertyLayerStyle.
|
|
97
|
+
"""
|
|
@@ -0,0 +1,440 @@
|
|
|
1
|
+
# noqa: D100
|
|
2
|
+
import warnings
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from dataclasses import fields
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import altair as alt
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
from matplotlib.colors import to_rgb
|
|
11
|
+
|
|
12
|
+
import mesa
|
|
13
|
+
from mesa.discrete_space import (
|
|
14
|
+
OrthogonalMooreGrid,
|
|
15
|
+
OrthogonalVonNeumannGrid,
|
|
16
|
+
)
|
|
17
|
+
from mesa.space import (
|
|
18
|
+
HexMultiGrid,
|
|
19
|
+
HexSingleGrid,
|
|
20
|
+
MultiGrid,
|
|
21
|
+
NetworkGrid,
|
|
22
|
+
SingleGrid,
|
|
23
|
+
)
|
|
24
|
+
from mesa.visualization.backends.abstract_renderer import AbstractRenderer
|
|
25
|
+
|
|
26
|
+
OrthogonalGrid = SingleGrid | MultiGrid | OrthogonalMooreGrid | OrthogonalVonNeumannGrid
|
|
27
|
+
HexGrid = HexSingleGrid | HexMultiGrid | mesa.discrete_space.HexGrid
|
|
28
|
+
Network = NetworkGrid | mesa.discrete_space.Network
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AltairBackend(AbstractRenderer):
|
|
32
|
+
"""Altair-based renderer for Mesa spaces.
|
|
33
|
+
|
|
34
|
+
This module provides an Altair-based renderer for visualizing Mesa model spaces,
|
|
35
|
+
agents, and property layers with interactive charting capabilities.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def initialize_canvas(self) -> None:
|
|
39
|
+
"""Initialize the Altair canvas."""
|
|
40
|
+
self._canvas = None
|
|
41
|
+
|
|
42
|
+
def draw_structure(self, **kwargs) -> alt.Chart:
|
|
43
|
+
"""Draw the space structure using Altair.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
**kwargs: Additional arguments passed to the space drawer.
|
|
47
|
+
Checkout respective `SpaceDrawer` class on details how to pass **kwargs.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
alt.Chart: The Altair chart representing the space structure.
|
|
51
|
+
"""
|
|
52
|
+
return self.space_drawer.draw_altair(**kwargs)
|
|
53
|
+
|
|
54
|
+
def collect_agent_data(
|
|
55
|
+
self, space, agent_portrayal: Callable, default_size: float | None = None
|
|
56
|
+
):
|
|
57
|
+
"""Collect plotting data for all agents in the space for Altair.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
space: The Mesa space containing agents.
|
|
61
|
+
agent_portrayal: Callable that returns AgentPortrayalStyle for each agent.
|
|
62
|
+
default_size: Default marker size if not specified in portrayal.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
dict: Dictionary containing agent plotting data arrays.
|
|
66
|
+
"""
|
|
67
|
+
# Initialize data collection arrays
|
|
68
|
+
arguments = {
|
|
69
|
+
"loc": [],
|
|
70
|
+
"size": [],
|
|
71
|
+
"color": [],
|
|
72
|
+
"shape": [],
|
|
73
|
+
"order": [], # z-order
|
|
74
|
+
"opacity": [],
|
|
75
|
+
"stroke": [], # Stroke color
|
|
76
|
+
"strokeWidth": [],
|
|
77
|
+
"filled": [],
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
# Import here to avoid circular import issues
|
|
81
|
+
from mesa.visualization.components import AgentPortrayalStyle # noqa: PLC0415
|
|
82
|
+
|
|
83
|
+
style_fields = {f.name: f.default for f in fields(AgentPortrayalStyle)}
|
|
84
|
+
class_default_size = style_fields.get("size")
|
|
85
|
+
|
|
86
|
+
# Marker mapping from Matplotlib to Altair
|
|
87
|
+
marker_to_shape_map = {
|
|
88
|
+
"o": "circle",
|
|
89
|
+
"s": "square",
|
|
90
|
+
"D": "diamond",
|
|
91
|
+
"^": "triangle-up",
|
|
92
|
+
"v": "triangle-down",
|
|
93
|
+
"<": "triangle-left",
|
|
94
|
+
">": "triangle-right",
|
|
95
|
+
"+": "cross",
|
|
96
|
+
"x": "cross", # Both '+' and 'x' map to cross in Altair
|
|
97
|
+
".": "circle", # Small point becomes circle
|
|
98
|
+
"1": "triangle-down",
|
|
99
|
+
"2": "triangle-up",
|
|
100
|
+
"3": "triangle-left",
|
|
101
|
+
"4": "triangle-right",
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
for agent in space.agents:
|
|
105
|
+
portray_input = agent_portrayal(agent)
|
|
106
|
+
aps: AgentPortrayalStyle
|
|
107
|
+
|
|
108
|
+
if isinstance(portray_input, dict):
|
|
109
|
+
warnings.warn(
|
|
110
|
+
"Returning a dict from agent_portrayal is deprecated. "
|
|
111
|
+
"Please return an AgentPortrayalStyle instance instead.",
|
|
112
|
+
PendingDeprecationWarning,
|
|
113
|
+
stacklevel=2,
|
|
114
|
+
)
|
|
115
|
+
dict_data = portray_input.copy()
|
|
116
|
+
agent_x, agent_y = self._get_agent_pos(agent, space)
|
|
117
|
+
|
|
118
|
+
aps = AgentPortrayalStyle(
|
|
119
|
+
x=agent_x,
|
|
120
|
+
y=agent_y,
|
|
121
|
+
size=dict_data.pop("size", style_fields.get("size")),
|
|
122
|
+
color=dict_data.pop("color", style_fields.get("color")),
|
|
123
|
+
marker=dict_data.pop("marker", style_fields.get("marker")),
|
|
124
|
+
zorder=dict_data.pop("zorder", style_fields.get("zorder")),
|
|
125
|
+
alpha=dict_data.pop("alpha", style_fields.get("alpha")),
|
|
126
|
+
edgecolors=dict_data.pop(
|
|
127
|
+
"edgecolors", style_fields.get("edgecolors")
|
|
128
|
+
),
|
|
129
|
+
linewidths=dict_data.pop(
|
|
130
|
+
"linewidths", style_fields.get("linewidths")
|
|
131
|
+
),
|
|
132
|
+
)
|
|
133
|
+
if dict_data:
|
|
134
|
+
ignored_keys = list(dict_data.keys())
|
|
135
|
+
warnings.warn(
|
|
136
|
+
f"The following keys were ignored from dict portrayal: {', '.join(ignored_keys)}",
|
|
137
|
+
UserWarning,
|
|
138
|
+
stacklevel=2,
|
|
139
|
+
)
|
|
140
|
+
else:
|
|
141
|
+
aps = portray_input
|
|
142
|
+
if aps.x is None and aps.y is None:
|
|
143
|
+
aps.x, aps.y = self._get_agent_pos(agent, space)
|
|
144
|
+
|
|
145
|
+
arguments["loc"].append((aps.x, aps.y))
|
|
146
|
+
|
|
147
|
+
size_to_collect = aps.size if aps.size is not None else default_size
|
|
148
|
+
if size_to_collect is None:
|
|
149
|
+
size_to_collect = class_default_size
|
|
150
|
+
arguments["size"].append(size_to_collect)
|
|
151
|
+
|
|
152
|
+
arguments["color"].append(
|
|
153
|
+
aps.color if aps.color is not None else style_fields.get("color")
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# Map marker to Altair shape if defined, else use raw marker
|
|
157
|
+
raw_marker = (
|
|
158
|
+
aps.marker if aps.marker is not None else style_fields.get("marker")
|
|
159
|
+
)
|
|
160
|
+
shape_value = marker_to_shape_map.get(raw_marker, raw_marker)
|
|
161
|
+
if shape_value is None:
|
|
162
|
+
warnings.warn(
|
|
163
|
+
f"Marker '{raw_marker}' is not supported in Altair. "
|
|
164
|
+
"Using 'circle' as default.",
|
|
165
|
+
UserWarning,
|
|
166
|
+
stacklevel=2,
|
|
167
|
+
)
|
|
168
|
+
shape_value = "circle"
|
|
169
|
+
arguments["shape"].append(shape_value)
|
|
170
|
+
|
|
171
|
+
arguments["order"].append(
|
|
172
|
+
aps.zorder if aps.zorder is not None else style_fields.get("zorder")
|
|
173
|
+
)
|
|
174
|
+
arguments["opacity"].append(
|
|
175
|
+
aps.alpha if aps.alpha is not None else style_fields.get("alpha")
|
|
176
|
+
)
|
|
177
|
+
arguments["stroke"].append(aps.edgecolors)
|
|
178
|
+
arguments["strokeWidth"].append(
|
|
179
|
+
aps.linewidths
|
|
180
|
+
if aps.linewidths is not None
|
|
181
|
+
else style_fields.get("linewidths")
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# FIXME: Make filled user-controllable
|
|
185
|
+
filled_value = True
|
|
186
|
+
arguments["filled"].append(filled_value)
|
|
187
|
+
|
|
188
|
+
final_data = {}
|
|
189
|
+
for k, v in arguments.items():
|
|
190
|
+
if k == "shape":
|
|
191
|
+
# Ensure shape is an object array
|
|
192
|
+
arr = np.empty(len(v), dtype=object)
|
|
193
|
+
arr[:] = v
|
|
194
|
+
final_data[k] = arr
|
|
195
|
+
elif k in ["x", "y", "size", "order", "opacity", "strokeWidth"]:
|
|
196
|
+
final_data[k] = np.asarray(v, dtype=float)
|
|
197
|
+
else:
|
|
198
|
+
final_data[k] = np.asarray(v)
|
|
199
|
+
|
|
200
|
+
return final_data
|
|
201
|
+
|
|
202
|
+
def draw_agents(
|
|
203
|
+
self, arguments, chart_width: int = 450, chart_height: int = 350, **kwargs
|
|
204
|
+
):
|
|
205
|
+
"""Draw agents using Altair backend.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
arguments: Dictionary containing agent data arrays.
|
|
209
|
+
chart_width: Width of the chart.
|
|
210
|
+
chart_height: Height of the chart.
|
|
211
|
+
**kwargs: Additional keyword arguments for customization.
|
|
212
|
+
Checkout respective `SpaceDrawer` class on details how to pass **kwargs.
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
alt.Chart: The Altair chart representing the agents, or None if no agents.
|
|
216
|
+
"""
|
|
217
|
+
if arguments["loc"].size == 0:
|
|
218
|
+
return None
|
|
219
|
+
|
|
220
|
+
# To get a continuous scale for color the domain should be between [0, 1]
|
|
221
|
+
# that's why changing the the domain of strokeWidth beforehand.
|
|
222
|
+
stroke_width = [data / 10 for data in arguments["strokeWidth"]]
|
|
223
|
+
|
|
224
|
+
# Agent data preparation
|
|
225
|
+
df_data = {
|
|
226
|
+
"x": arguments["loc"][:, 0],
|
|
227
|
+
"y": arguments["loc"][:, 1],
|
|
228
|
+
"size": arguments["size"],
|
|
229
|
+
"shape": arguments["shape"],
|
|
230
|
+
"opacity": arguments["opacity"],
|
|
231
|
+
"strokeWidth": stroke_width,
|
|
232
|
+
"original_color": arguments["color"],
|
|
233
|
+
"is_filled": arguments["filled"],
|
|
234
|
+
"original_stroke": arguments["stroke"],
|
|
235
|
+
}
|
|
236
|
+
df = pd.DataFrame(df_data)
|
|
237
|
+
|
|
238
|
+
# To ensure distinct shapes according to agent portrayal
|
|
239
|
+
unique_shape_names_in_data = df["shape"].unique().tolist()
|
|
240
|
+
|
|
241
|
+
fill_colors = []
|
|
242
|
+
stroke_colors = []
|
|
243
|
+
for i in range(len(df)):
|
|
244
|
+
filled = df["is_filled"][i]
|
|
245
|
+
main_color = df["original_color"][i]
|
|
246
|
+
stroke_spec = (
|
|
247
|
+
df["original_stroke"][i]
|
|
248
|
+
if isinstance(df["original_stroke"][i], str)
|
|
249
|
+
else None
|
|
250
|
+
)
|
|
251
|
+
if filled:
|
|
252
|
+
fill_colors.append(main_color)
|
|
253
|
+
stroke_colors.append(stroke_spec)
|
|
254
|
+
else:
|
|
255
|
+
fill_colors.append(None)
|
|
256
|
+
stroke_colors.append(main_color)
|
|
257
|
+
df["viz_fill_color"] = fill_colors
|
|
258
|
+
df["viz_stroke_color"] = stroke_colors
|
|
259
|
+
|
|
260
|
+
# Extract additional parameters from kwargs
|
|
261
|
+
# FIXME: Add more parameters to kwargs
|
|
262
|
+
title = kwargs.pop("title", "")
|
|
263
|
+
xlabel = kwargs.pop("xlabel", "")
|
|
264
|
+
ylabel = kwargs.pop("ylabel", "")
|
|
265
|
+
|
|
266
|
+
# Tooltip list for interactivity
|
|
267
|
+
# FIXME: Add more fields to tooltip (preferably from agent_portrayal)
|
|
268
|
+
tooltip_list = ["x", "y"]
|
|
269
|
+
|
|
270
|
+
# Handle custom colormapping
|
|
271
|
+
cmap = kwargs.pop("cmap", "viridis")
|
|
272
|
+
vmin = kwargs.pop("vmin", None)
|
|
273
|
+
vmax = kwargs.pop("vmax", None)
|
|
274
|
+
|
|
275
|
+
color_is_numeric = np.issubdtype(df["original_color"].dtype, np.number)
|
|
276
|
+
if color_is_numeric:
|
|
277
|
+
color_min = vmin if vmin is not None else df["original_color"].min()
|
|
278
|
+
color_max = vmax if vmax is not None else df["original_color"].max()
|
|
279
|
+
|
|
280
|
+
fill_encoding = alt.Fill(
|
|
281
|
+
"original_color:Q",
|
|
282
|
+
scale=alt.Scale(scheme=cmap, domain=[color_min, color_max]),
|
|
283
|
+
)
|
|
284
|
+
else:
|
|
285
|
+
fill_encoding = alt.Fill(
|
|
286
|
+
"viz_fill_color:N",
|
|
287
|
+
scale=None,
|
|
288
|
+
title="Color",
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
# Determine space dimensions
|
|
292
|
+
xmin, xmax, ymin, ymax = self.space_drawer.get_viz_limits()
|
|
293
|
+
|
|
294
|
+
chart = (
|
|
295
|
+
alt.Chart(df)
|
|
296
|
+
.mark_point()
|
|
297
|
+
.encode(
|
|
298
|
+
x=alt.X(
|
|
299
|
+
"x:Q",
|
|
300
|
+
title=xlabel,
|
|
301
|
+
scale=alt.Scale(type="linear", domain=[xmin, xmax]),
|
|
302
|
+
axis=None,
|
|
303
|
+
),
|
|
304
|
+
y=alt.Y(
|
|
305
|
+
"y:Q",
|
|
306
|
+
title=ylabel,
|
|
307
|
+
scale=alt.Scale(type="linear", domain=[ymin, ymax]),
|
|
308
|
+
axis=None,
|
|
309
|
+
),
|
|
310
|
+
size=alt.Size("size:Q", legend=None, scale=alt.Scale(domain=[0, 50])),
|
|
311
|
+
shape=alt.Shape(
|
|
312
|
+
"shape:N",
|
|
313
|
+
scale=alt.Scale(
|
|
314
|
+
domain=unique_shape_names_in_data,
|
|
315
|
+
range=unique_shape_names_in_data,
|
|
316
|
+
),
|
|
317
|
+
title="Shape",
|
|
318
|
+
),
|
|
319
|
+
opacity=alt.Opacity(
|
|
320
|
+
"opacity:Q",
|
|
321
|
+
title="Opacity",
|
|
322
|
+
scale=alt.Scale(domain=[0, 1], range=[0, 1]),
|
|
323
|
+
),
|
|
324
|
+
fill=fill_encoding,
|
|
325
|
+
stroke=alt.Stroke("viz_stroke_color:N", scale=None),
|
|
326
|
+
strokeWidth=alt.StrokeWidth(
|
|
327
|
+
"strokeWidth:Q", scale=alt.Scale(domain=[0, 1])
|
|
328
|
+
),
|
|
329
|
+
tooltip=tooltip_list,
|
|
330
|
+
)
|
|
331
|
+
.properties(title=title, width=chart_width, height=chart_height)
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
return chart
|
|
335
|
+
|
|
336
|
+
def draw_propertylayer(
|
|
337
|
+
self,
|
|
338
|
+
space,
|
|
339
|
+
property_layers: dict[str, Any],
|
|
340
|
+
propertylayer_portrayal: Callable,
|
|
341
|
+
chart_width: int = 450,
|
|
342
|
+
chart_height: int = 350,
|
|
343
|
+
):
|
|
344
|
+
"""Draw property layers using Altair backend.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
space: The Mesa space object containing the property layers.
|
|
348
|
+
property_layers: A dictionary of property layers to draw.
|
|
349
|
+
propertylayer_portrayal: A function that returns PropertyLayerStyle
|
|
350
|
+
that contains the visualization parameters.
|
|
351
|
+
chart_width: The width of the chart.
|
|
352
|
+
chart_height: The height of the chart.
|
|
353
|
+
|
|
354
|
+
Returns:
|
|
355
|
+
alt.Chart: A tuple containing the base chart and the color bar chart.
|
|
356
|
+
"""
|
|
357
|
+
main_charts = []
|
|
358
|
+
|
|
359
|
+
for layer_name in property_layers:
|
|
360
|
+
if layer_name == "empty":
|
|
361
|
+
continue
|
|
362
|
+
|
|
363
|
+
layer = property_layers.get(layer_name)
|
|
364
|
+
portrayal = propertylayer_portrayal(layer)
|
|
365
|
+
|
|
366
|
+
if portrayal is None:
|
|
367
|
+
continue
|
|
368
|
+
|
|
369
|
+
data = layer.data.astype(float) if layer.data.dtype == bool else layer.data
|
|
370
|
+
|
|
371
|
+
# Check dimensions
|
|
372
|
+
if (space.width, space.height) != data.shape:
|
|
373
|
+
warnings.warn(
|
|
374
|
+
f"Layer {layer_name} dimensions ({data.shape}) "
|
|
375
|
+
f"don't match space dimensions ({space.width}, {space.height})",
|
|
376
|
+
UserWarning,
|
|
377
|
+
stacklevel=2,
|
|
378
|
+
)
|
|
379
|
+
continue
|
|
380
|
+
|
|
381
|
+
# Get portrayal parameters
|
|
382
|
+
color = portrayal.color
|
|
383
|
+
colormap = portrayal.colormap
|
|
384
|
+
alpha = portrayal.alpha
|
|
385
|
+
vmin = portrayal.vmin if portrayal.vmin is not None else np.min(data)
|
|
386
|
+
vmax = portrayal.vmax if portrayal.vmax is not None else np.max(data)
|
|
387
|
+
|
|
388
|
+
df = pd.DataFrame(
|
|
389
|
+
{
|
|
390
|
+
"x": np.repeat(np.arange(data.shape[0]), data.shape[1]),
|
|
391
|
+
"y": np.tile(np.arange(data.shape[1]), data.shape[0]),
|
|
392
|
+
"value": data.flatten(),
|
|
393
|
+
}
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
if color:
|
|
397
|
+
# For a single color gradient, we define the range from transparent to solid.
|
|
398
|
+
rgb = to_rgb(color)
|
|
399
|
+
r, g, b = (int(c * 255) for c in rgb)
|
|
400
|
+
|
|
401
|
+
min_color = f"rgba({r},{g},{b},0)"
|
|
402
|
+
max_color = f"rgba({r},{g},{b},{alpha})"
|
|
403
|
+
opacity = 1
|
|
404
|
+
color_scale = alt.Scale(
|
|
405
|
+
range=[min_color, max_color], domain=[vmin, vmax]
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
elif colormap:
|
|
409
|
+
cmap = colormap
|
|
410
|
+
color_scale = alt.Scale(scheme=cmap, domain=[vmin, vmax])
|
|
411
|
+
opacity = alpha
|
|
412
|
+
|
|
413
|
+
else:
|
|
414
|
+
raise ValueError(
|
|
415
|
+
f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'."
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
current_chart = (
|
|
419
|
+
alt.Chart(df)
|
|
420
|
+
.mark_rect(opacity=opacity)
|
|
421
|
+
.encode(
|
|
422
|
+
x=alt.X("x:O", axis=None),
|
|
423
|
+
y=alt.Y("y:O", axis=None),
|
|
424
|
+
color=alt.Color(
|
|
425
|
+
"value:Q",
|
|
426
|
+
scale=color_scale,
|
|
427
|
+
title=layer_name,
|
|
428
|
+
legend=alt.Legend(title=layer_name, orient="bottom")
|
|
429
|
+
if portrayal.colorbar
|
|
430
|
+
else None,
|
|
431
|
+
),
|
|
432
|
+
)
|
|
433
|
+
.properties(width=chart_width, height=chart_height)
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
if current_chart is not None:
|
|
437
|
+
main_charts.append(current_chart)
|
|
438
|
+
|
|
439
|
+
base = alt.layer(*main_charts).resolve_scale(color="independent")
|
|
440
|
+
return base
|