Mesa 2.4.0__py3-none-any.whl → 3.0.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of Mesa might be problematic. Click here for more details.
- mesa/__init__.py +1 -3
- mesa/agent.py +53 -332
- mesa/batchrunner.py +8 -11
- mesa/cookiecutter-mesa/{{cookiecutter.snake}}/app.pytemplate +27 -0
- mesa/cookiecutter-mesa/{{cookiecutter.snake}}/{{cookiecutter.snake}}/model.pytemplate +1 -1
- mesa/datacollection.py +21 -136
- mesa/experimental/__init__.py +1 -3
- mesa/experimental/cell_space/cell_collection.py +2 -2
- mesa/experimental/cell_space/grid.py +1 -1
- mesa/experimental/cell_space/network.py +3 -3
- mesa/experimental/devs/eventlist.py +2 -1
- mesa/experimental/devs/examples/epstein_civil_violence.py +1 -2
- mesa/experimental/devs/examples/wolf_sheep.py +2 -3
- mesa/experimental/devs/simulator.py +2 -1
- mesa/model.py +23 -93
- mesa/space.py +11 -17
- mesa/time.py +1 -3
- mesa/visualization/UserParam.py +56 -1
- mesa/visualization/__init__.py +2 -6
- mesa/{experimental → visualization}/components/altair.py +1 -2
- mesa/{experimental → visualization}/components/matplotlib.py +4 -6
- mesa/{experimental → visualization}/jupyter_viz.py +117 -20
- {mesa-2.4.0.dist-info → mesa-3.0.0a0.dist-info}/METADATA +4 -11
- mesa-3.0.0a0.dist-info/RECORD +38 -0
- mesa/cookiecutter-mesa/{{cookiecutter.snake}}/run.pytemplate +0 -3
- mesa/cookiecutter-mesa/{{cookiecutter.snake}}/{{cookiecutter.snake}}/server.pytemplate +0 -36
- mesa/experimental/UserParam.py +0 -56
- mesa/flat/__init__.py +0 -6
- mesa/flat/visualization.py +0 -5
- mesa/visualization/ModularVisualization.py +0 -1
- mesa/visualization/TextVisualization.py +0 -1
- mesa/visualization/modules.py +0 -1
- mesa-2.4.0.dist-info/RECORD +0 -45
- {mesa-2.4.0.dist-info → mesa-3.0.0a0.dist-info}/WHEEL +0 -0
- {mesa-2.4.0.dist-info → mesa-3.0.0a0.dist-info}/entry_points.txt +0 -0
- {mesa-2.4.0.dist-info → mesa-3.0.0a0.dist-info}/licenses/LICENSE +0 -0
mesa/__init__.py
CHANGED
|
@@ -8,7 +8,6 @@ import datetime
|
|
|
8
8
|
|
|
9
9
|
import mesa.space as space
|
|
10
10
|
import mesa.time as time
|
|
11
|
-
import mesa.visualization as visualization
|
|
12
11
|
from mesa.agent import Agent
|
|
13
12
|
from mesa.batchrunner import batch_run
|
|
14
13
|
from mesa.datacollection import DataCollector
|
|
@@ -19,14 +18,13 @@ __all__ = [
|
|
|
19
18
|
"Agent",
|
|
20
19
|
"time",
|
|
21
20
|
"space",
|
|
22
|
-
"visualization",
|
|
23
21
|
"DataCollector",
|
|
24
22
|
"batch_run",
|
|
25
23
|
"experimental",
|
|
26
24
|
]
|
|
27
25
|
|
|
28
26
|
__title__ = "mesa"
|
|
29
|
-
__version__ = "
|
|
27
|
+
__version__ = "3.0.0a0"
|
|
30
28
|
__license__ = "Apache 2.0"
|
|
31
29
|
_this_year = datetime.datetime.now(tz=datetime.timezone.utc).date().year
|
|
32
30
|
__copyright__ = f"Copyright {_this_year} Project Mesa Team"
|
mesa/agent.py
CHANGED
|
@@ -14,11 +14,11 @@ import operator
|
|
|
14
14
|
import warnings
|
|
15
15
|
import weakref
|
|
16
16
|
from collections import defaultdict
|
|
17
|
-
from collections.abc import
|
|
17
|
+
from collections.abc import Callable, Iterable, Iterator, MutableSet, Sequence
|
|
18
18
|
from random import Random
|
|
19
19
|
|
|
20
20
|
# mypy
|
|
21
|
-
from typing import TYPE_CHECKING, Any
|
|
21
|
+
from typing import TYPE_CHECKING, Any
|
|
22
22
|
|
|
23
23
|
if TYPE_CHECKING:
|
|
24
24
|
# We ensure that these are not imported during runtime to prevent cyclic
|
|
@@ -49,12 +49,25 @@ class Agent:
|
|
|
49
49
|
self.model = model
|
|
50
50
|
self.pos: Position | None = None
|
|
51
51
|
|
|
52
|
-
|
|
52
|
+
# register agent
|
|
53
|
+
try:
|
|
54
|
+
self.model.agents_[type(self)][self] = None
|
|
55
|
+
except AttributeError:
|
|
56
|
+
# model super has not been called
|
|
57
|
+
self.model.agents_ = defaultdict(dict)
|
|
58
|
+
self.model.agents_[type(self)][self] = None
|
|
59
|
+
self.model.agentset_experimental_warning_given = False
|
|
60
|
+
|
|
61
|
+
warnings.warn(
|
|
62
|
+
"The Mesa Model class was not initialized. In the future, you need to explicitly initialize the Model by calling super().__init__() on initialization.",
|
|
63
|
+
FutureWarning,
|
|
64
|
+
stacklevel=2,
|
|
65
|
+
)
|
|
53
66
|
|
|
54
67
|
def remove(self) -> None:
|
|
55
68
|
"""Remove and delete the agent from the model."""
|
|
56
69
|
with contextlib.suppress(KeyError):
|
|
57
|
-
self.model.
|
|
70
|
+
self.model.agents_[type(self)].pop(self)
|
|
58
71
|
|
|
59
72
|
def step(self) -> None:
|
|
60
73
|
"""A single step of the agent."""
|
|
@@ -116,10 +129,9 @@ class AgentSet(MutableSet, Sequence):
|
|
|
116
129
|
def select(
|
|
117
130
|
self,
|
|
118
131
|
filter_func: Callable[[Agent], bool] | None = None,
|
|
119
|
-
|
|
132
|
+
n: int = 0,
|
|
120
133
|
inplace: bool = False,
|
|
121
134
|
agent_type: type[Agent] | None = None,
|
|
122
|
-
n: int | None = None,
|
|
123
135
|
) -> AgentSet:
|
|
124
136
|
"""
|
|
125
137
|
Select a subset of agents from the AgentSet based on a filter function and/or quantity limit.
|
|
@@ -127,47 +139,29 @@ class AgentSet(MutableSet, Sequence):
|
|
|
127
139
|
Args:
|
|
128
140
|
filter_func (Callable[[Agent], bool], optional): A function that takes an Agent and returns True if the
|
|
129
141
|
agent should be included in the result. Defaults to None, meaning no filtering is applied.
|
|
130
|
-
|
|
131
|
-
- If an integer, at most the first number of matching agents are selected.
|
|
132
|
-
- If a float between 0 and 1, at most that fraction of original the agents are selected.
|
|
142
|
+
n (int, optional): The number of agents to select. If 0, all matching agents are selected. Defaults to 0.
|
|
133
143
|
inplace (bool, optional): If True, modifies the current AgentSet; otherwise, returns a new AgentSet. Defaults to False.
|
|
134
144
|
agent_type (type[Agent], optional): The class type of the agents to select. Defaults to None, meaning no type filtering is applied.
|
|
135
145
|
|
|
136
146
|
Returns:
|
|
137
147
|
AgentSet: A new AgentSet containing the selected agents, unless inplace is True, in which case the current AgentSet is updated.
|
|
138
|
-
|
|
139
|
-
Notes:
|
|
140
|
-
- at_most just return the first n or fraction of agents. To take a random sample, shuffle() beforehand.
|
|
141
|
-
- at_most is an upper limit. When specifying other criteria, the number of agents returned can be smaller.
|
|
142
148
|
"""
|
|
143
|
-
if n is not None:
|
|
144
|
-
warnings.warn(
|
|
145
|
-
"The parameter 'n' is deprecated. Use 'at_most' instead.",
|
|
146
|
-
DeprecationWarning,
|
|
147
|
-
stacklevel=2,
|
|
148
|
-
)
|
|
149
|
-
at_most = n
|
|
150
149
|
|
|
151
|
-
|
|
152
|
-
if filter_func is None and agent_type is None and at_most == inf:
|
|
150
|
+
if filter_func is None and agent_type is None and n == 0:
|
|
153
151
|
return self if inplace else copy.copy(self)
|
|
154
152
|
|
|
155
|
-
|
|
156
|
-
if at_most <= 1.0 and isinstance(at_most, float):
|
|
157
|
-
at_most = int(len(self) * at_most) # Note that it rounds down (floor)
|
|
158
|
-
|
|
159
|
-
def agent_generator(filter_func, agent_type, at_most):
|
|
153
|
+
def agent_generator(filter_func=None, agent_type=None, n=0):
|
|
160
154
|
count = 0
|
|
161
155
|
for agent in self:
|
|
162
|
-
if count >= at_most:
|
|
163
|
-
break
|
|
164
156
|
if (not filter_func or filter_func(agent)) and (
|
|
165
157
|
not agent_type or isinstance(agent, agent_type)
|
|
166
158
|
):
|
|
167
159
|
yield agent
|
|
168
160
|
count += 1
|
|
161
|
+
if 0 < n <= count:
|
|
162
|
+
break
|
|
169
163
|
|
|
170
|
-
agents = agent_generator(filter_func, agent_type,
|
|
164
|
+
agents = agent_generator(filter_func, agent_type, n)
|
|
171
165
|
|
|
172
166
|
return AgentSet(agents, self.model) if not inplace else self._update(agents)
|
|
173
167
|
|
|
@@ -232,194 +226,53 @@ class AgentSet(MutableSet, Sequence):
|
|
|
232
226
|
self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents})
|
|
233
227
|
return self
|
|
234
228
|
|
|
235
|
-
def do(
|
|
229
|
+
def do(
|
|
230
|
+
self, method_name: str, *args, return_results: bool = False, **kwargs
|
|
231
|
+
) -> AgentSet | list[Any]:
|
|
236
232
|
"""
|
|
237
|
-
Invoke a method
|
|
233
|
+
Invoke a method on each agent in the AgentSet.
|
|
238
234
|
|
|
239
235
|
Args:
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
*args: Variable length argument list passed to the callable being called.
|
|
246
|
-
**kwargs: Arbitrary keyword arguments passed to the callable being called.
|
|
236
|
+
method_name (str): The name of the method to call on each agent.
|
|
237
|
+
return_results (bool, optional): If True, returns the results of the method calls; otherwise, returns the AgentSet itself. Defaults to False, so you can chain method calls.
|
|
238
|
+
*args: Variable length argument list passed to the method being called.
|
|
239
|
+
**kwargs: Arbitrary keyword arguments passed to the method being called.
|
|
247
240
|
|
|
248
241
|
Returns:
|
|
249
|
-
AgentSet | list[Any]: The results of the
|
|
242
|
+
AgentSet | list[Any]: The results of the method calls if return_results is True, otherwise the AgentSet itself.
|
|
250
243
|
"""
|
|
251
|
-
try:
|
|
252
|
-
return_results = kwargs.pop("return_results")
|
|
253
|
-
except KeyError:
|
|
254
|
-
return_results = False
|
|
255
|
-
else:
|
|
256
|
-
warnings.warn(
|
|
257
|
-
"Using return_results is deprecated. Use AgenSet.do in case of return_results=False, and "
|
|
258
|
-
"AgentSet.map in case of return_results=True",
|
|
259
|
-
stacklevel=2,
|
|
260
|
-
)
|
|
261
|
-
|
|
262
|
-
if return_results:
|
|
263
|
-
return self.map(method, *args, **kwargs)
|
|
264
|
-
|
|
265
244
|
# we iterate over the actual weakref keys and check if weakref is alive before calling the method
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
for agentref in self._agents.keyrefs():
|
|
272
|
-
if (agent := agentref()) is not None:
|
|
273
|
-
method(agent, *args, **kwargs)
|
|
274
|
-
|
|
275
|
-
return self
|
|
276
|
-
|
|
277
|
-
def shuffle_do(self, method: str | Callable, *args, **kwargs) -> AgentSet:
|
|
278
|
-
"""Shuffle the agents in the AgentSet and then invoke a method or function on each agent.
|
|
279
|
-
|
|
280
|
-
It's a fast, optimized version of calling shuffle() followed by do().
|
|
281
|
-
"""
|
|
282
|
-
agents = list(self._agents.keys())
|
|
283
|
-
self.random.shuffle(agents)
|
|
284
|
-
|
|
285
|
-
if isinstance(method, str):
|
|
286
|
-
for agent in agents:
|
|
287
|
-
getattr(agent, method)(*args, **kwargs)
|
|
288
|
-
else:
|
|
289
|
-
for agent in agents:
|
|
290
|
-
method(agent, *args, **kwargs)
|
|
245
|
+
res = [
|
|
246
|
+
getattr(agent, method_name)(*args, **kwargs)
|
|
247
|
+
for agentref in self._agents.keyrefs()
|
|
248
|
+
if (agent := agentref()) is not None
|
|
249
|
+
]
|
|
291
250
|
|
|
292
|
-
return self
|
|
251
|
+
return res if return_results else self
|
|
293
252
|
|
|
294
|
-
def
|
|
295
|
-
"""
|
|
296
|
-
Invoke a method or function on each agent in the AgentSet and return the results.
|
|
297
|
-
|
|
298
|
-
Args:
|
|
299
|
-
method (str, callable): the callable to apply on each agent
|
|
300
|
-
|
|
301
|
-
* in case of str, the name of the method to call on each agent.
|
|
302
|
-
* in case of callable, the function to be called with each agent as first argument
|
|
303
|
-
|
|
304
|
-
*args: Variable length argument list passed to the callable being called.
|
|
305
|
-
**kwargs: Arbitrary keyword arguments passed to the callable being called.
|
|
306
|
-
|
|
307
|
-
Returns:
|
|
308
|
-
list[Any]: The results of the callable calls
|
|
309
|
-
"""
|
|
310
|
-
# we iterate over the actual weakref keys and check if weakref is alive before calling the method
|
|
311
|
-
if isinstance(method, str):
|
|
312
|
-
res = [
|
|
313
|
-
getattr(agent, method)(*args, **kwargs)
|
|
314
|
-
for agentref in self._agents.keyrefs()
|
|
315
|
-
if (agent := agentref()) is not None
|
|
316
|
-
]
|
|
317
|
-
else:
|
|
318
|
-
res = [
|
|
319
|
-
method(agent, *args, **kwargs)
|
|
320
|
-
for agentref in self._agents.keyrefs()
|
|
321
|
-
if (agent := agentref()) is not None
|
|
322
|
-
]
|
|
323
|
-
|
|
324
|
-
return res
|
|
325
|
-
|
|
326
|
-
def agg(self, attribute: str, func: Callable) -> Any:
|
|
327
|
-
"""
|
|
328
|
-
Aggregate an attribute of all agents in the AgentSet using a specified function.
|
|
329
|
-
|
|
330
|
-
Args:
|
|
331
|
-
attribute (str): The name of the attribute to aggregate.
|
|
332
|
-
func (Callable): The function to apply to the attribute values (e.g., min, max, sum, np.mean).
|
|
333
|
-
|
|
334
|
-
Returns:
|
|
335
|
-
Any: The result of applying the function to the attribute values. Often a single value.
|
|
336
|
-
"""
|
|
337
|
-
values = self.get(attribute)
|
|
338
|
-
return func(values)
|
|
339
|
-
|
|
340
|
-
@overload
|
|
341
|
-
def get(
|
|
342
|
-
self,
|
|
343
|
-
attr_names: str,
|
|
344
|
-
handle_missing: Literal["error", "default"] = "error",
|
|
345
|
-
default_value: Any = None,
|
|
346
|
-
) -> list[Any]: ...
|
|
347
|
-
|
|
348
|
-
@overload
|
|
349
|
-
def get(
|
|
350
|
-
self,
|
|
351
|
-
attr_names: list[str],
|
|
352
|
-
handle_missing: Literal["error", "default"] = "error",
|
|
353
|
-
default_value: Any = None,
|
|
354
|
-
) -> list[list[Any]]: ...
|
|
355
|
-
|
|
356
|
-
def get(
|
|
357
|
-
self,
|
|
358
|
-
attr_names,
|
|
359
|
-
handle_missing="error",
|
|
360
|
-
default_value=None,
|
|
361
|
-
):
|
|
253
|
+
def get(self, attr_names: str | list[str]) -> list[Any]:
|
|
362
254
|
"""
|
|
363
255
|
Retrieve the specified attribute(s) from each agent in the AgentSet.
|
|
364
256
|
|
|
365
257
|
Args:
|
|
366
258
|
attr_names (str | list[str]): The name(s) of the attribute(s) to retrieve from each agent.
|
|
367
|
-
handle_missing (str, optional): How to handle missing attributes. Can be:
|
|
368
|
-
- 'error' (default): raises an AttributeError if attribute is missing.
|
|
369
|
-
- 'default': returns the specified default_value.
|
|
370
|
-
default_value (Any, optional): The default value to return if 'handle_missing' is set to 'default'
|
|
371
|
-
and the agent does not have the attribute.
|
|
372
259
|
|
|
373
260
|
Returns:
|
|
374
|
-
list[Any]: A list with the attribute value for each agent if attr_names is a str
|
|
375
|
-
list[list[Any]]: A list with a
|
|
261
|
+
list[Any]: A list with the attribute value for each agent in the set if attr_names is a str
|
|
262
|
+
list[list[Any]]: A list with a list of attribute values for each agent in the set if attr_names is a list of str
|
|
376
263
|
|
|
377
264
|
Raises:
|
|
378
|
-
AttributeError
|
|
379
|
-
ValueError: If an unknown 'handle_missing' option is provided.
|
|
380
|
-
"""
|
|
381
|
-
is_single_attr = isinstance(attr_names, str)
|
|
382
|
-
|
|
383
|
-
if handle_missing == "error":
|
|
384
|
-
if is_single_attr:
|
|
385
|
-
return [getattr(agent, attr_names) for agent in self._agents]
|
|
386
|
-
else:
|
|
387
|
-
return [
|
|
388
|
-
[getattr(agent, attr) for attr in attr_names]
|
|
389
|
-
for agent in self._agents
|
|
390
|
-
]
|
|
391
|
-
|
|
392
|
-
elif handle_missing == "default":
|
|
393
|
-
if is_single_attr:
|
|
394
|
-
return [
|
|
395
|
-
getattr(agent, attr_names, default_value) for agent in self._agents
|
|
396
|
-
]
|
|
397
|
-
else:
|
|
398
|
-
return [
|
|
399
|
-
[getattr(agent, attr, default_value) for attr in attr_names]
|
|
400
|
-
for agent in self._agents
|
|
401
|
-
]
|
|
265
|
+
AttributeError if an agent does not have the specified attribute(s)
|
|
402
266
|
|
|
403
|
-
else:
|
|
404
|
-
raise ValueError(
|
|
405
|
-
f"Unknown handle_missing option: {handle_missing}, "
|
|
406
|
-
"should be one of 'error' or 'default'"
|
|
407
|
-
)
|
|
408
|
-
|
|
409
|
-
def set(self, attr_name: str, value: Any) -> AgentSet:
|
|
410
267
|
"""
|
|
411
|
-
Set a specified attribute to a given value for all agents in the AgentSet.
|
|
412
|
-
|
|
413
|
-
Args:
|
|
414
|
-
attr_name (str): The name of the attribute to set.
|
|
415
|
-
value (Any): The value to set the attribute to.
|
|
416
268
|
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
269
|
+
if isinstance(attr_names, str):
|
|
270
|
+
return [getattr(agent, attr_names) for agent in self._agents]
|
|
271
|
+
else:
|
|
272
|
+
return [
|
|
273
|
+
[getattr(agent, attr_name) for attr_name in attr_names]
|
|
274
|
+
for agent in self._agents
|
|
275
|
+
]
|
|
423
276
|
|
|
424
277
|
def __getitem__(self, item: int | slice) -> Agent:
|
|
425
278
|
"""
|
|
@@ -503,139 +356,7 @@ class AgentSet(MutableSet, Sequence):
|
|
|
503
356
|
"""
|
|
504
357
|
return self.model.random
|
|
505
358
|
|
|
506
|
-
def groupby(self, by: Callable | str, result_type: str = "agentset") -> GroupBy:
|
|
507
|
-
"""
|
|
508
|
-
Group agents by the specified attribute or return from the callable
|
|
509
|
-
|
|
510
|
-
Args:
|
|
511
|
-
by (Callable, str): used to determine what to group agents by
|
|
512
|
-
|
|
513
|
-
* if ``by`` is a callable, it will be called for each agent and the return is used
|
|
514
|
-
for grouping
|
|
515
|
-
* if ``by`` is a str, it should refer to an attribute on the agent and the value
|
|
516
|
-
of this attribute will be used for grouping
|
|
517
|
-
result_type (str, optional): The datatype for the resulting groups {"agentset", "list"}
|
|
518
|
-
Returns:
|
|
519
|
-
GroupBy
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
Notes:
|
|
523
|
-
There might be performance benefits to using `result_type='list'` if you don't need the advanced functionality
|
|
524
|
-
of an AgentSet.
|
|
525
|
-
|
|
526
|
-
"""
|
|
527
|
-
groups = defaultdict(list)
|
|
528
|
-
|
|
529
|
-
if isinstance(by, Callable):
|
|
530
|
-
for agent in self:
|
|
531
|
-
groups[by(agent)].append(agent)
|
|
532
|
-
else:
|
|
533
|
-
for agent in self:
|
|
534
|
-
groups[getattr(agent, by)].append(agent)
|
|
535
|
-
|
|
536
|
-
if result_type == "agentset":
|
|
537
|
-
return GroupBy(
|
|
538
|
-
{k: AgentSet(v, model=self.model) for k, v in groups.items()}
|
|
539
|
-
)
|
|
540
|
-
else:
|
|
541
|
-
return GroupBy(groups)
|
|
542
|
-
|
|
543
|
-
# consider adding for performance reasons
|
|
544
|
-
# for Sequence: __reversed__, index, and count
|
|
545
|
-
# for MutableSet clear, pop, remove, __ior__, __iand__, __ixor__, and __isub__
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
class GroupBy:
|
|
549
|
-
"""Helper class for AgentSet.groupby
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
Attributes:
|
|
553
|
-
groups (dict): A dictionary with the group_name as key and group as values
|
|
554
|
-
|
|
555
|
-
"""
|
|
556
|
-
|
|
557
|
-
def __init__(self, groups: dict[Any, list | AgentSet]):
|
|
558
|
-
self.groups: dict[Any, list | AgentSet] = groups
|
|
559
|
-
|
|
560
|
-
def map(self, method: Callable | str, *args, **kwargs) -> dict[Any, Any]:
|
|
561
|
-
"""Apply the specified callable to each group and return the results.
|
|
562
|
-
|
|
563
|
-
Args:
|
|
564
|
-
method (Callable, str): The callable to apply to each group,
|
|
565
|
-
|
|
566
|
-
* if ``method`` is a callable, it will be called it will be called with the group as first argument
|
|
567
|
-
* if ``method`` is a str, it should refer to a method on the group
|
|
568
|
-
|
|
569
|
-
Additional arguments and keyword arguments will be passed on to the callable.
|
|
570
|
-
|
|
571
|
-
Returns:
|
|
572
|
-
dict with group_name as key and the return of the method as value
|
|
573
|
-
|
|
574
|
-
Notes:
|
|
575
|
-
this method is useful for methods or functions that do return something. It
|
|
576
|
-
will break method chaining. For that, use ``do`` instead.
|
|
577
|
-
|
|
578
|
-
"""
|
|
579
|
-
if isinstance(method, str):
|
|
580
|
-
return {
|
|
581
|
-
k: getattr(v, method)(*args, **kwargs) for k, v in self.groups.items()
|
|
582
|
-
}
|
|
583
|
-
else:
|
|
584
|
-
return {k: method(v, *args, **kwargs) for k, v in self.groups.items()}
|
|
585
|
-
|
|
586
|
-
def do(self, method: Callable | str, *args, **kwargs) -> GroupBy:
|
|
587
|
-
"""Apply the specified callable to each group
|
|
588
|
-
|
|
589
|
-
Args:
|
|
590
|
-
method (Callable, str): The callable to apply to each group,
|
|
591
|
-
|
|
592
|
-
* if ``method`` is a callable, it will be called it will be called with the group as first argument
|
|
593
|
-
* if ``method`` is a str, it should refer to a method on the group
|
|
594
|
-
|
|
595
|
-
Additional arguments and keyword arguments will be passed on to the callable.
|
|
596
|
-
|
|
597
|
-
Returns:
|
|
598
|
-
the original GroupBy instance
|
|
599
|
-
|
|
600
|
-
Notes:
|
|
601
|
-
this method is useful for methods or functions that don't return anything and/or
|
|
602
|
-
if you want to chain multiple do calls
|
|
603
|
-
|
|
604
|
-
"""
|
|
605
|
-
if isinstance(method, str):
|
|
606
|
-
for v in self.groups.values():
|
|
607
|
-
getattr(v, method)(*args, **kwargs)
|
|
608
|
-
else:
|
|
609
|
-
for v in self.groups.values():
|
|
610
|
-
method(v, *args, **kwargs)
|
|
611
|
-
|
|
612
|
-
return self
|
|
613
|
-
|
|
614
|
-
def count(self) -> dict[Any, int]:
|
|
615
|
-
"""Return the count of agents in each group.
|
|
616
|
-
|
|
617
|
-
Returns:
|
|
618
|
-
dict: A dictionary mapping group names to the number of agents in each group.
|
|
619
|
-
"""
|
|
620
|
-
return {k: len(v) for k, v in self.groups.items()}
|
|
621
|
-
|
|
622
|
-
def agg(self, attr_name: str, func: Callable) -> dict[Hashable, Any]:
|
|
623
|
-
"""Aggregate the values of a specific attribute across each group using the provided function.
|
|
624
|
-
|
|
625
|
-
Args:
|
|
626
|
-
attr_name (str): The name of the attribute to aggregate.
|
|
627
|
-
func (Callable): The function to apply (e.g., sum, min, max, mean).
|
|
628
|
-
|
|
629
|
-
Returns:
|
|
630
|
-
dict[Hashable, Any]: A dictionary mapping group names to the result of applying the aggregation function.
|
|
631
|
-
"""
|
|
632
|
-
return {
|
|
633
|
-
group_name: func([getattr(agent, attr_name) for agent in group])
|
|
634
|
-
for group_name, group in self.groups.items()
|
|
635
|
-
}
|
|
636
|
-
|
|
637
|
-
def __iter__(self):
|
|
638
|
-
return iter(self.groups.items())
|
|
639
359
|
|
|
640
|
-
|
|
641
|
-
|
|
360
|
+
# consider adding for performance reasons
|
|
361
|
+
# for Sequence: __reversed__, index, and count
|
|
362
|
+
# for MutableSet clear, pop, remove, __ior__, __iand__, __ixor__, and __isub__
|
mesa/batchrunner.py
CHANGED
|
@@ -1,22 +1,19 @@
|
|
|
1
1
|
import itertools
|
|
2
|
-
import multiprocessing
|
|
3
2
|
from collections.abc import Iterable, Mapping
|
|
4
3
|
from functools import partial
|
|
5
4
|
from multiprocessing import Pool
|
|
6
|
-
from typing import Any
|
|
5
|
+
from typing import Any
|
|
7
6
|
|
|
8
7
|
from tqdm.auto import tqdm
|
|
9
8
|
|
|
10
9
|
from mesa.model import Model
|
|
11
10
|
|
|
12
|
-
multiprocessing.set_start_method("spawn", force=True)
|
|
13
|
-
|
|
14
11
|
|
|
15
12
|
def batch_run(
|
|
16
13
|
model_cls: type[Model],
|
|
17
|
-
parameters: Mapping[str,
|
|
14
|
+
parameters: Mapping[str, Any | Iterable[Any]],
|
|
18
15
|
# We still retain the Optional[int] because users may set it to None (i.e. use all CPUs)
|
|
19
|
-
number_processes:
|
|
16
|
+
number_processes: int | None = 1,
|
|
20
17
|
iterations: int = 1,
|
|
21
18
|
data_collection_period: int = -1,
|
|
22
19
|
max_steps: int = 1000,
|
|
@@ -79,7 +76,7 @@ def batch_run(
|
|
|
79
76
|
|
|
80
77
|
|
|
81
78
|
def _make_model_kwargs(
|
|
82
|
-
parameters: Mapping[str,
|
|
79
|
+
parameters: Mapping[str, Any | Iterable[Any]],
|
|
83
80
|
) -> list[dict[str, Any]]:
|
|
84
81
|
"""Create model kwargs from parameters dictionary.
|
|
85
82
|
|
|
@@ -135,14 +132,14 @@ def _model_run_func(
|
|
|
135
132
|
"""
|
|
136
133
|
run_id, iteration, kwargs = run
|
|
137
134
|
model = model_cls(**kwargs)
|
|
138
|
-
while model.running and model.
|
|
135
|
+
while model.running and model.schedule.steps <= max_steps:
|
|
139
136
|
model.step()
|
|
140
137
|
|
|
141
138
|
data = []
|
|
142
139
|
|
|
143
|
-
steps = list(range(0, model.
|
|
144
|
-
if not steps or steps[-1] != model.
|
|
145
|
-
steps.append(model.
|
|
140
|
+
steps = list(range(0, model.schedule.steps, data_collection_period))
|
|
141
|
+
if not steps or steps[-1] != model.schedule.steps - 1:
|
|
142
|
+
steps.append(model.schedule.steps - 1)
|
|
146
143
|
|
|
147
144
|
for step in steps:
|
|
148
145
|
model_data, all_agents_data = _collect_data(model, step)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configure visualization elements and instantiate a server
|
|
3
|
+
"""
|
|
4
|
+
from mesa.visualization import JupyterViz
|
|
5
|
+
|
|
6
|
+
from {{ cookiecutter.snake }}.model import {{ cookiecutter.model }}, {{ cookiecutter.agent }} # noqa
|
|
7
|
+
|
|
8
|
+
import mesa
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def circle_portrayal_example(agent):
|
|
12
|
+
return {
|
|
13
|
+
"size": 40,
|
|
14
|
+
# This is Matplotlib's color
|
|
15
|
+
"color": "tab:pink",
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
model_params = {"num_agents": 10, "width": 10, "height": 10}
|
|
20
|
+
|
|
21
|
+
page = JupyterViz(
|
|
22
|
+
{{cookiecutter.model}},
|
|
23
|
+
model_params,
|
|
24
|
+
measures=["num_agents"],
|
|
25
|
+
agent_portrayal=circle_portrayal_example
|
|
26
|
+
)
|
|
27
|
+
page # noqa
|
|
@@ -47,7 +47,7 @@ class {{cookiecutter.model}}(mesa.Model):
|
|
|
47
47
|
self.grid.place_agent(agent, (x, y))
|
|
48
48
|
|
|
49
49
|
# example data collector
|
|
50
|
-
self.datacollector = mesa.datacollection.DataCollector()
|
|
50
|
+
self.datacollector = mesa.datacollection.DataCollector({"num_agents": "num_agents"})
|
|
51
51
|
|
|
52
52
|
self.running = True
|
|
53
53
|
self.datacollector.collect(self)
|