Mesa 3.0.0a3__py3-none-any.whl → 3.0.0a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of Mesa might be problematic. Click here for more details.

Files changed (40) hide show
  1. mesa/__init__.py +2 -3
  2. mesa/agent.py +193 -75
  3. mesa/batchrunner.py +18 -23
  4. mesa/cookiecutter-mesa/hooks/post_gen_project.py +2 -0
  5. mesa/cookiecutter-mesa/{{cookiecutter.snake}}/{{cookiecutter.snake}}/__init__.py +1 -0
  6. mesa/datacollection.py +138 -27
  7. mesa/experimental/UserParam.py +67 -0
  8. mesa/experimental/__init__.py +5 -1
  9. mesa/experimental/cell_space/__init__.py +7 -0
  10. mesa/experimental/cell_space/cell.py +61 -20
  11. mesa/experimental/cell_space/cell_agent.py +12 -7
  12. mesa/experimental/cell_space/cell_collection.py +54 -17
  13. mesa/experimental/cell_space/discrete_space.py +16 -5
  14. mesa/experimental/cell_space/grid.py +19 -8
  15. mesa/experimental/cell_space/network.py +9 -7
  16. mesa/experimental/cell_space/voronoi.py +26 -33
  17. mesa/experimental/components/altair.py +81 -0
  18. mesa/experimental/components/matplotlib.py +242 -0
  19. mesa/experimental/devs/__init__.py +2 -0
  20. mesa/experimental/devs/eventlist.py +36 -15
  21. mesa/experimental/devs/examples/epstein_civil_violence.py +71 -39
  22. mesa/experimental/devs/examples/wolf_sheep.py +43 -44
  23. mesa/experimental/devs/simulator.py +55 -15
  24. mesa/experimental/solara_viz.py +453 -0
  25. mesa/main.py +6 -4
  26. mesa/model.py +64 -61
  27. mesa/space.py +154 -123
  28. mesa/time.py +57 -67
  29. mesa/visualization/UserParam.py +19 -6
  30. mesa/visualization/__init__.py +14 -2
  31. mesa/visualization/components/altair.py +18 -1
  32. mesa/visualization/components/matplotlib.py +26 -2
  33. mesa/visualization/solara_viz.py +231 -225
  34. mesa/visualization/utils.py +9 -0
  35. {mesa-3.0.0a3.dist-info → mesa-3.0.0a5.dist-info}/METADATA +2 -1
  36. mesa-3.0.0a5.dist-info/RECORD +44 -0
  37. mesa-3.0.0a3.dist-info/RECORD +0 -39
  38. {mesa-3.0.0a3.dist-info → mesa-3.0.0a5.dist-info}/WHEEL +0 -0
  39. {mesa-3.0.0a3.dist-info → mesa-3.0.0a5.dist-info}/entry_points.txt +0 -0
  40. {mesa-3.0.0a3.dist-info → mesa-3.0.0a5.dist-info}/licenses/LICENSE +0 -0
mesa/__init__.py CHANGED
@@ -1,5 +1,4 @@
1
- """
2
- Mesa Agent-Based Modeling Framework
1
+ """Mesa Agent-Based Modeling Framework.
3
2
 
4
3
  Core Objects: Model, and Agent.
5
4
  """
@@ -24,7 +23,7 @@ __all__ = [
24
23
  ]
25
24
 
26
25
  __title__ = "mesa"
27
- __version__ = "3.0.0a3"
26
+ __version__ = "3.0.0a5"
28
27
  __license__ = "Apache 2.0"
29
28
  _this_year = datetime.datetime.now(tz=datetime.timezone.utc).date().year
30
29
  __copyright__ = f"Copyright {_this_year} Project Mesa Team"
mesa/agent.py CHANGED
@@ -1,7 +1,6 @@
1
- """
2
- The agent class for Mesa framework.
1
+ """Agent related classes.
3
2
 
4
- Core Objects: Agent
3
+ Core Objects: Agent and AgentSet.
5
4
  """
6
5
 
7
6
  # Mypy; for the `|` operator purpose
