cogames 0.3.65__py3-none-any.whl → 0.3.68__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. cogames/cli/client.py +0 -3
  2. cogames/cli/docsync/docsync.py +7 -1
  3. cogames/cli/mission.py +44 -19
  4. cogames/cli/policy.py +26 -10
  5. cogames/cli/submit.py +127 -141
  6. cogames/cli/utils.py +5 -0
  7. cogames/cogs_vs_clips/clip_difficulty.py +57 -0
  8. cogames/cogs_vs_clips/clips.py +23 -6
  9. cogames/cogs_vs_clips/cog.py +16 -5
  10. cogames/cogs_vs_clips/cogsguard_curriculum.py +122 -0
  11. cogames/cogs_vs_clips/cogsguard_tutorial.py +5 -5
  12. cogames/cogs_vs_clips/config.py +1 -1
  13. cogames/cogs_vs_clips/docs/cogs_vs_clips_mapgen.md +2 -3
  14. cogames/cogs_vs_clips/evals/README.md +8 -32
  15. cogames/cogs_vs_clips/evals/diagnostic_evals.py +0 -1
  16. cogames/cogs_vs_clips/evals/difficulty_variants.py +7 -10
  17. cogames/cogs_vs_clips/mission.py +38 -10
  18. cogames/cogs_vs_clips/missions.py +1 -1
  19. cogames/cogs_vs_clips/reward_variants.py +173 -0
  20. cogames/cogs_vs_clips/sites.py +6 -5
  21. cogames/cogs_vs_clips/stations.py +13 -9
  22. cogames/cogs_vs_clips/team.py +3 -1
  23. cogames/cogs_vs_clips/terrain.py +2 -2
  24. cogames/cogs_vs_clips/variants.py +175 -4
  25. cogames/cogs_vs_clips/weather.py +52 -0
  26. cogames/docs/SCRIPTED_AGENT.md +3 -3
  27. cogames/evaluate.py +4 -2
  28. cogames/main.py +357 -51
  29. cogames/maps/canidate1_1000.map +1 -1
  30. cogames/maps/canidate1_1000_stations.map +2 -2
  31. cogames/maps/canidate1_500.map +1 -1
  32. cogames/maps/canidate1_500_stations.map +2 -2
  33. cogames/maps/canidate2_1000.map +1 -1
  34. cogames/maps/canidate2_1000_stations.map +2 -2
  35. cogames/maps/canidate2_500.map +1 -1
  36. cogames/maps/canidate2_500_stations.map +1 -1
  37. cogames/maps/canidate3_1000.map +1 -1
  38. cogames/maps/canidate3_1000_stations.map +2 -2
  39. cogames/maps/canidate3_500.map +1 -1
  40. cogames/maps/canidate3_500_stations.map +2 -2
  41. cogames/maps/canidate4_500.map +1 -1
  42. cogames/maps/canidate4_500_stations.map +2 -2
  43. cogames/maps/cave_base_50.map +2 -2
  44. cogames/maps/diagnostic_evals/diagnostic_agile.map +2 -2
  45. cogames/maps/diagnostic_evals/diagnostic_agile_hard.map +2 -2
  46. cogames/maps/diagnostic_evals/diagnostic_charge_up.map +6 -6
  47. cogames/maps/diagnostic_evals/diagnostic_charge_up_hard.map +6 -6
  48. cogames/maps/diagnostic_evals/diagnostic_chest_navigation1.map +6 -6
  49. cogames/maps/diagnostic_evals/diagnostic_chest_navigation1_hard.map +6 -6
  50. cogames/maps/diagnostic_evals/diagnostic_chest_navigation2.map +6 -6
  51. cogames/maps/diagnostic_evals/diagnostic_chest_navigation2_hard.map +6 -6
  52. cogames/maps/diagnostic_evals/diagnostic_chest_navigation3.map +6 -6
  53. cogames/maps/diagnostic_evals/diagnostic_chest_navigation3_hard.map +6 -6
  54. cogames/maps/diagnostic_evals/diagnostic_chest_near.map +6 -6
  55. cogames/maps/diagnostic_evals/diagnostic_chest_search.map +6 -6
  56. cogames/maps/diagnostic_evals/diagnostic_chest_search_hard.map +6 -6
  57. cogames/maps/diagnostic_evals/diagnostic_extract_lab.map +6 -6
  58. cogames/maps/diagnostic_evals/diagnostic_extract_lab_hard.map +6 -6
  59. cogames/maps/diagnostic_evals/diagnostic_memory.map +6 -6
  60. cogames/maps/diagnostic_evals/diagnostic_memory_hard.map +6 -6
  61. cogames/maps/diagnostic_evals/diagnostic_radial.map +2 -2
  62. cogames/maps/diagnostic_evals/diagnostic_radial_hard.map +2 -2
  63. cogames/maps/diagnostic_evals/diagnostic_resource_lab.map +6 -6
  64. cogames/maps/diagnostic_evals/diagnostic_unclip.map +6 -6
  65. cogames/maps/evals/eval_balanced_spread.map +6 -6
  66. cogames/maps/evals/eval_clip_oxygen.map +6 -6
  67. cogames/maps/evals/eval_collect_resources.map +6 -6
  68. cogames/maps/evals/eval_collect_resources_hard.map +6 -6
  69. cogames/maps/evals/eval_collect_resources_medium.map +6 -6
  70. cogames/maps/evals/eval_divide_and_conquer.map +6 -6
  71. cogames/maps/evals/eval_energy_starved.map +6 -6
  72. cogames/maps/evals/eval_multi_coordinated_collect_hard.map +6 -6
  73. cogames/maps/evals/eval_oxygen_bottleneck.map +6 -6
  74. cogames/maps/evals/eval_single_use_world.map +6 -6
  75. cogames/maps/evals/extractor_hub_100x100.map +6 -6
  76. cogames/maps/evals/extractor_hub_30x30.map +6 -6
  77. cogames/maps/evals/extractor_hub_50x50.map +6 -6
  78. cogames/maps/evals/extractor_hub_70x70.map +6 -6
  79. cogames/maps/evals/extractor_hub_80x80.map +6 -6
  80. cogames/maps/machina_100_stations.map +2 -2
  81. cogames/maps/machina_200_stations.map +2 -2
  82. cogames/maps/machina_200_stations_small.map +2 -2
  83. cogames/maps/machina_eval_exp01.map +2 -2
  84. cogames/maps/machina_eval_template_large.map +2 -2
  85. cogames/maps/machinatrainer4agents.map +2 -2
  86. cogames/maps/machinatrainer4agentsbase.map +2 -2
  87. cogames/maps/machinatrainerbig.map +2 -2
  88. cogames/maps/machinatrainersmall.map +2 -2
  89. cogames/maps/planky_evals/aligner_avoid_aoe.map +6 -6
  90. cogames/maps/planky_evals/aligner_full_cycle.map +6 -6
  91. cogames/maps/planky_evals/aligner_gear.map +6 -6
  92. cogames/maps/planky_evals/aligner_hearts.map +6 -6
  93. cogames/maps/planky_evals/aligner_junction.map +6 -6
  94. cogames/maps/planky_evals/exploration_distant.map +6 -6
  95. cogames/maps/planky_evals/maze.map +6 -6
  96. cogames/maps/planky_evals/miner_best_resource.map +6 -6
  97. cogames/maps/planky_evals/miner_deposit.map +6 -6
  98. cogames/maps/planky_evals/miner_extract.map +6 -6
  99. cogames/maps/planky_evals/miner_full_cycle.map +6 -6
  100. cogames/maps/planky_evals/miner_gear.map +6 -6
  101. cogames/maps/planky_evals/multi_role.map +6 -6
  102. cogames/maps/planky_evals/resource_chain.map +6 -6
  103. cogames/maps/planky_evals/scout_explore.map +6 -6
  104. cogames/maps/planky_evals/scout_gear.map +6 -6
  105. cogames/maps/planky_evals/scrambler_full_cycle.map +6 -6
  106. cogames/maps/planky_evals/scrambler_gear.map +6 -6
  107. cogames/maps/planky_evals/scrambler_target.map +6 -6
  108. cogames/maps/planky_evals/stuck_corridor.map +6 -6
  109. cogames/maps/planky_evals/survive_retreat.map +6 -6
  110. cogames/maps/training_facility_clipped.map +2 -2
  111. cogames/maps/training_facility_open_1.map +2 -2
  112. cogames/maps/training_facility_open_2.map +2 -2
  113. cogames/maps/training_facility_open_3.map +2 -2
  114. cogames/maps/training_facility_tight_4.map +2 -2
  115. cogames/maps/training_facility_tight_5.map +2 -2
  116. cogames/maps/vanilla_large.map +2 -2
  117. cogames/maps/vanilla_small.map +2 -2
  118. cogames/pickup.py +6 -5
  119. cogames/play.py +14 -16
  120. cogames/policy/nim_agents/__init__.py +0 -2
  121. cogames/policy/nim_agents/agents.py +0 -11
  122. cogames/policy/starter_agent.py +4 -1
  123. {cogames-0.3.65.dist-info → cogames-0.3.68.dist-info}/METADATA +45 -29
  124. cogames-0.3.68.dist-info/RECORD +160 -0
  125. metta_alo/scoring.py +7 -7
  126. cogames-0.3.65.dist-info/RECORD +0 -160
  127. metta_alo/job_specs.py +0 -17
  128. metta_alo/policy.py +0 -16
  129. metta_alo/pure_single_episode_runner.py +0 -75
  130. metta_alo/rollout.py +0 -322
  131. {cogames-0.3.65.dist-info → cogames-0.3.68.dist-info}/WHEEL +0 -0
  132. {cogames-0.3.65.dist-info → cogames-0.3.68.dist-info}/entry_points.txt +0 -0
  133. {cogames-0.3.65.dist-info → cogames-0.3.68.dist-info}/licenses/LICENSE +0 -0
  134. {cogames-0.3.65.dist-info → cogames-0.3.68.dist-info}/top_level.txt +0 -0
