Mesa 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of Mesa might be problematic. Click here for more details.

Files changed (41) hide show
  1. {Mesa-1.1.0.dist-info → Mesa-1.2.0.dist-info}/LICENSE +1 -1
  2. {Mesa-1.1.0.dist-info → Mesa-1.2.0.dist-info}/METADATA +15 -14
  3. {Mesa-1.1.0.dist-info → Mesa-1.2.0.dist-info}/RECORD +41 -41
  4. {Mesa-1.1.0.dist-info → Mesa-1.2.0.dist-info}/WHEEL +1 -1
  5. mesa/__init__.py +8 -9
  6. mesa/agent.py +2 -3
  7. mesa/batchrunner.py +19 -28
  8. mesa/datacollection.py +15 -28
  9. mesa/main.py +4 -4
  10. mesa/model.py +2 -6
  11. mesa/space.py +379 -286
  12. mesa/time.py +21 -22
  13. mesa/visualization/ModularVisualization.py +11 -9
  14. mesa/visualization/TextVisualization.py +0 -3
  15. mesa/visualization/UserParam.py +8 -11
  16. mesa/visualization/__init__.py +0 -1
  17. mesa/visualization/modules/BarChartVisualization.py +7 -8
  18. mesa/visualization/modules/CanvasGridVisualization.py +1 -3
  19. mesa/visualization/modules/ChartVisualization.py +2 -3
  20. mesa/visualization/modules/HexGridVisualization.py +1 -3
  21. mesa/visualization/modules/NetworkVisualization.py +1 -2
  22. mesa/visualization/modules/PieChartVisualization.py +2 -6
  23. mesa/visualization/templates/js/GridDraw.js +6 -10
  24. mesa/visualization/templates/js/HexDraw.js +5 -9
  25. mesa/visualization/templates/js/InteractionHandler.js +0 -2
  26. tests/test_batchrunner.py +3 -4
  27. tests/test_batchrunnerMP.py +4 -4
  28. tests/test_datacollector.py +2 -2
  29. tests/test_examples.py +8 -5
  30. tests/test_grid.py +104 -37
  31. tests/test_import_namespace.py +0 -1
  32. tests/test_lifespan.py +4 -3
  33. tests/test_main.py +5 -1
  34. tests/test_scaffold.py +2 -1
  35. tests/test_space.py +13 -20
  36. tests/test_time.py +44 -14
  37. tests/test_tornado.py +4 -2
  38. tests/test_usersettableparam.py +4 -3
  39. tests/test_visualization.py +4 -8
  40. {Mesa-1.1.0.dist-info → Mesa-1.2.0.dist-info}/entry_points.txt +0 -0
  41. {Mesa-1.1.0.dist-info → Mesa-1.2.0.dist-info}/top_level.txt +0 -0
mesa/space.py CHANGED
@@ -4,10 +4,13 @@ Mesa Space Module
4
4
 
5
5
  Objects used to add a spatial component to a model.
6
6
 
7
- Grid: base grid, a simple list-of-lists.
8
- SingleGrid: grid which strictly enforces one object per cell.
9
- MultiGrid: extension to Grid where each cell is a set of objects.
10
-
7
+ Grid: base grid, which creates a rectangular grid.
8
+ SingleGrid: extension to Grid which strictly enforces one agent per cell.
9
+ MultiGrid: extension to Grid where each cell can contain a set of agents.
10
+ HexGrid: extension to Grid to handle hexagonal neighbors.
11
+ ContinuousSpace: a two-dimensional space where each agent has an arbitrary
12
+ position of `float`'s.
13
+ NetworkGrid: a network where each node contains zero or more agents.
11
14
  """
12
15
  # Instruction for PyLint to suppress variable name errors, since we have a
13
16
  # good reason to use one-character variable names for x and y.
@@ -17,18 +20,16 @@ MultiGrid: extension to Grid where each cell is a set of objects.
17
20
  # Remove this __future__ import once the oldest supported Python is 3.10
18
21
  from __future__ import annotations
19
22
 
23
+ import collections
20
24
  import itertools
21
25
  import math
22
- from warnings import warn
23
-
24
- import numpy as np
25
-
26
+ from numbers import Real
26
27
  from typing import (
27
28
  Any,
28
29
  Callable,
29
- List,
30
30
  Iterable,
31
31
  Iterator,
32
+ List,
32
33
  Sequence,
33
34
  Tuple,
34
35
  TypeVar,
@@ -36,11 +37,17 @@ from typing import (
36
37
  cast,
37
38
  overload,
38
39
  )
40
+ from warnings import warn
41
+
42
+ import networkx as nx
43
+ import numpy as np
44
+ import numpy.typing as npt
39
45
 
40
46
  # For Mypy
41
47
  from .agent import Agent
42
- from numbers import Real
43
- import numpy.typing as npt
48
+
49
+ # for better performance, we calculate the tuple to use in the is_integer function
50
+ _types_integer = (int, np.integer)
44
51
 
45
52
  Coordinate = Tuple[int, int]
46
53
  # used in ContinuousSpace
@@ -55,41 +62,35 @@ MultiGridContent = List[Agent]
55
62
  F = TypeVar("F", bound=Callable[..., Any])
56
63
 
57
64
 
58
- def clamp(x: float, lowest: float, highest: float) -> float:
59
- # much faster than np.clip for a scalar x.
60
- return lowest if x <= lowest else (highest if x >= highest else x)
61
-
62
-
63
65
  def accept_tuple_argument(wrapped_function: F) -> F:
64
66
  """Decorator to allow grid methods that take a list of (x, y) coord tuples
65
67
  to also handle a single position, by automatically wrapping tuple in