@@ -10,15 +9,17 @@ from __future__ import annotations
10
9
 
11
10
  import contextlib
12
11
  import copy
12
+ import functools
13
+ import itertools
13
14
  import operator
14
15
  import warnings
15
16
  import weakref
16
17
  from collections import defaultdict
17
- from collections.abc import Callable, Iterable, Iterator, MutableSet, Sequence
18
+ from collections.abc import Callable, Hashable, Iterable, Iterator, MutableSet, Sequence
18
19
  from random import Random
19
20
 
20
21
  # mypy
21
- from typing import TYPE_CHECKING, Any
22
+ from typing import TYPE_CHECKING, Any, Literal, overload
22
23
 
23
24
  if TYPE_CHECKING:
24
25
  # We ensure that these are not imported during runtime to prevent cyclic
@@ -28,25 +29,53 @@ if TYPE_CHECKING:
28
29
 
29
30
 
30
31
  class Agent:
31
- """
32
- Base class for a model agent in Mesa.
32
+ """Base class for a model agent in Mesa.
33
33
 
34
34
  Attributes:
35
- unique_id (int): A unique identifier for this agent.
36
35
  model (Model): A reference to the model instance.
37
- self.pos: Position | None = None
36
+ unique_id (int): A unique identifier for this agent.
37
+ pos (Position): A reference to the position where this agent is located.
38
+
39
+ Notes:
40
+ unique_id is unique relative to a model instance and starts from 1
41
+
38
42
  """
39
43
 
40
- def __init__(self, unique_id: int, model: Model) -> None:
41
- """
42
- Create a new agent.
44
+ # this is a class level attribute
45
+ # it is a dictionary, indexed by model instance
46
+ # so, unique_id is unique relative to a model, and counting starts from 1
47
+ _ids = defaultdict(functools.partial(itertools.count, 1))
48
+
49
+ def __init__(self, *args, **kwargs) -> None:
50
+ """Create a new agent.
43
51
 
44
52
  Args:
45
- unique_id (int): A unique identifier for this agent.
46
53
  model (Model): The model instance in which the agent exists.
47
- """
48
- self.unique_id = unique_id
49
- self.model = model
54
+ args: currently ignored, to be fixed in 3.1
55
+ kwargs: currently ignored, to be fixed in 3.1
56
+ """
57
+ # TODO: Cleanup in future Mesa version (3.1+)
58
+ match args:
59
+ # Case 1: Only the model is provided. The new correct behavior.
60
+ case [model]:
61
+ self.model = model
62
+ self.unique_id = next(self._ids[model])
63
+ # Case 2: Both unique_id and model are provided, deprecated
64
+ case [_, model]:
65
+ warnings.warn(
66
+ "unique ids are assigned automatically to Agents in Mesa 3. The use of custom unique_id is "
67
+ "deprecated. Only input a model when calling `super()__init__(model)`. The unique_id inputted is not used.",
68
+ DeprecationWarning,
69
+ stacklevel=2,
70
+ )
71
+ self.model = model
72
+ self.unique_id = next(self._ids[model])
73
+ # Case 3: Anything else, raise an error
74
+ case _:
75
+ raise ValueError(
76
+ "Invalid arguments provided to initialize the Agent. Only input a model: `super()__init__(model)`."
77
+ )
78
+
50
79
  self.pos: Position | None = None
51
80
 
52
81
  self.model.register_agent(self)
@@ -59,28 +88,25 @@ class Agent:
59
88
  def step(self) -> None:
60
89
  """A single step of the agent."""
61
90
 
62
- def advance(self) -> None:
91
+ def advance(self) -> None: # noqa: D102
63
92
  pass
64
93
 
65
94
  @property
66
95
  def random(self) -> Random:
96
+ """Return a seeded rng."""
67
97
  return self.model.random
68
98
 
69
99
 
70
100
  class AgentSet(MutableSet, Sequence):
71
- """
72
- A collection class that represents an ordered set of agents within an agent-based model (ABM). This class
73
- extends both MutableSet and Sequence, providing set-like functionality with order preservation and
101
+ """A collection class that represents an ordered set of agents within an agent-based model (ABM).
102
+
103
+ This class extends both MutableSet and Sequence, providing set-like functionality with order preservation and
74
104
  sequence operations.
75
105
 
76
106
  Attributes:
77
107
  model (Model): The ABM model instance to which this AgentSet belongs.
78
108
 
79
- Methods:
80
- __len__, __iter__, __contains__, select, shuffle, sort, _update, do, get, __getitem__,
81
- add, discard, remove, __getstate__, __setstate__, random
82
-
83
- Note:
109
+ Notes:
84
110
  The AgentSet maintains weak references to agents, allowing for efficient management of agent lifecycles
85
111
  without preventing garbage collection. It is associated with a specific model instance, enabling
86
112
  interactions with the model's environment and other agents.The implementation uses a WeakKeyDictionary to store agents,
@@ -88,14 +114,12 @@ class AgentSet(MutableSet, Sequence):
88
114
  """
