synth-ai 0.2.4.dev8__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (112) hide show
  1. synth_ai/__init__.py +1 -1
  2. synth_ai/cli/__init__.py +6 -0
  3. synth_ai/cli/demo.py +68 -9
  4. synth_ai/cli/rl_demo.py +137 -0
  5. synth_ai/cli/root.py +65 -0
  6. synth_ai/demos/core/__init__.py +1 -0
  7. synth_ai/demos/core/cli.py +685 -0
  8. synth_ai/demos/demo_task_apps/__init__.py +1 -0
  9. synth_ai/demos/demo_task_apps/core.py +374 -0
  10. synth_ai/demos/demo_task_apps/math/__init__.py +1 -0
  11. synth_ai/demos/demo_task_apps/math/app.py +37 -0
  12. synth_ai/demos/demo_task_apps/math/config.toml +44 -0
  13. synth_ai/demos/demo_task_apps/math/deploy_modal.py +60 -0
  14. synth_ai/demos/demo_task_apps/math/deploy_task_app.sh +22 -0
  15. synth_ai/environments/examples/bandit/__init__.py +33 -0
  16. synth_ai/environments/examples/bandit/engine.py +294 -0
  17. synth_ai/environments/examples/bandit/environment.py +194 -0
  18. synth_ai/environments/examples/bandit/taskset.py +200 -0
  19. synth_ai/environments/examples/crafter_classic/agent_demos/analyze_semantic_words_markdown.py +250 -0
  20. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +59 -0
  21. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
  22. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_config.toml +24 -0
  23. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
  24. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/crafter_synth_config.toml +56 -0
  25. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_config_modal.toml +32 -0
  26. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +724 -0
  27. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_modal.py +384 -0
  28. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_action_results.py +53 -0
  29. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_agent_actions.py +178 -0
  30. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_latest_run.py +222 -0
  31. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_lm_traces.py +183 -0
  32. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_no_rewards.py +210 -0
  33. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_trace_issue.py +206 -0
  34. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_db_schema.py +49 -0
  35. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_latest_results.py +64 -0
  36. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/debug_agent_responses.py +88 -0
  37. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/quick_trace_check.py +77 -0
  38. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/compare_experiments.py +324 -0
  39. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
  40. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/kick_off_ft_oai.py +362 -0
  41. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/multi_model_config.toml +49 -0
  42. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_enhanced_hooks.py +332 -0
  43. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_events.py +97 -0
  44. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_results.py +217 -0
  45. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_hook_storage.py +87 -0
  46. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_seeds.py +88 -0
  47. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/compare_seed_performance.py +195 -0
  48. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/custom_eval_pipelines.py +400 -0
  49. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/plot_hook_frequency.py +195 -0
  50. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/seed_analysis_summary.py +56 -0
  51. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v3.py +858 -0
  52. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +52 -0
  53. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +874 -0
  54. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
  55. synth_ai/environments/examples/crafter_classic/agent_demos/example_v3_usage.py +216 -0
  56. synth_ai/environments/examples/crafter_classic/agent_demos/old/compare_traces.py +296 -0
  57. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_comprehensive_evaluation.py +58 -0
  58. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_env_serialization.py +464 -0
  59. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_evaluation_browser.py +152 -0
  60. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_quick_evaluation.py +51 -0
  61. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_trace_evaluation.py +1412 -0
  62. synth_ai/environments/examples/crafter_classic/agent_demos/old/debug_player_loss.py +112 -0
  63. synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_service.py +203 -0
  64. synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_slowness.py +305 -0
  65. synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_by_difficulty.py +126 -0
  66. synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_example.py +94 -0
  67. synth_ai/environments/examples/crafter_classic/agent_demos/old/explore_saved_states.py +142 -0
  68. synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft.py +26 -0
  69. synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft_OLD.py +984 -0
  70. synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_gemini.py +724 -0
  71. synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_modal.py +386 -0
  72. synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_metadata.py +205 -0
  73. synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_gemini.py +150 -0
  74. synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_modal.py +283 -0
  75. synth_ai/environments/examples/crafter_classic/agent_demos/old/prepare_vertex_ft.py +280 -0
  76. synth_ai/environments/examples/crafter_classic/agent_demos/old/profile_env_slowness.py +456 -0
  77. synth_ai/environments/examples/crafter_classic/agent_demos/old/replicate_issue.py +166 -0
  78. synth_ai/environments/examples/crafter_classic/agent_demos/old/run_and_eval.py +102 -0
  79. synth_ai/environments/examples/crafter_classic/agent_demos/old/run_comparison.py +128 -0
  80. synth_ai/environments/examples/crafter_classic/agent_demos/old/run_qwen_rollouts.py +655 -0
  81. synth_ai/environments/examples/crafter_classic/agent_demos/old/trace_eval_OLD.py +202 -0
  82. synth_ai/environments/examples/crafter_classic/agent_demos/old/validate_openai_format.py +166 -0
  83. synth_ai/environments/examples/crafter_classic/environment.py +41 -2
  84. synth_ai/environments/examples/crafter_custom/agent_demos/__init__.py +1 -0
  85. synth_ai/environments/examples/crafter_custom/agent_demos/trace_eval.py +202 -0
  86. synth_ai/environments/examples/crafter_custom/old/analyze_diamond_issue.py +159 -0
  87. synth_ai/environments/examples/crafter_custom/old/analyze_diamond_spawning.py +158 -0
  88. synth_ai/environments/examples/crafter_custom/old/compare_worlds.py +71 -0
  89. synth_ai/environments/examples/crafter_custom/old/dataset_stats.py +105 -0
  90. synth_ai/environments/examples/crafter_custom/old/diamond_spawning_summary.py +119 -0
  91. synth_ai/environments/examples/crafter_custom/old/example_dataset_usage.py +52 -0
  92. synth_ai/environments/examples/enron/units/keyword_stats.py +112 -0
  93. synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
  94. synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +48 -0
  95. synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
  96. synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +221 -0
  97. synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
  98. synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
  99. synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +831 -0
  100. synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
  101. synth_ai/environments/examples/red/units/__init__.py +1 -0
  102. synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +899 -0
  103. synth_ai/environments/examples/sokoban/units/astar_common.py +95 -0
  104. synth_ai/environments/service/app.py +8 -0
  105. synth_ai/install_sqld.sh +40 -0
  106. synth_ai-0.2.5.dist-info/METADATA +106 -0
  107. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.5.dist-info}/RECORD +111 -12
  108. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.5.dist-info}/entry_points.txt +1 -0
  109. synth_ai-0.2.4.dev8.dist-info/METADATA +0 -635
  110. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.5.dist-info}/WHEEL +0 -0
  111. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.5.dist-info}/licenses/LICENSE +0 -0
  112. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,202 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Trace evaluation functions for Crafter episodes.