66
68
  single-item list rather than forcing user to do it."""
67
69
 
68
- def wrapper(*args: Any) -> Any:
69
- if isinstance(args[1], tuple) and len(args[1]) == 2:
70
- return wrapped_function(args[0], [args[1]])
70
+ def wrapper(grid_instance, positions) -> Any:
71
+ if isinstance(positions, tuple) and len(positions) == 2:
72
+ return wrapped_function(grid_instance, [positions])
71
73
  else:
72
- return wrapped_function(*args)
74
+ return wrapped_function(grid_instance, positions)
73
75
 
74
76
  return cast(F, wrapper)
75
77
 
76
78
 
77
79
  def is_integer(x: Real) -> bool:
78
80
  # Check if x is either a CPython integer or Numpy integer.
79
- return isinstance(x, (int, np.integer))
81
+ return isinstance(x, _types_integer)
80
82
 
81
83
 
82
- class Grid:
83
- """Base class for a square grid.
84
+ class _Grid:
85
+ """Base class for a rectangular grid.
84
86
 
85
- Grid cells are indexed by [x][y], where [0][0] is assumed to be the
86
- bottom-left and [width-1][height-1] is the top-right. If a grid is
87
+ Grid cells are indexed by [x, y], where [0, 0] is assumed to be the
88
+ bottom-left and [width-1, height-1] is the top-right. If a grid is
87
89
  toroidal, the top and bottom, and left and right, edges wrap to each other
88
90
 
89
91
  Properties:
90
92
  width, height: The grid's width and height.
91
93
  torus: Boolean which determines whether to treat the grid as a torus.
92
- grid: Internal list-of-lists which holds the grid cells themselves.
93
94
  """
94
95
 
95
96
  def __init__(self, width: int, height: int, torus: bool) -> None:
@@ -102,26 +103,42 @@ class Grid:
102
103
  self.height = height
103
104
  self.width = width
104
105
  self.torus = torus
106
+ self.num_cells = height * width
105
107
 
106
- self.grid: list[list[GridContent]] = []
107
-
108
- for x in range(self.width):
109
- col: list[GridContent] = []
110
- for y in range(self.height):
111
- col.append(self.default_val())
112
- self.grid.append(col)
108
+ # Internal list-of-lists which holds the grid cells themselves
109
+ self._grid: list[list[GridContent]]
110
+ self._grid = [
111
+ [self.default_val() for _ in range(self.height)] for _ in range(self.width)
112
+ ]
113
113
 
114
- # Add all cells to the empties list.
115
- self.empties = set(itertools.product(*(range(self.width), range(self.height))))
114
+ # Flag to check if the empties set has been created. Better than initializing
115
+ # _empties as set() because in this case it would become impossible to discern
116
+ # if the set hasn't still being built or if it has become empty after creation.
117
+ self._empties_built = False
116
118
 
117
119
  # Neighborhood Cache
118
- self._neighborhood_cache: dict[Any, list[Coordinate]] = dict()
120
+ self._neighborhood_cache: dict[Any, list[Coordinate]] = {}
119
121
 
120
122
  @staticmethod
121
123
  def default_val() -> None:
122
124
  """Default value for new cell elements."""
123
125
  return None
124
126
 
127
+ @property
128
+ def empties(self) -> set:
129
+ if not self._empties_built:
130
+ self.build_empties()
131
+ return self._empties
132
+
133
+ def build_empties(self) -> None:
134
+ self._empties = set(
135
+ filter(
136
+ self.is_cell_empty,
137
+ itertools.product(range(self.width), range(self.height)),
138
+ )
139
+ )
140
+ self._empties_built = True
141
+
125
142
  @overload
126
143
  def __getitem__(self, index: int) -> list[GridContent]:
127
144
  ...
@@ -144,55 +161,45 @@ class Grid:
144
161
 
145
162
  if isinstance(index, int):
146
163
  # grid[x]
147
- return self.grid[index]
164
+ return self._grid[index]
148
165
  elif isinstance(index[0], tuple):
149
- # grid[(x1, y1), (x2, y2)]
166
+ # grid[(x1, y1), (x2, y2), ...]
150
167
  index = cast(Sequence[Coordinate], index)
151
-
152
- cells = []
153
- for pos in index:
154
- x1, y1 = self.torus_adj(pos)
155
- cells.append(self.grid[x1][y1])
156
- return cells
168
+ return [self._grid[x][y] for x, y in map(self.torus_adj, index)]
157
169
 
158
170
  x, y = index
171
+ x_int, y_int = is_integer(x), is_integer(y)
159
172
 
160
- if is_integer(x) and is_integer(y):
173
+ if x_int and y_int:
161
174
  # grid[x, y]
162
175
  index = cast(Coordinate, index)
163
176
  x, y = self.torus_adj(index)
164
- return self.grid[x][y]
165
-
166
- if is_integer(x):
177
+ return self._grid[x][y]
178
+ elif x_int:
167
179
  # grid[x, :]
168
180
  x, _ = self.torus_adj((x, 0))
169
- x = slice(x, x + 1)
170
-
171
- if is_integer(y):
181
+ y = cast(slice, y)
182
+ return self._grid[x][y]
183
+ elif y_int:
172
184
  # grid[:, y]
173
185
  _, y = self.torus_adj((0, y))
174
- y = slice(y, y + 1)
175
-
176
- # grid[:, :]
177
- x, y = (cast(slice, x), cast(slice, y))
178
- cells = []
179
- for rows in self.grid[x]:
180
- for cell in rows[y]:
181
- cells.append(cell)
182
- return cells
183
-
184
- raise IndexError
186
+ x = cast(slice, x)
187
+ return [rows[y] for rows in self._grid[x]]
188
+ else:
189
+ # grid[:, :]
190
+ x, y = (cast(slice, x), cast(slice, y))
191
+ return [cell for rows in self._grid[x] for cell in rows[y]]
185
192
 
186
193
  def __iter__(self) -> Iterator[GridContent]:
187
194
  """Create an iterator that chains the rows of the grid together
188
195
  as if it is one list:"""
189
- return itertools.chain(*self.grid)
196
+ return itertools.chain(*self._grid)
190
197
 
