synth-ai 0.2.4.dev8__py3-none-any.whl ā 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/__init__.py +1 -1
- synth_ai/cli/__init__.py +6 -0
- synth_ai/cli/demo.py +68 -9
- synth_ai/cli/rl_demo.py +137 -0
- synth_ai/cli/root.py +65 -0
- synth_ai/demos/core/__init__.py +1 -0
- synth_ai/demos/core/cli.py +685 -0
- synth_ai/demos/demo_task_apps/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/core.py +374 -0
- synth_ai/demos/demo_task_apps/math/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/math/app.py +37 -0
- synth_ai/demos/demo_task_apps/math/config.toml +44 -0
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +60 -0
- synth_ai/demos/demo_task_apps/math/deploy_task_app.sh +22 -0
- synth_ai/environments/examples/bandit/__init__.py +33 -0
- synth_ai/environments/examples/bandit/engine.py +294 -0
- synth_ai/environments/examples/bandit/environment.py +194 -0
- synth_ai/environments/examples/bandit/taskset.py +200 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/analyze_semantic_words_markdown.py +250 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +59 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_config.toml +24 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/crafter_synth_config.toml +56 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_config_modal.toml +32 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +724 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_modal.py +384 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_action_results.py +53 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_agent_actions.py +178 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_latest_run.py +222 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_lm_traces.py +183 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_no_rewards.py +210 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_trace_issue.py +206 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_db_schema.py +49 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_latest_results.py +64 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/debug_agent_responses.py +88 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/quick_trace_check.py +77 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/compare_experiments.py +324 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/kick_off_ft_oai.py +362 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/multi_model_config.toml +49 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_enhanced_hooks.py +332 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_events.py +97 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_results.py +217 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_hook_storage.py +87 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_seeds.py +88 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/compare_seed_performance.py +195 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/custom_eval_pipelines.py +400 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/plot_hook_frequency.py +195 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/seed_analysis_summary.py +56 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v3.py +858 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +52 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +874 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/example_v3_usage.py +216 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/compare_traces.py +296 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_env_serialization.py +464 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/debug_player_loss.py +112 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_service.py +203 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_slowness.py +305 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_by_difficulty.py +126 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_example.py +94 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/explore_saved_states.py +142 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft.py +26 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft_OLD.py +984 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_gemini.py +724 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_modal.py +386 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_metadata.py +205 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_gemini.py +150 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_modal.py +283 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/prepare_vertex_ft.py +280 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/profile_env_slowness.py +456 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/replicate_issue.py +166 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_and_eval.py +102 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_comparison.py +128 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_qwen_rollouts.py +655 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/trace_eval_OLD.py +202 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/validate_openai_format.py +166 -0
- synth_ai/environments/examples/crafter_classic/environment.py +41 -2
- synth_ai/environments/examples/crafter_custom/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/crafter_custom/agent_demos/trace_eval.py +202 -0
- synth_ai/environments/examples/crafter_custom/old/analyze_diamond_issue.py +159 -0
- synth_ai/environments/examples/crafter_custom/old/analyze_diamond_spawning.py +158 -0
- synth_ai/environments/examples/crafter_custom/old/compare_worlds.py +71 -0
- synth_ai/environments/examples/crafter_custom/old/dataset_stats.py +105 -0
- synth_ai/environments/examples/crafter_custom/old/diamond_spawning_summary.py +119 -0
- synth_ai/environments/examples/crafter_custom/old/example_dataset_usage.py +52 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +112 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +221 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +831 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +899 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +95 -0
- synth_ai/environments/service/app.py +8 -0
- synth_ai/install_sqld.sh +40 -0
- synth_ai-0.2.5.dist-info/METADATA +106 -0
- {synth_ai-0.2.4.dev8.dist-info ā synth_ai-0.2.5.dist-info}/RECORD +111 -12
- {synth_ai-0.2.4.dev8.dist-info ā synth_ai-0.2.5.dist-info}/entry_points.txt +1 -0
- synth_ai-0.2.4.dev8.dist-info/METADATA +0 -635
- {synth_ai-0.2.4.dev8.dist-info ā synth_ai-0.2.5.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev8.dist-info ā synth_ai-0.2.5.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev8.dist-info ā synth_ai-0.2.5.dist-info}/top_level.txt +0 -0
synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/kick_off_ft_oai.py
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
OpenAI Fine-Tuning Script
|
|
4
|
+
========================
|
|
5
|
+
Uploads a JSONL file, kicks off a fine-tuning job, and polls until completion.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import argparse
|
|
9
|
+
import json
|
|
10
|
+
import os
|
|
11
|
+
import random
|
|
12
|
+
import sys
|
|
13
|
+
import time
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
import openai
|
|
19
|
+
except ImportError:
|
|
20
|
+
print("ā OpenAI package not found. Installing...")
|
|
21
|
+
os.system("pip install openai")
|
|
22
|
+
import openai
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
import tiktoken
|
|
26
|
+
except ImportError:
|
|
27
|
+
print("ā tiktoken package not found. Installing...")
|
|
28
|
+
os.system("pip install tiktoken")
|
|
29
|
+
import tiktoken
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def encoding_for(model: str = "gpt-4.1-mini"):
|
|
33
|
+
"""Return a tiktoken encoding for any GPTā4.1 family model."""
|
|
34
|
+
try:
|
|
35
|
+
return tiktoken.encoding_for_model(model)
|
|
36
|
+
except KeyError: # 4.1 isn't mapped yet
|
|
37
|
+
return tiktoken.get_encoding("o200k_base") # same BPE as 4o/modern models
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def analyze_jsonl_tokens(file_path: Path, model: str) -> tuple[int, int, float]:
|
|
41
|
+
"""Analyze JSONL file to estimate token usage."""
|
|
42
|
+
print(f"š Analyzing {file_path.name} for token usage...")
|
|
43
|
+
|
|
44
|
+
# Get the appropriate encoding for the model (handles GPT-4.1 properly)
|
|
45
|
+
encoding = encoding_for(model)
|
|
46
|
+
print(f" š¤ Using encoding: {encoding.name}")
|
|
47
|
+
|
|
48
|
+
total_input_tokens = 0
|
|
49
|
+
total_output_tokens = 0
|
|
50
|
+
line_count = 0
|
|
51
|
+
|
|
52
|
+
with open(file_path, 'r') as f:
|
|
53
|
+
for line in f:
|
|
54
|
+
try:
|
|
55
|
+
data = json.loads(line.strip())
|
|
56
|
+
messages = data.get('messages', [])
|
|
57
|
+
|
|
58
|
+
# Count input tokens (all messages except the last assistant message)
|
|
59
|
+
input_messages = []
|
|
60
|
+
output_message = None
|
|
61
|
+
|
|
62
|
+
for msg in messages:
|
|
63
|
+
if msg.get('role') == 'assistant' and msg == messages[-1]:
|
|
64
|
+
# This is the target output
|
|
65
|
+
output_message = msg
|
|
66
|
+
else:
|
|
67
|
+
# This is input context
|
|
68
|
+
input_messages.append(msg)
|
|
69
|
+
|
|
70
|
+
# Estimate input tokens
|
|
71
|
+
input_text = ""
|
|
72
|
+
for msg in input_messages:
|
|
73
|
+
content = msg.get('content', '')
|
|
74
|
+
if content:
|
|
75
|
+
input_text += content + " "
|
|
76
|
+
|
|
77
|
+
# Include tool calls in input if present
|
|
78
|
+
tool_calls = msg.get('tool_calls', [])
|
|
79
|
+
for tc in tool_calls:
|
|
80
|
+
if tc.get('function', {}).get('arguments'):
|
|
81
|
+
input_text += tc['function']['arguments'] + " "
|
|
82
|
+
|
|
83
|
+
input_tokens = len(encoding.encode(input_text))
|
|
84
|
+
total_input_tokens += input_tokens
|
|
85
|
+
|
|
86
|
+
# Estimate output tokens
|
|
87
|
+
output_tokens = 0
|
|
88
|
+
if output_message:
|
|
89
|
+
content = output_message.get('content', '') or ''
|
|
90
|
+
output_tokens += len(encoding.encode(content))
|
|
91
|
+
|
|
92
|
+
# Include tool calls in output
|
|
93
|
+
tool_calls = output_message.get('tool_calls', [])
|
|
94
|
+
for tc in tool_calls:
|
|
95
|
+
if tc.get('function', {}).get('arguments'):
|
|
96
|
+
output_tokens += len(encoding.encode(tc['function']['arguments']))
|
|
97
|
+
|
|
98
|
+
total_output_tokens += output_tokens
|
|
99
|
+
line_count += 1
|
|
100
|
+
|
|
101
|
+
except json.JSONDecodeError:
|
|
102
|
+
print(f" ā ļø Skipping invalid JSON line {line_count + 1}")
|
|
103
|
+
continue
|
|
104
|
+
except Exception as e:
|
|
105
|
+
print(f" ā ļø Error processing line {line_count + 1}: {e}")
|
|
106
|
+
continue
|
|
107
|
+
|
|
108
|
+
avg_tokens_per_line = (total_input_tokens + total_output_tokens) / line_count if line_count > 0 else 0
|
|
109
|
+
|
|
110
|
+
print(f" š Analysis complete:")
|
|
111
|
+
print(f" Lines: {line_count:,}")
|
|
112
|
+
print(f" Input tokens: {total_input_tokens:,}")
|
|
113
|
+
print(f" Output tokens: {total_output_tokens:,}")
|
|
114
|
+
print(f" Total tokens: {total_input_tokens + total_output_tokens:,}")
|
|
115
|
+
print(f" Avg tokens/line: {avg_tokens_per_line:.1f}")
|
|
116
|
+
|
|
117
|
+
return line_count, total_input_tokens + total_output_tokens, avg_tokens_per_line
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def create_subset_file(original_path: Path, num_lines: int) -> Path:
|
|
121
|
+
"""Create a subset of the original JSONL file with specified number of lines."""
|
|
122
|
+
subset_path = original_path.parent / f"{original_path.stem}_subset_{num_lines}.jsonl"
|
|
123
|
+
|
|
124
|
+
print(f"š Creating subset with {num_lines} lines...")
|
|
125
|
+
|
|
126
|
+
# Read all lines
|
|
127
|
+
with open(original_path, 'r') as f:
|
|
128
|
+
all_lines = [line.strip() for line in f if line.strip()]
|
|
129
|
+
|
|
130
|
+
# Randomly sample lines
|
|
131
|
+
if num_lines >= len(all_lines):
|
|
132
|
+
selected_lines = all_lines
|
|
133
|
+
print(f" ā ļø Requested {num_lines} lines, but file only has {len(all_lines)}. Using all lines.")
|
|
134
|
+
else:
|
|
135
|
+
selected_lines = random.sample(all_lines, num_lines)
|
|
136
|
+
|
|
137
|
+
# Write subset
|
|
138
|
+
with open(subset_path, 'w') as f:
|
|
139
|
+
for line in selected_lines:
|
|
140
|
+
f.write(line + '\n')
|
|
141
|
+
|
|
142
|
+
print(f" ā
Subset saved to: {subset_path.name}")
|
|
143
|
+
return subset_path
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def upload_file(client: openai.OpenAI, file_path: Path) -> str:
|
|
147
|
+
"""Upload training file to OpenAI."""
|
|
148
|
+
print(f"š¤ Uploading {file_path.name} ({file_path.stat().st_size / 1024 / 1024:.1f} MB)...")
|
|
149
|
+
|
|
150
|
+
with open(file_path, 'rb') as f:
|
|
151
|
+
file_obj = client.files.create(
|
|
152
|
+
file=f,
|
|
153
|
+
purpose="fine-tune"
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
print(f"ā
File uploaded: {file_obj.id}")
|
|
157
|
+
return file_obj.id
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def create_fine_tune_job(client: openai.OpenAI, file_id: str, model: str = "gpt-4.1-nano-2025-04-14",
|
|
161
|
+
suffix: Optional[str] = None) -> str:
|
|
162
|
+
"""Create a fine-tuning job."""
|
|
163
|
+
print(f"š Starting fine-tune job for {model}...")
|
|
164
|
+
|
|
165
|
+
kwargs = {
|
|
166
|
+
"training_file": file_id,
|
|
167
|
+
"model": model,
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
if suffix:
|
|
171
|
+
kwargs["suffix"] = suffix
|
|
172
|
+
|
|
173
|
+
job = client.fine_tuning.jobs.create(**kwargs)
|
|
174
|
+
|
|
175
|
+
print(f"ā
Fine-tune job created: {job.id}")
|
|
176
|
+
print(f" Model: {job.model}")
|
|
177
|
+
print(f" Status: {job.status}")
|
|
178
|
+
|
|
179
|
+
return job.id
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def poll_job_status(client: openai.OpenAI, job_id: str, poll_interval: int = 30) -> str:
|
|
183
|
+
"""Poll job status until completion."""
|
|
184
|
+
print(f"ā³ Polling job {job_id} every {poll_interval}s...")
|
|
185
|
+
|
|
186
|
+
start_time = time.time()
|
|
187
|
+
last_status = None
|
|
188
|
+
|
|
189
|
+
while True:
|
|
190
|
+
try:
|
|
191
|
+
job = client.fine_tuning.jobs.retrieve(job_id)
|
|
192
|
+
|
|
193
|
+
if job.status != last_status:
|
|
194
|
+
elapsed = time.time() - start_time
|
|
195
|
+
print(f" Status: {job.status} (elapsed: {elapsed/60:.1f}m)")
|
|
196
|
+
last_status = job.status
|
|
197
|
+
|
|
198
|
+
# Show training progress if available
|
|
199
|
+
if hasattr(job, 'trained_tokens') and job.trained_tokens:
|
|
200
|
+
print(f" Trained tokens: {job.trained_tokens:,}")
|
|
201
|
+
|
|
202
|
+
# Terminal states
|
|
203
|
+
if job.status == "succeeded":
|
|
204
|
+
print(f"š Fine-tuning completed successfully!")
|
|
205
|
+
print(f" Final model: {job.fine_tuned_model}")
|
|
206
|
+
return job.fine_tuned_model
|
|
207
|
+
|
|
208
|
+
elif job.status == "failed":
|
|
209
|
+
print(f"ā Fine-tuning failed!")
|
|
210
|
+
if hasattr(job, 'error') and job.error:
|
|
211
|
+
print(f" Error: {job.error}")
|
|
212
|
+
return None
|
|
213
|
+
|
|
214
|
+
elif job.status == "cancelled":
|
|
215
|
+
print(f"ā ļø Fine-tuning was cancelled")
|
|
216
|
+
return None
|
|
217
|
+
|
|
218
|
+
# Continue polling for running states
|
|
219
|
+
elif job.status in ["validating_files", "queued", "running"]:
|
|
220
|
+
time.sleep(poll_interval)
|
|
221
|
+
continue
|
|
222
|
+
|
|
223
|
+
else:
|
|
224
|
+
print(f"ā ļø Unknown status: {job.status}")
|
|
225
|
+
time.sleep(poll_interval)
|
|
226
|
+
continue
|
|
227
|
+
|
|
228
|
+
except KeyboardInterrupt:
|
|
229
|
+
print(f"\nā ļø Interrupted by user. Job {job_id} is still running on OpenAI.")
|
|
230
|
+
print(f" Check status with: openai api fine_tunes.get -i {job_id}")
|
|
231
|
+
return None
|
|
232
|
+
|
|
233
|
+
except Exception as e:
|
|
234
|
+
print(f"ā Error polling job: {e}")
|
|
235
|
+
time.sleep(poll_interval)
|
|
236
|
+
continue
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def main():
|
|
240
|
+
parser = argparse.ArgumentParser(description="OpenAI Fine-Tuning Script")
|
|
241
|
+
parser.add_argument("jsonl_file", type=Path, help="Path to JSONL training file")
|
|
242
|
+
parser.add_argument("--model", default="gpt-4.1-nano-2025-04-14",
|
|
243
|
+
help="Base model to fine-tune (default: gpt-4.1-nano-2025-04-14)")
|
|
244
|
+
parser.add_argument("--suffix", type=str, help="Suffix for the fine-tuned model name")
|
|
245
|
+
parser.add_argument("--poll-interval", type=int, default=30,
|
|
246
|
+
help="Polling interval in seconds (default: 30)")
|
|
247
|
+
parser.add_argument("--api-key", type=str, help="OpenAI API key (or set OPENAI_API_KEY env var)")
|
|
248
|
+
|
|
249
|
+
args = parser.parse_args()
|
|
250
|
+
|
|
251
|
+
# Validate file
|
|
252
|
+
if not args.jsonl_file.exists():
|
|
253
|
+
print(f"ā File not found: {args.jsonl_file}")
|
|
254
|
+
sys.exit(1)
|
|
255
|
+
|
|
256
|
+
if not args.jsonl_file.suffix == '.jsonl':
|
|
257
|
+
print(f"ā ļø Warning: File doesn't have .jsonl extension: {args.jsonl_file}")
|
|
258
|
+
|
|
259
|
+
# Setup API key
|
|
260
|
+
api_key = args.api_key or os.getenv("OPENAI_API_KEY")
|
|
261
|
+
if not api_key:
|
|
262
|
+
print("ā OpenAI API key required. Set OPENAI_API_KEY env var or use --api-key")
|
|
263
|
+
sys.exit(1)
|
|
264
|
+
|
|
265
|
+
# Initialize client
|
|
266
|
+
client = openai.OpenAI(api_key=api_key)
|
|
267
|
+
|
|
268
|
+
# Analyze tokens first
|
|
269
|
+
line_count, total_tokens, avg_tokens = analyze_jsonl_tokens(args.jsonl_file, args.model)
|
|
270
|
+
|
|
271
|
+
# Calculate estimated cost (rough estimate for fine-tuning)
|
|
272
|
+
# OpenAI pricing is approximately $8 per 1M tokens for gpt-3.5-turbo fine-tuning
|
|
273
|
+
# For gpt-4 models, it's higher (varies by model)
|
|
274
|
+
estimated_cost = total_tokens / 1_000_000 * 8 # Rough estimate
|
|
275
|
+
|
|
276
|
+
print(f"\nš° Estimated fine-tuning cost: ~${estimated_cost:.2f}")
|
|
277
|
+
print(f" (Based on $8/1M tokens - actual cost may vary by model)")
|
|
278
|
+
|
|
279
|
+
# Ask if user wants to use a subset
|
|
280
|
+
print(f"\nš¤ Do you want to fine-tune on all {line_count:,} lines?")
|
|
281
|
+
print(f" Total tokens: {total_tokens:,}")
|
|
282
|
+
print(f" Average tokens per line: {avg_tokens:.1f}")
|
|
283
|
+
|
|
284
|
+
use_subset = input("\nUse a subset instead? (y/n): ").lower().strip()
|
|
285
|
+
|
|
286
|
+
training_file = args.jsonl_file
|
|
287
|
+
if use_subset == 'y':
|
|
288
|
+
while True:
|
|
289
|
+
try:
|
|
290
|
+
subset_size = input(f"How many lines to use? (1-{line_count}): ").strip()
|
|
291
|
+
subset_size = int(subset_size)
|
|
292
|
+
|
|
293
|
+
if subset_size < 1:
|
|
294
|
+
print(" ā Number must be at least 1")
|
|
295
|
+
continue
|
|
296
|
+
elif subset_size > line_count:
|
|
297
|
+
print(f" ā Number cannot exceed {line_count}")
|
|
298
|
+
continue
|
|
299
|
+
|
|
300
|
+
# Show updated estimates
|
|
301
|
+
subset_tokens = int(avg_tokens * subset_size)
|
|
302
|
+
subset_cost = subset_tokens / 1_000_000 * 8
|
|
303
|
+
|
|
304
|
+
print(f"\nš Subset estimates:")
|
|
305
|
+
print(f" Lines: {subset_size:,}")
|
|
306
|
+
print(f" Estimated tokens: {subset_tokens:,}")
|
|
307
|
+
print(f" Estimated cost: ~${subset_cost:.2f}")
|
|
308
|
+
|
|
309
|
+
confirm = input(f"\nProceed with {subset_size} lines? (y/n): ").lower().strip()
|
|
310
|
+
if confirm == 'y':
|
|
311
|
+
training_file = create_subset_file(args.jsonl_file, subset_size)
|
|
312
|
+
break
|
|
313
|
+
else:
|
|
314
|
+
continue
|
|
315
|
+
|
|
316
|
+
except ValueError:
|
|
317
|
+
print(" ā Please enter a valid number")
|
|
318
|
+
continue
|
|
319
|
+
|
|
320
|
+
print("š¤ OpenAI Fine-Tuning Pipeline")
|
|
321
|
+
print("=" * 50)
|
|
322
|
+
print(f"Training file: {training_file}")
|
|
323
|
+
print(f"Base model: {args.model}")
|
|
324
|
+
if args.suffix:
|
|
325
|
+
print(f"Model suffix: {args.suffix}")
|
|
326
|
+
print("=" * 50)
|
|
327
|
+
|
|
328
|
+
try:
|
|
329
|
+
# Step 1: Upload file
|
|
330
|
+
file_id = upload_file(client, training_file)
|
|
331
|
+
|
|
332
|
+
# Step 2: Create fine-tune job
|
|
333
|
+
job_id = create_fine_tune_job(client, file_id, args.model, args.suffix)
|
|
334
|
+
|
|
335
|
+
# Step 3: Poll until completion
|
|
336
|
+
final_model = poll_job_status(client, job_id, args.poll_interval)
|
|
337
|
+
|
|
338
|
+
if final_model:
|
|
339
|
+
print("\n" + "=" * 50)
|
|
340
|
+
print(f"šÆ SUCCESS! Fine-tuned model ready: {final_model}")
|
|
341
|
+
print("=" * 50)
|
|
342
|
+
|
|
343
|
+
# Show usage example
|
|
344
|
+
print("\nš Usage example:")
|
|
345
|
+
print(f'client = openai.OpenAI()')
|
|
346
|
+
print(f'response = client.chat.completions.create(')
|
|
347
|
+
print(f' model="{final_model}",')
|
|
348
|
+
print(f' messages=[{{"role": "user", "content": "Hello!"}}]')
|
|
349
|
+
print(f')')
|
|
350
|
+
else:
|
|
351
|
+
print("\nā Fine-tuning did not complete successfully")
|
|
352
|
+
sys.exit(1)
|
|
353
|
+
|
|
354
|
+
except Exception as e:
|
|
355
|
+
print(f"\nā Unexpected error: {e}")
|
|
356
|
+
import traceback
|
|
357
|
+
traceback.print_exc()
|
|
358
|
+
sys.exit(1)
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
if __name__ == "__main__":
|
|
362
|
+
main()
|
synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/multi_model_config.toml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
# Configuration for Multi-Model Crafter Evaluation
|
|
2
|
+
|
|
3
|
+
[experiment]
|
|
4
|
+
# Number of episodes to run per model
|
|
5
|
+
episodes = 10
|
|
6
|
+
|
|
7
|
+
# Maximum turns per episode
|
|
8
|
+
max_turns = 100
|
|
9
|
+
|
|
10
|
+
# Difficulty level (easy, medium, hard)
|
|
11
|
+
difficulty = "easy"
|
|
12
|
+
|
|
13
|
+
# Models to test (can be overridden with --models CLI argument)
|
|
14
|
+
models = [
|
|
15
|
+
"gpt-4o-mini",
|
|
16
|
+
"gpt-4.1-mini",
|
|
17
|
+
"gpt-4.1-nano",
|
|
18
|
+
"gemini-1.5-flash",
|
|
19
|
+
"gemini-2.5-flash-lite",
|
|
20
|
+
"qwen3/32b"
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
# Concurrency settings
|
|
24
|
+
concurrent_models = false # Whether to run models concurrently
|
|
25
|
+
max_concurrent_models = 3 # Maximum number of models to run at once
|
|
26
|
+
|
|
27
|
+
[services]
|
|
28
|
+
# Crafter environment service URL
|
|
29
|
+
crafter_service_url = "http://localhost:8901"
|
|
30
|
+
|
|
31
|
+
# Database path for storing traces
|
|
32
|
+
database_path = "crafter_multi_model_traces.duckdb"
|
|
33
|
+
|
|
34
|
+
[output]
|
|
35
|
+
# Directory for saving traces and results
|
|
36
|
+
trace_dir = "./traces_multi_model"
|
|
37
|
+
|
|
38
|
+
# Whether to save detailed traces
|
|
39
|
+
save_traces = true
|
|
40
|
+
|
|
41
|
+
# Whether to enable verbose output
|
|
42
|
+
verbose = true
|
|
43
|
+
|
|
44
|
+
[http]
|
|
45
|
+
# HTTP timeout in seconds
|
|
46
|
+
timeout = 30.0
|
|
47
|
+
|
|
48
|
+
# Maximum retries for failed requests
|
|
49
|
+
max_retries = 3
|