synth-ai 0.2.4.dev8__py3-none-any.whl ā 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/__init__.py +1 -1
- synth_ai/cli/__init__.py +6 -0
- synth_ai/cli/demo.py +68 -9
- synth_ai/cli/rl_demo.py +137 -0
- synth_ai/cli/root.py +65 -0
- synth_ai/demos/core/__init__.py +1 -0
- synth_ai/demos/core/cli.py +685 -0
- synth_ai/demos/demo_task_apps/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/core.py +374 -0
- synth_ai/demos/demo_task_apps/math/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/math/app.py +37 -0
- synth_ai/demos/demo_task_apps/math/config.toml +44 -0
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +60 -0
- synth_ai/demos/demo_task_apps/math/deploy_task_app.sh +22 -0
- synth_ai/environments/examples/bandit/__init__.py +33 -0
- synth_ai/environments/examples/bandit/engine.py +294 -0
- synth_ai/environments/examples/bandit/environment.py +194 -0
- synth_ai/environments/examples/bandit/taskset.py +200 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/analyze_semantic_words_markdown.py +250 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +59 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_config.toml +24 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/crafter_synth_config.toml +56 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_config_modal.toml +32 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +724 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_modal.py +384 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_action_results.py +53 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_agent_actions.py +178 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_latest_run.py +222 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_lm_traces.py +183 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_no_rewards.py +210 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_trace_issue.py +206 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_db_schema.py +49 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_latest_results.py +64 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/debug_agent_responses.py +88 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/quick_trace_check.py +77 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/compare_experiments.py +324 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/kick_off_ft_oai.py +362 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/multi_model_config.toml +49 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_enhanced_hooks.py +332 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_events.py +97 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_results.py +217 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_hook_storage.py +87 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_seeds.py +88 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/compare_seed_performance.py +195 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/custom_eval_pipelines.py +400 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/plot_hook_frequency.py +195 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/seed_analysis_summary.py +56 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v3.py +858 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +52 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +874 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/example_v3_usage.py +216 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/compare_traces.py +296 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_env_serialization.py +464 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/debug_player_loss.py +112 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_service.py +203 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_slowness.py +305 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_by_difficulty.py +126 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_example.py +94 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/explore_saved_states.py +142 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft.py +26 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft_OLD.py +984 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_gemini.py +724 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_modal.py +386 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_metadata.py +205 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_gemini.py +150 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_modal.py +283 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/prepare_vertex_ft.py +280 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/profile_env_slowness.py +456 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/replicate_issue.py +166 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_and_eval.py +102 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_comparison.py +128 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_qwen_rollouts.py +655 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/trace_eval_OLD.py +202 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/validate_openai_format.py +166 -0
- synth_ai/environments/examples/crafter_classic/environment.py +41 -2
- synth_ai/environments/examples/crafter_custom/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/crafter_custom/agent_demos/trace_eval.py +202 -0
- synth_ai/environments/examples/crafter_custom/old/analyze_diamond_issue.py +159 -0
- synth_ai/environments/examples/crafter_custom/old/analyze_diamond_spawning.py +158 -0
- synth_ai/environments/examples/crafter_custom/old/compare_worlds.py +71 -0
- synth_ai/environments/examples/crafter_custom/old/dataset_stats.py +105 -0
- synth_ai/environments/examples/crafter_custom/old/diamond_spawning_summary.py +119 -0
- synth_ai/environments/examples/crafter_custom/old/example_dataset_usage.py +52 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +112 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +221 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +831 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +899 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +95 -0
- synth_ai/environments/service/app.py +8 -0
- synth_ai/install_sqld.sh +40 -0
- synth_ai-0.2.5.dist-info/METADATA +106 -0
- {synth_ai-0.2.4.dev8.dist-info ā synth_ai-0.2.5.dist-info}/RECORD +111 -12
- {synth_ai-0.2.4.dev8.dist-info ā synth_ai-0.2.5.dist-info}/entry_points.txt +1 -0
- synth_ai-0.2.4.dev8.dist-info/METADATA +0 -635
- {synth_ai-0.2.4.dev8.dist-info ā synth_ai-0.2.5.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev8.dist-info ā synth_ai-0.2.5.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev8.dist-info ā synth_ai-0.2.5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Trace evaluation functions for Crafter episodes.
|
|
4
|
+
Scores traces based on achievements and invalid actions.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Dict, List, Any, Tuple
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
import json
|
|
10
|
+
from collections import defaultdict
|
|
11
|
+
|
|
12
|
+
# Scoring weights
|
|
13
|
+
WEIGHTS = {
|
|
14
|
+
'easy_achievement': 1.0, # Easy achievement (e.g., collect_wood)
|
|
15
|
+
'medium_achievement': 2.5, # Medium achievement (e.g., make_wood_pickaxe)
|
|
16
|
+
'hard_achievement': 5.0, # Hard achievement (e.g., make_iron_sword)
|
|
17
|
+
'invalid_action': -0.05, # Invalid action penalty (50 invalid = -1 medium achievement)
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
# Map hook names to scoring categories
|
|
21
|
+
HOOK_TO_SCORE_TYPE = {
|
|
22
|
+
'easy_achievement': 'easy_achievement',
|
|
23
|
+
'medium_achievement': 'medium_achievement',
|
|
24
|
+
'hard_achievement': 'hard_achievement',
|
|
25
|
+
'invalid_action': 'invalid_action'
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def evaluate_event(event: Dict[str, Any]) -> Tuple[float, str]:
|
|
30
|
+
"""
|
|
31
|
+
Evaluate a single event based on its hooks.
|
|
32
|
+
Returns: (score, symbol) where symbol is '+', '-', or '0'
|
|
33
|
+
"""
|
|
34
|
+
score = 0.0
|
|
35
|
+
symbol = '0'
|
|
36
|
+
|
|
37
|
+
# Check if event has metadata from hooks
|
|
38
|
+
event_metadata = event.get('event_metadata', [])
|
|
39
|
+
|
|
40
|
+
for metadata in event_metadata:
|
|
41
|
+
hook_name = metadata.get('hook_name', '')
|
|
42
|
+
|
|
43
|
+
if hook_name in HOOK_TO_SCORE_TYPE:
|
|
44
|
+
score_type = HOOK_TO_SCORE_TYPE[hook_name]
|
|
45
|
+
weight = WEIGHTS[score_type]
|
|
46
|
+
score += weight
|
|
47
|
+
|
|
48
|
+
# Determine symbol
|
|
49
|
+
if weight > 0:
|
|
50
|
+
symbol = '+'
|
|
51
|
+
elif weight < 0:
|
|
52
|
+
symbol = '-'
|
|
53
|
+
|
|
54
|
+
return score, symbol
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def evaluate_trace(trace_path: Path) -> Dict[str, Any]:
|
|
58
|
+
"""
|
|
59
|
+
Evaluate an entire trace file.
|
|
60
|
+
Returns detailed scoring breakdown and trajectory visualization.
|
|
61
|
+
"""
|
|
62
|
+
with open(trace_path, 'r') as f:
|
|
63
|
+
trace_data = json.load(f)
|
|
64
|
+
|
|
65
|
+
# Track counts
|
|
66
|
+
counts = defaultdict(int)
|
|
67
|
+
total_score = 0.0
|
|
68
|
+
trajectory_symbols = []
|
|
69
|
+
|
|
70
|
+
# Process event history
|
|
71
|
+
event_history = trace_data.get('event_history', [])
|
|
72
|
+
|
|
73
|
+
for event in event_history:
|
|
74
|
+
event_score, symbol = evaluate_event(event)
|
|
75
|
+
total_score += event_score
|
|
76
|
+
|
|
77
|
+
# Only add symbol if score is non-zero
|
|
78
|
+
if event_score != 0:
|
|
79
|
+
trajectory_symbols.append(symbol)
|
|
80
|
+
|
|
81
|
+
# Count hook types
|
|
82
|
+
for metadata in event.get('event_metadata', []):
|
|
83
|
+
hook_name = metadata.get('hook_name', '')
|
|
84
|
+
if hook_name in HOOK_TO_SCORE_TYPE:
|
|
85
|
+
score_type = HOOK_TO_SCORE_TYPE[hook_name]
|
|
86
|
+
counts[score_type] += 1
|
|
87
|
+
|
|
88
|
+
# Create trajectory string
|
|
89
|
+
trajectory_str = ''.join(trajectory_symbols) if trajectory_symbols else '(no scored events)'
|
|
90
|
+
|
|
91
|
+
return {
|
|
92
|
+
'total_score': total_score,
|
|
93
|
+
'counts': dict(counts),
|
|
94
|
+
'trajectory': trajectory_str,
|
|
95
|
+
'num_events': len(event_history),
|
|
96
|
+
'trace_file': trace_path.name
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def print_trace_evaluation(eval_result: Dict[str, Any]):
|
|
101
|
+
"""Print a formatted evaluation result for a single trace."""
|
|
102
|
+
print(f"\nš Trace: {eval_result['trace_file']}")
|
|
103
|
+
print(f" Score: {eval_result['total_score']:.2f}")
|
|
104
|
+
print(f" Events: {eval_result['num_events']}")
|
|
105
|
+
|
|
106
|
+
counts = eval_result['counts']
|
|
107
|
+
if counts:
|
|
108
|
+
print(" Breakdown:")
|
|
109
|
+
if 'easy_achievement' in counts:
|
|
110
|
+
print(f" Easy achievements: {counts['easy_achievement']} Ć {WEIGHTS['easy_achievement']} = {counts['easy_achievement'] * WEIGHTS['easy_achievement']:.2f}")
|
|
111
|
+
if 'medium_achievement' in counts:
|
|
112
|
+
print(f" Medium achievements: {counts['medium_achievement']} Ć {WEIGHTS['medium_achievement']} = {counts['medium_achievement'] * WEIGHTS['medium_achievement']:.2f}")
|
|
113
|
+
if 'hard_achievement' in counts:
|
|
114
|
+
print(f" Hard achievements: {counts['hard_achievement']} Ć {WEIGHTS['hard_achievement']} = {counts['hard_achievement'] * WEIGHTS['hard_achievement']:.2f}")
|
|
115
|
+
if 'invalid_action' in counts:
|
|
116
|
+
print(f" Invalid actions: {counts['invalid_action']} Ć {WEIGHTS['invalid_action']} = {counts['invalid_action'] * WEIGHTS['invalid_action']:.2f}")
|
|
117
|
+
|
|
118
|
+
print(f" Trajectory: {eval_result['trajectory']}")
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def evaluate_all_traces(trace_dir: Path, pattern: str = "*.json") -> List[Dict[str, Any]]:
|
|
122
|
+
"""
|
|
123
|
+
Evaluate all trace files in a directory.
|
|
124
|
+
Returns list of evaluation results sorted by score.
|
|
125
|
+
"""
|
|
126
|
+
trace_files = list(trace_dir.glob(pattern))
|
|
127
|
+
results = []
|
|
128
|
+
|
|
129
|
+
for trace_file in trace_files:
|
|
130
|
+
try:
|
|
131
|
+
result = evaluate_trace(trace_file)
|
|
132
|
+
results.append(result)
|
|
133
|
+
except Exception as e:
|
|
134
|
+
print(f"ā ļø Error evaluating {trace_file.name}: {e}")
|
|
135
|
+
|
|
136
|
+
# Sort by score (descending)
|
|
137
|
+
results.sort(key=lambda x: x['total_score'], reverse=True)
|
|
138
|
+
|
|
139
|
+
return results
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def print_evaluation_summary(results: List[Dict[str, Any]]):
|
|
143
|
+
"""Print a summary of all trace evaluations."""
|
|
144
|
+
if not results:
|
|
145
|
+
print("No traces to evaluate.")
|
|
146
|
+
return
|
|
147
|
+
|
|
148
|
+
print("\n" + "=" * 80)
|
|
149
|
+
print("š TRACE EVALUATION SUMMARY")
|
|
150
|
+
print("=" * 80)
|
|
151
|
+
print(f"{'Rank':<6} {'Score':<10} {'Trajectory':<50} {'File':<30}")
|
|
152
|
+
print("-" * 80)
|
|
153
|
+
|
|
154
|
+
for i, result in enumerate(results, 1):
|
|
155
|
+
trajectory = result['trajectory']
|
|
156
|
+
if len(trajectory) > 50:
|
|
157
|
+
trajectory = trajectory[:47] + "..."
|
|
158
|
+
print(f"{i:<6} {result['total_score']:<10.2f} {trajectory:<50} {result['trace_file'][:30]:<30}")
|
|
159
|
+
|
|
160
|
+
print("-" * 80)
|
|
161
|
+
|
|
162
|
+
# Summary statistics
|
|
163
|
+
scores = [r['total_score'] for r in results]
|
|
164
|
+
avg_score = sum(scores) / len(scores) if scores else 0
|
|
165
|
+
max_score = max(scores) if scores else 0
|
|
166
|
+
min_score = min(scores) if scores else 0
|
|
167
|
+
|
|
168
|
+
print(f"Average Score: {avg_score:.2f}")
|
|
169
|
+
print(f"Best Score: {max_score:.2f}")
|
|
170
|
+
print(f"Worst Score: {min_score:.2f}")
|
|
171
|
+
print("=" * 80)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
if __name__ == "__main__":
|
|
175
|
+
import argparse
|
|
176
|
+
|
|
177
|
+
parser = argparse.ArgumentParser(description="Evaluate Crafter trace files")
|
|
178
|
+
parser.add_argument("trace_path", type=str, help="Directory containing trace files or single trace file")
|
|
179
|
+
parser.add_argument("--pattern", type=str, default="*.json", help="File pattern to match (for directories)")
|
|
180
|
+
parser.add_argument("--verbose", action="store_true", help="Show detailed evaluation for each trace")
|
|
181
|
+
|
|
182
|
+
args = parser.parse_args()
|
|
183
|
+
|
|
184
|
+
trace_path = Path(args.trace_path)
|
|
185
|
+
if not trace_path.exists():
|
|
186
|
+
print(f"ā Path not found: {trace_path}")
|
|
187
|
+
exit(1)
|
|
188
|
+
|
|
189
|
+
# Check if it's a file or directory
|
|
190
|
+
if trace_path.is_file():
|
|
191
|
+
print(f"š Evaluating single trace: {trace_path}")
|
|
192
|
+
result = evaluate_trace(trace_path)
|
|
193
|
+
print_trace_evaluation(result)
|
|
194
|
+
else:
|
|
195
|
+
print(f"š Evaluating traces in: {trace_path}")
|
|
196
|
+
results = evaluate_all_traces(trace_path, args.pattern)
|
|
197
|
+
|
|
198
|
+
if args.verbose:
|
|
199
|
+
for result in results:
|
|
200
|
+
print_trace_evaluation(result)
|
|
201
|
+
|
|
202
|
+
print_evaluation_summary(results)
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Validate that JSONL files are compatible with OpenAI fine-tuning format
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import sys
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import List, Dict, Any
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def validate_openai_format(file_path: Path) -> tuple[bool, List[str]]:
|
|
13
|
+
"""
|
|
14
|
+
Validate JSONL file for OpenAI fine-tuning compatibility.
|
|
15
|
+
Returns (is_valid, list_of_errors)
|
|
16
|
+
"""
|
|
17
|
+
errors = []
|
|
18
|
+
line_count = 0
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
with open(file_path, 'r') as f:
|
|
22
|
+
for line_num, line in enumerate(f, 1):
|
|
23
|
+
line_count += 1
|
|
24
|
+
|
|
25
|
+
# Parse JSON
|
|
26
|
+
try:
|
|
27
|
+
data = json.loads(line.strip())
|
|
28
|
+
except json.JSONDecodeError as e:
|
|
29
|
+
errors.append(f"Line {line_num}: Invalid JSON - {e}")
|
|
30
|
+
continue
|
|
31
|
+
|
|
32
|
+
# Check required structure
|
|
33
|
+
if not isinstance(data, dict):
|
|
34
|
+
errors.append(f"Line {line_num}: Root must be a dictionary")
|
|
35
|
+
continue
|
|
36
|
+
|
|
37
|
+
# Check for messages key
|
|
38
|
+
if 'messages' not in data:
|
|
39
|
+
errors.append(f"Line {line_num}: Missing 'messages' key")
|
|
40
|
+
continue
|
|
41
|
+
|
|
42
|
+
messages = data['messages']
|
|
43
|
+
if not isinstance(messages, list):
|
|
44
|
+
errors.append(f"Line {line_num}: 'messages' must be a list")
|
|
45
|
+
continue
|
|
46
|
+
|
|
47
|
+
if len(messages) < 2:
|
|
48
|
+
errors.append(f"Line {line_num}: Need at least 2 messages (system/user + assistant)")
|
|
49
|
+
continue
|
|
50
|
+
|
|
51
|
+
# Validate each message
|
|
52
|
+
roles_seen = []
|
|
53
|
+
for msg_idx, msg in enumerate(messages):
|
|
54
|
+
if not isinstance(msg, dict):
|
|
55
|
+
errors.append(f"Line {line_num}, message {msg_idx}: Message must be a dictionary")
|
|
56
|
+
continue
|
|
57
|
+
|
|
58
|
+
# Check required fields
|
|
59
|
+
if 'role' not in msg:
|
|
60
|
+
errors.append(f"Line {line_num}, message {msg_idx}: Missing 'role'")
|
|
61
|
+
continue
|
|
62
|
+
|
|
63
|
+
role = msg['role']
|
|
64
|
+
roles_seen.append(role)
|
|
65
|
+
|
|
66
|
+
if role not in ['system', 'user', 'assistant', 'tool']:
|
|
67
|
+
errors.append(f"Line {line_num}, message {msg_idx}: Invalid role '{role}'")
|
|
68
|
+
|
|
69
|
+
# Check content or tool_calls
|
|
70
|
+
if role == 'assistant':
|
|
71
|
+
# Assistant messages can have content, tool_calls, or both
|
|
72
|
+
if 'content' not in msg and 'tool_calls' not in msg:
|
|
73
|
+
errors.append(f"Line {line_num}, message {msg_idx}: Assistant message needs 'content' or 'tool_calls'")
|
|
74
|
+
|
|
75
|
+
# Validate tool_calls structure
|
|
76
|
+
if 'tool_calls' in msg:
|
|
77
|
+
tool_calls = msg['tool_calls']
|
|
78
|
+
if not isinstance(tool_calls, list):
|
|
79
|
+
errors.append(f"Line {line_num}, message {msg_idx}: 'tool_calls' must be a list")
|
|
80
|
+
else:
|
|
81
|
+
for tc_idx, tc in enumerate(tool_calls):
|
|
82
|
+
if not isinstance(tc, dict):
|
|
83
|
+
errors.append(f"Line {line_num}, message {msg_idx}, tool_call {tc_idx}: Must be a dict")
|
|
84
|
+
continue
|
|
85
|
+
|
|
86
|
+
# Check tool call structure
|
|
87
|
+
if 'id' not in tc:
|
|
88
|
+
errors.append(f"Line {line_num}, message {msg_idx}, tool_call {tc_idx}: Missing 'id'")
|
|
89
|
+
if 'type' not in tc:
|
|
90
|
+
errors.append(f"Line {line_num}, message {msg_idx}, tool_call {tc_idx}: Missing 'type'")
|
|
91
|
+
if 'function' not in tc:
|
|
92
|
+
errors.append(f"Line {line_num}, message {msg_idx}, tool_call {tc_idx}: Missing 'function'")
|
|
93
|
+
else:
|
|
94
|
+
func = tc['function']
|
|
95
|
+
if not isinstance(func, dict):
|
|
96
|
+
errors.append(f"Line {line_num}, message {msg_idx}, tool_call {tc_idx}: 'function' must be a dict")
|
|
97
|
+
else:
|
|
98
|
+
if 'name' not in func:
|
|
99
|
+
errors.append(f"Line {line_num}, message {msg_idx}, tool_call {tc_idx}: Missing function 'name'")
|
|
100
|
+
if 'arguments' not in func:
|
|
101
|
+
errors.append(f"Line {line_num}, message {msg_idx}, tool_call {tc_idx}: Missing function 'arguments'")
|
|
102
|
+
else:
|
|
103
|
+
# Other roles must have content
|
|
104
|
+
if 'content' not in msg:
|
|
105
|
+
errors.append(f"Line {line_num}, message {msg_idx}: {role} message missing 'content'")
|
|
106
|
+
|
|
107
|
+
# Check message order
|
|
108
|
+
if roles_seen[-1] != 'assistant':
|
|
109
|
+
errors.append(f"Line {line_num}: Last message must be from assistant (found {roles_seen[-1]})")
|
|
110
|
+
|
|
111
|
+
except Exception as e:
|
|
112
|
+
errors.append(f"Error reading file: {e}")
|
|
113
|
+
return False, errors
|
|
114
|
+
|
|
115
|
+
print(f"Validated {line_count} examples")
|
|
116
|
+
return len(errors) == 0, errors
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def main():
|
|
120
|
+
if len(sys.argv) < 2:
|
|
121
|
+
print("Usage: python validate_openai_format.py <jsonl_file>")
|
|
122
|
+
sys.exit(1)
|
|
123
|
+
|
|
124
|
+
file_path = Path(sys.argv[1])
|
|
125
|
+
if not file_path.exists():
|
|
126
|
+
print(f"File not found: {file_path}")
|
|
127
|
+
sys.exit(1)
|
|
128
|
+
|
|
129
|
+
print(f"Validating OpenAI format for: {file_path}")
|
|
130
|
+
print("=" * 60)
|
|
131
|
+
|
|
132
|
+
is_valid, errors = validate_openai_format(file_path)
|
|
133
|
+
|
|
134
|
+
if is_valid:
|
|
135
|
+
print("ā
File is valid for OpenAI fine-tuning!")
|
|
136
|
+
|
|
137
|
+
# Show sample
|
|
138
|
+
with open(file_path, 'r') as f:
|
|
139
|
+
first_line = f.readline()
|
|
140
|
+
example = json.loads(first_line)
|
|
141
|
+
|
|
142
|
+
print("\nSample example structure:")
|
|
143
|
+
print(f"- Number of messages: {len(example['messages'])}")
|
|
144
|
+
print(f"- Message roles: {[msg['role'] for msg in example['messages']]}")
|
|
145
|
+
|
|
146
|
+
# Check if assistant has tool calls
|
|
147
|
+
last_msg = example['messages'][-1]
|
|
148
|
+
if 'tool_calls' in last_msg:
|
|
149
|
+
print(f"- Assistant uses tool calls: Yes")
|
|
150
|
+
print(f"- Number of tool calls: {len(last_msg['tool_calls'])}")
|
|
151
|
+
if last_msg['tool_calls']:
|
|
152
|
+
print(f"- First tool: {last_msg['tool_calls'][0]['function']['name']}")
|
|
153
|
+
else:
|
|
154
|
+
print(f"- Assistant uses tool calls: No")
|
|
155
|
+
else:
|
|
156
|
+
print(f"ā File has {len(errors)} validation errors:")
|
|
157
|
+
for i, error in enumerate(errors[:10]): # Show first 10 errors
|
|
158
|
+
print(f" {i+1}. {error}")
|
|
159
|
+
if len(errors) > 10:
|
|
160
|
+
print(f" ... and {len(errors) - 10} more errors")
|
|
161
|
+
|
|
162
|
+
return 0 if is_valid else 1
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
if __name__ == "__main__":
|
|
166
|
+
sys.exit(main())
|
|
@@ -119,11 +119,20 @@ class CrafterInteractTool(AbstractTool):
|
|
|
119
119
|
|
|
120
120
|
# Default observation callable (can be customized via __init__)
|
|
121
121
|
class SynthCrafterObservationCallable(GetObservationCallable):
|
|
122
|
+
"""Default observation: public state dict + per-step reward/flags.
|
|
123
|
+
|
|
124
|
+
Additionally computes a small local semantic patch centered on the player
|
|
125
|
+
to simplify visualization on the client. The patch is exposed under the
|
|
126
|
+
key `semantic_map_patch7` as a list-of-lists of ints (7x7 unless the
|
|
127
|
+
semantic map is smaller, in which case it is cropped at edges).
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
def __init__(self, view_size: int = 7) -> None:
|
|
131
|
+
self.view_size = max(1, int(view_size))
|
|
132
|
+
|
|
122
133
|
async def get_observation(
|
|
123
134
|
self, pub: CrafterPublicState, priv: CrafterPrivateState
|
|
124
135
|
) -> InternalObservation:
|
|
125
|
-
# Example: return a dictionary combining public and selected private info
|
|
126
|
-
# Actual observation structure depends on agent's needs.
|
|
127
136
|
obs_dict: Dict[str, Any] = dataclasses.asdict(pub) # type: ignore
|
|
128
137
|
obs_dict["reward_last_step"] = priv.reward_last_step
|
|
129
138
|
obs_dict["total_reward_episode"] = priv.total_reward_episode
|
|
@@ -131,6 +140,36 @@ class SynthCrafterObservationCallable(GetObservationCallable):
|
|
|
131
140
|
obs_dict["truncated"] = priv.truncated
|
|
132
141
|
if pub.error_info:
|
|
133
142
|
obs_dict["tool_error"] = pub.error_info
|
|
143
|
+
|
|
144
|
+
# Derive a simple local semantic patch around the player for easy rendering
|
|
145
|
+
try:
|
|
146
|
+
sem = pub.semantic_map
|
|
147
|
+
if sem is not None:
|
|
148
|
+
rows = int(getattr(sem, "shape", [0, 0])[0]) # type: ignore
|
|
149
|
+
cols = int(getattr(sem, "shape", [0, 0])[1]) # type: ignore
|
|
150
|
+
if rows > 0 and cols > 0:
|
|
151
|
+
px, py = int(pub.player_position[0]), int(pub.player_position[1])
|
|
152
|
+
half = max(1, self.view_size // 2)
|
|
153
|
+
x0, y0 = px - half, py - half
|
|
154
|
+
x1, y1 = px + half, py + half
|
|
155
|
+
patch: list[list[int]] = []
|
|
156
|
+
for gy in range(y0, y1 + 1):
|
|
157
|
+
row_vals: list[int] = []
|
|
158
|
+
for gx in range(x0, x1 + 1):
|
|
159
|
+
if 0 <= gy < rows and 0 <= gx < cols:
|
|
160
|
+
try:
|
|
161
|
+
val = int(sem[gy, gx]) # type: ignore[index]
|
|
162
|
+
except Exception:
|
|
163
|
+
val = 0
|
|
164
|
+
else:
|
|
165
|
+
val = 0
|
|
166
|
+
row_vals.append(val)
|
|
167
|
+
patch.append(row_vals)
|
|
168
|
+
obs_dict["semantic_map_patch7"] = patch
|
|
169
|
+
except Exception:
|
|
170
|
+
# Best-effort; omit patch on error
|
|
171
|
+
pass
|
|
172
|
+
|
|
134
173
|
return obs_dict
|
|
135
174
|
|
|
136
175
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Agent demonstrations for custom Crafter environments."""
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Trace evaluation functions for Crafter episodes.
|
|
4
|
+
Scores traces based on achievements and invalid actions.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
from collections import defaultdict
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Dict, List, Tuple
|
|
11
|
+
|
|
12
|
+
# Scoring weights
|
|
13
|
+
WEIGHTS = {
|
|
14
|
+
'easy_achievement': 1.0, # Easy achievement (e.g., collect_wood)
|
|
15
|
+
'medium_achievement': 2.5, # Medium achievement (e.g., make_wood_pickaxe)
|
|
16
|
+
'hard_achievement': 5.0, # Hard achievement (e.g., make_iron_sword)
|
|
17
|
+
'invalid_action': -0.05, # Invalid action penalty (50 invalid = -1 medium achievement)
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
# Map hook names to scoring categories
|
|
21
|
+
HOOK_TO_SCORE_TYPE = {
|
|
22
|
+
'easy_achievement': 'easy_achievement',
|
|
23
|
+
'medium_achievement': 'medium_achievement',
|
|
24
|
+
'hard_achievement': 'hard_achievement',
|
|
25
|
+
'invalid_action': 'invalid_action'
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def evaluate_event(event: Dict[str, Any]) -> Tuple[float, str]:
|
|
30
|
+
"""
|
|
31
|
+
Evaluate a single event based on its hooks.
|
|
32
|
+
Returns: (score, symbol) where symbol is '+', '-', or '0'
|
|
33
|
+
"""
|
|
34
|
+
score = 0.0
|
|
35
|
+
symbol = '0'
|
|
36
|
+
|
|
37
|
+
# Check if event has metadata from hooks
|
|
38
|
+
event_metadata = event.get('event_metadata', [])
|
|
39
|
+
|
|
40
|
+
for metadata in event_metadata:
|
|
41
|
+
hook_name = metadata.get('hook_name', '')
|
|
42
|
+
|
|
43
|
+
if hook_name in HOOK_TO_SCORE_TYPE:
|
|
44
|
+
score_type = HOOK_TO_SCORE_TYPE[hook_name]
|
|
45
|
+
weight = WEIGHTS[score_type]
|
|
46
|
+
score += weight
|
|
47
|
+
|
|
48
|
+
# Determine symbol
|
|
49
|
+
if weight > 0:
|
|
50
|
+
symbol = '+'
|
|
51
|
+
elif weight < 0:
|
|
52
|
+
symbol = '-'
|
|
53
|
+
|
|
54
|
+
return score, symbol
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def evaluate_trace(trace_path: Path) -> Dict[str, Any]:
|
|
58
|
+
"""
|
|
59
|
+
Evaluate an entire trace file.
|
|
60
|
+
Returns detailed scoring breakdown and trajectory visualization.
|
|
61
|
+
"""
|
|
62
|
+
with open(trace_path, 'r') as f:
|
|
63
|
+
trace_data = json.load(f)
|
|
64
|
+
|
|
65
|
+
# Track counts
|
|
66
|
+
counts = defaultdict(int)
|
|
67
|
+
total_score = 0.0
|
|
68
|
+
trajectory_symbols = []
|
|
69
|
+
|
|
70
|
+
# Process event history
|
|
71
|
+
event_history = trace_data.get('event_history', [])
|
|
72
|
+
|
|
73
|
+
for event in event_history:
|
|
74
|
+
event_score, symbol = evaluate_event(event)
|
|
75
|
+
total_score += event_score
|
|
76
|
+
|
|
77
|
+
# Only add symbol if score is non-zero
|
|
78
|
+
if event_score != 0:
|
|
79
|
+
trajectory_symbols.append(symbol)
|
|
80
|
+
|
|
81
|
+
# Count hook types
|
|
82
|
+
for metadata in event.get('event_metadata', []):
|
|
83
|
+
hook_name = metadata.get('hook_name', '')
|
|
84
|
+
if hook_name in HOOK_TO_SCORE_TYPE:
|
|
85
|
+
score_type = HOOK_TO_SCORE_TYPE[hook_name]
|
|
86
|
+
counts[score_type] += 1
|
|
87
|
+
|
|
88
|
+
# Create trajectory string
|
|
89
|
+
trajectory_str = ''.join(trajectory_symbols) if trajectory_symbols else '(no scored events)'
|
|
90
|
+
|
|
91
|
+
return {
|
|
92
|
+
'total_score': total_score,
|
|
93
|
+
'counts': dict(counts),
|
|
94
|
+
'trajectory': trajectory_str,
|
|
95
|
+
'num_events': len(event_history),
|
|
96
|
+
'trace_file': trace_path.name
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def print_trace_evaluation(eval_result: Dict[str, Any]):
|
|
101
|
+
"""Print a formatted evaluation result for a single trace."""
|
|
102
|
+
print(f"\nš Trace: {eval_result['trace_file']}")
|
|
103
|
+
print(f" Score: {eval_result['total_score']:.2f}")
|
|
104
|
+
print(f" Events: {eval_result['num_events']}")
|
|
105
|
+
|
|
106
|
+
counts = eval_result['counts']
|
|
107
|
+
if counts:
|
|
108
|
+
print(" Breakdown:")
|
|
109
|
+
if 'easy_achievement' in counts:
|
|
110
|
+
print(f" Easy achievements: {counts['easy_achievement']} Ć {WEIGHTS['easy_achievement']} = {counts['easy_achievement'] * WEIGHTS['easy_achievement']:.2f}")
|
|
111
|
+
if 'medium_achievement' in counts:
|
|
112
|
+
print(f" Medium achievements: {counts['medium_achievement']} Ć {WEIGHTS['medium_achievement']} = {counts['medium_achievement'] * WEIGHTS['medium_achievement']:.2f}")
|
|
113
|
+
if 'hard_achievement' in counts:
|
|
114
|
+
print(f" Hard achievements: {counts['hard_achievement']} Ć {WEIGHTS['hard_achievement']} = {counts['hard_achievement'] * WEIGHTS['hard_achievement']:.2f}")
|
|
115
|
+
if 'invalid_action' in counts:
|
|
116
|
+
print(f" Invalid actions: {counts['invalid_action']} Ć {WEIGHTS['invalid_action']} = {counts['invalid_action'] * WEIGHTS['invalid_action']:.2f}")
|
|
117
|
+
|
|
118
|
+
print(f" Trajectory: {eval_result['trajectory']}")
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def evaluate_all_traces(trace_dir: Path, pattern: str = "*.json") -> List[Dict[str, Any]]:
|
|
122
|
+
"""
|
|
123
|
+
Evaluate all trace files in a directory.
|
|
124
|
+
Returns list of evaluation results sorted by score.
|
|
125
|
+
"""
|
|
126
|
+
trace_files = list(trace_dir.glob(pattern))
|
|
127
|
+
results = []
|
|
128
|
+
|
|
129
|
+
for trace_file in trace_files:
|
|
130
|
+
try:
|
|
131
|
+
result = evaluate_trace(trace_file)
|
|
132
|
+
results.append(result)
|
|
133
|
+
except Exception as e:
|
|
134
|
+
print(f"ā ļø Error evaluating {trace_file.name}: {e}")
|
|
135
|
+
|
|
136
|
+
# Sort by score (descending)
|
|
137
|
+
results.sort(key=lambda x: x['total_score'], reverse=True)
|
|
138
|
+
|
|
139
|
+
return results
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def print_evaluation_summary(results: List[Dict[str, Any]]):
|
|
143
|
+
"""Print a summary of all trace evaluations."""
|
|
144
|
+
if not results:
|
|
145
|
+
print("No traces to evaluate.")
|
|
146
|
+
return
|
|
147
|
+
|
|
148
|
+
print("\n" + "=" * 80)
|
|
149
|
+
print("š TRACE EVALUATION SUMMARY")
|
|
150
|
+
print("=" * 80)
|
|
151
|
+
print(f"{'Rank':<6} {'Score':<10} {'Trajectory':<50} {'File':<30}")
|
|
152
|
+
print("-" * 80)
|
|
153
|
+
|
|
154
|
+
for i, result in enumerate(results, 1):
|
|
155
|
+
trajectory = result['trajectory']
|
|
156
|
+
if len(trajectory) > 50:
|
|
157
|
+
trajectory = trajectory[:47] + "..."
|
|
158
|
+
print(f"{i:<6} {result['total_score']:<10.2f} {trajectory:<50} {result['trace_file'][:30]:<30}")
|
|
159
|
+
|
|
160
|
+
print("-" * 80)
|
|
161
|
+
|
|
162
|
+
# Summary statistics
|
|
163
|
+
scores = [r['total_score'] for r in results]
|
|
164
|
+
avg_score = sum(scores) / len(scores) if scores else 0
|
|
165
|
+
max_score = max(scores) if scores else 0
|
|
166
|
+
min_score = min(scores) if scores else 0
|
|
167
|
+
|
|
168
|
+
print(f"Average Score: {avg_score:.2f}")
|
|
169
|
+
print(f"Best Score: {max_score:.2f}")
|
|
170
|
+
print(f"Worst Score: {min_score:.2f}")
|
|
171
|
+
print("=" * 80)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
if __name__ == "__main__":
|
|
175
|
+
import argparse
|
|
176
|
+
|
|
177
|
+
parser = argparse.ArgumentParser(description="Evaluate Crafter trace files")
|
|
178
|
+
parser.add_argument("trace_path", type=str, help="Directory containing trace files or single trace file")
|
|
179
|
+
parser.add_argument("--pattern", type=str, default="*.json", help="File pattern to match (for directories)")
|
|
180
|
+
parser.add_argument("--verbose", action="store_true", help="Show detailed evaluation for each trace")
|
|
181
|
+
|
|
182
|
+
args = parser.parse_args()
|
|
183
|
+
|
|
184
|
+
trace_path = Path(args.trace_path)
|
|
185
|
+
if not trace_path.exists():
|
|
186
|
+
print(f"ā Path not found: {trace_path}")
|
|
187
|
+
exit(1)
|
|
188
|
+
|
|
189
|
+
# Check if it's a file or directory
|
|
190
|
+
if trace_path.is_file():
|
|
191
|
+
print(f"š Evaluating single trace: {trace_path}")
|
|
192
|
+
result = evaluate_trace(trace_path)
|
|
193
|
+
print_trace_evaluation(result)
|
|
194
|
+
else:
|
|
195
|
+
print(f"š Evaluating traces in: {trace_path}")
|
|
196
|
+
results = evaluate_all_traces(trace_path, args.pattern)
|
|
197
|
+
|
|
198
|
+
if args.verbose:
|
|
199
|
+
for result in results:
|
|
200
|
+
print_trace_evaluation(result)
|
|
201
|
+
|
|
202
|
+
print_evaluation_summary(results)
|