synth-ai 0.2.4.dev8__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (112) hide show
  1. synth_ai/__init__.py +1 -1
  2. synth_ai/cli/__init__.py +6 -0
  3. synth_ai/cli/demo.py +68 -9
  4. synth_ai/cli/rl_demo.py +137 -0
  5. synth_ai/cli/root.py +65 -0
  6. synth_ai/demos/core/__init__.py +1 -0
  7. synth_ai/demos/core/cli.py +685 -0
  8. synth_ai/demos/demo_task_apps/__init__.py +1 -0
  9. synth_ai/demos/demo_task_apps/core.py +374 -0
  10. synth_ai/demos/demo_task_apps/math/__init__.py +1 -0
  11. synth_ai/demos/demo_task_apps/math/app.py +37 -0
  12. synth_ai/demos/demo_task_apps/math/config.toml +44 -0
  13. synth_ai/demos/demo_task_apps/math/deploy_modal.py +60 -0
  14. synth_ai/demos/demo_task_apps/math/deploy_task_app.sh +22 -0
  15. synth_ai/environments/examples/bandit/__init__.py +33 -0
  16. synth_ai/environments/examples/bandit/engine.py +294 -0
  17. synth_ai/environments/examples/bandit/environment.py +194 -0
  18. synth_ai/environments/examples/bandit/taskset.py +200 -0
  19. synth_ai/environments/examples/crafter_classic/agent_demos/analyze_semantic_words_markdown.py +250 -0
  20. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +59 -0
  21. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
  22. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_config.toml +24 -0
  23. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
  24. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/crafter_synth_config.toml +56 -0
  25. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_config_modal.toml +32 -0
  26. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +724 -0
  27. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_modal.py +384 -0
  28. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_action_results.py +53 -0
  29. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_agent_actions.py +178 -0
  30. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_latest_run.py +222 -0
  31. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_lm_traces.py +183 -0
  32. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_no_rewards.py +210 -0
  33. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_trace_issue.py +206 -0
  34. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_db_schema.py +49 -0
  35. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_latest_results.py +64 -0
  36. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/debug_agent_responses.py +88 -0
  37. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/quick_trace_check.py +77 -0
  38. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/compare_experiments.py +324 -0
  39. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
  40. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/kick_off_ft_oai.py +362 -0
  41. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/multi_model_config.toml +49 -0
  42. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_enhanced_hooks.py +332 -0
  43. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_events.py +97 -0
  44. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_results.py +217 -0
  45. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_hook_storage.py +87 -0
  46. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_seeds.py +88 -0
  47. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/compare_seed_performance.py +195 -0
  48. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/custom_eval_pipelines.py +400 -0
  49. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/plot_hook_frequency.py +195 -0
  50. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/seed_analysis_summary.py +56 -0
  51. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v3.py +858 -0
  52. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +52 -0
  53. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +874 -0
  54. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
  55. synth_ai/environments/examples/crafter_classic/agent_demos/example_v3_usage.py +216 -0
  56. synth_ai/environments/examples/crafter_classic/agent_demos/old/compare_traces.py +296 -0
  57. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_comprehensive_evaluation.py +58 -0
  58. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_env_serialization.py +464 -0
  59. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_evaluation_browser.py +152 -0
  60. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_quick_evaluation.py +51 -0
  61. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_trace_evaluation.py +1412 -0
  62. synth_ai/environments/examples/crafter_classic/agent_demos/old/debug_player_loss.py +112 -0
  63. synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_service.py +203 -0
  64. synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_slowness.py +305 -0
  65. synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_by_difficulty.py +126 -0
  66. synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_example.py +94 -0
  67. synth_ai/environments/examples/crafter_classic/agent_demos/old/explore_saved_states.py +142 -0
  68. synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft.py +26 -0
  69. synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft_OLD.py +984 -0
  70. synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_gemini.py +724 -0
  71. synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_modal.py +386 -0
  72. synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_metadata.py +205 -0
  73. synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_gemini.py +150 -0
  74. synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_modal.py +283 -0
  75. synth_ai/environments/examples/crafter_classic/agent_demos/old/prepare_vertex_ft.py +280 -0
  76. synth_ai/environments/examples/crafter_classic/agent_demos/old/profile_env_slowness.py +456 -0
  77. synth_ai/environments/examples/crafter_classic/agent_demos/old/replicate_issue.py +166 -0
  78. synth_ai/environments/examples/crafter_classic/agent_demos/old/run_and_eval.py +102 -0
  79. synth_ai/environments/examples/crafter_classic/agent_demos/old/run_comparison.py +128 -0
  80. synth_ai/environments/examples/crafter_classic/agent_demos/old/run_qwen_rollouts.py +655 -0
  81. synth_ai/environments/examples/crafter_classic/agent_demos/old/trace_eval_OLD.py +202 -0
  82. synth_ai/environments/examples/crafter_classic/agent_demos/old/validate_openai_format.py +166 -0
  83. synth_ai/environments/examples/crafter_classic/environment.py +41 -2
  84. synth_ai/environments/examples/crafter_custom/agent_demos/__init__.py +1 -0
  85. synth_ai/environments/examples/crafter_custom/agent_demos/trace_eval.py +202 -0
  86. synth_ai/environments/examples/crafter_custom/old/analyze_diamond_issue.py +159 -0
  87. synth_ai/environments/examples/crafter_custom/old/analyze_diamond_spawning.py +158 -0
  88. synth_ai/environments/examples/crafter_custom/old/compare_worlds.py +71 -0
  89. synth_ai/environments/examples/crafter_custom/old/dataset_stats.py +105 -0
  90. synth_ai/environments/examples/crafter_custom/old/diamond_spawning_summary.py +119 -0
  91. synth_ai/environments/examples/crafter_custom/old/example_dataset_usage.py +52 -0
  92. synth_ai/environments/examples/enron/units/keyword_stats.py +112 -0
  93. synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
  94. synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +48 -0
  95. synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
  96. synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +221 -0
  97. synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
  98. synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
  99. synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +831 -0
  100. synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
  101. synth_ai/environments/examples/red/units/__init__.py +1 -0
  102. synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +899 -0
  103. synth_ai/environments/examples/sokoban/units/astar_common.py +95 -0
  104. synth_ai/environments/service/app.py +8 -0
  105. synth_ai/install_sqld.sh +40 -0
  106. synth_ai-0.2.5.dist-info/METADATA +106 -0
  107. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.5.dist-info}/RECORD +111 -12
  108. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.5.dist-info}/entry_points.txt +1 -0
  109. synth_ai-0.2.4.dev8.dist-info/METADATA +0 -635
  110. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.5.dist-info}/WHEEL +0 -0
  111. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.5.dist-info}/licenses/LICENSE +0 -0
  112. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,150 @@