4
+ Scores traces based on achievements and invalid actions.
5
+ """
6
+
7
+ from typing import Dict, List, Any, Tuple
8
+ from pathlib import Path
9
+ import json
10
+ from collections import defaultdict
11
+
12
+ # Scoring weights
13
+ WEIGHTS = {
14
+ 'easy_achievement': 1.0, # Easy achievement (e.g., collect_wood)
15
+ 'medium_achievement': 2.5, # Medium achievement (e.g., make_wood_pickaxe)
16
+ 'hard_achievement': 5.0, # Hard achievement (e.g., make_iron_sword)
17
+ 'invalid_action': -0.05, # Invalid action penalty (50 invalid = -1 medium achievement)
18
+ }
19
+
20
+ # Map hook names to scoring categories
21
+ HOOK_TO_SCORE_TYPE = {
22
+ 'easy_achievement': 'easy_achievement',
23
+ 'medium_achievement': 'medium_achievement',
24
+ 'hard_achievement': 'hard_achievement',
25
+ 'invalid_action': 'invalid_action'
26
+ }
27
+
28
+
29
+ def evaluate_event(event: Dict[str, Any]) -> Tuple[float, str]:
30
+ """
31
+ Evaluate a single event based on its hooks.
32
+ Returns: (score, symbol) where symbol is '+', '-', or '0'
33
+ """
34
+ score = 0.0
35
+ symbol = '0'
36
+
37
+ # Check if event has metadata from hooks
38
+ event_metadata = event.get('event_metadata', [])
39
+
40
+ for metadata in event_metadata:
41
+ hook_name = metadata.get('hook_name', '')
42
+
43
+ if hook_name in HOOK_TO_SCORE_TYPE:
44
+ score_type = HOOK_TO_SCORE_TYPE[hook_name]
45
+ weight = WEIGHTS[score_type]
46
+ score += weight
47
+
48
+ # Determine symbol
49
+ if weight > 0:
50
+ symbol = '+'
51
+ elif weight < 0:
52
+ symbol = '-'
53
+
54
+ return score, symbol
55
+
56
+
57
+ def evaluate_trace(trace_path: Path) -> Dict[str, Any]:
58
+ """
59
+ Evaluate an entire trace file.
60
+ Returns detailed scoring breakdown and trajectory visualization.
61
+ """
62
+ with open(trace_path, 'r') as f:
63
+ trace_data = json.load(f)
64
+
65
+ # Track counts
66
+ counts = defaultdict(int)
67
+ total_score = 0.0
68
+ trajectory_symbols = []
69
+
70
+ # Process event history
71
+ event_history = trace_data.get('event_history', [])
72
+
73
+ for event in event_history:
74
+ event_score, symbol = evaluate_event(event)
75
+ total_score += event_score
76
+
77
+ # Only add symbol if score is non-zero
78
+ if event_score != 0:
79
+ trajectory_symbols.append(symbol)
80
+
81
+ # Count hook types
82
+ for metadata in event.get('event_metadata', []):
83
+ hook_name = metadata.get('hook_name', '')
84
+ if hook_name in HOOK_TO_SCORE_TYPE:
85
+ score_type = HOOK_TO_SCORE_TYPE[hook_name]
86
+ counts[score_type] += 1
87
+
88
+ # Create trajectory string
89
+ trajectory_str = ''.join(trajectory_symbols) if trajectory_symbols else '(no scored events)'
90
+
91
+ return {
92
+ 'total_score': total_score,
93
+ 'counts': dict(counts),
94
+ 'trajectory': trajectory_str,
95
+ 'num_events': len(event_history),
96
+ 'trace_file': trace_path.name
97
+ }
98
+
99
+
100
+ def print_trace_evaluation(eval_result: Dict[str, Any]):
101
+ """Print a formatted evaluation result for a single trace."""
102
+ print(f"\nšŸ“Š Trace: {eval_result['trace_file']}")
103
+ print(f" Score: {eval_result['total_score']:.2f}")
104
+ print(f" Events: {eval_result['num_events']}")
105
+
106
+ counts = eval_result['counts']
107
+ if counts:
108
+ print(" Breakdown:")
109
+ if 'easy_achievement' in counts:
110
+ print(f" Easy achievements: {counts['easy_achievement']} Ɨ {WEIGHTS['easy_achievement']} = {counts['easy_achievement'] * WEIGHTS['easy_achievement']:.2f}")
111
+ if 'medium_achievement' in counts:
112
+ print(f" Medium achievements: {counts['medium_achievement']} Ɨ {WEIGHTS['medium_achievement']} = {counts['medium_achievement'] * WEIGHTS['medium_achievement']:.2f}")
113
+ if 'hard_achievement' in counts:
114
+ print(f" Hard achievements: {counts['hard_achievement']} Ɨ {WEIGHTS['hard_achievement']} = {counts['hard_achievement'] * WEIGHTS['hard_achievement']:.2f}")
115
+ if 'invalid_action' in counts:
116
+ print(f" Invalid actions: {counts['invalid_action']} Ɨ {WEIGHTS['invalid_action']} = {counts['invalid_action'] * WEIGHTS['invalid_action']:.2f}")
117
+
118
+ print(f" Trajectory: {eval_result['trajectory']}")
119
+
120
+
121
+ def evaluate_all_traces(trace_dir: Path, pattern: str = "*.json") -> List[Dict[str, Any]]:
122
+ """
123
+ Evaluate all trace files in a directory.
124
+ Returns list of evaluation results sorted by score.
125
+ """
126
+ trace_files = list(trace_dir.glob(pattern))
127
+ results = []
128
+
129
+ for trace_file in trace_files:
130
+ try:
131
+ result = evaluate_trace(trace_file)
132
+ results.append(result)
133
+ except Exception as e:
134
+ print(f"āš ļø Error evaluating {trace_file.name}: {e}")
135
+
136
+ # Sort by score (descending)
137
+ results.sort(key=lambda x: x['total_score'], reverse=True)
138
+
139
+ return results
140
+
141
+
142
+ def print_evaluation_summary(results: List[Dict[str, Any]]):
143
+ """Print a summary of all trace evaluations."""
144
+ if not results:
145
+ print("No traces to evaluate.")
146
+ return
147
+
148
+ print("\n" + "=" * 80)
149
+ print("šŸ“ˆ TRACE EVALUATION SUMMARY")
150
+ print("=" * 80)
151
+ print(f"{'Rank':<6} {'Score':<10} {'Trajectory':<50} {'File':<30}")
152
+ print("-" * 80)
153
+
154
+ for i, result in enumerate(results, 1):
155
+ trajectory = result['trajectory']
156
+ if len(trajectory) > 50:
157
+ trajectory = trajectory[:47] + "..."
158
+ print(f"{i:<6} {result['total_score']:<10.2f} {trajectory:<50} {result['trace_file'][:30]:<30}")
159
+
160
+ print("-" * 80)
161
+
162
+ # Summary statistics
163
+ scores = [r['total_score'] for r in results]
164
+ avg_score = sum(scores) / len(scores) if scores else 0
165
+ max_score = max(scores) if scores else 0
166
+ min_score = min(scores) if scores else 0
167
+
168
+ print(f"Average Score: {avg_score:.2f}")
169
+ print(f"Best Score: {max_score:.2f}")
170
+ print(f"Worst Score: {min_score:.2f}")
171
+ print("=" * 80)
172
+
173
+
174
+ if __name__ == "__main__":
175
+ import argparse
176
+
177
+ parser = argparse.ArgumentParser(description="Evaluate Crafter trace files")
178
+ parser.add_argument("trace_path", type=str, help="Directory containing trace files or single trace file")
179
+ parser.add_argument("--pattern", type=str, default="*.json", help="File pattern to match (for directories)")
180
+ parser.add_argument("--verbose", action="store_true", help="Show detailed evaluation for each trace")
181
+
182
+ args = parser.parse_args()
183
+
184
+ trace_path = Path(args.trace_path)
185
+ if not trace_path.exists():
186
+ print(f"āŒ Path not found: {trace_path}")
187
+ exit(1)
188
+
189
+ # Check if it's a file or directory
190
+ if trace_path.is_file():
191
+ print(f"šŸ” Evaluating single trace: {trace_path}")
192
+ result = evaluate_trace(trace_path)
193
+ print_trace_evaluation(result)
194
+ else:
195
+ print(f"šŸ” Evaluating traces in: {trace_path}")
196
+ results = evaluate_all_traces(trace_path, args.pattern)
197
+
198
+ if args.verbose:
199
+ for result in results:
200
+ print_trace_evaluation(result)
201
+
202
+ print_evaluation_summary(results)
@@ -0,0 +1,166 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Validate that JSONL files are compatible with OpenAI fine-tuning format
4
+ """
5
+
6
+ import json
7
+ import sys
8
+ from pathlib import Path
9
+ from typing import List, Dict, Any
10
+
11
+
12
+ def validate_openai_format(file_path: Path) -> tuple[bool, List[str]]:
13
+ """
14
+ Validate JSONL file for OpenAI fine-tuning compatibility.
15
+ Returns (is_valid, list_of_errors)
16
+ """
17
+ errors = []
18
+ line_count = 0
19
+
20
+ try:
21
+ with open(file_path, 'r') as f:
22
+ for line_num, line in enumerate(f, 1):
23
+ line_count += 1
24
+
25
+ # Parse JSON
26
+ try:
27
+ data = json.loads(line.strip())
28
+ except json.JSONDecodeError as e:
29
+ errors.append(f"Line {line_num}: Invalid JSON - {e}")
30
+ continue
31
+
32
+ # Check required structure
33
+ if not isinstance(data, dict):
34
+ errors.append(f"Line {line_num}: Root must be a dictionary")
35
+ continue
36
+
37
+ # Check for messages key
38
+ if 'messages' not in data:
39
+ errors.append(f"Line {line_num}: Missing 'messages' key")
40
+ continue
41
+
42
+ messages = data['messages']
43
+ if not isinstance(messages, list):
44
+ errors.append(f"Line {line_num}: 'messages' must be a list")
45
+ continue
46
+
47
+ if len(messages) < 2:
48
+ errors.append(f"Line {line_num}: Need at least 2 messages (system/user + assistant)")
49
+ continue
50
+
51
+ # Validate each message
52
+ roles_seen = []
53
+ for msg_idx, msg in enumerate(messages):
54
+ if not isinstance(msg, dict):
55
+ errors.append(f"Line {line_num}, message {msg_idx}: Message must be a dictionary")
56
+ continue
57
+
58
+ # Check required fields
59
+ if 'role' not in msg:
60
+ errors.append(f"Line {line_num}, message {msg_idx}: Missing 'role'")
61
+ continue
62
+
63
+ role = msg['role']
64
+ roles_seen.append(role)
65
+
66
+ if role not in ['system', 'user', 'assistant', 'tool']:
67
+ errors.append(f"Line {line_num}, message {msg_idx}: Invalid role '{role}'")
68
+
69
+ # Check content or tool_calls
70
+ if role == 'assistant':
71
+ # Assistant messages can have content, tool_calls, or both
72
+ if 'content' not in msg and 'tool_calls' not in msg:
73
+ errors.append(f"Line {line_num}, message {msg_idx}: Assistant message needs 'content' or 'tool_calls'")
74
+
75
+ # Validate tool_calls structure
76
+ if 'tool_calls' in msg:
77
+ tool_calls = msg['tool_calls']
78
+ if not isinstance(tool_calls, list):
79
+ errors.append(f"Line {line_num}, message {msg_idx}: 'tool_calls' must be a list")
80
+ else:
81
+ for tc_idx, tc in enumerate(tool_calls):
82
+ if not isinstance(tc, dict):
83
+ errors.append(f"Line {line_num}, message {msg_idx}, tool_call {tc_idx}: Must be a dict")
84
+ continue
85
+
86
+ # Check tool call structure
87
+ if 'id' not in tc:
88
+ errors.append(f"Line {line_num}, message {msg_idx}, tool_call {tc_idx}: Missing 'id'")
89
+ if 'type' not in tc:
90
+ errors.append(f"Line {line_num}, message {msg_idx}, tool_call {tc_idx}: Missing 'type'")
91
+ if 'function' not in tc:
92
+ errors.append(f"Line {line_num}, message {msg_idx}, tool_call {tc_idx}: Missing 'function'")
93
+ else:
94
+ func = tc['function']
95
+ if not isinstance(func, dict):
96
+ errors.append(f"Line {line_num}, message {msg_idx}, tool_call {tc_idx}: 'function' must be a dict")
97
+ else:
98
+ if 'name' not in func:
99
+ errors.append(f"Line {line_num}, message {msg_idx}, tool_call {tc_idx}: Missing function 'name'")
100
+ if 'arguments' not in func:
101
+ errors.append(f"Line {line_num}, message {msg_idx}, tool_call {tc_idx}: Missing function 'arguments'")
102
+ else:
103
+ # Other roles must have content
104
+ if 'content' not in msg:
105
+ errors.append(f"Line {line_num}, message {msg_idx}: {role} message missing 'content'")
106
+
107
+ # Check message order
108
+ if roles_seen[-1] != 'assistant':
109
+ errors.append(f"Line {line_num}: Last message must be from assistant (found {roles_seen[-1]})")
110
+
111
+ except Exception as e:
112
+ errors.append(f"Error reading file: {e}")
113
+ return False, errors
114
+
115
+ print(f"Validated {line_count} examples")
116
+ return len(errors) == 0, errors
117
+
118
+
119
+ def main():
120
+ if len(sys.argv) < 2:
121
+ print("Usage: python validate_openai_format.py <jsonl_file>")
122
+ sys.exit(1)
123
+
124
+ file_path = Path(sys.argv[1])
125
+ if not file_path.exists():
126
+ print(f"File not found: {file_path}")
127
+ sys.exit(1)
128
+
129
+ print(f"Validating OpenAI format for: {file_path}")
130
+ print("=" * 60)
131
+
132
+ is_valid, errors = validate_openai_format(file_path)
133
+
134
+ if is_valid:
135
+ print("āœ… File is valid for OpenAI fine-tuning!")
136
+
137
+ # Show sample
138
+ with open(file_path, 'r') as f:
139
+ first_line = f.readline()
140
+ example = json.loads(first_line)
141
+
142
+ print("\nSample example structure:")
143
+ print(f"- Number of messages: {len(example['messages'])}")
144
+ print(f"- Message roles: {[msg['role'] for msg in example['messages']]}")
145
+
146
+ # Check if assistant has tool calls
147
+ last_msg = example['messages'][-1]
148
+ if 'tool_calls' in last_msg:
149
+ print(f"- Assistant uses tool calls: Yes")
150
+ print(f"- Number of tool calls: {len(last_msg['tool_calls'])}")
151
+ if last_msg['tool_calls']:
152
+ print(f"- First tool: {last_msg['tool_calls'][0]['function']['name']}")
153
+ else:
154
+ print(f"- Assistant uses tool calls: No")
155
+ else:
156
+ print(f"āŒ File has {len(errors)} validation errors:")
157
+ for i, error in enumerate(errors[:10]): # Show first 10 errors
158
+ print(f" {i+1}. {error}")
159
+ if len(errors) > 10:
160
+ print(f" ... and {len(errors) - 10} more errors")
161
+
162
+ return 0 if is_valid else 1
163
+
164
+
165
+ if __name__ == "__main__":
166
+ sys.exit(main())
@@ -119,11 +119,20 @@ class CrafterInteractTool(AbstractTool):
119
119
 
