Mesa 3.0.0a2__py3-none-any.whl → 3.0.0a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of Mesa might be problematic. Click here for more details.

mesa/__init__.py CHANGED
@@ -24,7 +24,7 @@ __all__ = [
24
24
  ]
25
25
 
26
26
  __title__ = "mesa"
27
- __version__ = "3.0.0a2"
27
+ __version__ = "3.0.0a4"
28
28
  __license__ = "Apache 2.0"
29
29
  _this_year = datetime.datetime.now(tz=datetime.timezone.utc).date().year
30
30
  __copyright__ = f"Copyright {_this_year} Project Mesa Team"
mesa/agent.py CHANGED
@@ -10,13 +10,17 @@ from __future__ import annotations
10
10
 
11
11
  import contextlib
12
12
  import copy
13
+ import functools
14
+ import itertools
13
15
  import operator
16
+ import warnings
14
17
  import weakref
18
+ from collections import defaultdict
15
19
  from collections.abc import Callable, Iterable, Iterator, MutableSet, Sequence
16
20
  from random import Random
17
21
 
18
22
  # mypy
19
- from typing import TYPE_CHECKING, Any
23
+ from typing import TYPE_CHECKING, Any, Literal
20
24
 
21
25
  if TYPE_CHECKING:
22
26
  # We ensure that these are not imported during runtime to prevent cyclic
@@ -30,36 +34,57 @@ class Agent:
30
34
  Base class for a model agent in Mesa.
31
35
 
32
36
  Attributes:
33
- unique_id (int): A unique identifier for this agent.
34
37
  model (Model): A reference to the model instance.
35
- self.pos: Position | None = None
38
+ unique_id (int): A unique identifier for this agent.
39
+ pos (Position): A reference to the position where this agent is located.
40
+
41
+ Notes:
42
+ unique_id is unique relative to a model instance and starts from 1
43
+
36
44
  """
37
45
 
38
- def __init__(self, unique_id: int, model: Model) -> None:
46
+ # this is a class level attribute
47
+ # it is a dictionary, indexed by model instance
48
+ # so, unique_id is unique relative to a model, and counting starts from 1
49
+ _ids = defaultdict(functools.partial(itertools.count, 1))
50
+
51
+ def __init__(self, *args, **kwargs) -> None:
39
52
  """
40
53
  Create a new agent.
41
54
 
42
55
  Args:
43
- unique_id (int): A unique identifier for this agent.
44
56
  model (Model): The model instance in which the agent exists.
45
57
  """
46
- self.unique_id = unique_id
47
- self.model = model
58
+ # TODO: Cleanup in future Mesa version (3.1+)
59
+ match args:
60
+ # Case 1: Only the model is provided. The new correct behavior.
61
+ case [model]:
62
+ self.model = model
63
+ self.unique_id = next(self._ids[model])
64
+ # Case 2: Both unique_id and model are provided, deprecated
65
+ case [_, model]:
66
+ warnings.warn(
67
+ "unique ids are assigned automatically to Agents in Mesa 3. The use of custom unique_id is "
68
+ "deprecated. Only input a model when calling `super()__init__(model)`. The unique_id inputted is not used.",
69
+ DeprecationWarning,
70
+ stacklevel=2,
71
+ )
72
+ self.model = model
73
+ self.unique_id = next(self._ids[model])
74
+ # Case 3: Anything else, raise an error
75
+ case _:
76
+ raise ValueError(
77
+ "Invalid arguments provided to initialize the Agent. Only input a model: `super()__init__(model)`."
78
+ )
79
+
48
80
  self.pos: Position | None = None
49
81
 
50
- # register agent
51
- try:
52
- self.model.agents_[type(self)][self] = None
53
- except AttributeError as err:
54
- # model super has not been called
55
- raise RuntimeError(
56
- "The Mesa Model class was not initialized. You must explicitly initialize the Model by calling super().__init__() on initialization."
57
- ) from err
82
+ self.model.register_agent(self)
58
83
 
59
84
  def remove(self) -> None:
60
85
  """Remove and delete the agent from the model."""
61
86
  with contextlib.suppress(KeyError):
62
- self.model.agents_[type(self)].pop(self)
87
+ self.model.deregister_agent(self)
63
88
 
64
89
  def step(self) -> None:
65
90
  """A single step of the agent."""
@@ -119,9 +144,10 @@ class AgentSet(MutableSet, Sequence):
119
144
  def select(
120
145
  self,
121
146
  filter_func: Callable[[Agent], bool] | None = None,
122
- n: int = 0,
147
+ at_most: int | float = float("inf"),
123
148
  inplace: bool = False,
124
149
  agent_type: type[Agent] | None = None,
150
+ n: int | None = None,
125
151
  ) -> AgentSet:
126
152
  """