1
+ ```python
2
+ #!/usr/bin/env python3
3
+ """
4
+ Vertex AI Fine‑Tuning Script (Gemini Flash)
5
+ ==========================================
6
+ Uploads a JSONL file to GCS, starts a Gemini‑Flash tuning job, and polls until completion.
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import time
12
+ import argparse
13
+ import json
14
+ import random
15
+ from pathlib import Path
16
+ from typing import Optional
17
+
18
+ # --- Lazy‑install required packages -----------------------------------------
19
+ def _lazy_import(pkg: str, pip_name: Optional[str] = None):
20
+ try:
21
+ return __import__(pkg)
22
+ except ImportError: # pragma: no cover
23
+ print(f"📦 Installing {pkg}…")
24
+ os.system(f"pip install {pip_name or pkg}")
25
+ return __import__(pkg)
26
+
27
+ storage = _lazy_import("google.cloud.storage")
28
+ aiplatform = _lazy_import("google.cloud.aiplatform")
29
+ vertexai = _lazy_import("vertexai")
30
+ generative_models = _lazy_import("vertexai.preview.generative_models")
31
+
32
+ tiktoken = _lazy_import("tiktoken")
33
+
34
+ # --- Helpers ----------------------------------------------------------------
35
+ def encoding_for(model: str = "cl100k_base"):
36
+ try:
37
+ return tiktoken.encoding_for_model(model)
38
+ except KeyError:
39
+ return tiktoken.get_encoding("cl100k_base")
40
+
41
+ def analyze_jsonl_tokens(file_path: Path, model: str) -> tuple[int, int, float]:
42
+ enc = encoding_for(model)
43
+ inp_tok = out_tok = lines = 0
44
+ with open(file_path, "r") as fh:
45
+ for ln, raw in enumerate(fh, 1):
46
+ try:
47
+ data = json.loads(raw)
48
+ except json.JSONDecodeError:
49
+ continue
50
+ msgs = data.get("messages", [])
51
+ input_text = " ".join(m["content"] for m in msgs[:-1] if m.get("content"))
52
+ output_text = msgs[-1]["content"] if msgs else ""
53
+ inp_tok += len(enc.encode(input_text))
54
+ out_tok += len(enc.encode(output_text))
55
+ lines += 1
56
+ avg = (inp_tok + out_tok) / lines if lines else 0
57
+ print(f"🔍 {lines:,} lines ‑ {inp_tok+out_tok:,} tokens (avg ≈ {avg:.1f}/line)")
58
+ return lines, inp_tok + out_tok, avg
59
+
60
+ def create_subset_file(src: Path, n: int) -> Path:
61
+ dst = src.with_name(f"{src.stem}_subset_{n}.jsonl")
62
+ lines = [l.strip() for l in src.open() if l.strip()]
63
+ if n < len(lines):
64
+ lines = random.sample(lines, n)
65
+ with dst.open("w") as fh:
66
+ fh.write("\n".join(lines) + "\n")
67
+ return dst
68
+
69
+ def upload_to_gcs(bucket: str, file_path: Path) -> str:
70
+ client = storage.Client()
71
+ bkt = client.bucket(bucket)
72
+ blob = bkt.blob(file_path.name)
73
+ blob.upload_from_filename(file_path)
74
+ uri = f"gs://{bucket}/{file_path.name}"
75
+ print(f"📤 Uploaded to {uri}")
76
+ return uri
77
+
78
+ def start_tuning_job(project: str, region: str, gcs_uri: str,
79
+ base_model: str, display_name: str,
80
+ epochs: int, lr_mult: float):
81
+ vertexai.init(project=project, location=region)
82
+ model = generative_models.GenerativeModel(base_model)
83
+ job = model.tune_model(
84
+ training_data=generative_models.FileData(
85
+ path=gcs_uri, mime_type="jsonl"),
86
+ tuned_model_display_name=display_name,
87
+ hyperparameters={
88
+ "epochs": epochs,
89
+ "learning_rate_multiplier": lr_mult,
90
+ },
91
+ )
92
+ return job
93
+
94
+ def wait(job, poll: int):
95
+ print("⏳ Waiting for tuning to finish…")
96
+ while True:
97
+ job.refresh()
98
+ state = job.state
99
+ print(f" Status: {state}")
100
+ if state in ("SUCCEEDED", "FAILED", "CANCELLED"):
101
+ break
102
+ time.sleep(poll)
103
+ return state
104
+
105
+ # --- Main -------------------------------------------------------------------
106
+ def main():
107
+ ap = argparse.ArgumentParser(description="Vertex AI Gemini‑Flash Fine‑Tuning")
108
+ ap.add_argument("jsonl_file", type=Path, help="Training data (JSONL)")
109
+ ap.add_argument("--project", required=True)
110
+ ap.add_argument("--region", default="us‑central1")
111
+ ap.add_argument("--bucket", required=True, help="GCS bucket for upload")
112
+ ap.add_argument("--model", default="gemini‑1.0‑flash")
113
+ ap.add_argument("--display-name", default="gemini-flash-tuned")
114
+ ap.add_argument("--epochs", type=int, default=3)
115
+ ap.add_argument("--lr-mult", type=float, default=0.05)
116
+ ap.add_argument("--poll-interval", type=int, default=60)
117
+ args = ap.parse_args()
118
+
119
+ if not args.jsonl_file.exists():
120
+ sys.exit("❌ Training file not found.")
121
+
122
+ lines, total_tok, _ = analyze_jsonl_tokens(args.jsonl_file, args.model)
123
+
124
+ subset = input("Use subset? (y/N): ").strip().lower()
125
+ train_file = args.jsonl_file
126
+ if subset == "y":
127
+ n = int(input(f"Lines (1‑{lines}): "))
128
+ train_file = create_subset_file(args.jsonl_file, n)
129
+
130
+ gcs_uri = upload_to_gcs(args.bucket, train_file)
131
+ job = start_tuning_job(args.project, args.region, gcs_uri,
132
+ args.model, args.display_name,
133
+ args.epochs, args.lr_mult)
134
+ final_state = wait(job, args.poll_interval)
135
+
136
+ if final_state == "SUCCEEDED":
137
+ print(f"🎉 Tuned model: {job.resource_name}")
138
+ print("\n📝 Usage example:")
139
+ print(f"""
140
+ from vertexai.preview import generative_models
141
+ vertexai.init(project="{args.project}", location="{args.region}")
142
+ model = generative_models.GenerativeModel("{job.resource_name}")
143
+ resp = model.generate_content("Hello!")
144
+ print(resp.text)
145
+ """)
146
+ else:
147
+ sys.exit("❌ Tuning did not succeed.")
148
+
149
+ if __name__ == "__main__":
150
+ main()
@@ -0,0 +1,283 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Script to kick off fine-tuning jobs on Modal using generated Crafter rollout data
4
+ """
5
+
6
+ import argparse
7
+ import asyncio
8
+ import json
9
+ import os
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+ from typing import Optional
13
+
14
+ import httpx
15
+
16
+ # Modal API configuration
17
+ MODAL_BASE_URL = "https://synth-laboratories--unified-ft-service-fastapi-app.modal.run"
18
+ MODAL_API_KEY = os.environ.get("MODAL_API_KEY", "sk-test-11111111111111111111111111111111")
19
+
20
+ # Default hyperparameters for Crafter fine-tuning
21
+ DEFAULT_HYPERPARAMS = {
22
+ "n_epochs": 3,
23
+ "batch_size": 4,
24
+ "learning_rate_multiplier": 2.0,
25
+ "use_qlora": False, # Can enable for larger models
26
+ }
27
+
28
+ # Supported base models for fine-tuning
29
+ SUPPORTED_MODELS = [
30
+ "Qwen/Qwen2.5-0.5B-Instruct",
31
+ "Qwen/Qwen2.5-1.5B-Instruct",
32
+ "Qwen/Qwen2.5-3B-Instruct",
33
+ "Qwen/Qwen2.5-7B-Instruct",
34
+ "Qwen/Qwen2.5-14B-Instruct",
35
+ "Qwen/Qwen2.5-32B-Instruct",
36
+ "Qwen/Qwen2.5-72B-Instruct",
37
+ ]
38
+
39
+
40
+ async def upload_training_file(file_path: Path, api_key: str) -> str:
41
+ """Upload a training file to Modal and return the file ID."""
42
+ async with httpx.AsyncClient(timeout=60.0) as client:
43
+ headers = {"Authorization": f"Bearer {api_key}"}
44
+
45
+ print(f"📤 Uploading {file_path.name}...")
46
+
47
+ with open(file_path, "rb") as f:
48
+ files = {"file": (file_path.name, f, "application/jsonl")}
49
+ response = await client.post(
50
+ f"{MODAL_BASE_URL}/v1/files?purpose=fine-tune",
51
+ files=files,
52
+ headers=headers
53
+ )
54
+
55
+ if response.status_code != 200:
56
+ raise Exception(f"Failed to upload file: {response.status_code} - {response.text}")
57
+
58
+ file_data = response.json()
59
+ file_id = file_data["id"]
60
+ print(f"✅ Uploaded successfully: {file_id}")
61
+ return file_id
62
+
63
+
64
+ async def create_fine_tuning_job(
65
+ base_model: str,
66
+ training_file_id: str,
67
+ suffix: str,
68
+ hyperparameters: dict,
69
+ api_key: str,
70
+ validation_file_id: Optional[str] = None
71
+ ) -> str:
72
+ """Create a fine-tuning job and return the job ID."""
73
+ async with httpx.AsyncClient(timeout=30.0) as client:
74
+ headers = {
75
+ "Authorization": f"Bearer {api_key}",
76
+ "Content-Type": "application/json"
77
+ }
78
+
79
+ payload = {
80
+ "model": base_model,
81
+ "training_file": training_file_id,
82
+ "hyperparameters": hyperparameters,
83
+ "suffix": suffix
84
+ }
85
+
86
+ if validation_file_id:
87
+ payload["validation_file"] = validation_file_id
88
+
89
+ print(f"🚀 Creating fine-tuning job...")
90
+ print(f" Base model: {base_model}")
91
+ print(f" Suffix: {suffix}")
92
+ print(f" Hyperparameters: {json.dumps(hyperparameters, indent=2)}")
93
+
94
+ response = await client.post(
95
+ f"{MODAL_BASE_URL}/v1/fine_tuning/jobs",
96
+ json=payload,
97
+ headers=headers
98
+ )
99
+
100
+ if response.status_code != 200:
101
+ raise Exception(f"Failed to create job: {response.status_code} - {response.text}")
102
+
103
+ job_data = response.json()
104
+ job_id = job_data["id"]
105
+ print(f"✅ Job created: {job_id}")
106
+ return job_id
107
+
108
+
109
+ async def monitor_job(job_id: str, api_key: str) -> dict:
110
+ """Monitor a fine-tuning job until completion."""
111
+ async with httpx.AsyncClient(timeout=30.0) as client:
112
+ headers = {"Authorization": f"Bearer {api_key}"}
113
+
114
+ print(f"\n📊 Monitoring job {job_id}...")
115
+
116
+ while True:
117
+ # Get job status
118
+ response = await client.get(
119
+ f"{MODAL_BASE_URL}/v1/fine_tuning/jobs/{job_id}",
120
+ headers=headers
121
+ )
122
+
123
+ if response.status_code != 200:
124
+ print(f"⚠️ Failed to get job status: {response.text}")
125
+ await asyncio.sleep(10)
126
+ continue
127
+
128
+ job_data = response.json()
129
+ status = job_data["status"]
130
+
131
+ # Print status update
132
+ print(f" Status: {status}")
133
+
134
+ # Get recent events
135
+ events_response = await client.get(
136
+ f"{MODAL_BASE_URL}/v1/fine_tuning/jobs/{job_id}/events?limit=5",
137
+ headers=headers
138
+ )
139
+
140
+ if events_response.status_code == 200:
141
+ events = events_response.json()["data"]
142
+ for event in events[:2]: # Show last 2 events
143
+ print(f" - {event.get('message', 'No message')}")
144
+
145
+ # Check if job is complete
146
+ if status in ["succeeded", "failed", "cancelled"]:
147
+ if status == "succeeded":
148
+ print(f"\n✅ Fine-tuning completed successfully!")
149
+ print(f" Model ID: {job_data['fine_tuned_model']}")
150
+ else:
151
+ print(f"\n❌ Fine-tuning {status}")
152
+ if job_data.get("error"):
153
+ print(f" Error: {job_data['error']}")
154
+
155
+ return job_data
156
+
157
+ # Wait before next check
158
+ await asyncio.sleep(30)
159
+
160
+
161
+ async def validate_training_data(file_path: Path) -> tuple[int, int]:
162
+ """Validate training data and return (num_examples, num_tokens_estimate)."""
163
+ num_examples = 0
164
+ total_chars = 0
165
+
166
+ with open(file_path, "r") as f:
167
+ for line in f:
168
+ try:
169
+ data = json.loads(line)
170
+ if "messages" in data:
171
+ num_examples += 1
172
+ # Rough token estimate (chars/4)
173
+ total_chars += len(json.dumps(data["messages"]))
174
+ except json.JSONDecodeError:
175
+ print(f"⚠️ Invalid JSON line: {line[:50]}...")
176
+
177
+ estimated_tokens = total_chars // 4
178
+ return num_examples, estimated_tokens
179
+
180
+
181
+ async def main():
182
+ parser = argparse.ArgumentParser(description="Kick off fine-tuning jobs on Modal")
183
+ parser.add_argument("training_file", type=str, help="Path to training data (JSONL)")
184
+ parser.add_argument("--base-model", type=str, default="Qwen/Qwen2.5-0.5B-Instruct",
185
+ choices=SUPPORTED_MODELS, help="Base model to fine-tune")
186
+ parser.add_argument("--suffix", type=str, default=None,
187
+ help="Model suffix (default: crafter-TIMESTAMP)")
188
+ parser.add_argument("--epochs", type=int, default=3,
189
+ help="Number of training epochs")
190
+ parser.add_argument("--batch-size", type=int, default=4,
191
+ help="Training batch size")
192
+ parser.add_argument("--learning-rate", type=float, default=2.0,
193
+ help="Learning rate multiplier")
194
+ parser.add_argument("--use-qlora", action="store_true",
195
+ help="Enable QLoRA for efficient training")
196
+ parser.add_argument("--validation-file", type=str, default=None,
197
+ help="Optional validation data file")
198
+ parser.add_argument("--api-key", type=str, default=None,
199
+ help="Modal API key (or set MODAL_API_KEY)")
200
+ parser.add_argument("--no-monitor", action="store_true",
201
+ help="Don't monitor job after creation")
202
+
203
+ args = parser.parse_args()
204
+
205
+ # Get API key
206
+ api_key = args.api_key or os.environ.get("MODAL_API_KEY", MODAL_API_KEY)
207
+
208
+ # Generate suffix if not provided
209
+ if args.suffix is None:
210
+ timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
211
+ args.suffix = f"crafter-{timestamp}"
212
+
213
+ # Validate training file
214
+ training_path = Path(args.training_file)
215
+ if not training_path.exists():
216
+ print(f"❌ Training file not found: {training_path}")
217
+ return
218
+
219
+ print(f"🔍 Validating training data...")
220
+ num_examples, est_tokens = await validate_training_data(training_path)
221
+ print(f" Examples: {num_examples}")
222
+ print(f" Estimated tokens: {est_tokens:,}")
223
+
224
+ if num_examples < 10:
225
+ print("⚠️ Warning: Very few training examples. Consider generating more data.")
226
+
227
+ # Prepare hyperparameters
228
+ hyperparams = {
229
+ "n_epochs": args.epochs,
230
+ "batch_size": args.batch_size,
231
+ "learning_rate_multiplier": args.learning_rate,
232
+ "use_qlora": args.use_qlora
233
+ }
234
+
235
+ try:
236
+ # Upload training file
237
+ training_file_id = await upload_training_file(training_path, api_key)
238
+
239
+ # Upload validation file if provided
240
+ validation_file_id = None
241
+ if args.validation_file:
242
+ val_path = Path(args.validation_file)
243
+ if val_path.exists():
244
+ validation_file_id = await upload_training_file(val_path, api_key)
245
+
246
+ # Create fine-tuning job
247
+ job_id = await create_fine_tuning_job(
248
+ base_model=args.base_model,
249
+ training_file_id=training_file_id,
250
+ suffix=args.suffix,
251
+ hyperparameters=hyperparams,
252
+ api_key=api_key,
253
+ validation_file_id=validation_file_id
254
+ )
255
+
256
+ # Monitor job unless disabled
257
+ if not args.no_monitor:
258
+ job_data = await monitor_job(job_id, api_key)
259
+
260
+ if job_data["status"] == "succeeded":
261
+ print("\n🎉 Fine-tuning complete!")
262
+ print(f"Your model is ready: {job_data['fine_tuned_model']}")
263
+ print("\nTo use your model:")
264
+ print(f" curl -X POST {MODAL_BASE_URL}/v1/chat/completions \\")
265
+ print(f" -H 'Authorization: Bearer YOUR_API_KEY' \\")
266
+ print(f" -H 'Content-Type: application/json' \\")
267
+ print(f" -d '{{")
268
+ print(f' "model": "{job_data["fine_tuned_model"]}",')
269
+ print(f' "messages": [{{"role": "user", "content": "Hello!"}}]')
270
+ print(f" }}'")
271
+ else:
272
+ print(f"\nJob created: {job_id}")
273
+ print(f"Check status at: {MODAL_BASE_URL}/v1/fine_tuning/jobs/{job_id}")
274
+
275
+ except Exception as e:
276
+ print(f"\n❌ Error: {e}")
277
+ return 1
278
+
279
+ return 0
280
+
281
+
282
+ if __name__ == "__main__":
283
+ asyncio.run(main())