191
198
  def coord_iter(self) -> Iterator[tuple[GridContent, int, int]]:
192
199
  """An iterator that returns coordinates as well as cell contents."""
193
200
  for row in range(self.width):
194
201
  for col in range(self.height):
195
- yield self.grid[row][col], row, col # agent, x, y
202
+ yield self._grid[row][col], row, col # agent, x, y
196
203
 
197
204
  def neighbor_iter(self, pos: Coordinate, moore: bool = True) -> Iterator[Agent]:
198
205
  """Iterate over position neighbors.
@@ -230,7 +237,7 @@ class Grid:
230
237
  radius: radius, in cells, of neighborhood to get.
231
238
 
232
239
  Returns:
233
- A list of coordinate tuples representing the neighborhood. For
240
+ An iterator of coordinate tuples representing the neighborhood. For
234
241
  example with radius 1, it will return list with number of elements
235
242
  equals at most 9 (8) if Moore, 5 (4) if Von Neumann (if not
236
243
  including the center).
@@ -265,30 +272,54 @@ class Grid:
265
272
  cache_key = (pos, moore, include_center, radius)
266
273
  neighborhood = self._neighborhood_cache.get(cache_key, None)
267
274
 
268
- if neighborhood is None:
269
- coordinates: set[Coordinate] = set()
275
+ if neighborhood is not None:
276
+ return neighborhood
270
277
 
271
- x, y = pos
272
- for dy in range(-radius, radius + 1):
273
- for dx in range(-radius, radius + 1):
274
- if dx == 0 and dy == 0 and not include_center:
275
- continue
276
- # Skip coordinates that are outside manhattan distance
278
+ # We use a list instead of a dict for the neighborhood because it would
279
+ # be easier to port the code to Cython or Numba (for performance
280
+ # purpose), with minimal changes. To better understand how the
281
+ # algorithm was conceived, look at
282
+ # https://github.com/projectmesa/mesa/pull/1476#issuecomment-1306220403
283
+ # and the discussion in that PR in general.
284
+ neighborhood = []
285
+
286
+ x, y = pos
287
+ if self.torus:
288
+ x_max_radius, y_max_radius = self.width // 2, self.height // 2
289
+ x_radius, y_radius = min(radius, x_max_radius), min(radius, y_max_radius)
290
+
291
+ # For each dimension, in the edge case where the radius is as big as
292
+ # possible and the dimension is even, we need to shrink by one the range
293
+ # of values, to avoid duplicates in neighborhood. For example, if
294
+ # the width is 4, while x, x_radius, and x_max_radius are 2, then
295
+ # (x + dx) has a value from 0 to 4 (inclusive), but this means that
296
+ # the 0 position is repeated since 0 % 4 and 4 % 4 are both 0.
297
+ xdim_even, ydim_even = (self.width + 1) % 2, (self.height + 1) % 2
298
+ kx = int(x_radius == x_max_radius and xdim_even)
299
+ ky = int(y_radius == y_max_radius and ydim_even)
300
+
301
+ for dx in range(-x_radius, x_radius + 1 - kx):
302
+ for dy in range(-y_radius, y_radius + 1 - ky):
277
303
  if not moore and abs(dx) + abs(dy) > radius:
278
304
  continue
279
305
 
280
- coord = (x + dx, y + dy)
306
+ nx, ny = (x + dx) % self.width, (y + dy) % self.height
307
+ neighborhood.append((nx, ny))
308
+ else:
309
+ x_range = range(max(0, x - radius), min(self.width, x + radius + 1))
310
+ y_range = range(max(0, y - radius), min(self.height, y + radius + 1))
311
+
312
+ for nx in x_range:
313
+ for ny in y_range:
314
+ if not moore and abs(nx - x) + abs(ny - y) > radius:
315
+ continue
281
316
 
282
- if self.out_of_bounds(coord):
283
- # Skip if not a torus and new coords out of bounds.
284
- if not self.torus:
285
- continue
286
- coord = self.torus_adj(coord)
317
+ neighborhood.append((nx, ny))
287
318
 
288
- coordinates.add(coord)
319
+ if not include_center and neighborhood:
320
+ neighborhood.remove(pos)
289
321
 
290
- neighborhood = sorted(coordinates)
291
- self._neighborhood_cache[cache_key] = neighborhood
322
+ self._neighborhood_cache[cache_key] = neighborhood
292
323
 
293
324
  return neighborhood
294
325
 
@@ -366,34 +397,40 @@ class Grid:
366
397
  def iter_cell_list_contents(
367
398
  self, cell_list: Iterable[Coordinate]
368
399
  ) -> Iterator[Agent]:
369
- """Returns an iterator of the contents of the cells
370
- identified in cell_list.
400
+ """Returns an iterator of the agents contained in the cells identified
401
+ in `cell_list`; cells with empty content are excluded.
371
402
 
372
403
  Args:
373
404
  cell_list: Array-like of (x, y) tuples, or single tuple.
374
405
 
375
406
  Returns:
376
- An iterator of the contents of the cells identified in cell_list
407
+ An iterator of the agents contained in the cells identified in `cell_list`.
377
408
  """
378
- # Note: filter(None, iterator) filters away an element of iterator that
379
- # is falsy. Hence, iter_cell_list_contents returns only non-empty
380
- # contents.
381
- return filter(None, (self.grid[x][y] for x, y in cell_list))
409
+ # iter_cell_list_contents returns only non-empty contents.
410
+ return (
411
+ self._grid[x][y]
412
+ for x, y in itertools.filterfalse(self.is_cell_empty, cell_list)
413
+ )
382
414
 
383
415
  @accept_tuple_argument
384
416
  def get_cell_list_contents(self, cell_list: Iterable[Coordinate]) -> list[Agent]:
