kaggle-environments 1.23.2__py3-none-any.whl → 1.23.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaggle-environments might be problematic. Click here for more details.
- kaggle_environments/envs/open_spiel_env/games/repeated_poker/repeated_poker.js +2 -2
- kaggle_environments/envs/open_spiel_env/games/repeated_poker/visualizer/default/src/components/getRepeatedPokerStateForStep.js +6 -6
- kaggle_environments/envs/open_spiel_env/games/repeated_poker/visualizer/default/src/images/poker_chip_1.svg +22 -0
- kaggle_environments/envs/open_spiel_env/games/repeated_poker/visualizer/default/src/images/poker_chip_10.svg +22 -0
- kaggle_environments/envs/open_spiel_env/games/repeated_poker/visualizer/default/src/images/poker_chip_100.svg +48 -0
- kaggle_environments/envs/open_spiel_env/games/repeated_poker/visualizer/default/src/images/poker_chip_25.svg +22 -0
- kaggle_environments/envs/open_spiel_env/games/repeated_poker/visualizer/default/src/images/poker_chip_5.svg +22 -0
- kaggle_environments/envs/open_spiel_env/games/repeated_poker/visualizer/default/src/repeated_poker_renderer.js +586 -361
- kaggle_environments/envs/werewolf/README.md +190 -0
- kaggle_environments/envs/werewolf/harness/__init__.py +0 -0
- kaggle_environments/envs/werewolf/harness/base.py +767 -0
- kaggle_environments/envs/werewolf/harness/litellm_models.yaml +51 -0
- kaggle_environments/envs/werewolf/harness/test_base.py +35 -0
- kaggle_environments/envs/werewolf/runner.py +146 -0
- kaggle_environments/envs/werewolf/scripts/__init__.py +0 -0
- kaggle_environments/envs/werewolf/scripts/add_audio.py +425 -0
- kaggle_environments/envs/werewolf/scripts/configs/audio/standard.yaml +24 -0
- kaggle_environments/envs/werewolf/scripts/configs/run/block_basic.yaml +102 -0
- kaggle_environments/envs/werewolf/scripts/configs/run/comprehensive.yaml +100 -0
- kaggle_environments/envs/werewolf/scripts/configs/run/roundrobin_discussion_DisableDoctorSelfSave_DisableDoctorConsecutiveSave_large.yaml +104 -0
- kaggle_environments/envs/werewolf/scripts/configs/run/roundrobin_discussion_large.yaml +103 -0
- kaggle_environments/envs/werewolf/scripts/configs/run/roundrobin_discussion_small.yaml +103 -0
- kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard.yaml +103 -0
- kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard_DisableDoctorSelfSave_DisableDoctorConsecutiveSave.yaml +104 -0
- kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard_DisableDoctorSelfSave_SeerRevealTeam.yaml +105 -0
- kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard_DisableDoctorSelfSave_SeerRevealTeam_NightEliminationNoReveal_DayExileNoReveal.yaml +105 -0
- kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard_DisableDoctorSelfSave_SeerRevealTeam_NightEliminationRevealTeam_DayExileRevealTeam.yaml +105 -0
- kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard_disable_doctor_self_save.yaml +103 -0
- kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard_parallel_voting.yaml +103 -0
- kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard_parallel_voting_no_tie_exile.yaml +103 -0
- kaggle_environments/envs/werewolf/scripts/configs/run/rule_experiment/standard_parallel_voting_roundbiddiscussion.yaml +105 -0
- kaggle_environments/envs/werewolf/scripts/configs/run/run_config.yaml +58 -0
- kaggle_environments/envs/werewolf/scripts/configs/run/vertex_api_example_config.yaml +115 -0
- kaggle_environments/envs/werewolf/scripts/measure_cost.py +251 -0
- kaggle_environments/envs/werewolf/scripts/plot_existing_trajectories.py +135 -0
- kaggle_environments/envs/werewolf/scripts/rerender_html.py +87 -0
- kaggle_environments/envs/werewolf/scripts/run.py +93 -0
- kaggle_environments/envs/werewolf/scripts/run_block.py +237 -0
- kaggle_environments/envs/werewolf/scripts/run_pairwise_matrix.py +222 -0
- kaggle_environments/envs/werewolf/scripts/self_play.py +196 -0
- kaggle_environments/envs/werewolf/scripts/utils.py +47 -0
- {kaggle_environments-1.23.2.dist-info → kaggle_environments-1.23.4.dist-info}/METADATA +1 -1
- {kaggle_environments-1.23.2.dist-info → kaggle_environments-1.23.4.dist-info}/RECORD +46 -8
- {kaggle_environments-1.23.2.dist-info → kaggle_environments-1.23.4.dist-info}/WHEEL +0 -0
- {kaggle_environments-1.23.2.dist-info → kaggle_environments-1.23.4.dist-info}/entry_points.txt +0 -0
- {kaggle_environments-1.23.2.dist-info → kaggle_environments-1.23.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,767 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import re
|
|
6
|
+
import traceback
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from collections import namedtuple
|
|
9
|
+
from typing import List, Optional
|
|
10
|
+
|
|
11
|
+
import litellm
|
|
12
|
+
import pyjson5
|
|
13
|
+
import tenacity
|
|
14
|
+
import yaml
|
|
15
|
+
from dotenv import load_dotenv
|
|
16
|
+
from litellm import completion, cost_per_token
|
|
17
|
+
from litellm.types.utils import Usage
|
|
18
|
+
from pydantic import BaseModel, Field
|
|
19
|
+
|
|
20
|
+
from kaggle_environments.envs.werewolf.game.actions import (
|
|
21
|
+
BidAction,
|
|
22
|
+
ChatAction,
|
|
23
|
+
EliminateProposalAction,
|
|
24
|
+
HealAction,
|
|
25
|
+
InspectAction,
|
|
26
|
+
NoOpAction,
|
|
27
|
+
TargetedAction,
|
|
28
|
+
VoteAction,
|
|
29
|
+
)
|
|
30
|
+
from kaggle_environments.envs.werewolf.game.consts import ActionType, DetailedPhase, EventName, RoleConst
|
|
31
|
+
from kaggle_environments.envs.werewolf.game.records import get_raw_observation
|
|
32
|
+
from kaggle_environments.envs.werewolf.game.states import get_last_action_request
|
|
33
|
+
|
|
34
|
+
_LITELLM_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "litellm_models.yaml")
|
|
35
|
+
litellm.config_path = _LITELLM_CONFIG_PATH
|
|
36
|
+
with open(_LITELLM_CONFIG_PATH, "r") as _file:
|
|
37
|
+
_MODEL_COST_DICT = yaml.safe_load(_file)
|
|
38
|
+
litellm.register_model(_MODEL_COST_DICT)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
logger = logging.getLogger(__name__)
|
|
42
|
+
|
|
43
|
+
litellm.drop_params = True
|
|
44
|
+
|
|
45
|
+
# Load environment variables from a .env file in the same directory
|
|
46
|
+
load_dotenv()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class LLMActionException(Exception):
|
|
50
|
+
"""Custom exception to carry context from a failed LLM action."""
|
|
51
|
+
|
|
52
|
+
def __init__(self, message, original_exception, raw_out=None, prompt=None):
|
|
53
|
+
super().__init__(message)
|
|
54
|
+
self.original_exception = original_exception
|
|
55
|
+
self.raw_out = raw_out
|
|
56
|
+
self.prompt = prompt
|
|
57
|
+
|
|
58
|
+
def __str__(self):
|
|
59
|
+
return f"{super().__str__()} | Raw Output: '{self.raw_out}'"
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _log_retry_warning(retry_state: tenacity.RetryCallState):
|
|
63
|
+
assert retry_state.outcome is not None
|
|
64
|
+
exception = retry_state.outcome.exception()
|
|
65
|
+
traceback_str = "".join(traceback.format_exception(exception))
|
|
66
|
+
if retry_state.attempt_number < 1:
|
|
67
|
+
loglevel = logging.INFO
|
|
68
|
+
else:
|
|
69
|
+
loglevel = logging.WARNING
|
|
70
|
+
logging.log(
|
|
71
|
+
loglevel,
|
|
72
|
+
"Retrying: $s attempt # %s ended with: $s Traceback: %s Retry state: %s",
|
|
73
|
+
retry_state.fn,
|
|
74
|
+
retry_state.attempt_number,
|
|
75
|
+
retry_state.outcome,
|
|
76
|
+
traceback_str,
|
|
77
|
+
retry_state,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _is_rate_limit_error(exception) -> bool:
|
|
82
|
+
"""
|
|
83
|
+
Checks if an exception is a RateLimitError that warrants a context reduction retry.
|
|
84
|
+
This checks for both OpenAI's specific error and the generic HTTP 429 status code.
|
|
85
|
+
"""
|
|
86
|
+
is_openai_rate_limit = "RateLimitError" in str(type(exception))
|
|
87
|
+
is_http_429 = hasattr(exception, "status_code") and exception.status_code == 429
|
|
88
|
+
return is_openai_rate_limit or is_http_429
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _is_context_window_exceeded_error(exception) -> bool:
|
|
92
|
+
""""""
|
|
93
|
+
is_error = "ContextWindowExceededError" in str(type(exception))
|
|
94
|
+
return is_error
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _is_json_parsing_error(exception) -> bool:
|
|
98
|
+
out = True if isinstance(exception, pyjson5.Json5Exception) else False
|
|
99
|
+
return out
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _truncate_and_log_on_retry(retry_state: tenacity.RetryCallState):
|
|
103
|
+
"""
|
|
104
|
+
Tenacity hook called before a retry. It reduces the context size if a
|
|
105
|
+
RateLimitError was detected.
|
|
106
|
+
"""
|
|
107
|
+
# The first argument of the retried method is the class instance 'self'
|
|
108
|
+
agent_instance = retry_state.args[0]
|
|
109
|
+
|
|
110
|
+
if _is_rate_limit_error(retry_state.outcome.exception()):
|
|
111
|
+
# Reduce the number of history items to keep by 25% on each attempt
|
|
112
|
+
original_count = agent_instance._event_log_items_to_keep
|
|
113
|
+
agent_instance._event_log_items_to_keep = int(original_count * 0.75)
|
|
114
|
+
|
|
115
|
+
logger.warning(
|
|
116
|
+
"ContextWindowExceededError detected. Retrying with smaller context. "
|
|
117
|
+
"Reducing event log from %d to %d itms.",
|
|
118
|
+
original_count,
|
|
119
|
+
agent_instance._event_log_items_to_keep,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Also call the original logging function for general retry logging
|
|
123
|
+
_log_retry_warning(retry_state)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _add_error_entry_on_retry(retry_state: tenacity.RetryCallState):
|
|
127
|
+
last_exception_wrapper = retry_state.outcome.exception()
|
|
128
|
+
if isinstance(last_exception_wrapper, LLMActionException):
|
|
129
|
+
last_exception = last_exception_wrapper.original_exception
|
|
130
|
+
# You can also access the failed output here if needed for logging
|
|
131
|
+
raw_out = last_exception_wrapper.raw_out
|
|
132
|
+
prompt = last_exception_wrapper.prompt
|
|
133
|
+
logger.warning(f"Retrying due to JSON parsing error. Failed output: {raw_out} Failed prompt: {prompt}")
|
|
134
|
+
else:
|
|
135
|
+
last_exception = last_exception_wrapper
|
|
136
|
+
|
|
137
|
+
stack_trace_list = traceback.format_exception(last_exception)
|
|
138
|
+
stack_trace_str = "".join(stack_trace_list)
|
|
139
|
+
retry_state.kwargs["error_stack_trace"] = stack_trace_str
|
|
140
|
+
_log_retry_warning(retry_state)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
TARGETED_ACTION_SCHEMA = TargetedAction.schema_for_player()
|
|
144
|
+
CHAT_ACTION_SCHEMA = ChatAction.schema_for_player()
|
|
145
|
+
|
|
146
|
+
BID_ACTION_SCHEMA = BidAction.schema_for_player()
|
|
147
|
+
BID_ACTION_SCHEMA_REASONING = BidAction.schema_for_player(("perceived_threat_level", "reasoning", "target_id"))
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
TARGETED_ACTION_EXEMPLAR = f"""```json
|
|
151
|
+
{json.dumps(dict(perceived_threat_level="SAFE", reasoning="I chose this target randomly.", target_id="some_player_id"))}
|
|
152
|
+
```"""
|
|
153
|
+
|
|
154
|
+
BID_ACTION_EXEMPLAR = f"""```json
|
|
155
|
+
{json.dumps(dict(perceived_threat_level="UNEASY", amount=4))}
|
|
156
|
+
```"""
|
|
157
|
+
BID_ACTION_EXEMPLAR_REASONING = f"""```json
|
|
158
|
+
{json.dumps(dict(perceived_threat_level="UNEASY", reasoning="I have important information to share, so I am bidding high.", amount=4))}
|
|
159
|
+
```"""
|
|
160
|
+
|
|
161
|
+
AUDIO_EXAMPLE = 'Say in an spooky whisper: "By the pricking of my thumbs... Something wicked this way comes!"'
|
|
162
|
+
AUDIO_EXAMPLE_2 = 'Deliver in a thoughtful tone: "I was stunned. I really suspect John\'s intent of bringing up Tim."'
|
|
163
|
+
AUDIO_EXAMPLE_3 = (
|
|
164
|
+
'Read this in as fast as possible while remaining intelligible: "My nomination for Jack was purely incidental."'
|
|
165
|
+
)
|
|
166
|
+
AUDIO_EXAMPLE_4 = 'Sound amused and relaxed: "that was a very keen observation, AND a classic wolf play.\n(voice: curious)\nI\'m wondering what the seer might say."'
|
|
167
|
+
CHAT_AUDIO_DICT = {
|
|
168
|
+
"perceived_threat_level": "SAFE",
|
|
169
|
+
"reasoning": "To draw attention to other players ...",
|
|
170
|
+
"message": AUDIO_EXAMPLE,
|
|
171
|
+
}
|
|
172
|
+
CHAT_AUDIO_DICT_2 = {
|
|
173
|
+
"perceived_threat_level": "DANGER",
|
|
174
|
+
"reasoning": "This accusation is uncalled for ...",
|
|
175
|
+
"message": AUDIO_EXAMPLE_2,
|
|
176
|
+
}
|
|
177
|
+
CHAT_AUDIO_DICT_3 = {
|
|
178
|
+
"perceived_threat_level": "UNEASY",
|
|
179
|
+
"reasoning": "I sense there are some suspicion directed towards me ...",
|
|
180
|
+
"message": AUDIO_EXAMPLE_3,
|
|
181
|
+
}
|
|
182
|
+
CHAT_AUDIO_DICT_4 = {
|
|
183
|
+
"perceived_threat_level": "UNEASY",
|
|
184
|
+
"reasoning": "I am redirecting the attention to other leads ...",
|
|
185
|
+
"message": AUDIO_EXAMPLE_4,
|
|
186
|
+
}
|
|
187
|
+
CHAT_ACTION_EXEMPLAR_2 = f"```json\n{json.dumps(CHAT_AUDIO_DICT)}\n```"
|
|
188
|
+
CHAT_ACTION_EXEMPLAR_3 = f"```json\n{json.dumps(CHAT_AUDIO_DICT_2)}\n```"
|
|
189
|
+
CHAT_ACTION_EXEMPLAR = f"```json\n{json.dumps(CHAT_AUDIO_DICT_3)}\n```"
|
|
190
|
+
CHAT_ACTION_EXEMPLAR_4 = f"```json\n{json.dumps(CHAT_AUDIO_DICT_4)}\n```"
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
CHAT_ACTION_ADDITIONAL_CONSTRAINTS_AUDIO = [
|
|
194
|
+
f'- The "message" will be rendered to TTS and shown to other players, so make sure to control the style, tone, '
|
|
195
|
+
f"accent and pace of your message using natural language prompt. e.g.\n{CHAT_ACTION_EXEMPLAR_2}",
|
|
196
|
+
"- Since this is a social game, the script in the message should sound conversational.",
|
|
197
|
+
'- Be Informal: Use contractions (like "it\'s," "gonna"), and simple language.',
|
|
198
|
+
"- Be Spontaneous: Vary your sentence length. It's okay to have short, incomplete thoughts or to restart a sentence.",
|
|
199
|
+
"- [Optional] If appropriate, you could add natural sounds in (sound: ...) e.g. (sound: chuckles), or (sound: laughs), etc.",
|
|
200
|
+
"- [Optional] Be Dynamic: A real chat is never monotonous. Use (voice: ...) instructions to constantly and subtly shift the tone to match the words.",
|
|
201
|
+
# f'- Be Expressive: Use a variety of descriptive tones. Don\'t just use happy or sad. Try tones like amused, '
|
|
202
|
+
# f'thoughtful, curious, energetic, sarcastic, or conspiratorial. e.g. \n{CHAT_ACTION_EXEMPLAR_4}'
|
|
203
|
+
]
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
CHAT_TEXT_DICT = {
|
|
207
|
+
"perceived_threat_level": "UNEASY",
|
|
208
|
+
"reasoning": "I want to put pressure on Player3 and see how they react. A quiet player is often a werewolf.",
|
|
209
|
+
"message": "I'm suspicious of Player3. They've been too quiet. What do you all think?",
|
|
210
|
+
}
|
|
211
|
+
CHAT_ACTION_EXEMPLAR_TEXT = f"```json\n{json.dumps(CHAT_TEXT_DICT)}\n```"
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
CHAT_ACTION_ADDITIONAL_CONSTRAINTS_TEXT = [
|
|
215
|
+
'- The "message" will be displayed as text to other players. Focus on being clear and persuasive',
|
|
216
|
+
"- Your goal is to win the game as a team. Think about how to reach that goal strategically.",
|
|
217
|
+
'- Refer to players by their ID (e.g., "Player1", "Player3") to avoid ambiguity.',
|
|
218
|
+
"- Keep your messages concise and to the point. ",
|
|
219
|
+
'- You can simply say "Pass!", if you have nothing valuable you would like to share.',
|
|
220
|
+
]
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class WerewolfAgentBase(ABC):
|
|
224
|
+
@abstractmethod
|
|
225
|
+
def __call__(self, obs):
|
|
226
|
+
"""The instance is meant to be used as callable for kaggle environments."""
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
DEFAULT_PROMPT_TEMPLATE = """{system_prompt}
|
|
230
|
+
|
|
231
|
+
### Current Game State
|
|
232
|
+
{current_state}
|
|
233
|
+
|
|
234
|
+
### Game Timeline
|
|
235
|
+
This is the complete, chronological timeline of all public events and your private actions.
|
|
236
|
+
{event_log}
|
|
237
|
+
|
|
238
|
+
### Your Instruction
|
|
239
|
+
Based on the game state and event log, please respond to the following instruction.
|
|
240
|
+
|
|
241
|
+
{instruction}{error_instruction}
|
|
242
|
+
"""
|
|
243
|
+
|
|
244
|
+
INSTRUCTION_TEMPLATE = """#### ROLE
|
|
245
|
+
{role}
|
|
246
|
+
|
|
247
|
+
#### TASK
|
|
248
|
+
{task}
|
|
249
|
+
|
|
250
|
+
#### CONSTRAINTS
|
|
251
|
+
- Your response MUST be a single, valid JSON object.
|
|
252
|
+
- generate the "reasoning" key first to think through your response. Your "reasoning" is invisible to other players.
|
|
253
|
+
{additional_constraints}
|
|
254
|
+
|
|
255
|
+
#### JSON SCHEMA
|
|
256
|
+
Your JSON output must conform to the following schema. Do NOT include this schema in your response.
|
|
257
|
+
```json
|
|
258
|
+
{json_schema}
|
|
259
|
+
```
|
|
260
|
+
|
|
261
|
+
#### EXAMPLE OUTPUT
|
|
262
|
+
Please format your response as a Markdown JSON code block, which should include the fences. Here's a valid example:
|
|
263
|
+
{exemplar}
|
|
264
|
+
"""
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
class TokenCost(BaseModel):
|
|
268
|
+
total_tokens: int = 0
|
|
269
|
+
total_costs_usd: float = 0.0
|
|
270
|
+
token_count_history: List[int] = []
|
|
271
|
+
cost_history_usd: List[float] = []
|
|
272
|
+
|
|
273
|
+
def update(self, token_count, cost):
|
|
274
|
+
self.total_tokens += token_count
|
|
275
|
+
self.total_costs_usd += cost
|
|
276
|
+
self.token_count_history.append(token_count)
|
|
277
|
+
self.cost_history_usd.append(cost)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
class LLMCostTracker(BaseModel):
|
|
281
|
+
model_name: str
|
|
282
|
+
query_token_cost: TokenCost = Field(default_factory=TokenCost)
|
|
283
|
+
prompt_token_cost: TokenCost = Field(default_factory=TokenCost)
|
|
284
|
+
completion_token_cost: TokenCost = Field(default_factory=TokenCost)
|
|
285
|
+
usage_history: List[Usage] = []
|
|
286
|
+
"""example item from gemini flash model dump: response.usage = {'completion_tokens': 579, 'prompt_tokens': 1112,
|
|
287
|
+
'total_tokens': 1691, 'completion_tokens_details': {'accepted_prediction_tokens': None,
|
|
288
|
+
'audio_tokens': None, 'reasoning_tokens': 483, 'rejected_prediction_tokens': None,
|
|
289
|
+
'text_tokens': 96}, 'prompt_tokens_details': {'audio_tokens': None, 'cached_tokens': None,
|
|
290
|
+
'text_tokens': 1112, 'image_tokens': None}}"""
|
|
291
|
+
|
|
292
|
+
def update(self, response):
|
|
293
|
+
completion_tokens = response["usage"]["completion_tokens"]
|
|
294
|
+
prompt_tokens = response["usage"]["prompt_tokens"]
|
|
295
|
+
response_cost = response._hidden_params["response_cost"]
|
|
296
|
+
|
|
297
|
+
try:
|
|
298
|
+
prompt_cost, completion_cost = cost_per_token(
|
|
299
|
+
model=self.model_name, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
|
|
300
|
+
)
|
|
301
|
+
logger.info(f"Used litellm cost for {self.model_name}")
|
|
302
|
+
except Exception as exception:
|
|
303
|
+
raise Exception(
|
|
304
|
+
f"Could not find cost for {self.model_name} in litellm or custom dict. "
|
|
305
|
+
f'You can register the cost in "litellm_models.yaml"'
|
|
306
|
+
) from exception
|
|
307
|
+
|
|
308
|
+
self.query_token_cost.update(token_count=prompt_tokens + completion_tokens, cost=response_cost)
|
|
309
|
+
self.prompt_token_cost.update(token_count=prompt_tokens, cost=prompt_cost)
|
|
310
|
+
self.completion_token_cost.update(token_count=completion_tokens, cost=completion_cost)
|
|
311
|
+
self.usage_history.append(response.usage)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
class ActionRegistry:
|
|
315
|
+
"""A registry for action handler based on phase and role."""
|
|
316
|
+
|
|
317
|
+
def __init__(self):
|
|
318
|
+
self._registry = {}
|
|
319
|
+
|
|
320
|
+
def register(self, phase: DetailedPhase, role: Optional[RoleConst] = None):
|
|
321
|
+
"""If an action is not role specific, role can be left as None, in which case all roles will be
|
|
322
|
+
pointing to the same handler.
|
|
323
|
+
"""
|
|
324
|
+
|
|
325
|
+
def decorator(func):
|
|
326
|
+
self._registry.setdefault(phase, {})
|
|
327
|
+
if role is not None:
|
|
328
|
+
self._registry[phase][role] = func
|
|
329
|
+
else:
|
|
330
|
+
for item in RoleConst:
|
|
331
|
+
self._registry[phase][item] = func
|
|
332
|
+
|
|
333
|
+
@functools.wraps(func)
|
|
334
|
+
def wrapper(*args, **kwargs):
|
|
335
|
+
return func(*args, **kwargs)
|
|
336
|
+
|
|
337
|
+
return wrapper
|
|
338
|
+
|
|
339
|
+
return decorator
|
|
340
|
+
|
|
341
|
+
def get(self, phase: DetailedPhase, role: RoleConst):
|
|
342
|
+
func = self._registry[phase][role]
|
|
343
|
+
return func
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
class EventLogKeys:
|
|
347
|
+
PUBLIC_EVENT = "public_event"
|
|
348
|
+
PRIVATE_ACTION = "private_action"
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
EventLogItem = namedtuple("EventLogItem", ["event_log_key", "day", "phase", "log_item"])
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
class LLMWerewolfAgent(WerewolfAgentBase):
|
|
355
|
+
action_registry = ActionRegistry()
|
|
356
|
+
|
|
357
|
+
def __init__(
|
|
358
|
+
self,
|
|
359
|
+
model_name: str,
|
|
360
|
+
agent_config: dict = None,
|
|
361
|
+
system_prompt: str = "",
|
|
362
|
+
prompt_template: str = DEFAULT_PROMPT_TEMPLATE,
|
|
363
|
+
kaggle_config=None,
|
|
364
|
+
):
|
|
365
|
+
"""This wrapper only support 1 LLM."""
|
|
366
|
+
agent_config = agent_config or {}
|
|
367
|
+
decoding_kwargs = agent_config.get("llms", [{}])[0].get("parameters")
|
|
368
|
+
self._decoding_kwargs = decoding_kwargs or {}
|
|
369
|
+
self._kaggle_config = kaggle_config or {}
|
|
370
|
+
self._chat_mode = agent_config.get("chat_mode", "audio")
|
|
371
|
+
self._enable_bid_reasoning = agent_config.get("enable_bid_reasoning", False)
|
|
372
|
+
self._cost_tracker = LLMCostTracker(model_name=model_name)
|
|
373
|
+
|
|
374
|
+
self._model_name = model_name
|
|
375
|
+
self._system_prompt = system_prompt
|
|
376
|
+
self._prompt_template = prompt_template
|
|
377
|
+
self._is_vertex_ai = "vertex_ai" in self._model_name
|
|
378
|
+
|
|
379
|
+
# storing all events including internal and external
|
|
380
|
+
self._event_logs: List[EventLogItem] = []
|
|
381
|
+
|
|
382
|
+
# This new attribute will track how much history to include for each retry attempt
|
|
383
|
+
self._event_log_items_to_keep = 0
|
|
384
|
+
|
|
385
|
+
if self._is_vertex_ai:
|
|
386
|
+
self._decoding_kwargs.update(
|
|
387
|
+
{
|
|
388
|
+
"vertex_ai_project": os.environ.get("VERTEXAI_PROJECT", ""),
|
|
389
|
+
"vertex_ai_location": os.environ.get("VERTEXAI_LOCATION", ""),
|
|
390
|
+
}
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
@property
|
|
394
|
+
def cost_tracker(self) -> LLMCostTracker:
|
|
395
|
+
return self._cost_tracker
|
|
396
|
+
|
|
397
|
+
def log_token_usage(self):
|
|
398
|
+
cost_history = self._cost_tracker.query_token_cost.cost_history_usd
|
|
399
|
+
query_cost = cost_history[-1] if cost_history else None
|
|
400
|
+
logger.info(
|
|
401
|
+
", ".join(
|
|
402
|
+
[
|
|
403
|
+
f"*** Total prompt tokens: {self._cost_tracker.prompt_token_cost.total_tokens}",
|
|
404
|
+
f"total completion_tokens: {self._cost_tracker.completion_token_cost.total_tokens}",
|
|
405
|
+
f"total query cost: $ {self._cost_tracker.query_token_cost.total_costs_usd}",
|
|
406
|
+
f"current query cost: $ {query_cost}",
|
|
407
|
+
]
|
|
408
|
+
)
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
def __del__(self):
|
|
412
|
+
logger.info(
|
|
413
|
+
f"Instance '{self._model_name}' is being deleted. "
|
|
414
|
+
f"Prompt tokens: '{self._cost_tracker.prompt_token_cost.total_tokens}' "
|
|
415
|
+
f"completion_tokens: '{self._cost_tracker.completion_token_cost.total_tokens}'."
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
@tenacity.retry(
|
|
419
|
+
retry=tenacity.retry_if_exception(_is_rate_limit_error),
|
|
420
|
+
stop=tenacity.stop_after_attempt(5),
|
|
421
|
+
wait=tenacity.wait_random_exponential(multiplier=1, min=2, max=10),
|
|
422
|
+
reraise=True,
|
|
423
|
+
)
|
|
424
|
+
def query(self, prompt):
|
|
425
|
+
logger.info(f"prompt for {self._model_name}: {prompt}")
|
|
426
|
+
response = completion(
|
|
427
|
+
model=self._model_name, messages=[{"content": prompt, "role": "user"}], **self._decoding_kwargs
|
|
428
|
+
)
|
|
429
|
+
msg = response["choices"][0]["message"]["content"]
|
|
430
|
+
self._cost_tracker.update(response)
|
|
431
|
+
logger.info(f"message from {self._model_name}: {msg}")
|
|
432
|
+
return msg
|
|
433
|
+
|
|
434
|
+
def parse(self, out: str) -> dict:
|
|
435
|
+
"""
|
|
436
|
+
Parses the string output from an LLM into a dictionary.
|
|
437
|
+
|
|
438
|
+
This method implements best practices for parsing potentially-malformed
|
|
439
|
+
JSON output from a large language model.
|
|
440
|
+
1. It looks for JSON within Markdown code blocks (```json ... ```).
|
|
441
|
+
2. It attempts to clean the extracted string to fix common LLM mistakes.
|
|
442
|
+
3. It uses a robust JSON parser.
|
|
443
|
+
4. If standard parsing fails, it falls back to a regular expression search
|
|
444
|
+
for the most critical fields as a last resort.
|
|
445
|
+
|
|
446
|
+
Args:
|
|
447
|
+
out: The raw string output from the LLM.
|
|
448
|
+
|
|
449
|
+
Returns:
|
|
450
|
+
A dictionary parsed from the JSON, or an empty dictionary if all parsing attempts fail.
|
|
451
|
+
"""
|
|
452
|
+
try:
|
|
453
|
+
# 1. Extract JSON string from Markdown code blocks
|
|
454
|
+
if "```json" in out:
|
|
455
|
+
# Find the start and end of the json block
|
|
456
|
+
start = out.find("```json") + len("```json")
|
|
457
|
+
end = out.find("```", start)
|
|
458
|
+
json_str = out[start:end].strip()
|
|
459
|
+
elif "```" in out:
|
|
460
|
+
start = out.find("```") + len("```")
|
|
461
|
+
end = out.find("```", start)
|
|
462
|
+
json_str = out[start:end].strip()
|
|
463
|
+
else:
|
|
464
|
+
# If no code block, assume the whole output might be JSON
|
|
465
|
+
json_str = out
|
|
466
|
+
|
|
467
|
+
# 2. Clean the JSON string
|
|
468
|
+
# Remove trailing commas from objects and arrays which is a common mistake
|
|
469
|
+
json_str = re.sub(r",\s*([\}\]])", r"\1", json_str)
|
|
470
|
+
|
|
471
|
+
# 3. Parse the cleaned string
|
|
472
|
+
return pyjson5.loads(json_str)
|
|
473
|
+
except Exception:
|
|
474
|
+
# Catch any other unexpected errors during string manipulation or parsing
|
|
475
|
+
error_trace = traceback.format_exc()
|
|
476
|
+
logger.error("An error occurred:\n%s", error_trace)
|
|
477
|
+
logger.error(f'The model out failed to parse is model_name="{self._model_name}".')
|
|
478
|
+
logger.error(f"Failed to parse out={out}")
|
|
479
|
+
# reraise the error
|
|
480
|
+
raise
|
|
481
|
+
|
|
482
|
+
def render_prompt(self, instruction: str, obs, max_log_items: int = -1, error_stack_trace=None, error_prompt=None):
|
|
483
|
+
"""
|
|
484
|
+
Renders the final prompt, optionally truncating the event log
|
|
485
|
+
to include only the last 'max_log_items' events.
|
|
486
|
+
"""
|
|
487
|
+
current_state = self.current_state(obs)
|
|
488
|
+
|
|
489
|
+
# Greedily take the last n items from the event log if a limit is set
|
|
490
|
+
if 0 <= max_log_items < len(self._event_logs):
|
|
491
|
+
event_logs = self._event_logs[-max_log_items:]
|
|
492
|
+
else:
|
|
493
|
+
event_logs = self._event_logs
|
|
494
|
+
|
|
495
|
+
# Build the unified, tagged event logs
|
|
496
|
+
log_parts = []
|
|
497
|
+
day_phase = (None, None)
|
|
498
|
+
for log_key, day, phase, log_item in event_logs:
|
|
499
|
+
if (day, phase) != day_phase:
|
|
500
|
+
day_phase = (day, phase)
|
|
501
|
+
log_parts.append(f"**--- {phase} {day} ---**")
|
|
502
|
+
if log_key == EventLogKeys.PUBLIC_EVENT:
|
|
503
|
+
log_parts.append(f"[EVENT] {log_item.description}")
|
|
504
|
+
elif log_key == EventLogKeys.PRIVATE_ACTION:
|
|
505
|
+
text_parts = [f"[YOUR ACTION & REASONING] You decided to use {type(log_item).__name__} "]
|
|
506
|
+
# account for NOOP
|
|
507
|
+
if log_item.action_field:
|
|
508
|
+
action_field_item = (
|
|
509
|
+
f" - {log_item.action_field.capitalize()}: {getattr(log_item, log_item.action_field)}"
|
|
510
|
+
)
|
|
511
|
+
text_parts.append(action_field_item)
|
|
512
|
+
text_parts.append(f" - Reasoning: {log_item.reasoning}")
|
|
513
|
+
text_parts.append(f" - Perceived threat level: {log_item.perceived_threat_level}")
|
|
514
|
+
log_parts.append("\n".join(text_parts))
|
|
515
|
+
|
|
516
|
+
event_log = "\n\n".join(log_parts)
|
|
517
|
+
|
|
518
|
+
error_instruction = ""
|
|
519
|
+
if error_stack_trace:
|
|
520
|
+
error_instruction = (
|
|
521
|
+
f"\n\nYour previous attempt resulted in the following error:\n{error_stack_trace}\n\n{error_prompt}"
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
content = {
|
|
525
|
+
"system_prompt": self._system_prompt,
|
|
526
|
+
"current_state": json.dumps(current_state, sort_keys=True),
|
|
527
|
+
"event_log": event_log,
|
|
528
|
+
"instruction": instruction,
|
|
529
|
+
"error_instruction": error_instruction,
|
|
530
|
+
}
|
|
531
|
+
return self._prompt_template.format(**content)
|
|
532
|
+
|
|
533
|
+
@staticmethod
|
|
534
|
+
def current_state(obs):
|
|
535
|
+
obs_model = get_raw_observation(obs)
|
|
536
|
+
content = {
|
|
537
|
+
"your_name": obs_model.player_id,
|
|
538
|
+
"your_team": obs_model.team,
|
|
539
|
+
"your_role_name": obs_model.role,
|
|
540
|
+
"all_player_ids": obs_model.all_player_ids,
|
|
541
|
+
"alive_players": obs_model.alive_players,
|
|
542
|
+
"revealed_players": obs_model.revealed_players,
|
|
543
|
+
}
|
|
544
|
+
return content
|
|
545
|
+
|
|
546
|
+
@tenacity.retry(
|
|
547
|
+
retry=tenacity.retry_if_exception(_is_context_window_exceeded_error),
|
|
548
|
+
stop=tenacity.stop_after_attempt(5),
|
|
549
|
+
wait=tenacity.wait_random_exponential(multiplier=1, min=2, max=10),
|
|
550
|
+
before_sleep=_truncate_and_log_on_retry,
|
|
551
|
+
reraise=True,
|
|
552
|
+
)
|
|
553
|
+
def render_prompt_query(self, instruction, obs, error_stack_trace=None, error_prompt=None):
|
|
554
|
+
prompt = self.render_prompt(
|
|
555
|
+
instruction=instruction,
|
|
556
|
+
obs=obs,
|
|
557
|
+
max_log_items=self._event_log_items_to_keep,
|
|
558
|
+
error_stack_trace=error_stack_trace,
|
|
559
|
+
error_prompt=error_prompt,
|
|
560
|
+
)
|
|
561
|
+
out = self.query(prompt)
|
|
562
|
+
return out, prompt
|
|
563
|
+
|
|
564
|
+
@tenacity.retry(
|
|
565
|
+
retry=tenacity.retry_if_exception(_is_json_parsing_error),
|
|
566
|
+
stop=tenacity.stop_after_attempt(3),
|
|
567
|
+
wait=tenacity.wait_random_exponential(multiplier=1, min=2, max=10),
|
|
568
|
+
before_sleep=_add_error_entry_on_retry,
|
|
569
|
+
reraise=True,
|
|
570
|
+
)
|
|
571
|
+
def query_parse(self, instruction, obs, error_stack_trace=None, error_prompt=None):
|
|
572
|
+
raw_out, prompt = self.render_prompt_query(instruction, obs, error_stack_trace, error_prompt)
|
|
573
|
+
try:
|
|
574
|
+
parsed_out = self.parse(raw_out)
|
|
575
|
+
# Add the raw_out and prompt to the output dict
|
|
576
|
+
parsed_out["raw_prompt"] = prompt
|
|
577
|
+
parsed_out["raw_completion"] = raw_out
|
|
578
|
+
return parsed_out
|
|
579
|
+
except pyjson5.Json5Exception as e:
|
|
580
|
+
# Catch the parsing error, wrap it with context, and re-raise.
|
|
581
|
+
# Tenacity will catch this and decide whether to retry.
|
|
582
|
+
raise LLMActionException(
|
|
583
|
+
message="Failed to parse LLM output.", original_exception=e, raw_out=raw_out, prompt=prompt
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
@action_registry.register(DetailedPhase.NIGHT_AWAIT_ACTIONS, RoleConst.WEREWOLF)
|
|
587
|
+
def _night_werewolf_vote(self, entries, obs, common_args):
|
|
588
|
+
# Werewolves target other alive players.
|
|
589
|
+
history_entry = get_last_action_request(entries, EventName.VOTE_REQUEST)
|
|
590
|
+
action = NoOpAction(**common_args, reasoning="There's nothing to be done.")
|
|
591
|
+
if history_entry:
|
|
592
|
+
valid_targets = history_entry.data.get("valid_targets")
|
|
593
|
+
instruction = INSTRUCTION_TEMPLATE.format(
|
|
594
|
+
**{
|
|
595
|
+
"role": "You are a Werewolf.",
|
|
596
|
+
"task": "Vote for a player to eliminate.",
|
|
597
|
+
"additional_constraints": f"- Valid targets are: `{valid_targets}`.",
|
|
598
|
+
"json_schema": json.dumps(TARGETED_ACTION_SCHEMA),
|
|
599
|
+
"exemplar": TARGETED_ACTION_EXEMPLAR,
|
|
600
|
+
}
|
|
601
|
+
)
|
|
602
|
+
parsed_out = self.query_parse(
|
|
603
|
+
instruction, obs, error_prompt="Your previous attempt failed. Please vote again."
|
|
604
|
+
)
|
|
605
|
+
action = EliminateProposalAction(**common_args, **parsed_out)
|
|
606
|
+
return action
|
|
607
|
+
|
|
608
|
+
@action_registry.register(DetailedPhase.NIGHT_AWAIT_ACTIONS, RoleConst.SEER)
|
|
609
|
+
def _night_seer_inspect(self, entries, obs, common_args):
|
|
610
|
+
# Seers can inspect any alive player.
|
|
611
|
+
history_entry = get_last_action_request(entries, EventName.INSPECT_REQUEST)
|
|
612
|
+
action = NoOpAction(**common_args, reasoning="There's nothing to be done.")
|
|
613
|
+
if history_entry:
|
|
614
|
+
valid_targets = history_entry.data["valid_candidates"]
|
|
615
|
+
instruction = INSTRUCTION_TEMPLATE.format(
|
|
616
|
+
**{
|
|
617
|
+
"role": "You are a Seer.",
|
|
618
|
+
"task": "Choose a player to inspect and reveal their role.",
|
|
619
|
+
"additional_constraints": f'- The "target_id" must be in this list: `{valid_targets}`.',
|
|
620
|
+
"json_schema": json.dumps(TARGETED_ACTION_SCHEMA),
|
|
621
|
+
"exemplar": TARGETED_ACTION_EXEMPLAR,
|
|
622
|
+
}
|
|
623
|
+
)
|
|
624
|
+
parsed_out = self.query_parse(
|
|
625
|
+
instruction,
|
|
626
|
+
obs,
|
|
627
|
+
error_prompt="Your previous attempt failed. Please choose one player to inspect again.",
|
|
628
|
+
)
|
|
629
|
+
action = InspectAction(**common_args, **parsed_out)
|
|
630
|
+
return action
|
|
631
|
+
|
|
632
|
+
@action_registry.register(DetailedPhase.NIGHT_AWAIT_ACTIONS, RoleConst.DOCTOR)
|
|
633
|
+
def _night_doctor_heal(self, entries, obs, common_args):
|
|
634
|
+
action = NoOpAction(**common_args, reasoning="There's nothing to be done.")
|
|
635
|
+
history_entry = get_last_action_request(entries, EventName.HEAL_REQUEST)
|
|
636
|
+
if history_entry:
|
|
637
|
+
valid_targets = history_entry.data["valid_candidates"]
|
|
638
|
+
instruction = INSTRUCTION_TEMPLATE.format(
|
|
639
|
+
**{
|
|
640
|
+
"role": "You are a Doctor.",
|
|
641
|
+
"task": "Choose a player to save from the werewolf attack.",
|
|
642
|
+
"additional_constraints": f'- The "target_id" must be in this list: `{valid_targets}`.',
|
|
643
|
+
"json_schema": json.dumps(TARGETED_ACTION_SCHEMA),
|
|
644
|
+
"exemplar": TARGETED_ACTION_EXEMPLAR,
|
|
645
|
+
}
|
|
646
|
+
)
|
|
647
|
+
parsed_out = self.query_parse(
|
|
648
|
+
instruction, obs, error_prompt="Your previous attempt failed. Please choose one player to heal again."
|
|
649
|
+
)
|
|
650
|
+
action = HealAction(**common_args, **parsed_out)
|
|
651
|
+
return action
|
|
652
|
+
|
|
653
|
+
@action_registry.register(DetailedPhase.DAY_BIDDING_AWAIT)
|
|
654
|
+
def _day_bid(self, entries, obs, common_args):
|
|
655
|
+
instruction = INSTRUCTION_TEMPLATE.format(
|
|
656
|
+
**{
|
|
657
|
+
"role": "It is bidding time. You can bid to get a chance to speak.",
|
|
658
|
+
"task": "Decide how much to bid for a speaking turn. A higher bid increases your chance of speaking. You can bid from 0 to 4.",
|
|
659
|
+
"additional_constraints": "- The 'amount' must be an integer between 0 and 4.",
|
|
660
|
+
"json_schema": json.dumps(BID_ACTION_SCHEMA),
|
|
661
|
+
"exemplar": BID_ACTION_EXEMPLAR_REASONING if self._enable_bid_reasoning else BID_ACTION_EXEMPLAR,
|
|
662
|
+
}
|
|
663
|
+
)
|
|
664
|
+
parsed_out = self.query_parse(
|
|
665
|
+
instruction, obs, error_prompt="Your previous attempt failed. Please place your bid again."
|
|
666
|
+
)
|
|
667
|
+
action = BidAction(**common_args, **parsed_out)
|
|
668
|
+
return action
|
|
669
|
+
|
|
670
|
+
@action_registry.register(DetailedPhase.DAY_CHAT_AWAIT)
|
|
671
|
+
def _day_chat(self, entries, obs, common_args):
|
|
672
|
+
# All alive players can discuss.
|
|
673
|
+
if self._chat_mode == "text":
|
|
674
|
+
constraints = CHAT_ACTION_ADDITIONAL_CONSTRAINTS_TEXT
|
|
675
|
+
exemplar = CHAT_ACTION_EXEMPLAR_TEXT
|
|
676
|
+
elif self._chat_mode == "audio": # audio mode
|
|
677
|
+
constraints = CHAT_ACTION_ADDITIONAL_CONSTRAINTS_AUDIO
|
|
678
|
+
exemplar = CHAT_ACTION_EXEMPLAR
|
|
679
|
+
else:
|
|
680
|
+
raise ValueError(
|
|
681
|
+
f'Can only select between "text" mode and "audio" mode to prompt the LLM. "{self._chat_mode}" mode detected.'
|
|
682
|
+
)
|
|
683
|
+
instruction = INSTRUCTION_TEMPLATE.format(
|
|
684
|
+
**{
|
|
685
|
+
"role": "It is day time. Participate in the discussion.",
|
|
686
|
+
"task": 'Discuss with other players to decide who to vote out. Formulate a "message" to persuade others.',
|
|
687
|
+
"additional_constraints": "\n".join(constraints),
|
|
688
|
+
"json_schema": json.dumps(CHAT_ACTION_SCHEMA),
|
|
689
|
+
"exemplar": exemplar,
|
|
690
|
+
}
|
|
691
|
+
)
|
|
692
|
+
parsed_out = self.query_parse(
|
|
693
|
+
instruction, obs, error_prompt="Your previous attempt failed. Please prepare your message again."
|
|
694
|
+
)
|
|
695
|
+
action = ChatAction(**common_args, **parsed_out)
|
|
696
|
+
return action
|
|
697
|
+
|
|
698
|
+
@action_registry.register(DetailedPhase.DAY_VOTING_AWAIT)
|
|
699
|
+
def _day_vote(self, entries, obs, common_args):
|
|
700
|
+
raw_obs = get_raw_observation(obs)
|
|
701
|
+
alive_players = raw_obs.alive_players
|
|
702
|
+
my_id = raw_obs.player_id
|
|
703
|
+
valid_targets = [p for p in alive_players if p != my_id]
|
|
704
|
+
instruction = INSTRUCTION_TEMPLATE.format(
|
|
705
|
+
**{
|
|
706
|
+
"role": "It is day time. It is time to vote.",
|
|
707
|
+
"task": "Choose a player to exile.",
|
|
708
|
+
"additional_constraints": f'- The "target_id" must be in this list: `{valid_targets}`.',
|
|
709
|
+
"json_schema": json.dumps(TARGETED_ACTION_SCHEMA),
|
|
710
|
+
"exemplar": TARGETED_ACTION_EXEMPLAR,
|
|
711
|
+
}
|
|
712
|
+
)
|
|
713
|
+
parsed_out = self.query_parse(
|
|
714
|
+
instruction, obs, error_prompt="Your previous attempt failed. Please cast your vote again."
|
|
715
|
+
)
|
|
716
|
+
action = VoteAction(**common_args, **parsed_out)
|
|
717
|
+
return action
|
|
718
|
+
|
|
719
|
+
def __call__(self, obs):
|
|
720
|
+
raw_obs = get_raw_observation(obs)
|
|
721
|
+
entries = raw_obs.new_player_event_views
|
|
722
|
+
|
|
723
|
+
for entry in entries:
|
|
724
|
+
self._event_logs.append(
|
|
725
|
+
EventLogItem(EventLogKeys.PUBLIC_EVENT, day=entry.day, phase=entry.phase, log_item=entry)
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
# Default to NO_OP if observation is missing or agent cannot act
|
|
729
|
+
if not raw_obs or not entries:
|
|
730
|
+
return {"action_type": ActionType.NO_OP.value, "target_idx": None, "message": None}
|
|
731
|
+
|
|
732
|
+
self._event_log_items_to_keep = len(self._event_logs)
|
|
733
|
+
|
|
734
|
+
current_phase = DetailedPhase(raw_obs.detailed_phase)
|
|
735
|
+
my_role = RoleConst(raw_obs.role)
|
|
736
|
+
|
|
737
|
+
common_args = {"day": raw_obs.day, "phase": raw_obs.game_state_phase, "actor_id": raw_obs.player_id}
|
|
738
|
+
|
|
739
|
+
handler = self.action_registry.get(phase=current_phase, role=my_role)
|
|
740
|
+
|
|
741
|
+
try:
|
|
742
|
+
action = handler(self, entries, obs, common_args)
|
|
743
|
+
except LLMActionException as e:
|
|
744
|
+
# Catch the specific exception after all retries have failed
|
|
745
|
+
error_trace = traceback.format_exc()
|
|
746
|
+
logger.error("An LLMActionException occurred after all retries:\n%s", error_trace)
|
|
747
|
+
logger.error(f'The model failed to act is model_name="{self._model_name}".')
|
|
748
|
+
|
|
749
|
+
# Now you can access the preserved data!
|
|
750
|
+
action = NoOpAction(
|
|
751
|
+
**common_args,
|
|
752
|
+
reasoning="Fell back to NoOp after multiple parsing failures.",
|
|
753
|
+
error=error_trace,
|
|
754
|
+
raw_completion=e.raw_out, # <-- Preserved data
|
|
755
|
+
raw_prompt=e.prompt, # <-- Preserved data
|
|
756
|
+
)
|
|
757
|
+
except Exception:
|
|
758
|
+
error_trace = traceback.format_exc()
|
|
759
|
+
logger.error("An error occurred:\n%s", error_trace)
|
|
760
|
+
logger.error(f'The model failed to act is model_name="{self._model_name}".')
|
|
761
|
+
action = NoOpAction(**common_args, reasoning="", error=error_trace)
|
|
762
|
+
self.log_token_usage()
|
|
763
|
+
# record self action
|
|
764
|
+
self._event_logs.append(
|
|
765
|
+
EventLogItem(EventLogKeys.PRIVATE_ACTION, day=raw_obs.day, phase=raw_obs.game_state_phase, log_item=action)
|
|
766
|
+
)
|
|
767
|
+
return action.serialize()
|