Mesa 2.3.4__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of Mesa might be problematic. Click here for more details.

Files changed (110) hide show
  1. mesa/__init__.py +3 -5
  2. mesa/agent.py +393 -116
  3. mesa/batchrunner.py +58 -31
  4. mesa/datacollection.py +141 -30
  5. mesa/examples/README.md +37 -0
  6. mesa/examples/__init__.py +21 -0
  7. mesa/examples/advanced/epstein_civil_violence/Epstein Civil Violence.ipynb +116 -0
  8. mesa/examples/advanced/epstein_civil_violence/Readme.md +34 -0
  9. mesa/examples/advanced/epstein_civil_violence/__init__.py +0 -0
  10. mesa/examples/advanced/epstein_civil_violence/agents.py +164 -0
  11. mesa/examples/advanced/epstein_civil_violence/app.py +73 -0
  12. mesa/examples/advanced/epstein_civil_violence/model.py +114 -0
  13. mesa/examples/advanced/pd_grid/Readme.md +43 -0
  14. mesa/examples/advanced/pd_grid/__init__.py +0 -0
  15. mesa/examples/advanced/pd_grid/agents.py +50 -0
  16. mesa/examples/advanced/pd_grid/analysis.ipynb +228 -0
  17. mesa/examples/advanced/pd_grid/app.py +54 -0
  18. mesa/examples/advanced/pd_grid/model.py +71 -0
  19. mesa/examples/advanced/sugarscape_g1mt/Readme.md +64 -0
  20. mesa/examples/advanced/sugarscape_g1mt/__init__.py +0 -0
  21. mesa/examples/advanced/sugarscape_g1mt/agents.py +344 -0
  22. mesa/examples/advanced/sugarscape_g1mt/app.py +62 -0
  23. mesa/examples/advanced/sugarscape_g1mt/model.py +180 -0
  24. mesa/examples/advanced/sugarscape_g1mt/sugar-map.txt +50 -0
  25. mesa/examples/advanced/sugarscape_g1mt/tests.py +69 -0
  26. mesa/examples/advanced/wolf_sheep/Readme.md +57 -0
  27. mesa/examples/advanced/wolf_sheep/__init__.py +0 -0
  28. mesa/examples/advanced/wolf_sheep/agents.py +102 -0
  29. mesa/examples/advanced/wolf_sheep/app.py +84 -0
  30. mesa/examples/advanced/wolf_sheep/model.py +137 -0
  31. mesa/examples/basic/__init__.py +0 -0
  32. mesa/examples/basic/boid_flockers/Readme.md +22 -0
  33. mesa/examples/basic/boid_flockers/__init__.py +0 -0
  34. mesa/examples/basic/boid_flockers/agents.py +71 -0
  35. mesa/examples/basic/boid_flockers/app.py +58 -0
  36. mesa/examples/basic/boid_flockers/model.py +69 -0
  37. mesa/examples/basic/boltzmann_wealth_model/Readme.md +56 -0
  38. mesa/examples/basic/boltzmann_wealth_model/__init__.py +0 -0
  39. mesa/examples/basic/boltzmann_wealth_model/agents.py +31 -0
  40. mesa/examples/basic/boltzmann_wealth_model/app.py +74 -0
  41. mesa/examples/basic/boltzmann_wealth_model/model.py +43 -0
  42. mesa/examples/basic/boltzmann_wealth_model/st_app.py +115 -0
  43. mesa/examples/basic/conways_game_of_life/Readme.md +39 -0
  44. mesa/examples/basic/conways_game_of_life/__init__.py +0 -0
  45. mesa/examples/basic/conways_game_of_life/agents.py +47 -0
  46. mesa/examples/basic/conways_game_of_life/app.py +51 -0
  47. mesa/examples/basic/conways_game_of_life/model.py +31 -0
  48. mesa/examples/basic/conways_game_of_life/st_app.py +72 -0
  49. mesa/examples/basic/schelling/Readme.md +40 -0
  50. mesa/examples/basic/schelling/__init__.py +0 -0
  51. mesa/examples/basic/schelling/agents.py +26 -0
  52. mesa/examples/basic/schelling/analysis.ipynb +205 -0
  53. mesa/examples/basic/schelling/app.py +42 -0
  54. mesa/examples/basic/schelling/model.py +59 -0
  55. mesa/examples/basic/virus_on_network/Readme.md +61 -0
  56. mesa/examples/basic/virus_on_network/__init__.py +0 -0
  57. mesa/examples/basic/virus_on_network/agents.py +69 -0
  58. mesa/examples/basic/virus_on_network/app.py +114 -0
  59. mesa/examples/basic/virus_on_network/model.py +96 -0
  60. mesa/experimental/UserParam.py +18 -7
  61. mesa/experimental/__init__.py +10 -2
  62. mesa/experimental/cell_space/__init__.py +16 -1
  63. mesa/experimental/cell_space/cell.py +93 -23
  64. mesa/experimental/cell_space/cell_agent.py +117 -21
  65. mesa/experimental/cell_space/cell_collection.py +56 -19
  66. mesa/experimental/cell_space/discrete_space.py +92 -8
  67. mesa/experimental/cell_space/grid.py +33 -9
  68. mesa/experimental/cell_space/network.py +15 -10
  69. mesa/experimental/cell_space/voronoi.py +257 -0
  70. mesa/experimental/components/altair.py +11 -2
  71. mesa/experimental/components/matplotlib.py +132 -26
  72. mesa/experimental/devs/__init__.py +2 -0
  73. mesa/experimental/devs/eventlist.py +54 -15
  74. mesa/experimental/devs/examples/epstein_civil_violence.py +71 -39
  75. mesa/experimental/devs/examples/wolf_sheep.py +45 -45
  76. mesa/experimental/devs/simulator.py +57 -16
  77. mesa/experimental/{jupyter_viz.py → solara_viz.py} +151 -98
  78. mesa/model.py +212 -84
  79. mesa/space.py +217 -151
  80. mesa/time.py +63 -80
  81. mesa/visualization/__init__.py +25 -6
  82. mesa/visualization/components/__init__.py +83 -0
  83. mesa/visualization/components/altair_components.py +188 -0
  84. mesa/visualization/components/matplotlib_components.py +175 -0
  85. mesa/visualization/mpl_space_drawing.py +593 -0
  86. mesa/visualization/solara_viz.py +458 -0
  87. mesa/visualization/user_param.py +69 -0
  88. mesa/visualization/utils.py +9 -0
  89. {mesa-2.3.4.dist-info → mesa-3.0.0.dist-info}/METADATA +65 -19
  90. mesa-3.0.0.dist-info/RECORD +95 -0
  91. mesa-3.0.0.dist-info/licenses/LICENSE +202 -0
  92. mesa-2.3.4.dist-info/licenses/LICENSE → mesa-3.0.0.dist-info/licenses/NOTICE +2 -2
  93. mesa/cookiecutter-mesa/cookiecutter.json +0 -8
  94. mesa/cookiecutter-mesa/hooks/post_gen_project.py +0 -11
  95. mesa/cookiecutter-mesa/{{cookiecutter.snake}}/README.md +0 -4
  96. mesa/cookiecutter-mesa/{{cookiecutter.snake}}/run.pytemplate +0 -3
  97. mesa/cookiecutter-mesa/{{cookiecutter.snake}}/setup.pytemplate +0 -11
  98. mesa/cookiecutter-mesa/{{cookiecutter.snake}}/{{cookiecutter.snake}}/model.pytemplate +0 -60
  99. mesa/cookiecutter-mesa/{{cookiecutter.snake}}/{{cookiecutter.snake}}/server.pytemplate +0 -36
  100. mesa/flat/__init__.py +0 -6
  101. mesa/flat/visualization.py +0 -5
  102. mesa/main.py +0 -63
  103. mesa/visualization/ModularVisualization.py +0 -1
  104. mesa/visualization/TextVisualization.py +0 -1
  105. mesa/visualization/UserParam.py +0 -1
  106. mesa/visualization/modules.py +0 -1
  107. mesa-2.3.4.dist-info/RECORD +0 -45
  108. /mesa/{cookiecutter-mesa/{{cookiecutter.snake}}/{{cookiecutter.snake}} → examples/advanced}/__init__.py +0 -0
  109. {mesa-2.3.4.dist-info → mesa-3.0.0.dist-info}/WHEEL +0 -0
  110. {mesa-2.3.4.dist-info → mesa-3.0.0.dist-info}/entry_points.txt +0 -0
