kaggle-environments 1.23.3__py3-none-any.whl → 1.23.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kaggle-environments might be problematic. Click here for more details.

Files changed (46) hide show
  1. kaggle_environments/envs/open_spiel_env/games/repeated_poker/repeated_poker.js +2 -2
  2. kaggle_environments/envs/open_spiel_env/games/repeated_poker/visualizer/default/src/components/getRepeatedPokerStateForStep.js +6 -6
  3. kaggle_environments/envs/open_spiel_env/games/repeated_poker/visualizer/default/src/images/poker_chip_1.svg +22 -0
  4. kaggle_environments/envs/open_spiel_env/games/repeated_poker/visualizer/default/src/images/poker_chip_10.svg +22 -0
  5. kaggle_environments/envs/open_spiel_env/games/repeated_poker/visualizer/default/src/images/poker_chip_100.svg +48 -0
  6. kaggle_environments/envs/open_spiel_env/games/repeated_poker/visualizer/default/src/images/poker_chip_25.svg +22 -0
  7. kaggle_environments/envs/open_spiel_env/games/repeated_poker/visualizer/default/src/images/poker_chip_5.svg +22 -0
  8. kaggle_environments/envs/open_spiel_env/games/repeated_poker/visualizer/default/src/repeated_poker_renderer.js +550 -331
  9. kaggle_environments/envs/werewolf/README.md +190 -0
  10. kaggle_environments/envs/werewolf/harness/__init__.py +0 -0
  11. kaggle_environments/envs/werewolf/harness/base.py +767 -0
  12. kaggle_environments/envs/werewolf/harness/litellm_models.yaml +51 -0
  13. kaggle_environments/envs/werewolf/harness/test_base.py +35 -0
  14. kaggle_environments/envs/werewolf/runner.py +146 -0
  15. kaggle_environments/envs/werewolf/scripts/__init__.py +0 -0
  16. kaggle_environments/envs/werewolf/scripts/add_audio.py +425 -0
  17. kaggle_environments/envs/werewolf/scripts/configs/audio/standard.yaml +24 -0
  18. kaggle_environments/envs/werewolf/scripts/configs/run/block_basic.yaml +102 -0
  19. kaggle_environments/envs/werewolf/scripts/configs/run/comprehensive.yaml +100 -0
  20. kaggle_environments/envs/werewolf/scripts/configs/run/roundrobin_discussion_DisableDoctorSelfSave_DisableDoctorConsecutiveSave_large.yaml +104 -0
  21. kaggle_environments/envs/werewolf/scripts/configs/run/roundrobin_discussion_large.yaml +103 -0
  22. kaggle_environments/envs/werewolf/scripts/configs/run/roundrobin_discussion_small.yaml +103 -0
  23. kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard.yaml +103 -0
  24. kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard_DisableDoctorSelfSave_DisableDoctorConsecutiveSave.yaml +104 -0
  25. kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard_DisableDoctorSelfSave_SeerRevealTeam.yaml +105 -0
  26. kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard_DisableDoctorSelfSave_SeerRevealTeam_NightEliminationNoReveal_DayExileNoReveal.yaml +105 -0
  27. kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard_DisableDoctorSelfSave_SeerRevealTeam_NightEliminationRevealTeam_DayExileRevealTeam.yaml +105 -0
  28. kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard_disable_doctor_self_save.yaml +103 -0
  29. kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard_parallel_voting.yaml +103 -0
  30. kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard_parallel_voting_no_tie_exile.yaml +103 -0
  31. kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard_parallel_voting_roundbiddiscussion.yaml +105 -0
  32. kaggle_environments/envs/werewolf/scripts/configs/run/run_config.yaml +58 -0
  33. kaggle_environments/envs/werewolf/scripts/configs/run/vertex_api_example_config.yaml +115 -0
  34. kaggle_environments/envs/werewolf/scripts/measure_cost.py +251 -0
  35. kaggle_environments/envs/werewolf/scripts/plot_existing_trajectories.py +135 -0
  36. kaggle_environments/envs/werewolf/scripts/rerender_html.py +87 -0
  37. kaggle_environments/envs/werewolf/scripts/run.py +93 -0
  38. kaggle_environments/envs/werewolf/scripts/run_block.py +237 -0
  39. kaggle_environments/envs/werewolf/scripts/run_pairwise_matrix.py +222 -0
  40. kaggle_environments/envs/werewolf/scripts/self_play.py +196 -0
  41. kaggle_environments/envs/werewolf/scripts/utils.py +47 -0
  42. {kaggle_environments-1.23.3.dist-info → kaggle_environments-1.23.4.dist-info}/METADATA +1 -1
  43. {kaggle_environments-1.23.3.dist-info → kaggle_environments-1.23.4.dist-info}/RECORD +46 -8
  44. {kaggle_environments-1.23.3.dist-info → kaggle_environments-1.23.4.dist-info}/WHEEL +0 -0
  45. {kaggle_environments-1.23.3.dist-info → kaggle_environments-1.23.4.dist-info}/entry_points.txt +0 -0
  46. {kaggle_environments-1.23.3.dist-info → kaggle_environments-1.23.4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,767 @@
1
+ import functools
2
+ import json
3
+ import logging
4
+ import os
5
+ import re
6
+ import traceback
7
+ from abc import ABC, abstractmethod
8
+ from collections import namedtuple
9
+ from typing import List, Optional
10
+
11
+ import litellm
12
+ import pyjson5
13
+ import tenacity
14
+ import yaml
15
+ from dotenv import load_dotenv
16
+ from litellm import completion, cost_per_token
17
+ from litellm.types.utils import Usage
18
+ from pydantic import BaseModel, Field
19
+
20
+ from kaggle_environments.envs.werewolf.game.actions import (
21
+ BidAction,
22
+ ChatAction,
23
+ EliminateProposalAction,
24
+ HealAction,
25
+ InspectAction,
26
+ NoOpAction,
27
+ TargetedAction,
28
+ VoteAction,
29
+ )
30
+ from kaggle_environments.envs.werewolf.game.consts import ActionType, DetailedPhase, EventName, RoleConst
31
+ from kaggle_environments.envs.werewolf.game.records import get_raw_observation
32
+ from kaggle_environments.envs.werewolf.game.states import get_last_action_request
33
+
34
+ _LITELLM_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "litellm_models.yaml")
35
+ litellm.config_path = _LITELLM_CONFIG_PATH
36
+ with open(_LITELLM_CONFIG_PATH, "r") as _file:
37
+ _MODEL_COST_DICT = yaml.safe_load(_file)
38
+ litellm.register_model(_MODEL_COST_DICT)
39
+
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+ litellm.drop_params = True
44
+
45
+ # Load environment variables from a .env file in the same directory
46
+ load_dotenv()
47
+
48
+
49
+ class LLMActionException(Exception):
50
+ """Custom exception to carry context from a failed LLM action."""
51
+
52
+ def __init__(self, message, original_exception, raw_out=None, prompt=None):
53
+ super().__init__(message)
54
+ self.original_exception = original_exception
55
+ self.raw_out = raw_out
56
+ self.prompt = prompt
57
+
58
+ def __str__(self):
59
+ return f"{super().__str__()} | Raw Output: '{self.raw_out}'"
60
+
61
+
62
+ def _log_retry_warning(retry_state: tenacity.RetryCallState):
63
+ assert retry_state.outcome is not None
64
+ exception = retry_state.outcome.exception()
65
+ traceback_str = "".join(traceback.format_exception(exception))
66
+ if retry_state.attempt_number < 1:
67
+ loglevel = logging.INFO
68
+ else:
69
+ loglevel = logging.WARNING
70
+ logging.log(
71
+ loglevel,
72
+ "Retrying: $s attempt # %s ended with: $s Traceback: %s Retry state: %s",
73
+ retry_state.fn,
74
+ retry_state.attempt_number,
75
+ retry_state.outcome,
76
+ traceback_str,
77
+ retry_state,
78
+ )
79
+
80
+
81
+ def _is_rate_limit_error(exception) -> bool:
82
+ """
83
+ Checks if an exception is a RateLimitError that warrants a context reduction retry.
84
+ This checks for both OpenAI's specific error and the generic HTTP 429 status code.
85
+ """
86
+ is_openai_rate_limit = "RateLimitError" in str(type(exception))
87
+ is_http_429 = hasattr(exception, "status_code") and exception.status_code == 429
88
+ return is_openai_rate_limit or is_http_429
89
+
90
+
91
+ def _is_context_window_exceeded_error(exception) -> bool:
92
+ """"""
93
+ is_error = "ContextWindowExceededError" in str(type(exception))
94
+ return is_error
95
+
96
+
97
+ def _is_json_parsing_error(exception) -> bool:
98
+ out = True if isinstance(exception, pyjson5.Json5Exception) else False
99
+ return out
100
+
101
+
102
+ def _truncate_and_log_on_retry(retry_state: tenacity.RetryCallState):
103
+ """
104
+ Tenacity hook called before a retry. It reduces the context size if a
105
+ RateLimitError was detected.
106
+ """
107
+ # The first argument of the retried method is the class instance 'self'
108
+ agent_instance = retry_state.args[0]
109
+
110
+ if _is_rate_limit_error(retry_state.outcome.exception()):
111
+ # Reduce the number of history items to keep by 25% on each attempt
112
+ original_count = agent_instance._event_log_items_to_keep
113
+ agent_instance._event_log_items_to_keep = int(original_count * 0.75)
114
+
115
+ logger.warning(
116
+ "ContextWindowExceededError detected. Retrying with smaller context. "
117
+ "Reducing event log from %d to %d itms.",
118
+ original_count,
119
+ agent_instance._event_log_items_to_keep,
120
+ )
121
+
122
+ # Also call the original logging function for general retry logging
123
+ _log_retry_warning(retry_state)
124
+
125
+
126
+ def _add_error_entry_on_retry(retry_state: tenacity.RetryCallState):
127
+ last_exception_wrapper = retry_state.outcome.exception()
128
+ if isinstance(last_exception_wrapper, LLMActionException):
129
+ last_exception = last_exception_wrapper.original_exception
130
+ # You can also access the failed output here if needed for logging
131
+ raw_out = last_exception_wrapper.raw_out
132
+ prompt = last_exception_wrapper.prompt
133
+ logger.warning(f"Retrying due to JSON parsing error. Failed output: {raw_out} Failed prompt: {prompt}")
134
+ else:
135
+ last_exception = last_exception_wrapper
136
+
137
+ stack_trace_list = traceback.format_exception(last_exception)
138
+ stack_trace_str = "".join(stack_trace_list)
139
+ retry_state.kwargs["error_stack_trace"] = stack_trace_str
140
+ _log_retry_warning(retry_state)
141
+
142
+
143
+ TARGETED_ACTION_SCHEMA = TargetedAction.schema_for_player()
144
+ CHAT_ACTION_SCHEMA = ChatAction.schema_for_player()
145
+
146
+ BID_ACTION_SCHEMA = BidAction.schema_for_player()
147
+ BID_ACTION_SCHEMA_REASONING = BidAction.schema_for_player(("perceived_threat_level", "reasoning", "target_id"))
148
+
149
+
150
+ TARGETED_ACTION_EXEMPLAR = f"""```json
151
+ {json.dumps(dict(perceived_threat_level="SAFE", reasoning="I chose this target randomly.", target_id="some_player_id"))}
152
+ ```"""
153
+
154
+ BID_ACTION_EXEMPLAR = f"""```json
155
+ {json.dumps(dict(perceived_threat_level="UNEASY", amount=4))}
156
+ ```"""
157
+ BID_ACTION_EXEMPLAR_REASONING = f"""```json
158
+ {json.dumps(dict(perceived_threat_level="UNEASY", reasoning="I have important information to share, so I am bidding high.", amount=4))}
159
+ ```"""
160
+
161
+ AUDIO_EXAMPLE = 'Say in an spooky whisper: "By the pricking of my thumbs... Something wicked this way comes!"'
162
+ AUDIO_EXAMPLE_2 = 'Deliver in a thoughtful tone: "I was stunned. I really suspect John\'s intent of bringing up Tim."'
163
+ AUDIO_EXAMPLE_3 = (
164
+ 'Read this in as fast as possible while remaining intelligible: "My nomination for Jack was purely incidental."'
165
+ )
166
+ AUDIO_EXAMPLE_4 = 'Sound amused and relaxed: "that was a very keen observation, AND a classic wolf play.\n(voice: curious)\nI\'m wondering what the seer might say."'
167
+ CHAT_AUDIO_DICT = {
168
+ "perceived_threat_level": "SAFE",
169
+ "reasoning": "To draw attention to other players ...",
170
+ "message": AUDIO_EXAMPLE,
171
+ }
172
+ CHAT_AUDIO_DICT_2 = {
173
+ "perceived_threat_level": "DANGER",
174
+ "reasoning": "This accusation is uncalled for ...",
175
+ "message": AUDIO_EXAMPLE_2,
176
+ }
177
+ CHAT_AUDIO_DICT_3 = {
178
+ "perceived_threat_level": "UNEASY",
179
+ "reasoning": "I sense there are some suspicion directed towards me ...",
180
+ "message": AUDIO_EXAMPLE_3,
181
+ }
182
+ CHAT_AUDIO_DICT_4 = {
183
+ "perceived_threat_level": "UNEASY",
184
+ "reasoning": "I am redirecting the attention to other leads ...",
185
+ "message": AUDIO_EXAMPLE_4,
186
+ }
187
+ CHAT_ACTION_EXEMPLAR_2 = f"```json\n{json.dumps(CHAT_AUDIO_DICT)}\n```"
188
+ CHAT_ACTION_EXEMPLAR_3 = f"```json\n{json.dumps(CHAT_AUDIO_DICT_2)}\n```"
189
+ CHAT_ACTION_EXEMPLAR = f"```json\n{json.dumps(CHAT_AUDIO_DICT_3)}\n```"
190
+ CHAT_ACTION_EXEMPLAR_4 = f"```json\n{json.dumps(CHAT_AUDIO_DICT_4)}\n```"
191
+
192
+
193
+ CHAT_ACTION_ADDITIONAL_CONSTRAINTS_AUDIO = [
194
+ f'- The "message" will be rendered to TTS and shown to other players, so make sure to control the style, tone, '
195
+ f"accent and pace of your message using natural language prompt. e.g.\n{CHAT_ACTION_EXEMPLAR_2}",
196
+ "- Since this is a social game, the script in the message should sound conversational.",
197
+ '- Be Informal: Use contractions (like "it\'s," "gonna"), and simple language.',
198
+ "- Be Spontaneous: Vary your sentence length. It's okay to have short, incomplete thoughts or to restart a sentence.",
199
+ "- [Optional] If appropriate, you could add natural sounds in (sound: ...) e.g. (sound: chuckles), or (sound: laughs), etc.",
200
+ "- [Optional] Be Dynamic: A real chat is never monotonous. Use (voice: ...) instructions to constantly and subtly shift the tone to match the words.",
201
+ # f'- Be Expressive: Use a variety of descriptive tones. Don\'t just use happy or sad. Try tones like amused, '
202
+ # f'thoughtful, curious, energetic, sarcastic, or conspiratorial. e.g. \n{CHAT_ACTION_EXEMPLAR_4}'
203
+ ]
204
+
205
+
206
+ CHAT_TEXT_DICT = {
207
+ "perceived_threat_level": "UNEASY",
208
+ "reasoning": "I want to put pressure on Player3 and see how they react. A quiet player is often a werewolf.",
209
+ "message": "I'm suspicious of Player3. They've been too quiet. What do you all think?",
210
+ }
211
+ CHAT_ACTION_EXEMPLAR_TEXT = f"```json\n{json.dumps(CHAT_TEXT_DICT)}\n```"
212
+
213
+
214
+ CHAT_ACTION_ADDITIONAL_CONSTRAINTS_TEXT = [
215
+ '- The "message" will be displayed as text to other players. Focus on being clear and persuasive',
216
+ "- Your goal is to win the game as a team. Think about how to reach that goal strategically.",
217
+ '- Refer to players by their ID (e.g., "Player1", "Player3") to avoid ambiguity.',
218
+ "- Keep your messages concise and to the point. ",
219
+ '- You can simply say "Pass!", if you have nothing valuable you would like to share.',
220
+ ]
221
+
222
+
223
+ class WerewolfAgentBase(ABC):
224
+ @abstractmethod
225
+ def __call__(self, obs):
226
+ """The instance is meant to be used as callable for kaggle environments."""
227
+
228
+
229
+ DEFAULT_PROMPT_TEMPLATE = """{system_prompt}
230
+
231
+ ### Current Game State
232
+ {current_state}
233
+
234
+ ### Game Timeline
235
+ This is the complete, chronological timeline of all public events and your private actions.
236
+ {event_log}
237
+
238
+ ### Your Instruction
239
+ Based on the game state and event log, please respond to the following instruction.
240
+
241
+ {instruction}{error_instruction}
242
+ """
243
+
244
+ INSTRUCTION_TEMPLATE = """#### ROLE
245
+ {role}
246
+
247
+ #### TASK
248
+ {task}
249
+
250
+ #### CONSTRAINTS
251
+ - Your response MUST be a single, valid JSON object.
252
+ - generate the "reasoning" key first to think through your response. Your "reasoning" is invisible to other players.
253
+ {additional_constraints}
254
+
255
+ #### JSON SCHEMA
256
+ Your JSON output must conform to the following schema. Do NOT include this schema in your response.
257
+ ```json
258
+ {json_schema}
259
+ ```
260
+
261
+ #### EXAMPLE OUTPUT
262
+ Please format your response as a Markdown JSON code block, which should include the fences. Here's a valid example:
263
+ {exemplar}
264
+ """
265
+
266
+
267
+ class TokenCost(BaseModel):
268
+ total_tokens: int = 0
269
+ total_costs_usd: float = 0.0
270
+ token_count_history: List[int] = []
271
+ cost_history_usd: List[float] = []
272
+
273
+ def update(self, token_count, cost):
274
+ self.total_tokens += token_count
275
+ self.total_costs_usd += cost
276
+ self.token_count_history.append(token_count)
277
+ self.cost_history_usd.append(cost)
278
+
279
+
280
+ class LLMCostTracker(BaseModel):
281
+ model_name: str
282
+ query_token_cost: TokenCost = Field(default_factory=TokenCost)
283
+ prompt_token_cost: TokenCost = Field(default_factory=TokenCost)
284
+ completion_token_cost: TokenCost = Field(default_factory=TokenCost)
285
+ usage_history: List[Usage] = []
286
+ """example item from gemini flash model dump: response.usage = {'completion_tokens': 579, 'prompt_tokens': 1112,
287
+ 'total_tokens': 1691, 'completion_tokens_details': {'accepted_prediction_tokens': None,
288
+ 'audio_tokens': None, 'reasoning_tokens': 483, 'rejected_prediction_tokens': None,
289
+ 'text_tokens': 96}, 'prompt_tokens_details': {'audio_tokens': None, 'cached_tokens': None,
290
+ 'text_tokens': 1112, 'image_tokens': None}}"""
291
+
292
+ def update(self, response):
293
+ completion_tokens = response["usage"]["completion_tokens"]
294
+ prompt_tokens = response["usage"]["prompt_tokens"]
295
+ response_cost = response._hidden_params["response_cost"]
296
+
297
+ try:
298
+ prompt_cost, completion_cost = cost_per_token(
299
+ model=self.model_name, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
300
+ )
301
+ logger.info(f"Used litellm cost for {self.model_name}")
302
+ except Exception as exception:
303
+ raise Exception(
304
+ f"Could not find cost for {self.model_name} in litellm or custom dict. "
305
+ f'You can register the cost in "litellm_models.yaml"'
306
+ ) from exception
307
+
308
+ self.query_token_cost.update(token_count=prompt_tokens + completion_tokens, cost=response_cost)
309
+ self.prompt_token_cost.update(token_count=prompt_tokens, cost=prompt_cost)
310
+ self.completion_token_cost.update(token_count=completion_tokens, cost=completion_cost)
311
+ self.usage_history.append(response.usage)
312
+
313
+
314
+ class ActionRegistry:
315
+ """A registry for action handler based on phase and role."""
316
+
317
+ def __init__(self):
318
+ self._registry = {}
319
+
320
+ def register(self, phase: DetailedPhase, role: Optional[RoleConst] = None):
321
+ """If an action is not role specific, role can be left as None, in which case all roles will be
322
+ pointing to the same handler.
323
+ """
324
+
325
+ def decorator(func):
326
+ self._registry.setdefault(phase, {})
327
+ if role is not None:
328
+ self._registry[phase][role] = func
329
+ else:
330
+ for item in RoleConst:
331
+ self._registry[phase][item] = func
332
+
333
+ @functools.wraps(func)
334
+ def wrapper(*args, **kwargs):
335
+ return func(*args, **kwargs)
336
+
337
+ return wrapper
338
+
339
+ return decorator
340
+
341
+ def get(self, phase: DetailedPhase, role: RoleConst):
342
+ func = self._registry[phase][role]
343
+ return func
344
+
345
+
346
+ class EventLogKeys:
347
+ PUBLIC_EVENT = "public_event"
348
+ PRIVATE_ACTION = "private_action"
349
+
350
+
351
+ EventLogItem = namedtuple("EventLogItem", ["event_log_key", "day", "phase", "log_item"])
352
+
353
+
354
+ class LLMWerewolfAgent(WerewolfAgentBase):
355
+ action_registry = ActionRegistry()
356
+
357
+ def __init__(
358
+ self,
359
+ model_name: str,
360
+ agent_config: dict = None,
361
+ system_prompt: str = "",
362
+ prompt_template: str = DEFAULT_PROMPT_TEMPLATE,
363
+ kaggle_config=None,
364
+ ):
365
+ """This wrapper only support 1 LLM."""
366
+ agent_config = agent_config or {}
367
+ decoding_kwargs = agent_config.get("llms", [{}])[0].get("parameters")
368
+ self._decoding_kwargs = decoding_kwargs or {}
369
+ self._kaggle_config = kaggle_config or {}
370
+ self._chat_mode = agent_config.get("chat_mode", "audio")
371
+ self._enable_bid_reasoning = agent_config.get("enable_bid_reasoning", False)
372
+ self._cost_tracker = LLMCostTracker(model_name=model_name)
373
+
374
+ self._model_name = model_name
375
+ self._system_prompt = system_prompt
376
+ self._prompt_template = prompt_template
377
+ self._is_vertex_ai = "vertex_ai" in self._model_name
378
+
379
+ # storing all events including internal and external
380
+ self._event_logs: List[EventLogItem] = []
381
+
382
+ # This new attribute will track how much history to include for each retry attempt
383
+ self._event_log_items_to_keep = 0
384
+
385
+ if self._is_vertex_ai:
386
+ self._decoding_kwargs.update(
387
+ {
388
+ "vertex_ai_project": os.environ.get("VERTEXAI_PROJECT", ""),
389
+ "vertex_ai_location": os.environ.get("VERTEXAI_LOCATION", ""),
390
+ }
391
+ )
392
+
393
+ @property
394
+ def cost_tracker(self) -> LLMCostTracker:
395
+ return self._cost_tracker
396
+
397
+ def log_token_usage(self):
398
+ cost_history = self._cost_tracker.query_token_cost.cost_history_usd
399
+ query_cost = cost_history[-1] if cost_history else None
400
+ logger.info(
401
+ ", ".join(
402
+ [
403
+ f"*** Total prompt tokens: {self._cost_tracker.prompt_token_cost.total_tokens}",
404
+ f"total completion_tokens: {self._cost_tracker.completion_token_cost.total_tokens}",
405
+ f"total query cost: $ {self._cost_tracker.query_token_cost.total_costs_usd}",
406
+ f"current query cost: $ {query_cost}",
407
+ ]
408
+ )
409
+ )
410
+
411
+ def __del__(self):
412
+ logger.info(
413
+ f"Instance '{self._model_name}' is being deleted. "
414
+ f"Prompt tokens: '{self._cost_tracker.prompt_token_cost.total_tokens}' "
415
+ f"completion_tokens: '{self._cost_tracker.completion_token_cost.total_tokens}'."
416
+ )
417
+
418
+ @tenacity.retry(
419
+ retry=tenacity.retry_if_exception(_is_rate_limit_error),
420
+ stop=tenacity.stop_after_attempt(5),
421
+ wait=tenacity.wait_random_exponential(multiplier=1, min=2, max=10),
422
+ reraise=True,
423
+ )
424
+ def query(self, prompt):
425
+ logger.info(f"prompt for {self._model_name}: {prompt}")
426
+ response = completion(
427
+ model=self._model_name, messages=[{"content": prompt, "role": "user"}], **self._decoding_kwargs
428
+ )
429
+ msg = response["choices"][0]["message"]["content"]
430
+ self._cost_tracker.update(response)
431
+ logger.info(f"message from {self._model_name}: {msg}")
432
+ return msg
433
+
434
+ def parse(self, out: str) -> dict:
435
+ """
436
+ Parses the string output from an LLM into a dictionary.
437
+
438
+ This method implements best practices for parsing potentially-malformed
439
+ JSON output from a large language model.
440
+ 1. It looks for JSON within Markdown code blocks (```json ... ```).
441
+ 2. It attempts to clean the extracted string to fix common LLM mistakes.
442
+ 3. It uses a robust JSON parser.
443
+ 4. If standard parsing fails, it falls back to a regular expression search
444
+ for the most critical fields as a last resort.
445
+
446
+ Args:
447
+ out: The raw string output from the LLM.
448
+
449
+ Returns:
450
+ A dictionary parsed from the JSON, or an empty dictionary if all parsing attempts fail.
451
+ """
452
+ try:
453
+ # 1. Extract JSON string from Markdown code blocks
454
+ if "```json" in out:
455
+ # Find the start and end of the json block
456
+ start = out.find("```json") + len("```json")
457
+ end = out.find("```", start)
458
+ json_str = out[start:end].strip()
459
+ elif "```" in out:
460
+ start = out.find("```") + len("```")
461
+ end = out.find("```", start)
462
+ json_str = out[start:end].strip()
463
+ else:
464
+ # If no code block, assume the whole output might be JSON
465
+ json_str = out
466
+
467
+ # 2. Clean the JSON string
468
+ # Remove trailing commas from objects and arrays which is a common mistake
469
+ json_str = re.sub(r",\s*([\}\]])", r"\1", json_str)
470
+
471
+ # 3. Parse the cleaned string
472
+ return pyjson5.loads(json_str)
473
+ except Exception:
474
+ # Catch any other unexpected errors during string manipulation or parsing
475
+ error_trace = traceback.format_exc()
476
+ logger.error("An error occurred:\n%s", error_trace)
477
+ logger.error(f'The model out failed to parse is model_name="{self._model_name}".')
478
+ logger.error(f"Failed to parse out={out}")
479
+ # reraise the error
480
+ raise
481
+
482
+ def render_prompt(self, instruction: str, obs, max_log_items: int = -1, error_stack_trace=None, error_prompt=None):
483
+ """
484
+ Renders the final prompt, optionally truncating the event log
485
+ to include only the last 'max_log_items' events.
486
+ """
487
+ current_state = self.current_state(obs)
488
+
489
+ # Greedily take the last n items from the event log if a limit is set
490
+ if 0 <= max_log_items < len(self._event_logs):
491
+ event_logs = self._event_logs[-max_log_items:]
492
+ else:
493
+ event_logs = self._event_logs
494
+
495
+ # Build the unified, tagged event logs
496
+ log_parts = []
497
+ day_phase = (None, None)
498
+ for log_key, day, phase, log_item in event_logs:
499
+ if (day, phase) != day_phase:
500
+ day_phase = (day, phase)
501
+ log_parts.append(f"**--- {phase} {day} ---**")
502
+ if log_key == EventLogKeys.PUBLIC_EVENT:
503
+ log_parts.append(f"[EVENT] {log_item.description}")
504
+ elif log_key == EventLogKeys.PRIVATE_ACTION:
505
+ text_parts = [f"[YOUR ACTION & REASONING] You decided to use {type(log_item).__name__} "]
506
+ # account for NOOP
507
+ if log_item.action_field:
508
+ action_field_item = (
509
+ f" - {log_item.action_field.capitalize()}: {getattr(log_item, log_item.action_field)}"
510
+ )
511
+ text_parts.append(action_field_item)
512
+ text_parts.append(f" - Reasoning: {log_item.reasoning}")
513
+ text_parts.append(f" - Perceived threat level: {log_item.perceived_threat_level}")
514
+ log_parts.append("\n".join(text_parts))
515
+
516
+ event_log = "\n\n".join(log_parts)
517
+
518
+ error_instruction = ""
519
+ if error_stack_trace:
520
+ error_instruction = (
521
+ f"\n\nYour previous attempt resulted in the following error:\n{error_stack_trace}\n\n{error_prompt}"
522
+ )
523
+
524
+ content = {
525
+ "system_prompt": self._system_prompt,
526
+ "current_state": json.dumps(current_state, sort_keys=True),
527
+ "event_log": event_log,
528
+ "instruction": instruction,
529
+ "error_instruction": error_instruction,
530
+ }
531
+ return self._prompt_template.format(**content)
532
+
533
+ @staticmethod
534
+ def current_state(obs):
535
+ obs_model = get_raw_observation(obs)
536
+ content = {
537
+ "your_name": obs_model.player_id,
538
+ "your_team": obs_model.team,
539
+ "your_role_name": obs_model.role,
540
+ "all_player_ids": obs_model.all_player_ids,
541
+ "alive_players": obs_model.alive_players,
542
+ "revealed_players": obs_model.revealed_players,
543
+ }
544
+ return content
545
+
546
+ @tenacity.retry(
547
+ retry=tenacity.retry_if_exception(_is_context_window_exceeded_error),
548
+ stop=tenacity.stop_after_attempt(5),
549
+ wait=tenacity.wait_random_exponential(multiplier=1, min=2, max=10),
550
+ before_sleep=_truncate_and_log_on_retry,
551
+ reraise=True,
552
+ )
553
+ def render_prompt_query(self, instruction, obs, error_stack_trace=None, error_prompt=None):
554
+ prompt = self.render_prompt(
555
+ instruction=instruction,
556
+ obs=obs,
557
+ max_log_items=self._event_log_items_to_keep,
558
+ error_stack_trace=error_stack_trace,
559
+ error_prompt=error_prompt,
560
+ )
561
+ out = self.query(prompt)
562
+ return out, prompt
563
+
564
+ @tenacity.retry(
565
+ retry=tenacity.retry_if_exception(_is_json_parsing_error),
566
+ stop=tenacity.stop_after_attempt(3),
567
+ wait=tenacity.wait_random_exponential(multiplier=1, min=2, max=10),
568
+ before_sleep=_add_error_entry_on_retry,
569
+ reraise=True,
570
+ )
571
+ def query_parse(self, instruction, obs, error_stack_trace=None, error_prompt=None):
572
+ raw_out, prompt = self.render_prompt_query(instruction, obs, error_stack_trace, error_prompt)
573
+ try:
574
+ parsed_out = self.parse(raw_out)
575
+ # Add the raw_out and prompt to the output dict
576
+ parsed_out["raw_prompt"] = prompt
577
+ parsed_out["raw_completion"] = raw_out
578
+ return parsed_out
579
+ except pyjson5.Json5Exception as e:
580
+ # Catch the parsing error, wrap it with context, and re-raise.
581
+ # Tenacity will catch this and decide whether to retry.
582
+ raise LLMActionException(
583
+ message="Failed to parse LLM output.", original_exception=e, raw_out=raw_out, prompt=prompt
584
+ )
585
+
586
+ @action_registry.register(DetailedPhase.NIGHT_AWAIT_ACTIONS, RoleConst.WEREWOLF)
587
+ def _night_werewolf_vote(self, entries, obs, common_args):
588
+ # Werewolves target other alive players.
589
+ history_entry = get_last_action_request(entries, EventName.VOTE_REQUEST)
590
+ action = NoOpAction(**common_args, reasoning="There's nothing to be done.")
591
+ if history_entry:
592
+ valid_targets = history_entry.data.get("valid_targets")
593
+ instruction = INSTRUCTION_TEMPLATE.format(
594
+ **{
595
+ "role": "You are a Werewolf.",
596
+ "task": "Vote for a player to eliminate.",
597
+ "additional_constraints": f"- Valid targets are: `{valid_targets}`.",
598
+ "json_schema": json.dumps(TARGETED_ACTION_SCHEMA),
599
+ "exemplar": TARGETED_ACTION_EXEMPLAR,
600
+ }
601
+ )
602
+ parsed_out = self.query_parse(
603
+ instruction, obs, error_prompt="Your previous attempt failed. Please vote again."
604
+ )
605
+ action = EliminateProposalAction(**common_args, **parsed_out)
606
+ return action
607
+
608
+ @action_registry.register(DetailedPhase.NIGHT_AWAIT_ACTIONS, RoleConst.SEER)
609
+ def _night_seer_inspect(self, entries, obs, common_args):
610
+ # Seers can inspect any alive player.
611
+ history_entry = get_last_action_request(entries, EventName.INSPECT_REQUEST)
612
+ action = NoOpAction(**common_args, reasoning="There's nothing to be done.")
613
+ if history_entry:
614
+ valid_targets = history_entry.data["valid_candidates"]
615
+ instruction = INSTRUCTION_TEMPLATE.format(
616
+ **{
617
+ "role": "You are a Seer.",
618
+ "task": "Choose a player to inspect and reveal their role.",
619
+ "additional_constraints": f'- The "target_id" must be in this list: `{valid_targets}`.',
620
+ "json_schema": json.dumps(TARGETED_ACTION_SCHEMA),
621
+ "exemplar": TARGETED_ACTION_EXEMPLAR,
622
+ }
623
+ )
624
+ parsed_out = self.query_parse(
625
+ instruction,
626
+ obs,
627
+ error_prompt="Your previous attempt failed. Please choose one player to inspect again.",
628
+ )
629
+ action = InspectAction(**common_args, **parsed_out)
630
+ return action
631
+
632
+ @action_registry.register(DetailedPhase.NIGHT_AWAIT_ACTIONS, RoleConst.DOCTOR)
633
+ def _night_doctor_heal(self, entries, obs, common_args):
634
+ action = NoOpAction(**common_args, reasoning="There's nothing to be done.")
635
+ history_entry = get_last_action_request(entries, EventName.HEAL_REQUEST)
636
+ if history_entry:
637
+ valid_targets = history_entry.data["valid_candidates"]
638
+ instruction = INSTRUCTION_TEMPLATE.format(
639
+ **{
640
+ "role": "You are a Doctor.",
641
+ "task": "Choose a player to save from the werewolf attack.",
642
+ "additional_constraints": f'- The "target_id" must be in this list: `{valid_targets}`.',
643
+ "json_schema": json.dumps(TARGETED_ACTION_SCHEMA),
644
+ "exemplar": TARGETED_ACTION_EXEMPLAR,
645
+ }
646
+ )
647
+ parsed_out = self.query_parse(
648
+ instruction, obs, error_prompt="Your previous attempt failed. Please choose one player to heal again."
649
+ )
650
+ action = HealAction(**common_args, **parsed_out)
651
+ return action
652
+
653
+ @action_registry.register(DetailedPhase.DAY_BIDDING_AWAIT)
654
+ def _day_bid(self, entries, obs, common_args):
655
+ instruction = INSTRUCTION_TEMPLATE.format(
656
+ **{
657
+ "role": "It is bidding time. You can bid to get a chance to speak.",
658
+ "task": "Decide how much to bid for a speaking turn. A higher bid increases your chance of speaking. You can bid from 0 to 4.",
659
+ "additional_constraints": "- The 'amount' must be an integer between 0 and 4.",
660
+ "json_schema": json.dumps(BID_ACTION_SCHEMA),
661
+ "exemplar": BID_ACTION_EXEMPLAR_REASONING if self._enable_bid_reasoning else BID_ACTION_EXEMPLAR,
662
+ }
663
+ )
664
+ parsed_out = self.query_parse(
665
+ instruction, obs, error_prompt="Your previous attempt failed. Please place your bid again."
666
+ )
667
+ action = BidAction(**common_args, **parsed_out)
668
+ return action
669
+
670
+ @action_registry.register(DetailedPhase.DAY_CHAT_AWAIT)
671
+ def _day_chat(self, entries, obs, common_args):
672
+ # All alive players can discuss.
673
+ if self._chat_mode == "text":
674
+ constraints = CHAT_ACTION_ADDITIONAL_CONSTRAINTS_TEXT
675
+ exemplar = CHAT_ACTION_EXEMPLAR_TEXT
676
+ elif self._chat_mode == "audio": # audio mode
677
+ constraints = CHAT_ACTION_ADDITIONAL_CONSTRAINTS_AUDIO
678
+ exemplar = CHAT_ACTION_EXEMPLAR
679
+ else:
680
+ raise ValueError(
681
+ f'Can only select between "text" mode and "audio" mode to prompt the LLM. "{self._chat_mode}" mode detected.'
682
+ )
683
+ instruction = INSTRUCTION_TEMPLATE.format(
684
+ **{
685
+ "role": "It is day time. Participate in the discussion.",
686
+ "task": 'Discuss with other players to decide who to vote out. Formulate a "message" to persuade others.',
687
+ "additional_constraints": "\n".join(constraints),
688
+ "json_schema": json.dumps(CHAT_ACTION_SCHEMA),
689
+ "exemplar": exemplar,
690
+ }
691
+ )
692
+ parsed_out = self.query_parse(
693
+ instruction, obs, error_prompt="Your previous attempt failed. Please prepare your message again."
694
+ )
695
+ action = ChatAction(**common_args, **parsed_out)
696
+ return action
697
+
698
+ @action_registry.register(DetailedPhase.DAY_VOTING_AWAIT)
699
+ def _day_vote(self, entries, obs, common_args):
700
+ raw_obs = get_raw_observation(obs)
701
+ alive_players = raw_obs.alive_players
702
+ my_id = raw_obs.player_id
703
+ valid_targets = [p for p in alive_players if p != my_id]
704
+ instruction = INSTRUCTION_TEMPLATE.format(
705
+ **{
706
+ "role": "It is day time. It is time to vote.",
707
+ "task": "Choose a player to exile.",
708
+ "additional_constraints": f'- The "target_id" must be in this list: `{valid_targets}`.',
709
+ "json_schema": json.dumps(TARGETED_ACTION_SCHEMA),
710
+ "exemplar": TARGETED_ACTION_EXEMPLAR,
711
+ }
712
+ )
713
+ parsed_out = self.query_parse(
714
+ instruction, obs, error_prompt="Your previous attempt failed. Please cast your vote again."
715
+ )
716
+ action = VoteAction(**common_args, **parsed_out)
717
+ return action
718
+
719
+ def __call__(self, obs):
720
+ raw_obs = get_raw_observation(obs)
721
+ entries = raw_obs.new_player_event_views
722
+
723
+ for entry in entries:
724
+ self._event_logs.append(
725
+ EventLogItem(EventLogKeys.PUBLIC_EVENT, day=entry.day, phase=entry.phase, log_item=entry)
726
+ )
727
+
728
+ # Default to NO_OP if observation is missing or agent cannot act
729
+ if not raw_obs or not entries:
730
+ return {"action_type": ActionType.NO_OP.value, "target_idx": None, "message": None}
731
+
732
+ self._event_log_items_to_keep = len(self._event_logs)
733
+
734
+ current_phase = DetailedPhase(raw_obs.detailed_phase)
735
+ my_role = RoleConst(raw_obs.role)
736
+
737
+ common_args = {"day": raw_obs.day, "phase": raw_obs.game_state_phase, "actor_id": raw_obs.player_id}
738
+
739
+ handler = self.action_registry.get(phase=current_phase, role=my_role)
740
+
741
+ try:
742
+ action = handler(self, entries, obs, common_args)
743
+ except LLMActionException as e:
744
+ # Catch the specific exception after all retries have failed
745
+ error_trace = traceback.format_exc()
746
+ logger.error("An LLMActionException occurred after all retries:\n%s", error_trace)
747
+ logger.error(f'The model failed to act is model_name="{self._model_name}".')
748
+
749
+ # Now you can access the preserved data!
750
+ action = NoOpAction(
751
+ **common_args,
752
+ reasoning="Fell back to NoOp after multiple parsing failures.",
753
+ error=error_trace,
754
+ raw_completion=e.raw_out, # <-- Preserved data
755
+ raw_prompt=e.prompt, # <-- Preserved data
756
+ )
757
+ except Exception:
758
+ error_trace = traceback.format_exc()
759
+ logger.error("An error occurred:\n%s", error_trace)
760
+ logger.error(f'The model failed to act is model_name="{self._model_name}".')
761
+ action = NoOpAction(**common_args, reasoning="", error=error_trace)
762
+ self.log_token_usage()
763
+ # record self action
764
+ self._event_logs.append(
765
+ EventLogItem(EventLogKeys.PRIVATE_ACTION, day=raw_obs.day, phase=raw_obs.game_state_phase, log_item=action)
766
+ )
767
+ return action.serialize()