camel-ai 0.2.67__py3-none-any.whl → 0.2.80a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- camel/__init__.py +1 -1
- camel/agents/_types.py +6 -2
- camel/agents/_utils.py +38 -0
- camel/agents/chat_agent.py +4014 -410
- camel/agents/mcp_agent.py +30 -27
- camel/agents/repo_agent.py +2 -1
- camel/benchmarks/browsecomp.py +6 -6
- camel/configs/__init__.py +15 -0
- camel/configs/aihubmix_config.py +88 -0
- camel/configs/amd_config.py +70 -0
- camel/configs/cometapi_config.py +104 -0
- camel/configs/minimax_config.py +93 -0
- camel/configs/nebius_config.py +103 -0
- camel/configs/vllm_config.py +2 -0
- camel/data_collectors/alpaca_collector.py +15 -6
- camel/datagen/self_improving_cot.py +1 -1
- camel/datasets/base_generator.py +39 -10
- camel/environments/__init__.py +12 -0
- camel/environments/rlcards_env.py +860 -0
- camel/environments/single_step.py +28 -3
- camel/environments/tic_tac_toe.py +1 -1
- camel/interpreters/__init__.py +2 -0
- camel/interpreters/docker/Dockerfile +4 -16
- camel/interpreters/docker_interpreter.py +3 -2
- camel/interpreters/e2b_interpreter.py +34 -1
- camel/interpreters/internal_python_interpreter.py +51 -2
- camel/interpreters/microsandbox_interpreter.py +395 -0
- camel/loaders/__init__.py +11 -2
- camel/loaders/base_loader.py +85 -0
- camel/loaders/chunkr_reader.py +9 -0
- camel/loaders/firecrawl_reader.py +4 -4
- camel/logger.py +1 -1
- camel/memories/agent_memories.py +84 -1
- camel/memories/base.py +34 -0
- camel/memories/blocks/chat_history_block.py +122 -4
- camel/memories/blocks/vectordb_block.py +8 -1
- camel/memories/context_creators/score_based.py +29 -237
- camel/memories/records.py +88 -8
- camel/messages/base.py +166 -40
- camel/messages/func_message.py +32 -5
- camel/models/__init__.py +10 -0
- camel/models/aihubmix_model.py +83 -0
- camel/models/aiml_model.py +1 -16
- camel/models/amd_model.py +101 -0
- camel/models/anthropic_model.py +117 -18
- camel/models/aws_bedrock_model.py +2 -33
- camel/models/azure_openai_model.py +205 -91
- camel/models/base_audio_model.py +3 -1
- camel/models/base_model.py +189 -24
- camel/models/cohere_model.py +5 -17
- camel/models/cometapi_model.py +83 -0
- camel/models/crynux_model.py +1 -16
- camel/models/deepseek_model.py +6 -16
- camel/models/fish_audio_model.py +6 -0
- camel/models/gemini_model.py +71 -20
- camel/models/groq_model.py +1 -17
- camel/models/internlm_model.py +1 -16
- camel/models/litellm_model.py +49 -32
- camel/models/lmstudio_model.py +1 -17
- camel/models/minimax_model.py +83 -0
- camel/models/mistral_model.py +1 -16
- camel/models/model_factory.py +27 -1
- camel/models/model_manager.py +24 -6
- camel/models/modelscope_model.py +1 -16
- camel/models/moonshot_model.py +185 -19
- camel/models/nebius_model.py +83 -0
- camel/models/nemotron_model.py +0 -5
- camel/models/netmind_model.py +1 -16
- camel/models/novita_model.py +1 -16
- camel/models/nvidia_model.py +1 -16
- camel/models/ollama_model.py +4 -19
- camel/models/openai_compatible_model.py +171 -46
- camel/models/openai_model.py +205 -77
- camel/models/openrouter_model.py +1 -17
- camel/models/ppio_model.py +1 -16
- camel/models/qianfan_model.py +1 -16
- camel/models/qwen_model.py +1 -16
- camel/models/reka_model.py +1 -16
- camel/models/samba_model.py +34 -47
- camel/models/sglang_model.py +64 -31
- camel/models/siliconflow_model.py +1 -16
- camel/models/stub_model.py +0 -4
- camel/models/togetherai_model.py +1 -16
- camel/models/vllm_model.py +1 -16
- camel/models/volcano_model.py +0 -17
- camel/models/watsonx_model.py +1 -16
- camel/models/yi_model.py +1 -16
- camel/models/zhipuai_model.py +60 -16
- camel/parsers/__init__.py +18 -0
- camel/parsers/mcp_tool_call_parser.py +176 -0
- camel/retrievers/auto_retriever.py +1 -0
- camel/runtimes/configs.py +11 -11
- camel/runtimes/daytona_runtime.py +15 -16
- camel/runtimes/docker_runtime.py +6 -6
- camel/runtimes/remote_http_runtime.py +5 -5
- camel/services/agent_openapi_server.py +380 -0
- camel/societies/__init__.py +2 -0
- camel/societies/role_playing.py +26 -28
- camel/societies/workforce/__init__.py +2 -0
- camel/societies/workforce/events.py +122 -0
- camel/societies/workforce/prompts.py +249 -38
- camel/societies/workforce/role_playing_worker.py +82 -20
- camel/societies/workforce/single_agent_worker.py +634 -34
- camel/societies/workforce/structured_output_handler.py +512 -0
- camel/societies/workforce/task_channel.py +169 -23
- camel/societies/workforce/utils.py +176 -9
- camel/societies/workforce/worker.py +77 -23
- camel/societies/workforce/workflow_memory_manager.py +772 -0
- camel/societies/workforce/workforce.py +3168 -478
- camel/societies/workforce/workforce_callback.py +74 -0
- camel/societies/workforce/workforce_logger.py +203 -175
- camel/societies/workforce/workforce_metrics.py +33 -0
- camel/storages/__init__.py +4 -0
- camel/storages/key_value_storages/json.py +15 -2
- camel/storages/key_value_storages/mem0_cloud.py +48 -47
- camel/storages/object_storages/google_cloud.py +1 -1
- camel/storages/vectordb_storages/__init__.py +6 -0
- camel/storages/vectordb_storages/chroma.py +731 -0
- camel/storages/vectordb_storages/oceanbase.py +13 -13
- camel/storages/vectordb_storages/pgvector.py +349 -0
- camel/storages/vectordb_storages/qdrant.py +3 -3
- camel/storages/vectordb_storages/surreal.py +365 -0
- camel/storages/vectordb_storages/tidb.py +8 -6
- camel/tasks/task.py +244 -27
- camel/toolkits/__init__.py +46 -8
- camel/toolkits/aci_toolkit.py +64 -19
- camel/toolkits/arxiv_toolkit.py +6 -6
- camel/toolkits/base.py +63 -5
- camel/toolkits/code_execution.py +28 -1
- camel/toolkits/context_summarizer_toolkit.py +684 -0
- camel/toolkits/craw4ai_toolkit.py +93 -0
- camel/toolkits/dappier_toolkit.py +10 -6
- camel/toolkits/dingtalk.py +1135 -0
- camel/toolkits/edgeone_pages_mcp_toolkit.py +49 -0
- camel/toolkits/excel_toolkit.py +901 -67
- camel/toolkits/file_toolkit.py +1402 -0
- camel/toolkits/function_tool.py +30 -6
- camel/toolkits/github_toolkit.py +107 -20
- camel/toolkits/gmail_toolkit.py +1839 -0
- camel/toolkits/google_calendar_toolkit.py +38 -4
- camel/toolkits/google_drive_mcp_toolkit.py +54 -0
- camel/toolkits/human_toolkit.py +34 -10
- camel/toolkits/hybrid_browser_toolkit/__init__.py +18 -0
- camel/toolkits/hybrid_browser_toolkit/config_loader.py +185 -0
- camel/toolkits/hybrid_browser_toolkit/hybrid_browser_toolkit.py +246 -0
- camel/toolkits/hybrid_browser_toolkit/hybrid_browser_toolkit_ts.py +1973 -0
- camel/toolkits/hybrid_browser_toolkit/installer.py +203 -0
- camel/toolkits/hybrid_browser_toolkit/ts/package-lock.json +3749 -0
- camel/toolkits/hybrid_browser_toolkit/ts/package.json +32 -0
- camel/toolkits/hybrid_browser_toolkit/ts/src/browser-scripts.js +125 -0
- camel/toolkits/hybrid_browser_toolkit/ts/src/browser-session.ts +1815 -0
- camel/toolkits/hybrid_browser_toolkit/ts/src/config-loader.ts +233 -0
- camel/toolkits/hybrid_browser_toolkit/ts/src/hybrid-browser-toolkit.ts +590 -0
- camel/toolkits/hybrid_browser_toolkit/ts/src/index.ts +7 -0
- camel/toolkits/hybrid_browser_toolkit/ts/src/parent-child-filter.ts +226 -0
- camel/toolkits/hybrid_browser_toolkit/ts/src/snapshot-parser.ts +219 -0
- camel/toolkits/hybrid_browser_toolkit/ts/src/som-screenshot-injected.ts +543 -0
- camel/toolkits/hybrid_browser_toolkit/ts/src/types.ts +130 -0
- camel/toolkits/hybrid_browser_toolkit/ts/tsconfig.json +26 -0
- camel/toolkits/hybrid_browser_toolkit/ts/websocket-server.js +319 -0
- camel/toolkits/hybrid_browser_toolkit/ws_wrapper.py +1032 -0
- camel/toolkits/hybrid_browser_toolkit_py/__init__.py +17 -0
- camel/toolkits/hybrid_browser_toolkit_py/actions.py +575 -0
- camel/toolkits/hybrid_browser_toolkit_py/agent.py +311 -0
- camel/toolkits/hybrid_browser_toolkit_py/browser_session.py +787 -0
- camel/toolkits/hybrid_browser_toolkit_py/config_loader.py +490 -0
- camel/toolkits/hybrid_browser_toolkit_py/hybrid_browser_toolkit.py +2390 -0
- camel/toolkits/hybrid_browser_toolkit_py/snapshot.py +233 -0
- camel/toolkits/hybrid_browser_toolkit_py/stealth_script.js +0 -0
- camel/toolkits/hybrid_browser_toolkit_py/unified_analyzer.js +1043 -0
- camel/toolkits/image_generation_toolkit.py +390 -0
- camel/toolkits/jina_reranker_toolkit.py +3 -4
- camel/toolkits/klavis_toolkit.py +5 -1
- camel/toolkits/markitdown_toolkit.py +104 -0
- camel/toolkits/math_toolkit.py +64 -10
- camel/toolkits/mcp_toolkit.py +370 -45
- camel/toolkits/memory_toolkit.py +5 -1
- camel/toolkits/message_agent_toolkit.py +608 -0
- camel/toolkits/message_integration.py +724 -0
- camel/toolkits/minimax_mcp_toolkit.py +195 -0
- camel/toolkits/note_taking_toolkit.py +277 -0
- camel/toolkits/notion_mcp_toolkit.py +224 -0
- camel/toolkits/openbb_toolkit.py +5 -1
- camel/toolkits/origene_mcp_toolkit.py +56 -0
- camel/toolkits/playwright_mcp_toolkit.py +12 -31
- camel/toolkits/pptx_toolkit.py +25 -12
- camel/toolkits/resend_toolkit.py +168 -0
- camel/toolkits/screenshot_toolkit.py +213 -0
- camel/toolkits/search_toolkit.py +437 -142
- camel/toolkits/slack_toolkit.py +104 -50
- camel/toolkits/sympy_toolkit.py +1 -1
- camel/toolkits/task_planning_toolkit.py +3 -3
- camel/toolkits/terminal_toolkit/__init__.py +18 -0
- camel/toolkits/terminal_toolkit/terminal_toolkit.py +957 -0
- camel/toolkits/terminal_toolkit/utils.py +532 -0
- camel/toolkits/thinking_toolkit.py +1 -1
- camel/toolkits/vertex_ai_veo_toolkit.py +590 -0
- camel/toolkits/video_analysis_toolkit.py +106 -26
- camel/toolkits/video_download_toolkit.py +17 -14
- camel/toolkits/web_deploy_toolkit.py +1219 -0
- camel/toolkits/wechat_official_toolkit.py +483 -0
- camel/toolkits/zapier_toolkit.py +5 -1
- camel/types/__init__.py +2 -2
- camel/types/agents/tool_calling_record.py +4 -1
- camel/types/enums.py +316 -40
- camel/types/openai_types.py +2 -2
- camel/types/unified_model_type.py +31 -4
- camel/utils/commons.py +36 -5
- camel/utils/constants.py +3 -0
- camel/utils/context_utils.py +1003 -0
- camel/utils/mcp.py +138 -4
- camel/utils/mcp_client.py +45 -1
- camel/utils/message_summarizer.py +148 -0
- camel/utils/token_counting.py +43 -20
- camel/utils/tool_result.py +44 -0
- {camel_ai-0.2.67.dist-info → camel_ai-0.2.80a2.dist-info}/METADATA +296 -85
- {camel_ai-0.2.67.dist-info → camel_ai-0.2.80a2.dist-info}/RECORD +219 -146
- camel/loaders/pandas_reader.py +0 -368
- camel/toolkits/dalle_toolkit.py +0 -175
- camel/toolkits/file_write_toolkit.py +0 -444
- camel/toolkits/openai_agent_toolkit.py +0 -135
- camel/toolkits/terminal_toolkit.py +0 -1037
- {camel_ai-0.2.67.dist-info → camel_ai-0.2.80a2.dist-info}/WHEEL +0 -0
- {camel_ai-0.2.67.dist-info → camel_ai-0.2.80a2.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,860 @@
|
|
|
1
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
|
+
|
|
15
|
+
import re
|
|
16
|
+
from abc import abstractmethod
|
|
17
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from rlcard.agents import RandomAgent
|
|
21
|
+
|
|
22
|
+
from camel.environments.models import Action, Observation
|
|
23
|
+
from camel.environments.multi_step import MultiStepEnv
|
|
24
|
+
from camel.extractors import BaseExtractor, BaseExtractorStrategy
|
|
25
|
+
from camel.logger import get_logger
|
|
26
|
+
|
|
27
|
+
logger = get_logger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ActionExtractor(BaseExtractorStrategy):
|
|
31
|
+
r"""A strategy for extracting RLCard actions from text."""
|
|
32
|
+
|
|
33
|
+
def __init__(self, action_pattern: str = r"<Action>\s*(.+)") -> None:
|
|
34
|
+
r"""Initialize the action extractor with a regex pattern.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
action_pattern (str): The regex pattern to extract actions.
|
|
38
|
+
(default: :obj:`"<Action>\\s*(.+)"`).
|
|
39
|
+
"""
|
|
40
|
+
self.action_pattern = action_pattern
|
|
41
|
+
|
|
42
|
+
async def extract(self, text: str) -> Optional[str]:
|
|
43
|
+
r"""Extract a valid RLCard action from text.
|
|
44
|
+
|
|
45
|
+
Looks for a pattern '<Action> action_str' where action_str is the
|
|
46
|
+
string representation of the action.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
text (str): The text to extract the action from.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Optional[str]: The extracted action as a string, or None
|
|
53
|
+
if no valid action is found.
|
|
54
|
+
"""
|
|
55
|
+
match = re.search(self.action_pattern, text)
|
|
56
|
+
if match:
|
|
57
|
+
action_str = match.group(1).strip()
|
|
58
|
+
return action_str
|
|
59
|
+
return None
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class RLCardsEnv(MultiStepEnv):
|
|
63
|
+
r"""A base environment for RLCard games.
|
|
64
|
+
|
|
65
|
+
This environment implements a wrapper around RLCard environments for
|
|
66
|
+
reinforcement learning with LLMs. It handles the conversion between
|
|
67
|
+
RLCard states and actions and the CAMEL environment interface.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
game_name: str,
|
|
73
|
+
extractor: Optional[BaseExtractor] = None,
|
|
74
|
+
max_steps: Optional[int] = None,
|
|
75
|
+
num_players: int = 2,
|
|
76
|
+
**kwargs,
|
|
77
|
+
) -> None:
|
|
78
|
+
r"""Initialize the RLCard environment.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
game_name (str): The name of the RLCard game to play.
|
|
82
|
+
extractor (Optional[BaseExtractor]): Extractor to process LLM
|
|
83
|
+
responses. If None, a default extractor with ActionExtractor
|
|
84
|
+
will be used. (default: :obj:`None`)
|
|
85
|
+
max_steps (Optional[int]): Maximum steps per episode.
|
|
86
|
+
(default: :obj:`None`)
|
|
87
|
+
num_players (int): Number of players in the game.
|
|
88
|
+
(default: :obj:`2`)
|
|
89
|
+
**kwargs: Additional environment parameters.
|
|
90
|
+
"""
|
|
91
|
+
if extractor is None:
|
|
92
|
+
extractor = BaseExtractor(pipeline=[[ActionExtractor()]])
|
|
93
|
+
|
|
94
|
+
super().__init__(extractor, max_steps, **kwargs)
|
|
95
|
+
|
|
96
|
+
self.game_name = game_name
|
|
97
|
+
self.num_players = num_players
|
|
98
|
+
self.rlcard_env = None
|
|
99
|
+
self.current_player_id = None
|
|
100
|
+
self.agents: Optional[List[Optional['RandomAgent']]] = None
|
|
101
|
+
|
|
102
|
+
async def _setup(self) -> None:
|
|
103
|
+
r"""Set up the RLCard environment.
|
|
104
|
+
|
|
105
|
+
This method initializes the RLCard environment with the specified game
|
|
106
|
+
and parameters.
|
|
107
|
+
"""
|
|
108
|
+
import rlcard
|
|
109
|
+
|
|
110
|
+
try:
|
|
111
|
+
# Create the RLCard environment
|
|
112
|
+
self.rlcard_env = rlcard.make(
|
|
113
|
+
self.game_name,
|
|
114
|
+
config={
|
|
115
|
+
'game_num_players': self.num_players,
|
|
116
|
+
'allow_step_back': True,
|
|
117
|
+
**self._metadata.get('rlcard_config', {}),
|
|
118
|
+
},
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
from rlcard.agents import RandomAgent
|
|
122
|
+
|
|
123
|
+
# Initialize random agents for opponents
|
|
124
|
+
self.agents = [None] * self.num_players
|
|
125
|
+
assert self.rlcard_env is not None
|
|
126
|
+
|
|
127
|
+
for i in range(1, self.num_players): # Skip player 0 (LLM agent)
|
|
128
|
+
self.agents[i] = RandomAgent(
|
|
129
|
+
num_actions=self.rlcard_env.num_actions
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
logger.info(
|
|
133
|
+
f"RLCard environment for {self.game_name} initialized "
|
|
134
|
+
f"successfully"
|
|
135
|
+
)
|
|
136
|
+
except Exception as e:
|
|
137
|
+
logger.error(f"Failed to initialize RLCard environment: {e}")
|
|
138
|
+
raise
|
|
139
|
+
|
|
140
|
+
async def _close(self) -> None:
|
|
141
|
+
r"""Clean up the RLCard environment."""
|
|
142
|
+
self.rlcard_env = None
|
|
143
|
+
self.agents = None
|
|
144
|
+
|
|
145
|
+
def _get_initial_state(self) -> Dict[str, Any]:
|
|
146
|
+
r"""Get the initial state of the environment.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Dict[str, Any]: A dictionary containing the initial state with
|
|
150
|
+
game state, player info, and game status flags.
|
|
151
|
+
"""
|
|
152
|
+
return {
|
|
153
|
+
"rlcard_state": None,
|
|
154
|
+
"legal_actions": [],
|
|
155
|
+
"game_over": False,
|
|
156
|
+
"winner": None,
|
|
157
|
+
"payoffs": None,
|
|
158
|
+
"last_action": None,
|
|
159
|
+
"last_action_illegal": False,
|
|
160
|
+
"extraction_error": None,
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
async def _update_state(self, action: Action) -> None:
|
|
164
|
+
r"""Update the environment state based on the agent's action.
|
|
165
|
+
|
|
166
|
+
This method processes the agent's action, updates the game state,
|
|
167
|
+
and handles opponent moves if necessary.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
action (Action): The action containing the LLM's response with the
|
|
171
|
+
chosen move.
|
|
172
|
+
"""
|
|
173
|
+
assert self.rlcard_env is not None
|
|
174
|
+
|
|
175
|
+
if self._state["game_over"]:
|
|
176
|
+
return
|
|
177
|
+
|
|
178
|
+
# Extract action from LLM response
|
|
179
|
+
extraction_result = await self.extractor.extract(action.llm_response)
|
|
180
|
+
if not extraction_result:
|
|
181
|
+
self._state["last_action_illegal"] = True
|
|
182
|
+
self._state["extraction_error"] = (
|
|
183
|
+
"Could not extract a valid action"
|
|
184
|
+
)
|
|
185
|
+
return
|
|
186
|
+
|
|
187
|
+
# Convert extracted action to RLCard action format
|
|
188
|
+
rlcard_action = self._convert_to_rlcard_action(extraction_result)
|
|
189
|
+
if (
|
|
190
|
+
rlcard_action is None
|
|
191
|
+
or rlcard_action not in self._state["legal_actions"]
|
|
192
|
+
):
|
|
193
|
+
self._state["last_action_illegal"] = True
|
|
194
|
+
self._state["extraction_error"] = (
|
|
195
|
+
f"'{extraction_result}' is not a valid action"
|
|
196
|
+
)
|
|
197
|
+
return
|
|
198
|
+
|
|
199
|
+
# Reset illegal action flag
|
|
200
|
+
self._state["last_action_illegal"] = False
|
|
201
|
+
self._state["extraction_error"] = None
|
|
202
|
+
self._state["last_action"] = extraction_result
|
|
203
|
+
|
|
204
|
+
# Take the action in the environment
|
|
205
|
+
next_state, self.current_player_id = self.rlcard_env.step(
|
|
206
|
+
rlcard_action
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
# Update state with new information
|
|
210
|
+
self._state["rlcard_state"] = next_state
|
|
211
|
+
self._state["legal_actions"] = next_state['legal_actions'][
|
|
212
|
+
self.current_player_id
|
|
213
|
+
]
|
|
214
|
+
|
|
215
|
+
# Check if game is over
|
|
216
|
+
if self.rlcard_env.is_over():
|
|
217
|
+
self._state["game_over"] = True
|
|
218
|
+
self._state["payoffs"] = self.rlcard_env.get_payoffs()
|
|
219
|
+
# Determine winner based on payoffs
|
|
220
|
+
payoffs = self._state["payoffs"]
|
|
221
|
+
if payoffs[0] > 0:
|
|
222
|
+
self._state["winner"] = "player"
|
|
223
|
+
elif any(p > 0 for p in payoffs[1:]):
|
|
224
|
+
self._state["winner"] = "opponent"
|
|
225
|
+
else:
|
|
226
|
+
self._state["winner"] = "draw"
|
|
227
|
+
return
|
|
228
|
+
|
|
229
|
+
# If next player is not the LLM agent (player 0), let opponents play
|
|
230
|
+
while self.current_player_id != 0 and not self._state["game_over"]:
|
|
231
|
+
# Get action from the corresponding agent
|
|
232
|
+
agent_action = self.agents[self.current_player_id].eval_step(
|
|
233
|
+
next_state
|
|
234
|
+
)
|
|
235
|
+
# Take the action
|
|
236
|
+
next_state, self.current_player_id = self.rlcard_env.step(
|
|
237
|
+
agent_action
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Update state
|
|
241
|
+
self._state["rlcard_state"] = next_state
|
|
242
|
+
if self.current_player_id == 0: # Back to LLM agent
|
|
243
|
+
self._state["legal_actions"] = next_state['legal_actions'][0]
|
|
244
|
+
|
|
245
|
+
# Check if game is over after opponent's move
|
|
246
|
+
if self.rlcard_env.is_over():
|
|
247
|
+
self._state["game_over"] = True
|
|
248
|
+
self._state["payoffs"] = self.rlcard_env.get_payoffs()
|
|
249
|
+
# Determine winner based on payoffs
|
|
250
|
+
payoffs = self._state["payoffs"]
|
|
251
|
+
if payoffs[0] > 0:
|
|
252
|
+
self._state["winner"] = "player"
|
|
253
|
+
elif any(p > 0 for p in payoffs[1:]):
|
|
254
|
+
self._state["winner"] = "opponent"
|
|
255
|
+
else:
|
|
256
|
+
self._state["winner"] = "draw"
|
|
257
|
+
|
|
258
|
+
def _get_next_observation(self) -> Observation:
|
|
259
|
+
r"""Get the next observation based on the current state.
|
|
260
|
+
|
|
261
|
+
This method generates a text observation describing the current state
|
|
262
|
+
of the game and prompting the agent to make a move.
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
Observation: An Observation object containing the game state
|
|
266
|
+
description.
|
|
267
|
+
"""
|
|
268
|
+
assert self.rlcard_env is not None
|
|
269
|
+
|
|
270
|
+
if self._state["rlcard_state"] is None:
|
|
271
|
+
# Initial observation before the game starts
|
|
272
|
+
state, self.current_player_id = self.rlcard_env.reset()
|
|
273
|
+
self._state["rlcard_state"] = state
|
|
274
|
+
# Safely get legal actions, default to empty list if key
|
|
275
|
+
# is missing or value is None
|
|
276
|
+
legal_actions_dict = state.get('legal_actions', {})
|
|
277
|
+
player_legal_actions = legal_actions_dict.get(
|
|
278
|
+
self.current_player_id
|
|
279
|
+
)
|
|
280
|
+
self._state["legal_actions"] = (
|
|
281
|
+
player_legal_actions
|
|
282
|
+
if player_legal_actions is not None
|
|
283
|
+
else []
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
# Generate observation text
|
|
287
|
+
if self._state["last_action_illegal"]:
|
|
288
|
+
# Safely retrieve last_action to prevent None value
|
|
289
|
+
last_action = self._state.get("last_action", "None")
|
|
290
|
+
error_msg = self._state.get("extraction_error", "Unknown error")
|
|
291
|
+
|
|
292
|
+
inter_state_space = self._format_state_for_observation(
|
|
293
|
+
self._state['rlcard_state']
|
|
294
|
+
)
|
|
295
|
+
inter_action_space = self._format_legal_actions(
|
|
296
|
+
self._state['legal_actions']
|
|
297
|
+
)
|
|
298
|
+
obs_text = (
|
|
299
|
+
f"You are playing {self.game_name}.\n"
|
|
300
|
+
f"Your last action '{last_action}' was illegal.\n"
|
|
301
|
+
f"Error: {error_msg}\n"
|
|
302
|
+
f"Current game state:\n"
|
|
303
|
+
f"{inter_state_space}\n"
|
|
304
|
+
f"Legal actions: {inter_action_space}\n"
|
|
305
|
+
f"Please choose an action and end your response with "
|
|
306
|
+
f"<Action> [your action]"
|
|
307
|
+
)
|
|
308
|
+
else:
|
|
309
|
+
inter_state_space = self._format_state_for_observation(
|
|
310
|
+
self._state['rlcard_state']
|
|
311
|
+
)
|
|
312
|
+
inter_action_space = self._format_legal_actions(
|
|
313
|
+
self._state['legal_actions']
|
|
314
|
+
)
|
|
315
|
+
obs_text = (
|
|
316
|
+
f"You are playing {self.game_name}.\n"
|
|
317
|
+
f"Current game state:\n"
|
|
318
|
+
f"{inter_state_space}\n"
|
|
319
|
+
f"Legal actions: {inter_action_space}\n"
|
|
320
|
+
f"Please choose an action and end your response with "
|
|
321
|
+
f"<Action> [your action]"
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
return Observation(
|
|
325
|
+
question=obs_text,
|
|
326
|
+
context={
|
|
327
|
+
"game_name": self.game_name,
|
|
328
|
+
"raw_state": self._state["rlcard_state"],
|
|
329
|
+
"legal_actions": self._state["legal_actions"],
|
|
330
|
+
},
|
|
331
|
+
metadata={},
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
def _get_terminal_observation(self) -> Observation:
|
|
335
|
+
r"""Get the final observation when the game is over.
|
|
336
|
+
|
|
337
|
+
This method generates a text observation describing the final state
|
|
338
|
+
of the game and the game result (win, loss, or draw).
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
Observation: An Observation object containing the final game state
|
|
342
|
+
description.
|
|
343
|
+
"""
|
|
344
|
+
result_message = ""
|
|
345
|
+
if self._state["winner"] == "player":
|
|
346
|
+
result_message = "Congratulations, you won!"
|
|
347
|
+
elif self._state["winner"] == "opponent":
|
|
348
|
+
result_message = "Sorry, you lost!"
|
|
349
|
+
else:
|
|
350
|
+
result_message = "It's a draw!"
|
|
351
|
+
|
|
352
|
+
# Safely handle errors to prevent errors caused by None type
|
|
353
|
+
payoffs = self._state.get("payoffs", [])
|
|
354
|
+
payoffs_str = (
|
|
355
|
+
", ".join([f"{p:.2f}" for p in payoffs]) if payoffs else "N/A"
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
obs_text = (
|
|
359
|
+
f"Game Over. {result_message}\n"
|
|
360
|
+
f"Final game state:\n"
|
|
361
|
+
f"{self._format_state_for_observation(self._state['rlcard_state'])}\n"
|
|
362
|
+
f"Payoffs: [{payoffs_str}]\n"
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
return Observation(
|
|
366
|
+
question=obs_text,
|
|
367
|
+
context={
|
|
368
|
+
"game_name": self.game_name,
|
|
369
|
+
"raw_state": self._state["rlcard_state"],
|
|
370
|
+
"payoffs": self._state["payoffs"],
|
|
371
|
+
"winner": self._state["winner"],
|
|
372
|
+
},
|
|
373
|
+
metadata={},
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
async def compute_reward(
|
|
377
|
+
self,
|
|
378
|
+
) -> Tuple[float, Dict[str, float]]:
|
|
379
|
+
r"""Compute the reward for the current state.
|
|
380
|
+
|
|
381
|
+
Returns:
|
|
382
|
+
Tuple[float, Dict[str, float]]: A tuple containing the total
|
|
383
|
+
reward and a dictionary of reward components:
|
|
384
|
+
- 1.0 for a win
|
|
385
|
+
- 0.0 for a loss or illegal move
|
|
386
|
+
- 0.5 for a draw
|
|
387
|
+
- For ongoing games, returns a small step penalty
|
|
388
|
+
"""
|
|
389
|
+
if self._state["game_over"]:
|
|
390
|
+
if self._state["winner"] == "player":
|
|
391
|
+
return 1.0, {"win": 1.0}
|
|
392
|
+
elif self._state["winner"] == "opponent":
|
|
393
|
+
return 0.0, {"loss": 0.0}
|
|
394
|
+
else:
|
|
395
|
+
return 0.5, {"draw": 0.5}
|
|
396
|
+
elif self._state["last_action_illegal"]:
|
|
397
|
+
return 0.0, {"illegal_move": 0.0}
|
|
398
|
+
else:
|
|
399
|
+
# Small negative reward for each step to encourage efficiency
|
|
400
|
+
step_penalty = -0.01
|
|
401
|
+
return step_penalty, {"step_penalty": step_penalty}
|
|
402
|
+
|
|
403
|
+
def _is_done(self) -> bool:
|
|
404
|
+
r"""Check if the episode is done.
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
bool: True if the game is over, False otherwise.
|
|
408
|
+
"""
|
|
409
|
+
return self._state["game_over"]
|
|
410
|
+
|
|
411
|
+
@abstractmethod
|
|
412
|
+
def _convert_to_rlcard_action(self, action_str: str) -> Any:
|
|
413
|
+
r"""Convert a string action to the format expected by RLCard.
|
|
414
|
+
|
|
415
|
+
This method must be implemented by subclasses to handle the specific
|
|
416
|
+
action format of each game.
|
|
417
|
+
|
|
418
|
+
Args:
|
|
419
|
+
action_str (str): The string representation of the action.
|
|
420
|
+
|
|
421
|
+
Returns:
|
|
422
|
+
Any: The action in the format expected by the RLCard environment.
|
|
423
|
+
"""
|
|
424
|
+
pass
|
|
425
|
+
|
|
426
|
+
@abstractmethod
|
|
427
|
+
def _format_state_for_observation(self, state: Dict[str, Any]) -> str:
|
|
428
|
+
r"""Format the RLCard state for human-readable observation.
|
|
429
|
+
|
|
430
|
+
This method must be implemented by subclasses to create a
|
|
431
|
+
human-readable representation of the game state.
|
|
432
|
+
|
|
433
|
+
Args:
|
|
434
|
+
state (Dict[str, Any]): The RLCard state dictionary.
|
|
435
|
+
|
|
436
|
+
Returns:
|
|
437
|
+
str: A human-readable representation of the state.
|
|
438
|
+
"""
|
|
439
|
+
pass
|
|
440
|
+
|
|
441
|
+
@abstractmethod
|
|
442
|
+
def _format_legal_actions(self, legal_actions: List[Any]) -> str:
|
|
443
|
+
r"""Format the legal actions for human-readable observation.
|
|
444
|
+
|
|
445
|
+
This method must be implemented by subclasses to create a
|
|
446
|
+
human-readable representation of the legal actions.
|
|
447
|
+
|
|
448
|
+
Args:
|
|
449
|
+
legal_actions (List[Any]): The list of legal actions.
|
|
450
|
+
|
|
451
|
+
Returns:
|
|
452
|
+
str: A human-readable representation of the legal actions.
|
|
453
|
+
"""
|
|
454
|
+
pass
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
class BlackjackEnv(RLCardsEnv):
|
|
458
|
+
r"""A Blackjack environment for reinforcement learning with LLMs.
|
|
459
|
+
|
|
460
|
+
This environment implements a standard Blackjack game where the LLM agent
|
|
461
|
+
plays against a dealer.
|
|
462
|
+
"""
|
|
463
|
+
|
|
464
|
+
def __init__(
|
|
465
|
+
self,
|
|
466
|
+
extractor: Optional[BaseExtractor] = None,
|
|
467
|
+
max_steps: Optional[int] = None,
|
|
468
|
+
**kwargs,
|
|
469
|
+
) -> None:
|
|
470
|
+
r"""Initialize the Blackjack environment.
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
extractor (Optional[BaseExtractor]): Extractor to process LLM
|
|
474
|
+
responses. If None, a default extractor will be used.
|
|
475
|
+
(default: :obj:`None`)
|
|
476
|
+
max_steps (Optional[int]): Maximum steps per episode.
|
|
477
|
+
(default: :obj:`None`)
|
|
478
|
+
**kwargs: Additional environment parameters.
|
|
479
|
+
"""
|
|
480
|
+
super().__init__(
|
|
481
|
+
"blackjack", extractor, max_steps, num_players=1, **kwargs
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
def _convert_to_rlcard_action(self, action_str: str) -> int:
|
|
485
|
+
r"""Convert a string action to the format expected by RLCard Blackjack.
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
action_str (str): The string representation of the action.
|
|
489
|
+
Expected to be 'hit' or 'stand'.
|
|
490
|
+
|
|
491
|
+
Returns:
|
|
492
|
+
int: 0 for 'hit', 1 for 'stand'.
|
|
493
|
+
"""
|
|
494
|
+
action_str = action_str.lower().strip()
|
|
495
|
+
if action_str == "hit":
|
|
496
|
+
return 0
|
|
497
|
+
elif action_str == "stand":
|
|
498
|
+
return 1
|
|
499
|
+
raise ValueError()
|
|
500
|
+
|
|
501
|
+
def _format_state_for_observation(self, state: Dict[str, Any]) -> str:
|
|
502
|
+
r"""Format the Blackjack state for human-readable observation.
|
|
503
|
+
|
|
504
|
+
Args:
|
|
505
|
+
state (Dict[str, Any]): The RLCard state dictionary.
|
|
506
|
+
|
|
507
|
+
Returns:
|
|
508
|
+
str: A human-readable representation of the state.
|
|
509
|
+
"""
|
|
510
|
+
if state is None:
|
|
511
|
+
return "Game not started yet."
|
|
512
|
+
|
|
513
|
+
# Extract state information safely
|
|
514
|
+
raw_obs = state.get('raw_obs', {})
|
|
515
|
+
if raw_obs is None:
|
|
516
|
+
raw_obs = {}
|
|
517
|
+
player_hand = raw_obs.get('player', [])
|
|
518
|
+
dealer_hand = raw_obs.get('dealer', [])
|
|
519
|
+
|
|
520
|
+
# 确保player_hand和dealer_hand是列表
|
|
521
|
+
if player_hand is None:
|
|
522
|
+
player_hand = []
|
|
523
|
+
if dealer_hand is None:
|
|
524
|
+
dealer_hand = []
|
|
525
|
+
|
|
526
|
+
# Format hands
|
|
527
|
+
player_cards = self._format_cards(player_hand)
|
|
528
|
+
dealer_cards = self._format_cards(dealer_hand)
|
|
529
|
+
|
|
530
|
+
# Calculate hand values
|
|
531
|
+
player_value = self._calculate_hand_value(player_hand)
|
|
532
|
+
dealer_value = self._calculate_hand_value(dealer_hand)
|
|
533
|
+
|
|
534
|
+
return (
|
|
535
|
+
f"Your hand: {player_cards} (Value: {player_value})\n"
|
|
536
|
+
f"Dealer's hand: {dealer_cards} (Value: {dealer_value})"
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
def _format_legal_actions(self, legal_actions: List[int]) -> str:
|
|
540
|
+
r"""Format the legal actions for Blackjack.
|
|
541
|
+
|
|
542
|
+
Args:
|
|
543
|
+
legal_actions (List[int]): The list of legal actions.
|
|
544
|
+
|
|
545
|
+
Returns:
|
|
546
|
+
str: A human-readable representation of the legal actions.
|
|
547
|
+
"""
|
|
548
|
+
if not legal_actions:
|
|
549
|
+
return "No legal actions available"
|
|
550
|
+
|
|
551
|
+
action_map = {0: "hit", 1: "stand"}
|
|
552
|
+
return ", ".join([action_map.get(a, str(a)) for a in legal_actions])
|
|
553
|
+
|
|
554
|
+
def _format_cards(self, cards: List[str]) -> str:
|
|
555
|
+
r"""Format a list of cards for display.
|
|
556
|
+
|
|
557
|
+
Args:
|
|
558
|
+
cards (List[str]): List of card strings.
|
|
559
|
+
|
|
560
|
+
Returns:
|
|
561
|
+
str: Formatted card string.
|
|
562
|
+
"""
|
|
563
|
+
return ", ".join(cards)
|
|
564
|
+
|
|
565
|
+
def _calculate_hand_value(self, cards: List[str]) -> int:
|
|
566
|
+
r"""Calculate the value of a hand in Blackjack.
|
|
567
|
+
|
|
568
|
+
Args:
|
|
569
|
+
cards (List[str]): List of card strings.
|
|
570
|
+
|
|
571
|
+
Returns:
|
|
572
|
+
int: The value of the hand.
|
|
573
|
+
"""
|
|
574
|
+
value = 0
|
|
575
|
+
num_aces = 0
|
|
576
|
+
|
|
577
|
+
for card in cards:
|
|
578
|
+
# Extract the rank (first character(s) before the suit)
|
|
579
|
+
rank = card[:1]
|
|
580
|
+
if rank == 'A':
|
|
581
|
+
num_aces += 1
|
|
582
|
+
value += 11
|
|
583
|
+
elif rank in ['J', 'Q', 'K', 'T']:
|
|
584
|
+
value += 10
|
|
585
|
+
else:
|
|
586
|
+
value += int(rank)
|
|
587
|
+
|
|
588
|
+
# Adjust for aces if needed
|
|
589
|
+
while value > 21 and num_aces > 0:
|
|
590
|
+
value -= 10 # Change an ace from 11 to 1
|
|
591
|
+
num_aces -= 1
|
|
592
|
+
|
|
593
|
+
return value
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
class LeducHoldemEnv(RLCardsEnv):
|
|
597
|
+
r"""A Leduc Hold'em environment for reinforcement learning with LLMs.
|
|
598
|
+
|
|
599
|
+
This environment implements a Leduc Hold'em poker game where the LLM agent
|
|
600
|
+
plays against one or more opponents.
|
|
601
|
+
"""
|
|
602
|
+
|
|
603
|
+
def __init__(
|
|
604
|
+
self,
|
|
605
|
+
extractor: Optional[BaseExtractor] = None,
|
|
606
|
+
max_steps: Optional[int] = None,
|
|
607
|
+
num_players: int = 2,
|
|
608
|
+
**kwargs,
|
|
609
|
+
) -> None:
|
|
610
|
+
r"""Initialize the Leduc Hold'em environment.
|
|
611
|
+
|
|
612
|
+
Args:
|
|
613
|
+
extractor (Optional[BaseExtractor]): Extractor to process LLM
|
|
614
|
+
responses. If None, a default extractor will be used.
|
|
615
|
+
(default: :obj:`None`)
|
|
616
|
+
max_steps (Optional[int]): Maximum steps per episode.
|
|
617
|
+
(default: :obj:`None`)
|
|
618
|
+
num_players (int): Number of players in the game.
|
|
619
|
+
(default: :obj:`2`)
|
|
620
|
+
**kwargs: Additional environment parameters.
|
|
621
|
+
"""
|
|
622
|
+
super().__init__(
|
|
623
|
+
"leduc-holdem",
|
|
624
|
+
extractor,
|
|
625
|
+
max_steps,
|
|
626
|
+
num_players=num_players,
|
|
627
|
+
**kwargs,
|
|
628
|
+
)
|
|
629
|
+
|
|
630
|
+
def _convert_to_rlcard_action(self, action_str: str) -> int:
|
|
631
|
+
r"""Convert a string action to the format expected by RLCard
|
|
632
|
+
Leduc Hold'em.
|
|
633
|
+
|
|
634
|
+
Args:
|
|
635
|
+
action_str (str): The string representation of the action.
|
|
636
|
+
Expected to be 'fold', 'check', 'call', or 'raise'.
|
|
637
|
+
|
|
638
|
+
Returns:
|
|
639
|
+
int: 0 for 'fold', 1 for 'check/call', 2 for 'raise'.
|
|
640
|
+
"""
|
|
641
|
+
action_str = action_str.lower().strip()
|
|
642
|
+
if action_str == "fold":
|
|
643
|
+
return 0
|
|
644
|
+
elif action_str in ["check", "call"]:
|
|
645
|
+
return 1
|
|
646
|
+
elif action_str == "raise":
|
|
647
|
+
return 2
|
|
648
|
+
else:
|
|
649
|
+
raise ValueError()
|
|
650
|
+
|
|
651
|
+
def _format_state_for_observation(self, state: Dict[str, Any]) -> str:
|
|
652
|
+
r"""Format the Leduc Hold'em state for human-readable observation.
|
|
653
|
+
|
|
654
|
+
Args:
|
|
655
|
+
state (Dict[str, Any]): The RLCard state dictionary.
|
|
656
|
+
|
|
657
|
+
Returns:
|
|
658
|
+
str: A human-readable representation of the state.
|
|
659
|
+
"""
|
|
660
|
+
if state is None:
|
|
661
|
+
return "Game not started yet."
|
|
662
|
+
|
|
663
|
+
raw_obs = state.get('raw_obs', {})
|
|
664
|
+
if raw_obs is None:
|
|
665
|
+
raw_obs = {}
|
|
666
|
+
|
|
667
|
+
hand = raw_obs.get('hand', [])
|
|
668
|
+
public_card = raw_obs.get('public_card', None)
|
|
669
|
+
all_chips = raw_obs.get('all_chips', [])
|
|
670
|
+
my_chips = all_chips[0] if all_chips else 0
|
|
671
|
+
opponent_chips = all_chips[1:] if len(all_chips) > 1 else []
|
|
672
|
+
stage = raw_obs.get('stage', 0)
|
|
673
|
+
current_round = "pre-flop" if stage == 0 else "flop"
|
|
674
|
+
|
|
675
|
+
# Format the observation
|
|
676
|
+
obs_text = f"Round: {current_round}\n"
|
|
677
|
+
obs_text += f"Your hand: {hand}\n"
|
|
678
|
+
if public_card:
|
|
679
|
+
obs_text += f"Public card: {public_card}\n"
|
|
680
|
+
else:
|
|
681
|
+
obs_text += "Public card: None\n"
|
|
682
|
+
|
|
683
|
+
obs_text += f"Your chips: {my_chips}\n"
|
|
684
|
+
for i, chips in enumerate(opponent_chips):
|
|
685
|
+
obs_text += f"Opponent {i+1} chips: {chips}\n"
|
|
686
|
+
|
|
687
|
+
return obs_text
|
|
688
|
+
|
|
689
|
+
def _format_legal_actions(self, legal_actions: List[int]) -> str:
|
|
690
|
+
r"""Format the legal actions for Leduc Hold'em.
|
|
691
|
+
|
|
692
|
+
Args:
|
|
693
|
+
legal_actions (List[int]): The list of legal actions.
|
|
694
|
+
|
|
695
|
+
Returns:
|
|
696
|
+
str: A human-readable representation of the legal actions.
|
|
697
|
+
"""
|
|
698
|
+
action_map = {0: "fold", 1: "check/call", 2: "raise"}
|
|
699
|
+
return ", ".join([action_map[a] for a in legal_actions])
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
class DoudizhuEnv(RLCardsEnv):
|
|
703
|
+
r"""A Doudizhu environment for reinforcement learning with LLMs.
|
|
704
|
+
|
|
705
|
+
This environment implements a standard Doudizhu game where the LLM agent
|
|
706
|
+
plays against two AI opponents.
|
|
707
|
+
"""
|
|
708
|
+
|
|
709
|
+
def __init__(
|
|
710
|
+
self,
|
|
711
|
+
extractor: Optional[BaseExtractor] = None,
|
|
712
|
+
max_steps: Optional[int] = None,
|
|
713
|
+
**kwargs,
|
|
714
|
+
) -> None:
|
|
715
|
+
r"""Initialize the Doudizhu environment.
|
|
716
|
+
|
|
717
|
+
Args:
|
|
718
|
+
extractor (Optional[BaseExtractor]): Extractor to process LLM
|
|
719
|
+
responses. If None, a default extractor will be used.
|
|
720
|
+
(default: :obj:`None`)
|
|
721
|
+
max_steps (Optional[int]): Maximum steps per episode.
|
|
722
|
+
(default: :obj:`None`)
|
|
723
|
+
**kwargs: Additional environment parameters.
|
|
724
|
+
"""
|
|
725
|
+
super().__init__(
|
|
726
|
+
"doudizhu", extractor, max_steps, num_players=3, **kwargs
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
def _convert_to_rlcard_action(self, action_str: str) -> Any:
|
|
730
|
+
r"""Convert a string action to the format expected by RLCard Doudizhu.
|
|
731
|
+
|
|
732
|
+
Args:
|
|
733
|
+
action_str (str): The string representation of the action.
|
|
734
|
+
Expected to be a card combination or 'pass'.
|
|
735
|
+
|
|
736
|
+
Returns:
|
|
737
|
+
str: The action string in the format expected by RLCard.
|
|
738
|
+
"""
|
|
739
|
+
action_str = action_str.lower().strip()
|
|
740
|
+
if action_str == "pass":
|
|
741
|
+
return "pass"
|
|
742
|
+
|
|
743
|
+
# For card combinations, we need to convert them to the RLCard format
|
|
744
|
+
# This is a simplified implementation and might need to be adjusted
|
|
745
|
+
# based on the exact format expected by RLCard
|
|
746
|
+
|
|
747
|
+
# Remove spaces and convert to uppercase for consistency
|
|
748
|
+
action_str = action_str.replace(" ", "").upper()
|
|
749
|
+
|
|
750
|
+
# Check if the action is in the legal actions
|
|
751
|
+
if action_str in self._state["legal_actions"]:
|
|
752
|
+
return action_str
|
|
753
|
+
|
|
754
|
+
return None
|
|
755
|
+
|
|
756
|
+
def _format_state_for_observation(self, state: Dict[str, Any]) -> str:
|
|
757
|
+
r"""Format the Doudizhu state for human-readable observation.
|
|
758
|
+
|
|
759
|
+
Args:
|
|
760
|
+
state (Dict[str, Any]): The RLCard state dictionary.
|
|
761
|
+
|
|
762
|
+
Returns:
|
|
763
|
+
str: A human-readable representation of the state.
|
|
764
|
+
"""
|
|
765
|
+
if state is None:
|
|
766
|
+
return "Game not started yet."
|
|
767
|
+
|
|
768
|
+
# Extract state information
|
|
769
|
+
raw_obs = state['raw_obs']
|
|
770
|
+
current_hand = raw_obs['current_hand']
|
|
771
|
+
# potentially useful for debugging
|
|
772
|
+
# played_cards = raw_obs['played_cards']
|
|
773
|
+
landlord = raw_obs['landlord']
|
|
774
|
+
landlord_up_played_cards = raw_obs['landlord_up_played_cards']
|
|
775
|
+
landlord_down_played_cards = raw_obs['landlord_down_played_cards']
|
|
776
|
+
landlord_played_cards = raw_obs['landlord_played_cards']
|
|
777
|
+
bomb_num = raw_obs['bomb_num']
|
|
778
|
+
|
|
779
|
+
# Format the observation
|
|
780
|
+
obs_text = ""
|
|
781
|
+
|
|
782
|
+
# Player role
|
|
783
|
+
if landlord == 0:
|
|
784
|
+
obs_text += "You are the Landlord.\n"
|
|
785
|
+
else:
|
|
786
|
+
obs_text += (
|
|
787
|
+
f"You are a Peasant. Player {landlord} is the Landlord.\n"
|
|
788
|
+
)
|
|
789
|
+
|
|
790
|
+
# Current hand
|
|
791
|
+
obs_text += f"Your hand: {self._format_cards(current_hand)}\n"
|
|
792
|
+
|
|
793
|
+
# Last played cards by each player
|
|
794
|
+
obs_text += "Last played cards:\n"
|
|
795
|
+
if landlord == 0:
|
|
796
|
+
inter_text = self._format_cards(landlord_played_cards)
|
|
797
|
+
obs_text += f" You (Landlord): {inter_text}\n"
|
|
798
|
+
inter_text = self._format_cards(landlord_up_played_cards)
|
|
799
|
+
obs_text += f" Peasant 1: {inter_text}\n"
|
|
800
|
+
inter_text = self._format_cards(landlord_down_played_cards)
|
|
801
|
+
obs_text += f" Peasant 2: {inter_text}\n"
|
|
802
|
+
elif landlord == 1:
|
|
803
|
+
obs_text += (
|
|
804
|
+
f" You: {self._format_cards(landlord_up_played_cards)}\n"
|
|
805
|
+
)
|
|
806
|
+
obs_text += (
|
|
807
|
+
f" Landlord: {self._format_cards(landlord_played_cards)}\n"
|
|
808
|
+
)
|
|
809
|
+
|
|
810
|
+
inter_text = self._format_cards(landlord_down_played_cards)
|
|
811
|
+
obs_text += f" Other Peasant: {inter_text}\n"
|
|
812
|
+
else: # landlord == 2
|
|
813
|
+
obs_text += (
|
|
814
|
+
f" You: {self._format_cards(landlord_down_played_cards)}\n"
|
|
815
|
+
)
|
|
816
|
+
obs_text += (
|
|
817
|
+
f" Landlord: {self._format_cards(landlord_played_cards)}\n"
|
|
818
|
+
)
|
|
819
|
+
|
|
820
|
+
inter_text = self._format_cards(landlord_up_played_cards)
|
|
821
|
+
obs_text += f" Other Peasant: {inter_text}\n"
|
|
822
|
+
|
|
823
|
+
# Bomb count
|
|
824
|
+
obs_text += f"Number of bombs played: {bomb_num}\n"
|
|
825
|
+
|
|
826
|
+
return obs_text
|
|
827
|
+
|
|
828
|
+
def _format_legal_actions(self, legal_actions: List[str]) -> str:
|
|
829
|
+
r"""Format the legal actions for Doudizhu.
|
|
830
|
+
|
|
831
|
+
Args:
|
|
832
|
+
legal_actions (List[str]): The list of legal actions.
|
|
833
|
+
|
|
834
|
+
Returns:
|
|
835
|
+
str: A human-readable representation of the legal actions.
|
|
836
|
+
"""
|
|
837
|
+
# For simplicity, we'll just list the first few legal
|
|
838
|
+
# actions if there are many
|
|
839
|
+
if len(legal_actions) > 10:
|
|
840
|
+
action_str = (
|
|
841
|
+
", ".join(legal_actions[:10])
|
|
842
|
+
+ f" and {len(legal_actions) - 10} more options"
|
|
843
|
+
)
|
|
844
|
+
else:
|
|
845
|
+
action_str = ", ".join(legal_actions)
|
|
846
|
+
|
|
847
|
+
return action_str
|
|
848
|
+
|
|
849
|
+
def _format_cards(self, cards: List[str]) -> str:
|
|
850
|
+
r"""Format a list of cards for display.
|
|
851
|
+
|
|
852
|
+
Args:
|
|
853
|
+
cards (List[str]): List of card strings.
|
|
854
|
+
|
|
855
|
+
Returns:
|
|
856
|
+
str: Formatted card string.
|
|
857
|
+
"""
|
|
858
|
+
if not cards:
|
|
859
|
+
return "None"
|
|
860
|
+
return " ".join(cards)
|