Mesa 3.0.0a2__py3-none-any.whl → 3.0.0a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of Mesa might be problematic. Click here for more details.

@@ -0,0 +1,462 @@
1
+ """
2
+ Mesa visualization module for creating interactive model visualizations.
3
+
4
+ This module provides components to create browser- and Jupyter notebook-based visualizations of
5
+ Mesa models, allowing users to watch models run step-by-step and interact with model parameters.
6
+
7
+ Key features:
8
+ - SolaraViz: Main component for creating visualizations, supporting grid displays and plots
9
+ - ModelController: Handles model execution controls (step, play, pause, reset)
10
+ - UserInputs: Generates UI elements for adjusting model parameters
11
+ - Card: Renders individual visualization elements (space, measures)
12
+
13
+ The module uses Solara for rendering in Jupyter notebooks or as standalone web applications.
14
+ It supports various types of visualizations including matplotlib plots, agent grids, and
15
+ custom visualization components.
16
+
17
+ Usage:
18
+ 1. Define an agent_portrayal function to specify how agents should be displayed
19
+ 2. Set up model_params to define adjustable parameters
20
+ 3. Create a SolaraViz instance with your model, parameters, and desired measures
21
+ 4. Display the visualization in a Jupyter notebook or run as a Solara app
22
+
23
+ See the Visualization Tutorial and example models for more details.
24
+ """
25
+
26
+ import threading
27
+
28
+ import reacton.ipywidgets as widgets
29
+ import solara
30
+ from solara.alias import rv
31
+
32
+ import mesa.experimental.components.altair as components_altair
33
+ import mesa.experimental.components.matplotlib as components_matplotlib
34
+ from mesa.experimental.UserParam import Slider
35
+
36
+
37
+ # TODO: Turn this function into a Solara component once the current_step.value
38
+ # dependency is passed to measure()
39
+ def Card(
40
+ model, measures, agent_portrayal, space_drawer, dependencies, color, layout_type
41
+ ):
42
+ """
43
+ Create a card component for visualizing model space or measures.
44
+
45
+ Args:
46
+ model: The Mesa model instance
47
+ measures: List of measures to be plotted
48
+ agent_portrayal: Function to define agent appearance
49
+ space_drawer: Method to render agent space
50
+ dependencies: List of dependencies for updating the visualization
51
+ color: Background color of the card
52
+ layout_type: Type of layout (Space or Measure)
53
+
54
+ Returns:
55
+ rv.Card: A card component containing the visualization
56
+ """
57
+ with rv.Card(
58
+ style_=f"background-color: {color}; width: 100%; height: 100%"
59
+ ) as main:
60
+ if "Space" in layout_type:
61
+ rv.CardTitle(children=["Space"])
62
+ if space_drawer == "default":
63
+ # draw with the default implementation
64
+ components_matplotlib.SpaceMatplotlib(
65
+ model, agent_portrayal, dependencies=dependencies
66
+ )
67
+ elif space_drawer == "altair":
68
+ components_altair.SpaceAltair(
69
+ model, agent_portrayal, dependencies=dependencies
70
+ )
71
+ elif space_drawer:
72
+ # if specified, draw agent space with an alternate renderer
73
+ space_drawer(model, agent_portrayal, dependencies=dependencies)
74
+ elif "Measure" in layout_type:
75
+ rv.CardTitle(children=["Measure"])
76
+ measure = measures[layout_type["Measure"]]
77
+ if callable(measure):
78
+ # Is a custom object
79
+ measure(model)
80
+ else:
81
+ components_matplotlib.PlotMatplotlib(
82
+ model, measure, dependencies=dependencies
83
+ )
84
+ return main
85
+
86
+
87
+ @solara.component
88
+ def SolaraViz(
89
+ model_class,
90
+ model_params,
91
+ measures=None,
92
+ name=None,
93
+ agent_portrayal=None,
94
+ space_drawer="default",
95
+ play_interval=150,
96
+ seed=None,
97
+ ):
98
+ """
99
+ Initialize a component to visualize a model.
100
+
101
+ Args:
102
+ model_class: Class of the model to instantiate
103
+ model_params: Parameters for initializing the model
104
+ measures: List of callables or data attributes to plot
105
+ name: Name for display
106
+ agent_portrayal: Options for rendering agents (dictionary);
107
+ Default drawer supports custom `"size"`, `"color"`, and `"shape"`.
108
+ space_drawer: Method to render the agent space for
109
+ the model; default implementation is the `SpaceMatplotlib` component;
110
+ simulations with no space to visualize should
111
+ specify `space_drawer=False`
112
+ play_interval: Play interval (default: 150)
113
+ seed: The random seed used to initialize the model
114
+ """
115
+ if name is None:
116
+ name = model_class.__name__
117
+
118
+ current_step = solara.use_reactive(0)
119
+
120
+ # 1. Set up model parameters
121
+ reactive_seed = solara.use_reactive(0)
122
+ user_params, fixed_params = split_model_params(model_params)
123
+ model_parameters, set_model_parameters = solara.use_state(
124
+ {**fixed_params, **{k: v.get("value") for k, v in user_params.items()}}
125
+ )
126
+
127
+ # 2. Set up Model
128
+ def make_model():
129
+ """Create a new model instance with current parameters and seed."""
130
+ model = model_class.__new__(
131
+ model_class, **model_parameters, seed=reactive_seed.value
132
+ )
133
+ model.__init__(**model_parameters)
134
+ current_step.value = 0
135
+ return model
136
+
137
+ reset_counter = solara.use_reactive(0)
138
+ model = solara.use_memo(
139
+ make_model,
140
+ dependencies=[
141
+ *list(model_parameters.values()),
142
+ reset_counter.value,
143
+ reactive_seed.value,
144
+ ],
145
+ )
146
+
147
+ def handle_change_model_params(name: str, value: any):
148
+ """Update model parameters when user input changes."""
149
+ set_model_parameters({**model_parameters, name: value})
150
+
151
+ # 3. Set up UI
152
+
153
+ with solara.AppBar():
154
+ solara.AppBarTitle(name)
155
+
156
+ # render layout and plot
157
+ def do_reseed():
158
+ """Update the random seed for the model."""
159
+ reactive_seed.value = model.random.random()
160
+
161
+ dependencies = [
162
+ *list(model_parameters.values()),
163
+ current_step.value,
164
+ reactive_seed.value,
165
+ ]
166
+
167
+ # if space drawer is disabled, do not include it
168
+ layout_types = [{"Space": "default"}] if space_drawer else []
169
+
170
+ if measures:
171
+ layout_types += [{"Measure": elem} for elem in range(len(measures))]
172
+
173
+ grid_layout_initial = make_initial_grid_layout(layout_types=layout_types)
174
+ grid_layout, set_grid_layout = solara.use_state(grid_layout_initial)
175
+
176
+ with solara.Sidebar():
177
+ with solara.Card("Controls", margin=1, elevation=2):
178
+ solara.InputText(
179
+ label="Seed",
180
+ value=reactive_seed,
181
+ continuous_update=True,
182
+ )
183
+ UserInputs(user_params, on_change=handle_change_model_params)
184
+ ModelController(model, play_interval, current_step, reset_counter)
185
+ solara.Button(label="Reseed", color="primary", on_click=do_reseed)
186
+ with solara.Card("Information", margin=1, elevation=2):
187
+ solara.Markdown(md_text=f"Step - {current_step}")
188
+
189
+ items = [
190
+ Card(
191
+ model,
192
+ measures,
193
+ agent_portrayal,
194
+ space_drawer,
195
+ dependencies,
196
+ color="white",
197
+ layout_type=layout_types[i],
198
+ )
199
+ for i in range(len(layout_types))
200
+ ]
201
+ solara.GridDraggable(
202
+ items=items,
203
+ grid_layout=grid_layout,
204
+ resizable=True,
205
+ draggable=True,
206
+ on_grid_layout=set_grid_layout,
207
+ )
208
+
209
+
210
+ JupyterViz = SolaraViz
211
+
212
+
213
+ @solara.component
214
+ def ModelController(model, play_interval, current_step, reset_counter):
215
+ """
216
+ Create controls for model execution (step, play, pause, reset).
217
+
218
+ Args:
219
+ model: The model being visualized
220
+ play_interval: Interval between steps during play
221
+ current_step: Reactive value for the current step
222
+ reset_counter: Counter to trigger model reset
223
+ """
224
+ playing = solara.use_reactive(False)
225
+ thread = solara.use_reactive(None)
226
+ # We track the previous step to detect if user resets the model via
227
+ # clicking the reset button or changing the parameters. If previous_step >
228
+ # current_step, it means a model reset happens while the simulation is
229
+ # still playing.
230
+ previous_step = solara.use_reactive(0)
231
+
232
+ def on_value_play(change):
233
+ """Handle play/pause state changes."""
234
+ if previous_step.value > current_step.value and current_step.value == 0:
235
+ # We add extra checks for current_step.value == 0, just to be sure.
236
+ # We automatically stop the playing if a model is reset.
237
+ playing.value = False
238
+ elif model.running:
239
+ do_step()
240
+ else:
241
+ playing.value = False
242
+
243
+ def do_step():
244
+ """Advance the model by one step."""
245
+ model.step()
246
+ previous_step.value = current_step.value
247
+ current_step.value = model.steps
248
+
249
+ def do_play():
250
+ """Run the model continuously."""
251
+ model.running = True
252
+ while model.running:
253
+ do_step()
254
+
255
+ def threaded_do_play():
256
+ """Start a new thread for continuous model execution."""
257
+ if thread is not None and thread.is_alive():
258
+ return
259
+ thread.value = threading.Thread(target=do_play)
260
+ thread.start()
261
+
262
+ def do_pause():
263
+ """Pause the model execution."""
264
+ if (thread is None) or (not thread.is_alive()):
265
+ return
266
+ model.running = False
267
+ thread.join()
268
+
269
+ def do_reset():
270
+ """Reset the model."""
271
+ reset_counter.value += 1
272
+
273
+ def do_set_playing(value):
274
+ """Set the playing state."""
275
+ if current_step.value == 0:
276
+ # This means the model has been recreated, and the step resets to
277
+ # 0. We want to avoid triggering the playing.value = False in the
278
+ # on_value_play function.
279
+ previous_step.value = current_step.value
280
+ playing.set(value)
281
+
282
+ with solara.Row():
283
+ solara.Button(label="Step", color="primary", on_click=do_step)
284
+ # This style is necessary so that the play widget has almost the same
285
+ # height as typical Solara buttons.
286
+ solara.Style(
287
+ """
288
+ .widget-play {
289
+ height: 35px;
290
+ }
291
+ .widget-play button {
292
+ color: white;
293
+ background-color: #1976D2; // Solara blue color
294
+ }
295
+ """
296
+ )
297
+ widgets.Play(
298
+ value=0,
299
+ interval=play_interval,
300
+ repeat=True,
301
+ show_repeat=False,
302
+ on_value=on_value_play,
303
+ playing=playing.value,
304
+ on_playing=do_set_playing,
305
+ )
306
+ solara.Button(label="Reset", color="primary", on_click=do_reset)
307
+ # threaded_do_play is not used for now because it
308
+ # doesn't work in Google colab. We use
309
+ # ipywidgets.Play until it is fixed. The threading
310
+ # version is definite a much better implementation,
311
+ # if it works.
312
+ # solara.Button(label="▶", color="primary", on_click=viz.threaded_do_play)
313
+ # solara.Button(label="⏸︎", color="primary", on_click=viz.do_pause)
314
+ # solara.Button(label="Reset", color="primary", on_click=do_reset)
315
+
316
+
317
+ def split_model_params(model_params):
318
+ """
319
+ Split model parameters into user-adjustable and fixed parameters.
320
+
321
+ Args:
322
+ model_params: Dictionary of all model parameters
323
+
324
+ Returns:
325
+ tuple: (user_adjustable_params, fixed_params)
326
+ """
327
+ model_params_input = {}
328
+ model_params_fixed = {}
329
+ for k, v in model_params.items():
330
+ if check_param_is_fixed(v):
331
+ model_params_fixed[k] = v
332
+ else:
333
+ model_params_input[k] = v
334
+ return model_params_input, model_params_fixed
335
+
336
+
337
+ def check_param_is_fixed(param):
338
+ """
339
+ Check if a parameter is fixed (not user-adjustable).
340
+
341
+ Args:
342
+ param: Parameter to check
343
+
344
+ Returns:
345
+ bool: True if parameter is fixed, False otherwise
346
+ """
347
+ if isinstance(param, Slider):
348
+ return False
349
+ if not isinstance(param, dict):
350
+ return True
351
+ if "type" not in param:
352
+ return True
353
+
354
+
355
+ @solara.component
356
+ def UserInputs(user_params, on_change=None):
357
+ """
358
+ Initialize user inputs for configurable model parameters.
359
+ Currently supports :class:`solara.SliderInt`, :class:`solara.SliderFloat`,
360
+ :class:`solara.Select`, and :class:`solara.Checkbox`.
361
+
362
+ Args:
363
+ user_params: Dictionary with options for the input, including label,
364
+ min and max values, and other fields specific to the input type.
365
+ on_change: Function to be called with (name, value) when the value of an input changes.
366
+ """
367
+
368
+ for name, options in user_params.items():
369
+
370
+ def change_handler(value, name=name):
371
+ on_change(name, value)
372
+
373
+ if isinstance(options, Slider):
374
+ slider_class = (
375
+ solara.SliderFloat if options.is_float_slider else solara.SliderInt
376
+ )
377
+ slider_class(
378
+ options.label,
379
+ value=options.value,
380
+ on_value=change_handler,
381
+ min=options.min,
382
+ max=options.max,
383
+ step=options.step,
384
+ )
385
+ continue
386
+
387
+ # label for the input is "label" from options or name
388
+ label = options.get("label", name)
389
+ input_type = options.get("type")
390
+ if input_type == "SliderInt":
391
+ solara.SliderInt(
392
+ label,
393
+ value=options.get("value"),
394
+ on_value=change_handler,
395
+ min=options.get("min"),
396
+ max=options.get("max"),
397
+ step=options.get("step"),
398
+ )
399
+ elif input_type == "SliderFloat":
400
+ solara.SliderFloat(
401
+ label,
402
+ value=options.get("value"),
403
+ on_value=change_handler,
404
+ min=options.get("min"),
405
+ max=options.get("max"),
406
+ step=options.get("step"),
407
+ )
408
+ elif input_type == "Select":
409
+ solara.Select(
410
+ label,
411
+ value=options.get("value"),
412
+ on_value=change_handler,
413
+ values=options.get("values"),
414
+ )
415
+ elif input_type == "Checkbox":
416
+ solara.Checkbox(
417
+ label=label,
418
+ on_value=change_handler,
419
+ value=options.get("value"),
420
+ )
421
+ else:
422
+ raise ValueError(f"{input_type} is not a supported input type")
423
+
424
+
425
+ def make_text(renderer):
426
+ """
427
+ Create a function that renders text using Markdown.
428
+
429
+ Args:
430
+ renderer: Function that takes a model and returns a string
431
+
432
+ Returns:
433
+ function: A function that renders the text as Markdown
434
+ """
435
+
436
+ def function(model):
437
+ solara.Markdown(renderer(model))
438
+
439
+ return function
440
+
441
+
442
+ def make_initial_grid_layout(layout_types):
443
+ """
444
+ Create an initial grid layout for visualization components.
445
+
446
+ Args:
447
+ layout_types: List of layout types (Space or Measure)
448
+
449
+ Returns:
450
+ list: Initial grid layout configuration
451
+ """
452
+ return [
453
+ {
454
+ "i": i,
455
+ "w": 6,
456
+ "h": 10,
457
+ "moved": False,
458
+ "x": 6 * (i % 2),
459
+ "y": 16 * (i - i % 2),
460
+ }
461
+ for i in range(len(layout_types))
462
+ ]
mesa/model.py CHANGED
@@ -8,9 +8,8 @@ Core Objects: Model
8
8
  # Remove this __future__ import once the oldest supported Python is 3.10
