Mesa 1.1.1__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of Mesa might be problematic. Click here for more details.

Files changed (41) hide show
  1. {Mesa-1.1.1.dist-info → Mesa-1.2.0.dist-info}/LICENSE +1 -1
  2. {Mesa-1.1.1.dist-info → Mesa-1.2.0.dist-info}/METADATA +14 -12
  3. {Mesa-1.1.1.dist-info → Mesa-1.2.0.dist-info}/RECORD +41 -41
  4. {Mesa-1.1.1.dist-info → Mesa-1.2.0.dist-info}/WHEEL +1 -1
  5. mesa/__init__.py +8 -9
  6. mesa/agent.py +2 -3
  7. mesa/batchrunner.py +16 -23
  8. mesa/datacollection.py +15 -28
  9. mesa/main.py +4 -4
  10. mesa/model.py +2 -6
  11. mesa/space.py +298 -225
  12. mesa/time.py +25 -24
  13. mesa/visualization/ModularVisualization.py +5 -8
  14. mesa/visualization/TextVisualization.py +0 -3
  15. mesa/visualization/UserParam.py +8 -11
  16. mesa/visualization/__init__.py +0 -1
  17. mesa/visualization/modules/BarChartVisualization.py +7 -8
  18. mesa/visualization/modules/CanvasGridVisualization.py +1 -3
  19. mesa/visualization/modules/ChartVisualization.py +2 -3
  20. mesa/visualization/modules/HexGridVisualization.py +1 -3
  21. mesa/visualization/modules/NetworkVisualization.py +1 -2
  22. mesa/visualization/modules/PieChartVisualization.py +2 -6
  23. mesa/visualization/templates/js/GridDraw.js +5 -9
  24. mesa/visualization/templates/js/HexDraw.js +5 -9
  25. mesa/visualization/templates/js/InteractionHandler.js +0 -2
  26. tests/test_batchrunner.py +3 -4
  27. tests/test_batchrunnerMP.py +4 -4
  28. tests/test_datacollector.py +2 -2
  29. tests/test_examples.py +8 -5
  30. tests/test_grid.py +89 -36
  31. tests/test_import_namespace.py +0 -1
  32. tests/test_lifespan.py +4 -3
  33. tests/test_main.py +5 -1
  34. tests/test_scaffold.py +2 -1
  35. tests/test_space.py +13 -20
  36. tests/test_time.py +44 -14
  37. tests/test_tornado.py +4 -2
  38. tests/test_usersettableparam.py +4 -3
  39. tests/test_visualization.py +4 -8
  40. {Mesa-1.1.1.dist-info → Mesa-1.2.0.dist-info}/entry_points.txt +0 -0
  41. {Mesa-1.1.1.dist-info → Mesa-1.2.0.dist-info}/top_level.txt +0 -0
mesa/space.py CHANGED
@@ -4,10 +4,13 @@ Mesa Space Module
4
4
 
5
5
  Objects used to add a spatial component to a model.
6
6
 
7
- Grid: base grid, a simple list-of-lists.
8
- SingleGrid: grid which strictly enforces one object per cell.
9
- MultiGrid: extension to Grid where each cell is a set of objects.
10
-
7
+ Grid: base grid, which creates a rectangular grid.
8
+ SingleGrid: extension to Grid which strictly enforces one agent per cell.
9
+ MultiGrid: extension to Grid where each cell can contain a set of agents.
10
+ HexGrid: extension to Grid to handle hexagonal neighbors.
11
+ ContinuousSpace: a two-dimensional space where each agent has an arbitrary
12
+ position of `float`'s.
13
+ NetworkGrid: a network where each node contains zero or more agents.
11
14
  """
12
15
  # Instruction for PyLint to suppress variable name errors, since we have a
13
16
  # good reason to use one-character variable names for x and y.
@@ -17,19 +20,16 @@ MultiGrid: extension to Grid where each cell is a set of objects.
17
20
  # Remove this __future__ import once the oldest supported Python is 3.10
18
21
  from __future__ import annotations
19
22
 
20
- import itertools
21
23
  import collections
24
+ import itertools
22
25
  import math
23
- from warnings import warn
24
-
25
- import numpy as np
26
-
26
+ from numbers import Real
27
27
  from typing import (
28
28
  Any,
29
29
  Callable,
30
- List,
31
30
  Iterable,
32
31
  Iterator,
32
+ List,
33
33
  Sequence,
34
34
  Tuple,
35
35
  TypeVar,
@@ -37,11 +37,17 @@ from typing import (
37
37
  cast,
38
38
  overload,
39
39
  )
40
+ from warnings import warn
41
+
42
+ import networkx as nx
43
+ import numpy as np
44
+ import numpy.typing as npt
40
45
 
41
46
  # For Mypy
42
47
  from .agent import Agent
43
- from numbers import Real
44
- import numpy.typing as npt
48
+
49
+ # for better performance, we calculate the tuple to use in the is_integer function
50
+ _types_integer = (int, np.integer)
45
51
 
46
52
  Coordinate = Tuple[int, int]
47
53
  # used in ContinuousSpace
@@ -56,41 +62,35 @@ MultiGridContent = List[Agent]
56
62
  F = TypeVar("F", bound=Callable[..., Any])
57
63
 
58
64
 
59
- def clamp(x: float, lowest: float, highest: float) -> float:
60
- # much faster than np.clip for a scalar x.
61
- return lowest if x <= lowest else (highest if x >= highest else x)
62
-
63
-
64
65
  def accept_tuple_argument(wrapped_function: F) -> F:
65
66
  """Decorator to allow grid methods that take a list of (x, y) coord tuples