385
- """Returns a list of the contents of the cells
386
- identified in cell_list.
387
- Note: this method returns a list of `Agent`'s; `None` contents are excluded.
417
+ """Returns an iterator of the agents contained in the cells identified
418
+ in `cell_list`; cells with empty content are excluded.
388
419
 
389
420
  Args:
390
421
  cell_list: Array-like of (x, y) tuples, or single tuple.
391
422
 
392
423
  Returns:
393
- A list of the contents of the cells identified in cell_list
424
+ A list of the agents contained in the cells identified in `cell_list`.
394
425
  """
395
426
  return list(self.iter_cell_list_contents(cell_list))
396
427
 
428
+ def place_agent(self, agent: Agent, pos: Coordinate) -> None:
429
+ ...
430
+
431
+ def remove_agent(self, agent: Agent) -> None:
432
+ ...
433
+
397
434
  def move_agent(self, agent: Agent, pos: Coordinate) -> None:
398
435
  """Move an agent from its current position to a new position.
399
436
 
@@ -404,63 +441,55 @@ class Grid:
404
441
  """
405
442
  pos = self.torus_adj(pos)
406
443
  self.remove_agent(agent)
407
- self._place_agent(agent, pos)
408
- agent.pos = pos
444
+ self.place_agent(agent, pos)
409
445
 
410
- def place_agent(self, agent: Agent, pos: Coordinate) -> None:
411
- """Position an agent on the grid, and set its pos variable."""
412
- self._place_agent(agent, pos)
413
- agent.pos = pos
446
+ def swap_pos(self, agent_a: Agent, agent_b: Agent) -> None:
447
+ """Swap agents positions"""
448
+ agents_no_pos = []
449
+ if (pos_a := agent_a.pos) is None:
450
+ agents_no_pos.append(agent_a)
451
+ if (pos_b := agent_b.pos) is None:
452
+ agents_no_pos.append(agent_b)
453
+ if agents_no_pos:
454
+ agents_no_pos = [f"<Agent id: {a.unique_id}>" for a in agents_no_pos]
455
+ raise Exception(f"{', '.join(agents_no_pos)} - not on the grid")
414
456
 
415
- def _place_agent(self, agent: Agent, pos: Coordinate) -> None:
416
- """Place the agent at the correct location."""
417
- x, y = pos
418
- self.grid[x][y] = agent
419
- self.empties.discard(pos)
457
+ if pos_a == pos_b:
458
+ return
420
459
 
421
- def remove_agent(self, agent: Agent) -> None:
422
- """Remove the agent from the grid and set its pos attribute to None."""
423
- pos = agent.pos
424
- x, y = pos
425
- self.grid[x][y] = self.default_val()
426
- self.empties.add(pos)
427
- agent.pos = None
460
+ self.remove_agent(agent_a)
461
+ self.remove_agent(agent_b)
462
+
463
+ self.place_agent(agent_a, pos_b)
464
+ self.place_agent(agent_b, pos_a)
428
465
 
429
466
  def is_cell_empty(self, pos: Coordinate) -> bool:
430
467
  """Returns a bool of the contents of a cell."""
431
468
  x, y = pos
432
- return self.grid[x][y] == self.default_val()
469
+ return self._grid[x][y] == self.default_val()
433
470
 
434
471
  def move_to_empty(
435
472
  self, agent: Agent, cutoff: float = 0.998, num_agents: int | None = None
436
473
  ) -> None:
437
474
  """Moves agent to a random empty cell, vacating agent's old cell."""
438
- if len(self.empties) == 0:
475
+ if num_agents is not None:
476
+ warn(
477
+ (
478
+ "`num_agents` is being deprecated since it's no longer used "
479
+ "inside `move_to_empty`. It shouldn't be passed as a parameter."
480
+ ),
481
+ DeprecationWarning,
482
+ )
483
+ num_empty_cells = len(self.empties)
484
+ if num_empty_cells == 0:
439
485
  raise Exception("ERROR: No empty cells")
440
- if num_agents is None:
441
- try:
442
- num_agents = agent.model.schedule.get_agent_count()
443
- except AttributeError:
444
- raise Exception(
445
- "Your agent is not attached to a model, and so Mesa is unable\n"
446
- "to figure out the total number of agents you have created.\n"
447
- "This number is required in order to calculate the threshold\n"
448
- "for using a much faster algorithm to find an empty cell.\n"
449
- "In this case, you must specify `num_agents`."
450
- )
451
- new_pos = (0, 0) # Initialize it with a starting value.
452
- # This method is based on Agents.jl's random_empty() implementation.
453
- # See https://github.com/JuliaDynamics/Agents.jl/pull/541.
454
- # For the discussion, see
455
- # https://github.com/projectmesa/mesa/issues/1052.
456
- # This switch assumes the worst case (for this algorithm) of one
457
- # agent per position, which is not true in general but is appropriate
458
- # here.
459
- if clamp(num_agents / (self.width * self.height), 0.0, 1.0) < cutoff:
460
- # The default cutoff value provided is the break-even comparison
461
- # with the time taken in the else branching point.
462
- # The number is measured to be 0.998 in Agents.jl, but since Mesa
463
- # run under different environment, the number is different here.
486
+
487
+ # This method is based on Agents.jl's random_empty() implementation. See
488
+ # https://github.com/JuliaDynamics/Agents.jl/pull/541. For the discussion, see
489
+ # https://github.com/projectmesa/mesa/issues/1052. The default cutoff value
490
+ # provided is the break-even comparison with the time taken in the else
491
+ # branching point.
492
+ if 1 - num_empty_cells / self.num_cells < cutoff:
464
493
  while True:
465
494
  new_pos = (
466
495
  agent.random.randrange(self.width),
@@ -471,8 +500,7 @@ class Grid:
471
500
  else:
472
501
  new_pos = agent.random.choice(sorted(self.empties))
473
502
  self.remove_agent(agent)
474
- self._place_agent(agent, new_pos)
475
- agent.pos = new_pos
503
+ self.place_agent(agent, new_pos)
476
504
 
477
505
  def find_empty(self) -> Coordinate | None:
478
506
  """Pick a random empty cell."""
@@ -499,56 +527,86 @@ class Grid:
499
527
  return len(self.empties) > 0
500
528
 
501
529
 
502
- class SingleGrid(Grid):
503
- """Grid where each cell contains exactly at most one object."""
530
+ class SingleGrid(_Grid):
531
+ """Rectangular grid where each cell contains exactly at most one agent.
504
532
 
505
- empties: set[Coordinate] = set()
533
+ Grid cells are indexed by [x, y], where [0, 0] is assumed to be the
534
+ bottom-left and [width-1, height-1] is the top-right. If a grid is
535
+ toroidal, the top and bottom, and left and right, edges wrap to each other.
536
+
537
+ Properties:
538
+ width, height: The grid's width and height.
539
+ torus: Boolean which determines whether to treat the grid as a torus.
540
+ """
506
541
 
507
542
  def position_agent(
508
543
  self, agent: Agent, x: int | str = "random", y: int | str = "random"
509
544
  ) -> None:
510
545
  """Position an agent on the grid.
511
- This is used when first placing agents! Use 'move_to_empty()'
512
- when you want agents to jump to an empty cell.
546
+ This is used when first placing agents! Setting either x or y to "random"
547
+ gives the same behavior as 'move_to_empty()' to get a random position.
548
+ If x or y are positive, they are used.
513
549
  Use 'swap_pos()' to swap agents positions.
514
- If x or y are positive, they are used, but if "random",
515
- we get a random position.
516
- Ensure this random position is not occupied (in Grid).
517
550
  """
551
+ warn(
552
+ (
553
+ "`position_agent` is being deprecated; use instead "
554
+ "`place_agent` to place an agent at a specified "
555
+ "location or `move_to_empty` to place an agent "
556
+ "at a random empty cell."
557
+ ),
558
+ DeprecationWarning,
559
+ )
560
+
561
+ if not (isinstance(x, int) or x == "random"):
562
+ raise Exception(
563
+ "x must be an integer or a string 'random'."
564
+ f" Actual type: {type(x)}. Actual value: {x}."
565
+ )
566
+ if not (isinstance(y, int) or y == "random"):
567
+ raise Exception(
568
+ "y must be an integer or a string 'random'."
569
+ f" Actual type: {type(y)}. Actual value: {y}."
570
+ )
571
+
518
572
  if x == "random" or y == "random":
519
- if len(self.empties) == 0:
520
- raise Exception("ERROR: Grid full")
521
- coords = agent.random.choice(sorted(self.empties))
573
+ self.move_to_empty(agent)
522
574
  else:
523
575
  coords = (x, y)
524
- agent.pos = coords
525
- self._place_agent(agent, coords)
576
+ self.place_agent(agent, coords)
526
577
 
527
- def _place_agent(self, agent: Agent, pos: Coordinate) -> None:
578
+ def place_agent(self, agent: Agent, pos: Coordinate) -> None:
579
+ """Place the agent at the specified location, and set its pos variable."""
528
580
  if self.is_cell_empty(pos):
529
- super()._place_agent(agent, pos)
581
+ x, y = pos
582
+ self._grid[x][y] = agent
583
+ if self._empties_built:
584
+ self._empties.discard(pos)
585
+ agent.pos = pos
530
586
  else:
531
587
  raise Exception("Cell not empty")
532
588
 
589
+ def remove_agent(self, agent: Agent) -> None:
590
+ """Remove the agent from the grid and set its pos attribute to None."""
591
+ if (pos := agent.pos) is None:
592
+ return
593
+ x, y = pos
594
+ self._grid[x][y] = self.default_val()
595
+ if self._empties_built:
596
+ self._empties.add(pos)
597
+ agent.pos = None
533
598
 
534
- class MultiGrid(Grid):
535
- """Grid where each cell can contain more than one object.
536
599
 
537
- Grid cells are indexed by [x][y], where [0][0] is assumed to be at
538
- bottom-left and [width-1][height-1] is the top-right. If a grid is
539
- toroidal, the top and bottom, and left and right, edges wrap to each other.
600
+ class MultiGrid(_Grid):
601
+ """Rectangular grid where each cell can contain more than one agent.
540
602
 
541
- Each grid cell holds a set object.
603
+ Grid cells are indexed by [x, y], where [0, 0] is assumed to be at
604
+ bottom-left and [width-1, height-1] is the top-right. If a grid is
605
+ toroidal, the top and bottom, and left and right, edges wrap to each other.
542
606
 
543
607
  Properties:
544
608
  width, height: The grid's width and height.
545
-
546
609
  torus: Boolean which determines whether to treat the grid as a torus.
547
-
548
- grid: Internal list-of-lists which holds the grid cells themselves.
549
-
550
- Methods:
551
- get_neighbors: Returns the objects surrounding a given cell.
552
610
  """
553
611
 
554
612
  grid: list[list[MultiGridContent]]
@@ -558,43 +616,45 @@ class MultiGrid(Grid):
558
616
  """Default value for new cell elements."""
559
617
  return []
560
618
 
561
- def _place_agent(self, agent: Agent, pos: Coordinate) -> None:
562
- """Place the agent at the correct location."""
619
+ def place_agent(self, agent: Agent, pos: Coordinate) -> None:
620
+ """Place the agent at the specified location, and set its pos variable."""
563
621
  x, y = pos
564
- if agent not in self.grid[x][y]:
565
- self.grid[x][y].append(agent)
566
- self.empties.discard(pos)
622
+ if agent.pos is None or agent not in self._grid[x][y]:
623
+ self._grid[x][y].append(agent)
624
+ agent.pos = pos
625
+ if self._empties_built:
626
+ self._empties.discard(pos)
567
627
 
568
628
  def remove_agent(self, agent: Agent) -> None:
569
629
  """Remove the agent from the given location and set its pos attribute to None."""
570
630
  pos = agent.pos
571
631
  x, y = pos
572
- self.grid[x][y].remove(agent)
573
- if self.is_cell_empty(pos):
574
- self.empties.add(pos)
632
+ self._grid[x][y].remove(agent)
633
+ if self._empties_built and self.is_cell_empty(pos):
634
+ self._empties.add(pos)
575
635
  agent.pos = None
576
636
 
577
637
  @accept_tuple_argument
578
638
  def iter_cell_list_contents(
579
639
  self, cell_list: Iterable[Coordinate]
580
- ) -> Iterator[MultiGridContent]:
581
- """Returns an iterator of the contents of the
582
- cells identified in cell_list.
640
+ ) -> Iterator[Agent]:
641
+ """Returns an iterator of the agents contained in the cells identified
642
+ in `cell_list`; cells with empty content are excluded.
583
643
 
584
644
  Args:
585
645
  cell_list: Array-like of (x, y) tuples, or single tuple.
586
646
 
587
647
  Returns:
588
- A iterator of the contents of the cells identified in cell_list
589
-
648
+ An iterator of the agents contained in the cells identified in `cell_list`.
590
649
  """