9
9
  from __future__ import annotations
10
10
 
11
- import itertools
12
11
  import random
13
- from collections import defaultdict
12
+ import warnings
14
13
 
15
14
  # mypy
16
15
  from typing import Any
@@ -18,8 +17,6 @@ from typing import Any
18
17
  from mesa.agent import Agent, AgentSet
19
18
  from mesa.datacollection import DataCollector
20
19
 
21
- TimeT = float | int
22
-
23
20
 
24
21
  class Model:
25
22
  """Base class for models in the Mesa ABM library.
@@ -31,21 +28,30 @@ class Model:
31
28
  Attributes:
32
29
  running: A boolean indicating if the model should continue running.
33
30
  schedule: An object to manage the order and execution of agent steps.
34
- current_id: A counter for assigning unique IDs to agents.
35
- agents_: A defaultdict mapping each agent type to a dict of its instances.
36
- This private attribute is used internally to manage agents.
37
31
 
38
32
  Properties:
39
- agents: An AgentSet containing all agents in the model, generated from the _agents attribute.
33
+ agents: An AgentSet containing all agents in the model
40
34
  agent_types: A list of different agent types present in the model.
35
+ agents_by_type: A dictionary where the keys are agent types and the values are the corresponding AgentSets.
36
+ steps: An integer representing the number of steps the model has taken.
37
+ It increases automatically at the start of each step() call.
41
38
 
42
39
  Methods:
43
40
  get_agents_of_type: Returns an AgentSet of agents of the specified type.
41
+ Deprecated: Use agents_by_type[agenttype] instead.
44
42
  run_model: Runs the model's simulation until a defined end condition is reached.
45
43
  step: Executes a single step of the model's simulation process.
46
44
  next_id: Generates and returns the next unique identifier for an agent.
47
45
  reset_randomizer: Resets the model's random number generator with a new or existing seed.
48
46
  initialize_data_collector: Sets up the data collector for the model, requiring an initialized scheduler and agents.
47
+ register_agent : register an agent with the model
48
+ deregister_agent : remove an agent from the model
49
+
50
+ Notes:
51
+ Model.agents returns the AgentSet containing all agents registered with the model. Changing
52
+ the content of the AgentSet directly can result in strange behavior. If you want change the
53
+ composition of this AgentSet, ensure you operate on a copy.
54
+
49
55
  """
