Mesa 3.0.0a3__py3-none-any.whl → 3.0.0a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of Mesa might be problematic. Click here for more details.

Files changed (40) hide show
  1. mesa/__init__.py +2 -3
  2. mesa/agent.py +193 -75
  3. mesa/batchrunner.py +18 -23
  4. mesa/cookiecutter-mesa/hooks/post_gen_project.py +2 -0
  5. mesa/cookiecutter-mesa/{{cookiecutter.snake}}/{{cookiecutter.snake}}/__init__.py +1 -0
  6. mesa/datacollection.py +138 -27
  7. mesa/experimental/UserParam.py +67 -0
  8. mesa/experimental/__init__.py +5 -1
  9. mesa/experimental/cell_space/__init__.py +7 -0
  10. mesa/experimental/cell_space/cell.py +61 -20
  11. mesa/experimental/cell_space/cell_agent.py +12 -7
  12. mesa/experimental/cell_space/cell_collection.py +54 -17
  13. mesa/experimental/cell_space/discrete_space.py +16 -5
  14. mesa/experimental/cell_space/grid.py +19 -8
  15. mesa/experimental/cell_space/network.py +9 -7
  16. mesa/experimental/cell_space/voronoi.py +26 -33
  17. mesa/experimental/components/altair.py +81 -0
  18. mesa/experimental/components/matplotlib.py +242 -0
  19. mesa/experimental/devs/__init__.py +2 -0
  20. mesa/experimental/devs/eventlist.py +36 -15
  21. mesa/experimental/devs/examples/epstein_civil_violence.py +71 -39
  22. mesa/experimental/devs/examples/wolf_sheep.py +43 -44
  23. mesa/experimental/devs/simulator.py +55 -15
  24. mesa/experimental/solara_viz.py +453 -0
  25. mesa/main.py +6 -4
  26. mesa/model.py +64 -61
  27. mesa/space.py +154 -123
  28. mesa/time.py +57 -67
  29. mesa/visualization/UserParam.py +19 -6
  30. mesa/visualization/__init__.py +14 -2
  31. mesa/visualization/components/altair.py +18 -1
  32. mesa/visualization/components/matplotlib.py +26 -2
  33. mesa/visualization/solara_viz.py +231 -225
  34. mesa/visualization/utils.py +9 -0
  35. {mesa-3.0.0a3.dist-info → mesa-3.0.0a5.dist-info}/METADATA +2 -1
  36. mesa-3.0.0a5.dist-info/RECORD +44 -0
  37. mesa-3.0.0a3.dist-info/RECORD +0 -39
  38. {mesa-3.0.0a3.dist-info → mesa-3.0.0a5.dist-info}/WHEEL +0 -0
  39. {mesa-3.0.0a3.dist-info → mesa-3.0.0a5.dist-info}/entry_points.txt +0 -0
  40. {mesa-3.0.0a3.dist-info → mesa-3.0.0a5.dist-info}/licenses/LICENSE +0 -0
mesa/datacollection.py CHANGED
@@ -1,20 +1,19 @@
1
- """
2
- Mesa Data Collection Module
3
- ===========================
1
+ """Mesa Data Collection Module.
4
2
 
5
3
  DataCollector is meant to provide a simple, standard way to collect data
6
- generated by a Mesa model. It collects three types of data: model-level data,
7
- agent-level data, and tables.
4
+ generated by a Mesa model. It collects four types of data: model-level data,
5
+ agent-level data, agent-type-level data, and tables.
8
6
 
9
- A DataCollector is instantiated with two dictionaries of reporter names and
10
- associated variable names or functions for each, one for model-level data and
11
- one for agent-level data; a third dictionary provides table names and columns.
12
- Variable names are converted into functions which retrieve attributes of that
13
- name.
7
+ A DataCollector is instantiated with three dictionaries of reporter names and
8
+ associated variable names or functions for each, one for model-level data,
9
+ one for agent-level data, and one for agent-type-level data; a fourth dictionary
10
+ provides table names and columns. Variable names are converted into functions
11
+ which retrieve attributes of that name.
14
12
 
15
13
  When the collect() method is called, each model-level function is called, with
16
14
  the model as the argument, and the results associated with the relevant
