Mesa 3.2.0.dev0__py3-none-any.whl → 3.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of Mesa might be problematic. Click here for more details.
- mesa/__init__.py +1 -1
- mesa/agent.py +9 -7
- mesa/datacollection.py +1 -1
- mesa/examples/README.md +1 -1
- mesa/examples/__init__.py +2 -0
- mesa/examples/advanced/alliance_formation/Readme.md +50 -0
- mesa/examples/advanced/alliance_formation/__init__ .py +0 -0
- mesa/examples/advanced/alliance_formation/agents.py +20 -0
- mesa/examples/advanced/alliance_formation/app.py +71 -0
- mesa/examples/advanced/alliance_formation/model.py +184 -0
- mesa/examples/advanced/epstein_civil_violence/app.py +11 -11
- mesa/examples/advanced/pd_grid/Readme.md +4 -6
- mesa/examples/advanced/pd_grid/app.py +10 -11
- mesa/examples/advanced/sugarscape_g1mt/Readme.md +4 -5
- mesa/examples/advanced/sugarscape_g1mt/app.py +34 -16
- mesa/examples/advanced/wolf_sheep/Readme.md +2 -17
- mesa/examples/advanced/wolf_sheep/app.py +21 -18
- mesa/examples/basic/boid_flockers/Readme.md +6 -1
- mesa/examples/basic/boid_flockers/app.py +15 -11
- mesa/examples/basic/boltzmann_wealth_model/Readme.md +2 -12
- mesa/examples/basic/boltzmann_wealth_model/app.py +39 -32
- mesa/examples/basic/conways_game_of_life/Readme.md +1 -9
- mesa/examples/basic/conways_game_of_life/app.py +13 -16
- mesa/examples/basic/schelling/Readme.md +2 -10
- mesa/examples/basic/schelling/agents.py +9 -3
- mesa/examples/basic/schelling/app.py +50 -3
- mesa/examples/basic/schelling/model.py +2 -0
- mesa/examples/basic/schelling/resources/blue_happy.png +0 -0
- mesa/examples/basic/schelling/resources/blue_unhappy.png +0 -0
- mesa/examples/basic/schelling/resources/orange_happy.png +0 -0
- mesa/examples/basic/schelling/resources/orange_unhappy.png +0 -0
- mesa/examples/basic/virus_on_network/Readme.md +0 -4
- mesa/examples/basic/virus_on_network/app.py +31 -14
- mesa/experimental/__init__.py +2 -2
- mesa/experimental/continuous_space/continuous_space.py +1 -1
- mesa/experimental/meta_agents/__init__.py +25 -0
- mesa/experimental/meta_agents/meta_agent.py +387 -0
- mesa/model.py +3 -3
- mesa/space.py +4 -1
- mesa/visualization/__init__.py +2 -0
- mesa/visualization/backends/__init__.py +23 -0
- mesa/visualization/backends/abstract_renderer.py +97 -0
- mesa/visualization/backends/altair_backend.py +440 -0
- mesa/visualization/backends/matplotlib_backend.py +419 -0
- mesa/visualization/components/__init__.py +28 -8
- mesa/visualization/components/altair_components.py +86 -0
- mesa/visualization/components/matplotlib_components.py +4 -2
- mesa/visualization/components/portrayal_components.py +120 -0
- mesa/visualization/mpl_space_drawing.py +292 -129
- mesa/visualization/solara_viz.py +274 -32
- mesa/visualization/space_drawers.py +797 -0
- mesa/visualization/space_renderer.py +399 -0
- {mesa-3.2.0.dev0.dist-info → mesa-3.3.0.dist-info}/METADATA +13 -4
- {mesa-3.2.0.dev0.dist-info → mesa-3.3.0.dist-info}/RECORD +57 -40
- mesa/examples/advanced/sugarscape_g1mt/tests.py +0 -69
- {mesa-3.2.0.dev0.dist-info → mesa-3.3.0.dist-info}/WHEEL +0 -0
- {mesa-3.2.0.dev0.dist-info → mesa-3.3.0.dist-info}/licenses/LICENSE +0 -0
- {mesa-3.2.0.dev0.dist-info → mesa-3.3.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,797 @@
|
|
|
1
|
+
"""Mesa visualization space drawers.
|
|
2
|
+
|
|
3
|
+
This module provides the core logic for drawing spaces in Mesa, supporting
|
|
4
|
+
orthogonal grids, hexagonal grids, networks, continuous spaces, and Voronoi grids.
|
|
5
|
+
It includes implementations for both Matplotlib and Altair backends.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import itertools
|
|
9
|
+
from itertools import pairwise
|
|
10
|
+
|
|
11
|
+
import altair as alt
|
|
12
|
+
import matplotlib.pyplot as plt
|
|
13
|
+
import networkx as nx
|
|
14
|
+
import numpy as np
|
|
15
|
+
import pandas as pd
|
|
16
|
+
from matplotlib.collections import LineCollection
|
|
17
|
+
|
|
18
|
+
import mesa
|
|
19
|
+
from mesa.discrete_space import (
|
|
20
|
+
OrthogonalMooreGrid,
|
|
21
|
+
OrthogonalVonNeumannGrid,
|
|
22
|
+
VoronoiGrid,
|
|
23
|
+
)
|
|
24
|
+
from mesa.space import (
|
|
25
|
+
ContinuousSpace,
|
|
26
|
+
HexMultiGrid,
|
|
27
|
+
HexSingleGrid,
|
|
28
|
+
MultiGrid,
|
|
29
|
+
NetworkGrid,
|
|
30
|
+
SingleGrid,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
OrthogonalGrid = SingleGrid | MultiGrid | OrthogonalMooreGrid | OrthogonalVonNeumannGrid
|
|
34
|
+
HexGrid = HexSingleGrid | HexMultiGrid | mesa.discrete_space.HexGrid
|
|
35
|
+
Network = NetworkGrid | mesa.discrete_space.Network
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class BaseSpaceDrawer:
|
|
39
|
+
"""Base class for all space drawers."""
|
|
40
|
+
|
|
41
|
+
def __init__(self, space):
|
|
42
|
+
"""Initialize the base space drawer.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
space: Grid/Space type to draw.
|
|
46
|
+
"""
|
|
47
|
+
self.space = space
|
|
48
|
+
self.viz_xmin = None
|
|
49
|
+
self.viz_xmax = None
|
|
50
|
+
self.viz_ymin = None
|
|
51
|
+
self.viz_ymax = None
|
|
52
|
+
|
|
53
|
+
def get_viz_limits(self):
|
|
54
|
+
"""Get visualization limits for the space.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
A tuple of (xmin, xmax, ymin, ymax) for visualization limits.
|
|
58
|
+
"""
|
|
59
|
+
return (
|
|
60
|
+
self.viz_xmin,
|
|
61
|
+
self.viz_xmax,
|
|
62
|
+
self.viz_ymin,
|
|
63
|
+
self.viz_ymax,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class OrthogonalSpaceDrawer(BaseSpaceDrawer):
|
|
68
|
+
"""Drawer for orthogonal grid spaces (SingleGrid, MultiGrid, Moore, VonNeumann)."""
|
|
69
|
+
|
|
70
|
+
def __init__(self, space: OrthogonalGrid):
|
|
71
|
+
"""Initialize the orthogonal space drawer.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
space: The orthogonal grid space to draw
|
|
75
|
+
"""
|
|
76
|
+
super().__init__(space)
|
|
77
|
+
self.s_default = (180 / max(self.space.width, self.space.height)) ** 2
|
|
78
|
+
|
|
79
|
+
# Parameters for visualization limits
|
|
80
|
+
self.viz_xmin = -0.5
|
|
81
|
+
self.viz_xmax = self.space.width - 0.5
|
|
82
|
+
self.viz_ymin = -0.5
|
|
83
|
+
self.viz_ymax = self.space.height - 0.5
|
|
84
|
+
|
|
85
|
+
def draw_matplotlib(self, ax=None, **space_kwargs):
|
|
86
|
+
"""Draw the orthogonal grid using matplotlib.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
ax: Matplotlib axes object to draw on
|
|
90
|
+
**space_kwargs: Additional keyword arguments for styling.
|
|
91
|
+
|
|
92
|
+
Examples:
|
|
93
|
+
figsize=(10, 10), color="blue", linewidth=2.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
The modified axes object
|
|
97
|
+
"""
|
|
98
|
+
fig_kwargs = {
|
|
99
|
+
"figsize": space_kwargs.pop("figsize", (8, 8)),
|
|
100
|
+
"dpi": space_kwargs.pop("dpi", 100),
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
if ax is None:
|
|
104
|
+
fig, ax = plt.subplots(**fig_kwargs)
|
|
105
|
+
|
|
106
|
+
# gridline styling kwargs
|
|
107
|
+
line_kwargs = {
|
|
108
|
+
"color": "gray",
|
|
109
|
+
"linestyle": ":",
|
|
110
|
+
"linewidth": 1,
|
|
111
|
+
"alpha": 1,
|
|
112
|
+
}
|
|
113
|
+
line_kwargs.update(space_kwargs)
|
|
114
|
+
|
|
115
|
+
ax.set_xlim(self.viz_xmin, self.viz_xmax)
|
|
116
|
+
ax.set_ylim(self.viz_ymin, self.viz_ymax)
|
|
117
|
+
|
|
118
|
+
# Draw grid lines
|
|
119
|
+
for x in np.arange(-0.5, self.space.width, 1):
|
|
120
|
+
ax.axvline(x, **line_kwargs)
|
|
121
|
+
for y in np.arange(-0.5, self.space.height, 1):
|
|
122
|
+
ax.axhline(y, **line_kwargs)
|
|
123
|
+
|
|
124
|
+
return ax
|
|
125
|
+
|
|
126
|
+
def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs):
|
|
127
|
+
"""Draw the orthogonal grid using Altair.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
chart_width: Width for the shown chart
|
|
131
|
+
chart_height: Height for the shown chart
|
|
132
|
+
**chart_kwargs: Additional keyword arguments for styling the chart.
|
|
133
|
+
|
|
134
|
+
Examples:
|
|
135
|
+
width=500, height=500, title="Grid".
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
Altair chart object
|
|
139
|
+
"""
|
|
140
|
+
# for axis and grid styling
|
|
141
|
+
axis_kwargs = {
|
|
142
|
+
"xlabel": chart_kwargs.pop("xlabel", "X"),
|
|
143
|
+
"ylabel": chart_kwargs.pop("ylabel", "Y"),
|
|
144
|
+
"grid_color": chart_kwargs.pop("grid_color", "lightgray"),
|
|
145
|
+
"grid_dash": chart_kwargs.pop("grid_dash", [2, 2]),
|
|
146
|
+
"grid_width": chart_kwargs.pop("grid_width", 1),
|
|
147
|
+
"grid_opacity": chart_kwargs.pop("grid_opacity", 1),
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
# for chart properties
|
|
151
|
+
chart_props = {
|
|
152
|
+
"width": chart_width,
|
|
153
|
+
"height": chart_height,
|
|
154
|
+
}
|
|
155
|
+
chart_props.update(chart_kwargs)
|
|
156
|
+
|
|
157
|
+
chart = (
|
|
158
|
+
alt.Chart(pd.DataFrame([{}]))
|
|
159
|
+
.mark_point(opacity=0)
|
|
160
|
+
.encode(
|
|
161
|
+
x=alt.X(
|
|
162
|
+
"X:Q",
|
|
163
|
+
title=axis_kwargs["xlabel"],
|
|
164
|
+
scale=alt.Scale(domain=[self.viz_xmin, self.viz_xmax], nice=False),
|
|
165
|
+
axis=alt.Axis(
|
|
166
|
+
grid=True,
|
|
167
|
+
gridColor=axis_kwargs["grid_color"],
|
|
168
|
+
gridDash=axis_kwargs["grid_dash"],
|
|
169
|
+
gridWidth=axis_kwargs["grid_width"],
|
|
170
|
+
gridOpacity=axis_kwargs["grid_opacity"],
|
|
171
|
+
),
|
|
172
|
+
),
|
|
173
|
+
y=alt.Y(
|
|
174
|
+
"Y:Q",
|
|
175
|
+
title=axis_kwargs["ylabel"],
|
|
176
|
+
scale=alt.Scale(domain=[self.viz_ymin, self.viz_ymax], nice=False),
|
|
177
|
+
axis=alt.Axis(
|
|
178
|
+
grid=True,
|
|
179
|
+
gridColor=axis_kwargs["grid_color"],
|
|
180
|
+
gridDash=axis_kwargs["grid_dash"],
|
|
181
|
+
gridWidth=axis_kwargs["grid_width"],
|
|
182
|
+
gridOpacity=axis_kwargs["grid_opacity"],
|
|
183
|
+
),
|
|
184
|
+
),
|
|
185
|
+
)
|
|
186
|
+
.properties(**chart_props)
|
|
187
|
+
)
|
|
188
|
+
return chart
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class HexSpaceDrawer(BaseSpaceDrawer):
|
|
192
|
+
"""Drawer for hexagonal grid spaces."""
|
|
193
|
+
|
|
194
|
+
def __init__(self, space: HexGrid):
|
|
195
|
+
"""Initialize the hexagonal space drawer.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
space: The hexagonal grid space to draw
|
|
199
|
+
"""
|
|
200
|
+
super().__init__(space)
|
|
201
|
+
self.s_default = (180 / max(self.space.width, self.space.height)) ** 2
|
|
202
|
+
size = 1.0
|
|
203
|
+
self.x_spacing = np.sqrt(3) * size
|
|
204
|
+
self.y_spacing = 1.5 * size
|
|
205
|
+
|
|
206
|
+
x_max = self.space.width * self.x_spacing + (self.space.height % 2) * (
|
|
207
|
+
self.x_spacing / 2
|
|
208
|
+
)
|
|
209
|
+
y_max = self.space.height * self.y_spacing
|
|
210
|
+
|
|
211
|
+
x_padding = size * np.sqrt(3) / 2
|
|
212
|
+
y_padding = size
|
|
213
|
+
|
|
214
|
+
self.hexagons = self._get_hexmesh(self.space.width, self.space.height, size)
|
|
215
|
+
|
|
216
|
+
# Parameters for visualization limits
|
|
217
|
+
self.viz_xmin = -1.8 * x_padding
|
|
218
|
+
self.viz_xmax = x_max
|
|
219
|
+
self.viz_ymin = -1.8 * y_padding
|
|
220
|
+
self.viz_ymax = y_max
|
|
221
|
+
|
|
222
|
+
def _get_hexmesh(
|
|
223
|
+
self, width: int, height: int, size: float = 1.0
|
|
224
|
+
) -> list[tuple[float, float]]:
|
|
225
|
+
"""Generate hexagon vertices for the mesh. Yields list of vertex coordinates for each hexagon."""
|
|
226
|
+
|
|
227
|
+
# Helper function for getting the vertices of a hexagon given the center and size
|
|
228
|
+
def _get_hex_vertices(
|
|
229
|
+
center_x: float, center_y: float, size: float = 1.0
|
|
230
|
+
) -> list[tuple[float, float]]:
|
|
231
|
+
"""Get vertices for a hexagon centered at (center_x, center_y)."""
|
|
232
|
+
vertices = [
|
|
233
|
+
(center_x, center_y + size), # top
|
|
234
|
+
(center_x + size * np.sqrt(3) / 2, center_y + size / 2), # top right
|
|
235
|
+
(center_x + size * np.sqrt(3) / 2, center_y - size / 2), # bottom right
|
|
236
|
+
(center_x, center_y - size), # bottom
|
|
237
|
+
(center_x - size * np.sqrt(3) / 2, center_y - size / 2), # bottom left
|
|
238
|
+
(center_x - size * np.sqrt(3) / 2, center_y + size / 2), # top left
|
|
239
|
+
]
|
|
240
|
+
return vertices
|
|
241
|
+
|
|
242
|
+
x_spacing = np.sqrt(3) * size
|
|
243
|
+
y_spacing = 1.5 * size
|
|
244
|
+
hexagons = []
|
|
245
|
+
|
|
246
|
+
for row, col in itertools.product(range(height), range(width)):
|
|
247
|
+
# Calculate center position with offset for even rows
|
|
248
|
+
x = col * x_spacing + (row % 2 == 0) * (x_spacing / 2)
|
|
249
|
+
y = row * y_spacing
|
|
250
|
+
hexagons.append(_get_hex_vertices(x, y, size))
|
|
251
|
+
|
|
252
|
+
return hexagons
|
|
253
|
+
|
|
254
|
+
def _get_unique_edges(self):
|
|
255
|
+
"""Helper method to extract unique edges from all hexagons."""
|
|
256
|
+
edges = set()
|
|
257
|
+
# Generate edges for each hexagon
|
|
258
|
+
for vertices in self.hexagons:
|
|
259
|
+
# Edge logic, connecting each vertex to the next
|
|
260
|
+
for v1, v2 in pairwise([*vertices, vertices[0]]):
|
|
261
|
+
# Sort vertices to ensure consistent edge representation and avoid duplicates.
|
|
262
|
+
edge = tuple(sorted([tuple(np.round(v1, 6)), tuple(np.round(v2, 6))]))
|
|
263
|
+
edges.add(edge)
|
|
264
|
+
return edges
|
|
265
|
+
|
|
266
|
+
def draw_matplotlib(self, ax=None, **space_kwargs):
|
|
267
|
+
"""Draw the hexagonal grid using matplotlib.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
ax: Matplotlib axes object to draw on
|
|
271
|
+
**space_kwargs: Additional keyword arguments for styling.
|
|
272
|
+
|
|
273
|
+
Examples:
|
|
274
|
+
figsize=(8, 8), color="red", alpha=0.5.
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
The modified axes object
|
|
278
|
+
"""
|
|
279
|
+
fig_kwargs = {
|
|
280
|
+
"figsize": space_kwargs.pop("figsize", (8, 8)),
|
|
281
|
+
"dpi": space_kwargs.pop("dpi", 100),
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
if ax is None:
|
|
285
|
+
fig, ax = plt.subplots(**fig_kwargs)
|
|
286
|
+
|
|
287
|
+
line_kwargs = {
|
|
288
|
+
"color": "black",
|
|
289
|
+
"linestyle": ":",
|
|
290
|
+
"linewidth": 1,
|
|
291
|
+
"alpha": 0.8,
|
|
292
|
+
}
|
|
293
|
+
line_kwargs.update(space_kwargs)
|
|
294
|
+
|
|
295
|
+
ax.set_xlim(self.viz_xmin, self.viz_xmax)
|
|
296
|
+
ax.set_ylim(self.viz_ymin, self.viz_ymax)
|
|
297
|
+
ax.set_aspect("equal", adjustable="box")
|
|
298
|
+
|
|
299
|
+
edges = self._get_unique_edges()
|
|
300
|
+
ax.add_collection(LineCollection(list(edges), **line_kwargs))
|
|
301
|
+
return ax
|
|
302
|
+
|
|
303
|
+
def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs):
|
|
304
|
+
"""Draw the hexagonal grid using Altair.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
chart_width: Width for the shown chart
|
|
308
|
+
chart_height: Height for the shown chart
|
|
309
|
+
**chart_kwargs: Additional keyword arguments for styling the chart.
|
|
310
|
+
|
|
311
|
+
Examples:
|
|
312
|
+
* Line properties like color, strokeDash, strokeWidth, opacity.
|
|
313
|
+
* Other kwargs (e.g., width, title) apply to the chart.
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
Altair chart object representing the hexagonal grid.
|
|
317
|
+
"""
|
|
318
|
+
mark_kwargs = {
|
|
319
|
+
"color": chart_kwargs.pop("color", "black"),
|
|
320
|
+
"strokeDash": chart_kwargs.pop("strokeDash", [2, 2]),
|
|
321
|
+
"strokeWidth": chart_kwargs.pop("strokeWidth", 1),
|
|
322
|
+
"opacity": chart_kwargs.pop("opacity", 0.8),
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
chart_props = {
|
|
326
|
+
"width": chart_width,
|
|
327
|
+
"height": chart_height,
|
|
328
|
+
}
|
|
329
|
+
chart_props.update(chart_kwargs)
|
|
330
|
+
|
|
331
|
+
edge_data = []
|
|
332
|
+
edges = self._get_unique_edges()
|
|
333
|
+
|
|
334
|
+
for i, edge_tuple in enumerate(edges):
|
|
335
|
+
p1, p2 = edge_tuple
|
|
336
|
+
edge_data.append({"edge_id": i, "point_order": 0, "x": p1[0], "y": p1[1]})
|
|
337
|
+
edge_data.append({"edge_id": i, "point_order": 1, "x": p2[0], "y": p2[1]})
|
|
338
|
+
|
|
339
|
+
source = pd.DataFrame(edge_data)
|
|
340
|
+
|
|
341
|
+
chart = (
|
|
342
|
+
alt.Chart(source)
|
|
343
|
+
.mark_line(**mark_kwargs)
|
|
344
|
+
.encode(
|
|
345
|
+
x=alt.X(
|
|
346
|
+
"x:Q",
|
|
347
|
+
scale=alt.Scale(domain=[self.viz_xmin, self.viz_xmax], zero=False),
|
|
348
|
+
axis=None,
|
|
349
|
+
),
|
|
350
|
+
y=alt.Y(
|
|
351
|
+
"y:Q",
|
|
352
|
+
scale=alt.Scale(domain=[self.viz_ymin, self.viz_ymax], zero=False),
|
|
353
|
+
axis=None,
|
|
354
|
+
),
|
|
355
|
+
detail="edge_id:N",
|
|
356
|
+
order="point_order:Q",
|
|
357
|
+
)
|
|
358
|
+
.properties(**chart_props)
|
|
359
|
+
)
|
|
360
|
+
return chart
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
class NetworkSpaceDrawer(BaseSpaceDrawer):
|
|
364
|
+
"""Drawer for network-based spaces."""
|
|
365
|
+
|
|
366
|
+
def __init__(
|
|
367
|
+
self,
|
|
368
|
+
space: Network,
|
|
369
|
+
layout_alg=nx.spring_layout,
|
|
370
|
+
layout_kwargs=None,
|
|
371
|
+
):
|
|
372
|
+
"""Initialize the network space drawer.
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
space: The network space to draw
|
|
376
|
+
layout_alg: NetworkX layout algorithm to use
|
|
377
|
+
layout_kwargs: Keyword arguments for the layout algorithm
|
|
378
|
+
"""
|
|
379
|
+
super().__init__(space)
|
|
380
|
+
self.layout_alg = layout_alg
|
|
381
|
+
self.layout_kwargs = layout_kwargs if layout_kwargs is not None else {"seed": 0}
|
|
382
|
+
|
|
383
|
+
# gather locations for nodes in network
|
|
384
|
+
self.graph = self.space.G
|
|
385
|
+
self.pos = self.layout_alg(self.graph, **self.layout_kwargs)
|
|
386
|
+
|
|
387
|
+
x, y = list(zip(*self.pos.values())) if self.pos else ([0], [0])
|
|
388
|
+
xmin, xmax = min(x), max(x)
|
|
389
|
+
ymin, ymax = min(y), max(y)
|
|
390
|
+
|
|
391
|
+
width = xmax - xmin
|
|
392
|
+
height = ymax - ymin
|
|
393
|
+
self.s_default = (
|
|
394
|
+
(180 / max(width, height)) ** 2 if width > 0 or height > 0 else 1
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
# Parameters for visualization limits
|
|
398
|
+
self.viz_xmin = xmin - width / 20
|
|
399
|
+
self.viz_xmax = xmax + width / 20
|
|
400
|
+
self.viz_ymin = ymin - height / 20
|
|
401
|
+
self.viz_ymax = ymax + height / 20
|
|
402
|
+
|
|
403
|
+
def draw_matplotlib(self, ax=None, **space_kwargs):
|
|
404
|
+
"""Draw the network using matplotlib.
|
|
405
|
+
|
|
406
|
+
Args:
|
|
407
|
+
ax: Matplotlib axes object to draw on.
|
|
408
|
+
**space_kwargs: Dictionaries of keyword arguments for styling.
|
|
409
|
+
Can also handle zorder for both nodes and edges if passed.
|
|
410
|
+
* ``node_kwargs``: A dict passed to nx.draw_networkx_nodes.
|
|
411
|
+
* ``edge_kwargs``: A dict passed to nx.draw_networkx_edges.
|
|
412
|
+
|
|
413
|
+
Returns:
|
|
414
|
+
The modified axes object.
|
|
415
|
+
"""
|
|
416
|
+
if ax is None:
|
|
417
|
+
fig, ax = plt.subplots()
|
|
418
|
+
|
|
419
|
+
ax.set_axis_off()
|
|
420
|
+
ax.set_xlim(self.viz_xmin, self.viz_xmax)
|
|
421
|
+
ax.set_ylim(self.viz_ymin, self.viz_ymax)
|
|
422
|
+
|
|
423
|
+
node_kwargs = {"alpha": 0.5}
|
|
424
|
+
edge_kwargs = {"alpha": 0.5, "style": "--"}
|
|
425
|
+
|
|
426
|
+
node_kwargs.update(space_kwargs.get("node_kwargs", {}))
|
|
427
|
+
edge_kwargs.update(space_kwargs.get("edge_kwargs", {}))
|
|
428
|
+
|
|
429
|
+
node_zorder = node_kwargs.pop("zorder", 1)
|
|
430
|
+
edge_zorder = edge_kwargs.pop("zorder", 0)
|
|
431
|
+
|
|
432
|
+
nodes = nx.draw_networkx_nodes(self.graph, self.pos, ax=ax, **node_kwargs)
|
|
433
|
+
edges = nx.draw_networkx_edges(self.graph, self.pos, ax=ax, **edge_kwargs)
|
|
434
|
+
|
|
435
|
+
if nodes:
|
|
436
|
+
nodes.set_zorder(node_zorder)
|
|
437
|
+
# In some matplotlib versions, edges can be a list of collections
|
|
438
|
+
if isinstance(edges, list):
|
|
439
|
+
for edge_collection in edges:
|
|
440
|
+
edge_collection.set_zorder(edge_zorder)
|
|
441
|
+
elif edges:
|
|
442
|
+
edges.set_zorder(edge_zorder)
|
|
443
|
+
|
|
444
|
+
return ax
|
|
445
|
+
|
|
446
|
+
def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs):
|
|
447
|
+
"""Draw the network using Altair.
|
|
448
|
+
|
|
449
|
+
Args:
|
|
450
|
+
chart_width: Width for the shown chart
|
|
451
|
+
chart_height: Height for the shown chart
|
|
452
|
+
**chart_kwargs: Dictionaries for styling the chart.
|
|
453
|
+
* ``node_kwargs``: A dict of properties for the node's mark_point.
|
|
454
|
+
* ``edge_kwargs``: A dict of properties for the edge's mark_rule.
|
|
455
|
+
* Other kwargs (e.g., title, width) are passed to chart.properties().
|
|
456
|
+
|
|
457
|
+
Returns:
|
|
458
|
+
Altair chart object representing the network.
|
|
459
|
+
"""
|
|
460
|
+
nodes_df = pd.DataFrame(self.pos).T.reset_index()
|
|
461
|
+
nodes_df.columns = ["node", "x", "y"]
|
|
462
|
+
|
|
463
|
+
edges_df = pd.DataFrame(self.graph.edges(), columns=["source", "target"])
|
|
464
|
+
edge_positions = edges_df.merge(
|
|
465
|
+
nodes_df, how="left", left_on="source", right_on="node"
|
|
466
|
+
).merge(
|
|
467
|
+
nodes_df,
|
|
468
|
+
how="left",
|
|
469
|
+
left_on="target",
|
|
470
|
+
right_on="node",
|
|
471
|
+
suffixes=("_source", "_target"),
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
node_mark_kwargs = {"filled": True, "opacity": 0.5, "size": 500}
|
|
475
|
+
edge_mark_kwargs = {"opacity": 0.5, "strokeDash": [5, 3]}
|
|
476
|
+
|
|
477
|
+
node_mark_kwargs.update(chart_kwargs.pop("node_kwargs", {}))
|
|
478
|
+
edge_mark_kwargs.update(chart_kwargs.pop("edge_kwargs", {}))
|
|
479
|
+
|
|
480
|
+
chart_kwargs = {
|
|
481
|
+
"width": chart_width,
|
|
482
|
+
"height": chart_height,
|
|
483
|
+
}
|
|
484
|
+
chart_kwargs.update(chart_kwargs)
|
|
485
|
+
|
|
486
|
+
edge_plot = (
|
|
487
|
+
alt.Chart(edge_positions)
|
|
488
|
+
.mark_rule(**edge_mark_kwargs)
|
|
489
|
+
.encode(
|
|
490
|
+
x=alt.X(
|
|
491
|
+
"x_source",
|
|
492
|
+
scale=alt.Scale(domain=[self.viz_xmin, self.viz_xmax]),
|
|
493
|
+
axis=None,
|
|
494
|
+
),
|
|
495
|
+
y=alt.Y(
|
|
496
|
+
"y_source",
|
|
497
|
+
scale=alt.Scale(domain=[self.viz_ymin, self.viz_ymax]),
|
|
498
|
+
axis=None,
|
|
499
|
+
),
|
|
500
|
+
x2="x_target",
|
|
501
|
+
y2="y_target",
|
|
502
|
+
)
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
node_plot = (
|
|
506
|
+
alt.Chart(nodes_df)
|
|
507
|
+
.mark_point(**node_mark_kwargs)
|
|
508
|
+
.encode(x="x", y="y", tooltip=["node"])
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
chart = edge_plot + node_plot
|
|
512
|
+
|
|
513
|
+
if chart_kwargs:
|
|
514
|
+
chart = chart.properties(**chart_kwargs)
|
|
515
|
+
|
|
516
|
+
return chart
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
class ContinuousSpaceDrawer(BaseSpaceDrawer):
|
|
520
|
+
"""Drawer for continuous spaces."""
|
|
521
|
+
|
|
522
|
+
def __init__(self, space: ContinuousSpace):
|
|
523
|
+
"""Initialize the continuous space drawer.
|
|
524
|
+
|
|
525
|
+
Args:
|
|
526
|
+
space: The continuous space to draw
|
|
527
|
+
"""
|
|
528
|
+
super().__init__(space)
|
|
529
|
+
width = self.space.x_max - self.space.x_min
|
|
530
|
+
height = self.space.y_max - self.space.y_min
|
|
531
|
+
self.s_default = (
|
|
532
|
+
(180 / max(width, height)) ** 2 if width > 0 or height > 0 else 1
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
x_padding = width / 20
|
|
536
|
+
y_padding = height / 20
|
|
537
|
+
|
|
538
|
+
self.viz_xmin = self.space.x_min - x_padding
|
|
539
|
+
self.viz_xmax = self.space.x_max + x_padding
|
|
540
|
+
self.viz_ymin = self.space.y_min - y_padding
|
|
541
|
+
self.viz_ymax = self.space.y_max + y_padding
|
|
542
|
+
|
|
543
|
+
def draw_matplotlib(self, ax=None, **space_kwargs):
|
|
544
|
+
"""Draw the continuous space using matplotlib.
|
|
545
|
+
|
|
546
|
+
Args:
|
|
547
|
+
ax: Matplotlib axes object to draw on
|
|
548
|
+
**space_kwargs: Keyword arguments for styling the axis frame.
|
|
549
|
+
|
|
550
|
+
Examples:
|
|
551
|
+
linewidth=3, color="green"
|
|
552
|
+
|
|
553
|
+
Returns:
|
|
554
|
+
The modified axes object
|
|
555
|
+
"""
|
|
556
|
+
if ax is None:
|
|
557
|
+
fig, ax = plt.subplots()
|
|
558
|
+
|
|
559
|
+
border_style = "solid" if not self.space.torus else (0, (5, 10))
|
|
560
|
+
spine_kwargs = {"linewidth": 1.5, "color": "black", "linestyle": border_style}
|
|
561
|
+
spine_kwargs.update(space_kwargs)
|
|
562
|
+
|
|
563
|
+
for spine in ax.spines.values():
|
|
564
|
+
spine.set(**spine_kwargs)
|
|
565
|
+
|
|
566
|
+
ax.set_xlim(self.viz_xmin, self.viz_xmax)
|
|
567
|
+
ax.set_ylim(self.viz_ymin, self.viz_ymax)
|
|
568
|
+
|
|
569
|
+
return ax
|
|
570
|
+
|
|
571
|
+
def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs):
|
|
572
|
+
"""Draw the continuous space using Altair.
|
|
573
|
+
|
|
574
|
+
Args:
|
|
575
|
+
chart_width: Width for the shown chart
|
|
576
|
+
chart_height: Height for the shown chart
|
|
577
|
+
**chart_kwargs: Keyword arguments for styling the chart's view properties.
|
|
578
|
+
See Altair's documentation for `configure_view`.
|
|
579
|
+
|
|
580
|
+
Returns:
|
|
581
|
+
An Altair Chart object representing the space.
|
|
582
|
+
"""
|
|
583
|
+
chart_props = {"width": chart_width, "height": chart_height}
|
|
584
|
+
chart_props.update(chart_kwargs)
|
|
585
|
+
|
|
586
|
+
chart = (
|
|
587
|
+
alt.Chart(pd.DataFrame([{}]))
|
|
588
|
+
.mark_rect(color="transparent")
|
|
589
|
+
.encode(
|
|
590
|
+
x=alt.X(scale=alt.Scale(domain=[self.viz_xmin, self.viz_xmax])),
|
|
591
|
+
y=alt.Y(scale=alt.Scale(domain=[self.viz_ymin, self.viz_ymax])),
|
|
592
|
+
)
|
|
593
|
+
.properties(**chart_props)
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
return chart
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
class VoronoiSpaceDrawer(BaseSpaceDrawer):
|
|
600
|
+
"""Drawer for Voronoi diagram spaces."""
|
|
601
|
+
|
|
602
|
+
def __init__(self, space: VoronoiGrid):
|
|
603
|
+
"""Initialize the Voronoi space drawer.
|
|
604
|
+
|
|
605
|
+
Args:
|
|
606
|
+
space: The Voronoi grid space to draw
|
|
607
|
+
"""
|
|
608
|
+
super().__init__(space)
|
|
609
|
+
if self.space.centroids_coordinates:
|
|
610
|
+
x_list = [i[0] for i in self.space.centroids_coordinates]
|
|
611
|
+
y_list = [i[1] for i in self.space.centroids_coordinates]
|
|
612
|
+
x_max, x_min = max(x_list), min(x_list)
|
|
613
|
+
y_max, y_min = max(y_list), min(y_list)
|
|
614
|
+
else:
|
|
615
|
+
x_max, x_min, y_max, y_min = 1, 0, 1, 0
|
|
616
|
+
|
|
617
|
+
width = x_max - x_min
|
|
618
|
+
height = y_max - y_min
|
|
619
|
+
self.s_default = (
|
|
620
|
+
(180 / max(width, height)) ** 2 if width > 0 or height > 0 else 1
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
# Parameters for visualization limits
|
|
624
|
+
self.viz_xmin = x_min - width / 20
|
|
625
|
+
self.viz_xmax = x_max + width / 20
|
|
626
|
+
self.viz_ymin = y_min - height / 20
|
|
627
|
+
self.viz_ymax = y_max + height / 20
|
|
628
|
+
|
|
629
|
+
def _clip_line(self, p1, p2, box):
|
|
630
|
+
"""Clips a line segment using the Cohen-Sutherland algorithm.
|
|
631
|
+
|
|
632
|
+
Returns the clipped line segment (p1, p2) or None if it's outside.
|
|
633
|
+
"""
|
|
634
|
+
x1, y1 = p1
|
|
635
|
+
x2, y2 = p2
|
|
636
|
+
min_x, min_y, max_x, max_y = box
|
|
637
|
+
|
|
638
|
+
# Define region codes
|
|
639
|
+
INSIDE, LEFT, RIGHT, BOTTOM, TOP = 0, 1, 2, 4, 8 # noqa: N806
|
|
640
|
+
|
|
641
|
+
def compute_outcode(x, y):
|
|
642
|
+
code = INSIDE
|
|
643
|
+
if x < min_x:
|
|
644
|
+
code |= LEFT
|
|
645
|
+
elif x > max_x:
|
|
646
|
+
code |= RIGHT
|
|
647
|
+
if y < min_y:
|
|
648
|
+
code |= BOTTOM
|
|
649
|
+
elif y > max_y:
|
|
650
|
+
code |= TOP
|
|
651
|
+
return code
|
|
652
|
+
|
|
653
|
+
outcode1 = compute_outcode(x1, y1)
|
|
654
|
+
outcode2 = compute_outcode(x2, y2)
|
|
655
|
+
|
|
656
|
+
while True:
|
|
657
|
+
if not (outcode1 | outcode2): # Both points inside
|
|
658
|
+
return (x1, y1), (x2, y2)
|
|
659
|
+
elif outcode1 & outcode2: # Both points share an outside region
|
|
660
|
+
return None
|
|
661
|
+
else:
|
|
662
|
+
outcode_out = outcode1 if outcode1 else outcode2
|
|
663
|
+
x, y = 0.0, 0.0
|
|
664
|
+
|
|
665
|
+
# Check for horizontal line
|
|
666
|
+
if y1 != y2:
|
|
667
|
+
if outcode_out & TOP:
|
|
668
|
+
x = x1 + (x2 - x1) * (max_y - y1) / (y2 - y1)
|
|
669
|
+
y = max_y
|
|
670
|
+
elif outcode_out & BOTTOM:
|
|
671
|
+
x = x1 + (x2 - x1) * (min_y - y1) / (y2 - y1)
|
|
672
|
+
y = min_y
|
|
673
|
+
|
|
674
|
+
# Check for vertical line
|
|
675
|
+
if x1 != x2:
|
|
676
|
+
if outcode_out & RIGHT:
|
|
677
|
+
y = y1 + (y2 - y1) * (max_x - x1) / (x2 - x1)
|
|
678
|
+
x = max_x
|
|
679
|
+
elif outcode_out & LEFT:
|
|
680
|
+
y = y1 + (y2 - y1) * (min_x - x1) / (x2 - x1)
|
|
681
|
+
x = min_x
|
|
682
|
+
|
|
683
|
+
if outcode_out == outcode1:
|
|
684
|
+
x1, y1 = x, y
|
|
685
|
+
outcode1 = compute_outcode(x1, y1)
|
|
686
|
+
else:
|
|
687
|
+
x2, y2 = x, y
|
|
688
|
+
outcode2 = compute_outcode(x2, y2)
|
|
689
|
+
|
|
690
|
+
def _get_clipped_segments(self):
|
|
691
|
+
"""Helper method to perform the segment extraction, de-duplication and clipping logic."""
|
|
692
|
+
clip_box = (
|
|
693
|
+
self.viz_xmin,
|
|
694
|
+
self.viz_ymin,
|
|
695
|
+
self.viz_xmax,
|
|
696
|
+
self.viz_ymax,
|
|
697
|
+
)
|
|
698
|
+
|
|
699
|
+
unique_segments = set()
|
|
700
|
+
for cell in self.space.all_cells.cells:
|
|
701
|
+
vertices = [tuple(v) for v in cell.properties["polygon"]]
|
|
702
|
+
for p1, p2 in pairwise([*vertices, vertices[0]]):
|
|
703
|
+
# Sort to avoid duplicate segments going in opposite directions
|
|
704
|
+
unique_segments.add(tuple(sorted((p1, p2))))
|
|
705
|
+
|
|
706
|
+
# Clip each unique segment
|
|
707
|
+
final_segments = []
|
|
708
|
+
for p1, p2 in unique_segments:
|
|
709
|
+
clipped_segment = self._clip_line(p1, p2, clip_box)
|
|
710
|
+
if clipped_segment:
|
|
711
|
+
final_segments.append(clipped_segment)
|
|
712
|
+
|
|
713
|
+
return final_segments, clip_box
|
|
714
|
+
|
|
715
|
+
def draw_matplotlib(self, ax=None, **space_kwargs):
|
|
716
|
+
"""Draw the Voronoi diagram using matplotlib.
|
|
717
|
+
|
|
718
|
+
Args:
|
|
719
|
+
ax: Matplotlib axes object to draw on
|
|
720
|
+
**space_kwargs: Keyword arguments passed to matplotlib's LineCollection.
|
|
721
|
+
|
|
722
|
+
Examples:
|
|
723
|
+
lw=2, alpha=0.5, colors='red'
|
|
724
|
+
|
|
725
|
+
Returns:
|
|
726
|
+
The modified axes object
|
|
727
|
+
"""
|
|
728
|
+
if ax is None:
|
|
729
|
+
fig, ax = plt.subplots()
|
|
730
|
+
|
|
731
|
+
final_segments, clip_box = self._get_clipped_segments()
|
|
732
|
+
|
|
733
|
+
ax.set_xlim(clip_box[0], clip_box[2])
|
|
734
|
+
ax.set_ylim(clip_box[1], clip_box[3])
|
|
735
|
+
|
|
736
|
+
if final_segments:
|
|
737
|
+
# Define default styles for the plot
|
|
738
|
+
style_args = {"colors": "k", "linestyle": "dotted", "lw": 1}
|
|
739
|
+
style_args.update(space_kwargs)
|
|
740
|
+
|
|
741
|
+
# Create the LineCollection with the final styles
|
|
742
|
+
lc = LineCollection(final_segments, **style_args)
|
|
743
|
+
ax.add_collection(lc)
|
|
744
|
+
|
|
745
|
+
return ax
|
|
746
|
+
|
|
747
|
+
def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs):
|
|
748
|
+
"""Draw the Voronoi diagram using Altair.
|
|
749
|
+
|
|
750
|
+
Args:
|
|
751
|
+
chart_width: Width for the shown chart
|
|
752
|
+
chart_height: Height for the shown chart
|
|
753
|
+
**chart_kwargs: Additional keyword arguments for styling the chart.
|
|
754
|
+
|
|
755
|
+
Examples:
|
|
756
|
+
* Line properties like color, strokeDash, strokeWidth, opacity.
|
|
757
|
+
* Other kwargs (e.g., width, title) apply to the chart.
|
|
758
|
+
|
|
759
|
+
Returns:
|
|
760
|
+
An Altair Chart object representing the Voronoi diagram.
|
|
761
|
+
"""
|
|
762
|
+
final_segments, clip_box = self._get_clipped_segments()
|
|
763
|
+
|
|
764
|
+
# Prepare data
|
|
765
|
+
final_data = []
|
|
766
|
+
for i, (p1, p2) in enumerate(final_segments):
|
|
767
|
+
final_data.append({"x": p1[0], "y": p1[1], "line_id": i})
|
|
768
|
+
final_data.append({"x": p2[0], "y": p2[1], "line_id": i})
|
|
769
|
+
|
|
770
|
+
df = pd.DataFrame(final_data)
|
|
771
|
+
|
|
772
|
+
# Define default properties for the mark
|
|
773
|
+
mark_kwargs = {
|
|
774
|
+
"color": chart_kwargs.pop("color", "black"),
|
|
775
|
+
"strokeDash": chart_kwargs.pop("strokeDash", [2, 2]),
|
|
776
|
+
"strokeWidth": chart_kwargs.pop("strokeWidth", 1),
|
|
777
|
+
"opacity": chart_kwargs.pop("opacity", 0.8),
|
|
778
|
+
}
|
|
779
|
+
|
|
780
|
+
chart_props = {"width": chart_width, "height": chart_height}
|
|
781
|
+
chart_props.update(chart_kwargs)
|
|
782
|
+
|
|
783
|
+
chart = (
|
|
784
|
+
alt.Chart(df)
|
|
785
|
+
.mark_line(**mark_kwargs)
|
|
786
|
+
.encode(
|
|
787
|
+
x=alt.X(
|
|
788
|
+
"x:Q", scale=alt.Scale(domain=[clip_box[0], clip_box[2]]), axis=None
|
|
789
|
+
),
|
|
790
|
+
y=alt.Y(
|
|
791
|
+
"y:Q", scale=alt.Scale(domain=[clip_box[1], clip_box[3]]), axis=None
|
|
792
|
+
),
|
|
793
|
+
detail="line_id:N",
|
|
794
|
+
)
|
|
795
|
+
.properties(**chart_props)
|
|
796
|
+
)
|
|
797
|
+
return chart
|