591
650
  return itertools.chain.from_iterable(
592
- self[x][y] for x, y in cell_list if not self.is_cell_empty((x, y))
651
+ self._grid[x][y]
652
+ for x, y in itertools.filterfalse(self.is_cell_empty, cell_list)
593
653
  )
594
654
 
595
655
 
596
- class HexGrid(Grid):
597
- """Hexagonal Grid: Extends Grid to handle hexagonal neighbors.
656
+ class HexGrid(SingleGrid):
657
+ """Hexagonal Grid: Extends SingleGrid to handle hexagonal neighbors.
598
658
 
599
659
  Functions according to odd-q rules.
600
660
  See http://www.redblobgames.com/grids/hexagons/#coordinates for more.
@@ -611,11 +671,20 @@ class HexGrid(Grid):
611
671
  in the neighborhood of a certain point.
612
672
  """
613
673
 
614
- def iter_neighborhood(
674
+ def torus_adj_2d(self, pos: Coordinate) -> Coordinate:
675
+ return pos[0] % self.width, pos[1] % self.height
676
+
677
+ def get_neighborhood(
615
678
  self, pos: Coordinate, include_center: bool = False, radius: int = 1
616
- ) -> Iterator[Coordinate]:
617
- """Return an iterator over cell coordinates that are in the
618
- neighborhood of a certain point.
679
+ ) -> list[Coordinate]:
680
+ """Return a list of coordinates that are in the
681
+ neighborhood of a certain point. To calculate the neighborhood
682
+ for a HexGrid the parity of the x coordinate of the point is
683
+ important, the neighborhood can be sketched as:
684
+
685
+ Always: (0,-), (0,+)
686
+ When x is even: (-,+), (-,0), (+,+), (+,0)
687
+ When x is odd: (-,0), (-,-), (+,0), (+,-)
619
688
 
620
689
  Args:
621
690
  pos: Coordinate tuple for the neighborhood to get.
@@ -629,49 +698,69 @@ class HexGrid(Grid):
629
698
  equals at most 9 (8) if Moore, 5 (4) if Von Neumann (if not
630
699
  including the center).
631
700
  """
701
+ cache_key = (pos, include_center, radius)
702
+ neighborhood = self._neighborhood_cache.get(cache_key, None)
632
703
 
633
- def torus_adj_2d(pos: Coordinate) -> Coordinate:
634
- return (pos[0] % self.width, pos[1] % self.height)
704
+ if neighborhood is not None:
705
+ return neighborhood
635
706
 
707
+ queue = collections.deque()
708
+ queue.append(pos)
636
709
  coordinates = set()
637
710
 
638
- def find_neighbors(pos: Coordinate, radius: int) -> None:
639
- x, y = pos
640
-
641
- """
642
- Both: (0,-), (0,+)
643
-
644
- Even: (-,+), (-,0), (+,+), (+,0)
645
- Odd: (-,0), (-,-), (+,0), (+,-)
646
- """
647
- adjacent = [(x, y - 1), (x, y + 1)]
648
-
649
- if include_center:
650
- adjacent.append(pos)
651
-
652
- if x % 2 == 0:
653
- adjacent += [(x - 1, y + 1), (x - 1, y), (x + 1, y + 1), (x + 1, y)]
654
- else:
655
- adjacent += [(x - 1, y), (x - 1, y - 1), (x + 1, y), (x + 1, y - 1)]
656
-
657
- if self.torus is False:
658
- adjacent = list(
659
- filter(lambda coords: not self.out_of_bounds(coords), adjacent)
660
- )
661
- else:
662
- adjacent = [torus_adj_2d(coord) for coord in adjacent]
711
+ while radius > 0:
712
+ level_size = len(queue)
713
+ radius -= 1
714
+
715
+ for _i in range(level_size):
716
+ x, y = queue.pop()
717
+
718
+ if x % 2 == 0:
719
+ adjacent = [
720
+ (x, y - 1),
721
+ (x, y + 1),
722
+ (x - 1, y + 1),
723
+ (x - 1, y),
724
+ (x + 1, y + 1),
725
+ (x + 1, y),
726
+ ]
727
+ else:
728
+ adjacent = [
729
+ (x, y - 1),
730
+ (x, y + 1),
731
+ (x - 1, y),
732
+ (x - 1, y - 1),
733
+ (x + 1, y),
734
+ (x + 1, y - 1),
735
+ ]
736
+
737
+ if self.torus:
738
+ adjacent = [
739
+ coord
740
+ for coord in map(self.torus_adj_2d, adjacent)
741
+ if coord not in coordinates
742
+ ]
743
+ else:
744
+ adjacent = [
745
+ coord
746
+ for coord in adjacent
747
+ if not self.out_of_bounds(coord) and coord not in coordinates
748
+ ]
749
+
750
+ coordinates.update(adjacent)
751
+
752
+ if radius > 0:
753
+ queue.extendleft(adjacent)
663
754
 