17
- variable. Then the agent-level functions are called on each agent.
15
+ variable. Then the agent-level functions are called on each agent, and the
16
+ agent-type-level functions are called on each agent of the specified type.
18
17
 
19
18
  Additionally, other objects can write directly to tables by passing in an
20
19
  appropriate dictionary object for a table row.
@@ -23,19 +22,18 @@ The DataCollector then stores the data it collects in dictionaries:
23
22
  * model_vars maps each reporter to a list of its values
24
23
  * tables maps each table to a dictionary, with each column as a key with a
25
24
  list as its value.
26
- * _agent_records maps each model step to a list of each agents id
25
+ * _agent_records maps each model step to a list of each agent's id
27
26
  and its values.
27
+ * _agenttype_records maps each model step to a dictionary of agent types,
28
+ each containing a list of each agent's id and its values.
28
29
 
29
30
  Finally, DataCollector can create a pandas DataFrame from each collection.
30
-
31
- The default DataCollector here makes several assumptions:
32
- * The model has an agent list called agents
33
- * For collecting agent-level variables, agents must have a unique_id
34
31
  """
35
32
 
36
33
  import contextlib
37
34
  import itertools
38
35
  import types
36
+ import warnings
39
37
  from copy import deepcopy
40
38
  from functools import partial
41
39
 
@@ -46,24 +44,25 @@ with contextlib.suppress(ImportError):
46
44
  class DataCollector:
47
45
  """Class for collecting data generated by a Mesa model.
48
46
 
49
- A DataCollector is instantiated with dictionaries of names of model- and
50
- agent-level variables to collect, associated with attribute names or
51
- functions which actually collect them. When the collect(...) method is
52
- called, it collects these attributes and executes these functions one by
53
- one and stores the results.
47
+ A DataCollector is instantiated with dictionaries of names of model-,
48
+ agent-, and agent-type-level variables to collect, associated with
49
+ attribute names or functions which actually collect them. When the
50
+ collect(...) method is called, it collects these attributes and executes
51
+ these functions one by one and stores the results.
54
52
  """
55
53
 
56
54
  def __init__(
57
55
  self,
58
56
  model_reporters=None,
59
57
  agent_reporters=None,
58
+ agenttype_reporters=None,
60
59
  tables=None,
61
60
  ):
62
- """
63
- Instantiate a DataCollector with lists of model and agent reporters.
64
- Both model_reporters and agent_reporters accept a dictionary mapping a
65
- variable name to either an attribute name, a function, a method of a class/instance,
66
- or a function with parameters placed in a list.
61
+ """Instantiate a DataCollector with lists of model, agent, and agent-type reporters.
62
+
63
+ Both model_reporters, agent_reporters, and agenttype_reporters accept a
64
+ dictionary mapping a variable name to either an attribute name, a function,
65
+ a method of a class/instance, or a function with parameters placed in a list.
67
66
 
68
67
  Model reporters can take four types of arguments:
69
68
  1. Lambda function:
@@ -87,6 +86,10 @@ class DataCollector:
87
86
  4. Functions with parameters placed in a list:
88
87
  {"Agent_Function": [function, [param_1, param_2]]}
89
88
 
89
+ Agenttype reporters take a dictionary mapping agent types to dictionaries
90
+ of reporter names and attributes/funcs/methods, similar to agent_reporters:
91
+ {Wolf: {"energy": lambda a: a.energy}}
92
+
90
93
  The tables arg accepts a dictionary mapping names of tables to lists of
91
94
  columns. For example, if we want to allow agents to write their age
92
95
  when they are destroyed (to keep track of lifespans), it might look
@@ -96,6 +99,8 @@ class DataCollector:
96
99
  Args:
97
100
  model_reporters: Dictionary of reporter names and attributes/funcs/methods.
98
101
  agent_reporters: Dictionary of reporter names and attributes/funcs/methods.
102
+ agenttype_reporters: Dictionary of agent types to dictionaries of
103
+ reporter names and attributes/funcs/methods.
99
104
  tables: Dictionary of table names to lists of column names.
100
105
 
101
106
  Notes:
@@ -105,9 +110,11 @@ class DataCollector:
105
110
  """
106
111
  self.model_reporters = {}
107
112
  self.agent_reporters = {}