66
67
  to also handle a single position, by automatically wrapping tuple in
67
68
  single-item list rather than forcing user to do it."""
68
69
 
69
- def wrapper(*args: Any) -> Any:
70
- if isinstance(args[1], tuple) and len(args[1]) == 2:
71
- return wrapped_function(args[0], [args[1]])
70
+ def wrapper(grid_instance, positions) -> Any:
71
+ if isinstance(positions, tuple) and len(positions) == 2:
72
+ return wrapped_function(grid_instance, [positions])
72
73
  else:
73
- return wrapped_function(*args)
74
+ return wrapped_function(grid_instance, positions)
74
75
 
75
76
  return cast(F, wrapper)
76
77
 
77
78
 
78
79
  def is_integer(x: Real) -> bool:
79
80
  # Check if x is either a CPython integer or Numpy integer.
80
- return isinstance(x, (int, np.integer))
81
+ return isinstance(x, _types_integer)
81
82
 
82
83
 
83
- class Grid:
84
- """Base class for a square grid.
84
+ class _Grid:
85
+ """Base class for a rectangular grid.
85
86
 
86
- Grid cells are indexed by [x][y], where [0][0] is assumed to be the
87
- bottom-left and [width-1][height-1] is the top-right. If a grid is
87
+ Grid cells are indexed by [x, y], where [0, 0] is assumed to be the
88
+ bottom-left and [width-1, height-1] is the top-right. If a grid is
88
89
  toroidal, the top and bottom, and left and right, edges wrap to each other
89
90
 
90
91
  Properties:
91
92
  width, height: The grid's width and height.
92
93
  torus: Boolean which determines whether to treat the grid as a torus.
93
- grid: Internal list-of-lists which holds the grid cells themselves.
94
94
  """
95
95
 
96
96
  def __init__(self, width: int, height: int, torus: bool) -> None:
@@ -103,23 +103,42 @@ class Grid:
103
103
  self.height = height
104
104
  self.width = width
105
105
  self.torus = torus
106
+ self.num_cells = height * width
106
107
 