664
- coordinates.update(adjacent)
665
-
666
- if radius > 1:
667
- [find_neighbors(coords, radius - 1) for coords in adjacent]
668
-
669
- find_neighbors(pos, radius)
755
+ if include_center:
756
+ coordinates.add(pos)
757
+ else:
758
+ coordinates.discard(pos)
670
759
 
671
- if not include_center and pos in coordinates:
672
- coordinates.remove(pos)
760
+ neighborhood = sorted(coordinates)
761
+ self._neighborhood_cache[cache_key] = neighborhood
673
762
 
674
- yield from coordinates
763
+ return neighborhood
675
764
 
676
765
  def neighbor_iter(self, pos: Coordinate) -> Iterator[Agent]:
677
766
  """Iterate over position neighbors.
@@ -686,11 +775,11 @@ class HexGrid(Grid):
686
775
  )
687
776
  return self.iter_neighbors(pos)
688
777
 
689
- def get_neighborhood(
778
+ def iter_neighborhood(
690
779
  self, pos: Coordinate, include_center: bool = False, radius: int = 1
691
- ) -> list[Coordinate]:
692
- """Return a list of cells that are in the neighborhood of a
693
- certain point.
780
+ ) -> Iterator[Coordinate]:
781
+ """Return an iterator over cell coordinates that are in the
782
+ neighborhood of a certain point.
694
783
 
695
784
  Args:
696
785
  pos: Coordinate tuple for the neighborhood to get.
@@ -699,10 +788,9 @@ class HexGrid(Grid):
699
788
  radius: radius, in cells, of neighborhood to get.
700
789
 
701
790
  Returns:
702
- A list of coordinate tuples representing the neighborhood;
703
- With radius 1
791
+ An iterator of coordinate tuples representing the neighborhood.
704
792
  """
705
- return list(self.iter_neighborhood(pos, include_center, radius))
793
+ yield from self.get_neighborhood(pos, include_center, radius)
706
794
 