@@ -118,7 +118,7 @@ class MachinaArena(Scene[MachinaArenaConfig]):
118
118
 
119
119
  # Building weights
120
120
  default_building_weights = {
121
- "chest": 0.0,
121
+ "c:chest": 0.0,
122
122
  "junction": 0.7,
123
123
  "germanium_extractor": 0.3,
124
124
  "silicon_extractor": 0.3,
@@ -377,7 +377,7 @@ class SequentialMachinaArena(Scene[SequentialMachinaArenaConfig]):
377
377
  raise ValueError(f"Unknown base_biome '{cfg.base_biome}'. Valid: {sorted(biome_map.keys())}")
378
378
  base_cfg: SceneConfig = BaseCfgModel.model_validate(cfg.base_biome_config or {})
379
379
  default_building_weights = {
380
- "chest": 0.0,
380
+ "c:chest": 0.0,
381
381
  "junction": 0.6,
382
382
  "germanium_extractor": 0.2,
383
383
  "silicon_extractor": 0.2,
@@ -2,9 +2,13 @@ from __future__ import annotations
2
2
 
3
3
  from typing import TYPE_CHECKING, override
4
4
 
5
+ from pydantic import Field
6
+
7
+ from cogames.cogs_vs_clips.config import CvCConfig
5
8
  from cogames.cogs_vs_clips.evals.difficulty_variants import DIFFICULTY_VARIANTS
6
9
  from cogames.cogs_vs_clips.terrain import BaseHubVariant, MachinaArenaVariant
7
10
  from cogames.core import CoGameMissionVariant
11
+ from mettagrid.config.mettagrid_config import MettaGridConfig
8
12
  from mettagrid.map_builder.map_builder import MapBuilderConfig
9
13
  from mettagrid.mapgen.mapgen import MapGen
10
14
  from mettagrid.mapgen.scenes.base_hub import DEFAULT_EXTRACTORS as HUB_EXTRACTORS
@@ -14,6 +18,37 @@ if TYPE_CHECKING:
14
18
  from cogames.cogs_vs_clips.mission import CvCMission
15
19
 
16
20
 
21
+ def _apply_clips_settings(
22
+ mission: CvCMission,
23
+ *,
24
+ initial_clips_start: int | None = None,
25
+ initial_clips_spots: int | None = None,
26
+ scramble_start: int | None = None,
27
+ scramble_interval: int | None = None,
28
+ scramble_radius: int | None = None,
29
+ align_start: int | None = None,
30
+ align_interval: int | None = None,
31
+ align_radius: int | None = None,
32
+ ) -> None:
33
+ clips = mission.clips
34
+ if initial_clips_start is not None:
35
+ clips.initial_clips_start = initial_clips_start
36
+ if initial_clips_spots is not None:
37
+ clips.initial_clips_spots = initial_clips_spots
38
+ if scramble_start is not None:
39
+ clips.scramble_start = scramble_start
40
+ if scramble_interval is not None:
41
+ clips.scramble_interval = scramble_interval
42
+ if scramble_radius is not None:
43
+ clips.scramble_radius = scramble_radius
44
+ if align_start is not None:
45
+ clips.align_start = align_start
46
+ if align_interval is not None:
47
+ clips.align_interval = align_interval
48
+ if align_radius is not None:
49
+ clips.align_radius = align_radius
50
+
51
+
17
52
  class NumCogsVariant(CoGameMissionVariant):
18
53
  name: str = "num_cogs"
19
54
  description: str = "Set the number of cogs for the mission."
@@ -30,13 +65,91 @@ class NumCogsVariant(CoGameMissionVariant):
30
65
  mission.num_cogs = self.num_cogs
31
66
 
32
67
 
68
+ class ClipsEasyVariant(CoGameMissionVariant):
69
+ name: str = "clips_easy"
70
+ description: str = "Slow clips expansion with late pressure."
71
+
72
+ @override
73
+ def modify_mission(self, mission: CvCMission) -> None:
74
+ _apply_clips_settings(
75
+ mission,
76
+ initial_clips_start=50,
77
+ initial_clips_spots=1,
78
+ scramble_start=250,
79
+ scramble_interval=250,
80
+ scramble_radius=15,
81
+ align_start=300,
82
+ align_interval=250,
83
+ align_radius=15,
84
+ )
85
+
86
+
87
+ class ClipsMediumVariant(CoGameMissionVariant):
88
+ name: str = "clips_medium"
89
+ description: str = "Baseline clips pressure (Machina1 default)."
90
+
91
+ @override
92
+ def modify_mission(self, mission: CvCMission) -> None:
93
+ _apply_clips_settings(
94
+ mission,
95
+ initial_clips_start=10,
96
+ initial_clips_spots=1,
97
+ scramble_start=50,
98
+ scramble_interval=100,
99
+ scramble_radius=25,
100
+ align_start=100,
101
+ align_interval=100,
102
+ align_radius=25,
103
+ )
104
+
105
+
106
+ class ClipsHardVariant(CoGameMissionVariant):
107
+ name: str = "clips_hard"
108
+ description: str = "Fast clips pressure with wider influence."
109
+
110
+ @override
111
+ def modify_mission(self, mission: CvCMission) -> None:
112
+ _apply_clips_settings(
113
+ mission,
114
+ initial_clips_start=5,
115
+ initial_clips_spots=2,
116
+ scramble_start=25,
117
+ scramble_interval=50,
118
+ scramble_radius=35,
119
+ align_start=50,
120
+ align_interval=50,
121
+ align_radius=35,
122
+ )
123
+
124
+
125
+ class ClipsWaveOnlyVariant(CoGameMissionVariant):
126
+ name: str = "clips_wave_only"
127
+ description: str = "Initial clips wave only, no further spread."
128
+
129
+ @override
130
+ def modify_mission(self, mission: CvCMission) -> None:
131
+ disable_start = mission.max_steps + 1
132
+ _apply_clips_settings(
133
+ mission,
134
+ initial_clips_start=10,
135
+ initial_clips_spots=3,
136
+ scramble_start=disable_start,
137
+ scramble_interval=disable_start,
138
+ align_start=disable_start,
139
+ align_interval=disable_start,
140
+ scramble_radius=25,
141
+ align_radius=25,
142
+ )
143
+
144
+
33
145
  class DarkSideVariant(CoGameMissionVariant):
34
146
  name: str = "dark_side"
35
147
  description: str = "You're on the dark side of the asteroid. You recharge slower."
36
148
 
37
149
  @override
38
150
  def modify_mission(self, mission: CvCMission) -> None:
39
- mission.cog.energy_regen = 0
151
+ mission.weather.day_deltas = {"solar": 0}
152
+ mission.weather.night_deltas = {"solar": 0}
40
153
 
41
154
 
42
155
  class SuperChargedVariant(CoGameMissionVariant):
@@ -45,7 +158,8 @@ class SuperChargedVariant(CoGameMissionVariant):
45
158
 
46
159
  @override
47
160
  def modify_mission(self, mission: CvCMission) -> None:
48
- mission.cog.energy_regen += 2
161
+ mission.weather.day_deltas = {k: v + 2 for k, v in mission.weather.day_deltas.items()}
162
+ mission.weather.night_deltas = {k: v + 2 for k, v in mission.weather.night_deltas.items()}
49
163
 
50
164
 
51
165
  class EnergizedVariant(CoGameMissionVariant):
@@ -55,7 +169,8 @@ class EnergizedVariant(CoGameMissionVariant):
55
169
  @override
56
170
  def modify_mission(self, mission: CvCMission) -> None:
57
171
  mission.cog.energy_limit = max(mission.cog.energy_limit, 255)
58
- mission.cog.energy_regen = mission.cog.energy_limit
172
+ mission.weather.day_deltas = {"solar": 255}
173
+ mission.weather.night_deltas = {"solar": 255}
59
174
 
60
175
 
61
176
  class Small50Variant(CoGameMissionVariant):
@@ -213,14 +328,65 @@ class BalancedCornersVariant(MachinaArenaVariant):
213
328
  node.max_balance_shortcuts = self.max_balance_shortcuts
214
329
 
215
330
 
331
+ class MultiTeamVariant(CoGameMissionVariant):
332
+ """Split the map into multiple team instances, each with their own hub and resources."""
333
+
334
+ name: str = "multi_team"
335
+ description: str = "Split map into separate team instances with independent hubs."
336
+ num_teams: int = Field(default=2, ge=2, le=2, description="Number of teams (max 2 supported)")
337
+
338
+ @override
339
+ def modify_mission(self, mission: CvCMission) -> None:
340
+ team = next(iter(mission.teams.values()))
341
+ # Each team gets the original agent count; clear num_cogs so total is derived from teams
342
+ original_agents = mission.num_agents
343
+ mission.teams = {
344
+ name: team.model_copy(update={"name": name, "short_name": name, "num_agents": original_agents})
345
+ for name in ["cogs_green", "cogs_blue"][: self.num_teams]
346
+ }
347
+ mission.num_cogs = None
348
+
349
+ def modify_env(self, mission: CvCMission, env: MettaGridConfig) -> None:
350
+ original_builder = env.game.map_builder
351
+ # Shrink inner instance borders so teams are close together
352
+ if isinstance(original_builder, MapGen.Config):
353
+ original_builder.border_width = 1
354
+ env.game.map_builder = MapGen.Config(
355
+ instance=original_builder,
356
+ instances=self.num_teams,
357
+ set_team_by_instance=True,
358
+ instance_names=[t.short_name for t in mission.teams.values()],
359
+ instance_object_remap={
360
+ "c:hub": "{instance_name}:hub",
361
+ "c:chest": "{instance_name}:chest",
362
+ **{f"c:{g}": f"{{instance_name}}:{g}" for g in CvCConfig.GEAR},
363
+ },
364
+ # Connect instances: no added borders, clear walls at boundary
365
+ border_width=0, # No outer border (inner instances have their own)
366
+ instance_border_width=0, # No border between instances
367
+ instance_border_clear_radius=3, # Clear walls near instance boundary
368
+ )
369
+
370
+
371
+ class NoClipsVariant(CoGameMissionVariant):
372
+ name: str = "no_clips"
373
+ description: str = "Disable clips behavior entirely."
374
+
375
+ @override
376
+ def modify_mission(self, mission: CvCMission) -> None:
377
+ mission.clips.disabled = True
378
+
379
+
216
380
  VARIANTS: list[CoGameMissionVariant] = [
217
381
  CavesVariant(),
218
382
  CityVariant(),
219
383
  DarkSideVariant(),
384
+ NoClipsVariant(),
220
385
  DesertVariant(),
221
386
  EmptyBaseVariant(),
222
387
  EnergizedVariant(),
223
388
  ForestVariant(),
389
+ MultiTeamVariant(),
224
390
  QuadrantBuildingsVariant(),
225
391
  SingleResourceUniformVariant(),
226
392
  Small50Variant(),
@@ -228,4 +394,9 @@ VARIANTS: list[CoGameMissionVariant] = [
228
394
  *DIFFICULTY_VARIANTS,
229
395
  ]
230
396
 
231
- HIDDEN_VARIANTS: list[CoGameMissionVariant] = []
397
+ HIDDEN_VARIANTS: list[CoGameMissionVariant] = [
398
+ ClipsEasyVariant(),
399
+ ClipsMediumVariant(),
400
+ ClipsHardVariant(),
401
+ ClipsWaveOnlyVariant(),
402
+ ]
@@ -0,0 +1,52 @@
1
+ """Weather system events for CogsGuard missions.
2
+
3
+ Day/night cycle that applies resource deltas to entities at regular intervals.
4
+ """
5
+
6
+ from pydantic import Field
7
+
8
+ from mettagrid.base_config import Config
9
+ from mettagrid.config.event_config import EventConfig, periodic
10
+ from mettagrid.config.mutation import updateTarget
11
+ from mettagrid.config.tag import typeTag
12
+
13
+
14
+ class WeatherConfig(Config):
15
+ """Configuration for day/night weather cycle."""
16
+
17
+ day_length: int = Field(default=200)
18
+ day_deltas: dict[str, int] = Field(default_factory=lambda: {"solar": 3})
19
+ night_deltas: dict[str, int] = Field(default_factory=lambda: {"solar": 1})
20
+ target_tag: str = Field(default="agent")
21
+
22
+ def events(self, max_steps: int) -> dict[str, EventConfig]:
23
+ """Create weather events for a mission.
24
+
25
+ Returns:
26
+ Dictionary of event name to EventConfig.
27
+ """
28
+ events: dict[str, EventConfig] = {}
29
+ tag = typeTag(self.target_tag)
30
+ half = self.day_length // 2
31
+
32
+ def _merge(apply: dict[str, int], reverse: dict[str, int]) -> dict[str, int]:
33
+ keys = set(apply) | set(reverse)
34
+ return {k: apply.get(k, 0) - reverse.get(k, 0) for k in keys}
35
+
36
+ # Dawn: reverse night deltas, apply day deltas
37
+ events["weather_day"] = EventConfig(
38
+ name="weather_day",
39
+ target_tag=tag,
40
+ timesteps=periodic(start=0, period=self.day_length, end=max_steps),
41
+ mutations=[updateTarget(_merge(self.day_deltas, self.night_deltas))],
42
+ )
43
+
44
+ # Dusk: reverse day deltas, apply night deltas
45
+ events["weather_night"] = EventConfig(
46
+ name="weather_night",
47
+ target_tag=tag,
48
+ timesteps=periodic(start=half, period=self.day_length, end=max_steps),
49
+ mutations=[updateTarget(_merge(self.night_deltas, self.day_deltas))],
50
+ )
51
+
52
+ return events
@@ -261,11 +261,11 @@ uv run cogames play --mission evals.diagnostic_assemble_seeded_search -p baselin
261
261
 
262
262
  ```bash
263
263
  # Run full evaluation suite
264
- uv run python packages/cogames/scripts/run_evaluation.py
264
+ uv run cogames diagnose ladybug -S all
265
265
 
266
266
  # Evaluate specific agent
267
- uv run python packages/cogames/scripts/run_evaluation.py --policy baseline
268
- uv run python packages/cogames/scripts/run_evaluation.py --policy ladybug
267
+ uv run cogames diagnose baseline
268
+ uv run cogames diagnose ladybug
269
269
  ```
270
270
 
271
271
  ## Evaluation Results
cogames/evaluate.py CHANGED
@@ -13,10 +13,10 @@ from pydantic import BaseModel, ConfigDict
13
13
  from rich.console import Console
14
14
  from rich.table import Table
15
15
 
16
- from metta_alo.rollout import run_multi_episode_rollout
17
16
  from metta_alo.scoring import allocate_counts, validate_proportions
18
17
  from mettagrid import MettaGridConfig
19
18
  from mettagrid.policy.policy import PolicySpec
19
+ from mettagrid.runner.rollout import run_multi_episode_rollout
20
20
  from mettagrid.simulator.multi_episode.rollout import MultiEpisodeRolloutResult
21
21
  from mettagrid.simulator.multi_episode.summary import MultiEpisodeRolloutSummary, build_multi_episode_rollout_summaries
22
22
 
@@ -46,6 +46,7 @@ def evaluate(
46
46
  episodes: int,
47
47
  action_timeout_ms: int,
48
48
  seed: int = 42,
49
+ device: Optional[str] = None,
49
50
  output_format: Optional[Literal["yaml", "json"]] = None,
50
51
  save_replay: Optional[str] = None,
51
52
  ) -> MissionResultsSummary:
@@ -70,7 +71,7 @@ def evaluate(
70
71
  all_replay_paths: list[str] = []
71
72
  for mission_name, env_cfg in missions:
72
73
  counts = allocate_counts(env_cfg.game.num_agents, proportions)
73
- assignments = np.repeat(np.arange(len(counts), dtype=int), counts)
74
+ assignments = [i for i, c in enumerate(counts) for _ in range(c)]
74
75
 
75
76
  progress_label = f"Simulating ({mission_name})"
76
77
  with typer.progressbar(length=episodes, label=progress_label) as progress:
@@ -83,6 +84,7 @@ def evaluate(
83
84
  max_action_time_ms=action_timeout_ms,
84
85
  replay_dir=save_replay,
85
86
  create_replay_dir=save_replay is not None,
87
+ device=device,
86
88
  on_progress=lambda _episode_idx, _result: progress.update(1),
87
89
  )
88
90