89
115
 
90
116
  def __init__(self, agents: Iterable[Agent], model: Model):
91
- """
92
- Initializes the AgentSet with a collection of agents and a reference to the model.
117
+ """Initializes the AgentSet with a collection of agents and a reference to the model.
93
118
 
94
119
  Args:
95
120
  agents (Iterable[Agent]): An iterable of Agent objects to be included in the set.
96
121
  model (Model): The ABM model instance to which this AgentSet belongs.
97
122
  """
98
-
99
123
  self.model = model
100
124
  self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents})
101
125
 
@@ -119,8 +143,7 @@ class AgentSet(MutableSet, Sequence):
119
143
  agent_type: type[Agent] | None = None,
120
144
  n: int | None = None,
121
145
  ) -> AgentSet:
122
- """
123
- Select a subset of agents from the AgentSet based on a filter function and/or quantity limit.
146
+ """Select a subset of agents from the AgentSet based on a filter function and/or quantity limit.
124
147
 
125
148
  Args:
126
149
  filter_func (Callable[[Agent], bool], optional): A function that takes an Agent and returns True if the
@@ -130,6 +153,7 @@ class AgentSet(MutableSet, Sequence):
130
153
  - If a float between 0 and 1, at most that fraction of original the agents are selected.
131
154
  inplace (bool, optional): If True, modifies the current AgentSet; otherwise, returns a new AgentSet. Defaults to False.
132
155
  agent_type (type[Agent], optional): The class type of the agents to select. Defaults to None, meaning no type filtering is applied.
156
+ n (int): deprecated, use at_most instead
133
157
 
134
158
  Returns:
135
159
  AgentSet: A new AgentSet containing the selected agents, unless inplace is True, in which case the current AgentSet is updated.
@@ -170,8 +194,7 @@ class AgentSet(MutableSet, Sequence):
170
194
  return AgentSet(agents, self.model) if not inplace else self._update(agents)
171
195
 
172
196
  def shuffle(self, inplace: bool = False) -> AgentSet:
173
- """
174
- Randomly shuffle the order of agents in the AgentSet.
197
+ """Randomly shuffle the order of agents in the AgentSet.
175
198
 
176
199
  Args:
177
200
  inplace (bool, optional): If True, shuffles the agents in the current AgentSet; otherwise, returns a new shuffled AgentSet. Defaults to False.
@@ -200,8 +223,7 @@ class AgentSet(MutableSet, Sequence):
200
223
  ascending: bool = False,
201
224
  inplace: bool = False,
202
225
  ) -> AgentSet:
203
- """
204
- Sort the agents in the AgentSet based on a specified attribute or custom function.
226
+ """Sort the agents in the AgentSet based on a specified attribute or custom function.
205
227
 
206
228
  Args:
207
229
  key (Callable[[Agent], Any] | str): A function or attribute name based on which the agents are sorted.
@@ -224,15 +246,14 @@ class AgentSet(MutableSet, Sequence):
224
246
 
225
247
  def _update(self, agents: Iterable[Agent]):
226
248
  """Update the AgentSet with a new set of agents.
249
+
227
250
  This is a private method primarily used internally by other methods like select, shuffle, and sort.
