kaggle-environments 1.23.3__py3-none-any.whl → 1.23.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaggle-environments might be problematic. Click here for more details.

Files changed (46) hide show
  1. kaggle_environments/envs/open_spiel_env/games/repeated_poker/repeated_poker.js +2 -2
  2. kaggle_environments/envs/open_spiel_env/games/repeated_poker/visualizer/default/src/components/getRepeatedPokerStateForStep.js +6 -6
  3. kaggle_environments/envs/open_spiel_env/games/repeated_poker/visualizer/default/src/images/poker_chip_1.svg +22 -0
  4. kaggle_environments/envs/open_spiel_env/games/repeated_poker/visualizer/default/src/images/poker_chip_10.svg +22 -0
  5. kaggle_environments/envs/open_spiel_env/games/repeated_poker/visualizer/default/src/images/poker_chip_100.svg +48 -0
  6. kaggle_environments/envs/open_spiel_env/games/repeated_poker/visualizer/default/src/images/poker_chip_25.svg +22 -0
  7. kaggle_environments/envs/open_spiel_env/games/repeated_poker/visualizer/default/src/images/poker_chip_5.svg +22 -0
  8. kaggle_environments/envs/open_spiel_env/games/repeated_poker/visualizer/default/src/repeated_poker_renderer.js +550 -331
  9. kaggle_environments/envs/werewolf/README.md +190 -0
  10. kaggle_environments/envs/werewolf/harness/__init__.py +0 -0
  11. kaggle_environments/envs/werewolf/harness/base.py +767 -0
  12. kaggle_environments/envs/werewolf/harness/litellm_models.yaml +51 -0
  13. kaggle_environments/envs/werewolf/harness/test_base.py +35 -0
  14. kaggle_environments/envs/werewolf/runner.py +146 -0
  15. kaggle_environments/envs/werewolf/scripts/__init__.py +0 -0
  16. kaggle_environments/envs/werewolf/scripts/add_audio.py +425 -0
  17. kaggle_environments/envs/werewolf/scripts/configs/audio/standard.yaml +24 -0
  18. kaggle_environments/envs/werewolf/scripts/configs/run/block_basic.yaml +102 -0
  19. kaggle_environments/envs/werewolf/scripts/configs/run/comprehensive.yaml +100 -0
  20. kaggle_environments/envs/werewolf/scripts/configs/run/roundrobin_discussion_DisableDoctorSelfSave_DisableDoctorConsecutiveSave_large.yaml +104 -0
  21. kaggle_environments/envs/werewolf/scripts/configs/run/roundrobin_discussion_large.yaml +103 -0
  22. kaggle_environments/envs/werewolf/scripts/configs/run/roundrobin_discussion_small.yaml +103 -0
  23. kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard.yaml +103 -0
  24. kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard_DisableDoctorSelfSave_DisableDoctorConsecutiveSave.yaml +104 -0
  25. kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard_DisableDoctorSelfSave_SeerRevealTeam.yaml +105 -0
  26. kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard_DisableDoctorSelfSave_SeerRevealTeam_NightEliminationNoReveal_DayExileNoReveal.yaml +105 -0
  27. kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard_DisableDoctorSelfSave_SeerRevealTeam_NightEliminationRevealTeam_DayExileRevealTeam.yaml +105 -0
  28. kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard_disable_doctor_self_save.yaml +103 -0
  29. kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard_parallel_voting.yaml +103 -0
  30. kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard_parallel_voting_no_tie_exile.yaml +103 -0
  31. kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard_parallel_voting_roundbiddiscussion.yaml +105 -0
  32. kaggle_environments/envs/werewolf/scripts/configs/run/run_config.yaml +58 -0
  33. kaggle_environments/envs/werewolf/scripts/configs/run/vertex_api_example_config.yaml +115 -0
  34. kaggle_environments/envs/werewolf/scripts/measure_cost.py +251 -0
  35. kaggle_environments/envs/werewolf/scripts/plot_existing_trajectories.py +135 -0
  36. kaggle_environments/envs/werewolf/scripts/rerender_html.py +87 -0
  37. kaggle_environments/envs/werewolf/scripts/run.py +93 -0
  38. kaggle_environments/envs/werewolf/scripts/run_block.py +237 -0
  39. kaggle_environments/envs/werewolf/scripts/run_pairwise_matrix.py +222 -0
  40. kaggle_environments/envs/werewolf/scripts/self_play.py +196 -0
  41. kaggle_environments/envs/werewolf/scripts/utils.py +47 -0
  42. {kaggle_environments-1.23.3.dist-info → kaggle_environments-1.23.4.dist-info}/METADATA +1 -1
  43. {kaggle_environments-1.23.3.dist-info → kaggle_environments-1.23.4.dist-info}/RECORD +46 -8
  44. {kaggle_environments-1.23.3.dist-info → kaggle_environments-1.23.4.dist-info}/WHEEL +0 -0
  45. {kaggle_environments-1.23.3.dist-info → kaggle_environments-1.23.4.dist-info}/entry_points.txt +0 -0
  46. {kaggle_environments-1.23.3.dist-info → kaggle_environments-1.23.4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,135 @@
1
+ import argparse
2
+ import glob
3
+ import json
4
+ import logging
5
+ import os
6
+ import re
7
+ import sys
8
+
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+
12
+ from kaggle_environments.envs.werewolf.werewolf import CostSummary
13
+
14
+ # Add the project root to the Python path to allow importing from kaggle_environments
15
+ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
16
+ if project_root not in sys.path:
17
+ sys.path.insert(0, project_root)
18
+
19
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def plot_token_trajectories(trajectories_data, output_dir):
24
+ """
25
+ Plots token usage trajectories, grouped by max_turns, and saves them to files.
26
+ """
27
+ for metric, trajectories_by_turns in trajectories_data.items():
28
+ if not trajectories_by_turns:
29
+ logger.warning(f"No data found for metric '{metric}'. Skipping plot.")
30
+ continue
31
+
32
+ plt.figure(figsize=(12, 8))
33
+
34
+ # Create a color map for the different turn settings
35
+ turn_keys = sorted(trajectories_by_turns.keys(), key=int)
36
+ colors = plt.cm.viridis(np.linspace(0, 1, len(turn_keys)))
37
+ color_map = {turns: color for turns, color in zip(turn_keys, colors)}
38
+
39
+ for turns, trajectories in sorted(trajectories_by_turns.items(), key=lambda item: int(item[0])):
40
+ for i, traj in enumerate(trajectories):
41
+ if not all(isinstance(x, (int, float)) for x in traj):
42
+ logger.error(
43
+ f"Trajectory for metric '{metric}' (turns={turns}) contains non-numeric data. Skipping."
44
+ )
45
+ continue
46
+ # Only add a label to the first trajectory of each group for a clean legend
47
+ label = f"Max Turns: {turns}" if i == 0 else None
48
+ plt.plot(np.arange(len(traj)), traj, linestyle="-", alpha=0.4, color=color_map[turns], label=label)
49
+
50
+ plt.title(f"{metric.replace('_', ' ').title()} per Query Step Trajectories")
51
+ plt.xlabel("Query Step")
52
+ plt.ylabel(f"{metric.replace('_', ' ').title()} per Query Step")
53
+ plt.grid(True, which="both", linestyle="--", linewidth=0.5)
54
+ plt.legend()
55
+
56
+ plot_filename = os.path.join(output_dir, f"{metric}_trajectories.png")
57
+ plt.savefig(plot_filename)
58
+ plt.close()
59
+ logger.info(f"Saved trajectory plot: {plot_filename}")
60
+
61
+
62
+ def main():
63
+ parser = argparse.ArgumentParser(
64
+ description="Load data from a measure_cost.py output directory and generate token trajectory plots."
65
+ )
66
+ parser.add_argument(
67
+ "-i",
68
+ "--input_dir",
69
+ type=str,
70
+ required=True,
71
+ help="Path to the output directory of a previous measure_cost.py run.",
72
+ )
73
+ args = parser.parse_args()
74
+
75
+ if not os.path.isdir(args.input_dir):
76
+ logger.error(f"Input directory not found: {args.input_dir}")
77
+ return
78
+
79
+ logger.info(f"Loading data from: {args.input_dir}")
80
+
81
+ all_trajectories = {"total_tokens": {}, "reasoning_tokens": {}, "text_tokens": {}}
82
+
83
+ # Find all game replay JSON files
84
+ game_files = glob.glob(os.path.join(args.input_dir, "game_*_run_*.json"))
85
+ if not game_files:
86
+ logger.error(f"No game replay files (game_*_run_*.json) found in {args.input_dir}.")
87
+ return
88
+
89
+ logger.info(f"Found {len(game_files)} game replay files to process.")
90
+
91
+ for game_file in game_files:
92
+ # Extract max_turns from filename
93
+ match = re.search(r"game_turns_(\d+)_run_", os.path.basename(game_file))
94
+ if not match:
95
+ logger.warning(f"Could not parse max_turns from filename: {game_file}. Skipping.")
96
+ continue
97
+ turns = match.group(1)
98
+
99
+ with open(game_file, "r") as f:
100
+ game_data = json.load(f)
101
+
102
+ cost_summary_dict = game_data.get("info", {}).get("GAME_END", {}).get("cost_summary")
103
+ if not cost_summary_dict:
104
+ logger.warning(f"No cost_summary found in {game_file}. Skipping.")
105
+ continue
106
+
107
+ cost_summary = CostSummary(**cost_summary_dict)
108
+
109
+ for agent_summary in cost_summary.cost_per_agent:
110
+ if agent_summary.data and agent_summary.data.usage_history:
111
+ usage_history_dicts = [usage.model_dump() for usage in agent_summary.data.usage_history]
112
+
113
+ total_tokens_traj = [usage.get("total_tokens", 0) or 0 for usage in usage_history_dicts]
114
+ all_trajectories["total_tokens"].setdefault(turns, []).append(total_tokens_traj)
115
+
116
+ reasoning_tokens_traj = [
117
+ usage.get("completion_tokens_details", {}).get("reasoning_tokens", 0) or 0
118
+ for usage in usage_history_dicts
119
+ ]
120
+ all_trajectories["reasoning_tokens"].setdefault(turns, []).append(reasoning_tokens_traj)
121
+
122
+ text_tokens_traj = [
123
+ (u.get("completion_tokens", 0) or 0)
124
+ - (u.get("completion_tokens_details", {}).get("reasoning_tokens", 0) or 0)
125
+ for u in usage_history_dicts
126
+ ]
127
+ all_trajectories["text_tokens"].setdefault(turns, []).append(text_tokens_traj)
128
+
129
+ logger.info("Finished processing all files. Generating plots...")
130
+ plot_token_trajectories(all_trajectories, args.input_dir)
131
+ logger.info(f"--- Script finished. Plots saved in {args.input_dir} ---")
132
+
133
+
134
+ if __name__ == "__main__":
135
+ main()
@@ -0,0 +1,87 @@
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import os
5
+
6
+ from kaggle_environments import make
7
+
8
+ # Configure logging
9
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
10
+
11
+
12
+ def main():
13
+ """
14
+ Rerenders a Werewolf game replay HTML file from an existing game record JSON.
15
+ This is useful for updating the replay viewer to the latest version without
16
+ rerunning the entire game simulation.
17
+ """
18
+ parser = argparse.ArgumentParser(
19
+ description="Rerender a Werewolf game HTML replay from a JSON game record.",
20
+ formatter_class=argparse.RawTextHelpFormatter,
21
+ )
22
+ parser.add_argument(
23
+ "-i",
24
+ "--input_json",
25
+ type=str,
26
+ required=True,
27
+ help="Path to the input game record JSON file (e.g., werewolf_game.json).",
28
+ )
29
+ parser.add_argument(
30
+ "-o", "--output_html", type=str, required=True, help="Path to write the newly rendered HTML output file."
31
+ )
32
+ args = parser.parse_args()
33
+
34
+ logging.info(f"Loading game record from: {args.input_json}")
35
+ if not os.path.exists(args.input_json):
36
+ logging.error(f"Error: Input file not found at {args.input_json}")
37
+ return
38
+
39
+ try:
40
+ with open(args.input_json, "r", encoding="utf-8") as f:
41
+ replay_data = json.load(f)
42
+ except json.JSONDecodeError:
43
+ logging.error(f"Error: Failed to decode JSON from {args.input_json}. The file might be corrupted.")
44
+ return
45
+ except Exception as e:
46
+ logging.error(f"An unexpected error occurred while reading the file: {e}")
47
+ return
48
+
49
+ logging.info("Successfully loaded game data. Initializing Kaggle environment...")
50
+
51
+ # The environment name should be stored in the replay, but we default to 'werewolf'
52
+ env_name = replay_data.get("name", "werewolf")
53
+ if env_name != "werewolf":
54
+ logging.warning(f"Game record is for '{env_name}', but we are rendering with the 'werewolf' environment.")
55
+
56
+ try:
57
+ # Recreate the environment state from the replay file
58
+ env = make(
59
+ "werewolf",
60
+ configuration=replay_data.get("configuration"),
61
+ steps=replay_data.get("steps", []),
62
+ info=replay_data.get("info", {}),
63
+ )
64
+ logging.info("Environment initialized. Rendering new HTML...")
65
+
66
+ # Render the HTML. This will use the werewolf.js file included in the
67
+ # installed kaggle_environments package.
68
+ html_content = env.render(mode="html")
69
+
70
+ output_dir = os.path.dirname(args.output_html)
71
+ if output_dir:
72
+ os.makedirs(output_dir, exist_ok=True)
73
+
74
+ with open(args.output_html, "w", encoding="utf-8") as f:
75
+ f.write(html_content)
76
+
77
+ logging.info(f"Successfully rerendered HTML to: {args.output_html}")
78
+
79
+ except Exception as e:
80
+ logging.error(f"An error occurred during environment creation or rendering: {e}")
81
+ logging.error(
82
+ "Please ensure the 'kaggle_environments' package is correctly installed and the JSON file is valid."
83
+ )
84
+
85
+
86
+ if __name__ == "__main__":
87
+ main()
@@ -0,0 +1,93 @@
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import random
5
+
6
+ import yaml
7
+
8
+ from kaggle_environments.envs.werewolf.harness.base import LLMWerewolfAgent
9
+ from kaggle_environments.envs.werewolf.runner import (
10
+ LogExecutionTime,
11
+ append_timestamp_to_dir,
12
+ log_git_hash,
13
+ run_werewolf,
14
+ setup_logger,
15
+ )
16
+ from kaggle_environments.envs.werewolf.werewolf import LLM_SYSTEM_PROMPT, AgentFactoryWrapper, register_agents
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def main():
22
+ parser = argparse.ArgumentParser(description="Run a single Werewolf game.")
23
+ parser.add_argument(
24
+ "-c",
25
+ "--config_path",
26
+ type=str,
27
+ default=os.path.join(os.path.dirname(__file__), "configs/run/run_config.yaml"),
28
+ help="Path to the YAML configuration file.",
29
+ )
30
+ parser.add_argument(
31
+ "-o", "--output_dir", type=str, default="werewolf_run", help="Output directory for the log and replay file."
32
+ )
33
+ parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode.")
34
+ parser.add_argument(
35
+ "-r", "--random_agents", action="store_true", help="Use random agents for all players for fast testing."
36
+ )
37
+ parser.add_argument(
38
+ "-a", "--append_timestamp_to_dir", action="store_true", help="Append a timestamp to the output directory."
39
+ )
40
+ parser.add_argument(
41
+ "-s", "--shuffle_roles", action="store_true", help="If provided, shuffle the roles provided in the config."
42
+ )
43
+
44
+ args = parser.parse_args()
45
+
46
+ # Create a unique subdirectory for this run
47
+ run_output_dir = append_timestamp_to_dir(args.output_dir, append=args.append_timestamp_to_dir)
48
+
49
+ os.makedirs(run_output_dir, exist_ok=True)
50
+
51
+ base_name = "werewolf_game"
52
+ setup_logger(output_dir=run_output_dir, base_name=base_name)
53
+
54
+ log_git_hash()
55
+
56
+ # Load game configuration
57
+ with open(args.config_path, "r") as f:
58
+ config = yaml.safe_load(f)
59
+ game_config = config.get("game_config", {})
60
+
61
+ # shuffle roles
62
+ if args.shuffle_roles:
63
+ role_and_params = [(agent["role"], agent.get("role_params", {})) for agent in game_config["agents"]]
64
+ random.shuffle(role_and_params)
65
+ for agent, (new_role, new_role_params) in zip(game_config["agents"], role_and_params):
66
+ agent["role"] = new_role
67
+ agent["role_params"] = new_role_params
68
+
69
+ # Extract agent harnesses from the config and register the agents
70
+ agents_ = [agent.get("agent_id", "random") for agent in game_config.get("agents", [])]
71
+ agent_dict = {}
72
+ for agent_name in agents_:
73
+ if agent_name.startswith("llm/"):
74
+ model_name = agent_name.lstrip("llm/")
75
+ agent_dict[agent_name] = AgentFactoryWrapper(
76
+ LLMWerewolfAgent, model_name=model_name, system_prompt=LLM_SYSTEM_PROMPT
77
+ )
78
+ register_agents(agent_dict)
79
+
80
+ if args.random_agents:
81
+ logger.info("Using random agents for all players.")
82
+ agents_ = ["random"] * len(agents_)
83
+
84
+ logger.info(f"Starting Werewolf game run. Output will be saved to: {run_output_dir}")
85
+ with LogExecutionTime(logger_obj=logger, task_str="single game"):
86
+ run_werewolf(
87
+ output_dir=run_output_dir, base_name=base_name, config=game_config, agents=agents_, debug=args.debug
88
+ )
89
+ logger.info(f"Game finished. Replay and log saved in: {run_output_dir}")
90
+
91
+
92
+ if __name__ == "__main__":
93
+ main()
@@ -0,0 +1,237 @@
1
+ import argparse
2
+ import collections
3
+ import logging
4
+ import math
5
+ import multiprocessing
6
+ import os
7
+ import random
8
+ from itertools import permutations
9
+ from typing import Any, Dict, List
10
+
11
+ import tenacity
12
+ import yaml
13
+ from tqdm import tqdm
14
+
15
+ from kaggle_environments.envs.werewolf.runner import LogExecutionTime, append_timestamp_to_dir, setup_logger
16
+ from kaggle_environments.envs.werewolf.scripts.utils import run_single_game_cli
17
+
18
+ # Initialize a placeholder logger
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def load_config(config_path):
23
+ """Loads the configuration from a YAML file."""
24
+ with open(config_path, "r") as f:
25
+ return yaml.safe_load(f)
26
+
27
+
28
+ def get_all_unique_role_configs(role_configs: List[Dict[str, Any]]) -> List[List[Dict[str, Any]]]:
29
+ """
30
+ Generates all unique permutations of role configurations.
31
+ A role configuration is a dict with 'role' and 'role_params'.
32
+ """
33
+
34
+ def make_hashable(config):
35
+ role = config["role"]
36
+ params = config.get("role_params", {})
37
+ if params:
38
+ return role, frozenset(params.items())
39
+ return role, frozenset()
40
+
41
+ def make_unhashable(hashable_config):
42
+ role, params_frozenset = hashable_config
43
+ return {"role": role, "role_params": dict(params_frozenset)}
44
+
45
+ hashable_configs = [make_hashable(c) for c in role_configs]
46
+ all_perms_hashable = list(set(permutations(hashable_configs)))
47
+ all_perms = [[make_unhashable(c) for c in p] for p in all_perms_hashable]
48
+ return all_perms
49
+
50
+
51
+ run_single_game_with_retry = tenacity.retry(
52
+ wait=tenacity.wait_random_exponential(multiplier=1, min=2, max=10),
53
+ stop=tenacity.stop_after_attempt(3),
54
+ before_sleep=tenacity.before_sleep_log(logger, logging.INFO),
55
+ )(run_single_game_cli)
56
+
57
+
58
+ def game_runner_wrapper(args):
59
+ """Wrapper to unpack arguments for the multiprocessing pool."""
60
+ game_dir, game_config, use_random_agents, debug, _, _ = args
61
+ run_single_game_with_retry(game_dir, game_config, use_random_agents, debug)
62
+
63
+
64
+ def generate_game_tasks(output_dir, num_blocks, config, use_random_agents, debug, shuffle_player_ids):
65
+ """
66
+ Generates all game configurations for the entire experiment.
67
+ """
68
+ base_game_config = config["game_config"]
69
+ players_data = base_game_config["agents"]
70
+ base_role_configs = [{"role": agent["role"], "role_params": agent.get("role_params", {})} for agent in players_data]
71
+
72
+ logger.info("Generating all unique role configurations...")
73
+ all_role_configs = get_all_unique_role_configs(base_role_configs)
74
+ logger.info(f"Found {len(all_role_configs)} unique arrangements.")
75
+
76
+ available_role_configs = []
77
+
78
+ for block_index in range(num_blocks):
79
+ block_dir = os.path.join(output_dir, f"block_{block_index}")
80
+ os.makedirs(block_dir, exist_ok=True)
81
+
82
+ if not available_role_configs:
83
+ if num_blocks > len(all_role_configs):
84
+ logger.warning("Sampling with replacement as num_blocks > unique configurations.")
85
+ available_role_configs = list(all_role_configs)
86
+ random.shuffle(available_role_configs)
87
+
88
+ block_role_config = available_role_configs.pop()
89
+ random.shuffle(players_data)
90
+ current_players_deque = collections.deque(players_data)
91
+
92
+ for game_in_block in range(len(players_data)):
93
+ game_dir = os.path.join(block_dir, f"game_{game_in_block}")
94
+ os.makedirs(game_dir, exist_ok=True)
95
+
96
+ current_players = list(current_players_deque)
97
+ game_agents_config = [
98
+ {**player_config, **block_role_config[i]} for i, player_config in enumerate(current_players)
99
+ ]
100
+
101
+ if shuffle_player_ids:
102
+ player_ids = [agent["id"] for agent in game_agents_config]
103
+ random.shuffle(player_ids)
104
+ for i, agent in enumerate(game_agents_config):
105
+ agent["id"] = player_ids[i]
106
+
107
+ game_config = {**base_game_config, "agents": game_agents_config}
108
+ yield (game_dir, game_config, use_random_agents, debug, block_index, game_in_block)
109
+ current_players_deque.rotate(1)
110
+
111
+
112
+ def run_experiment(
113
+ output_dir, num_blocks, config, use_random_agents, debug, parallel, num_processes, shuffle_player_ids
114
+ ):
115
+ """
116
+ Runs a tournament by generating all game tasks and processing them,
117
+ potentially in parallel.
118
+ """
119
+ if debug:
120
+ logger.warning("Debug mode is enabled. Forcing sequential execution.")
121
+
122
+ base_game_config = config["game_config"]
123
+ players_data = base_game_config["agents"]
124
+ total_games = num_blocks * len(players_data)
125
+
126
+ if parallel:
127
+ logger.info(f"Running games in parallel with up to {num_processes} processes.")
128
+
129
+ game_tasks = generate_game_tasks(output_dir, num_blocks, config, use_random_agents, debug, shuffle_player_ids)
130
+
131
+ with tqdm(total=total_games, desc="Processing Games") as pbar:
132
+ if parallel:
133
+ with multiprocessing.Pool(processes=num_processes) as pool:
134
+ for _ in pool.imap_unordered(game_runner_wrapper, game_tasks):
135
+ pbar.update(1)
136
+ else:
137
+ for task_args in game_tasks:
138
+ game_runner_wrapper(task_args)
139
+ pbar.update(1)
140
+
141
+ logger.info("All game tasks have been processed.")
142
+
143
+
144
+ def main():
145
+ script_dir = os.path.dirname(os.path.abspath(__file__))
146
+ default_config_path = os.path.join(script_dir, "configs", "run", "run_config.yaml")
147
+
148
+ parser = argparse.ArgumentParser(
149
+ description="Run a block-design experiment for the Werewolf game, "
150
+ "where each block is a complete role rotation amongst the players."
151
+ )
152
+ parser.add_argument(
153
+ "-o",
154
+ "--output_dir",
155
+ type=str,
156
+ help="Output directory for game replays and logs.",
157
+ default="werewolf_block_experiment",
158
+ )
159
+ parser.add_argument(
160
+ "-c", "--config", type=str, default=default_config_path, help="Path to the base configuration YAML file."
161
+ )
162
+ parser.add_argument(
163
+ "-b",
164
+ "--num_blocks",
165
+ type=int,
166
+ default=10,
167
+ help="Number of blocks to run. Each block is a complete role rotation.",
168
+ )
169
+ parser.add_argument(
170
+ "-r", "--use_random_agents", action="store_true", help="Use random agents for all players for fast testing."
171
+ )
172
+ parser.add_argument(
173
+ "-d",
174
+ "--debug",
175
+ action="store_true",
176
+ help="Enable debug mode for the game environment. "
177
+ "Note that you can use debug mode to enable intra game sequential execution.",
178
+ )
179
+ parser.add_argument("-p", "--parallel", action="store_true", help="Run games in parallel using multiple processes.")
180
+ parser.add_argument(
181
+ "-n", "--num_processes", type=int, default=None, help="Number of processes for parallel execution."
182
+ )
183
+ parser.add_argument(
184
+ "-a", "--append_timestamp_to_dir", action="store_true", help="Append a timestamp to the output directory."
185
+ )
186
+ parser.add_argument(
187
+ "-s",
188
+ "--shuffle_player_ids",
189
+ action="store_true",
190
+ help="Shuffle player ids for each game to account for name bias.",
191
+ )
192
+
193
+ args = parser.parse_args()
194
+
195
+ output_dir = append_timestamp_to_dir(args.output_dir, append=args.append_timestamp_to_dir)
196
+
197
+ os.makedirs(output_dir, exist_ok=True)
198
+
199
+ setup_logger(output_dir, "run_block")
200
+
201
+ config = load_config(args.config)
202
+
203
+ num_players = len(config.get("game_config", {}).get("agents", []))
204
+ if args.num_processes is None:
205
+ num_processes = multiprocessing.cpu_count() * 0.9
206
+ if not args.debug:
207
+ num_processes /= num_players
208
+ num_processes = max(1, math.floor(num_processes))
209
+ else:
210
+ num_processes = args.num_processes
211
+
212
+ logger.info("Starting experiment with the following settings:")
213
+ logger.info(f"Output Directory: {output_dir}")
214
+ logger.info(f"Number of Blocks: {args.num_blocks}")
215
+ logger.info(f"Parallel Execution: {args.parallel}")
216
+ if args.parallel:
217
+ logger.info(f"Number of Processes: {num_processes}")
218
+ logger.info(f"Debug Mode: {args.debug}")
219
+ logger.info(f"Use Random Agents: {args.use_random_agents}")
220
+ logger.info(f"Shuffle Player IDs: {args.shuffle_player_ids}")
221
+
222
+ with LogExecutionTime(logger_obj=logger, task_str="block experiment"):
223
+ run_experiment(
224
+ output_dir=output_dir,
225
+ num_blocks=args.num_blocks,
226
+ config=config,
227
+ use_random_agents=args.use_random_agents,
228
+ debug=args.debug,
229
+ parallel=args.parallel,
230
+ num_processes=num_processes,
231
+ shuffle_player_ids=args.shuffle_player_ids,
232
+ )
233
+ logger.info("Experiment finished successfully.")
234
+
235
+
236
+ if __name__ == "__main__":
237
+ main()