50
56
 
51
57
  def __new__(cls, *args: Any, **kwargs: Any) -> Any:
@@ -57,9 +63,6 @@ class Model:
57
63
  # advance.
58
64
  obj._seed = random.random()
59
65
  obj.random = random.Random(obj._seed)
60
- # TODO: Remove these 2 lines just before Mesa 3.0
61
- obj._steps = 0
62
- obj._time = 0
63
66
  return obj
64
67
 
65
68
  def __init__(self, *args: Any, **kwargs: Any) -> None:
@@ -70,40 +73,117 @@ class Model:
70
73
 
71
74
  self.running = True
72
75
  self.schedule = None
73
- self.current_id = 0
74
- self.agents_: defaultdict[type, dict] = defaultdict(dict)
76
+ self.steps: int = 0
77
+
78
+ self._setup_agent_registration()
75
79
 
76
- self._steps: int = 0
77
- self._time: TimeT = 0 # the model's clock
80
+ # Wrap the user-defined step method
81
+ self._user_step = self.step
82
+ self.step = self._wrapped_step
83
+
84
+ def _wrapped_step(self, *args: Any, **kwargs: Any) -> None:
85
+ """Automatically increments time and steps after calling the user's step method."""
86
+ # Automatically increment time and step counters
87
+ self.steps += 1
88
+ # Call the original user-defined step method
89
+ self._user_step(*args, **kwargs)
90
+
91
+ def next_id(self) -> int:
92
+ warnings.warn(
93
+ "using model.next_id() is deprecated. Agents track their unique ID automatically",
94
+ DeprecationWarning,
95
+ stacklevel=2,
96
+ )
97
+ return 0
78
98
 