228
251
  """
229
-
230
252
  self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents})
231
253
  return self
232
254
 
233
255
  def do(self, method: str | Callable, *args, **kwargs) -> AgentSet:
234
- """
235
- Invoke a method or function on each agent in the AgentSet.
256
+ """Invoke a method or function on each agent in the AgentSet.
236
257
 
237
258
  Args:
238
259
  method (str, callable): the callable to do on each agent
@@ -272,9 +293,25 @@ class AgentSet(MutableSet, Sequence):
272
293
 
273
294
  return self
274
295
 
275
- def map(self, method: str | Callable, *args, **kwargs) -> list[Any]:
296
+ def shuffle_do(self, method: str | Callable, *args, **kwargs) -> AgentSet:
297
+ """Shuffle the agents in the AgentSet and then invoke a method or function on each agent.
298
+
299
+ It's a fast, optimized version of calling shuffle() followed by do().
276
300
  """
277
- Invoke a method or function on each agent in the AgentSet and return the results.
301
+ agents = list(self._agents.keys())
302
+ self.random.shuffle(agents)
303
+
304
+ if isinstance(method, str):
305
+ for agent in agents:
306
+ getattr(agent, method)(*args, **kwargs)
307
+ else:
308
+ for agent in agents:
309
+ method(agent, *args, **kwargs)
310
+
311
+ return self
312
+
313
+ def map(self, method: str | Callable, *args, **kwargs) -> list[Any]:
314
+ """Invoke a method or function on each agent in the AgentSet and return the results.
278
315
 
279
316
  Args:
280
317
  method (str, callable): the callable to apply on each agent
@@ -304,33 +341,89 @@ class AgentSet(MutableSet, Sequence):
304
341
 
305
342
  return res
306
343
 
307
- def get(self, attr_names: str | list[str]) -> list[Any]:
344
+ def agg(self, attribute: str, func: Callable) -> Any:
345
+ """Aggregate an attribute of all agents in the AgentSet using a specified function.
346
+
347
+ Args:
348
+ attribute (str): The name of the attribute to aggregate.
349
+ func (Callable): The function to apply to the attribute values (e.g., min, max, sum, np.mean).
350
+
351
+ Returns:
352
+ Any: The result of applying the function to the attribute values. Often a single value.
308
353
  """