127
153
  Select a subset of agents from the AgentSet based on a filter function and/or quantity limit.
@@ -129,29 +155,47 @@ class AgentSet(MutableSet, Sequence):
129
155
  Args:
130
156
  filter_func (Callable[[Agent], bool], optional): A function that takes an Agent and returns True if the
131
157
  agent should be included in the result. Defaults to None, meaning no filtering is applied.
132
- n (int, optional): The number of agents to select. If 0, all matching agents are selected. Defaults to 0.
158
+ at_most (int | float, optional): The maximum amount of agents to select. Defaults to infinity.
159
+ - If an integer, at most the first number of matching agents are selected.
160
+ - If a float between 0 and 1, at most that fraction of original the agents are selected.
133
161
  inplace (bool, optional): If True, modifies the current AgentSet; otherwise, returns a new AgentSet. Defaults to False.
134
162
  agent_type (type[Agent], optional): The class type of the agents to select. Defaults to None, meaning no type filtering is applied.
135
163
 
136
164
  Returns:
137
165
  AgentSet: A new AgentSet containing the selected agents, unless inplace is True, in which case the current AgentSet is updated.
166
+
167
+ Notes:
168
+ - at_most just return the first n or fraction of agents. To take a random sample, shuffle() beforehand.
169
+ - at_most is an upper limit. When specifying other criteria, the number of agents returned can be smaller.
138
170
  """
171
+ if n is not None:
172
+ warnings.warn(
173
+ "The parameter 'n' is deprecated. Use 'at_most' instead.",
174
+ DeprecationWarning,
175
+ stacklevel=2,
176
+ )
177
+ at_most = n
139
178
 
140
- if filter_func is None and agent_type is None and n == 0:
179
+ inf = float("inf")
180
+ if filter_func is None and agent_type is None and at_most == inf:
141
181
  return self if inplace else copy.copy(self)
142
182
 
143
- def agent_generator(filter_func=None, agent_type=None, n=0):
183
+ # Check if at_most is of type float
184
+ if at_most <= 1.0 and isinstance(at_most, float):
185
+ at_most = int(len(self) * at_most) # Note that it rounds down (floor)
186
+
187
+ def agent_generator(filter_func, agent_type, at_most):
144
188
  count = 0
145
189
  for agent in self:
190
+ if count >= at_most:
191
+ break
146
192
  if (not filter_func or filter_func(agent)) and (
147
193
  not agent_type or isinstance(agent, agent_type)
148
194
  ):
149
195
  yield agent
150
196
  count += 1
151
- if 0 < n <= count:
152
- break
153
197
 
154
- agents = agent_generator(filter_func, agent_type, n)
198
+ agents = agent_generator(filter_func, agent_type, at_most)
155
199
 
156
200
  return AgentSet(agents, self.model) if not inplace else self._update(agents)
157
201
 
@@ -216,25 +260,64 @@ class AgentSet(MutableSet, Sequence):
216
260
  self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents})
217
261
  return self
218
262
 
219
- def do(
220
- self, method: str | Callable, *args, return_results: bool = False, **kwargs
221
- ) -> AgentSet | list[Any]:
263
+ def do(self, method: str | Callable, *args, **kwargs) -> AgentSet:
222
264
  """
223
265
  Invoke a method or function on each agent in the AgentSet.
224
266
 
225
267
  Args:
226
- method (str, callable): the callable to do on each agents
268
+ method (str, callable): the callable to do on each agent
227
269
 
228
270
  * in case of str, the name of the method to call on each agent.