120
120
  # Default observation callable (can be customized via __init__)
121
121
  class SynthCrafterObservationCallable(GetObservationCallable):
122
+ """Default observation: public state dict + per-step reward/flags.
123
+
124
+ Additionally computes a small local semantic patch centered on the player
125
+ to simplify visualization on the client. The patch is exposed under the
126
+ key `semantic_map_patch7` as a list-of-lists of ints (7x7 unless the
127
+ semantic map is smaller, in which case it is cropped at edges).
128
+ """
129
+
130
+ def __init__(self, view_size: int = 7) -> None:
131
+ self.view_size = max(1, int(view_size))
132
+
122
133
  async def get_observation(
123
134
  self, pub: CrafterPublicState, priv: CrafterPrivateState
124
135
  ) -> InternalObservation:
125
- # Example: return a dictionary combining public and selected private info
126
- # Actual observation structure depends on agent's needs.
127
136
  obs_dict: Dict[str, Any] = dataclasses.asdict(pub) # type: ignore
128
137
  obs_dict["reward_last_step"] = priv.reward_last_step
129
138
  obs_dict["total_reward_episode"] = priv.total_reward_episode
@@ -131,6 +140,36 @@ class SynthCrafterObservationCallable(GetObservationCallable):
131
140
  obs_dict["truncated"] = priv.truncated