309
- Retrieve the specified attribute(s) from each agent in the AgentSet.
354
+ values = self.get(attribute)
355
+ return func(values)
356
+
357
+ @overload
358
+ def get(
359
+ self,
360
+ attr_names: str,
361
+ handle_missing: Literal["error", "default"] = "error",
362
+ default_value: Any = None,
363
+ ) -> list[Any]: ...
364
+
365
+ @overload
366
+ def get(
367
+ self,
368
+ attr_names: list[str],
369
+ handle_missing: Literal["error", "default"] = "error",
370
+ default_value: Any = None,
371
+ ) -> list[list[Any]]: ...
372
+
373
+ def get(
374
+ self,
375
+ attr_names,
376
+ handle_missing="error",
377
+ default_value=None,
378
+ ):
379
+ """Retrieve the specified attribute(s) from each agent in the AgentSet.
310
380
 
311
381
  Args:
312
382
  attr_names (str | list[str]): The name(s) of the attribute(s) to retrieve from each agent.
383
+ handle_missing (str, optional): How to handle missing attributes. Can be:
384
+ - 'error' (default): raises an AttributeError if attribute is missing.
385
+ - 'default': returns the specified default_value.
386
+ default_value (Any, optional): The default value to return if 'handle_missing' is set to 'default'
387
+ and the agent does not have the attribute.
313
388
 
314
389
  Returns:
315
- list[Any]: A list with the attribute value for each agent in the set if attr_names is a str
316
- list[list[Any]]: A list with a list of attribute values for each agent in the set if attr_names is a list of str
390
+ list[Any]: A list with the attribute value for each agent if attr_names is a str.
391
+ list[list[Any]]: A list with a lists of attribute values for each agent if attr_names is a list of str.
317
392
 
318
393
  Raises:
319
- AttributeError if an agent does not have the specified attribute(s)
320
-
321
- """
394
+ AttributeError: If 'handle_missing' is 'error' and the agent does not have the specified attribute(s).
395
+ ValueError: If an unknown 'handle_missing' option is provided.
396
+ """
397
+ is_single_attr = isinstance(attr_names, str)
398
+
399
+ if handle_missing == "error":
400
+ if is_single_attr:
401
+ return [getattr(agent, attr_names) for agent in self._agents]
402
+ else:
403
+ return [
404
+ [getattr(agent, attr) for attr in attr_names]
405
+ for agent in self._agents
406
+ ]
407
+
408
+ elif handle_missing == "default":
409
+ if is_single_attr:
410
+ return [
411
+ getattr(agent, attr_names, default_value) for agent in self._agents
412
+ ]
413
+ else:
414
+ return [
415
+ [getattr(agent, attr, default_value) for attr in attr_names]
416
+ for agent in self._agents
417
+ ]
322
418
 
323
- if isinstance(attr_names, str):
324
- return [getattr(agent, attr_names) for agent in self._agents]
325
419
  else:
326
- return [
327
- [getattr(agent, attr_name) for attr_name in attr_names]
328
- for agent in self._agents
329
- ]
420
+ raise ValueError(
421
+ f"Unknown handle_missing option: {handle_missing}, "
422
+ "should be one of 'error' or 'default'"
423
+ )
330
424
 
331
425
  def set(self, attr_name: str, value: Any) -> AgentSet:
332
- """
333
- Set a specified attribute to a given value for all agents in the AgentSet.
426
+ """Set a specified attribute to a given value for all agents in the AgentSet.
334
427
 
335
428
  Args:
336
429
  attr_name (str): The name of the attribute to set.
@@ -344,8 +437,7 @@ class AgentSet(MutableSet, Sequence):
344
437
  return self
345
438
 
346
439
  def __getitem__(self, item: int | slice) -> Agent:
347
- """
348
- Retrieve an agent or a slice of agents from the AgentSet.
440
+ """Retrieve an agent or a slice of agents from the AgentSet.
349
441
 
350
442
  Args:
351
443
  item (int | slice): The index or slice for selecting agents.
@@ -356,8 +448,7 @@ class AgentSet(MutableSet, Sequence):
356
448
  return list(self._agents.keys())[item]
357
449
 
358
450
  def add(self, agent: Agent):
359
- """
360
- Add an agent to the AgentSet.
451
+ """Add an agent to the AgentSet.
361
452
 
362
453
  Args:
363
454
  agent (Agent): The agent to add to the set.
@@ -368,8 +459,7 @@ class AgentSet(MutableSet, Sequence):
368
459
  self._agents[agent] = None
369
460
 
370
461
  def discard(self, agent: Agent):
371
- """
372
- Remove an agent from the AgentSet if it exists.
462
+ """Remove an agent from the AgentSet if it exists.
373
463
 
374
464
  This method does not raise an error if the agent is not present.
375
465
 
@@ -383,8 +473,7 @@ class AgentSet(MutableSet, Sequence):
383
473
  del self._agents[agent]
384
474
 
385
475
  def remove(self, agent: Agent):
386
- """
387
- Remove an agent from the AgentSet.
476
+ """Remove an agent from the AgentSet.
388
477
 
389
478
  This method raises an error if the agent is not present.
390
479
 
@@ -397,8 +486,7 @@ class AgentSet(MutableSet, Sequence):
397
486
  del self._agents[agent]
398
487
 
399
488
  def __getstate__(self):
400
- """
401
- Retrieve the state of the AgentSet for serialization.
489
+ """Retrieve the state of the AgentSet for serialization.
402
490
 
403
491
  Returns:
404
492
  dict: A dictionary representing the state of the AgentSet.
@@ -406,8 +494,7 @@ class AgentSet(MutableSet, Sequence):
406
494
  return {"agents": list(self._agents.keys()), "model": self.model}
407
495
 
408
496
  def __setstate__(self, state):
409
- """
410
- Set the state of the AgentSet during deserialization.
497
+ """Set the state of the AgentSet during deserialization.
411
498
 
412
499
  Args:
413
500
  state (dict): A dictionary representing the state to restore.
@@ -417,8 +504,7 @@ class AgentSet(MutableSet, Sequence):
417
504
 
418
505
  @property
419
506
  def random(self) -> Random:
420
- """
421
- Provide access to the model's random number generator.
507
+ """Provide access to the model's random number generator.
422
508
 
423
509
  Returns:
424
510
  Random: The random number generator associated with the model.
@@ -426,8 +512,7 @@ class AgentSet(MutableSet, Sequence):
426
512
  return self.model.random
427
513
 
428
514
  def groupby(self, by: Callable | str, result_type: str = "agentset") -> GroupBy:
429
- """
430
- Group agents by the specified attribute or return from the callable
515
+ """Group agents by the specified attribute or return from the callable.
431
516
 
432
517
  Args:
433
518
  by (Callable, str): used to determine what to group agents by
@@ -437,6 +522,7 @@ class AgentSet(MutableSet, Sequence):
437
522
  * if ``by`` is a str, it should refer to an attribute on the agent and the value
438
523
  of this attribute will be used for grouping
439
524
  result_type (str, optional): The datatype for the resulting groups {"agentset", "list"}
525
+
440
526
  Returns:
441
527
  GroupBy
442
528
 
@@ -468,8 +554,7 @@ class AgentSet(MutableSet, Sequence):
468
554
 
469
555
 
470
556
  class GroupBy:
471
- """Helper class for AgentSet.groupby
472
-
557
+ """Helper class for AgentSet.groupby.
473
558
 
474
559
  Attributes:
475
560
  groups (dict): A dictionary with the group_name as key and group as values
@@ -477,6 +562,12 @@ class GroupBy:
477
562
  """