113
+ self.agenttype_reporters = {}
108
114
 
109
115
  self.model_vars = {}
110
116
  self._agent_records = {}
117
+ self._agenttype_records = {}
111
118
  self.tables = {}
112
119
 
113
120
  if model_reporters is not None:
@@ -118,6 +125,11 @@ class DataCollector:
118
125
  for name, reporter in agent_reporters.items():
119
126
  self._new_agent_reporter(name, reporter)
120
127
 
128
+ if agenttype_reporters is not None:
129
+ for agent_type, reporters in agenttype_reporters.items():
130
+ for name, reporter in reporters.items():
131
+ self._new_agenttype_reporter(agent_type, name, reporter)
132
+
121
133
  if tables is not None:
122
134
  for name, columns in tables.items():
123
135
  self._new_table(name, columns)
@@ -165,6 +177,38 @@ class DataCollector:
165
177
 
166
178
  self.agent_reporters[name] = reporter
167
179
 
180
+ def _new_agenttype_reporter(self, agent_type, name, reporter):
181
+ """Add a new agent-type-level reporter to collect.
182
+
183
+ Args:
184
+ agent_type: The type of agent to collect data for.
185
+ name: Name of the agent-type-level variable to collect.
186
+ reporter: Attribute string, function object, method of a class/instance, or
187
+ function with parameters placed in a list that returns the
188
+ variable when given an agent instance.
189
+ """
190
+ if agent_type not in self.agenttype_reporters:
191
+ self.agenttype_reporters[agent_type] = {}
192
+
193
+ # Use the same logic as _new_agent_reporter
194
+ if isinstance(reporter, str):
195
+ attribute_name = reporter
196
+
197
+ def attr_reporter(agent):
198
+ return getattr(agent, attribute_name, None)
199
+
200
+ reporter = attr_reporter
201
+
202
+ elif isinstance(reporter, list):
203
+ func, params = reporter[0], reporter[1]
204
+
205
+ def func_with_params(agent):
206
+ return func(agent, *params)
207
+
208
+ reporter = func_with_params
209
+
210
+ self.agenttype_reporters[agent_type][name] = reporter
211
+
168
212
  def _new_table(self, table_name, table_columns):
169
213
  """Add a new table that objects can write to.
170
214
 
@@ -192,6 +236,34 @@ class DataCollector:
192
236
  )
193
237
  return agent_records
194
238
 
239
+ def _record_agenttype(self, model, agent_type):
240
+ """Record agent-type data in a mapping of functions and agents."""
241
+ rep_funcs = self.agenttype_reporters[agent_type].values()
242
+
243
+ def get_reports(agent):
244
+ _prefix = (agent.model.steps, agent.unique_id)
245
+ reports = tuple(rep(agent) for rep in rep_funcs)
246
+ return _prefix + reports
247
+
248
+ agent_types = model.agent_types
249
+ if agent_type in agent_types:
250
+ agents = model.agents_by_type[agent_type]
251
+ else:
252
+ from mesa import Agent
253
+
254
+ if issubclass(agent_type, Agent):
255
+ agents = [
256
+ agent for agent in model.agents if isinstance(agent, agent_type)
257
+ ]
258
+ else:
259
+ # Raise error if agent_type is not in model.agent_types
260
+ raise ValueError(
261
+ f"Agent type {agent_type} is not recognized as an Agent type in the model or Agent subclass. Use an Agent (sub)class, like {agent_types}."
262
+ )
263
+
264
+ agenttype_records = map(get_reports, agents)
265
+ return agenttype_records
266
+
195
267
  def collect(self, model):
196
268
  """Collect all the data for the given model object."""
197
269
  if self.model_reporters:
@@ -210,7 +282,6 @@ class DataCollector:
210
282
  elif isinstance(reporter, list):
211
283
  self.model_vars[var].append(deepcopy(reporter[0](*reporter[1])))
212
284
  # Assume it's a callable otherwise (e.g., method)
213
- # TODO: Check if method of a class explicitly
214
285
  else:
215
286
  self.model_vars[var].append(deepcopy(reporter()))
216
287
 
@@ -218,6 +289,14 @@ class DataCollector:
218
289
  agent_records = self._record_agents(model)
219
290
  self._agent_records[model.steps] = list(agent_records)
220
291
 
292
+ if self.agenttype_reporters:
293
+ self._agenttype_records[model.steps] = {}
294
+ for agent_type in self.agenttype_reporters:
295
+ agenttype_records = self._record_agenttype(model, agent_type)
296
+ self._agenttype_records[model.steps][agent_type] = list(
297
+ agenttype_records
298
+ )
299
+
221
300
  def add_table_row(self, table_name, row, ignore_missing=False):
222
301
  """Add a row dictionary to a specific table.