79
99
  @property
80
100
  def agents(self) -> AgentSet:
81
101
  """Provides an AgentSet of all agents in the model, combining agents from all types."""
82
-
83
- if hasattr(self, "_agents"):
84
- return self._agents
85
- else:
86
- all_agents = itertools.chain.from_iterable(self.agents_.values())
87
- return AgentSet(all_agents, self)
102
+ return self._all_agents
88
103
 
89
104
  @agents.setter
90
105
  def agents(self, agents: Any) -> None:
91
106
  raise AttributeError(
92
- "You are trying to set model.agents. In Mesa 3.0 and higher, this attribute will be "
107
+ "You are trying to set model.agents. In Mesa 3.0 and higher, this attribute is "
93
108
  "used by Mesa itself, so you cannot use it directly anymore."
94
109
  "Please adjust your code to use a different attribute name for custom agent storage."
95
110
  )
96
111
 
97
- self._agents = agents
98
-
99
112
  @property
100
113
  def agent_types(self) -> list[type]:
101
- """Return a list of different agent types."""
102
- return list(self.agents_.keys())
114
+ """Return a list of all unique agent types registered with the model."""
115
+ return list(self._agents_by_type.keys())
116
+
117
+ @property
118
+ def agents_by_type(self) -> dict[type[Agent], AgentSet]:
119
+ """A dictionary where the keys are agent types and the values are the corresponding AgentSets."""
120
+ return self._agents_by_type
103
121
 
