Mesa 3.2.0.dev0__py3-none-any.whl → 3.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of Mesa might be problematic. Click here for more details.

Files changed (58) hide show
  1. mesa/__init__.py +1 -1
  2. mesa/agent.py +9 -7
  3. mesa/datacollection.py +1 -1
  4. mesa/examples/README.md +1 -1
  5. mesa/examples/__init__.py +2 -0
  6. mesa/examples/advanced/alliance_formation/Readme.md +50 -0
  7. mesa/examples/advanced/alliance_formation/__init__ .py +0 -0
  8. mesa/examples/advanced/alliance_formation/agents.py +20 -0
  9. mesa/examples/advanced/alliance_formation/app.py +71 -0
  10. mesa/examples/advanced/alliance_formation/model.py +184 -0
  11. mesa/examples/advanced/epstein_civil_violence/app.py +11 -11
  12. mesa/examples/advanced/pd_grid/Readme.md +4 -6
  13. mesa/examples/advanced/pd_grid/app.py +10 -11
  14. mesa/examples/advanced/sugarscape_g1mt/Readme.md +4 -5
  15. mesa/examples/advanced/sugarscape_g1mt/app.py +34 -16
  16. mesa/examples/advanced/wolf_sheep/Readme.md +2 -17
  17. mesa/examples/advanced/wolf_sheep/app.py +21 -18
  18. mesa/examples/basic/boid_flockers/Readme.md +6 -1
  19. mesa/examples/basic/boid_flockers/app.py +15 -11
  20. mesa/examples/basic/boltzmann_wealth_model/Readme.md +2 -12
  21. mesa/examples/basic/boltzmann_wealth_model/app.py +39 -32
  22. mesa/examples/basic/conways_game_of_life/Readme.md +1 -9
  23. mesa/examples/basic/conways_game_of_life/app.py +13 -16
  24. mesa/examples/basic/schelling/Readme.md +2 -10
  25. mesa/examples/basic/schelling/agents.py +9 -3
  26. mesa/examples/basic/schelling/app.py +50 -3
  27. mesa/examples/basic/schelling/model.py +2 -0
  28. mesa/examples/basic/schelling/resources/blue_happy.png +0 -0
  29. mesa/examples/basic/schelling/resources/blue_unhappy.png +0 -0
  30. mesa/examples/basic/schelling/resources/orange_happy.png +0 -0
  31. mesa/examples/basic/schelling/resources/orange_unhappy.png +0 -0
  32. mesa/examples/basic/virus_on_network/Readme.md +0 -4
  33. mesa/examples/basic/virus_on_network/app.py +31 -14
  34. mesa/experimental/__init__.py +2 -2
  35. mesa/experimental/continuous_space/continuous_space.py +1 -1
  36. mesa/experimental/meta_agents/__init__.py +25 -0
  37. mesa/experimental/meta_agents/meta_agent.py +387 -0
  38. mesa/model.py +3 -3
  39. mesa/space.py +4 -1
  40. mesa/visualization/__init__.py +2 -0
  41. mesa/visualization/backends/__init__.py +23 -0
  42. mesa/visualization/backends/abstract_renderer.py +97 -0
  43. mesa/visualization/backends/altair_backend.py +440 -0
  44. mesa/visualization/backends/matplotlib_backend.py +419 -0
  45. mesa/visualization/components/__init__.py +28 -8
  46. mesa/visualization/components/altair_components.py +86 -0
  47. mesa/visualization/components/matplotlib_components.py +4 -2
  48. mesa/visualization/components/portrayal_components.py +120 -0
  49. mesa/visualization/mpl_space_drawing.py +292 -129
  50. mesa/visualization/solara_viz.py +274 -32
  51. mesa/visualization/space_drawers.py +797 -0
  52. mesa/visualization/space_renderer.py +399 -0
  53. {mesa-3.2.0.dev0.dist-info → mesa-3.3.0.dist-info}/METADATA +13 -4
  54. {mesa-3.2.0.dev0.dist-info → mesa-3.3.0.dist-info}/RECORD +57 -40
  55. mesa/examples/advanced/sugarscape_g1mt/tests.py +0 -69
  56. {mesa-3.2.0.dev0.dist-info → mesa-3.3.0.dist-info}/WHEEL +0 -0
  57. {mesa-3.2.0.dev0.dist-info → mesa-3.3.0.dist-info}/licenses/LICENSE +0 -0
  58. {mesa-3.2.0.dev0.dist-info → mesa-3.3.0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,797 @@
1
+ """Mesa visualization space drawers.
2
+
3
+ This module provides the core logic for drawing spaces in Mesa, supporting
4
+ orthogonal grids, hexagonal grids, networks, continuous spaces, and Voronoi grids.
5
+ It includes implementations for both Matplotlib and Altair backends.
6
+ """
7
+
8
+ import itertools
9
+ from itertools import pairwise
10
+
11
+ import altair as alt
12
+ import matplotlib.pyplot as plt
13
+ import networkx as nx
14
+ import numpy as np
15
+ import pandas as pd
16
+ from matplotlib.collections import LineCollection
17
+
18
+ import mesa
19
+ from mesa.discrete_space import (
20
+ OrthogonalMooreGrid,
21
+ OrthogonalVonNeumannGrid,
22
+ VoronoiGrid,
23
+ )
24
+ from mesa.space import (
25
+ ContinuousSpace,
26
+ HexMultiGrid,
27
+ HexSingleGrid,
28
+ MultiGrid,
29
+ NetworkGrid,
30
+ SingleGrid,
31
+ )
32
+
33
+ OrthogonalGrid = SingleGrid | MultiGrid | OrthogonalMooreGrid | OrthogonalVonNeumannGrid
34
+ HexGrid = HexSingleGrid | HexMultiGrid | mesa.discrete_space.HexGrid
35
+ Network = NetworkGrid | mesa.discrete_space.Network
36
+
37
+
38
+ class BaseSpaceDrawer:
39
+ """Base class for all space drawers."""
40
+
41
+ def __init__(self, space):
42
+ """Initialize the base space drawer.
43
+
44
+ Args:
45
+ space: Grid/Space type to draw.
46
+ """
47
+ self.space = space
48
+ self.viz_xmin = None
49
+ self.viz_xmax = None
50
+ self.viz_ymin = None
51
+ self.viz_ymax = None
52
+
53
+ def get_viz_limits(self):
54
+ """Get visualization limits for the space.
55
+
56
+ Returns:
57
+ A tuple of (xmin, xmax, ymin, ymax) for visualization limits.
58
+ """
59
+ return (
60
+ self.viz_xmin,
61
+ self.viz_xmax,
62
+ self.viz_ymin,
63
+ self.viz_ymax,
64
+ )
65
+
66
+
67
+ class OrthogonalSpaceDrawer(BaseSpaceDrawer):
68
+ """Drawer for orthogonal grid spaces (SingleGrid, MultiGrid, Moore, VonNeumann)."""
69
+
70
+ def __init__(self, space: OrthogonalGrid):
71
+ """Initialize the orthogonal space drawer.
72
+
73
+ Args:
74
+ space: The orthogonal grid space to draw
75
+ """
76
+ super().__init__(space)
77
+ self.s_default = (180 / max(self.space.width, self.space.height)) ** 2
78
+
79
+ # Parameters for visualization limits
80
+ self.viz_xmin = -0.5
81
+ self.viz_xmax = self.space.width - 0.5
82
+ self.viz_ymin = -0.5
83
+ self.viz_ymax = self.space.height - 0.5
84
+
85
+ def draw_matplotlib(self, ax=None, **space_kwargs):
86
+ """Draw the orthogonal grid using matplotlib.
87
+
88
+ Args:
89
+ ax: Matplotlib axes object to draw on
90
+ **space_kwargs: Additional keyword arguments for styling.
91
+
92
+ Examples:
93
+ figsize=(10, 10), color="blue", linewidth=2.
94
+
95
+ Returns:
96
+ The modified axes object
97
+ """
98
+ fig_kwargs = {
99
+ "figsize": space_kwargs.pop("figsize", (8, 8)),
100
+ "dpi": space_kwargs.pop("dpi", 100),
101
+ }
102
+
103
+ if ax is None:
104
+ fig, ax = plt.subplots(**fig_kwargs)
105
+
106
+ # gridline styling kwargs
107
+ line_kwargs = {
108
+ "color": "gray",
109
+ "linestyle": ":",
110
+ "linewidth": 1,
111
+ "alpha": 1,
112
+ }
113
+ line_kwargs.update(space_kwargs)
114
+
115
+ ax.set_xlim(self.viz_xmin, self.viz_xmax)
116
+ ax.set_ylim(self.viz_ymin, self.viz_ymax)
117
+
118
+ # Draw grid lines
119
+ for x in np.arange(-0.5, self.space.width, 1):
120
+ ax.axvline(x, **line_kwargs)
121
+ for y in np.arange(-0.5, self.space.height, 1):
122
+ ax.axhline(y, **line_kwargs)
123
+
124
+ return ax
125
+
126
+ def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs):
127
+ """Draw the orthogonal grid using Altair.
128
+
129
+ Args:
130
+ chart_width: Width for the shown chart
131
+ chart_height: Height for the shown chart
132
+ **chart_kwargs: Additional keyword arguments for styling the chart.
133
+
134
+ Examples:
135
+ width=500, height=500, title="Grid".
136
+
137
+ Returns:
138
+ Altair chart object
139
+ """
140
+ # for axis and grid styling
141
+ axis_kwargs = {
142
+ "xlabel": chart_kwargs.pop("xlabel", "X"),
143
+ "ylabel": chart_kwargs.pop("ylabel", "Y"),
144
+ "grid_color": chart_kwargs.pop("grid_color", "lightgray"),
145
+ "grid_dash": chart_kwargs.pop("grid_dash", [2, 2]),
146
+ "grid_width": chart_kwargs.pop("grid_width", 1),
147
+ "grid_opacity": chart_kwargs.pop("grid_opacity", 1),
148
+ }
149
+
150
+ # for chart properties
151
+ chart_props = {
152
+ "width": chart_width,
153
+ "height": chart_height,
154
+ }
155
+ chart_props.update(chart_kwargs)
156
+
157
+ chart = (
158
+ alt.Chart(pd.DataFrame([{}]))
159
+ .mark_point(opacity=0)
160
+ .encode(
161
+ x=alt.X(
162
+ "X:Q",
163
+ title=axis_kwargs["xlabel"],
164
+ scale=alt.Scale(domain=[self.viz_xmin, self.viz_xmax], nice=False),
165
+ axis=alt.Axis(
166
+ grid=True,
167
+ gridColor=axis_kwargs["grid_color"],
168
+ gridDash=axis_kwargs["grid_dash"],
169
+ gridWidth=axis_kwargs["grid_width"],
170
+ gridOpacity=axis_kwargs["grid_opacity"],
171
+ ),
172
+ ),
173
+ y=alt.Y(
174
+ "Y:Q",
175
+ title=axis_kwargs["ylabel"],
176
+ scale=alt.Scale(domain=[self.viz_ymin, self.viz_ymax], nice=False),
177
+ axis=alt.Axis(
178
+ grid=True,
179
+ gridColor=axis_kwargs["grid_color"],
180
+ gridDash=axis_kwargs["grid_dash"],
181
+ gridWidth=axis_kwargs["grid_width"],
182
+ gridOpacity=axis_kwargs["grid_opacity"],
183
+ ),
184
+ ),
185
+ )
186
+ .properties(**chart_props)
187
+ )
188
+ return chart
189
+
190
+
191
+ class HexSpaceDrawer(BaseSpaceDrawer):
192
+ """Drawer for hexagonal grid spaces."""
193
+
194
+ def __init__(self, space: HexGrid):
195
+ """Initialize the hexagonal space drawer.
196
+
197
+ Args:
198
+ space: The hexagonal grid space to draw
199
+ """
200
+ super().__init__(space)
201
+ self.s_default = (180 / max(self.space.width, self.space.height)) ** 2
202
+ size = 1.0
203
+ self.x_spacing = np.sqrt(3) * size
204
+ self.y_spacing = 1.5 * size
205
+
206
+ x_max = self.space.width * self.x_spacing + (self.space.height % 2) * (
207
+ self.x_spacing / 2
208
+ )
209
+ y_max = self.space.height * self.y_spacing
210
+
211
+ x_padding = size * np.sqrt(3) / 2
212
+ y_padding = size
213
+
214
+ self.hexagons = self._get_hexmesh(self.space.width, self.space.height, size)
215
+
216
+ # Parameters for visualization limits
217
+ self.viz_xmin = -1.8 * x_padding
218
+ self.viz_xmax = x_max
219
+ self.viz_ymin = -1.8 * y_padding
220
+ self.viz_ymax = y_max
221
+
222
+ def _get_hexmesh(
223
+ self, width: int, height: int, size: float = 1.0
224
+ ) -> list[tuple[float, float]]:
225
+ """Generate hexagon vertices for the mesh. Yields list of vertex coordinates for each hexagon."""
226
+
227
+ # Helper function for getting the vertices of a hexagon given the center and size
228
+ def _get_hex_vertices(
229
+ center_x: float, center_y: float, size: float = 1.0
230
+ ) -> list[tuple[float, float]]:
231
+ """Get vertices for a hexagon centered at (center_x, center_y)."""
232
+ vertices = [
233
+ (center_x, center_y + size), # top
234
+ (center_x + size * np.sqrt(3) / 2, center_y + size / 2), # top right
235
+ (center_x + size * np.sqrt(3) / 2, center_y - size / 2), # bottom right
236
+ (center_x, center_y - size), # bottom
237
+ (center_x - size * np.sqrt(3) / 2, center_y - size / 2), # bottom left
238
+ (center_x - size * np.sqrt(3) / 2, center_y + size / 2), # top left
239
+ ]
240
+ return vertices
241
+
242
+ x_spacing = np.sqrt(3) * size
243
+ y_spacing = 1.5 * size
244
+ hexagons = []
245
+
246
+ for row, col in itertools.product(range(height), range(width)):
247
+ # Calculate center position with offset for even rows
248
+ x = col * x_spacing + (row % 2 == 0) * (x_spacing / 2)
249
+ y = row * y_spacing
250
+ hexagons.append(_get_hex_vertices(x, y, size))
251
+
252
+ return hexagons
253
+
254
+ def _get_unique_edges(self):
255
+ """Helper method to extract unique edges from all hexagons."""
256
+ edges = set()
257
+ # Generate edges for each hexagon
258
+ for vertices in self.hexagons:
259
+ # Edge logic, connecting each vertex to the next
260
+ for v1, v2 in pairwise([*vertices, vertices[0]]):
261
+ # Sort vertices to ensure consistent edge representation and avoid duplicates.
262
+ edge = tuple(sorted([tuple(np.round(v1, 6)), tuple(np.round(v2, 6))]))
263
+ edges.add(edge)
264
+ return edges
265
+
266
+ def draw_matplotlib(self, ax=None, **space_kwargs):
267
+ """Draw the hexagonal grid using matplotlib.
268
+
269
+ Args:
270
+ ax: Matplotlib axes object to draw on
271
+ **space_kwargs: Additional keyword arguments for styling.
272
+
273
+ Examples:
274
+ figsize=(8, 8), color="red", alpha=0.5.
275
+
276
+ Returns:
277
+ The modified axes object
278
+ """
279
+ fig_kwargs = {
280
+ "figsize": space_kwargs.pop("figsize", (8, 8)),
281
+ "dpi": space_kwargs.pop("dpi", 100),
282
+ }
283
+
284
+ if ax is None:
285
+ fig, ax = plt.subplots(**fig_kwargs)
286
+
287
+ line_kwargs = {
288
+ "color": "black",
289
+ "linestyle": ":",
290
+ "linewidth": 1,
291
+ "alpha": 0.8,
292
+ }
293
+ line_kwargs.update(space_kwargs)
294
+
295
+ ax.set_xlim(self.viz_xmin, self.viz_xmax)
296
+ ax.set_ylim(self.viz_ymin, self.viz_ymax)
297
+ ax.set_aspect("equal", adjustable="box")
298
+
299
+ edges = self._get_unique_edges()
300
+ ax.add_collection(LineCollection(list(edges), **line_kwargs))
301
+ return ax
302
+
303
+ def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs):
304
+ """Draw the hexagonal grid using Altair.
305
+
306
+ Args:
307
+ chart_width: Width for the shown chart
308
+ chart_height: Height for the shown chart
309
+ **chart_kwargs: Additional keyword arguments for styling the chart.
310
+
311
+ Examples:
312
+ * Line properties like color, strokeDash, strokeWidth, opacity.
313
+ * Other kwargs (e.g., width, title) apply to the chart.
314
+
315
+ Returns:
316
+ Altair chart object representing the hexagonal grid.
317
+ """
318
+ mark_kwargs = {
319
+ "color": chart_kwargs.pop("color", "black"),
320
+ "strokeDash": chart_kwargs.pop("strokeDash", [2, 2]),
321
+ "strokeWidth": chart_kwargs.pop("strokeWidth", 1),
322
+ "opacity": chart_kwargs.pop("opacity", 0.8),
323
+ }
324
+
325
+ chart_props = {
326
+ "width": chart_width,
327
+ "height": chart_height,
328
+ }
329
+ chart_props.update(chart_kwargs)
330
+
331
+ edge_data = []
332
+ edges = self._get_unique_edges()
333
+
334
+ for i, edge_tuple in enumerate(edges):
335
+ p1, p2 = edge_tuple
336
+ edge_data.append({"edge_id": i, "point_order": 0, "x": p1[0], "y": p1[1]})
337
+ edge_data.append({"edge_id": i, "point_order": 1, "x": p2[0], "y": p2[1]})
338
+
339
+ source = pd.DataFrame(edge_data)
340
+
341
+ chart = (
342
+ alt.Chart(source)
343
+ .mark_line(**mark_kwargs)
344
+ .encode(
345
+ x=alt.X(
346
+ "x:Q",
347
+ scale=alt.Scale(domain=[self.viz_xmin, self.viz_xmax], zero=False),
348
+ axis=None,
349
+ ),
350
+ y=alt.Y(
351
+ "y:Q",
352
+ scale=alt.Scale(domain=[self.viz_ymin, self.viz_ymax], zero=False),
353
+ axis=None,
354
+ ),
355
+ detail="edge_id:N",
356
+ order="point_order:Q",
357
+ )
358
+ .properties(**chart_props)
359
+ )
360
+ return chart
361
+
362
+
363
+ class NetworkSpaceDrawer(BaseSpaceDrawer):
364
+ """Drawer for network-based spaces."""
365
+
366
+ def __init__(
367
+ self,
368
+ space: Network,
369
+ layout_alg=nx.spring_layout,
370
+ layout_kwargs=None,
371
+ ):
372
+ """Initialize the network space drawer.
373
+
374
+ Args:
375
+ space: The network space to draw
376
+ layout_alg: NetworkX layout algorithm to use
377
+ layout_kwargs: Keyword arguments for the layout algorithm
378
+ """
379
+ super().__init__(space)
380
+ self.layout_alg = layout_alg
381
+ self.layout_kwargs = layout_kwargs if layout_kwargs is not None else {"seed": 0}
382
+
383
+ # gather locations for nodes in network
384
+ self.graph = self.space.G
385
+ self.pos = self.layout_alg(self.graph, **self.layout_kwargs)
386
+
387
+ x, y = list(zip(*self.pos.values())) if self.pos else ([0], [0])
388
+ xmin, xmax = min(x), max(x)
389
+ ymin, ymax = min(y), max(y)
390
+
391
+ width = xmax - xmin
392
+ height = ymax - ymin
393
+ self.s_default = (
394
+ (180 / max(width, height)) ** 2 if width > 0 or height > 0 else 1
395
+ )
396
+
397
+ # Parameters for visualization limits
398
+ self.viz_xmin = xmin - width / 20
399
+ self.viz_xmax = xmax + width / 20
400
+ self.viz_ymin = ymin - height / 20
401
+ self.viz_ymax = ymax + height / 20
402
+
403
+ def draw_matplotlib(self, ax=None, **space_kwargs):
404
+ """Draw the network using matplotlib.
405
+
406
+ Args:
407
+ ax: Matplotlib axes object to draw on.
408
+ **space_kwargs: Dictionaries of keyword arguments for styling.
409
+ Can also handle zorder for both nodes and edges if passed.
410
+ * ``node_kwargs``: A dict passed to nx.draw_networkx_nodes.
411
+ * ``edge_kwargs``: A dict passed to nx.draw_networkx_edges.
412
+
413
+ Returns:
414
+ The modified axes object.
415
+ """
416
+ if ax is None:
417
+ fig, ax = plt.subplots()
418
+
419
+ ax.set_axis_off()
420
+ ax.set_xlim(self.viz_xmin, self.viz_xmax)
421
+ ax.set_ylim(self.viz_ymin, self.viz_ymax)
422
+
423
+ node_kwargs = {"alpha": 0.5}
424
+ edge_kwargs = {"alpha": 0.5, "style": "--"}
425
+
426
+ node_kwargs.update(space_kwargs.get("node_kwargs", {}))
427
+ edge_kwargs.update(space_kwargs.get("edge_kwargs", {}))
428
+
429
+ node_zorder = node_kwargs.pop("zorder", 1)
430
+ edge_zorder = edge_kwargs.pop("zorder", 0)
431
+
432
+ nodes = nx.draw_networkx_nodes(self.graph, self.pos, ax=ax, **node_kwargs)
433
+ edges = nx.draw_networkx_edges(self.graph, self.pos, ax=ax, **edge_kwargs)
434
+
435
+ if nodes:
436
+ nodes.set_zorder(node_zorder)
437
+ # In some matplotlib versions, edges can be a list of collections
438
+ if isinstance(edges, list):
439
+ for edge_collection in edges:
440
+ edge_collection.set_zorder(edge_zorder)
441
+ elif edges:
442
+ edges.set_zorder(edge_zorder)
443
+
444
+ return ax
445
+
446
+ def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs):
447
+ """Draw the network using Altair.
448
+
449
+ Args:
450
+ chart_width: Width for the shown chart
451
+ chart_height: Height for the shown chart
452
+ **chart_kwargs: Dictionaries for styling the chart.
453
+ * ``node_kwargs``: A dict of properties for the node's mark_point.
454
+ * ``edge_kwargs``: A dict of properties for the edge's mark_rule.
455
+ * Other kwargs (e.g., title, width) are passed to chart.properties().
456
+
457
+ Returns:
458
+ Altair chart object representing the network.
459
+ """
460
+ nodes_df = pd.DataFrame(self.pos).T.reset_index()
461
+ nodes_df.columns = ["node", "x", "y"]
462
+
463
+ edges_df = pd.DataFrame(self.graph.edges(), columns=["source", "target"])
464
+ edge_positions = edges_df.merge(
465
+ nodes_df, how="left", left_on="source", right_on="node"
466
+ ).merge(
467
+ nodes_df,
468
+ how="left",
469
+ left_on="target",
470
+ right_on="node",
471
+ suffixes=("_source", "_target"),
472
+ )
473
+
474
+ node_mark_kwargs = {"filled": True, "opacity": 0.5, "size": 500}
475
+ edge_mark_kwargs = {"opacity": 0.5, "strokeDash": [5, 3]}
476
+
477
+ node_mark_kwargs.update(chart_kwargs.pop("node_kwargs", {}))
478
+ edge_mark_kwargs.update(chart_kwargs.pop("edge_kwargs", {}))
479
+
480
+ chart_kwargs = {
481
+ "width": chart_width,
482
+ "height": chart_height,
483
+ }
484
+ chart_kwargs.update(chart_kwargs)
485
+
486
+ edge_plot = (
487
+ alt.Chart(edge_positions)
488
+ .mark_rule(**edge_mark_kwargs)
489
+ .encode(
490
+ x=alt.X(
491
+ "x_source",
492
+ scale=alt.Scale(domain=[self.viz_xmin, self.viz_xmax]),
493
+ axis=None,
494
+ ),
495
+ y=alt.Y(
496
+ "y_source",
497
+ scale=alt.Scale(domain=[self.viz_ymin, self.viz_ymax]),
498
+ axis=None,
499
+ ),
500
+ x2="x_target",
501
+ y2="y_target",
502
+ )
503
+ )
504
+
505
+ node_plot = (
506
+ alt.Chart(nodes_df)
507
+ .mark_point(**node_mark_kwargs)
508
+ .encode(x="x", y="y", tooltip=["node"])
509
+ )
510
+
511
+ chart = edge_plot + node_plot
512
+
513
+ if chart_kwargs:
514
+ chart = chart.properties(**chart_kwargs)
515
+
516
+ return chart
517
+
518
+
519
+ class ContinuousSpaceDrawer(BaseSpaceDrawer):
520
+ """Drawer for continuous spaces."""
521
+
522
+ def __init__(self, space: ContinuousSpace):
523
+ """Initialize the continuous space drawer.
524
+
525
+ Args:
526
+ space: The continuous space to draw
527
+ """
528
+ super().__init__(space)
529
+ width = self.space.x_max - self.space.x_min
530
+ height = self.space.y_max - self.space.y_min
531
+ self.s_default = (
532
+ (180 / max(width, height)) ** 2 if width > 0 or height > 0 else 1
533
+ )
534
+
535
+ x_padding = width / 20
536
+ y_padding = height / 20
537
+
538
+ self.viz_xmin = self.space.x_min - x_padding
539
+ self.viz_xmax = self.space.x_max + x_padding
540
+ self.viz_ymin = self.space.y_min - y_padding
541
+ self.viz_ymax = self.space.y_max + y_padding
542
+
543
+ def draw_matplotlib(self, ax=None, **space_kwargs):
544
+ """Draw the continuous space using matplotlib.
545
+
546
+ Args:
547
+ ax: Matplotlib axes object to draw on
548
+ **space_kwargs: Keyword arguments for styling the axis frame.
549
+
550
+ Examples:
551
+ linewidth=3, color="green"
552
+
553
+ Returns:
554
+ The modified axes object
555
+ """
556
+ if ax is None:
557
+ fig, ax = plt.subplots()
558
+
559
+ border_style = "solid" if not self.space.torus else (0, (5, 10))
560
+ spine_kwargs = {"linewidth": 1.5, "color": "black", "linestyle": border_style}
561
+ spine_kwargs.update(space_kwargs)
562
+
563
+ for spine in ax.spines.values():
564
+ spine.set(**spine_kwargs)
565
+
566
+ ax.set_xlim(self.viz_xmin, self.viz_xmax)
567
+ ax.set_ylim(self.viz_ymin, self.viz_ymax)
568
+
569
+ return ax
570
+
571
+ def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs):
572
+ """Draw the continuous space using Altair.
573
+
574
+ Args:
575
+ chart_width: Width for the shown chart
576
+ chart_height: Height for the shown chart
577
+ **chart_kwargs: Keyword arguments for styling the chart's view properties.
578
+ See Altair's documentation for `configure_view`.
579
+
580
+ Returns:
581
+ An Altair Chart object representing the space.
582
+ """
583
+ chart_props = {"width": chart_width, "height": chart_height}
584
+ chart_props.update(chart_kwargs)
585
+
586
+ chart = (
587
+ alt.Chart(pd.DataFrame([{}]))
588
+ .mark_rect(color="transparent")
589
+ .encode(
590
+ x=alt.X(scale=alt.Scale(domain=[self.viz_xmin, self.viz_xmax])),
591
+ y=alt.Y(scale=alt.Scale(domain=[self.viz_ymin, self.viz_ymax])),
592
+ )
593
+ .properties(**chart_props)
594
+ )
595
+
596
+ return chart
597
+
598
+
599
+ class VoronoiSpaceDrawer(BaseSpaceDrawer):
600
+ """Drawer for Voronoi diagram spaces."""
601
+
602
+ def __init__(self, space: VoronoiGrid):
603
+ """Initialize the Voronoi space drawer.
604
+
605
+ Args:
606
+ space: The Voronoi grid space to draw
607
+ """
608
+ super().__init__(space)
609
+ if self.space.centroids_coordinates:
610
+ x_list = [i[0] for i in self.space.centroids_coordinates]
611
+ y_list = [i[1] for i in self.space.centroids_coordinates]
612
+ x_max, x_min = max(x_list), min(x_list)
613
+ y_max, y_min = max(y_list), min(y_list)
614
+ else:
615
+ x_max, x_min, y_max, y_min = 1, 0, 1, 0
616
+
617
+ width = x_max - x_min
618
+ height = y_max - y_min
619
+ self.s_default = (
620
+ (180 / max(width, height)) ** 2 if width > 0 or height > 0 else 1
621
+ )
622
+
623
+ # Parameters for visualization limits
624
+ self.viz_xmin = x_min - width / 20
625
+ self.viz_xmax = x_max + width / 20
626
+ self.viz_ymin = y_min - height / 20
627
+ self.viz_ymax = y_max + height / 20
628
+
629
+ def _clip_line(self, p1, p2, box):
630
+ """Clips a line segment using the Cohen-Sutherland algorithm.
631
+
632
+ Returns the clipped line segment (p1, p2) or None if it's outside.
633
+ """
634
+ x1, y1 = p1
635
+ x2, y2 = p2
636
+ min_x, min_y, max_x, max_y = box
637
+
638
+ # Define region codes
639
+ INSIDE, LEFT, RIGHT, BOTTOM, TOP = 0, 1, 2, 4, 8 # noqa: N806
640
+
641
+ def compute_outcode(x, y):
642
+ code = INSIDE
643
+ if x < min_x:
644
+ code |= LEFT
645
+ elif x > max_x:
646
+ code |= RIGHT
647
+ if y < min_y:
648
+ code |= BOTTOM
649
+ elif y > max_y:
650
+ code |= TOP
651
+ return code
652
+
653
+ outcode1 = compute_outcode(x1, y1)
654
+ outcode2 = compute_outcode(x2, y2)
655
+
656
+ while True:
657
+ if not (outcode1 | outcode2): # Both points inside
658
+ return (x1, y1), (x2, y2)
659
+ elif outcode1 & outcode2: # Both points share an outside region
660
+ return None
661
+ else:
662
+ outcode_out = outcode1 if outcode1 else outcode2
663
+ x, y = 0.0, 0.0
664
+
665
+ # Check for horizontal line
666
+ if y1 != y2:
667
+ if outcode_out & TOP:
668
+ x = x1 + (x2 - x1) * (max_y - y1) / (y2 - y1)
669
+ y = max_y
670
+ elif outcode_out & BOTTOM:
671
+ x = x1 + (x2 - x1) * (min_y - y1) / (y2 - y1)
672
+ y = min_y
673
+
674
+ # Check for vertical line
675
+ if x1 != x2:
676
+ if outcode_out & RIGHT:
677
+ y = y1 + (y2 - y1) * (max_x - x1) / (x2 - x1)
678
+ x = max_x
679
+ elif outcode_out & LEFT:
680
+ y = y1 + (y2 - y1) * (min_x - x1) / (x2 - x1)
681
+ x = min_x
682
+
683
+ if outcode_out == outcode1:
684
+ x1, y1 = x, y
685
+ outcode1 = compute_outcode(x1, y1)
686
+ else:
687
+ x2, y2 = x, y
688
+ outcode2 = compute_outcode(x2, y2)
689
+
690
+ def _get_clipped_segments(self):
691
+ """Helper method to perform the segment extraction, de-duplication and clipping logic."""
692
+ clip_box = (
693
+ self.viz_xmin,
694
+ self.viz_ymin,
695
+ self.viz_xmax,
696
+ self.viz_ymax,
697
+ )
698
+
699
+ unique_segments = set()
700
+ for cell in self.space.all_cells.cells:
701
+ vertices = [tuple(v) for v in cell.properties["polygon"]]
702
+ for p1, p2 in pairwise([*vertices, vertices[0]]):
703
+ # Sort to avoid duplicate segments going in opposite directions
704
+ unique_segments.add(tuple(sorted((p1, p2))))
705
+
706
+ # Clip each unique segment
707
+ final_segments = []
708
+ for p1, p2 in unique_segments:
709
+ clipped_segment = self._clip_line(p1, p2, clip_box)
710
+ if clipped_segment:
711
+ final_segments.append(clipped_segment)
712
+
713
+ return final_segments, clip_box
714
+
715
+ def draw_matplotlib(self, ax=None, **space_kwargs):
716
+ """Draw the Voronoi diagram using matplotlib.
717
+
718
+ Args:
719
+ ax: Matplotlib axes object to draw on
720
+ **space_kwargs: Keyword arguments passed to matplotlib's LineCollection.
721
+
722
+ Examples:
723
+ lw=2, alpha=0.5, colors='red'
724
+
725
+ Returns:
726
+ The modified axes object
727
+ """
728
+ if ax is None:
729
+ fig, ax = plt.subplots()
730
+
731
+ final_segments, clip_box = self._get_clipped_segments()
732
+
733
+ ax.set_xlim(clip_box[0], clip_box[2])
734
+ ax.set_ylim(clip_box[1], clip_box[3])
735
+
736
+ if final_segments:
737
+ # Define default styles for the plot
738
+ style_args = {"colors": "k", "linestyle": "dotted", "lw": 1}
739
+ style_args.update(space_kwargs)
740
+
741
+ # Create the LineCollection with the final styles
742
+ lc = LineCollection(final_segments, **style_args)
743
+ ax.add_collection(lc)
744
+
745
+ return ax
746
+
747
+ def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs):
748
+ """Draw the Voronoi diagram using Altair.
749
+
750
+ Args:
751
+ chart_width: Width for the shown chart
752
+ chart_height: Height for the shown chart
753
+ **chart_kwargs: Additional keyword arguments for styling the chart.
754
+
755
+ Examples:
756
+ * Line properties like color, strokeDash, strokeWidth, opacity.
757
+ * Other kwargs (e.g., width, title) apply to the chart.
758
+
759
+ Returns:
760
+ An Altair Chart object representing the Voronoi diagram.
761
+ """
762
+ final_segments, clip_box = self._get_clipped_segments()
763
+
764
+ # Prepare data
765
+ final_data = []
766
+ for i, (p1, p2) in enumerate(final_segments):
767
+ final_data.append({"x": p1[0], "y": p1[1], "line_id": i})
768
+ final_data.append({"x": p2[0], "y": p2[1], "line_id": i})
769
+
770
+ df = pd.DataFrame(final_data)
771
+
772
+ # Define default properties for the mark
773
+ mark_kwargs = {
774
+ "color": chart_kwargs.pop("color", "black"),
775
+ "strokeDash": chart_kwargs.pop("strokeDash", [2, 2]),
776
+ "strokeWidth": chart_kwargs.pop("strokeWidth", 1),
777
+ "opacity": chart_kwargs.pop("opacity", 0.8),
778
+ }
779
+
780
+ chart_props = {"width": chart_width, "height": chart_height}
781
+ chart_props.update(chart_kwargs)
782
+
783
+ chart = (
784
+ alt.Chart(df)
785
+ .mark_line(**mark_kwargs)
786
+ .encode(
787
+ x=alt.X(
788
+ "x:Q", scale=alt.Scale(domain=[clip_box[0], clip_box[2]]), axis=None
789
+ ),
790
+ y=alt.Y(
791
+ "y:Q", scale=alt.Scale(domain=[clip_box[1], clip_box[3]]), axis=None
792
+ ),
793
+ detail="line_id:N",
794
+ )
795
+ .properties(**chart_props)
796
+ )
797
+ return chart