mesa/agent.py CHANGED
@@ -1,7 +1,6 @@
1
- """
2
- The agent class for Mesa framework.
1
+ """Agent related classes.
3
2
 
4
- Core Objects: Agent
3
+ Core Objects: Agent and AgentSet.
5
4
  """
6
5
 
7
6
  # Mypy; for the `|` operator purpose
@@ -10,15 +9,19 @@ from __future__ import annotations
10
9
 
11
10
  import contextlib
12
11
  import copy
12
+ import functools
13
+ import itertools
13
14
  import operator
14
15
  import warnings
15
16
  import weakref
16
17
  from collections import defaultdict
17
- from collections.abc import Iterable, Iterator, MutableSet, Sequence
18
+ from collections.abc import Callable, Hashable, Iterable, Iterator, MutableSet, Sequence
18
19
  from random import Random
19
20
 
20
21
  # mypy
21
- from typing import TYPE_CHECKING, Any, Callable
22
+ from typing import TYPE_CHECKING, Any, Literal, overload
23
+
24
+ import numpy as np
22
25
 
23
26
  if TYPE_CHECKING:
24
27
  # We ensure that these are not imported during runtime to prevent cyclic
@@ -28,90 +31,99 @@ if TYPE_CHECKING:
28
31
 
29
32
 
30
33
  class Agent:
31
- """
32
- Base class for a model agent in Mesa.
34
+ """Base class for a model agent in Mesa.
33
35
 
34
36
  Attributes:
35
- unique_id (int): A unique identifier for this agent.
36
37
  model (Model): A reference to the model instance.
37
- self.pos: Position | None = None
38
+ unique_id (int): A unique identifier for this agent.
39
+ pos (Position): A reference to the position where this agent is located.
40
+
41
+ Notes:
42
+ unique_id is unique relative to a model instance and starts from 1
43
+
38
44
  """