107
- self.grid: list[list[GridContent]]
108
- self.grid = [
108
+ # Internal list-of-lists which holds the grid cells themselves
109
+ self._grid: list[list[GridContent]]
110
+ self._grid = [
109
111
  [self.default_val() for _ in range(self.height)] for _ in range(self.width)
110
112
  ]
111
113
 
112
- # Add all cells to the empties list.
113
- self.empties = set(itertools.product(range(self.width), range(self.height)))
114
+ # Flag to check if the empties set has been created. Better than initializing
115
+ # _empties as set() because in this case it would become impossible to discern
116
+ # if the set hasn't still being built or if it has become empty after creation.
117
+ self._empties_built = False
114
118
 
115
119
  # Neighborhood Cache
116
- self._neighborhood_cache: dict[Any, list[Coordinate]] = dict()
120
+ self._neighborhood_cache: dict[Any, list[Coordinate]] = {}
117
121
 
118
122
  @staticmethod
119
123
  def default_val() -> None:
120
124
  """Default value for new cell elements."""
121
125
  return None
122
126
 
127
+ @property
128
+ def empties(self) -> set:
129
+ if not self._empties_built:
130
+ self.build_empties()
131
+ return self._empties
132
+
133
+ def build_empties(self) -> None:
134
+ self._empties = set(
135
+ filter(
136
+ self.is_cell_empty,
137
+ itertools.product(range(self.width), range(self.height)),
138
+ )
139
+ )
140
+ self._empties_built = True
141
+
123
142
  @overload
124
143
  def __getitem__(self, index: int) -> list[GridContent]:
125
144
  ...
@@ -142,55 +161,45 @@ class Grid:
142
161
 
143
162
  if isinstance(index, int):
144
163
  # grid[x]
145
- return self.grid[index]
164
+ return self._grid[index]
146
165
  elif isinstance(index[0], tuple):
147
- # grid[(x1, y1), (x2, y2)]
166
+ # grid[(x1, y1), (x2, y2), ...]
148
167
  index = cast(Sequence[Coordinate], index)
149
-
150
- cells = []
151
- for pos in index:
152
- x1, y1 = self.torus_adj(pos)
153
- cells.append(self.grid[x1][y1])
154
- return cells
168
+ return [self._grid[x][y] for x, y in map(self.torus_adj, index)]
155
169
 
156
170
  x, y = index
171
+ x_int, y_int = is_integer(x), is_integer(y)
157
172
 
158
- if is_integer(x) and is_integer(y):
173
+ if x_int and y_int:
159
174
  # grid[x, y]
160
175
  index = cast(Coordinate, index)
161
176
  x, y = self.torus_adj(index)
162
- return self.grid[x][y]
163
-
164
- if is_integer(x):
177
+ return self._grid[x][y]
178
+ elif x_int:
165
179
  # grid[x, :]
166
180
  x, _ = self.torus_adj((x, 0))
167
- x = slice(x, x + 1)
168
-
169
- if is_integer(y):
181
+ y = cast(slice, y)
182
+ return self._grid[x][y]
183
+ elif y_int:
170
184
  # grid[:, y]
171
185
  _, y = self.torus_adj((0, y))
172
- y = slice(y, y + 1)
173
-
174
- # grid[:, :]
175
- x, y = (cast(slice, x), cast(slice, y))
176
- cells = []
177
- for rows in self.grid[x]:
178
- for cell in rows[y]:
179
- cells.append(cell)
180
- return cells
181
-
182
- raise IndexError
186
+ x = cast(slice, x)
187
+ return [rows[y] for rows in self._grid[x]]
188
+ else:
189
+ # grid[:, :]
190
+ x, y = (cast(slice, x), cast(slice, y))
191
+ return [cell for rows in self._grid[x] for cell in rows[y]]
183
192
 
184
193
  def __iter__(self) -> Iterator[GridContent]:
185
194
  """Create an iterator that chains the rows of the grid together
186
195
  as if it is one list:"""
187
- return itertools.chain(*self.grid)
196
+ return itertools.chain(*self._grid)
188
197
 
189
198
  def coord_iter(self) -> Iterator[tuple[GridContent, int, int]]:
190
199
  """An iterator that returns coordinates as well as cell contents."""
191
200
  for row in range(self.width):
192
201
  for col in range(self.height):
193
- yield self.grid[row][col], row, col # agent, x, y
202
+ yield self._grid[row][col], row, col # agent, x, y
194
203
 
195
204
  def neighbor_iter(self, pos: Coordinate, moore: bool = True) -> Iterator[Agent]:
196
205
  """Iterate over position neighbors.
@@ -266,29 +275,50 @@ class Grid:
266
275
  if neighborhood is not None:
267
276
  return neighborhood
268
277
 
269
- coordinates: set[Coordinate] = set()
278
+ # We use a list instead of a dict for the neighborhood because it would
279
+ # be easier to port the code to Cython or Numba (for performance
280
+ # purpose), with minimal changes. To better understand how the
281
+ # algorithm was conceived, look at
282
+ # https://github.com/projectmesa/mesa/pull/1476#issuecomment-1306220403
283
+ # and the discussion in that PR in general.
284
+ neighborhood = []
270
285
 
271
286
  x, y = pos
272
- for dy in range(-radius, radius + 1):
273
- for dx in range(-radius, radius + 1):
274
- # Skip coordinates that are outside manhattan distance
275
- if not moore and abs(dx) + abs(dy) > radius:
276
- continue
287
+ if self.torus:
288
+ x_max_radius, y_max_radius = self.width // 2, self.height // 2
289
+ x_radius, y_radius = min(radius, x_max_radius), min(radius, y_max_radius)
290
+
291
+ # For each dimension, in the edge case where the radius is as big as
292
+ # possible and the dimension is even, we need to shrink by one the range
293
+ # of values, to avoid duplicates in neighborhood. For example, if
294
+ # the width is 4, while x, x_radius, and x_max_radius are 2, then
295
+ # (x + dx) has a value from 0 to 4 (inclusive), but this means that
296
+ # the 0 position is repeated since 0 % 4 and 4 % 4 are both 0.
297
+ xdim_even, ydim_even = (self.width + 1) % 2, (self.height + 1) % 2
298
+ kx = int(x_radius == x_max_radius and xdim_even)
299
+ ky = int(y_radius == y_max_radius and ydim_even)
300
+
301
+ for dx in range(-x_radius, x_radius + 1 - kx):
302
+ for dy in range(-y_radius, y_radius + 1 - ky):
303
+ if not moore and abs(dx) + abs(dy) > radius:
304
+ continue
277
305
 
278
- coord = (x + dx, y + dy)
306
+ nx, ny = (x + dx) % self.width, (y + dy) % self.height
307
+ neighborhood.append((nx, ny))
308
+ else:
309
+ x_range = range(max(0, x - radius), min(self.width, x + radius + 1))
310
+ y_range = range(max(0, y - radius), min(self.height, y + radius + 1))
279
311
 
280
- if self.out_of_bounds(coord):
281
- # Skip if not a torus and new coords out of bounds.
282
- if not self.torus:
312
+ for nx in x_range:
313
+ for ny in y_range:
314
+ if not moore and abs(nx - x) + abs(ny - y) > radius:
283
315
  continue
284
- coord = self.torus_adj(coord)
285
316
 
286
- coordinates.add(coord)
317
+ neighborhood.append((nx, ny))
287
318
 
288
- if not include_center:
289
- coordinates.discard(pos)
319
+ if not include_center and neighborhood:
320
+ neighborhood.remove(pos)
290
321
 
291
- neighborhood = sorted(coordinates)
292
322
  self._neighborhood_cache[cache_key] = neighborhood
293
323
 
294
324
  return neighborhood
@@ -367,34 +397,40 @@ class Grid:
367
397
  def iter_cell_list_contents(
368
398
  self, cell_list: Iterable[Coordinate]
369
399
  ) -> Iterator[Agent]:
370
- """Returns an iterator of the contents of the cells
371
- identified in cell_list.
400
+ """Returns an iterator of the agents contained in the cells identified
401
+ in `cell_list`; cells with empty content are excluded.
372
402
 
373
403
  Args:
374
404
  cell_list: Array-like of (x, y) tuples, or single tuple.
375
405
 
376
406
  Returns:
377
- An iterator of the contents of the cells identified in cell_list
407
+ An iterator of the agents contained in the cells identified in `cell_list`.
378
408
  """
379
- # Note: filter(None, iterator) filters away an element of iterator that
380
- # is falsy. Hence, iter_cell_list_contents returns only non-empty
381
- # contents.
382
- return filter(None, (self.grid[x][y] for x, y in cell_list))
409
+ # iter_cell_list_contents returns only non-empty contents.
410
+ return (
411
+ self._grid[x][y]
412
+ for x, y in itertools.filterfalse(self.is_cell_empty, cell_list)
413
+ )
383
414
 
384
415
  @accept_tuple_argument
385
416
  def get_cell_list_contents(self, cell_list: Iterable[Coordinate]) -> list[Agent]:
386
- """Returns a list of the contents of the cells
387
- identified in cell_list.
388
- Note: this method returns a list of `Agent`'s; `None` contents are excluded.
417
+ """Returns an iterator of the agents contained in the cells identified
418
+ in `cell_list`; cells with empty content are excluded.
389
419
 
390
420
  Args:
391
421
  cell_list: Array-like of (x, y) tuples, or single tuple.
392
422
 
393
423
  Returns:
394
- A list of the contents of the cells identified in cell_list
424
+ A list of the agents contained in the cells identified in `cell_list`.
395
425
  """
396
426
  return list(self.iter_cell_list_contents(cell_list))
397
427
 
428
+ def place_agent(self, agent: Agent, pos: Coordinate) -> None:
429
+ ...
430
+
431
+ def remove_agent(self, agent: Agent) -> None:
432
+ ...
433
+
398
434
  def move_agent(self, agent: Agent, pos: Coordinate) -> None:
399
435
  """Move an agent from its current position to a new position.
400
436
 
@@ -407,57 +443,53 @@ class Grid:
407
443
  self.remove_agent(agent)
408
444
  self.place_agent(agent, pos)
409
445
 
410
- def place_agent(self, agent: Agent, pos: Coordinate) -> None:
411
- """Place the agent at the specified location, and set its pos variable."""
412
- x, y = pos
413
- self.grid[x][y] = agent
414
- self.empties.discard(pos)
415
- agent.pos = pos
416
-
417
- def remove_agent(self, agent: Agent) -> None:
418
- """Remove the agent from the grid and set its pos attribute to None."""
419
- if (pos := agent.pos) is None:
446
+ def swap_pos(self, agent_a: Agent, agent_b: Agent) -> None:
447
+ """Swap agents positions"""
448
+ agents_no_pos = []
449
+ if (pos_a := agent_a.pos) is None:
450
+ agents_no_pos.append(agent_a)
451
+ if (pos_b := agent_b.pos) is None:
452
+ agents_no_pos.append(agent_b)
453
+ if agents_no_pos:
454
+ agents_no_pos = [f"<Agent id: {a.unique_id}>" for a in agents_no_pos]
455
+ raise Exception(f"{', '.join(agents_no_pos)} - not on the grid")
456
+
457
+ if pos_a == pos_b:
420
458
  return
421
- x, y = pos
422
- self.grid[x][y] = self.default_val()
423
- self.empties.add(pos)
424
- agent.pos = None
459
+
460
+ self.remove_agent(agent_a)
461
+ self.remove_agent(agent_b)
462
+
463
+ self.place_agent(agent_a, pos_b)
464
+ self.place_agent(agent_b, pos_a)
425
465
 
426
466
  def is_cell_empty(self, pos: Coordinate) -> bool:
427
467
  """Returns a bool of the contents of a cell."""
428
468
  x, y = pos
429
- return self.grid[x][y] == self.default_val()
469
+ return self._grid[x][y] == self.default_val()
430
470
 
431
471
  def move_to_empty(
432
472
  self, agent: Agent, cutoff: float = 0.998, num_agents: int | None = None
433
473
  ) -> None:
434
474
  """Moves agent to a random empty cell, vacating agent's old cell."""
435
- if len(self.empties) == 0:
475
+ if num_agents is not None:
476
+ warn(
477
+ (
478
+ "`num_agents` is being deprecated since it's no longer used "
479
+ "inside `move_to_empty`. It shouldn't be passed as a parameter."
480
+ ),
481
+ DeprecationWarning,
482
+ )
483
+ num_empty_cells = len(self.empties)
484
+ if num_empty_cells == 0:
436
485
  raise Exception("ERROR: No empty cells")
437
- if num_agents is None:
438
- try:
439
- num_agents = agent.model.schedule.get_agent_count()
440
- except AttributeError:
441
- raise Exception(
442
- "Your agent is not attached to a model, and so Mesa is unable\n"
443
- "to figure out the total number of agents you have created.\n"
444
- "This number is required in order to calculate the threshold\n"
445
- "for using a much faster algorithm to find an empty cell.\n"
446
- "In this case, you must specify `num_agents`."
447
- )
448
- new_pos = (0, 0) # Initialize it with a starting value.
449
- # This method is based on Agents.jl's random_empty() implementation.
450
- # See https://github.com/JuliaDynamics/Agents.jl/pull/541.
451
- # For the discussion, see
452
- # https://github.com/projectmesa/mesa/issues/1052.
453
- # This switch assumes the worst case (for this algorithm) of one
454
- # agent per position, which is not true in general but is appropriate
455
- # here.
456
- if clamp(num_agents / (self.width * self.height), 0.0, 1.0) < cutoff:
457
- # The default cutoff value provided is the break-even comparison
458
- # with the time taken in the else branching point.
459
- # The number is measured to be 0.998 in Agents.jl, but since Mesa
460
- # run under different environment, the number is different here.
486
+
487
+ # This method is based on Agents.jl's random_empty() implementation. See
488
+ # https://github.com/JuliaDynamics/Agents.jl/pull/541. For the discussion, see
489
+ # https://github.com/projectmesa/mesa/issues/1052. The default cutoff value
490
+ # provided is the break-even comparison with the time taken in the else
491
+ # branching point.
492
+ if 1 - num_empty_cells / self.num_cells < cutoff:
461
493
  while True:
462
494
  new_pos = (
463
495
  agent.random.randrange(self.width),
@@ -495,10 +527,17 @@ class Grid:
495
527
  return len(self.empties) > 0
496
528
 
497
529
 
498
- class SingleGrid(Grid):
499
- """Grid where each cell contains exactly at most one object."""
530
+ class SingleGrid(_Grid):
531
+ """Rectangular grid where each cell contains exactly at most one agent.
500
532
 
501
- empties: set[Coordinate] = set()
533
+ Grid cells are indexed by [x, y], where [0, 0] is assumed to be the
534
+ bottom-left and [width-1, height-1] is the top-right. If a grid is
535
+ toroidal, the top and bottom, and left and right, edges wrap to each other.
536
+
537
+ Properties:
538
+ width, height: The grid's width and height.
539
+ torus: Boolean which determines whether to treat the grid as a torus.
540
+ """
502
541
 
503
542
  def position_agent(
504
543
  self, agent: Agent, x: int | str = "random", y: int | str = "random"
@@ -509,39 +548,65 @@ class SingleGrid(Grid):
509
548
  If x or y are positive, they are used.
510
549
  Use 'swap_pos()' to swap agents positions.
511
550
  """
551
+ warn(
552
+ (
553
+ "`position_agent` is being deprecated; use instead "
554
+ "`place_agent` to place an agent at a specified "
555
+ "location or `move_to_empty` to place an agent "
556
+ "at a random empty cell."
557
+ ),
558
+ DeprecationWarning,
559
+ )
560
+
561
+ if not (isinstance(x, int) or x == "random"):
562
+ raise Exception(
563
+ "x must be an integer or a string 'random'."
564
+ f" Actual type: {type(x)}. Actual value: {x}."
565
+ )
566
+ if not (isinstance(y, int) or y == "random"):
567
+ raise Exception(
568
+ "y must be an integer or a string 'random'."
569
+ f" Actual type: {type(y)}. Actual value: {y}."
570
+ )
571
+
512
572
  if x == "random" or y == "random":
513
- if len(self.empties) == 0:
514
- raise Exception("ERROR: Grid full")
515
573
  self.move_to_empty(agent)
516
574
  else:
517
575
  coords = (x, y)
518
576
  self.place_agent(agent, coords)
519
577
 
520
578
  def place_agent(self, agent: Agent, pos: Coordinate) -> None:
579
+ """Place the agent at the specified location, and set its pos variable."""
521
580
  if self.is_cell_empty(pos):
522
- super().place_agent(agent, pos)
581
+ x, y = pos
582
+ self._grid[x][y] = agent
583
+ if self._empties_built:
584
+ self._empties.discard(pos)
585
+ agent.pos = pos
523
586
  else:
524
587
  raise Exception("Cell not empty")
525
588
 
589
+ def remove_agent(self, agent: Agent) -> None:
590
+ """Remove the agent from the grid and set its pos attribute to None."""
591
+ if (pos := agent.pos) is None:
592
+ return
593
+ x, y = pos
594
+ self._grid[x][y] = self.default_val()
595
+ if self._empties_built:
596
+ self._empties.add(pos)
597
+ agent.pos = None
598
+
526
599
 
527
- class MultiGrid(Grid):
528
- """Grid where each cell can contain more than one object.
600
+ class MultiGrid(_Grid):
601
+ """Rectangular grid where each cell can contain more than one agent.
529
602
 
530
- Grid cells are indexed by [x][y], where [0][0] is assumed to be at
531
- bottom-left and [width-1][height-1] is the top-right. If a grid is
603
+ Grid cells are indexed by [x, y], where [0, 0] is assumed to be at
604
+ bottom-left and [width-1, height-1] is the top-right. If a grid is
532
605
  toroidal, the top and bottom, and left and right, edges wrap to each other.
533
606
 
534
- Each grid cell holds a set object.
535
-
536
607
  Properties:
537
608
  width, height: The grid's width and height.
538
-
539
609
  torus: Boolean which determines whether to treat the grid as a torus.
540
-
541
- grid: Internal list-of-lists which holds the grid cells themselves.
542
-
543
- Methods:
544
- get_neighbors: Returns the objects surrounding a given cell.
545
610
  """
546
611
 
547
612
  grid: list[list[MultiGridContent]]
@@ -554,41 +619,42 @@ class MultiGrid(Grid):
554
619
  def place_agent(self, agent: Agent, pos: Coordinate) -> None:
555
620
  """Place the agent at the specified location, and set its pos variable."""
556
621
  x, y = pos
557
- if agent not in self.grid[x][y]:
558
- self.grid[x][y].append(agent)
559
- self.empties.discard(pos)
560
- agent.pos = pos
622
+ if agent.pos is None or agent not in self._grid[x][y]:
623
+ self._grid[x][y].append(agent)
624
+ agent.pos = pos
625
+ if self._empties_built:
626
+ self._empties.discard(pos)
561
627
 
562
628
  def remove_agent(self, agent: Agent) -> None:
563
629
  """Remove the agent from the given location and set its pos attribute to None."""
564
630
  pos = agent.pos
565
631
  x, y = pos
566
- self.grid[x][y].remove(agent)
567
- if self.is_cell_empty(pos):
568
- self.empties.add(pos)
632
+ self._grid[x][y].remove(agent)
633
+ if self._empties_built and self.is_cell_empty(pos):
634
+ self._empties.add(pos)
569
635
  agent.pos = None
570
636
 
571
637
  @accept_tuple_argument
572
638
  def iter_cell_list_contents(
573
639
  self, cell_list: Iterable[Coordinate]
574
- ) -> Iterator[MultiGridContent]:
575
- """Returns an iterator of the contents of the
576
- cells identified in cell_list.
640
+ ) -> Iterator[Agent]:
641
+ """Returns an iterator of the agents contained in the cells identified
642
+ in `cell_list`; cells with empty content are excluded.
577
643
 
578
644
  Args:
579
645
  cell_list: Array-like of (x, y) tuples, or single tuple.
580
646
 
581
647
  Returns:
582
- A iterator of the contents of the cells identified in cell_list
583
-
648
+ An iterator of the agents contained in the cells identified in `cell_list`.
584
649
  """
585
650
  return itertools.chain.from_iterable(
586
- self[x][y] for x, y in cell_list if not self.is_cell_empty((x, y))
651
+ self._grid[x][y]
652
+ for x, y in itertools.filterfalse(self.is_cell_empty, cell_list)
587
653
  )
588
654
 
589
655
 
590
- class HexGrid(Grid):
591
- """Hexagonal Grid: Extends Grid to handle hexagonal neighbors.
656
+ class HexGrid(SingleGrid):
657
+ """Hexagonal Grid: Extends SingleGrid to handle hexagonal neighbors.
592
658
 
593
659
  Functions according to odd-q rules.
594
660
  See http://www.redblobgames.com/grids/hexagons/#coordinates for more.
@@ -608,10 +674,10 @@ class HexGrid(Grid):
608
674
  def torus_adj_2d(self, pos: Coordinate) -> Coordinate:
609
675
  return pos[0] % self.width, pos[1] % self.height
610
676
 
611
- def iter_neighborhood(
677
+ def get_neighborhood(
612
678
  self, pos: Coordinate, include_center: bool = False, radius: int = 1
613
- ) -> Iterator[Coordinate]:
614
- """Return an iterator over cell coordinates that are in the
679
+ ) -> list[Coordinate]:
680
+ """Return a list of coordinates that are in the
615
681
  neighborhood of a certain point. To calculate the neighborhood
616
682
  for a HexGrid the parity of the x coordinate of the point is
617
683
  important, the neighborhood can be sketched as:
@@ -627,7 +693,7 @@ class HexGrid(Grid):
627
693
  radius: radius, in cells, of neighborhood to get.
628
694
 
629
695
  Returns:
630
- An iterator of coordinate tuples representing the neighborhood. For
696
+ A list of coordinate tuples representing the neighborhood. For
631
697
  example with radius 1, it will return list with number of elements
632
698
  equals at most 9 (8) if Moore, 5 (4) if Von Neumann (if not
633
699
  including the center).
@@ -636,20 +702,17 @@ class HexGrid(Grid):
636
702
  neighborhood = self._neighborhood_cache.get(cache_key, None)
637
703
 
638
704
  if neighborhood is not None:
639
- yield from neighborhood
640
- return
705
+ return neighborhood
641
706
 
642
707
  queue = collections.deque()
643
708
  queue.append(pos)
644
709
  coordinates = set()
645
710
 
646
711
  while radius > 0:
647
-
648
712
  level_size = len(queue)
649
713
  radius -= 1
650
714
 
651
- for i in range(level_size):
652
-
715
+ for _i in range(level_size):
653
716
  x, y = queue.pop()
654
717
 
655
718
  if x % 2 == 0:
@@ -697,7 +760,7 @@ class HexGrid(Grid):
697
760
  neighborhood = sorted(coordinates)
698
761
  self._neighborhood_cache[cache_key] = neighborhood
699
762
 
700
- yield from neighborhood
763
+ return neighborhood
701
764
 
702
765
  def neighbor_iter(self, pos: Coordinate) -> Iterator[Agent]:
703
766
  """Iterate over position neighbors.
@@ -712,11 +775,11 @@ class HexGrid(Grid):
712
775
  )
713
776
  return self.iter_neighbors(pos)
714
777
 
715
- def get_neighborhood(
778
+ def iter_neighborhood(
716
779
  self, pos: Coordinate, include_center: bool = False, radius: int = 1
717
- ) -> list[Coordinate]:
718
- """Return a list of cells that are in the neighborhood of a
719
- certain point.
780
+ ) -> Iterator[Coordinate]:
781
+ """Return an iterator over cell coordinates that are in the
782
+ neighborhood of a certain point.
720
783
 
721
784
  Args:
722
785
  pos: Coordinate tuple for the neighborhood to get.
@@ -725,10 +788,9 @@ class HexGrid(Grid):
725
788
  radius: radius, in cells, of neighborhood to get.
726
789
 
727
790
  Returns:
728
- A list of coordinate tuples representing the neighborhood;
729
- With radius 1
791
+ An iterator of coordinate tuples representing the neighborhood.
730
792
  """
731
- return list(self.iter_neighborhood(pos, include_center, radius))
793
+ yield from self.get_neighborhood(pos, include_center, radius)
732
794
 
733
795
  def iter_neighbors(
734
796
  self, pos: Coordinate, include_center: bool = False, radius: int = 1
@@ -745,7 +807,7 @@ class HexGrid(Grid):
745
807
  Returns:
746
808
  An iterator of non-None objects in the given neighborhood
747
809
  """
748
- neighborhood = self.iter_neighborhood(pos, include_center, radius)
810
+ neighborhood = self.get_neighborhood(pos, include_center, radius)
749
811
  return self.iter_cell_list_contents(neighborhood)
750
812
 
751
813
  def get_neighbors(
@@ -769,16 +831,14 @@ class HexGrid(Grid):
769
831
  class ContinuousSpace:
770
832
  """Continuous space where each agent can have an arbitrary position.
771
833
 
772
- Assumes that all agents are point objects, and have a pos property storing
773
- their position as an (x, y) tuple.
834
+ Assumes that all agents have a pos property storing their position as
835
+ an (x, y) tuple.
774
836
 
775
- This class uses a numpy array internally to store agent objects, to speed
837
+ This class uses a numpy array internally to store agents in order to speed
776
838
  up neighborhood lookups. This array is calculated on the first neighborhood
777
- lookup, and is reused (and updated) until agents are added or removed.
839
+ lookup, and is updated if agents are added or removed.
778
840
  """
779
841
 
780
- _grid = None
781
-
782
842
  def __init__(
783
843
  self,
784
844
  x_max: float,
@@ -811,18 +871,16 @@ class ContinuousSpace:
811
871
  self._agent_to_index: dict[Agent, int | None] = {}
812
872
 
813
873
  def _build_agent_cache(self):
814
- """Cache Agent positions to speed up neighbors calculations."""
874
+ """Cache agents positions to speed up neighbors calculations."""
815
875
  self._index_to_agent = {}
816
- agents = self._agent_to_index.keys()
817
- for idx, agent in enumerate(agents):
876
+ for idx, agent in enumerate(self._agent_to_index):
818
877
  self._agent_to_index[agent] = idx
819
878
  self._index_to_agent[idx] = agent
820
- self._agent_points = np.array(
821
- [self._index_to_agent[idx].pos for idx in range(len(agents))]
822
- )
879
+ # Since dicts are ordered by insertion, we can iterate through agents keys
880
+ self._agent_points = np.array([agent.pos for agent in self._agent_to_index])
823
881
 
824
882
  def _invalidate_agent_cache(self):
825
- """Clear cached data of Agents and positions in the space."""
883
+ """Clear cached data of agents and positions in the space."""
826
884
  self._agent_points = None
827
885
  self._index_to_agent = {}
828
886
 
@@ -852,18 +910,17 @@ class ContinuousSpace:
852
910
  # instead of invalidating the full cache,
853
911
  # apply the move to the cached values
854
912
  idx = self._agent_to_index[agent]
855
- self._agent_points[idx, 0] = pos[0]
856
- self._agent_points[idx, 1] = pos[1]
913
+ self._agent_points[idx] = pos
857
914
 
858
915
  def remove_agent(self, agent: Agent) -> None:
859
- """Remove an agent from the simulation.
916
+ """Remove an agent from the space.
860
917
 
861
918
  Args:
862
919
  agent: The agent object to remove
863
920
  """
864
921
  if agent not in self._agent_to_index:
865
922
  raise Exception("Agent does not exist in the space")
866
- self._agent_to_index.pop(agent)
923
+ del self._agent_to_index[agent]
867
924
 
868
925
  self._invalidate_agent_cache()
869
926
  agent.pos = None
@@ -871,7 +928,7 @@ class ContinuousSpace:
871
928
  def get_neighbors(
872
929
  self, pos: FloatCoordinate, radius: float, include_center: bool = True
873
930
  ) -> list[Agent]:
874
- """Get all objects within a certain radius.
931
+ """Get all agents within a certain radius.
875
932
 
876
933
  Args:
877
934
  pos: (x,y) coordinate tuple to center the search at.
@@ -898,7 +955,9 @@ class ContinuousSpace:
898
955
  def get_heading(
899
956
  self, pos_1: FloatCoordinate, pos_2: FloatCoordinate
900
957
  ) -> FloatCoordinate:
901
- """Get the heading angle between two points, accounting for toroidal space.
958
+ """Get the heading vector between two points, accounting for toroidal space.
959
+ It is possible to calculate the heading angle by applying the atan2 function to the
960
+ result.
902
961
 
903
962
  Args:
904
963
  pos_1, pos_2: Coordinate tuples for both points.
@@ -960,29 +1019,45 @@ class ContinuousSpace:
960
1019
  class NetworkGrid:
961
1020
  """Network Grid where each node contains zero or more agents."""
962
1021
 
963
- def __init__(self, G: Any) -> None:
964
- self.G = G
1022
+ def __init__(self, g: Any) -> None:
1023
+ """Create a new network.
1024
+
1025
+ Args:
1026
+ G: a NetworkX graph instance.
1027
+ """
1028
+ self.G = g
965
1029
  for node_id in self.G.nodes:
966
- G.nodes[node_id]["agent"] = list()
1030
+ g.nodes[node_id]["agent"] = self.default_val()
967
1031
 
968
- def place_agent(self, agent: Agent, node_id: int) -> None:
969
- """Place a agent in a node."""
1032
+ @staticmethod
1033
+ def default_val() -> list:
1034
+ """Default value for a new node."""
1035
+ return []
970
1036
 
1037
+ def place_agent(self, agent: Agent, node_id: int) -> None:
1038
+ """Place an agent in a node."""
971
1039
  self.G.nodes[node_id]["agent"].append(agent)
972
1040
  agent.pos = node_id
973
1041
 
974
- def get_neighbors(self, node_id: int, include_center: bool = False) -> list[int]:
975
- """Get all adjacent nodes"""
976
-
977
- neighbors = list(self.G.neighbors(node_id))
978
- if include_center:
979
- neighbors.append(node_id)
980
-
1042
+ def get_neighbors(
1043
+ self, node_id: int, include_center: bool = False, radius: int = 1
1044
+ ) -> list[int]:
1045
+ """Get all adjacent nodes within a certain radius"""
1046
+ if radius == 1:
1047
+ neighbors = list(self.G.neighbors(node_id))
1048
+ if include_center:
1049
+ neighbors.append(node_id)
1050
+ else:
1051
+ neighbors_with_distance = nx.single_source_shortest_path_length(
1052
+ self.G, node_id, radius
1053
+ )
1054
+ if not include_center:
1055
+ del neighbors_with_distance[node_id]
1056
+ neighbors = sorted(neighbors_with_distance.keys())
981
1057
  return neighbors
982
1058
 
983
1059
  def move_agent(self, agent: Agent, node_id: int) -> None:
984
1060
  """Move an agent from its current node to a new node."""
985
-
986
1061
  self.remove_agent(agent)
987
1062
  self.place_agent(agent, node_id)
988
1063
 
@@ -994,25 +1069,23 @@ class NetworkGrid:
994
1069
 
995
1070
  def is_cell_empty(self, node_id: int) -> bool:
996
1071
  """Returns a bool of the contents of a cell."""
997
- return not self.G.nodes[node_id]["agent"]
1072
+ return self.G.nodes[node_id]["agent"] == self.default_val()
998
1073
 
999
- def get_cell_list_contents(self, cell_list: list[int]) -> list[GridContent]:
1000
- """Returns the contents of a list of cells ((x,y) tuples)
1001
- Note: this method returns a list of `Agent`'s; `None` contents are excluded.
1074
+ def get_cell_list_contents(self, cell_list: list[int]) -> list[Agent]:
1075
+ """Returns a list of the agents contained in the nodes identified
1076
+ in `cell_list`; nodes with empty content are excluded.
1002
1077
  """
1003
1078
  return list(self.iter_cell_list_contents(cell_list))
1004
1079
 
1005
- def get_all_cell_contents(self) -> list[GridContent]:
1006
- """Returns a list of the contents of the cells
1007
- identified in cell_list."""
1008
- return list(self.iter_cell_list_contents(self.G))
1080
+ def get_all_cell_contents(self) -> list[Agent]:
1081
+ """Returns a list of all the agents in the network."""
1082
+ return self.get_cell_list_contents(self.G)
1009
1083
 
1010
- def iter_cell_list_contents(self, cell_list: list[int]) -> list[GridContent]:
1011
- """Returns an iterator of the contents of the cells
1012
- identified in cell_list."""
1013
- list_of_lists = [
1084
+ def iter_cell_list_contents(self, cell_list: list[int]) -> Iterator[Agent]:
1085
+ """Returns an iterator of the agents contained in the nodes identified
1086
+ in `cell_list`; nodes with empty content are excluded.
1087
+ """
1088
+ return itertools.chain.from_iterable(
1014
1089
  self.G.nodes[node_id]["agent"]
1015
- for node_id in cell_list
1016
- if not self.is_cell_empty(node_id)
1017
- ]
1018
- return [item for sublist in list_of_lists for item in sublist]
1090
+ for node_id in itertools.filterfalse(self.is_cell_empty, cell_list)
1091
+ )