229
271
  * in case of callable, the function to be called with each agent as first argument
230
272
 
231
- return_results (bool, optional): If True, returns the results of the method calls; otherwise, returns the AgentSet itself. Defaults to False, so you can chain method calls.
232
273
  *args: Variable length argument list passed to the callable being called.
233
274
  **kwargs: Arbitrary keyword arguments passed to the callable being called.
234
275
 
235
276
  Returns:
236
277
  AgentSet | list[Any]: The results of the callable calls if return_results is True, otherwise the AgentSet itself.
237
278
  """
279
+ try:
280
+ return_results = kwargs.pop("return_results")
281
+ except KeyError:
282
+ return_results = False
283
+ else:
284
+ warnings.warn(
285
+ "Using return_results is deprecated. Use AgenSet.do in case of return_results=False, and "
286
+ "AgentSet.map in case of return_results=True",
287
+ stacklevel=2,
288
+ )
289
+
290
+ if return_results:
291
+ return self.map(method, *args, **kwargs)
292
+
293
+ # we iterate over the actual weakref keys and check if weakref is alive before calling the method
294
+ if isinstance(method, str):
295
+ for agentref in self._agents.keyrefs():
296
+ if (agent := agentref()) is not None:
297
+ getattr(agent, method)(*args, **kwargs)
298
+ else:
299
+ for agentref in self._agents.keyrefs():
300
+ if (agent := agentref()) is not None:
301
+ method(agent, *args, **kwargs)
302
+
303
+ return self
304
+
305
+ def map(self, method: str | Callable, *args, **kwargs) -> list[Any]:
306
+ """
307
+ Invoke a method or function on each agent in the AgentSet and return the results.
308
+
309
+ Args:
310
+ method (str, callable): the callable to apply on each agent
311
+
312
+ * in case of str, the name of the method to call on each agent.
313
+ * in case of callable, the function to be called with each agent as first argument
314
+
315
+ *args: Variable length argument list passed to the callable being called.
316
+ **kwargs: Arbitrary keyword arguments passed to the callable being called.
317
+
318
+ Returns:
319
+ list[Any]: The results of the callable calls
320
+ """
238
321
  # we iterate over the actual weakref keys and check if weakref is alive before calling the method
239
322
  if isinstance(method, str):
240
323
  res = [
@@ -249,31 +332,89 @@ class AgentSet(MutableSet, Sequence):
249
332
  if (agent := agentref()) is not None
250
333
  ]
251
334
 
252
- return res if return_results else self
335
+ return res
336
+
337
+ def agg(self, attribute: str, func: Callable) -> Any:
338
+ """
339
+ Aggregate an attribute of all agents in the AgentSet using a specified function.
340
+
341
+ Args:
342
+ attribute (str): The name of the attribute to aggregate.
343
+ func (Callable): The function to apply to the attribute values (e.g., min, max, sum, np.mean).
344
+
345
+ Returns:
346
+ Any: The result of applying the function to the attribute values. Often a single value.
347
+ """
348
+ values = self.get(attribute)
349
+ return func(values)
253
350
 
254
- def get(self, attr_names: str | list[str]) -> list[Any]:
351
+ def get(
352
+ self,
353
+ attr_names: str | list[str],
354
+ handle_missing: Literal["error", "default"] = "error",
355
+ default_value: Any = None,
356
+ ) -> list[Any] | list[list[Any]]:
255
357
  """
256
358
  Retrieve the specified attribute(s) from each agent in the AgentSet.
257
359
 
258
360
  Args:
259
361
  attr_names (str | list[str]): The name(s) of the attribute(s) to retrieve from each agent.
362
+ handle_missing (str, optional): How to handle missing attributes. Can be:
363
+ - 'error' (default): raises an AttributeError if attribute is missing.
364
+ - 'default': returns the specified default_value.
365
+ default_value (Any, optional): The default value to return if 'handle_missing' is set to 'default'
366
+ and the agent does not have the attribute.
260
367
 
261
368
  Returns:
262
- list[Any]: A list with the attribute value for each agent in the set if attr_names is a str
263
- list[list[Any]]: A list with a list of attribute values for each agent in the set if attr_names is a list of str
369
+ list[Any]: A list with the attribute value for each agent if attr_names is a str.
370
+ list[list[Any]]: A list with a lists of attribute values for each agent if attr_names is a list of str.
264
371
 
265
372
  Raises:
266
- AttributeError if an agent does not have the specified attribute(s)
267
-
373
+ AttributeError: If 'handle_missing' is 'error' and the agent does not have the specified attribute(s).
374
+ ValueError: If an unknown 'handle_missing' option is provided.
268
375
  """
376
+ is_single_attr = isinstance(attr_names, str)
377
+
378
+ if handle_missing == "error":
379
+ if is_single_attr:
380
+ return [getattr(agent, attr_names) for agent in self._agents]
381
+ else:
382
+ return [
383
+ [getattr(agent, attr) for attr in attr_names]
384
+ for agent in self._agents
385
+ ]
386
+
387
+ elif handle_missing == "default":
388
+ if is_single_attr:
389
+ return [
390
+ getattr(agent, attr_names, default_value) for agent in self._agents
391
+ ]
392
+ else:
393
+ return [
394
+ [getattr(agent, attr, default_value) for attr in attr_names]
395
+ for agent in self._agents
396
+ ]
269
397
 
270
- if isinstance(attr_names, str):
271
- return [getattr(agent, attr_names) for agent in self._agents]
272
398
  else:
273
- return [
274
- [getattr(agent, attr_name) for attr_name in attr_names]
275
- for agent in self._agents
276
- ]
399
+ raise ValueError(
400
+ f"Unknown handle_missing option: {handle_missing}, "
401
+ "should be one of 'error' or 'default'"
402
+ )
403
+
404
+ def set(self, attr_name: str, value: Any) -> AgentSet:
405
+ """
406
+ Set a specified attribute to a given value for all agents in the AgentSet.
407
+
408
+ Args:
409
+ attr_name (str): The name of the attribute to set.
410
+ value (Any): The value to set the attribute to.
411
+
412
+ Returns:
413
+ AgentSet: The AgentSet instance itself, after setting the attribute.
414
+ """
415
+ for agent in self:
416
+ setattr(agent, attr_name, value)
417
+ return self
277
418
 
278
419
  def __getitem__(self, item: int | slice) -> Agent:
279
420
  """