223
302
 
@@ -274,6 +353,38 @@ class DataCollector:
274
353
  )
275
354
  return df
276
355
 
356
+ def get_agenttype_vars_dataframe(self, agent_type):
357
+ """Create a pandas DataFrame from the agent-type variables for a specific agent type.
358
+
359
+ The DataFrame has one column for each variable, with two additional
360
+ columns for tick and agent_id.
361
+
362
+ Args:
363
+ agent_type: The type of agent to get the data for.
364
+ """
365
+ # Check if self.agenttype_reporters dictionary is empty for this agent type, if so return empty DataFrame
366
+ if agent_type not in self.agenttype_reporters:
367
+ warnings.warn(
368
+ f"No agent-type reporters have been defined for {agent_type} in the DataCollector, returning empty DataFrame.",
369
+ UserWarning,
370
+ stacklevel=2,
371
+ )
372
+ return pd.DataFrame()
373
+
374
+ all_records = itertools.chain.from_iterable(
375
+ records[agent_type]
376
+ for records in self._agenttype_records.values()
377
+ if agent_type in records
378
+ )
379
+ rep_names = list(self.agenttype_reporters[agent_type])
380
+
381
+ df = pd.DataFrame.from_records(
382
+ data=all_records,
383
+ columns=["Step", "AgentID", *rep_names],
384
+ index=["Step", "AgentID"],
385
+ )
386
+ return df
387
+
277
388
  def get_table_dataframe(self, table_name):
278
389
  """Create a pandas DataFrame from a particular table.
279
390
 
@@ -0,0 +1,67 @@
1
+ """helper classes."""
2
+
3
+
4
+ class UserParam: # noqa: D101
5
+ _ERROR_MESSAGE = "Missing or malformed inputs for '{}' Option '{}'"
6
+
7
+ def maybe_raise_error(self, param_type, valid): # noqa: D102
8
+ if valid:
9
+ return
10
+ msg = self._ERROR_MESSAGE.format(param_type, self.label)
11
+ raise ValueError(msg)
12
+
13
+
14
+ class Slider(UserParam):
15
+ """A number-based slider input with settable increment.
16
+
17
+ Example:
18
+ slider_option = Slider("My Slider", value=123, min=10, max=200, step=0.1)
19
+
20
+ Args:
21
+ label: The displayed label in the UI
22
+ value: The initial value of the slider
23
+ min: The minimum possible value of the slider
24
+ max: The maximum possible value of the slider
25
+ step: The step between min and max for a range of possible values
26
+ dtype: either int or float
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ label="",
32
+ value=None,
33
+ min=None,
34
+ max=None,
35
+ step=1,
36
+ dtype=None,
37
+ ):
38
+ """Slider class.
39
+
40
+ Args:
41
+ label: The displayed label in the UI
42
+ value: The initial value of the slider
43
+ min: The minimum possible value of the slider
44
+ max: The maximum possible value of the slider
45
+ step: The step between min and max for a range of possible values
46
+ dtype: either int or float
47
+ """
48
+ self.label = label
49
+ self.value = value
50
+ self.min = min
51
+ self.max = max
52
+ self.step = step
53
+
54
+ # Validate option type to make sure values are supplied properly
55
+ valid = not (self.value is None or self.min is None or self.max is None)
56
+ self.maybe_raise_error("slider", valid)
57
+
58
+ if dtype is None:
59
+ self.is_float_slider = self._check_values_are_float(value, min, max, step)
60
+ else:
61
+ self.is_float_slider = dtype is float
62
+
63
+ def _check_values_are_float(self, value, min, max, step):
64
+ return any(isinstance(n, float) for n in (value, min, max, step))
65
+
66
+ def get(self, attr): # noqa: D102
67
+ return getattr(self, attr)
@@ -1,3 +1,7 @@
1
+ """Experimental init."""
2
+
1
3
  from mesa.experimental import cell_space