104
122
  def get_agents_of_type(self, agenttype: type[Agent]) -> AgentSet:
105
- """Retrieves an AgentSet containing all agents of the specified type."""
106
- return AgentSet(self.agents_[agenttype].keys(), self)
123
+ """Deprecated: Retrieves an AgentSet containing all agents of the specified type."""
124
+ warnings.warn(
125
+ f"Model.get_agents_of_type() is deprecated, please replace get_agents_of_type({agenttype})"
126
+ f"with the property agents_by_type[{agenttype}].",
127
+ DeprecationWarning,
128
+ stacklevel=2,
129
+ )
130
+ return self.agents_by_type[agenttype]
131
+
132
+ def _setup_agent_registration(self):
133
+ """helper method to initialize the agent registration datastructures"""
134
+ self._agents = {} # the hard references to all agents in the model
135
+ self._agents_by_type: dict[
136
+ type[Agent], AgentSet
137
+ ] = {} # a dict with an agentset for each class of agents
138
+ self._all_agents = AgentSet([], self) # an agenset with all agents
139
+
140
+ def register_agent(self, agent):
141
+ """Register the agent with the model
142
+
143
+ Args:
144
+ agent: The agent to register.
145
+
146
+ Notes:
147
+ This method is called automatically by ``Agent.__init__``, so there is no need to use this
148
+ if you are subclassing Agent and calling its super in the ``__init__`` method.
149
+
150
+ """
151
+ if not hasattr(self, "_agents"):
152
+ self._setup_agent_registration()
153
+
154
+ warnings.warn(
155
+ "The Mesa Model class was not initialized. In the future, you need to explicitly initialize "
156
+ "the Model by calling super().__init__() on initialization.",
157
+ FutureWarning,
158
+ stacklevel=2,
159
+ )
160
+
161
+ self._agents[agent] = None
162
+
163
+ # because AgentSet requires model, we cannot use defaultdict
164
+ # tricks with a function won't work because model then cannot be pickled
165
+ try:
166
+ self._agents_by_type[type(agent)].add(agent)
167
+ except KeyError:
168
+ self._agents_by_type[type(agent)] = AgentSet(
169
+ [
170
+ agent,
171
+ ],
172
+ self,
173
+ )
174
+
175
+ self._all_agents.add(agent)
176
+
177
+ def deregister_agent(self, agent):
178
+ """Deregister the agent with the model
179
+
180
+ Notes::
181
+ This method is called automatically by ``Agent.remove``
182
+
183
+ """
184
+ del self._agents[agent]
185
+ self._agents_by_type[type(agent)].remove(agent)
186
+ self._all_agents.remove(agent)
107
187
 