132
141
  if pub.error_info:
133
142
  obs_dict["tool_error"] = pub.error_info
143
+
144
+ # Derive a simple local semantic patch around the player for easy rendering
145
+ try:
146
+ sem = pub.semantic_map
147
+ if sem is not None:
148
+ rows = int(getattr(sem, "shape", [0, 0])[0]) # type: ignore
149
+ cols = int(getattr(sem, "shape", [0, 0])[1]) # type: ignore
150
+ if rows > 0 and cols > 0:
151
+ px, py = int(pub.player_position[0]), int(pub.player_position[1])
152
+ half = max(1, self.view_size // 2)
153
+ x0, y0 = px - half, py - half
154
+ x1, y1 = px + half, py + half
155
+ patch: list[list[int]] = []
156
+ for gy in range(y0, y1 + 1):
157
+ row_vals: list[int] = []
158
+ for gx in range(x0, x1 + 1):
159
+ if 0 <= gy < rows and 0 <= gx < cols:
160
+ try:
161
+ val = int(sem[gy, gx]) # type: ignore[index]
162
+ except Exception:
163
+ val = 0
164
+ else:
165
+ val = 0
166
+ row_vals.append(val)
167
+ patch.append(row_vals)
168
+ obs_dict["semantic_map_patch7"] = patch
169
+ except Exception:
170
+ # Best-effort; omit patch on error
171
+ pass
172
+
134
173
  return obs_dict
135
174
 
136
175
 
@@ -0,0 +1 @@
1
+ """Agent demonstrations for custom Crafter environments."""
@@ -0,0 +1,202 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Trace evaluation functions for Crafter episodes.
4
+ Scores traces based on achievements and invalid actions.
5
+ """
6
+
7
+ import json
8
+ from collections import defaultdict
9
+ from pathlib import Path
10
+ from typing import Any, Dict, List, Tuple
11
+
12
+ # Scoring weights
13
+ WEIGHTS = {
14
+ 'easy_achievement': 1.0, # Easy achievement (e.g., collect_wood)
15
+ 'medium_achievement': 2.5, # Medium achievement (e.g., make_wood_pickaxe)
16
+ 'hard_achievement': 5.0, # Hard achievement (e.g., make_iron_sword)
17
+ 'invalid_action': -0.05, # Invalid action penalty (50 invalid = -1 medium achievement)
18
+ }
19
+
20
+ # Map hook names to scoring categories
21
+ HOOK_TO_SCORE_TYPE = {
22
+ 'easy_achievement': 'easy_achievement',
23
+ 'medium_achievement': 'medium_achievement',
24
+ 'hard_achievement': 'hard_achievement',
25
+ 'invalid_action': 'invalid_action'
26
+ }
27
+
28
+
29
+ def evaluate_event(event: Dict[str, Any]) -> Tuple[float, str]:
30
+ """
31
+ Evaluate a single event based on its hooks.
32
+ Returns: (score, symbol) where symbol is '+', '-', or '0'
33
+ """
34
+ score = 0.0
35
+ symbol = '0'
36
+
37
+ # Check if event has metadata from hooks
38
+ event_metadata = event.get('event_metadata', [])
39
+
40
+ for metadata in event_metadata:
41
+ hook_name = metadata.get('hook_name', '')
42
+
43
+ if hook_name in HOOK_TO_SCORE_TYPE:
44
+ score_type = HOOK_TO_SCORE_TYPE[hook_name]
45
+ weight = WEIGHTS[score_type]
46
+ score += weight
47
+
48
+ # Determine symbol
49
+ if weight > 0:
50
+ symbol = '+'
51
+ elif weight < 0:
52
+ symbol = '-'
53
+
54
+ return score, symbol
55
+
56
+
57
+ def evaluate_trace(trace_path: Path) -> Dict[str, Any]:
58
+ """
59
+ Evaluate an entire trace file.
60
+ Returns detailed scoring breakdown and trajectory visualization.
61
+ """
62
+ with open(trace_path, 'r') as f:
63
+ trace_data = json.load(f)
64
+
65
+ # Track counts
66
+ counts = defaultdict(int)
67
+ total_score = 0.0
68
+ trajectory_symbols = []
69
+
70
+ # Process event history
71
+ event_history = trace_data.get('event_history', [])
72
+
73
+ for event in event_history:
74
+ event_score, symbol = evaluate_event(event)
75
+ total_score += event_score
76
+
77
+ # Only add symbol if score is non-zero
78
+ if event_score != 0:
79
+ trajectory_symbols.append(symbol)
80
+
81
+ # Count hook types
82
+ for metadata in event.get('event_metadata', []):
83
+ hook_name = metadata.get('hook_name', '')
84
+ if hook_name in HOOK_TO_SCORE_TYPE:
85
+ score_type = HOOK_TO_SCORE_TYPE[hook_name]
86
+ counts[score_type] += 1
87
+
88
+ # Create trajectory string
89
+ trajectory_str = ''.join(trajectory_symbols) if trajectory_symbols else '(no scored events)'
90
+
91
+ return {
92
+ 'total_score': total_score,
93
+ 'counts': dict(counts),
94
+ 'trajectory': trajectory_str,
95
+ 'num_events': len(event_history),
96
+ 'trace_file': trace_path.name
97
+ }
98
+
99
+
100
+ def print_trace_evaluation(eval_result: Dict[str, Any]):
101
+ """Print a formatted evaluation result for a single trace."""
102
+ print(f"\nšŸ“Š Trace: {eval_result['trace_file']}")
103
+ print(f" Score: {eval_result['total_score']:.2f}")
104
+ print(f" Events: {eval_result['num_events']}")
105
+
106
+ counts = eval_result['counts']
107
+ if counts:
108
+ print(" Breakdown:")
109
+ if 'easy_achievement' in counts:
110
+ print(f" Easy achievements: {counts['easy_achievement']} Ɨ {WEIGHTS['easy_achievement']} = {counts['easy_achievement'] * WEIGHTS['easy_achievement']:.2f}")
111
+ if 'medium_achievement' in counts:
112
+ print(f" Medium achievements: {counts['medium_achievement']} Ɨ {WEIGHTS['medium_achievement']} = {counts['medium_achievement'] * WEIGHTS['medium_achievement']:.2f}")
113
+ if 'hard_achievement' in counts:
114
+ print(f" Hard achievements: {counts['hard_achievement']} Ɨ {WEIGHTS['hard_achievement']} = {counts['hard_achievement'] * WEIGHTS['hard_achievement']:.2f}")
115
+ if 'invalid_action' in counts:
116
+ print(f" Invalid actions: {counts['invalid_action']} Ɨ {WEIGHTS['invalid_action']} = {counts['invalid_action'] * WEIGHTS['invalid_action']:.2f}")
117
+
118
+ print(f" Trajectory: {eval_result['trajectory']}")
119
+
120
+
121
+ def evaluate_all_traces(trace_dir: Path, pattern: str = "*.json") -> List[Dict[str, Any]]:
122
+ """
123
+ Evaluate all trace files in a directory.
124
+ Returns list of evaluation results sorted by score.
125
+ """
126
+ trace_files = list(trace_dir.glob(pattern))
127
+ results = []
128
+
129
+ for trace_file in trace_files:
130
+ try:
131
+ result = evaluate_trace(trace_file)
132
+ results.append(result)
133
+ except Exception as e:
134
+ print(f"āš ļø Error evaluating {trace_file.name}: {e}")
135
+
136
+ # Sort by score (descending)
137
+ results.sort(key=lambda x: x['total_score'], reverse=True)
138
+
139
+ return results
140
+
141
+
142
+ def print_evaluation_summary(results: List[Dict[str, Any]]):
143
+ """Print a summary of all trace evaluations."""
144
+ if not results:
145
+ print("No traces to evaluate.")
146
+ return
147
+
148
+ print("\n" + "=" * 80)
149
+ print("šŸ“ˆ TRACE EVALUATION SUMMARY")
150
+ print("=" * 80)
151
+ print(f"{'Rank':<6} {'Score':<10} {'Trajectory':<50} {'File':<30}")
152
+ print("-" * 80)
153
+
154
+ for i, result in enumerate(results, 1):
155
+ trajectory = result['trajectory']
156
+ if len(trajectory) > 50:
157
+ trajectory = trajectory[:47] + "..."
158
+ print(f"{i:<6} {result['total_score']:<10.2f} {trajectory:<50} {result['trace_file'][:30]:<30}")
159
+
160
+ print("-" * 80)
161
+
162
+ # Summary statistics
163
+ scores = [r['total_score'] for r in results]
164
+ avg_score = sum(scores) / len(scores) if scores else 0
165
+ max_score = max(scores) if scores else 0
166
+ min_score = min(scores) if scores else 0
167
+
168
+ print(f"Average Score: {avg_score:.2f}")
169
+ print(f"Best Score: {max_score:.2f}")
170
+ print(f"Worst Score: {min_score:.2f}")
171
+ print("=" * 80)
172
+
173
+
174
+ if __name__ == "__main__":
175
+ import argparse
176
+
177
+ parser = argparse.ArgumentParser(description="Evaluate Crafter trace files")
178
+ parser.add_argument("trace_path", type=str, help="Directory containing trace files or single trace file")
179
+ parser.add_argument("--pattern", type=str, default="*.json", help="File pattern to match (for directories)")
180
+ parser.add_argument("--verbose", action="store_true", help="Show detailed evaluation for each trace")
181
+
182
+ args = parser.parse_args()
183
+
184
+ trace_path = Path(args.trace_path)
185
+ if not trace_path.exists():
186
+ print(f"āŒ Path not found: {trace_path}")
187
+ exit(1)
188
+
189
+ # Check if it's a file or directory
190
+ if trace_path.is_file():
191
+ print(f"šŸ” Evaluating single trace: {trace_path}")
192
+ result = evaluate_trace(trace_path)
193
+ print_trace_evaluation(result)
194
+ else:
195
+ print(f"šŸ” Evaluating traces in: {trace_path}")
196
+ results = evaluate_all_traces(trace_path, args.pattern)
197
+
198
+ if args.verbose:
199
+ for result in results:
200
+ print_trace_evaluation(result)
201
+
202
+ print_evaluation_summary(results)