@@ -357,7 +498,116 @@ class AgentSet(MutableSet, Sequence):
357
498
  """
358
499
  return self.model.random
359
500
 
501
+ def groupby(self, by: Callable | str, result_type: str = "agentset") -> GroupBy:
502
+ """
503
+ Group agents by the specified attribute or return from the callable
504
+
505
+ Args:
506
+ by (Callable, str): used to determine what to group agents by
507
+
508
+ * if ``by`` is a callable, it will be called for each agent and the return is used
509
+ for grouping
510
+ * if ``by`` is a str, it should refer to an attribute on the agent and the value
511
+ of this attribute will be used for grouping
512
+ result_type (str, optional): The datatype for the resulting groups {"agentset", "list"}
513
+ Returns:
514
+ GroupBy
515
+
516
+
517
+ Notes:
518
+ There might be performance benefits to using `result_type='list'` if you don't need the advanced functionality
519
+ of an AgentSet.
520
+
521
+ """
522
+ groups = defaultdict(list)
523
+
524
+ if isinstance(by, Callable):
525
+ for agent in self:
526
+ groups[by(agent)].append(agent)
527
+ else:
528
+ for agent in self:
529
+ groups[getattr(agent, by)].append(agent)
530
+
531
+ if result_type == "agentset":
532
+ return GroupBy(
533
+ {k: AgentSet(v, model=self.model) for k, v in groups.items()}
534
+ )
535
+ else:
536
+ return GroupBy(groups)
537
+
538
+ # consider adding for performance reasons
539
+ # for Sequence: __reversed__, index, and count
540
+ # for MutableSet clear, pop, remove, __ior__, __iand__, __ixor__, and __isub__
541
+
542
+
543
+ class GroupBy:
544
+ """Helper class for AgentSet.groupby
545
+
546
+
547
+ Attributes:
548
+ groups (dict): A dictionary with the group_name as key and group as values
549
+
550
+ """
551
+
552
+ def __init__(self, groups: dict[Any, list | AgentSet]):
553
+ self.groups: dict[Any, list | AgentSet] = groups
554
+
555
+ def map(self, method: Callable | str, *args, **kwargs) -> dict[Any, Any]:
556
+ """Apply the specified callable to each group and return the results.
557
+
558
+ Args:
559
+ method (Callable, str): The callable to apply to each group,
560
+
561
+ * if ``method`` is a callable, it will be called it will be called with the group as first argument
562
+ * if ``method`` is a str, it should refer to a method on the group
563
+
564
+ Additional arguments and keyword arguments will be passed on to the callable.
565
+
566
+ Returns:
567
+ dict with group_name as key and the return of the method as value
568
+
569
+ Notes:
570
+ this method is useful for methods or functions that do return something. It
571
+ will break method chaining. For that, use ``do`` instead.
572
+
573
+ """
574
+ if isinstance(method, str):
575
+ return {
576
+ k: getattr(v, method)(*args, **kwargs) for k, v in self.groups.items()
577
+ }
578
+ else:
579
+ return {k: method(v, *args, **kwargs) for k, v in self.groups.items()}
580
+
581
+ def do(self, method: Callable | str, *args, **kwargs) -> GroupBy:
582
+ """Apply the specified callable to each group
583
+
584
+ Args:
585
+ method (Callable, str): The callable to apply to each group,
586
+
587
+ * if ``method`` is a callable, it will be called it will be called with the group as first argument
588
+ * if ``method`` is a str, it should refer to a method on the group
589
+
590
+ Additional arguments and keyword arguments will be passed on to the callable.
591
+
592
+ Returns:
593
+ the original GroupBy instance
594
+
595
+ Notes:
596
+ this method is useful for methods or functions that don't return anything and/or
597
+ if you want to chain multiple do calls
598
+
599
+ """
600
+ if isinstance(method, str):
601
+ for v in self.groups.values():
602
+ getattr(v, method)(*args, **kwargs)
603
+ else:
604
+ for v in self.groups.values():
605
+ method(v, *args, **kwargs)
606
+
607
+ return self
608
+
609
+ def __iter__(self):
610
+ return iter(self.groups.items())
360
611
 
361
- # consider adding for performance reasons
362
- # for Sequence: __reversed__, index, and count
363
- # for MutableSet clear, pop, remove, __ior__, __iand__, __ixor__, and __isub__
612
+ def __len__(self):
613
+ return len(self.groups)
mesa/batchrunner.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import itertools
2
+ import multiprocessing
2
3
  from collections.abc import Iterable, Mapping
3
4
  from functools import partial
4
5
  from multiprocessing import Pool
@@ -8,6 +9,8 @@ from tqdm.auto import tqdm
8
9
 
9
10
  from mesa.model import Model
10
11
 
12
+ multiprocessing.set_start_method("spawn", force=True)
13
+
11
14
 
12
15
  def batch_run(
13
16
  model_cls: type[Model],
@@ -132,14 +135,14 @@ def _model_run_func(
132
135
  """
133
136
  run_id, iteration, kwargs = run
134
137
  model = model_cls(**kwargs)
135
- while model.running and model._steps <= max_steps:
138
+ while model.running and model.steps <= max_steps:
136
139
  model.step()
137
140
 
138
141
  data = []
139
142
 
140
- steps = list(range(0, model._steps, data_collection_period))
141
- if not steps or steps[-1] != model._steps - 1:
142
- steps.append(model._steps - 1)
143
+ steps = list(range(0, model.steps, data_collection_period))
144
+ if not steps or steps[-1] != model.steps - 1:
145
+ steps.append(model.steps - 1)
143
146
 
144
147
  for step in steps:
145
148
  model_data, all_agents_data = _collect_data(model, step)
mesa/datacollection.py CHANGED
@@ -180,7 +180,7 @@ class DataCollector:
180
180
  rep_funcs = self.agent_reporters.values()
181
181
 
182
182
  def get_reports(agent):