108
188
  def run_model(self) -> None:
109
189
  """Run the model until the end condition is reached. Overload as
@@ -115,16 +195,6 @@ class Model:
115
195
  def step(self) -> None:
116
196
  """A single step. Fill in here."""
117
197
 
118
- def _advance_time(self, deltat: TimeT = 1):
119
- """Increment the model's steps counter and clock."""
120
- self._steps += 1
121
- self._time += deltat
122
-
123
- def next_id(self) -> int:
124
- """Return the next unique ID for agents, increment current_id"""
125
- self.current_id += 1
126
- return self.current_id
127
-
128
198
  def reset_randomizer(self, seed: int | None = None) -> None:
129
199
  """Reset the model random number generator.
130
200
 
mesa/space.py CHANGED
@@ -459,15 +459,21 @@ class _Grid:
459
459
  elif selection == "closest":
460
460
  current_pos = agent.pos
461
461
  # Find the closest position without sorting all positions
462
- closest_pos = None
462
+ # TODO: See if this method can be optimized further
463
+ closest_pos = []
463
464
  min_distance = float("inf")
464
465
  agent.random.shuffle(pos)
465
466
  for p in pos:
466
467
  distance = self._distance_squared(p, current_pos)
467
468
  if distance < min_distance:
468
469
  min_distance = distance
469
- closest_pos = p
470
- chosen_pos = closest_pos
470
+ closest_pos.clear()
471
+ closest_pos.append(p)
472
+ elif distance == min_distance:
473
+ closest_pos.append(p)
474
+
475
+ chosen_pos = agent.random.choice(closest_pos)
476
+
471
477
  else:
472
478
  raise ValueError(
473
479
  f"Invalid selection method {selection}. Choose 'random' or 'closest'."