synth-ai 0.2.4.dev8__py3-none-any.whl → 0.2.4.dev9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/cli/__init__.py +6 -0
- synth_ai/cli/demo.py +68 -9
- synth_ai/cli/rl_demo.py +137 -0
- synth_ai/cli/root.py +65 -0
- synth_ai/demos/core/__init__.py +1 -0
- synth_ai/demos/core/cli.py +621 -0
- synth_ai/demos/demo_task_apps/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/core.py +374 -0
- synth_ai/demos/demo_task_apps/math/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/math/app.py +37 -0
- synth_ai/demos/demo_task_apps/math/config.toml +44 -0
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +60 -0
- synth_ai/demos/demo_task_apps/math/deploy_task_app.sh +22 -0
- synth_ai/environments/examples/bandit/__init__.py +33 -0
- synth_ai/environments/examples/bandit/engine.py +294 -0
- synth_ai/environments/examples/bandit/environment.py +194 -0
- synth_ai/environments/examples/bandit/taskset.py +200 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/analyze_semantic_words_markdown.py +250 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +59 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_config.toml +24 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/crafter_synth_config.toml +56 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_config_modal.toml +32 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +724 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_modal.py +384 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_action_results.py +53 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_agent_actions.py +178 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_latest_run.py +222 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_lm_traces.py +183 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_no_rewards.py +210 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_trace_issue.py +206 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_db_schema.py +49 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_latest_results.py +64 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/debug_agent_responses.py +88 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/quick_trace_check.py +77 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/compare_experiments.py +324 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/kick_off_ft_oai.py +362 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/multi_model_config.toml +49 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_enhanced_hooks.py +332 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_events.py +97 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_results.py +217 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_hook_storage.py +87 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_seeds.py +88 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/compare_seed_performance.py +195 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/custom_eval_pipelines.py +400 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/plot_hook_frequency.py +195 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/seed_analysis_summary.py +56 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v3.py +858 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +52 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +874 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/example_v3_usage.py +216 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/compare_traces.py +296 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_env_serialization.py +464 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/debug_player_loss.py +112 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_service.py +203 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_slowness.py +305 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_by_difficulty.py +126 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_example.py +94 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/explore_saved_states.py +142 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft.py +26 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft_OLD.py +984 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_gemini.py +724 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_modal.py +386 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_metadata.py +205 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_gemini.py +150 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_modal.py +283 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/prepare_vertex_ft.py +280 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/profile_env_slowness.py +456 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/replicate_issue.py +166 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_and_eval.py +102 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_comparison.py +128 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_qwen_rollouts.py +655 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/trace_eval_OLD.py +202 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/validate_openai_format.py +166 -0
- synth_ai/environments/examples/crafter_classic/environment.py +41 -2
- synth_ai/environments/examples/crafter_custom/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/crafter_custom/agent_demos/trace_eval.py +202 -0
- synth_ai/environments/examples/crafter_custom/old/analyze_diamond_issue.py +159 -0
- synth_ai/environments/examples/crafter_custom/old/analyze_diamond_spawning.py +158 -0
- synth_ai/environments/examples/crafter_custom/old/compare_worlds.py +71 -0
- synth_ai/environments/examples/crafter_custom/old/dataset_stats.py +105 -0
- synth_ai/environments/examples/crafter_custom/old/diamond_spawning_summary.py +119 -0
- synth_ai/environments/examples/crafter_custom/old/example_dataset_usage.py +52 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +112 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +221 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +831 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +899 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +95 -0
- synth_ai/environments/service/app.py +8 -0
- synth_ai/install_sqld.sh +40 -0
- synth_ai-0.2.4.dev9.dist-info/METADATA +91 -0
- {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.4.dev9.dist-info}/RECORD +110 -11
- {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.4.dev9.dist-info}/entry_points.txt +1 -0
- synth_ai-0.2.4.dev8.dist-info/METADATA +0 -635
- {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.4.dev9.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.4.dev9.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.4.dev9.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
```python
|
|
2
|
+
#!/usr/bin/env python3
|
|
3
|
+
"""
|
|
4
|
+
Vertex AI Fine‑Tuning Script (Gemini Flash)
|
|
5
|
+
==========================================
|
|
6
|
+
Uploads a JSONL file to GCS, starts a Gemini‑Flash tuning job, and polls until completion.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
import sys
|
|
11
|
+
import time
|
|
12
|
+
import argparse
|
|
13
|
+
import json
|
|
14
|
+
import random
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
18
|
+
# --- Lazy‑install required packages -----------------------------------------
|
|
19
|
+
def _lazy_import(pkg: str, pip_name: Optional[str] = None):
|
|
20
|
+
try:
|
|
21
|
+
return __import__(pkg)
|
|
22
|
+
except ImportError: # pragma: no cover
|
|
23
|
+
print(f"📦 Installing {pkg}…")
|
|
24
|
+
os.system(f"pip install {pip_name or pkg}")
|
|
25
|
+
return __import__(pkg)
|
|
26
|
+
|
|
27
|
+
storage = _lazy_import("google.cloud.storage")
|
|
28
|
+
aiplatform = _lazy_import("google.cloud.aiplatform")
|
|
29
|
+
vertexai = _lazy_import("vertexai")
|
|
30
|
+
generative_models = _lazy_import("vertexai.preview.generative_models")
|
|
31
|
+
|
|
32
|
+
tiktoken = _lazy_import("tiktoken")
|
|
33
|
+
|
|
34
|
+
# --- Helpers ----------------------------------------------------------------
|
|
35
|
+
def encoding_for(model: str = "cl100k_base"):
|
|
36
|
+
try:
|
|
37
|
+
return tiktoken.encoding_for_model(model)
|
|
38
|
+
except KeyError:
|
|
39
|
+
return tiktoken.get_encoding("cl100k_base")
|
|
40
|
+
|
|
41
|
+
def analyze_jsonl_tokens(file_path: Path, model: str) -> tuple[int, int, float]:
|
|
42
|
+
enc = encoding_for(model)
|
|
43
|
+
inp_tok = out_tok = lines = 0
|
|
44
|
+
with open(file_path, "r") as fh:
|
|
45
|
+
for ln, raw in enumerate(fh, 1):
|
|
46
|
+
try:
|
|
47
|
+
data = json.loads(raw)
|
|
48
|
+
except json.JSONDecodeError:
|
|
49
|
+
continue
|
|
50
|
+
msgs = data.get("messages", [])
|
|
51
|
+
input_text = " ".join(m["content"] for m in msgs[:-1] if m.get("content"))
|
|
52
|
+
output_text = msgs[-1]["content"] if msgs else ""
|
|
53
|
+
inp_tok += len(enc.encode(input_text))
|
|
54
|
+
out_tok += len(enc.encode(output_text))
|
|
55
|
+
lines += 1
|
|
56
|
+
avg = (inp_tok + out_tok) / lines if lines else 0
|
|
57
|
+
print(f"🔍 {lines:,} lines ‑ {inp_tok+out_tok:,} tokens (avg ≈ {avg:.1f}/line)")
|
|
58
|
+
return lines, inp_tok + out_tok, avg
|
|
59
|
+
|
|
60
|
+
def create_subset_file(src: Path, n: int) -> Path:
|
|
61
|
+
dst = src.with_name(f"{src.stem}_subset_{n}.jsonl")
|
|
62
|
+
lines = [l.strip() for l in src.open() if l.strip()]
|
|
63
|
+
if n < len(lines):
|
|
64
|
+
lines = random.sample(lines, n)
|
|
65
|
+
with dst.open("w") as fh:
|
|
66
|
+
fh.write("\n".join(lines) + "\n")
|
|
67
|
+
return dst
|
|
68
|
+
|
|
69
|
+
def upload_to_gcs(bucket: str, file_path: Path) -> str:
|
|
70
|
+
client = storage.Client()
|
|
71
|
+
bkt = client.bucket(bucket)
|
|
72
|
+
blob = bkt.blob(file_path.name)
|
|
73
|
+
blob.upload_from_filename(file_path)
|
|
74
|
+
uri = f"gs://{bucket}/{file_path.name}"
|
|
75
|
+
print(f"📤 Uploaded to {uri}")
|
|
76
|
+
return uri
|
|
77
|
+
|
|
78
|
+
def start_tuning_job(project: str, region: str, gcs_uri: str,
|
|
79
|
+
base_model: str, display_name: str,
|
|
80
|
+
epochs: int, lr_mult: float):
|
|
81
|
+
vertexai.init(project=project, location=region)
|
|
82
|
+
model = generative_models.GenerativeModel(base_model)
|
|
83
|
+
job = model.tune_model(
|
|
84
|
+
training_data=generative_models.FileData(
|
|
85
|
+
path=gcs_uri, mime_type="jsonl"),
|
|
86
|
+
tuned_model_display_name=display_name,
|
|
87
|
+
hyperparameters={
|
|
88
|
+
"epochs": epochs,
|
|
89
|
+
"learning_rate_multiplier": lr_mult,
|
|
90
|
+
},
|
|
91
|
+
)
|
|
92
|
+
return job
|
|
93
|
+
|
|
94
|
+
def wait(job, poll: int):
|
|
95
|
+
print("⏳ Waiting for tuning to finish…")
|
|
96
|
+
while True:
|
|
97
|
+
job.refresh()
|
|
98
|
+
state = job.state
|
|
99
|
+
print(f" Status: {state}")
|
|
100
|
+
if state in ("SUCCEEDED", "FAILED", "CANCELLED"):
|
|
101
|
+
break
|
|
102
|
+
time.sleep(poll)
|
|
103
|
+
return state
|
|
104
|
+
|
|
105
|
+
# --- Main -------------------------------------------------------------------
|
|
106
|
+
def main():
|
|
107
|
+
ap = argparse.ArgumentParser(description="Vertex AI Gemini‑Flash Fine‑Tuning")
|
|
108
|
+
ap.add_argument("jsonl_file", type=Path, help="Training data (JSONL)")
|
|
109
|
+
ap.add_argument("--project", required=True)
|
|
110
|
+
ap.add_argument("--region", default="us‑central1")
|
|
111
|
+
ap.add_argument("--bucket", required=True, help="GCS bucket for upload")
|
|
112
|
+
ap.add_argument("--model", default="gemini‑1.0‑flash")
|
|
113
|
+
ap.add_argument("--display-name", default="gemini-flash-tuned")
|
|
114
|
+
ap.add_argument("--epochs", type=int, default=3)
|
|
115
|
+
ap.add_argument("--lr-mult", type=float, default=0.05)
|
|
116
|
+
ap.add_argument("--poll-interval", type=int, default=60)
|
|
117
|
+
args = ap.parse_args()
|
|
118
|
+
|
|
119
|
+
if not args.jsonl_file.exists():
|
|
120
|
+
sys.exit("❌ Training file not found.")
|
|
121
|
+
|
|
122
|
+
lines, total_tok, _ = analyze_jsonl_tokens(args.jsonl_file, args.model)
|
|
123
|
+
|
|
124
|
+
subset = input("Use subset? (y/N): ").strip().lower()
|
|
125
|
+
train_file = args.jsonl_file
|
|
126
|
+
if subset == "y":
|
|
127
|
+
n = int(input(f"Lines (1‑{lines}): "))
|
|
128
|
+
train_file = create_subset_file(args.jsonl_file, n)
|
|
129
|
+
|
|
130
|
+
gcs_uri = upload_to_gcs(args.bucket, train_file)
|
|
131
|
+
job = start_tuning_job(args.project, args.region, gcs_uri,
|
|
132
|
+
args.model, args.display_name,
|
|
133
|
+
args.epochs, args.lr_mult)
|
|
134
|
+
final_state = wait(job, args.poll_interval)
|
|
135
|
+
|
|
136
|
+
if final_state == "SUCCEEDED":
|
|
137
|
+
print(f"🎉 Tuned model: {job.resource_name}")
|
|
138
|
+
print("\n📝 Usage example:")
|
|
139
|
+
print(f"""
|
|
140
|
+
from vertexai.preview import generative_models
|
|
141
|
+
vertexai.init(project="{args.project}", location="{args.region}")
|
|
142
|
+
model = generative_models.GenerativeModel("{job.resource_name}")
|
|
143
|
+
resp = model.generate_content("Hello!")
|
|
144
|
+
print(resp.text)
|
|
145
|
+
""")
|
|
146
|
+
else:
|
|
147
|
+
sys.exit("❌ Tuning did not succeed.")
|
|
148
|
+
|
|
149
|
+
if __name__ == "__main__":
|
|
150
|
+
main()
|
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Script to kick off fine-tuning jobs on Modal using generated Crafter rollout data
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import asyncio
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
from datetime import datetime
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Optional
|
|
13
|
+
|
|
14
|
+
import httpx
|
|
15
|
+
|
|
16
|
+
# Modal API configuration
|
|
17
|
+
MODAL_BASE_URL = "https://synth-laboratories--unified-ft-service-fastapi-app.modal.run"
|
|
18
|
+
MODAL_API_KEY = os.environ.get("MODAL_API_KEY", "sk-test-11111111111111111111111111111111")
|
|
19
|
+
|
|
20
|
+
# Default hyperparameters for Crafter fine-tuning
|
|
21
|
+
DEFAULT_HYPERPARAMS = {
|
|
22
|
+
"n_epochs": 3,
|
|
23
|
+
"batch_size": 4,
|
|
24
|
+
"learning_rate_multiplier": 2.0,
|
|
25
|
+
"use_qlora": False, # Can enable for larger models
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
# Supported base models for fine-tuning
|
|
29
|
+
SUPPORTED_MODELS = [
|
|
30
|
+
"Qwen/Qwen2.5-0.5B-Instruct",
|
|
31
|
+
"Qwen/Qwen2.5-1.5B-Instruct",
|
|
32
|
+
"Qwen/Qwen2.5-3B-Instruct",
|
|
33
|
+
"Qwen/Qwen2.5-7B-Instruct",
|
|
34
|
+
"Qwen/Qwen2.5-14B-Instruct",
|
|
35
|
+
"Qwen/Qwen2.5-32B-Instruct",
|
|
36
|
+
"Qwen/Qwen2.5-72B-Instruct",
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
async def upload_training_file(file_path: Path, api_key: str) -> str:
|
|
41
|
+
"""Upload a training file to Modal and return the file ID."""
|
|
42
|
+
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
43
|
+
headers = {"Authorization": f"Bearer {api_key}"}
|
|
44
|
+
|
|
45
|
+
print(f"📤 Uploading {file_path.name}...")
|
|
46
|
+
|
|
47
|
+
with open(file_path, "rb") as f:
|
|
48
|
+
files = {"file": (file_path.name, f, "application/jsonl")}
|
|
49
|
+
response = await client.post(
|
|
50
|
+
f"{MODAL_BASE_URL}/v1/files?purpose=fine-tune",
|
|
51
|
+
files=files,
|
|
52
|
+
headers=headers
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
if response.status_code != 200:
|
|
56
|
+
raise Exception(f"Failed to upload file: {response.status_code} - {response.text}")
|
|
57
|
+
|
|
58
|
+
file_data = response.json()
|
|
59
|
+
file_id = file_data["id"]
|
|
60
|
+
print(f"✅ Uploaded successfully: {file_id}")
|
|
61
|
+
return file_id
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
async def create_fine_tuning_job(
|
|
65
|
+
base_model: str,
|
|
66
|
+
training_file_id: str,
|
|
67
|
+
suffix: str,
|
|
68
|
+
hyperparameters: dict,
|
|
69
|
+
api_key: str,
|
|
70
|
+
validation_file_id: Optional[str] = None
|
|
71
|
+
) -> str:
|
|
72
|
+
"""Create a fine-tuning job and return the job ID."""
|
|
73
|
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
74
|
+
headers = {
|
|
75
|
+
"Authorization": f"Bearer {api_key}",
|
|
76
|
+
"Content-Type": "application/json"
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
payload = {
|
|
80
|
+
"model": base_model,
|
|
81
|
+
"training_file": training_file_id,
|
|
82
|
+
"hyperparameters": hyperparameters,
|
|
83
|
+
"suffix": suffix
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
if validation_file_id:
|
|
87
|
+
payload["validation_file"] = validation_file_id
|
|
88
|
+
|
|
89
|
+
print(f"🚀 Creating fine-tuning job...")
|
|
90
|
+
print(f" Base model: {base_model}")
|
|
91
|
+
print(f" Suffix: {suffix}")
|
|
92
|
+
print(f" Hyperparameters: {json.dumps(hyperparameters, indent=2)}")
|
|
93
|
+
|
|
94
|
+
response = await client.post(
|
|
95
|
+
f"{MODAL_BASE_URL}/v1/fine_tuning/jobs",
|
|
96
|
+
json=payload,
|
|
97
|
+
headers=headers
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
if response.status_code != 200:
|
|
101
|
+
raise Exception(f"Failed to create job: {response.status_code} - {response.text}")
|
|
102
|
+
|
|
103
|
+
job_data = response.json()
|
|
104
|
+
job_id = job_data["id"]
|
|
105
|
+
print(f"✅ Job created: {job_id}")
|
|
106
|
+
return job_id
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
async def monitor_job(job_id: str, api_key: str) -> dict:
|
|
110
|
+
"""Monitor a fine-tuning job until completion."""
|
|
111
|
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
112
|
+
headers = {"Authorization": f"Bearer {api_key}"}
|
|
113
|
+
|
|
114
|
+
print(f"\n📊 Monitoring job {job_id}...")
|
|
115
|
+
|
|
116
|
+
while True:
|
|
117
|
+
# Get job status
|
|
118
|
+
response = await client.get(
|
|
119
|
+
f"{MODAL_BASE_URL}/v1/fine_tuning/jobs/{job_id}",
|
|
120
|
+
headers=headers
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
if response.status_code != 200:
|
|
124
|
+
print(f"⚠️ Failed to get job status: {response.text}")
|
|
125
|
+
await asyncio.sleep(10)
|
|
126
|
+
continue
|
|
127
|
+
|
|
128
|
+
job_data = response.json()
|
|
129
|
+
status = job_data["status"]
|
|
130
|
+
|
|
131
|
+
# Print status update
|
|
132
|
+
print(f" Status: {status}")
|
|
133
|
+
|
|
134
|
+
# Get recent events
|
|
135
|
+
events_response = await client.get(
|
|
136
|
+
f"{MODAL_BASE_URL}/v1/fine_tuning/jobs/{job_id}/events?limit=5",
|
|
137
|
+
headers=headers
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
if events_response.status_code == 200:
|
|
141
|
+
events = events_response.json()["data"]
|
|
142
|
+
for event in events[:2]: # Show last 2 events
|
|
143
|
+
print(f" - {event.get('message', 'No message')}")
|
|
144
|
+
|
|
145
|
+
# Check if job is complete
|
|
146
|
+
if status in ["succeeded", "failed", "cancelled"]:
|
|
147
|
+
if status == "succeeded":
|
|
148
|
+
print(f"\n✅ Fine-tuning completed successfully!")
|
|
149
|
+
print(f" Model ID: {job_data['fine_tuned_model']}")
|
|
150
|
+
else:
|
|
151
|
+
print(f"\n❌ Fine-tuning {status}")
|
|
152
|
+
if job_data.get("error"):
|
|
153
|
+
print(f" Error: {job_data['error']}")
|
|
154
|
+
|
|
155
|
+
return job_data
|
|
156
|
+
|
|
157
|
+
# Wait before next check
|
|
158
|
+
await asyncio.sleep(30)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
async def validate_training_data(file_path: Path) -> tuple[int, int]:
|
|
162
|
+
"""Validate training data and return (num_examples, num_tokens_estimate)."""
|
|
163
|
+
num_examples = 0
|
|
164
|
+
total_chars = 0
|
|
165
|
+
|
|
166
|
+
with open(file_path, "r") as f:
|
|
167
|
+
for line in f:
|
|
168
|
+
try:
|
|
169
|
+
data = json.loads(line)
|
|
170
|
+
if "messages" in data:
|
|
171
|
+
num_examples += 1
|
|
172
|
+
# Rough token estimate (chars/4)
|
|
173
|
+
total_chars += len(json.dumps(data["messages"]))
|
|
174
|
+
except json.JSONDecodeError:
|
|
175
|
+
print(f"⚠️ Invalid JSON line: {line[:50]}...")
|
|
176
|
+
|
|
177
|
+
estimated_tokens = total_chars // 4
|
|
178
|
+
return num_examples, estimated_tokens
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
async def main():
|
|
182
|
+
parser = argparse.ArgumentParser(description="Kick off fine-tuning jobs on Modal")
|
|
183
|
+
parser.add_argument("training_file", type=str, help="Path to training data (JSONL)")
|
|
184
|
+
parser.add_argument("--base-model", type=str, default="Qwen/Qwen2.5-0.5B-Instruct",
|
|
185
|
+
choices=SUPPORTED_MODELS, help="Base model to fine-tune")
|
|
186
|
+
parser.add_argument("--suffix", type=str, default=None,
|
|
187
|
+
help="Model suffix (default: crafter-TIMESTAMP)")
|
|
188
|
+
parser.add_argument("--epochs", type=int, default=3,
|
|
189
|
+
help="Number of training epochs")
|
|
190
|
+
parser.add_argument("--batch-size", type=int, default=4,
|
|
191
|
+
help="Training batch size")
|
|
192
|
+
parser.add_argument("--learning-rate", type=float, default=2.0,
|
|
193
|
+
help="Learning rate multiplier")
|
|
194
|
+
parser.add_argument("--use-qlora", action="store_true",
|
|
195
|
+
help="Enable QLoRA for efficient training")
|
|
196
|
+
parser.add_argument("--validation-file", type=str, default=None,
|
|
197
|
+
help="Optional validation data file")
|
|
198
|
+
parser.add_argument("--api-key", type=str, default=None,
|
|
199
|
+
help="Modal API key (or set MODAL_API_KEY)")
|
|
200
|
+
parser.add_argument("--no-monitor", action="store_true",
|
|
201
|
+
help="Don't monitor job after creation")
|
|
202
|
+
|
|
203
|
+
args = parser.parse_args()
|
|
204
|
+
|
|
205
|
+
# Get API key
|
|
206
|
+
api_key = args.api_key or os.environ.get("MODAL_API_KEY", MODAL_API_KEY)
|
|
207
|
+
|
|
208
|
+
# Generate suffix if not provided
|
|
209
|
+
if args.suffix is None:
|
|
210
|
+
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
|
211
|
+
args.suffix = f"crafter-{timestamp}"
|
|
212
|
+
|
|
213
|
+
# Validate training file
|
|
214
|
+
training_path = Path(args.training_file)
|
|
215
|
+
if not training_path.exists():
|
|
216
|
+
print(f"❌ Training file not found: {training_path}")
|
|
217
|
+
return
|
|
218
|
+
|
|
219
|
+
print(f"🔍 Validating training data...")
|
|
220
|
+
num_examples, est_tokens = await validate_training_data(training_path)
|
|
221
|
+
print(f" Examples: {num_examples}")
|
|
222
|
+
print(f" Estimated tokens: {est_tokens:,}")
|
|
223
|
+
|
|
224
|
+
if num_examples < 10:
|
|
225
|
+
print("⚠️ Warning: Very few training examples. Consider generating more data.")
|
|
226
|
+
|
|
227
|
+
# Prepare hyperparameters
|
|
228
|
+
hyperparams = {
|
|
229
|
+
"n_epochs": args.epochs,
|
|
230
|
+
"batch_size": args.batch_size,
|
|
231
|
+
"learning_rate_multiplier": args.learning_rate,
|
|
232
|
+
"use_qlora": args.use_qlora
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
try:
|
|
236
|
+
# Upload training file
|
|
237
|
+
training_file_id = await upload_training_file(training_path, api_key)
|
|
238
|
+
|
|
239
|
+
# Upload validation file if provided
|
|
240
|
+
validation_file_id = None
|
|
241
|
+
if args.validation_file:
|
|
242
|
+
val_path = Path(args.validation_file)
|
|
243
|
+
if val_path.exists():
|
|
244
|
+
validation_file_id = await upload_training_file(val_path, api_key)
|
|
245
|
+
|
|
246
|
+
# Create fine-tuning job
|
|
247
|
+
job_id = await create_fine_tuning_job(
|
|
248
|
+
base_model=args.base_model,
|
|
249
|
+
training_file_id=training_file_id,
|
|
250
|
+
suffix=args.suffix,
|
|
251
|
+
hyperparameters=hyperparams,
|
|
252
|
+
api_key=api_key,
|
|
253
|
+
validation_file_id=validation_file_id
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
# Monitor job unless disabled
|
|
257
|
+
if not args.no_monitor:
|
|
258
|
+
job_data = await monitor_job(job_id, api_key)
|
|
259
|
+
|
|
260
|
+
if job_data["status"] == "succeeded":
|
|
261
|
+
print("\n🎉 Fine-tuning complete!")
|
|
262
|
+
print(f"Your model is ready: {job_data['fine_tuned_model']}")
|
|
263
|
+
print("\nTo use your model:")
|
|
264
|
+
print(f" curl -X POST {MODAL_BASE_URL}/v1/chat/completions \\")
|
|
265
|
+
print(f" -H 'Authorization: Bearer YOUR_API_KEY' \\")
|
|
266
|
+
print(f" -H 'Content-Type: application/json' \\")
|
|
267
|
+
print(f" -d '{{")
|
|
268
|
+
print(f' "model": "{job_data["fine_tuned_model"]}",')
|
|
269
|
+
print(f' "messages": [{{"role": "user", "content": "Hello!"}}]')
|
|
270
|
+
print(f" }}'")
|
|
271
|
+
else:
|
|
272
|
+
print(f"\nJob created: {job_id}")
|
|
273
|
+
print(f"Check status at: {MODAL_BASE_URL}/v1/fine_tuning/jobs/{job_id}")
|
|
274
|
+
|
|
275
|
+
except Exception as e:
|
|
276
|
+
print(f"\n❌ Error: {e}")
|
|
277
|
+
return 1
|
|
278
|
+
|
|
279
|
+
return 0
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
if __name__ == "__main__":
|
|
283
|
+
asyncio.run(main())
|