39
45
 
40
- def __init__(self, unique_id: int, model: Model) -> None:
41
- """
42
- Create a new agent.
46
+ # this is a class level attribute
47
+ # it is a dictionary, indexed by model instance
48
+ # so, unique_id is unique relative to a model, and counting starts from 1
49
+ _ids = defaultdict(functools.partial(itertools.count, 1))
50
+
51
+ def __init__(self, model: Model, *args, **kwargs) -> None:
52
+ """Create a new agent.
43
53
 
44
54
  Args:
45
- unique_id (int): A unique identifier for this agent.
46
55
  model (Model): The model instance in which the agent exists.
56
+ args: passed on to super
57
+ kwargs: passed on to super
58
+
59
+ Notes:
60
+ to make proper use of python's super, in each class remove the arguments and
61
+ keyword arguments you need and pass on the rest to super
62
+
47
63
  """
48
- self.unique_id = unique_id
49
- self.model = model
64
+ super().__init__(*args, **kwargs)
65
+
66
+ self.model: Model = model
67
+ self.unique_id: int = next(self._ids[model])
50
68
  self.pos: Position | None = None
69
+ self.model.register_agent(self)
51
70
 
52
- # register agent
53
- try:
54
- self.model.agents_[type(self)][self] = None
55
- except AttributeError:
56
- # model super has not been called
57
- self.model.agents_ = defaultdict(dict)
58
- self.model.agents_[type(self)][self] = None
59
- self.model.agentset_experimental_warning_given = False
71
+ def remove(self) -> None:
72
+ """Remove and delete the agent from the model.
60
73
 
61
- warnings.warn(
62
- "The Mesa Model class was not initialized. In the future, you need to explicitly initialize the Model by calling super().__init__() on initialization.",
63
- FutureWarning,
64
- stacklevel=2,
65
- )
74
+ Notes:
75
+ If you need to do additional cleanup when removing an agent by for example removing
76
+ it from a space, consider extending this method in your own agent class.
66
77
 
67
- def remove(self) -> None:
68
- """Remove and delete the agent from the model."""
78
+ """
69
79
  with contextlib.suppress(KeyError):
70
- self.model.agents_[type(self)].pop(self)
80
+ self.model.deregister_agent(self)
71
81
 
72
82
  def step(self) -> None:
73
83
  """A single step of the agent."""
74
84
 
75
- def advance(self) -> None:
85
+ def advance(self) -> None: # noqa: D102
76
86
  pass
77
87
 
78
88
  @property
79
89
  def random(self) -> Random:
90
+ """Return a seeded stdlib rng."""
80
91
  return self.model.random
81
92
 
93
+ @property
94
+ def rng(self) -> np.random.Generator:
95
+ """Return a seeded np.random rng."""
96
+ return self.model.rng
97
+
82
98
 
83
99
  class AgentSet(MutableSet, Sequence):
84
- """
85
- A collection class that represents an ordered set of agents within an agent-based model (ABM). This class
86
- extends both MutableSet and Sequence, providing set-like functionality with order preservation and
100
+ """A collection class that represents an ordered set of agents within an agent-based model (ABM).
101
+
102
+ This class extends both MutableSet and Sequence, providing set-like functionality with order preservation and
87
103
  sequence operations.
88
104
 
89
105
  Attributes:
90
106
  model (Model): The ABM model instance to which this AgentSet belongs.
91
107
 
92
- Methods:
93
- __len__, __iter__, __contains__, select, shuffle, sort, _update, do, get, __getitem__,
94
- add, discard, remove, __getstate__, __setstate__, random
95
-
96
- Note:
108
+ Notes:
97
109
  The AgentSet maintains weak references to agents, allowing for efficient management of agent lifecycles
98
110
  without preventing garbage collection. It is associated with a specific model instance, enabling
99
111
  interactions with the model's environment and other agents.The implementation uses a WeakKeyDictionary to store agents,
100
112
  which means that agents not referenced elsewhere in the program may be automatically removed from the AgentSet.
101
113
  """