2
4
 
3
- __all__ = ["cell_space"]
5
+ from .solara_viz import JupyterViz, Slider, SolaraViz, make_text
6
+
7
+ __all__ = ["cell_space", "JupyterViz", "SolaraViz", "make_text", "Slider"]
@@ -1,3 +1,10 @@
1
+ """Cell spaces.
2
+
3
+ Cell spaces offer an alternative API for discrete spaces. It is experimental and under development. The API is more
4
+ expressive that the default grids available in `mesa.space`.
5
+
6
+ """
7
+
1
8
  from mesa.experimental.cell_space.cell import Cell
2
9
  from mesa.experimental.cell_space.cell_agent import CellAgent
3
10
  from mesa.experimental.cell_space.cell_collection import CellCollection
@@ -1,14 +1,19 @@
1
+ """The Cell in a cell space."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
- from functools import cache
5
+ from functools import cache, cached_property
4
6
  from random import Random
5
7
  from typing import TYPE_CHECKING
6
8
 
7
9
  from mesa.experimental.cell_space.cell_collection import CellCollection
8
10
 
9
11
  if TYPE_CHECKING:
12
+ from mesa.agent import Agent
10
13
  from mesa.experimental.cell_space.cell_agent import CellAgent
11
14
 
15
+ Coordinate = tuple[int, ...]
16
+
12
17
 
13
18
  class Cell:
14
19
  """The cell represents a position in a discrete space.
@@ -24,11 +29,12 @@ class Cell:
24
29
 
25
30
  __slots__ = [
26
31
  "coordinate",
27
- "_connections",
32
+ "connections",
28
33
  "agents",
29
34
  "capacity",
30
35
  "properties",
31
36
  "random",
37
+ "__dict__",
32
38
  ]
33
39
 
34
40
  # def __new__(cls,
@@ -42,34 +48,39 @@ class Cell:
42
48
 
43
49
  def __init__(
44
50
  self,
45
- coordinate: tuple[int, ...],
46
- capacity: float | None = None,
51
+ coordinate: Coordinate,
52
+ capacity: int | None = None,
47
53
  random: Random | None = None,
48
54
  ) -> None:
49
- """ "
55
+ """Initialise the cell.
50
56
 
51
57
  Args:
52
- coordinate:
58
+ coordinate: coordinates of the cell
53
59
  capacity (int) : the capacity of the cell. If None, the capacity is infinite
54
60
  random (Random) : the random number generator to use
55
61
 
56
62
  """
57
63
  super().__init__()
58
64
  self.coordinate = coordinate
59
- self._connections: list[Cell] = [] # TODO: change to CellCollection?
60
- self.agents = [] # TODO:: change to AgentSet or weakrefs? (neither is very performant, )
61
- self.capacity = capacity
62
- self.properties: dict[str, object] = {}
65
+ self.connections: dict[Coordinate, Cell] = {}
66
+ self.agents: list[
67
+ Agent
68
+ ] = [] # TODO:: change to AgentSet or weakrefs? (neither is very performant, )
69
+ self.capacity: int = capacity
70
+ self.properties: dict[Coordinate, object] = {}
63
71
  self.random = random
64
72
 
65
- def connect(self, other: Cell) -> None:
73
+ def connect(self, other: Cell, key: Coordinate | None = None) -> None:
66
74
  """Connects this cell to another cell.
67
75
 
68
76
  Args:
69
77
  other (Cell): other cell to connect to
78
+ key (Tuple[int, ...]): key for the connection. Should resemble a relative coordinate
70
79
 
71
80
  """
72
- self._connections.append(other)
81
+ if key is None:
82
+ key = other.coordinate
83
+ self.connections[key] = other
73
84
 
74
85
  def disconnect(self, other: Cell) -> None:
75
86
  """Disconnects this cell from another cell.
@@ -78,7 +89,9 @@ class Cell:
78
89
  other (Cell): other cell to remove from connections
79
90
 
80
91
  """
81
- self._connections.remove(other)
92
+ keys_to_remove = [k for k, v in self.connections.items() if v == other]
93
+ for key in keys_to_remove:
94
+ del self.connections[key]
82
95
 
83
96
  def add_agent(self, agent: CellAgent) -> None:
84
97
  """Adds an agent to the cell.
@@ -116,34 +129,62 @@ class Cell:
116
129
  """Returns a bool of the contents of a cell."""
117
130
  return len(self.agents) == self.capacity
118
131
 
119
- def __repr__(self):
132
+ def __repr__(self): # noqa
120
133
  return f"Cell({self.coordinate}, {self.agents})"
121
134
 
135
+ @cached_property
136
+ def neighborhood(self) -> CellCollection:
137
+ """Returns the direct neighborhood of the cell.
138
+
139
+ This is equivalent to cell.get_neighborhood(radius=1)
140
+
141
+ """
142
+ return self.get_neighborhood()
143
+
122
144
  # FIXME: Revisit caching strategy on methods
123
145
  @cache # noqa: B019
124
- def neighborhood(self, radius=1, include_center=False):
125
- return CellCollection(
146
+ def get_neighborhood(
147
+ self, radius: int = 1, include_center: bool = False
148
+ ) -> CellCollection:
149
+ """Returns a list of all neighboring cells for the given radius.
150
+
151
+ For getting the direct neighborhood (i.e., radius=1) you can also use
152
+ the `neighborhood` property.
153
+
154
+ Args:
155
+ radius (int): the radius of the neighborhood
156
+ include_center (bool): include the center of the neighborhood
157
+
158
+ Returns:
159
+ a list of all neighboring cells
160
+
161
+ """
162
+ return CellCollection[Cell](
126
163
  self._neighborhood(radius=radius, include_center=include_center),
127
164
  random=self.random,
128
165
  )
129
166
 
130
167
  # FIXME: Revisit caching strategy on methods
131
168
  @cache # noqa: B019
132
- def _neighborhood(self, radius=1, include_center=False):
169
+ def _neighborhood(
170
+ self, radius: int = 1, include_center: bool = False
171
+ ) -> dict[Cell, list[Agent]]:
133
172
  # if radius == 0:
134
173
  # return {self: self.agents}
135
174
  if radius < 1:
136
175
  raise ValueError("radius must be larger than one")
137
176
  if radius == 1:
138
- neighborhood = {neighbor: neighbor.agents for neighbor in self._connections}
177
+ neighborhood = {
178
+ neighbor: neighbor.agents for neighbor in self.connections.values()
179
+ }
139
180
  if not include_center:
140
181
  return neighborhood
141
182
  else:
142
183
  neighborhood[self] = self.agents
143
184
  return neighborhood
144
185
  else:
145
- neighborhood = {}
146
- for neighbor in self._connections:
186
+ neighborhood: dict[Cell, list[Agent]] = {}
187
+ for neighbor in self.connections.values():
147
188
  neighborhood.update(
148
189
  neighbor._neighborhood(radius - 1, include_center=True)
149
190
  )
@@ -1,3 +1,5 @@
1
+ """An agent with movement methods for cell spaces."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -9,8 +11,7 @@ if TYPE_CHECKING:
9
11
 
10
12
 
11
13
  class CellAgent(Agent):
12
- """Cell Agent is an extension of the Agent class and adds behavior for moving in discrete spaces
13
-
14
+ """Cell Agent is an extension of the Agent class and adds behavior for moving in discrete spaces.
14
15
 
15
16
  Attributes:
16
17
  unique_id (int): A unique identifier for this agent.
@@ -19,18 +20,22 @@ class CellAgent(Agent):
19
20
  cell: (Cell | None): the cell which the agent occupies
20
21
  """
21
22
 
22
- def __init__(self, unique_id: int, model: Model) -> None:
23
- """
24
- Create a new agent.
23
+ def __init__(self, model: Model) -> None:
24
+ """Create a new agent.
25
25
 
26
26
  Args:
27
- unique_id (int): A unique identifier for this agent.
28
27
  model (Model): The model instance in which the agent exists.
29
28
  """
30
- super().__init__(unique_id, model)
29
+ super().__init__(model)
31
30
  self.cell: Cell | None = None
32
31
 
33
32
  def move_to(self, cell) -> None:
33
+ """Move agent to cell.
34
+
35
+ Args:
36
+ cell: cell to which agent is to move
37
+
38
+ """
34
39
  if self.cell is not None:
35
40
  self.cell.remove_agent(self)
36
41
  self.cell = cell
@@ -1,3 +1,5 @@
1
+ """CellCollection class."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  import itertools
@@ -14,7 +16,7 @@ T = TypeVar("T", bound="Cell")
14
16
 
15
17
 
16
18
  class CellCollection(Generic[T]):
17
- """An immutable collection of cells
19
+ """An immutable collection of cells.
18
20
 
19
21
  Attributes:
20
22
  cells (List[Cell]): The list of cells this collection represents
@@ -28,6 +30,12 @@ class CellCollection(Generic[T]):
28
30
  cells: Mapping[T, list[CellAgent]] | Iterable[T],
29
31
  random: Random | None = None,
30
32
  ) -> None:
33
+ """Initialize a CellCollection.
34
+
35
+ Args:
36
+ cells: cells to add to the collection
37
+ random: a seeded random number generator.
38
+ """
31
39
  if isinstance(cells, dict):
32
40
  self._cells = cells
33
41
  else:
@@ -40,42 +48,71 @@ class CellCollection(Generic[T]):
40
48
  random = Random() # FIXME
41
49
  self.random = random
42
50
 
43
- def __iter__(self):
51
+ def __iter__(self): # noqa
44
52
  return iter(self._cells)
45
53
 
46
- def __getitem__(self, key: T) -> Iterable[CellAgent]:
54
+ def __getitem__(self, key: T) -> Iterable[CellAgent]: # noqa
47
55
  return self._cells[key]
48
56
 
49
57
  # @cached_property
50
- def __len__(self) -> int:
58
+ def __len__(self) -> int: # noqa
51
59
  return len(self._cells)
52
60
 
53
- def __repr__(self):
61
+ def __repr__(self): # noqa
54
62
  return f"CellCollection({self._cells})"
55
63
 
56
64
  @cached_property
57
- def cells(self) -> list[T]:
65
+ def cells(self) -> list[T]: # noqa
58
66
  return list(self._cells.keys())
59
67
 
60
68
  @property
61
- def agents(self) -> Iterable[CellAgent]:
69
+ def agents(self) -> Iterable[CellAgent]: # noqa
62
70
  return itertools.chain.from_iterable(self._cells.values())
63
71
 
64
72
  def select_random_cell(self) -> T:
73
+ """Select a random cell."""
65
74
  return self.random.choice(self.cells)
66
75
 
67
76
  def select_random_agent(self) -> CellAgent:
77
+ """Select a random agent.
78
+
79
+ Returns:
80
+ CellAgent instance
81
+
82
+
83
+ """
68
84
  return self.random.choice(list(self.agents))
69
85
 
70
- def select(self, filter_func: Callable[[T], bool] | None = None, n=0):
71
- # FIXME: n is not considered
72
- if filter_func is None and n == 0:
86
+ def select(
87
+ self,
88
+ filter_func: Callable[[T], bool] | None = None,
89
+ at_most: int | float = float("inf"),
90
+ ):
91
+ """Select cells based on filter function.
92
+
93
+ Args:
94
+ filter_func: filter function
95
+ at_most: The maximum amount of cells to select. Defaults to infinity.
96
+ - If an integer, at most the first number of matching cells is selected.
97
+ - If a float between 0 and 1, at most that fraction of original number of cells
98
+
99
+ Returns:
100
+ CellCollection
101
+
102
+ """
103
+ if filter_func is None and at_most == float("inf"):
73
104
  return self
74
105
 
75
- return CellCollection(
76
- {
77
- cell: agents
78
- for cell, agents in self._cells.items()
79
- if filter_func is None or filter_func(cell)
80
- }
81
- )
106
+ if at_most <= 1.0 and isinstance(at_most, float):
107
+ at_most = int(len(self) * at_most) # Note that it rounds down (floor)
108
+
109
+ def cell_generator(filter_func, at_most):
110
+ count = 0
111
+ for cell in self:
112
+ if count >= at_most:
113
+ break
114
+ if not filter_func or filter_func(cell):
115
+ yield cell
116
+ count += 1
117
+
118
+ return CellCollection(cell_generator(filter_func, at_most))