478
563
 
479
564
  def __init__(self, groups: dict[Any, list | AgentSet]):
565
+ """Initialize a GroupBy instance.
566
+
567
+ Args:
568
+ groups (dict): A dictionary with the group_name as key and group as values
569
+
570
+ """
480
571
  self.groups: dict[Any, list | AgentSet] = groups
481
572
 
482
573
  def map(self, method: Callable | str, *args, **kwargs) -> dict[Any, Any]:
@@ -489,6 +580,8 @@ class GroupBy:
489
580
  * if ``method`` is a str, it should refer to a method on the group
490
581
 
491
582
  Additional arguments and keyword arguments will be passed on to the callable.
583
+ args: arguments to pass to the callable
584
+ kwargs: keyword arguments to pass to the callable
492
585
 
493
586
  Returns:
494
587
  dict with group_name as key and the return of the method as value
@@ -506,7 +599,7 @@ class GroupBy:
506
599
  return {k: method(v, *args, **kwargs) for k, v in self.groups.items()}
507
600
 
508
601
  def do(self, method: Callable | str, *args, **kwargs) -> GroupBy:
509
- """Apply the specified callable to each group
602
+ """Apply the specified callable to each group.
510
603
 
511
604
  Args:
512
605
  method (Callable, str): The callable to apply to each group,
@@ -515,6 +608,8 @@ class GroupBy:
515
608
  * if ``method`` is a str, it should refer to a method on the group
516
609
 
517
610
  Additional arguments and keyword arguments will be passed on to the callable.
611
+ args: arguments to pass to the callable
612
+ kwargs: keyword arguments to pass to the callable
518
613
 
519
614
  Returns:
520
615
  the original GroupBy instance
@@ -533,8 +628,31 @@ class GroupBy:
533
628
 
534
629
  return self
535
630
 
