synth-ai 0.1.9__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +28 -2
- synth_ai/core/system.py +4 -0
- synth_ai/environments/__init__.py +35 -0
- synth_ai/environments/environment/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/base.py +50 -0
- synth_ai/environments/environment/core.py +22 -0
- synth_ai/environments/environment/db/__init__.py +1 -0
- synth_ai/environments/environment/db/sqlite.py +45 -0
- synth_ai/environments/environment/registry.py +24 -0
- synth_ai/environments/environment/resources/sqlite.py +46 -0
- synth_ai/environments/environment/results.py +1 -0
- synth_ai/environments/environment/rewards/__init__.py +1 -0
- synth_ai/environments/environment/rewards/core.py +28 -0
- synth_ai/environments/environment/shared_engine.py +26 -0
- synth_ai/environments/environment/tools/__init__.py +34 -0
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +872 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +1110 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/engine.py +502 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/environment.py +255 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +228 -0
- synth_ai/environments/examples/enron/agent_demos/test_synth_react.py +535 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +111 -0
- synth_ai/environments/examples/enron/units/test_email_index.py +8 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +47 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +220 -0
- synth_ai/environments/examples/minigrid/agent_demos/test_minigrid_react_agent.py +393 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/minigrid/units/test_action_behavior.py +226 -0
- synth_ai/environments/examples/minigrid/units/test_debug_messages.py +83 -0
- synth_ai/environments/examples/minigrid/units/test_exploration.py +120 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_engine.py +214 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment.py +238 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment_mapping.py +301 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_taskset.py +210 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +832 -0
- synth_ai/environments/examples/nethack/agent_demos/test_nethack_react_agent.py +1112 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/nethack/units/test_nethack_engine.py +277 -0
- synth_ai/environments/examples/nethack/units/test_nethack_environment.py +281 -0
- synth_ai/environments/examples/nethack/units/test_nethack_taskset.py +213 -0
- synth_ai/environments/examples/nethack/units/test_recording.py +307 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/agent_demos/test_synth_react.py +1471 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/red/test_fixes.py +125 -0
- synth_ai/environments/examples/red/test_fixes_mock.py +148 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/red/units/test_basic_functionality.py +97 -0
- synth_ai/environments/examples/red/units/test_button_press_requirements.py +217 -0
- synth_ai/environments/examples/red/units/test_engine.py +192 -0
- synth_ai/environments/examples/red/units/test_environment.py +455 -0
- synth_ai/environments/examples/red/units/test_exploration_strategy.py +227 -0
- synth_ai/environments/examples/red/units/test_integration.py +217 -0
- synth_ai/environments/examples/red/units/test_memory_extraction.py +111 -0
- synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +1100 -0
- synth_ai/environments/examples/red/units/test_movement_debug.py +255 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_debug.py +163 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_verbose.py +117 -0
- synth_ai/environments/examples/red/units/test_red_basic.py +145 -0
- synth_ai/environments/examples/red/units/test_red_comprehensive.py +323 -0
- synth_ai/environments/examples/red/units/test_retry_movement.py +195 -0
- synth_ai/environments/examples/red/units/test_reward_components.py +186 -0
- synth_ai/environments/examples/red/units/test_rom_integration.py +260 -0
- synth_ai/environments/examples/red/units/test_taskset.py +116 -0
- synth_ai/environments/examples/red/units/test_tree.py +448 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +900 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_dspy_react.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_sokoban_react_agent.py +498 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_lats.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_locally.py +748 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_service.py +296 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +94 -0
- synth_ai/environments/examples/sokoban/units/test_building_task_set.py +49 -0
- synth_ai/environments/examples/sokoban/units/test_false_positive.py +120 -0
- synth_ai/environments/examples/sokoban/units/test_simple_run_through_environment.py +119 -0
- synth_ai/environments/examples/sokoban/units/test_sokoban_environment.py +98 -0
- synth_ai/environments/examples/sokoban/units/test_tree.py +364 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_synth_react.py +266 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_tictactoe_react_agent.py +470 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_engine.py +393 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_environment.py +493 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_taskset.py +191 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/agent_demos/test_synth_react.py +520 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/environments/examples/verilog/units/test_verilog_engine.py +466 -0
- synth_ai/environments/examples/verilog/units/test_verilog_environment.py +585 -0
- synth_ai/environments/examples/verilog/units/test_verilog_integration.py +383 -0
- synth_ai/environments/examples/verilog/units/test_verilog_taskset.py +457 -0
- synth_ai/environments/reproducibility/core.py +42 -0
- synth_ai/environments/reproducibility/tree.py +364 -0
- synth_ai/environments/service/app.py +78 -0
- synth_ai/environments/service/core_routes.py +775 -0
- synth_ai/environments/service/external_registry.py +57 -0
- synth_ai/environments/service/registry.py +9 -0
- synth_ai/environments/stateful/__init__.py +1 -0
- synth_ai/environments/stateful/core.py +28 -0
- synth_ai/environments/stateful/engine.py +21 -0
- synth_ai/environments/stateful/state.py +7 -0
- synth_ai/environments/tasks/api.py +19 -0
- synth_ai/environments/tasks/core.py +78 -0
- synth_ai/environments/tasks/filters.py +39 -0
- synth_ai/environments/tasks/utils.py +89 -0
- synth_ai/environments/v0_observability/history.py +3 -0
- synth_ai/environments/v0_observability/log.py +2 -0
- synth_ai/lm/caching/constants.py +1 -0
- synth_ai/{zyk/lms → lm}/caching/ephemeral.py +4 -8
- synth_ai/{zyk/lms → lm}/caching/handler.py +15 -15
- synth_ai/{zyk/lms → lm}/caching/initialize.py +2 -4
- synth_ai/{zyk/lms → lm}/caching/persistent.py +4 -10
- synth_ai/{zyk/lms → lm}/config.py +2 -1
- synth_ai/{zyk/lms → lm}/constants.py +2 -2
- synth_ai/{zyk/lms → lm}/core/all.py +10 -10
- synth_ai/{zyk/lms → lm}/core/main.py +57 -33
- synth_ai/{zyk/lms → lm}/core/vendor_clients.py +12 -10
- synth_ai/lm/cost/monitor.py +1 -0
- synth_ai/lm/cost/statefulness.py +1 -0
- synth_ai/lm/provider_support/__init__.py +8 -0
- synth_ai/lm/provider_support/anthropic.py +945 -0
- synth_ai/lm/provider_support/openai.py +1115 -0
- synth_ai/lm/provider_support/suppress_logging.py +31 -0
- synth_ai/{zyk/lms → lm}/structured_outputs/handler.py +58 -80
- synth_ai/{zyk/lms → lm}/structured_outputs/inject.py +6 -20
- synth_ai/{zyk/lms → lm}/structured_outputs/rehabilitate.py +6 -12
- synth_ai/{zyk/lms → lm}/vendors/core/anthropic_api.py +21 -30
- synth_ai/{zyk/lms → lm}/vendors/core/gemini_api.py +37 -32
- synth_ai/{zyk/lms → lm}/vendors/core/mistral_api.py +19 -28
- synth_ai/{zyk/lms → lm}/vendors/core/openai_api.py +26 -36
- synth_ai/{zyk/lms → lm}/vendors/openai_standard.py +29 -33
- synth_ai/{zyk/lms → lm}/vendors/retries.py +1 -1
- synth_ai/lm/vendors/supported/__init__.py +0 -0
- synth_ai/{zyk/lms → lm}/vendors/supported/custom_endpoint.py +131 -118
- synth_ai/{zyk/lms → lm}/vendors/supported/deepseek.py +4 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/grok.py +6 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/groq.py +1 -1
- synth_ai/{zyk/lms → lm}/vendors/supported/ollama.py +2 -2
- synth_ai/{zyk/lms → lm}/vendors/supported/openrouter.py +18 -16
- synth_ai/{zyk/lms → lm}/vendors/supported/together.py +1 -1
- synth_ai/tracing/__init__.py +0 -0
- synth_ai/tracing/abstractions.py +224 -0
- synth_ai/tracing/base_client.py +91 -0
- synth_ai/tracing/client_manager.py +131 -0
- synth_ai/tracing/config.py +140 -0
- synth_ai/tracing/context.py +146 -0
- synth_ai/tracing/decorators.py +679 -0
- synth_ai/tracing/events/__init__.py +0 -0
- synth_ai/tracing/events/manage.py +147 -0
- synth_ai/tracing/events/scope.py +86 -0
- synth_ai/tracing/events/store.py +227 -0
- synth_ai/tracing/immediate_client.py +152 -0
- synth_ai/tracing/local.py +18 -0
- synth_ai/tracing/log_client_base.py +74 -0
- synth_ai/tracing/retry_queue.py +187 -0
- synth_ai/tracing/trackers.py +515 -0
- synth_ai/tracing/upload.py +504 -0
- synth_ai/tracing/utils.py +9 -0
- synth_ai/zyk/__init__.py +28 -2
- synth_ai-0.2.1.dev0.dist-info/METADATA +349 -0
- synth_ai-0.2.1.dev0.dist-info/RECORD +261 -0
- synth_ai/zyk/lms/caching/constants.py +0 -1
- synth_ai/zyk/lms/cost/monitor.py +0 -1
- synth_ai/zyk/lms/cost/statefulness.py +0 -1
- synth_ai-0.1.9.dist-info/METADATA +0 -37
- synth_ai-0.1.9.dist-info/RECORD +0 -50
- /synth_ai/{zyk/lms/__init__.py → environments/reproducibility/helpers.py} +0 -0
- /synth_ai/{zyk/lms/caching → lm}/__init__.py +0 -0
- /synth_ai/{zyk/lms/core → lm/caching}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/caching/dbs.py +0 -0
- /synth_ai/{zyk/lms/cost → lm/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/core/exceptions.py +0 -0
- /synth_ai/{zyk/lms/structured_outputs → lm/cost}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors → lm/structured_outputs}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/core → lm/vendors}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/local → lm/vendors/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors/supported → lm/vendors/local}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/local/ollama.py +0 -0
- {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/WHEEL +0 -0
- {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,493 @@
|
|
1
|
+
import pytest
|
2
|
+
import numpy as np
|
3
|
+
from uuid import uuid4
|
4
|
+
|
5
|
+
from synth_ai.environments.environment.tools import EnvToolCall, ToolResult
|
6
|
+
from synth_ai.environments.tasks.core import TaskInstance, Impetus, Intent
|
7
|
+
from synth_ai.environments.examples.tictactoe.environment import (
|
8
|
+
TicTacToeEnvironment,
|
9
|
+
TicTacToeActionInput,
|
10
|
+
TicTacToeInteractTool,
|
11
|
+
)
|
12
|
+
from synth_ai.environments.examples.tictactoe.engine import TicTacToeEngine
|
13
|
+
from synth_ai.environments.examples.tictactoe.taskset import (
|
14
|
+
TicTacToeTaskInstance,
|
15
|
+
TicTacToeTaskInstanceMetadata,
|
16
|
+
)
|
17
|
+
|
18
|
+
|
19
|
+
@pytest.fixture
|
20
|
+
def simple_task_instance():
|
21
|
+
"""Create a simple task instance for testing."""
|
22
|
+
metadata = TicTacToeTaskInstanceMetadata(
|
23
|
+
starting_player="X",
|
24
|
+
opening_moves=[],
|
25
|
+
optimal_outcome="draw",
|
26
|
+
position_complexity=0,
|
27
|
+
shortest_win_length=5,
|
28
|
+
)
|
29
|
+
|
30
|
+
return TicTacToeTaskInstance(
|
31
|
+
id=uuid4(),
|
32
|
+
impetus=Impetus(instructions="Test TicTacToe game"),
|
33
|
+
intent=Intent(rubric={"goal": "Test game"}, gold_trajectories=None, gold_state_diff={}),
|
34
|
+
metadata=metadata,
|
35
|
+
is_reproducible=True,
|
36
|
+
initial_engine_snapshot=None,
|
37
|
+
)
|
38
|
+
|
39
|
+
|
40
|
+
class TestTicTacToeEnvironment:
|
41
|
+
@pytest.mark.asyncio
|
42
|
+
async def test_environment_initialization(self, simple_task_instance):
|
43
|
+
"""Test environment initializes correctly."""
|
44
|
+
env = TicTacToeEnvironment(simple_task_instance)
|
45
|
+
|
46
|
+
assert env.name == "TicTacToe"
|
47
|
+
assert env.task_instance == simple_task_instance
|
48
|
+
assert env.engine is not None
|
49
|
+
assert env._interact_tool is not None
|
50
|
+
|
51
|
+
# Test initial observation
|
52
|
+
obs = await env.initialize()
|
53
|
+
|
54
|
+
assert "board_text" in obs
|
55
|
+
assert "current_player" in obs
|
56
|
+
assert obs["current_player"] == "X"
|
57
|
+
assert obs["move_count"] == 0
|
58
|
+
assert obs["terminated"] == False
|
59
|
+
|
60
|
+
@pytest.mark.asyncio
|
61
|
+
async def test_step_with_valid_move(self, simple_task_instance):
|
62
|
+
"""Test stepping with valid moves."""
|
63
|
+
env = TicTacToeEnvironment(simple_task_instance)
|
64
|
+
await env.initialize()
|
65
|
+
|
66
|
+
# Test dictionary format
|
67
|
+
obs = await env.step({"action": "B2"})
|
68
|
+
|
69
|
+
assert obs["last_move"] == "B2"
|
70
|
+
assert obs["current_player"] == "O"
|
71
|
+
assert obs["move_count"] == 1
|
72
|
+
assert "X" in obs["board_text"]
|
73
|
+
|
74
|
+
# Test EnvToolCall format
|
75
|
+
tool_call = EnvToolCall(tool="interact", args={"action": "A1"})
|
76
|
+
obs = await env.step(tool_call)
|
77
|
+
|
78
|
+
assert obs["last_move"] == "A1"
|
79
|
+
assert obs["current_player"] == "X"
|
80
|
+
assert obs["move_count"] == 2
|
81
|
+
|
82
|
+
@pytest.mark.asyncio
|
83
|
+
async def test_step_with_invalid_move(self, simple_task_instance):
|
84
|
+
"""Test stepping with invalid moves."""
|
85
|
+
env = TicTacToeEnvironment(simple_task_instance)
|
86
|
+
await env.initialize()
|
87
|
+
|
88
|
+
# Make a valid move first
|
89
|
+
await env.step({"action": "B2"})
|
90
|
+
|
91
|
+
# Try to make move in occupied cell
|
92
|
+
obs = await env.step({"action": "B2"})
|
93
|
+
|
94
|
+
assert obs["terminated"] == True
|
95
|
+
assert obs["reward_last"] == -1.0
|
96
|
+
|
97
|
+
@pytest.mark.asyncio
|
98
|
+
async def test_checkpoint(self, simple_task_instance):
|
99
|
+
"""Test checkpoint functionality."""
|
100
|
+
env = TicTacToeEnvironment(simple_task_instance)
|
101
|
+
await env.initialize()
|
102
|
+
|
103
|
+
# Make some moves
|
104
|
+
await env.step({"action": "B2"})
|
105
|
+
await env.step({"action": "A1"})
|
106
|
+
|
107
|
+
# Get checkpoint
|
108
|
+
checkpoint = await env.checkpoint()
|
109
|
+
|
110
|
+
assert "board_text_final" in checkpoint
|
111
|
+
assert "winner_final" in checkpoint
|
112
|
+
assert "move_count_final" in checkpoint
|
113
|
+
assert checkpoint["move_count_final"] == 2
|
114
|
+
assert checkpoint["total_reward"] == 0.0
|
115
|
+
|
116
|
+
@pytest.mark.asyncio
|
117
|
+
async def test_terminate(self, simple_task_instance):
|
118
|
+
"""Test terminate functionality."""
|
119
|
+
env = TicTacToeEnvironment(simple_task_instance)
|
120
|
+
await env.initialize()
|
121
|
+
|
122
|
+
obs = await env.terminate()
|
123
|
+
|
124
|
+
assert obs["terminated"] == True
|
125
|
+
assert "board_text_final" in obs
|
126
|
+
|
127
|
+
@pytest.mark.asyncio
|
128
|
+
async def test_validate_tool_calls_various_formats(self, simple_task_instance):
|
129
|
+
"""Test tool call validation with various input formats."""
|
130
|
+
env = TicTacToeEnvironment(simple_task_instance)
|
131
|
+
|
132
|
+
# Test EnvToolCall format
|
133
|
+
call = EnvToolCall(tool="interact", args={"action": "A1"})
|
134
|
+
validated = env.validate_tool_calls(call)
|
135
|
+
assert validated.tool == "interact"
|
136
|
+
assert validated.args["letter"] == "A"
|
137
|
+
assert validated.args["number"] == 1
|
138
|
+
|
139
|
+
# Test dict with tool/args
|
140
|
+
validated = env.validate_tool_calls({"tool": "interact", "args": {"action": "B2"}})
|
141
|
+
assert validated.tool == "interact"
|
142
|
+
assert validated.args["letter"] == "B"
|
143
|
+
assert validated.args["number"] == 2
|
144
|
+
|
145
|
+
# Test dict with name/parameters (legacy)
|
146
|
+
validated = env.validate_tool_calls({"name": "interact", "parameters": {"action": "C3"}})
|
147
|
+
assert validated.tool == "interact"
|
148
|
+
assert validated.args["letter"] == "C"
|
149
|
+
assert validated.args["number"] == 3
|
150
|
+
|
151
|
+
# Test OpenAI function format
|
152
|
+
validated = env.validate_tool_calls(
|
153
|
+
{"function": {"name": "interact", "arguments": {"action": "A2"}}}
|
154
|
+
)
|
155
|
+
assert validated.tool == "interact"
|
156
|
+
assert validated.args["letter"] == "A"
|
157
|
+
assert validated.args["number"] == 2
|
158
|
+
|
159
|
+
# Test bare dict (assumed to be args)
|
160
|
+
validated = env.validate_tool_calls({"action": "B1"})
|
161
|
+
assert validated.tool == "interact"
|
162
|
+
assert validated.args["letter"] == "B"
|
163
|
+
assert validated.args["number"] == 1
|
164
|
+
|
165
|
+
# Test list format
|
166
|
+
validated = env.validate_tool_calls([{"tool": "interact", "args": {"action": "C1"}}])
|
167
|
+
assert validated.tool == "interact"
|
168
|
+
assert validated.args["letter"] == "C"
|
169
|
+
assert validated.args["number"] == 1
|
170
|
+
|
171
|
+
# Test string conversion
|
172
|
+
validated = env.validate_tool_calls("A3")
|
173
|
+
assert validated.tool == "interact"
|
174
|
+
assert validated.args["letter"] == "A"
|
175
|
+
assert validated.args["number"] == 3
|
176
|
+
|
177
|
+
@pytest.mark.asyncio
|
178
|
+
async def test_validate_tool_calls_errors(self, simple_task_instance):
|
179
|
+
"""Test tool call validation error cases."""
|
180
|
+
env = TicTacToeEnvironment(simple_task_instance)
|
181
|
+
|
182
|
+
# Test wrong tool name
|
183
|
+
with pytest.raises(ValueError, match="Unknown tool"):
|
184
|
+
env.validate_tool_calls({"tool": "wrong_tool", "args": {}})
|
185
|
+
|
186
|
+
# Test empty list
|
187
|
+
with pytest.raises(ValueError, match="Empty tool calls list"):
|
188
|
+
env.validate_tool_calls([])
|
189
|
+
|
190
|
+
@pytest.mark.asyncio
|
191
|
+
async def test_serialization(self, simple_task_instance):
|
192
|
+
"""Test environment serialization and deserialization."""
|
193
|
+
env = TicTacToeEnvironment(simple_task_instance)
|
194
|
+
await env.initialize()
|
195
|
+
|
196
|
+
# Make some moves
|
197
|
+
await env.step({"action": "B2"})
|
198
|
+
await env.step({"action": "A1"})
|
199
|
+
|
200
|
+
# Serialize
|
201
|
+
snapshot = await env._serialize_engine()
|
202
|
+
|
203
|
+
# Deserialize
|
204
|
+
restored_env = await TicTacToeEnvironment._deserialize_engine(
|
205
|
+
snapshot, simple_task_instance
|
206
|
+
)
|
207
|
+
|
208
|
+
# Check state is preserved
|
209
|
+
assert restored_env.engine.current_player == env.engine.current_player
|
210
|
+
assert restored_env.engine.move_count == env.engine.move_count
|
211
|
+
assert np.array_equal(restored_env.engine.board, env.engine.board)
|
212
|
+
|
213
|
+
@pytest.mark.asyncio
|
214
|
+
async def test_full_game_to_win(self, simple_task_instance):
|
215
|
+
"""Test playing a full game to win."""
|
216
|
+
env = TicTacToeEnvironment(simple_task_instance)
|
217
|
+
await env.initialize()
|
218
|
+
|
219
|
+
# X wins in top row
|
220
|
+
moves = [
|
221
|
+
("A1", "X", "O"),
|
222
|
+
("B1", "O", "X"),
|
223
|
+
("A2", "X", "O"),
|
224
|
+
("B2", "O", "X"),
|
225
|
+
("A3", "X", None), # X wins
|
226
|
+
]
|
227
|
+
|
228
|
+
for move, player_before, player_after in moves:
|
229
|
+
obs = await env.step({"action": move})
|
230
|
+
|
231
|
+
if player_after:
|
232
|
+
assert obs["current_player"] == player_after
|
233
|
+
else:
|
234
|
+
assert obs["terminated"] == True
|
235
|
+
assert obs["winner"] == "X"
|
236
|
+
assert obs["reward_last"] == 1.0
|
237
|
+
|
238
|
+
@pytest.mark.asyncio
|
239
|
+
async def test_full_game_to_draw(self, simple_task_instance):
|
240
|
+
"""Test playing a full game to draw."""
|
241
|
+
env = TicTacToeEnvironment(simple_task_instance)
|
242
|
+
await env.initialize()
|
243
|
+
|
244
|
+
# Play a draw game
|
245
|
+
moves = ["A1", "B2", "A2", "A3", "B3", "B1", "C1", "C3", "C2"]
|
246
|
+
|
247
|
+
for i, move in enumerate(moves):
|
248
|
+
obs = await env.step({"action": move})
|
249
|
+
|
250
|
+
if i < len(moves) - 1:
|
251
|
+
assert not obs["terminated"]
|
252
|
+
else:
|
253
|
+
assert obs["terminated"]
|
254
|
+
assert obs["winner"] == "draw"
|
255
|
+
assert obs["move_count"] == 9
|
256
|
+
assert obs["reward_last"] == 0.0
|
257
|
+
|
258
|
+
|
259
|
+
class TestTicTacToeInteractTool:
|
260
|
+
@pytest.mark.asyncio
|
261
|
+
async def test_interact_tool_valid_action(self, simple_task_instance):
|
262
|
+
"""Test interact tool with valid action."""
|
263
|
+
engine = TicTacToeEngine(simple_task_instance)
|
264
|
+
tool = TicTacToeInteractTool(engine)
|
265
|
+
|
266
|
+
call = EnvToolCall(tool="interact", args={"letter": "B", "number": 2})
|
267
|
+
result = await tool(call)
|
268
|
+
|
269
|
+
assert result.ok
|
270
|
+
assert "public_state" in result.payload
|
271
|
+
assert "private_state" in result.payload
|
272
|
+
|
273
|
+
pub_state = result.payload["public_state"]
|
274
|
+
assert pub_state.last_move == "B2"
|
275
|
+
assert pub_state.current_player == "O"
|
276
|
+
|
277
|
+
@pytest.mark.asyncio
|
278
|
+
async def test_interact_tool_invalid_action(self, simple_task_instance):
|
279
|
+
"""Test interact tool with invalid action."""
|
280
|
+
engine = TicTacToeEngine(simple_task_instance)
|
281
|
+
tool = TicTacToeInteractTool(engine)
|
282
|
+
|
283
|
+
# Make a move first
|
284
|
+
await tool(EnvToolCall(tool="interact", args={"letter": "B", "number": 2}))
|
285
|
+
|
286
|
+
# Try same cell again
|
287
|
+
call = EnvToolCall(tool="interact", args={"letter": "B", "number": 2})
|
288
|
+
result = await tool(call)
|
289
|
+
|
290
|
+
assert result.ok
|
291
|
+
assert result.payload["public_state"].terminated
|
292
|
+
assert result.payload["private_state"].reward_last == -1.0
|
293
|
+
|
294
|
+
@pytest.mark.asyncio
|
295
|
+
async def test_interact_tool_no_action(self, simple_task_instance):
|
296
|
+
"""Test interact tool with missing action."""
|
297
|
+
engine = TicTacToeEngine(simple_task_instance)
|
298
|
+
tool = TicTacToeInteractTool(engine)
|
299
|
+
|
300
|
+
call = EnvToolCall(tool="interact", args={})
|
301
|
+
result = await tool(call)
|
302
|
+
|
303
|
+
assert not result.ok
|
304
|
+
assert result.error == "Both letter and number parameters are required"
|
305
|
+
|
306
|
+
@pytest.mark.asyncio
|
307
|
+
async def test_interact_tool_exception_handling(self, simple_task_instance):
|
308
|
+
"""Test interact tool exception handling."""
|
309
|
+
engine = TicTacToeEngine(simple_task_instance)
|
310
|
+
tool = TicTacToeInteractTool(engine)
|
311
|
+
|
312
|
+
# Force an exception by passing invalid data type
|
313
|
+
call = EnvToolCall(tool="interact", args={"letter": "A", "number": None})
|
314
|
+
result = await tool(call)
|
315
|
+
|
316
|
+
assert not result.ok
|
317
|
+
assert result.error == "Both letter and number parameters are required"
|
318
|
+
|
319
|
+
|
320
|
+
class TestTicTacToeActionInput:
|
321
|
+
def test_action_input_model(self):
|
322
|
+
"""Test TicTacToeActionInput Pydantic model."""
|
323
|
+
# Valid input
|
324
|
+
input_model = TicTacToeActionInput(letter="A", number=1)
|
325
|
+
assert input_model.letter == "A"
|
326
|
+
assert input_model.number == 1
|
327
|
+
|
328
|
+
# Test schema
|
329
|
+
schema = TicTacToeActionInput.model_json_schema()
|
330
|
+
assert "properties" in schema
|
331
|
+
assert "letter" in schema["properties"]
|
332
|
+
assert "number" in schema["properties"]
|
333
|
+
assert schema["properties"]["letter"]["type"] == "string"
|
334
|
+
assert schema["properties"]["number"]["type"] == "integer"
|
335
|
+
|
336
|
+
|
337
|
+
class TestTicTacToeValidation:
|
338
|
+
"""Test validation fixes for TicTacToe environment."""
|
339
|
+
|
340
|
+
@pytest.mark.asyncio
|
341
|
+
async def test_position_validation_valid_positions(self, simple_task_instance):
|
342
|
+
"""Test that valid positions (0-8) are correctly converted."""
|
343
|
+
env = TicTacToeEnvironment(simple_task_instance)
|
344
|
+
|
345
|
+
# Test all valid positions
|
346
|
+
valid_positions = [0, 1, 2, 3, 4, 5, 6, 7, 8]
|
347
|
+
expected_conversions = [
|
348
|
+
("A", 1),
|
349
|
+
("A", 2),
|
350
|
+
("A", 3),
|
351
|
+
("B", 1),
|
352
|
+
("B", 2),
|
353
|
+
("B", 3),
|
354
|
+
("C", 1),
|
355
|
+
("C", 2),
|
356
|
+
("C", 3),
|
357
|
+
]
|
358
|
+
|
359
|
+
for pos, (expected_letter, expected_number) in zip(valid_positions, expected_conversions):
|
360
|
+
validated_call = env.validate_tool_calls(
|
361
|
+
{"tool": "interact", "args": {"position": pos}}
|
362
|
+
)
|
363
|
+
assert validated_call.args["letter"] == expected_letter
|
364
|
+
assert validated_call.args["number"] == expected_number
|
365
|
+
|
366
|
+
@pytest.mark.asyncio
|
367
|
+
async def test_position_validation_invalid_positions(self, simple_task_instance):
|
368
|
+
"""Test that invalid positions are properly rejected."""
|
369
|
+
env = TicTacToeEnvironment(simple_task_instance)
|
370
|
+
|
371
|
+
invalid_positions = [-1, 9, 10, 100, -5]
|
372
|
+
|
373
|
+
for pos in invalid_positions:
|
374
|
+
with pytest.raises(ValueError, match=f"Position {pos} must be between 0 and 8"):
|
375
|
+
env.validate_tool_calls({"tool": "interact", "args": {"position": pos}})
|
376
|
+
|
377
|
+
@pytest.mark.asyncio
|
378
|
+
async def test_coordinate_validation_valid_coordinates(self, simple_task_instance):
|
379
|
+
"""Test that valid coordinate strings are correctly converted."""
|
380
|
+
env = TicTacToeEnvironment(simple_task_instance)
|
381
|
+
|
382
|
+
valid_coords = ["A1", "A2", "A3", "B1", "B2", "B3", "C1", "C2", "C3"]
|
383
|
+
|
384
|
+
for coord in valid_coords:
|
385
|
+
validated_call = env.validate_tool_calls(
|
386
|
+
{"tool": "interact", "args": {"action": coord}}
|
387
|
+
)
|
388
|
+
expected_letter = coord[0]
|
389
|
+
expected_number = int(coord[1])
|
390
|
+
assert validated_call.args["letter"] == expected_letter
|
391
|
+
assert validated_call.args["number"] == expected_number
|
392
|
+
|
393
|
+
@pytest.mark.asyncio
|
394
|
+
async def test_coordinate_validation_invalid_coordinates(self, simple_task_instance):
|
395
|
+
"""Test that invalid coordinate strings are properly rejected."""
|
396
|
+
env = TicTacToeEnvironment(simple_task_instance)
|
397
|
+
|
398
|
+
# Test invalid numbers
|
399
|
+
with pytest.raises(ValueError, match="Number must be 1, 2, or 3, got 0"):
|
400
|
+
env.validate_tool_calls({"tool": "interact", "args": {"action": "A0"}})
|
401
|
+
|
402
|
+
with pytest.raises(ValueError, match="Number must be 1, 2, or 3, got 4"):
|
403
|
+
env.validate_tool_calls({"tool": "interact", "args": {"action": "A4"}})
|
404
|
+
|
405
|
+
# Test invalid letters
|
406
|
+
with pytest.raises(ValueError, match="Letter must be A, B, or C, got 'D'"):
|
407
|
+
env.validate_tool_calls({"tool": "interact", "args": {"action": "D1"}})
|
408
|
+
|
409
|
+
with pytest.raises(ValueError, match="Letter must be A, B, or C, got 'Z'"):
|
410
|
+
env.validate_tool_calls({"tool": "interact", "args": {"action": "Z9"}})
|
411
|
+
|
412
|
+
# Test invalid format
|
413
|
+
with pytest.raises(ValueError, match="Action '' must be 2 characters"):
|
414
|
+
env.validate_tool_calls({"tool": "interact", "args": {"action": ""}})
|
415
|
+
|
416
|
+
with pytest.raises(ValueError, match="Action '1A' must have a numeric second character"):
|
417
|
+
env.validate_tool_calls({"tool": "interact", "args": {"action": "1A"}})
|
418
|
+
|
419
|
+
with pytest.raises(ValueError, match="Action 'BB' must have a numeric second character"):
|
420
|
+
env.validate_tool_calls({"tool": "interact", "args": {"action": "BB"}})
|
421
|
+
|
422
|
+
@pytest.mark.asyncio
|
423
|
+
async def test_direct_letter_number_validation_valid(self, simple_task_instance):
|
424
|
+
"""Test that valid letter/number combinations work."""
|
425
|
+
env = TicTacToeEnvironment(simple_task_instance)
|
426
|
+
|
427
|
+
valid_combinations = [
|
428
|
+
("A", 1),
|
429
|
+
("A", 2),
|
430
|
+
("A", 3),
|
431
|
+
("B", 1),
|
432
|
+
("B", 2),
|
433
|
+
("B", 3),
|
434
|
+
("C", 1),
|
435
|
+
("C", 2),
|
436
|
+
("C", 3),
|
437
|
+
]
|
438
|
+
|
439
|
+
for letter, number in valid_combinations:
|
440
|
+
validated_call = env.validate_tool_calls(
|
441
|
+
{"tool": "interact", "args": {"letter": letter, "number": number}}
|
442
|
+
)
|
443
|
+
assert validated_call.args["letter"] == letter
|
444
|
+
assert validated_call.args["number"] == number
|
445
|
+
|
446
|
+
@pytest.mark.asyncio
|
447
|
+
async def test_direct_letter_number_validation_invalid(self, simple_task_instance):
|
448
|
+
"""Test that invalid letter/number combinations are rejected."""
|
449
|
+
env = TicTacToeEnvironment(simple_task_instance)
|
450
|
+
|
451
|
+
# Test invalid numbers
|
452
|
+
with pytest.raises(ValueError, match="Number must be 1, 2, or 3, got 0"):
|
453
|
+
env.validate_tool_calls({"tool": "interact", "args": {"letter": "A", "number": 0}})
|
454
|
+
|
455
|
+
with pytest.raises(ValueError, match="Number must be 1, 2, or 3, got 4"):
|
456
|
+
env.validate_tool_calls({"tool": "interact", "args": {"letter": "A", "number": 4}})
|
457
|
+
|
458
|
+
# Test invalid letters
|
459
|
+
with pytest.raises(ValueError, match="Letter must be A, B, or C, got 'D'"):
|
460
|
+
env.validate_tool_calls({"tool": "interact", "args": {"letter": "D", "number": 1}})
|
461
|
+
|
462
|
+
with pytest.raises(ValueError, match="Letter must be A, B, or C, got 'a'"):
|
463
|
+
env.validate_tool_calls({"tool": "interact", "args": {"letter": "a", "number": 1}})
|
464
|
+
|
465
|
+
with pytest.raises(ValueError, match="Letter must be A, B, or C, got ''"):
|
466
|
+
env.validate_tool_calls({"tool": "interact", "args": {"letter": "", "number": 1}})
|
467
|
+
|
468
|
+
with pytest.raises(ValueError, match="Letter must be A, B, or C, got 'AB'"):
|
469
|
+
env.validate_tool_calls({"tool": "interact", "args": {"letter": "AB", "number": 1}})
|
470
|
+
|
471
|
+
@pytest.mark.asyncio
|
472
|
+
async def test_tool_validates_before_execution(self, simple_task_instance):
|
473
|
+
"""Test that the tool validates inputs before execution."""
|
474
|
+
env = TicTacToeEnvironment(simple_task_instance)
|
475
|
+
|
476
|
+
# Test invalid letter directly on tool
|
477
|
+
result = await env._interact_tool(
|
478
|
+
EnvToolCall(tool="interact", args={"letter": "Z", "number": 1})
|
479
|
+
)
|
480
|
+
assert not result.ok
|
481
|
+
assert "Letter must be A, B, or C, got 'Z'" in result.error
|
482
|
+
|
483
|
+
# Test invalid number directly on tool
|
484
|
+
result = await env._interact_tool(
|
485
|
+
EnvToolCall(tool="interact", args={"letter": "A", "number": 0})
|
486
|
+
)
|
487
|
+
assert not result.ok
|
488
|
+
assert "Number must be 1, 2, or 3, got 0" in result.error
|
489
|
+
|
490
|
+
# Test missing parameters
|
491
|
+
result = await env._interact_tool(EnvToolCall(tool="interact", args={}))
|
492
|
+
assert not result.ok
|
493
|
+
assert "Both letter and number parameters are required" in result.error
|
@@ -0,0 +1,191 @@
|
|
1
|
+
import pytest
|
2
|
+
import numpy as np
|
3
|
+
from uuid import UUID
|
4
|
+
|
5
|
+
from synth_ai.environments.examples.tictactoe.taskset import (
|
6
|
+
create_tictactoe_taskset,
|
7
|
+
TicTacToeTaskInstance,
|
8
|
+
TicTacToeTaskInstanceMetadata,
|
9
|
+
_evaluate_position,
|
10
|
+
_count_shortest_win,
|
11
|
+
COORD_TO_IDX,
|
12
|
+
PLAYER_MARKS,
|
13
|
+
)
|
14
|
+
|
15
|
+
|
16
|
+
class TestTasksetGeneration:
|
17
|
+
@pytest.mark.asyncio
|
18
|
+
async def test_create_taskset(self):
|
19
|
+
"""Test taskset creation."""
|
20
|
+
taskset = await create_tictactoe_taskset()
|
21
|
+
|
22
|
+
assert taskset.name == "TicTacToe Procedural TaskSet"
|
23
|
+
assert len(taskset.instances) == 50 # 10 + 15 + 15 + 10
|
24
|
+
assert taskset.split_info._is_split_defined
|
25
|
+
|
26
|
+
# Check that we have instances of each complexity
|
27
|
+
complexities = [inst.metadata.position_complexity for inst in taskset.instances]
|
28
|
+
assert 0 in complexities # opening positions
|
29
|
+
assert 1 in complexities # early positions
|
30
|
+
assert 2 in complexities # mid positions
|
31
|
+
assert 3 in complexities # complex positions
|
32
|
+
|
33
|
+
@pytest.mark.asyncio
|
34
|
+
async def test_task_instance_metadata(self):
|
35
|
+
"""Test task instance metadata."""
|
36
|
+
taskset = await create_tictactoe_taskset()
|
37
|
+
|
38
|
+
for instance in taskset.instances:
|
39
|
+
metadata = instance.metadata
|
40
|
+
|
41
|
+
# Check metadata fields
|
42
|
+
assert metadata.starting_player in ["X", "O"]
|
43
|
+
assert isinstance(metadata.opening_moves, list)
|
44
|
+
assert metadata.optimal_outcome in ["win", "draw", "loss"]
|
45
|
+
assert metadata.position_complexity >= 0
|
46
|
+
assert metadata.shortest_win_length >= 1
|
47
|
+
|
48
|
+
# Check opening moves match complexity
|
49
|
+
assert len(metadata.opening_moves) == metadata.position_complexity
|
50
|
+
|
51
|
+
@pytest.mark.asyncio
|
52
|
+
async def test_task_instance_structure(self):
|
53
|
+
"""Test task instance structure."""
|
54
|
+
taskset = await create_tictactoe_taskset()
|
55
|
+
instance = taskset.instances[0]
|
56
|
+
|
57
|
+
# Check instance has required attributes
|
58
|
+
assert isinstance(instance.id, UUID)
|
59
|
+
assert instance.impetus is not None
|
60
|
+
assert instance.intent is not None
|
61
|
+
assert instance.metadata is not None
|
62
|
+
assert instance.is_reproducible == True
|
63
|
+
assert instance.initial_engine_snapshot is None
|
64
|
+
|
65
|
+
# Check impetus
|
66
|
+
assert "TicTacToe" in instance.impetus.instructions
|
67
|
+
assert instance.metadata.starting_player in instance.impetus.instructions
|
68
|
+
|
69
|
+
# Check intent
|
70
|
+
assert "goal" in instance.intent.rubric
|
71
|
+
assert instance.intent.gold_trajectories is None
|
72
|
+
assert isinstance(instance.intent.gold_state_diff, dict)
|
73
|
+
|
74
|
+
@pytest.mark.asyncio
|
75
|
+
async def test_splits(self):
|
76
|
+
"""Test train/val/test splits."""
|
77
|
+
taskset = await create_tictactoe_taskset()
|
78
|
+
|
79
|
+
val_ids = taskset.split_info.val_instance_ids
|
80
|
+
test_ids = taskset.split_info.test_instance_ids
|
81
|
+
all_ids = {inst.id for inst in taskset.instances}
|
82
|
+
|
83
|
+
# Check splits are disjoint
|
84
|
+
assert len(val_ids & test_ids) == 0
|
85
|
+
|
86
|
+
# Check splits are subsets of all instances
|
87
|
+
assert val_ids.issubset(all_ids)
|
88
|
+
assert test_ids.issubset(all_ids)
|
89
|
+
|
90
|
+
# Check we have some instances in each split
|
91
|
+
assert len(val_ids) > 0
|
92
|
+
assert len(test_ids) > 0
|
93
|
+
|
94
|
+
# Train should be everything not in val/test
|
95
|
+
train_ids = all_ids - val_ids - test_ids
|
96
|
+
assert len(train_ids) > 0
|
97
|
+
|
98
|
+
@pytest.mark.asyncio
|
99
|
+
async def test_serialization(self):
|
100
|
+
"""Test task instance serialization."""
|
101
|
+
taskset = await create_tictactoe_taskset()
|
102
|
+
instance = taskset.instances[0]
|
103
|
+
|
104
|
+
# Serialize
|
105
|
+
data = await instance.serialize()
|
106
|
+
|
107
|
+
assert "id" in data
|
108
|
+
assert "impetus" in data
|
109
|
+
assert "intent" in data
|
110
|
+
assert "metadata" in data
|
111
|
+
assert "is_reproducible" in data
|
112
|
+
|
113
|
+
# Check metadata serialization
|
114
|
+
assert data["metadata"]["starting_player"] == instance.metadata.starting_player
|
115
|
+
assert data["metadata"]["opening_moves"] == instance.metadata.opening_moves
|
116
|
+
|
117
|
+
# Deserialize
|
118
|
+
restored = await TicTacToeTaskInstance.deserialize(data)
|
119
|
+
|
120
|
+
assert str(restored.id) == str(instance.id)
|
121
|
+
assert restored.impetus.instructions == instance.impetus.instructions
|
122
|
+
assert restored.metadata.starting_player == instance.metadata.starting_player
|
123
|
+
assert restored.metadata.opening_moves == instance.metadata.opening_moves
|
124
|
+
|
125
|
+
@pytest.mark.asyncio
|
126
|
+
async def test_opening_moves_validity(self):
|
127
|
+
"""Test that opening moves are valid."""
|
128
|
+
taskset = await create_tictactoe_taskset()
|
129
|
+
|
130
|
+
for instance in taskset.instances:
|
131
|
+
# Check all moves are valid coordinates
|
132
|
+
for move in instance.metadata.opening_moves:
|
133
|
+
assert move in COORD_TO_IDX
|
134
|
+
|
135
|
+
# Check no duplicate moves
|
136
|
+
assert len(instance.metadata.opening_moves) == len(set(instance.metadata.opening_moves))
|
137
|
+
|
138
|
+
# Simulate the moves to check they're valid
|
139
|
+
board = np.zeros(9, dtype=int)
|
140
|
+
current_player = "X"
|
141
|
+
|
142
|
+
for move in instance.metadata.opening_moves:
|
143
|
+
idx = COORD_TO_IDX[move]
|
144
|
+
assert board[idx] == 0 # Cell should be empty
|
145
|
+
board[idx] = PLAYER_MARKS[current_player]
|
146
|
+
current_player = "O" if current_player == "X" else "X"
|
147
|
+
|
148
|
+
|
149
|
+
class TestHelperFunctions:
|
150
|
+
def test_evaluate_position_wins(self):
|
151
|
+
"""Test position evaluation for wins."""
|
152
|
+
# X wins in top row
|
153
|
+
board = np.array([1, 1, 1, 2, 2, 0, 0, 0, 0])
|
154
|
+
assert _evaluate_position(board, 1) == "win"
|
155
|
+
assert _evaluate_position(board, 2) == "loss"
|
156
|
+
|
157
|
+
# O wins in first column
|
158
|
+
board = np.array([2, 1, 1, 2, 1, 0, 2, 0, 0])
|
159
|
+
assert _evaluate_position(board, 1) == "loss"
|
160
|
+
assert _evaluate_position(board, 2) == "win"
|
161
|
+
|
162
|
+
def test_evaluate_position_draw(self):
|
163
|
+
"""Test position evaluation for draws."""
|
164
|
+
# Full board with no winner - fixed to actually be a draw
|
165
|
+
# X O X
|
166
|
+
# X O X
|
167
|
+
# O X O
|
168
|
+
board = np.array([1, 2, 1, 1, 2, 1, 2, 1, 2])
|
169
|
+
assert _evaluate_position(board, 1) == "draw"
|
170
|
+
assert _evaluate_position(board, 2) == "draw"
|
171
|
+
|
172
|
+
def test_evaluate_position_ongoing(self):
|
173
|
+
"""Test position evaluation for ongoing games."""
|
174
|
+
# Game still in progress
|
175
|
+
board = np.array([1, 2, 0, 0, 1, 0, 0, 0, 0])
|
176
|
+
# For simplicity, our implementation returns "draw" for non-terminal
|
177
|
+
assert _evaluate_position(board, 1) == "draw"
|
178
|
+
|
179
|
+
def test_count_shortest_win(self):
|
180
|
+
"""Test shortest win calculation."""
|
181
|
+
# Empty board
|
182
|
+
board = np.zeros(9, dtype=int)
|
183
|
+
assert _count_shortest_win(board, 1) == 4 # 9 empty cells / 2
|
184
|
+
|
185
|
+
# Partially filled board
|
186
|
+
board = np.array([1, 2, 0, 0, 1, 0, 0, 0, 0])
|
187
|
+
assert _count_shortest_win(board, 1) == 3 # 6 empty cells / 2
|
188
|
+
|
189
|
+
# Almost full board
|
190
|
+
board = np.array([1, 2, 1, 2, 1, 2, 2, 1, 0])
|
191
|
+
assert _count_shortest_win(board, 1) == 1 # max(1, 1/2)
|
@@ -0,0 +1,10 @@
|
|
1
|
+
from .engine import VerilogEngine
|
2
|
+
from .environment import VerilogEnvironment
|
3
|
+
from .taskset import VerilogTaskInstance, create_verilog_taskset
|
4
|
+
|
5
|
+
__all__ = [
|
6
|
+
"VerilogEngine",
|
7
|
+
"VerilogEnvironment",
|
8
|
+
"VerilogTaskInstance",
|
9
|
+
"create_verilog_taskset",
|
10
|
+
]
|