102
114
 
103
- agentset_experimental_warning_given = False
104
-
105
- def __init__(self, agents: Iterable[Agent], model: Model):
106
- """
107
- Initializes the AgentSet with a collection of agents and a reference to the model.
115
+ def __init__(self, agents: Iterable[Agent], random: Random | None = None):
116
+ """Initializes the AgentSet with a collection of agents and a reference to the model.
108
117
 
109
118
  Args:
110
119
  agents (Iterable[Agent]): An iterable of Agent objects to be included in the set.
111
- model (Model): The ABM model instance to which this AgentSet belongs.
120
+ random (Random): the random number generator
112
121
  """
113
-
114
- self.model = model
122
+ if random is None:
123
+ random = (
124
+ Random()
125
+ ) # FIXME see issue 1981, how to get the central rng from model
126
+ self.random = random
115
127
  self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents})
116
128
 
117
129
  def __len__(self) -> int:
@@ -129,45 +141,63 @@ class AgentSet(MutableSet, Sequence):
129
141
  def select(
130
142
  self,
131
143
  filter_func: Callable[[Agent], bool] | None = None,
132
- n: int = 0,
144
+ at_most: int | float = float("inf"),
133
145
  inplace: bool = False,
134
146
  agent_type: type[Agent] | None = None,
147
+ n: int | None = None,
135
148
  ) -> AgentSet:
136
- """
137
- Select a subset of agents from the AgentSet based on a filter function and/or quantity limit.
149
+ """Select a subset of agents from the AgentSet based on a filter function and/or quantity limit.
138
150
 
139
151
  Args:
140
152
  filter_func (Callable[[Agent], bool], optional): A function that takes an Agent and returns True if the
141
153
  agent should be included in the result. Defaults to None, meaning no filtering is applied.
142
- n (int, optional): The number of agents to select. If 0, all matching agents are selected. Defaults to 0.
154
+ at_most (int | float, optional): The maximum amount of agents to select. Defaults to infinity.
155
+ - If an integer, at most the first number of matching agents are selected.
156
+ - If a float between 0 and 1, at most that fraction of original the agents are selected.
143
157
  inplace (bool, optional): If True, modifies the current AgentSet; otherwise, returns a new AgentSet. Defaults to False.
144
158
  agent_type (type[Agent], optional): The class type of the agents to select. Defaults to None, meaning no type filtering is applied.
159
+ n (int): deprecated, use at_most instead
145
160
 
146
161
  Returns:
147
162
  AgentSet: A new AgentSet containing the selected agents, unless inplace is True, in which case the current AgentSet is updated.
163
+
164
+ Notes:
165
+ - at_most just return the first n or fraction of agents. To take a random sample, shuffle() beforehand.
166
+ - at_most is an upper limit. When specifying other criteria, the number of agents returned can be smaller.
148
167
  """
168
+ if n is not None:
169
+ warnings.warn(
170
+ "The parameter 'n' is deprecated. Use 'at_most' instead.",
171
+ DeprecationWarning,
172
+ stacklevel=2,
173
+ )
174
+ at_most = n
149
175
 
150
- if filter_func is None and agent_type is None and n == 0:
176
+ inf = float("inf")
177
+ if filter_func is None and agent_type is None and at_most == inf:
151
178
  return self if inplace else copy.copy(self)
152
179
 
153
- def agent_generator(filter_func=None, agent_type=None, n=0):
180
+ # Check if at_most is of type float
181
+ if at_most <= 1.0 and isinstance(at_most, float):
182
+ at_most = int(len(self) * at_most) # Note that it rounds down (floor)
183
+
184
+ def agent_generator(filter_func, agent_type, at_most):
154
185
  count = 0
155
186
  for agent in self:
187
+ if count >= at_most:
188
+ break
156
189
  if (not filter_func or filter_func(agent)) and (
157
190
  not agent_type or isinstance(agent, agent_type)
158
191
  ):
159
192
  yield agent
160
193
  count += 1
161
- if 0 < n <= count:
162
- break
163
194
 
164
- agents = agent_generator(filter_func, agent_type, n)
195
+ agents = agent_generator(filter_func, agent_type, at_most)
165
196
 
166
- return AgentSet(agents, self.model) if not inplace else self._update(agents)
197
+ return AgentSet(agents, self.random) if not inplace else self._update(agents)
167
198
 
168
199
  def shuffle(self, inplace: bool = False) -> AgentSet:
169
- """
170
- Randomly shuffle the order of agents in the AgentSet.
200
+ """Randomly shuffle the order of agents in the AgentSet.
171
201
 
172
202
  Args:
173
203
  inplace (bool, optional): If True, shuffles the agents in the current AgentSet; otherwise, returns a new shuffled AgentSet. Defaults to False.
@@ -187,7 +217,7 @@ class AgentSet(MutableSet, Sequence):
187
217
  return self
188
218
  else:
189
219
  return AgentSet(
190
- (agent for ref in weakrefs if (agent := ref()) is not None), self.model
220
+ (agent for ref in weakrefs if (agent := ref()) is not None), self.random
191
221
  )
192
222
 
193
223
  def sort(
@@ -196,8 +226,7 @@ class AgentSet(MutableSet, Sequence):
196
226
  ascending: bool = False,
197
227
  inplace: bool = False,
198
228
  ) -> AgentSet:
199
- """
200
- Sort the agents in the AgentSet based on a specified attribute or custom function.
229
+ """Sort the agents in the AgentSet based on a specified attribute or custom function.
201
230
 
202
231
  Args:
203
232
  key (Callable[[Agent], Any] | str): A function or attribute name based on which the agents are sorted.
@@ -213,70 +242,193 @@ class AgentSet(MutableSet, Sequence):
213
242
  sorted_agents = sorted(self._agents.keys(), key=key, reverse=not ascending)
214
243
 
215
244
  return (
216
- AgentSet(sorted_agents, self.model)
245
+ AgentSet(sorted_agents, self.random)
217
246
  if not inplace
218
247
  else self._update(sorted_agents)
219
248
  )
220
249
 
221
250
  def _update(self, agents: Iterable[Agent]):
222
251
  """Update the AgentSet with a new set of agents.
252
+
223
253
  This is a private method primarily used internally by other methods like select, shuffle, and sort.
224
254
  """
225
-
226
255
  self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents})
227
256
  return self
228
257
 
229
- def do(
230
- self, method_name: str, *args, return_results: bool = False, **kwargs
231
- ) -> AgentSet | list[Any]:
258
+ def do(self, method: str | Callable, *args, **kwargs) -> AgentSet:
259
+ """Invoke a method or function on each agent in the AgentSet.
260
+
261
+ Args:
262
+ method (str, callable): the callable to do on each agent
263
+
264
+ * in case of str, the name of the method to call on each agent.
265
+ * in case of callable, the function to be called with each agent as first argument
266
+
267
+ *args: Variable length argument list passed to the callable being called.
268
+ **kwargs: Arbitrary keyword arguments passed to the callable being called.
269
+
270
+ Returns:
271
+ AgentSet | list[Any]: The results of the callable calls if return_results is True, otherwise the AgentSet itself.
272
+ """
273
+ try:
274
+ return_results = kwargs.pop("return_results")
275
+ except KeyError:
276
+ return_results = False
277
+ else:
278
+ warnings.warn(
279
+ "Using return_results is deprecated. Use AgenSet.do in case of return_results=False, and "
280
+ "AgentSet.map in case of return_results=True",
281
+ stacklevel=2,
282
+ )
283
+
284
+ if return_results:
285
+ return self.map(method, *args, **kwargs)
286
+
287
+ # we iterate over the actual weakref keys and check if weakref is alive before calling the method
288
+ if isinstance(method, str):
289
+ for agentref in self._agents.keyrefs():
290
+ if (agent := agentref()) is not None:
291
+ getattr(agent, method)(*args, **kwargs)
292
+ else:
293
+ for agentref in self._agents.keyrefs():
294
+ if (agent := agentref()) is not None:
295
+ method(agent, *args, **kwargs)
296
+
297
+ return self
298
+
299
+ def shuffle_do(self, method: str | Callable, *args, **kwargs) -> AgentSet:
300
+ """Shuffle the agents in the AgentSet and then invoke a method or function on each agent.
301
+
302
+ It's a fast, optimized version of calling shuffle() followed by do().
232
303
  """
233
- Invoke a method on each agent in the AgentSet.
304
+ weakrefs = list(self._agents.keyrefs())
305
+ self.random.shuffle(weakrefs)
306
+
307
+ if isinstance(method, str):
308
+ for ref in weakrefs:
309
+ if (agent := ref()) is not None:
310
+ getattr(agent, method)(*args, **kwargs)
311
+ else:
312
+ for ref in weakrefs:
313
+ if (agent := ref()) is not None:
314
+ method(agent, *args, **kwargs)
315
+
316
+ return self
317
+
318
+ def map(self, method: str | Callable, *args, **kwargs) -> list[Any]:
319
+ """Invoke a method or function on each agent in the AgentSet and return the results.
234
320
 
235
321
  Args:
236
- method_name (str): The name of the method to call on each agent.
237
- return_results (bool, optional): If True, returns the results of the method calls; otherwise, returns the AgentSet itself. Defaults to False, so you can chain method calls.
238
- *args: Variable length argument list passed to the method being called.
239
- **kwargs: Arbitrary keyword arguments passed to the method being called.
322
+ method (str, callable): the callable to apply on each agent
323
+
324
+ * in case of str, the name of the method to call on each agent.
325
+ * in case of callable, the function to be called with each agent as first argument
326
+
327
+ *args: Variable length argument list passed to the callable being called.
328
+ **kwargs: Arbitrary keyword arguments passed to the callable being called.
240
329
 
241
330
  Returns:
242
- AgentSet | list[Any]: The results of the method calls if return_results is True, otherwise the AgentSet itself.
331
+ list[Any]: The results of the callable calls
243
332
  """
244
333
  # we iterate over the actual weakref keys and check if weakref is alive before calling the method
245
- res = [
246
- getattr(agent, method_name)(*args, **kwargs)
247
- for agentref in self._agents.keyrefs()
248
- if (agent := agentref()) is not None
249
- ]
334
+ if isinstance(method, str):
335
+ res = [
336
+ getattr(agent, method)(*args, **kwargs)
337
+ for agentref in self._agents.keyrefs()
338
+ if (agent := agentref()) is not None
339
+ ]
340
+ else:
341
+ res = [
342
+ method(agent, *args, **kwargs)
343
+ for agentref in self._agents.keyrefs()
344
+ if (agent := agentref()) is not None
345
+ ]
250
346
 
251
- return res if return_results else self
347
+ return res
252
348
 
253
- def get(self, attr_names: str | list[str]) -> list[Any]:
349
+ def agg(self, attribute: str, func: Callable) -> Any:
350
+ """Aggregate an attribute of all agents in the AgentSet using a specified function.
351
+
352
+ Args:
353
+ attribute (str): The name of the attribute to aggregate.
354
+ func (Callable): The function to apply to the attribute values (e.g., min, max, sum, np.mean).
355
+
356
+ Returns:
357
+ Any: The result of applying the function to the attribute values. Often a single value.
254
358
  """
255
- Retrieve the specified attribute(s) from each agent in the AgentSet.
359
+ values = self.get(attribute)
360
+ return func(values)
361
+
362
+ @overload
363
+ def get(
364
+ self,
365
+ attr_names: str,
366
+ handle_missing: Literal["error", "default"] = "error",
367
+ default_value: Any = None,
368
+ ) -> list[Any]: ...
369
+
370
+ @overload
371
+ def get(
372
+ self,
373
+ attr_names: list[str],
374
+ handle_missing: Literal["error", "default"] = "error",
375
+ default_value: Any = None,
376
+ ) -> list[list[Any]]: ...
377
+
378
+ def get(
379
+ self,
380
+ attr_names,
381
+ handle_missing="error",
382
+ default_value=None,
383
+ ):
384
+ """Retrieve the specified attribute(s) from each agent in the AgentSet.
256
385
 
257
386
  Args:
258
387
  attr_names (str | list[str]): The name(s) of the attribute(s) to retrieve from each agent.
388
+ handle_missing (str, optional): How to handle missing attributes. Can be:
389
+ - 'error' (default): raises an AttributeError if attribute is missing.
390
+ - 'default': returns the specified default_value.
391
+ default_value (Any, optional): The default value to return if 'handle_missing' is set to 'default'
392
+ and the agent does not have the attribute.
259
393
 
260
394
  Returns:
261
- list[Any]: A list with the attribute value for each agent in the set if attr_names is a str
262
- list[list[Any]]: A list with a list of attribute values for each agent in the set if attr_names is a list of str
395
+ list[Any]: A list with the attribute value for each agent if attr_names is a str.
396
+ list[list[Any]]: A list with a lists of attribute values for each agent if attr_names is a list of str.
263
397
 
264
398
  Raises:
265
- AttributeError if an agent does not have the specified attribute(s)
266
-
399
+ AttributeError: If 'handle_missing' is 'error' and the agent does not have the specified attribute(s).
400
+ ValueError: If an unknown 'handle_missing' option is provided.
267
401
  """
402
+ is_single_attr = isinstance(attr_names, str)
403
+
404
+ if handle_missing == "error":
405
+ if is_single_attr:
406
+ return [getattr(agent, attr_names) for agent in self._agents]
407
+ else:
408
+ return [
409
+ [getattr(agent, attr) for attr in attr_names]
410
+ for agent in self._agents
411
+ ]
412
+
413
+ elif handle_missing == "default":
414
+ if is_single_attr:
415
+ return [
416
+ getattr(agent, attr_names, default_value) for agent in self._agents
417
+ ]
418
+ else:
419
+ return [
420
+ [getattr(agent, attr, default_value) for attr in attr_names]
421
+ for agent in self._agents
422
+ ]
268
423
 
269
- if isinstance(attr_names, str):
270
- return [getattr(agent, attr_names) for agent in self._agents]
271
424
  else:
272
- return [
273
- [getattr(agent, attr_name) for attr_name in attr_names]
274
- for agent in self._agents
275
- ]
425
+ raise ValueError(
426
+ f"Unknown handle_missing option: {handle_missing}, "
427
+ "should be one of 'error' or 'default'"
428
+ )
276
429
 
277
430
  def set(self, attr_name: str, value: Any) -> AgentSet:
278
- """
279
- Set a specified attribute to a given value for all agents in the AgentSet.
431
+ """Set a specified attribute to a given value for all agents in the AgentSet.
280
432
 
281
433
  Args:
282
434
  attr_name (str): The name of the attribute to set.
@@ -290,8 +442,7 @@ class AgentSet(MutableSet, Sequence):
290
442
  return self
291
443
 
292
444
  def __getitem__(self, item: int | slice) -> Agent:
293
- """
294
- Retrieve an agent or a slice of agents from the AgentSet.
445
+ """Retrieve an agent or a slice of agents from the AgentSet.
295
446
 
296
447
  Args:
297
448
  item (int | slice): The index or slice for selecting agents.
@@ -302,8 +453,7 @@ class AgentSet(MutableSet, Sequence):
302
453
  return list(self._agents.keys())[item]
303
454
 
304
455
  def add(self, agent: Agent):
305
- """
306
- Add an agent to the AgentSet.
456
+ """Add an agent to the AgentSet.
307
457
 
308
458
  Args:
309
459
  agent (Agent): The agent to add to the set.
@@ -314,8 +464,7 @@ class AgentSet(MutableSet, Sequence):
314
464
  self._agents[agent] = None
315
465
 
316
466
  def discard(self, agent: Agent):
317
- """
318
- Remove an agent from the AgentSet if it exists.
467
+ """Remove an agent from the AgentSet if it exists.
319
468
 
320
469
  This method does not raise an error if the agent is not present.
321
470
 
@@ -329,8 +478,7 @@ class AgentSet(MutableSet, Sequence):
329
478
  del self._agents[agent]
330
479
 
331
480
  def remove(self, agent: Agent):
332
- """
333
- Remove an agent from the AgentSet.
481
+ """Remove an agent from the AgentSet.
334
482
 
335
483
  This method raises an error if the agent is not present.
336
484
 
@@ -343,35 +491,164 @@ class AgentSet(MutableSet, Sequence):
343
491
  del self._agents[agent]
344
492
 
345
493
  def __getstate__(self):
346
- """
347
- Retrieve the state of the AgentSet for serialization.
494
+ """Retrieve the state of the AgentSet for serialization.
348
495
 
349
496
  Returns:
350
497
  dict: A dictionary representing the state of the AgentSet.
351
498
  """
352
- return {"agents": list(self._agents.keys()), "model": self.model}
499
+ return {"agents": list(self._agents.keys()), "random": self.random}
353
500
 
354
501
  def __setstate__(self, state):
355
- """
356
- Set the state of the AgentSet during deserialization.
502
+ """Set the state of the AgentSet during deserialization.
357
503
 
358
504
  Args:
359
505
  state (dict): A dictionary representing the state to restore.
360
506
  """
361
- self.model = state["model"]
507
+ self.random = state["random"]
362
508
  self._update(state["agents"])
363
509
 
364
- @property
365
- def random(self) -> Random:
510
+ def groupby(self, by: Callable | str, result_type: str = "agentset") -> GroupBy:
511
+ """Group agents by the specified attribute or return from the callable.
512
+
513
+ Args:
514
+ by (Callable, str): used to determine what to group agents by
515
+
516
+ * if ``by`` is a callable, it will be called for each agent and the return is used
517
+ for grouping
518
+ * if ``by`` is a str, it should refer to an attribute on the agent and the value
519
+ of this attribute will be used for grouping
520
+ result_type (str, optional): The datatype for the resulting groups {"agentset", "list"}
521
+
522
+ Returns:
523
+ GroupBy
524
+
525
+
526
+ Notes:
527
+ There might be performance benefits to using `result_type='list'` if you don't need the advanced functionality
528
+ of an AgentSet.
529
+
366
530
  """
367
- Provide access to the model's random number generator.
531
+ groups = defaultdict(list)
532
+
533
+ if isinstance(by, Callable):
534
+ for agent in self:
535
+ groups[by(agent)].append(agent)
536
+ else:
537
+ for agent in self:
538
+ groups[getattr(agent, by)].append(agent)
539
+
540
+ if result_type == "agentset":
541
+ return GroupBy(
542
+ {k: AgentSet(v, random=self.random) for k, v in groups.items()}
543
+ )
544
+ else:
545
+ return GroupBy(groups)
546
+
547
+ # consider adding for performance reasons
548
+ # for Sequence: __reversed__, index, and count
549
+ # for MutableSet clear, pop, remove, __ior__, __iand__, __ixor__, and __isub__
550
+
551
+
552
+ class GroupBy:
553
+ """Helper class for AgentSet.groupby.
554
+
555
+ Attributes:
556
+ groups (dict): A dictionary with the group_name as key and group as values
557
+
558
+ """
559
+
560
+ def __init__(self, groups: dict[Any, list | AgentSet]):
561
+ """Initialize a GroupBy instance.
562
+
563
+ Args:
564
+ groups (dict): A dictionary with the group_name as key and group as values
565
+
566
+ """
567
+ self.groups: dict[Any, list | AgentSet] = groups
568
+
569
+ def map(self, method: Callable | str, *args, **kwargs) -> dict[Any, Any]:
570
+ """Apply the specified callable to each group and return the results.
571
+
572
+ Args:
573
+ method (Callable, str): The callable to apply to each group,
574
+
575
+ * if ``method`` is a callable, it will be called it will be called with the group as first argument
576
+ * if ``method`` is a str, it should refer to a method on the group
577
+
578
+ Additional arguments and keyword arguments will be passed on to the callable.
579
+ args: arguments to pass to the callable
580
+ kwargs: keyword arguments to pass to the callable
368
581
 
369
582
  Returns:
370
- Random: The random number generator associated with the model.
583
+ dict with group_name as key and the return of the method as value
584
+
585
+ Notes:
586
+ this method is useful for methods or functions that do return something. It
587
+ will break method chaining. For that, use ``do`` instead.
588
+
371
589
  """
372
- return self.model.random
590
+ if isinstance(method, str):
591
+ return {
592
+ k: getattr(v, method)(*args, **kwargs) for k, v in self.groups.items()
593
+ }
594
+ else:
595
+ return {k: method(v, *args, **kwargs) for k, v in self.groups.items()}
596
+
597
+ def do(self, method: Callable | str, *args, **kwargs) -> GroupBy:
598
+ """Apply the specified callable to each group.
599
+
600
+ Args:
601
+ method (Callable, str): The callable to apply to each group,
602
+
603
+ * if ``method`` is a callable, it will be called it will be called with the group as first argument
604
+ * if ``method`` is a str, it should refer to a method on the group
605
+
606
+ Additional arguments and keyword arguments will be passed on to the callable.
607
+ args: arguments to pass to the callable
608
+ kwargs: keyword arguments to pass to the callable
609
+
610
+ Returns:
611
+ the original GroupBy instance
612
+
613
+ Notes:
614
+ this method is useful for methods or functions that don't return anything and/or
615
+ if you want to chain multiple do calls
616
+
617
+ """
618
+ if isinstance(method, str):
619
+ for v in self.groups.values():
620
+ getattr(v, method)(*args, **kwargs)
621
+ else:
622
+ for v in self.groups.values():
623
+ method(v, *args, **kwargs)
624
+
625
+ return self
626
+
627
+ def count(self) -> dict[Any, int]:
628
+ """Return the count of agents in each group.
629
+
630
+ Returns:
631
+ dict: A dictionary mapping group names to the number of agents in each group.
632
+ """
633
+ return {k: len(v) for k, v in self.groups.items()}
634
+
635
+ def agg(self, attr_name: str, func: Callable) -> dict[Hashable, Any]:
636
+ """Aggregate the values of a specific attribute across each group using the provided function.
637
+
638
+ Args:
639
+ attr_name (str): The name of the attribute to aggregate.
640
+ func (Callable): The function to apply (e.g., sum, min, max, mean).
641
+
642
+ Returns:
643
+ dict[Hashable, Any]: A dictionary mapping group names to the result of applying the aggregation function.
644
+ """
645
+ return {
646
+ group_name: func([getattr(agent, attr_name) for agent in group])
647
+ for group_name, group in self.groups.items()
648
+ }
373
649
 
650
+ def __iter__(self): # noqa: D105
651
+ return iter(self.groups.items())
374
652
 
375
- # consider adding for performance reasons
376
- # for Sequence: __reversed__, index, and count
377
- # for MutableSet clear, pop, remove, __ior__, __iand__, __ixor__, and __isub__
653
+ def __len__(self): # noqa: D105
654
+ return len(self.groups)