synth-ai 0.2.4.dev7__py3-none-any.whl → 0.2.4.dev9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/__init__.py +1 -1
- synth_ai/cli/__init__.py +6 -0
- synth_ai/cli/balance.py +3 -15
- synth_ai/cli/demo.py +68 -9
- synth_ai/cli/rl_demo.py +137 -0
- synth_ai/cli/root.py +65 -0
- synth_ai/config/base_url.py +47 -0
- synth_ai/demos/core/__init__.py +1 -0
- synth_ai/demos/core/cli.py +621 -0
- synth_ai/demos/demo_task_apps/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/core.py +374 -0
- synth_ai/demos/demo_task_apps/math/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/math/app.py +37 -0
- synth_ai/demos/demo_task_apps/math/config.toml +44 -0
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +60 -0
- synth_ai/demos/demo_task_apps/math/deploy_task_app.sh +22 -0
- synth_ai/environments/examples/bandit/__init__.py +33 -0
- synth_ai/environments/examples/bandit/engine.py +294 -0
- synth_ai/environments/examples/bandit/environment.py +194 -0
- synth_ai/environments/examples/bandit/taskset.py +200 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/analyze_semantic_words_markdown.py +250 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +59 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_config.toml +24 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/crafter_synth_config.toml +56 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_config_modal.toml +32 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +724 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_modal.py +384 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_action_results.py +53 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_agent_actions.py +178 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_latest_run.py +222 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_lm_traces.py +183 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_no_rewards.py +210 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_trace_issue.py +206 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_db_schema.py +49 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_latest_results.py +64 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/debug_agent_responses.py +88 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/quick_trace_check.py +77 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/compare_experiments.py +324 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/kick_off_ft_oai.py +362 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/multi_model_config.toml +49 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_enhanced_hooks.py +332 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_events.py +97 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_results.py +217 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_hook_storage.py +87 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_seeds.py +88 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/compare_seed_performance.py +195 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/custom_eval_pipelines.py +400 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/plot_hook_frequency.py +195 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/seed_analysis_summary.py +56 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v3.py +858 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +52 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +874 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/example_v3_usage.py +216 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/compare_traces.py +296 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_env_serialization.py +464 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/debug_player_loss.py +112 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_service.py +203 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_slowness.py +305 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_by_difficulty.py +126 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_example.py +94 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/explore_saved_states.py +142 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft.py +26 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft_OLD.py +984 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_gemini.py +724 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_modal.py +386 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_metadata.py +205 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_gemini.py +150 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_modal.py +283 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/prepare_vertex_ft.py +280 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/profile_env_slowness.py +456 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/replicate_issue.py +166 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_and_eval.py +102 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_comparison.py +128 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_qwen_rollouts.py +655 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/trace_eval_OLD.py +202 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/validate_openai_format.py +166 -0
- synth_ai/environments/examples/crafter_classic/environment.py +41 -2
- synth_ai/environments/examples/crafter_custom/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/crafter_custom/agent_demos/trace_eval.py +202 -0
- synth_ai/environments/examples/crafter_custom/old/analyze_diamond_issue.py +159 -0
- synth_ai/environments/examples/crafter_custom/old/analyze_diamond_spawning.py +158 -0
- synth_ai/environments/examples/crafter_custom/old/compare_worlds.py +71 -0
- synth_ai/environments/examples/crafter_custom/old/dataset_stats.py +105 -0
- synth_ai/environments/examples/crafter_custom/old/diamond_spawning_summary.py +119 -0
- synth_ai/environments/examples/crafter_custom/old/example_dataset_usage.py +52 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +112 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +221 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +831 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +899 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +95 -0
- synth_ai/environments/service/app.py +8 -0
- synth_ai/http.py +102 -0
- synth_ai/inference/__init__.py +7 -0
- synth_ai/inference/client.py +20 -0
- synth_ai/install_sqld.sh +40 -0
- synth_ai/jobs/client.py +246 -0
- synth_ai/learning/__init__.py +24 -0
- synth_ai/learning/client.py +149 -0
- synth_ai/learning/config.py +43 -0
- synth_ai/learning/constants.py +29 -0
- synth_ai/learning/ft_client.py +59 -0
- synth_ai/learning/health.py +43 -0
- synth_ai/learning/jobs.py +205 -0
- synth_ai/learning/rl_client.py +256 -0
- synth_ai/learning/sse.py +58 -0
- synth_ai/learning/validators.py +48 -0
- synth_ai/lm/core/main_v3.py +13 -0
- synth_ai/lm/core/synth_models.py +48 -0
- synth_ai/lm/core/vendor_clients.py +9 -6
- synth_ai/lm/vendors/core/openai_api.py +31 -3
- synth_ai/lm/vendors/openai_standard.py +45 -14
- synth_ai/lm/vendors/supported/custom_endpoint.py +12 -2
- synth_ai/lm/vendors/synth_client.py +372 -28
- synth_ai/rl/__init__.py +30 -0
- synth_ai/rl/contracts.py +32 -0
- synth_ai/rl/env_keys.py +137 -0
- synth_ai/rl/secrets.py +19 -0
- synth_ai/scripts/verify_rewards.py +100 -0
- synth_ai/task/__init__.py +10 -0
- synth_ai/task/contracts.py +120 -0
- synth_ai/task/health.py +28 -0
- synth_ai/task/validators.py +12 -0
- synth_ai/tracing_v3/hooks.py +3 -1
- synth_ai/tracing_v3/session_tracer.py +123 -2
- synth_ai/tracing_v3/turso/manager.py +218 -0
- synth_ai/tracing_v3/turso/models.py +53 -0
- synth_ai-0.2.4.dev9.dist-info/METADATA +91 -0
- {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev9.dist-info}/RECORD +147 -30
- {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev9.dist-info}/entry_points.txt +1 -0
- synth_ai/tui/__init__.py +0 -1
- synth_ai/tui/__main__.py +0 -13
- synth_ai/tui/cli/__init__.py +0 -1
- synth_ai/tui/cli/query_experiments.py +0 -164
- synth_ai/tui/cli/query_experiments_v3.py +0 -164
- synth_ai/tui/dashboard.py +0 -340
- synth_ai-0.2.4.dev7.dist-info/METADATA +0 -193
- {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev9.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev9.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev9.dist-info}/top_level.txt +0 -0
synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_modal.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Modal/Synth Fine-Tuning Script
|
|
4
|
+
==============================
|
|
5
|
+
Uploads a JSONL file to Modal, kicks off a fine-tuning job, and polls until completion.
|
|
6
|
+
Updated for OpenAI v1 compatible unified fine-tuning service.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import argparse
|
|
10
|
+
import asyncio
|
|
11
|
+
import json
|
|
12
|
+
import os
|
|
13
|
+
import random
|
|
14
|
+
import sys
|
|
15
|
+
import time
|
|
16
|
+
from datetime import datetime
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import Any, Dict, Optional
|
|
19
|
+
|
|
20
|
+
import httpx
|
|
21
|
+
|
|
22
|
+
# Add synth_ai to path (optional - only if needed)
|
|
23
|
+
# sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent.parent))
|
|
24
|
+
|
|
25
|
+
# Import Synth LM utilities (optional - only if needed)
|
|
26
|
+
# from synth_ai.lm import SynthConfig
|
|
27
|
+
|
|
28
|
+
# Modal fine-tuning endpoints - Updated for OpenAI v1 compatible service
|
|
29
|
+
MODAL_BASE_URL = os.getenv('MODAL_BASE_URL', os.getenv('SYNTH_BASE_URL', 'https://synth-laboratories--unified-ft-service-fastapi-app.modal.run'))
|
|
30
|
+
MODAL_API_KEY = os.getenv('MODAL_API_KEY', os.getenv('SYNTH_API_KEY', 'sk-test-11111111111111111111111111111111'))
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def analyze_jsonl_tokens(file_path: Path, model: str) -> tuple[int, int, float]:
|
|
34
|
+
"""Analyze JSONL file to estimate token usage."""
|
|
35
|
+
print(f"🔍 Analyzing {file_path.name} for token usage...")
|
|
36
|
+
|
|
37
|
+
# For Modal/Synth, we'll do a rough estimate based on character count
|
|
38
|
+
# Approximate: 1 token ≈ 4 characters (rough estimate)
|
|
39
|
+
CHARS_PER_TOKEN = 4 # noqa: N806
|
|
40
|
+
|
|
41
|
+
total_input_tokens = 0
|
|
42
|
+
total_output_tokens = 0
|
|
43
|
+
line_count = 0
|
|
44
|
+
|
|
45
|
+
with open(file_path, 'r') as f:
|
|
46
|
+
for line in f:
|
|
47
|
+
try:
|
|
48
|
+
data = json.loads(line.strip())
|
|
49
|
+
messages = data.get('messages', [])
|
|
50
|
+
|
|
51
|
+
# Count input tokens (all messages except the last assistant message)
|
|
52
|
+
input_chars = 0
|
|
53
|
+
output_chars = 0
|
|
54
|
+
|
|
55
|
+
for i, msg in enumerate(messages):
|
|
56
|
+
content = msg.get('content', '')
|
|
57
|
+
char_count = len(content)
|
|
58
|
+
|
|
59
|
+
if msg.get('role') == 'assistant' and i == len(messages) - 1:
|
|
60
|
+
# This is the target output
|
|
61
|
+
output_chars += char_count
|
|
62
|
+
else:
|
|
63
|
+
# This is input context
|
|
64
|
+
input_chars += char_count
|
|
65
|
+
|
|
66
|
+
# Include tool calls
|
|
67
|
+
tool_calls = msg.get('tool_calls', [])
|
|
68
|
+
for tc in tool_calls:
|
|
69
|
+
if tc.get('function', {}).get('arguments'):
|
|
70
|
+
if i == len(messages) - 1 and msg.get('role') == 'assistant':
|
|
71
|
+
output_chars += len(tc['function']['arguments'])
|
|
72
|
+
else:
|
|
73
|
+
input_chars += len(tc['function']['arguments'])
|
|
74
|
+
|
|
75
|
+
input_tokens = input_chars // CHARS_PER_TOKEN
|
|
76
|
+
output_tokens = output_chars // CHARS_PER_TOKEN
|
|
77
|
+
|
|
78
|
+
total_input_tokens += input_tokens
|
|
79
|
+
total_output_tokens += output_tokens
|
|
80
|
+
line_count += 1
|
|
81
|
+
|
|
82
|
+
except json.JSONDecodeError:
|
|
83
|
+
print(f" ⚠️ Skipping invalid JSON line {line_count + 1}")
|
|
84
|
+
continue
|
|
85
|
+
except Exception as e:
|
|
86
|
+
print(f" ⚠️ Error processing line {line_count + 1}: {e}")
|
|
87
|
+
continue
|
|
88
|
+
|
|
89
|
+
total_tokens = total_input_tokens + total_output_tokens
|
|
90
|
+
avg_tokens_per_line = total_tokens / line_count if line_count > 0 else 0
|
|
91
|
+
|
|
92
|
+
print(f" 📊 Analysis complete:")
|
|
93
|
+
print(f" Lines: {line_count:,}")
|
|
94
|
+
print(f" Input tokens (est.): {total_input_tokens:,}")
|
|
95
|
+
print(f" Output tokens (est.): {total_output_tokens:,}")
|
|
96
|
+
print(f" Total tokens (est.): {total_tokens:,}")
|
|
97
|
+
print(f" Avg tokens/line: {avg_tokens_per_line:.1f}")
|
|
98
|
+
|
|
99
|
+
return line_count, total_tokens, avg_tokens_per_line
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def create_subset_file(original_path: Path, num_lines: int) -> Path:
|
|
103
|
+
"""Create a subset of the original JSONL file with specified number of lines."""
|
|
104
|
+
subset_path = original_path.parent / f"{original_path.stem}_subset_{num_lines}.jsonl"
|
|
105
|
+
|
|
106
|
+
print(f"📝 Creating subset with {num_lines} lines...")
|
|
107
|
+
|
|
108
|
+
# Read all lines
|
|
109
|
+
with open(original_path, 'r') as f:
|
|
110
|
+
all_lines = [line.strip() for line in f if line.strip()]
|
|
111
|
+
|
|
112
|
+
# Randomly sample lines
|
|
113
|
+
if num_lines >= len(all_lines):
|
|
114
|
+
selected_lines = all_lines
|
|
115
|
+
print(f" ⚠️ Requested {num_lines} lines, but file only has {len(all_lines)}. Using all lines.")
|
|
116
|
+
else:
|
|
117
|
+
selected_lines = random.sample(all_lines, num_lines)
|
|
118
|
+
|
|
119
|
+
# Write subset
|
|
120
|
+
with open(subset_path, 'w') as f:
|
|
121
|
+
for line in selected_lines:
|
|
122
|
+
f.write(line + '\n')
|
|
123
|
+
|
|
124
|
+
print(f" ✅ Subset saved to: {subset_path.name}")
|
|
125
|
+
return subset_path
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
async def upload_file(file_path: Path) -> str:
|
|
129
|
+
"""Upload training file to Modal using OpenAI v1 compatible endpoint."""
|
|
130
|
+
print(f"📤 Uploading {file_path.name} ({file_path.stat().st_size / 1024 / 1024:.1f} MB)...")
|
|
131
|
+
|
|
132
|
+
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
133
|
+
# Read file content
|
|
134
|
+
with open(file_path, 'rb') as f:
|
|
135
|
+
file_content = f.read()
|
|
136
|
+
|
|
137
|
+
# Create multipart form data for OpenAI v1 compatible endpoint
|
|
138
|
+
files = {
|
|
139
|
+
'file': (file_path.name, file_content, 'application/json')
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
# Use OpenAI v1 compatible endpoint
|
|
143
|
+
response = await client.post(
|
|
144
|
+
f"{MODAL_BASE_URL}/v1/files?purpose=fine-tune",
|
|
145
|
+
files=files,
|
|
146
|
+
headers={"Authorization": f"Bearer {MODAL_API_KEY}"}
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
if response.status_code != 200:
|
|
150
|
+
raise Exception(f"Upload failed: {response.status_code} - {response.text}")
|
|
151
|
+
|
|
152
|
+
result = response.json()
|
|
153
|
+
file_id = result.get('id')
|
|
154
|
+
|
|
155
|
+
print(f"✅ File uploaded: {file_id}")
|
|
156
|
+
return file_id
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
async def create_fine_tune_job(file_id: str, model: str = "Qwen/Qwen2.5-7B-Instruct",
|
|
160
|
+
config: Optional[Dict[str, Any]] = None) -> str:
|
|
161
|
+
"""Create a fine-tuning job on Modal using OpenAI v1 compatible endpoint."""
|
|
162
|
+
print(f"🚀 Starting fine-tune job for {model}...")
|
|
163
|
+
|
|
164
|
+
# Default fine-tuning configuration for OpenAI v1 compatible service
|
|
165
|
+
ft_config = {
|
|
166
|
+
"model": model,
|
|
167
|
+
"training_file": file_id,
|
|
168
|
+
"training_type": "sft", # sft or dpo
|
|
169
|
+
"hyperparameters": {
|
|
170
|
+
"n_epochs": 3,
|
|
171
|
+
"batch_size": 4,
|
|
172
|
+
"learning_rate": 5e-5,
|
|
173
|
+
"use_qlora": True,
|
|
174
|
+
"lora_r": 16,
|
|
175
|
+
"lora_alpha": 32,
|
|
176
|
+
"lora_dropout": 0.1,
|
|
177
|
+
},
|
|
178
|
+
"suffix": f"modal-{int(time.time())}"
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
# Update with user config if provided
|
|
182
|
+
if config:
|
|
183
|
+
ft_config["hyperparameters"].update(config)
|
|
184
|
+
|
|
185
|
+
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
186
|
+
response = await client.post(
|
|
187
|
+
f"{MODAL_BASE_URL}/v1/fine_tuning/jobs",
|
|
188
|
+
json=ft_config,
|
|
189
|
+
headers={
|
|
190
|
+
"Authorization": f"Bearer {MODAL_API_KEY}",
|
|
191
|
+
"Content-Type": "application/json"
|
|
192
|
+
}
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
if response.status_code != 200:
|
|
196
|
+
raise Exception(f"Job creation failed: {response.status_code} - {response.text}")
|
|
197
|
+
|
|
198
|
+
result = response.json()
|
|
199
|
+
job_id = result.get('id')
|
|
200
|
+
|
|
201
|
+
print(f"✅ Fine-tune job created: {job_id}")
|
|
202
|
+
print(f" Model: {model}")
|
|
203
|
+
print(f" Status: {result.get('status', 'created')}")
|
|
204
|
+
|
|
205
|
+
return job_id
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
async def poll_job_status(job_id: str, poll_interval: int = 30) -> Optional[str]:
|
|
209
|
+
"""Poll job status until completion using OpenAI v1 compatible endpoint."""
|
|
210
|
+
print(f"⏳ Polling job {job_id} every {poll_interval}s...")
|
|
211
|
+
|
|
212
|
+
start_time = time.time()
|
|
213
|
+
last_status = None
|
|
214
|
+
|
|
215
|
+
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
216
|
+
while True:
|
|
217
|
+
try:
|
|
218
|
+
response = await client.get(
|
|
219
|
+
f"{MODAL_BASE_URL}/v1/fine_tuning/jobs/{job_id}",
|
|
220
|
+
headers={"Authorization": f"Bearer {MODAL_API_KEY}"}
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
if response.status_code != 200:
|
|
224
|
+
print(f" ⚠️ Failed to get status: {response.status_code}")
|
|
225
|
+
await asyncio.sleep(poll_interval)
|
|
226
|
+
continue
|
|
227
|
+
|
|
228
|
+
job = response.json()
|
|
229
|
+
status = job.get('status', 'unknown')
|
|
230
|
+
|
|
231
|
+
if status != last_status:
|
|
232
|
+
elapsed = time.time() - start_time
|
|
233
|
+
print(f" Status: {status} (elapsed: {elapsed/60:.1f}m)")
|
|
234
|
+
last_status = status
|
|
235
|
+
|
|
236
|
+
# Show training progress if available
|
|
237
|
+
if 'hyperparameters' in job:
|
|
238
|
+
hp = job['hyperparameters']
|
|
239
|
+
print(f" Model: {job.get('model', 'unknown')}")
|
|
240
|
+
print(f" Training file: {job.get('training_file', 'unknown')}")
|
|
241
|
+
|
|
242
|
+
# Terminal states
|
|
243
|
+
if status == "succeeded":
|
|
244
|
+
print(f"🎉 Fine-tuning completed successfully!")
|
|
245
|
+
final_model = job.get('fine_tuned_model')
|
|
246
|
+
if final_model:
|
|
247
|
+
print(f" Final model: {final_model}")
|
|
248
|
+
return final_model
|
|
249
|
+
|
|
250
|
+
elif status == "failed":
|
|
251
|
+
print(f"❌ Fine-tuning failed!")
|
|
252
|
+
if 'error' in job:
|
|
253
|
+
print(f" Error: {job['error']}")
|
|
254
|
+
return None
|
|
255
|
+
|
|
256
|
+
elif status == "cancelled":
|
|
257
|
+
print(f"⚠️ Fine-tuning was cancelled")
|
|
258
|
+
return None
|
|
259
|
+
|
|
260
|
+
# Continue polling for running states
|
|
261
|
+
elif status in ["queued", "running", "validating_files"]:
|
|
262
|
+
await asyncio.sleep(poll_interval)
|
|
263
|
+
continue
|
|
264
|
+
|
|
265
|
+
else:
|
|
266
|
+
print(f"⚠️ Unknown status: {status}")
|
|
267
|
+
await asyncio.sleep(poll_interval)
|
|
268
|
+
continue
|
|
269
|
+
|
|
270
|
+
except KeyboardInterrupt:
|
|
271
|
+
print(f"\n⚠️ Interrupted by user. Job {job_id} is still running on Modal.")
|
|
272
|
+
print(f" Check status later with the job ID: {job_id}")
|
|
273
|
+
return None
|
|
274
|
+
|
|
275
|
+
except Exception as e:
|
|
276
|
+
print(f"❌ Error polling job: {e}")
|
|
277
|
+
await asyncio.sleep(poll_interval)
|
|
278
|
+
continue
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
async def main():
|
|
282
|
+
parser = argparse.ArgumentParser(description="Modal/Synth Fine-Tuning Script")
|
|
283
|
+
parser.add_argument("jsonl_file", type=Path, help="Path to JSONL training file")
|
|
284
|
+
parser.add_argument("--model", default="Qwen/Qwen2.5-7B-Instruct",
|
|
285
|
+
help="Base model to fine-tune (default: Qwen/Qwen2.5-7B-Instruct)")
|
|
286
|
+
parser.add_argument("--epochs", type=int, default=3, help="Number of training epochs")
|
|
287
|
+
parser.add_argument("--batch-size", type=int, default=4, help="Training batch size")
|
|
288
|
+
parser.add_argument("--learning-rate", type=float, default=5e-5, help="Learning rate")
|
|
289
|
+
parser.add_argument("--poll-interval", type=int, default=30,
|
|
290
|
+
help="Polling interval in seconds (default: 30)")
|
|
291
|
+
parser.add_argument("--subset", type=int, help="Use a random subset of N lines")
|
|
292
|
+
parser.add_argument("--training-type", default="sft", choices=["sft", "dpo"],
|
|
293
|
+
help="Training type: sft (supervised) or dpo (preference)")
|
|
294
|
+
|
|
295
|
+
args = parser.parse_args()
|
|
296
|
+
|
|
297
|
+
# Validate file
|
|
298
|
+
if not args.jsonl_file.exists():
|
|
299
|
+
print(f"❌ File not found: {args.jsonl_file}")
|
|
300
|
+
sys.exit(1)
|
|
301
|
+
|
|
302
|
+
if not args.jsonl_file.suffix == '.jsonl':
|
|
303
|
+
print(f"⚠️ Warning: File doesn't have .jsonl extension: {args.jsonl_file}")
|
|
304
|
+
|
|
305
|
+
# Check API key
|
|
306
|
+
if not MODAL_API_KEY:
|
|
307
|
+
print("❌ Modal API key required. Set MODAL_API_KEY or SYNTH_API_KEY env var")
|
|
308
|
+
sys.exit(1)
|
|
309
|
+
|
|
310
|
+
# Analyze tokens first
|
|
311
|
+
line_count, total_tokens, avg_tokens = analyze_jsonl_tokens(args.jsonl_file, args.model)
|
|
312
|
+
|
|
313
|
+
# Estimate cost (Modal/Synth pricing varies by model)
|
|
314
|
+
# This is a rough estimate - actual costs depend on Modal pricing
|
|
315
|
+
estimated_hours = total_tokens / 1_000_000 * 0.5 # Rough estimate
|
|
316
|
+
print(f"\n⏱️ Estimated training time: ~{estimated_hours:.1f} hours")
|
|
317
|
+
print(f" (Based on ~2M tokens/hour - actual time may vary)")
|
|
318
|
+
|
|
319
|
+
# Use subset if requested
|
|
320
|
+
training_file = args.jsonl_file
|
|
321
|
+
if args.subset:
|
|
322
|
+
if args.subset > line_count:
|
|
323
|
+
print(f"⚠️ Subset size ({args.subset}) exceeds file size ({line_count}). Using all lines.")
|
|
324
|
+
else:
|
|
325
|
+
training_file = create_subset_file(args.jsonl_file, args.subset)
|
|
326
|
+
|
|
327
|
+
print("\n🤖 Modal/Synth Fine-Tuning Pipeline")
|
|
328
|
+
print("=" * 50)
|
|
329
|
+
print(f"Training file: {training_file}")
|
|
330
|
+
print(f"Base model: {args.model}")
|
|
331
|
+
print(f"Training type: {args.training_type}")
|
|
332
|
+
print(f"Epochs: {args.epochs}")
|
|
333
|
+
print(f"Batch size: {args.batch_size}")
|
|
334
|
+
print(f"Learning rate: {args.learning_rate}")
|
|
335
|
+
print("=" * 50)
|
|
336
|
+
|
|
337
|
+
try:
|
|
338
|
+
# Step 1: Upload file
|
|
339
|
+
file_id = await upload_file(training_file)
|
|
340
|
+
|
|
341
|
+
# Step 2: Create fine-tune job
|
|
342
|
+
ft_config = {
|
|
343
|
+
"n_epochs": args.epochs,
|
|
344
|
+
"batch_size": args.batch_size,
|
|
345
|
+
"learning_rate": args.learning_rate,
|
|
346
|
+
}
|
|
347
|
+
job_id = await create_fine_tune_job(file_id, args.model, ft_config)
|
|
348
|
+
|
|
349
|
+
# Step 3: Poll until completion
|
|
350
|
+
final_model = await poll_job_status(job_id, args.poll_interval)
|
|
351
|
+
|
|
352
|
+
if final_model:
|
|
353
|
+
print("\n" + "=" * 50)
|
|
354
|
+
print(f"🎯 SUCCESS! Fine-tuned model ready: {final_model}")
|
|
355
|
+
print("=" * 50)
|
|
356
|
+
|
|
357
|
+
# Show usage example
|
|
358
|
+
print("\n📝 Usage example:")
|
|
359
|
+
print(f'import httpx')
|
|
360
|
+
print(f'')
|
|
361
|
+
print(f'async def test_model():')
|
|
362
|
+
print(f' async with httpx.AsyncClient() as client:')
|
|
363
|
+
print(f' response = await client.post(')
|
|
364
|
+
print(f' "{MODAL_BASE_URL}/v1/chat/completions",')
|
|
365
|
+
print(f' headers={{"Authorization": f"Bearer {MODAL_API_KEY}"}},')
|
|
366
|
+
print(f' json={{')
|
|
367
|
+
print(f' "model": "{final_model}",')
|
|
368
|
+
print(f' "messages": [{{"role": "user", "content": "Hello!"}}]')
|
|
369
|
+
print(f' }}')
|
|
370
|
+
print(f' )')
|
|
371
|
+
print(f' return response.json()')
|
|
372
|
+
else:
|
|
373
|
+
print("\n❌ Fine-tuning did not complete successfully")
|
|
374
|
+
sys.exit(1)
|
|
375
|
+
|
|
376
|
+
except Exception as e:
|
|
377
|
+
print(f"\n❌ Unexpected error: {e}")
|
|
378
|
+
import traceback
|
|
379
|
+
traceback.print_exc()
|
|
380
|
+
sys.exit(1)
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
if __name__ == "__main__":
|
|
384
|
+
asyncio.run(main())
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Analyze why actions aren't producing expected results."""
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
import glob
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
# Find the latest results file
|
|
9
|
+
result_files = glob.glob("crafter_lm_synth_results_*.json")
|
|
10
|
+
if not result_files:
|
|
11
|
+
print("No result files found")
|
|
12
|
+
exit(1)
|
|
13
|
+
|
|
14
|
+
latest_file = max(result_files, key=lambda f: Path(f).stat().st_mtime)
|
|
15
|
+
print(f"📊 Analyzing: {latest_file}\n")
|
|
16
|
+
|
|
17
|
+
with open(latest_file) as f:
|
|
18
|
+
data = json.load(f)
|
|
19
|
+
|
|
20
|
+
# Check each episode
|
|
21
|
+
for ep_idx, episode in enumerate(data.get('results', [])):
|
|
22
|
+
if 'error' in episode:
|
|
23
|
+
continue
|
|
24
|
+
|
|
25
|
+
print(f"\n{'='*60}")
|
|
26
|
+
print(f"EPISODE {ep_idx}")
|
|
27
|
+
print(f"{'='*60}")
|
|
28
|
+
|
|
29
|
+
steps = episode.get('step_results', [])
|
|
30
|
+
|
|
31
|
+
# Track inventory changes
|
|
32
|
+
prev_inventory = {}
|
|
33
|
+
|
|
34
|
+
for step in steps:
|
|
35
|
+
turn = step.get('turn', 0)
|
|
36
|
+
action = step.get('action', 'unknown')
|
|
37
|
+
reward = step.get('reward', 0)
|
|
38
|
+
|
|
39
|
+
# This won't work because we don't store obs in step_results
|
|
40
|
+
# We need to look at the actual traces or add obs to step_results
|
|
41
|
+
|
|
42
|
+
achievements = episode.get('achievements_unlocked', [])
|
|
43
|
+
print(f"\nFinal achievements: {achievements}")
|
|
44
|
+
print(f"Total reward: {episode.get('total_reward', 0)}")
|
|
45
|
+
print(f"Steps: {episode.get('steps', 0)}")
|
|
46
|
+
|
|
47
|
+
# Summary
|
|
48
|
+
summary = data.get('summary', {})
|
|
49
|
+
print(f"\n{'='*60}")
|
|
50
|
+
print("OVERALL SUMMARY")
|
|
51
|
+
print(f"{'='*60}")
|
|
52
|
+
print(f"Unique achievements: {summary.get('unique_achievements', [])}")
|
|
53
|
+
print(f"Average reward: {summary.get('avg_reward', 0)}")
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Analyze agent actions and model responses from trace data."""
|
|
3
|
+
|
|
4
|
+
import duckdb
|
|
5
|
+
import json
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from collections import Counter
|
|
8
|
+
|
|
9
|
+
def analyze_agent_actions(db_path: str):
|
|
10
|
+
"""Analyze agent actions and responses."""
|
|
11
|
+
conn = duckdb.connect(db_path, read_only=True)
|
|
12
|
+
|
|
13
|
+
print("🔍 Analyzing agent actions and responses...\n")
|
|
14
|
+
|
|
15
|
+
# Get all events
|
|
16
|
+
all_events = conn.execute("""
|
|
17
|
+
SELECT
|
|
18
|
+
e.session_id,
|
|
19
|
+
e.event_type,
|
|
20
|
+
e.event_time,
|
|
21
|
+
e.metadata,
|
|
22
|
+
e.system_state_after,
|
|
23
|
+
e.system_state_before,
|
|
24
|
+
e.reward,
|
|
25
|
+
e.terminated
|
|
26
|
+
FROM events e
|
|
27
|
+
ORDER BY e.session_id, e.event_time
|
|
28
|
+
""").fetchall()
|
|
29
|
+
|
|
30
|
+
print(f"Total events found: {len(all_events)}\n")
|
|
31
|
+
|
|
32
|
+
# Group events by session
|
|
33
|
+
sessions = {}
|
|
34
|
+
for event in all_events:
|
|
35
|
+
session_id = event[0]
|
|
36
|
+
if session_id not in sessions:
|
|
37
|
+
sessions[session_id] = []
|
|
38
|
+
sessions[session_id].append(event)
|
|
39
|
+
|
|
40
|
+
print(f"Total sessions: {len(sessions)}\n")
|
|
41
|
+
|
|
42
|
+
# Analyze each session
|
|
43
|
+
for i, (session_id, events) in enumerate(sessions.items()):
|
|
44
|
+
if i >= 3: # Only analyze first 3 sessions
|
|
45
|
+
break
|
|
46
|
+
|
|
47
|
+
print(f"\n{'='*60}")
|
|
48
|
+
print(f"SESSION {i+1}: {session_id}")
|
|
49
|
+
print(f"{'='*60}")
|
|
50
|
+
print(f"Total events in session: {len(events)}")
|
|
51
|
+
|
|
52
|
+
# Track actions and achievements
|
|
53
|
+
actions_taken = []
|
|
54
|
+
achievements_unlocked = set()
|
|
55
|
+
total_reward = 0
|
|
56
|
+
|
|
57
|
+
for event in events:
|
|
58
|
+
event_type = event[1]
|
|
59
|
+
metadata = json.loads(event[3]) if event[3] else {}
|
|
60
|
+
state_after = json.loads(event[4]) if event[4] else {}
|
|
61
|
+
state_before = json.loads(event[5]) if event[5] else {}
|
|
62
|
+
reward = event[6] or 0
|
|
63
|
+
terminated = event[7]
|
|
64
|
+
|
|
65
|
+
total_reward += reward
|
|
66
|
+
|
|
67
|
+
# Track runtime events (actions)
|
|
68
|
+
if event_type == 'runtime' and 'action_name' in metadata:
|
|
69
|
+
action = metadata['action_name']
|
|
70
|
+
actions_taken.append(action)
|
|
71
|
+
|
|
72
|
+
# Check for achievements
|
|
73
|
+
if 'public_state' in state_after:
|
|
74
|
+
public_state = state_after['public_state']
|
|
75
|
+
if 'achievements_status' in public_state:
|
|
76
|
+
for ach, unlocked in public_state['achievements_status'].items():
|
|
77
|
+
if unlocked:
|
|
78
|
+
achievements_unlocked.add(ach)
|
|
79
|
+
|
|
80
|
+
# Look for generation responses
|
|
81
|
+
if event_type == 'generation' and metadata:
|
|
82
|
+
print(f"\n--- Generation Event ---")
|
|
83
|
+
if 'model' in metadata:
|
|
84
|
+
print(f"Model: {metadata['model']}")
|
|
85
|
+
if 'response' in metadata:
|
|
86
|
+
response = metadata['response']
|
|
87
|
+
if isinstance(response, str):
|
|
88
|
+
print(f"Response preview: {response[:200]}")
|
|
89
|
+
elif isinstance(response, dict):
|
|
90
|
+
if 'content' in response:
|
|
91
|
+
print(f"Content: {response['content'][:200]}")
|
|
92
|
+
if 'tool_calls' in response:
|
|
93
|
+
print(f"Tool calls: {response['tool_calls']}")
|
|
94
|
+
|
|
95
|
+
# Summary for this session
|
|
96
|
+
print(f"\n--- Session Summary ---")
|
|
97
|
+
print(f"Total reward: {total_reward}")
|
|
98
|
+
print(f"Achievements unlocked: {achievements_unlocked if achievements_unlocked else 'None'}")
|
|
99
|
+
print(f"Total actions taken: {len(actions_taken)}")
|
|
100
|
+
|
|
101
|
+
if actions_taken:
|
|
102
|
+
action_counts = Counter(actions_taken)
|
|
103
|
+
print(f"\nAction distribution:")
|
|
104
|
+
for action, count in action_counts.most_common(10):
|
|
105
|
+
print(f" {action}: {count}")
|
|
106
|
+
|
|
107
|
+
# Overall analysis
|
|
108
|
+
print(f"\n\n{'='*60}")
|
|
109
|
+
print("OVERALL ANALYSIS")
|
|
110
|
+
print(f"{'='*60}")
|
|
111
|
+
|
|
112
|
+
# Check for any achievements across all sessions
|
|
113
|
+
all_achievements = conn.execute("""
|
|
114
|
+
SELECT DISTINCT
|
|
115
|
+
json_extract_string(system_state_after, '$.public_state.achievements_status') as achievements
|
|
116
|
+
FROM events
|
|
117
|
+
WHERE system_state_after IS NOT NULL
|
|
118
|
+
AND json_extract_string(system_state_after, '$.public_state.achievements_status') IS NOT NULL
|
|
119
|
+
LIMIT 10
|
|
120
|
+
""").fetchall()
|
|
121
|
+
|
|
122
|
+
print(f"\nChecking achievement states...")
|
|
123
|
+
total_unlocked = 0
|
|
124
|
+
for ach_json, in all_achievements[:3]:
|
|
125
|
+
if ach_json:
|
|
126
|
+
achievements = json.loads(ach_json)
|
|
127
|
+
unlocked = [k for k, v in achievements.items() if v]
|
|
128
|
+
if unlocked:
|
|
129
|
+
total_unlocked += len(unlocked)
|
|
130
|
+
print(f"Found unlocked: {unlocked}")
|
|
131
|
+
|
|
132
|
+
if total_unlocked == 0:
|
|
133
|
+
print("❌ No achievements were unlocked in any session!")
|
|
134
|
+
|
|
135
|
+
# Check for specific issues
|
|
136
|
+
print(f"\n\n{'='*60}")
|
|
137
|
+
print("POTENTIAL ISSUES")
|
|
138
|
+
print(f"{'='*60}")
|
|
139
|
+
|
|
140
|
+
# Look for model responses
|
|
141
|
+
model_responses = conn.execute("""
|
|
142
|
+
SELECT COUNT(*)
|
|
143
|
+
FROM events
|
|
144
|
+
WHERE event_type = 'generation'
|
|
145
|
+
""").fetchone()[0]
|
|
146
|
+
|
|
147
|
+
print(f"Generation events found: {model_responses}")
|
|
148
|
+
|
|
149
|
+
if model_responses == 0:
|
|
150
|
+
print("❌ No generation events found - agent may not be responding")
|
|
151
|
+
|
|
152
|
+
# Look for specific metadata patterns
|
|
153
|
+
print("\nChecking event metadata patterns...")
|
|
154
|
+
metadata_samples = conn.execute("""
|
|
155
|
+
SELECT event_type, metadata
|
|
156
|
+
FROM events
|
|
157
|
+
WHERE metadata IS NOT NULL
|
|
158
|
+
AND metadata != '{}'
|
|
159
|
+
AND metadata != '[]'
|
|
160
|
+
LIMIT 20
|
|
161
|
+
""").fetchall()
|
|
162
|
+
|
|
163
|
+
for event_type, metadata_str in metadata_samples[:5]:
|
|
164
|
+
metadata = json.loads(metadata_str)
|
|
165
|
+
print(f"\n{event_type}: {list(metadata.keys())}")
|
|
166
|
+
if 'action_name' in metadata:
|
|
167
|
+
print(f" Action: {metadata['action_name']}")
|
|
168
|
+
if 'model' in metadata:
|
|
169
|
+
print(f" Model: {metadata['model']}")
|
|
170
|
+
|
|
171
|
+
conn.close()
|
|
172
|
+
|
|
173
|
+
if __name__ == "__main__":
|
|
174
|
+
db_path = "./traces_v2_synth/traces.duckdb"
|
|
175
|
+
if Path(db_path).exists():
|
|
176
|
+
analyze_agent_actions(db_path)
|
|
177
|
+
else:
|
|
178
|
+
print(f"❌ Database not found at {db_path}")
|