707
795
  def iter_neighbors(
708
796
  self, pos: Coordinate, include_center: bool = False, radius: int = 1
@@ -719,7 +807,7 @@ class HexGrid(Grid):
719
807
  Returns:
720
808
  An iterator of non-None objects in the given neighborhood
721
809
  """
722
- neighborhood = self.iter_neighborhood(pos, include_center, radius)
810
+ neighborhood = self.get_neighborhood(pos, include_center, radius)
723
811
  return self.iter_cell_list_contents(neighborhood)
724
812
 
725
813
  def get_neighbors(
@@ -743,16 +831,14 @@ class HexGrid(Grid):
743
831
  class ContinuousSpace:
744
832
  """Continuous space where each agent can have an arbitrary position.
745
833
 
746
- Assumes that all agents are point objects, and have a pos property storing
747
- their position as an (x, y) tuple.
834
+ Assumes that all agents have a pos property storing their position as
835
+ an (x, y) tuple.
748
836
 
749
- This class uses a numpy array internally to store agent objects, to speed
837
+ This class uses a numpy array internally to store agents in order to speed
750
838
  up neighborhood lookups. This array is calculated on the first neighborhood
751
- lookup, and is reused (and updated) until agents are added or removed.
839
+ lookup, and is updated if agents are added or removed.
752
840
  """
753
841
 
754
- _grid = None
755
-
756
842
  def __init__(
757
843
  self,
758
844
  x_max: float,
@@ -785,18 +871,16 @@ class ContinuousSpace:
785
871
  self._agent_to_index: dict[Agent, int | None] = {}
786
872
 
787
873
  def _build_agent_cache(self):
788
- """Cache Agent positions to speed up neighbors calculations."""
874
+ """Cache agents positions to speed up neighbors calculations."""
789
875
  self._index_to_agent = {}
790
- agents = self._agent_to_index.keys()
791
- for idx, agent in enumerate(agents):
876
+ for idx, agent in enumerate(self._agent_to_index):
792
877
  self._agent_to_index[agent] = idx
793
878
  self._index_to_agent[idx] = agent
794
- self._agent_points = np.array(
795
- [self._index_to_agent[idx].pos for idx in range(len(agents))]
796
- )
879
+ # Since dicts are ordered by insertion, we can iterate through agents keys
880
+ self._agent_points = np.array([agent.pos for agent in self._agent_to_index])
797
881
 
798
882
  def _invalidate_agent_cache(self):
799
- """Clear cached data of Agents and positions in the space."""
883
+ """Clear cached data of agents and positions in the space."""
800
884
  self._agent_points = None
801
885
  self._index_to_agent = {}
802
886
 
@@ -826,18 +910,17 @@ class ContinuousSpace:
826
910
  # instead of invalidating the full cache,
827
911
  # apply the move to the cached values
828
912
  idx = self._agent_to_index[agent]
829
- self._agent_points[idx, 0] = pos[0]
830
- self._agent_points[idx, 1] = pos[1]
913
+ self._agent_points[idx] = pos
831
914
 
832
915
  def remove_agent(self, agent: Agent) -> None:
833
- """Remove an agent from the simulation.
916
+ """Remove an agent from the space.
834
917
 
835
918
  Args:
836
919
  agent: The agent object to remove
837
920
  """
838
921
  if agent not in self._agent_to_index:
839
922
  raise Exception("Agent does not exist in the space")
840
- self._agent_to_index.pop(agent)
923
+ del self._agent_to_index[agent]
841
924
 
842
925
  self._invalidate_agent_cache()
843
926
  agent.pos = None
@@ -845,7 +928,7 @@ class ContinuousSpace:
845
928
  def get_neighbors(
846
929
  self, pos: FloatCoordinate, radius: float, include_center: bool = True
847
930
  ) -> list[Agent]:
848
- """Get all objects within a certain radius.
931
+ """Get all agents within a certain radius.
849
932
 
850
933
  Args:
851
934
  pos: (x,y) coordinate tuple to center the search at.
@@ -872,7 +955,9 @@ class ContinuousSpace:
872
955
  def get_heading(
873
956
  self, pos_1: FloatCoordinate, pos_2: FloatCoordinate
874
957
  ) -> FloatCoordinate:
875
- """Get the heading angle between two points, accounting for toroidal space.
958
+ """Get the heading vector between two points, accounting for toroidal space.
959
+ It is possible to calculate the heading angle by applying the atan2 function to the
960
+ result.
876
961
 
877
962
  Args:
878
963
  pos_1, pos_2: Coordinate tuples for both points.
@@ -934,37 +1019,47 @@ class ContinuousSpace:
934
1019
  class NetworkGrid:
935
1020
  """Network Grid where each node contains zero or more agents."""
936
1021
 
937
- def __init__(self, G: Any) -> None:
938
- self.G = G
1022
+ def __init__(self, g: Any) -> None:
1023
+ """Create a new network.
1024
+
1025
+ Args:
1026
+ G: a NetworkX graph instance.
1027
+ """
1028
+ self.G = g
939
1029
  for node_id in self.G.nodes:
940
- G.nodes[node_id]["agent"] = list()
1030
+ g.nodes[node_id]["agent"] = self.default_val()
941
1031
 
942
- def place_agent(self, agent: Agent, node_id: int) -> None:
943
- """Place a agent in a node."""
1032
+ @staticmethod
1033
+ def default_val() -> list:
1034
+ """Default value for a new node."""
1035
+ return []
944
1036
 
945
- self._place_agent(agent, node_id)
1037
+ def place_agent(self, agent: Agent, node_id: int) -> None:
1038
+ """Place an agent in a node."""
1039
+ self.G.nodes[node_id]["agent"].append(agent)
946
1040
  agent.pos = node_id
947
1041
 
948
- def get_neighbors(self, node_id: int, include_center: bool = False) -> list[int]:
949
- """Get all adjacent nodes"""
950
-
951
- neighbors = list(self.G.neighbors(node_id))
952
- if include_center:
953
- neighbors.append(node_id)
954
-
1042
+ def get_neighbors(
1043
+ self, node_id: int, include_center: bool = False, radius: int = 1
1044
+ ) -> list[int]:
1045
+ """Get all adjacent nodes within a certain radius"""
1046
+ if radius == 1:
1047
+ neighbors = list(self.G.neighbors(node_id))
1048
+ if include_center:
1049
+ neighbors.append(node_id)
1050
+ else:
1051
+ neighbors_with_distance = nx.single_source_shortest_path_length(
1052
+ self.G, node_id, radius
1053
+ )
1054
+ if not include_center:
1055
+ del neighbors_with_distance[node_id]
1056
+ neighbors = sorted(neighbors_with_distance.keys())
955
1057
  return neighbors
956
1058
 
957
1059
  def move_agent(self, agent: Agent, node_id: int) -> None:
958
1060
  """Move an agent from its current node to a new node."""
959
-
960
1061
  self.remove_agent(agent)
961
- self._place_agent(agent, node_id)
962
- agent.pos = node_id
963
-
964
- def _place_agent(self, agent: Agent, node_id: int) -> None:
965
- """Place the agent at the correct node."""
966
-
967
- self.G.nodes[node_id]["agent"].append(agent)
1062
+ self.place_agent(agent, node_id)
968
1063
 
969
1064
  def remove_agent(self, agent: Agent) -> None:
970
1065
  """Remove the agent from the network and set its pos attribute to None."""
@@ -974,25 +1069,23 @@ class NetworkGrid:
974
1069
 
975
1070
  def is_cell_empty(self, node_id: int) -> bool:
976
1071
  """Returns a bool of the contents of a cell."""
977
- return not self.G.nodes[node_id]["agent"]
1072
+ return self.G.nodes[node_id]["agent"] == self.default_val()
978
1073
 
979
- def get_cell_list_contents(self, cell_list: list[int]) -> list[GridContent]:
980
- """Returns the contents of a list of cells ((x,y) tuples)
981
- Note: this method returns a list of `Agent`'s; `None` contents are excluded.
1074
+ def get_cell_list_contents(self, cell_list: list[int]) -> list[Agent]:
1075
+ """Returns a list of the agents contained in the nodes identified
1076
+ in `cell_list`; nodes with empty content are excluded.
982
1077
  """
983
1078
  return list(self.iter_cell_list_contents(cell_list))
984
1079
 
985
- def get_all_cell_contents(self) -> list[GridContent]:
986
- """Returns a list of the contents of the cells
987
- identified in cell_list."""
988
- return list(self.iter_cell_list_contents(self.G))
1080
+ def get_all_cell_contents(self) -> list[Agent]:
1081
+ """Returns a list of all the agents in the network."""
1082
+ return self.get_cell_list_contents(self.G)
989
1083
 
990
- def iter_cell_list_contents(self, cell_list: list[int]) -> list[GridContent]:
991
- """Returns an iterator of the contents of the cells
992
- identified in cell_list."""
993
- list_of_lists = [
1084
+ def iter_cell_list_contents(self, cell_list: list[int]) -> Iterator[Agent]:
1085
+ """Returns an iterator of the agents contained in the nodes identified
1086
+ in `cell_list`; nodes with empty content are excluded.
1087
+ """
1088
+ return itertools.chain.from_iterable(
994
1089
  self.G.nodes[node_id]["agent"]
995
- for node_id in cell_list
996
- if not self.is_cell_empty(node_id)
997
- ]
998
- return [item for sublist in list_of_lists for item in sublist]
1090
+ for node_id in itertools.filterfalse(self.is_cell_empty, cell_list)
1091
+ )