536
- def __iter__(self):
631
+ def count(self) -> dict[Any, int]:
632
+ """Return the count of agents in each group.
633
+
634
+ Returns:
635
+ dict: A dictionary mapping group names to the number of agents in each group.
636
+ """
637
+ return {k: len(v) for k, v in self.groups.items()}
638
+
639
+ def agg(self, attr_name: str, func: Callable) -> dict[Hashable, Any]:
640
+ """Aggregate the values of a specific attribute across each group using the provided function.
641
+
642
+ Args:
643
+ attr_name (str): The name of the attribute to aggregate.
644
+ func (Callable): The function to apply (e.g., sum, min, max, mean).
645
+
646
+ Returns:
647
+ dict[Hashable, Any]: A dictionary mapping group names to the result of applying the aggregation function.
648
+ """
649
+ return {
650
+ group_name: func([getattr(agent, attr_name) for agent in group])
651
+ for group_name, group in self.groups.items()
652
+ }
653
+
654
+ def __iter__(self): # noqa: D105
537
655
  return iter(self.groups.items())
538
656
 
539
- def __len__(self):
657
+ def __len__(self): # noqa: D105
540
658
  return len(self.groups)
mesa/batchrunner.py CHANGED
@@ -1,4 +1,7 @@
1
+ """batchrunner for running a factorial experiment design over a model."""
2
+
1
3
  import itertools
4
+ import multiprocessing
2
5
  from collections.abc import Iterable, Mapping
3
6
  from functools import partial
4
7
  from multiprocessing import Pool
@@ -8,6 +11,8 @@ from tqdm.auto import tqdm
8
11
 
9
12
  from mesa.model import Model
10
13
 
14
+ multiprocessing.set_start_method("spawn", force=True)
15
+
11
16
 
12
17
  def batch_run(
13
18
  model_cls: type[Model],
@@ -21,29 +26,19 @@ def batch_run(
21
26
  ) -> list[dict[str, Any]]:
22
27
  """Batch run a mesa model with a set of parameter values.
23
28
 
24
- Parameters
25
- ----------
26
- model_cls : Type[Model]
27
- The model class to batch-run
28
- parameters : Mapping[str, Union[Any, Iterable[Any]]],
29
- Dictionary with model parameters over which to run the model. You can either pass single values or iterables.
30
- number_processes : int, optional
31
- Number of processes used, by default 1. Set this to None if you want to use all CPUs.
32
- iterations : int, optional
33
- Number of iterations for each parameter combination, by default 1
34
- data_collection_period : int, optional
35
- Number of steps after which data gets collected, by default -1 (end of episode)
36
- max_steps : int, optional
37
- Maximum number of model steps after which the model halts, by default 1000
38
- display_progress : bool, optional
39
- Display batch run process, by default True
29
+ Args:
30
+ model_cls (Type[Model]): The model class to batch-run
31
+ parameters (Mapping[str, Union[Any, Iterable[Any]]]): Dictionary with model parameters over which to run the model. You can either pass single values or iterables.
32
+ number_processes (int, optional): Number of processes used, by default 1. Set this to None if you want to use all CPUs.
33
+ iterations (int, optional): Number of iterations for each parameter combination, by default 1
34
+ data_collection_period (int, optional): Number of steps after which data gets collected, by default -1 (end of episode)
35
+ max_steps (int, optional): Maximum number of model steps after which the model halts, by default 1000
36
+ display_progress (bool, optional): Display batch run process, by default True
40
37
 
41
- Returns
42
- -------
43
- List[Dict[str, Any]]
44
- [description]
45
- """
38
+ Returns:
39
+ List[Dict[str, Any]]
46
40
 
41
+ """
47
42
  runs_list = []
48
43
  run_id = 0
49
44
  for iteration in range(iterations):
@@ -85,7 +80,7 @@ def _make_model_kwargs(
85
80
  parameters : Mapping[str, Union[Any, Iterable[Any]]]
86
81
  Single or multiple values for each model parameter name
87
82
 
88
- Returns
83
+ Returns:
89
84
  -------
90
85
  List[Dict[str, Any]]
91
86
  A list of all kwargs combinations.
@@ -125,7 +120,7 @@ def _model_run_func(
125
120
  data_collection_period : int
126
121
  Number of steps after which data gets collected
127
122
 
128
- Returns
123
+ Returns:
129
124
  -------
130
125
  List[Dict[str, Any]]
131
126
  Return model_data, agent_data from the reporters
@@ -1,3 +1,5 @@
1
+ """helper module."""
2
+
1
3
  import glob
2
4
  import os
3
5