183
- _prefix = (agent.model._steps, agent.unique_id)
183
+ _prefix = (agent.model.steps, agent.unique_id)
184
184
  reports = tuple(rep(agent) for rep in rep_funcs)
185
185
  return _prefix + reports
186
186
 
@@ -216,7 +216,7 @@ class DataCollector:
216
216
 
217
217
  if self.agent_reporters:
218
218
  agent_records = self._record_agents(model)
219
- self._agent_records[model._steps] = list(agent_records)
219
+ self._agent_records[model.steps] = list(agent_records)
220
220
 
221
221
  def add_table_row(self, table_name, row, ignore_missing=False):
222
222
  """Add a row dictionary to a specific table.
@@ -0,0 +1,56 @@
1
+ class UserParam:
2
+ _ERROR_MESSAGE = "Missing or malformed inputs for '{}' Option '{}'"
3
+
4
+ def maybe_raise_error(self, param_type, valid):
5
+ if valid:
6
+ return
7
+ msg = self._ERROR_MESSAGE.format(param_type, self.label)
8
+ raise ValueError(msg)
9
+
10
+
11
+ class Slider(UserParam):
12
+ """
13
+ A number-based slider input with settable increment.
14
+
15
+ Example:
16
+
17
+ slider_option = Slider("My Slider", value=123, min=10, max=200, step=0.1)
18
+
19
+ Args:
20
+ label: The displayed label in the UI
21
+ value: The initial value of the slider
22
+ min: The minimum possible value of the slider
23
+ max: The maximum possible value of the slider
24
+ step: The step between min and max for a range of possible values
25
+ dtype: either int or float
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ label="",
31
+ value=None,
32
+ min=None,
33
+ max=None,
34
+ step=1,
35
+ dtype=None,
36
+ ):
37
+ self.label = label
38
+ self.value = value
39
+ self.min = min
40
+ self.max = max
41
+ self.step = step
42
+
43
+ # Validate option type to make sure values are supplied properly
44
+ valid = not (self.value is None or self.min is None or self.max is None)
45
+ self.maybe_raise_error("slider", valid)
46
+
47
+ if dtype is None:
48
+ self.is_float_slider = self._check_values_are_float(value, min, max, step)
49
+ else:
50
+ self.is_float_slider = dtype is float
51
+
52
+ def _check_values_are_float(self, value, min, max, step):
53
+ return any(isinstance(n, float) for n in (value, min, max, step))
54
+
55
+ def get(self, attr):
56
+ return getattr(self, attr)
@@ -1,3 +1,5 @@
1
1
  from mesa.experimental import cell_space
2
2
 
3
- __all__ = ["cell_space"]
3
+ from .solara_viz import JupyterViz, Slider, SolaraViz, make_text
4
+
5
+ __all__ = ["cell_space", "JupyterViz", "SolaraViz", "make_text", "Slider"]
@@ -9,6 +9,7 @@ from mesa.experimental.cell_space.grid import (
9
9
  OrthogonalVonNeumannGrid,
10
10
  )
11
11
  from mesa.experimental.cell_space.network import Network
12
+ from mesa.experimental.cell_space.voronoi import VoronoiGrid
12
13
 
13
14
  __all__ = [
14
15
  "CellCollection",
@@ -20,4 +21,5 @@ __all__ = [
20
21
  "OrthogonalMooreGrid",
21
22
  "OrthogonalVonNeumannGrid",
22
23
  "Network",
24
+ "VoronoiGrid",
23
25
  ]
@@ -19,7 +19,7 @@ class CellAgent(Agent):
19
19
  cell: (Cell | None): the cell which the agent occupies
20
20
  """
21
21
 
22
- def __init__(self, unique_id: int, model: Model) -> None:
22
+ def __init__(self, model: Model) -> None:
23
23
  """
24
24
  Create a new agent.
25
25
 
@@ -27,7 +27,7 @@ class CellAgent(Agent):
27
27
  unique_id (int): A unique identifier for this agent.
28
28
  model (Model): The model instance in which the agent exists.
29
29
  """
30
- super().__init__(unique_id, model)
30
+ super().__init__(model)
31
31
  self.cell: Cell | None = None
32
32
 
33
33
  def move_to(self, cell) -> None: