openadapt-ml 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/benchmarks/__init__.py +8 -0
- openadapt_ml/benchmarks/agent.py +90 -11
- openadapt_ml/benchmarks/azure.py +35 -6
- openadapt_ml/benchmarks/cli.py +4449 -201
- openadapt_ml/benchmarks/live_tracker.py +180 -0
- openadapt_ml/benchmarks/runner.py +41 -4
- openadapt_ml/benchmarks/viewer.py +1219 -0
- openadapt_ml/benchmarks/vm_monitor.py +610 -0
- openadapt_ml/benchmarks/waa.py +61 -4
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +222 -0
- openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +539 -0
- openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
- openadapt_ml/benchmarks/waa_live.py +619 -0
- openadapt_ml/cloud/local.py +1555 -1
- openadapt_ml/cloud/ssh_tunnel.py +553 -0
- openadapt_ml/datasets/next_action.py +87 -68
- openadapt_ml/evals/grounding.py +26 -8
- openadapt_ml/evals/trajectory_matching.py +84 -36
- openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
- openadapt_ml/experiments/demo_prompt/format_demo.py +226 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
- openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
- openadapt_ml/experiments/demo_prompt/run_experiment.py +531 -0
- openadapt_ml/experiments/waa_demo/__init__.py +10 -0
- openadapt_ml/experiments/waa_demo/demos.py +357 -0
- openadapt_ml/experiments/waa_demo/runner.py +717 -0
- openadapt_ml/experiments/waa_demo/tasks.py +151 -0
- openadapt_ml/export/__init__.py +9 -0
- openadapt_ml/export/__main__.py +6 -0
- openadapt_ml/export/cli.py +89 -0
- openadapt_ml/export/parquet.py +265 -0
- openadapt_ml/ingest/__init__.py +3 -4
- openadapt_ml/ingest/capture.py +89 -81
- openadapt_ml/ingest/loader.py +116 -68
- openadapt_ml/ingest/synthetic.py +221 -159
- openadapt_ml/retrieval/README.md +226 -0
- openadapt_ml/retrieval/USAGE.md +391 -0
- openadapt_ml/retrieval/__init__.py +91 -0
- openadapt_ml/retrieval/demo_retriever.py +817 -0
- openadapt_ml/retrieval/embeddings.py +629 -0
- openadapt_ml/retrieval/index.py +194 -0
- openadapt_ml/retrieval/retriever.py +160 -0
- openadapt_ml/runtime/policy.py +10 -10
- openadapt_ml/schema/__init__.py +104 -0
- openadapt_ml/schema/converters.py +541 -0
- openadapt_ml/schema/episode.py +457 -0
- openadapt_ml/scripts/compare.py +26 -16
- openadapt_ml/scripts/eval_policy.py +4 -5
- openadapt_ml/scripts/prepare_synthetic.py +14 -17
- openadapt_ml/scripts/train.py +81 -70
- openadapt_ml/training/benchmark_viewer.py +3225 -0
- openadapt_ml/training/trainer.py +120 -363
- openadapt_ml/training/trl_trainer.py +354 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/METADATA +102 -60
- openadapt_ml-0.2.0.dist-info/RECORD +86 -0
- openadapt_ml/schemas/__init__.py +0 -53
- openadapt_ml/schemas/sessions.py +0 -122
- openadapt_ml/schemas/validation.py +0 -252
- openadapt_ml-0.1.0.dist-info/RECORD +0 -55
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/licenses/LICENSE +0 -0
openadapt_ml/scripts/train.py
CHANGED
|
@@ -1,14 +1,27 @@
|
|
|
1
|
+
"""Train a VLM using TRL SFTTrainer + Unsloth.
|
|
2
|
+
|
|
3
|
+
This script provides the main training entry point for openadapt-ml.
|
|
4
|
+
It uses TRL's SFTTrainer with optional Unsloth optimizations for
|
|
5
|
+
efficient VLM fine-tuning.
|
|
6
|
+
|
|
7
|
+
Usage:
|
|
8
|
+
# Train on synthetic data
|
|
9
|
+
python -m openadapt_ml.scripts.train --config configs/qwen3vl_synthetic_som.yaml
|
|
10
|
+
|
|
11
|
+
# Train on capture recording
|
|
12
|
+
python -m openadapt_ml.scripts.train --config configs/qwen3vl_capture.yaml \
|
|
13
|
+
--capture /path/to/capture --goal "Task description" --open
|
|
14
|
+
"""
|
|
15
|
+
|
|
1
16
|
from __future__ import annotations
|
|
2
17
|
|
|
3
18
|
from pathlib import Path
|
|
4
|
-
from typing import
|
|
19
|
+
from typing import Dict, Any, Optional
|
|
5
20
|
|
|
6
21
|
import yaml
|
|
7
22
|
|
|
8
|
-
from openadapt_ml.
|
|
9
|
-
from openadapt_ml.
|
|
10
|
-
from openadapt_ml.models.qwen_vl import QwenVLAdapter
|
|
11
|
-
from openadapt_ml.training.trainer import TrainingConfig, TrainingLogger, train_supervised
|
|
23
|
+
from openadapt_ml.ingest.synthetic import generate_synthetic_episodes
|
|
24
|
+
from openadapt_ml.training.trl_trainer import TRLTrainingConfig, train_with_trl
|
|
12
25
|
|
|
13
26
|
|
|
14
27
|
def _load_config(path: str | Path) -> dict:
|
|
@@ -31,22 +44,27 @@ def main(
|
|
|
31
44
|
goal: str | None = None,
|
|
32
45
|
output_dir: str | None = None,
|
|
33
46
|
open_dashboard: bool = False,
|
|
47
|
+
use_unsloth: bool = True,
|
|
34
48
|
) -> None:
|
|
49
|
+
"""Train a VLM using TRL SFTTrainer.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
config_path: Path to YAML config file
|
|
53
|
+
capture_path: Optional path to openadapt-capture recording
|
|
54
|
+
goal: Task goal/description (overrides recording's task description)
|
|
55
|
+
output_dir: Output directory for logs and dashboard
|
|
56
|
+
open_dashboard: Open training dashboard in browser after training
|
|
57
|
+
use_unsloth: Enable Unsloth optimizations (default True)
|
|
58
|
+
"""
|
|
35
59
|
cfg = _load_config(config_path)
|
|
36
60
|
|
|
37
61
|
model_name = cfg["model"]["name"]
|
|
38
62
|
load_in_4bit = cfg["model"].get("load_in_4bit", False)
|
|
39
|
-
max_pixels = cfg["model"].get("max_pixels") # For faster training with smaller images
|
|
40
|
-
min_pixels = cfg["model"].get("min_pixels")
|
|
41
63
|
|
|
42
|
-
# LoRA config
|
|
43
|
-
# adapter should be saved. We pass a cleaned config (without
|
|
44
|
-
# weights_path) to the adapter loader.
|
|
64
|
+
# LoRA config
|
|
45
65
|
raw_lora_cfg = cfg.get("lora")
|
|
46
|
-
lora_weights_path: Optional[str] = None
|
|
47
66
|
lora_cfg: Optional[Dict[str, Any]] = None
|
|
48
67
|
if isinstance(raw_lora_cfg, dict):
|
|
49
|
-
lora_weights_path = raw_lora_cfg.get("weights_path")
|
|
50
68
|
lora_cfg = {k: v for k, v in raw_lora_cfg.items() if k != "weights_path"}
|
|
51
69
|
else:
|
|
52
70
|
lora_cfg = raw_lora_cfg
|
|
@@ -65,84 +83,61 @@ def main(
|
|
|
65
83
|
num_sessions = synth_cfg.get("num_sessions", 10)
|
|
66
84
|
seed = synth_cfg.get("seed")
|
|
67
85
|
default_output_dir = str(Path("synthetic") / "train")
|
|
68
|
-
|
|
86
|
+
synth_output = synth_cfg.get("output_dir", default_output_dir)
|
|
69
87
|
use_som = synth_cfg.get("use_som", False)
|
|
70
88
|
scenario = synth_cfg.get("scenario", "login")
|
|
71
89
|
|
|
72
|
-
|
|
73
|
-
|
|
90
|
+
episodes = generate_synthetic_episodes(
|
|
91
|
+
num_episodes=num_sessions,
|
|
74
92
|
seed=seed,
|
|
75
|
-
output_dir=
|
|
93
|
+
output_dir=synth_output,
|
|
76
94
|
use_som=use_som,
|
|
77
95
|
scenario=scenario,
|
|
78
96
|
)
|
|
79
|
-
episodes = [ep for sess in sessions for ep in sess.episodes]
|
|
80
97
|
data_source = f"synthetic '{scenario}'"
|
|
81
98
|
|
|
82
|
-
samples = build_next_action_sft_samples(episodes, use_som=use_som)
|
|
83
|
-
dataset = NextActionDataset(samples)
|
|
84
|
-
|
|
85
|
-
# Adapter + model
|
|
86
|
-
adapter = QwenVLAdapter.from_pretrained(
|
|
87
|
-
model_name=model_name,
|
|
88
|
-
lora_config=lora_cfg,
|
|
89
|
-
load_in_4bit=load_in_4bit,
|
|
90
|
-
max_pixels=max_pixels,
|
|
91
|
-
min_pixels=min_pixels,
|
|
92
|
-
)
|
|
93
|
-
|
|
94
|
-
# Training config
|
|
95
|
-
train_cfg_raw = cfg.get("training", {})
|
|
96
99
|
# Determine output directory
|
|
100
|
+
train_cfg_raw = cfg.get("training", {})
|
|
97
101
|
if output_dir is None:
|
|
98
102
|
output_dir = train_cfg_raw.get("output_dir", "training_output")
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
+
|
|
104
|
+
print(f"Using TRL trainer (Unsloth: {use_unsloth})")
|
|
105
|
+
|
|
106
|
+
# Build TRL config from YAML config
|
|
107
|
+
lora_dict = lora_cfg if isinstance(lora_cfg, dict) else {}
|
|
108
|
+
trl_config = TRLTrainingConfig(
|
|
109
|
+
model_name=model_name,
|
|
110
|
+
load_in_4bit=load_in_4bit,
|
|
111
|
+
max_seq_length=train_cfg_raw.get("max_seq_length", 4096),
|
|
112
|
+
lora_r=lora_dict.get("r", 16),
|
|
113
|
+
lora_alpha=lora_dict.get("lora_alpha", 32),
|
|
114
|
+
lora_dropout=lora_dict.get("lora_dropout", 0.0),
|
|
115
|
+
finetune_vision_layers=lora_dict.get("finetune_vision_layers", False),
|
|
116
|
+
num_epochs=train_cfg_raw.get("num_train_epochs", 3),
|
|
117
|
+
batch_size=train_cfg_raw.get("per_device_train_batch_size", 1),
|
|
118
|
+
gradient_accumulation_steps=train_cfg_raw.get("gradient_accumulation_steps", 4),
|
|
103
119
|
learning_rate=train_cfg_raw.get("learning_rate", 2e-4),
|
|
104
120
|
warmup_ratio=train_cfg_raw.get("warmup_ratio", 0.03),
|
|
105
|
-
weight_decay=train_cfg_raw.get("weight_decay", 0.0),
|
|
106
|
-
max_grad_norm=train_cfg_raw.get("max_grad_norm", 1.0),
|
|
107
|
-
logging_steps=train_cfg_raw.get("logging_steps", 10),
|
|
108
|
-
lr_scheduler_type=train_cfg_raw.get("lr_scheduler_type", "linear"),
|
|
109
|
-
early_stop_loss=train_cfg_raw.get("early_stop_loss", 1e-4),
|
|
110
|
-
early_stop_patience=train_cfg_raw.get("early_stop_patience", 10),
|
|
111
121
|
output_dir=output_dir,
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
eval_samples=train_cfg_raw.get("eval_samples", 3),
|
|
122
|
+
logging_steps=train_cfg_raw.get("logging_steps", 10),
|
|
123
|
+
save_strategy=train_cfg_raw.get("save_strategy", "epoch"),
|
|
115
124
|
)
|
|
116
125
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
126
|
+
# Disable Unsloth if requested
|
|
127
|
+
if not use_unsloth:
|
|
128
|
+
import os
|
|
129
|
+
os.environ["OPENADAPT_DISABLE_UNSLOTH"] = "1"
|
|
120
130
|
|
|
121
|
-
|
|
122
|
-
|
|
131
|
+
base_path = Path(capture_path).parent if capture_path else None
|
|
132
|
+
print(f"Training on {len(episodes)} episodes from {data_source}")
|
|
123
133
|
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
config_path=str(config_path),
|
|
130
|
-
goal=goal or episode_goal, # Use explicit goal or episode goal
|
|
134
|
+
checkpoint_path = train_with_trl(
|
|
135
|
+
episodes=episodes,
|
|
136
|
+
config=trl_config,
|
|
137
|
+
use_som=use_som,
|
|
138
|
+
base_path=base_path,
|
|
131
139
|
)
|
|
132
|
-
|
|
133
|
-
# Pass the first episode for periodic evaluation (if available)
|
|
134
|
-
eval_episode = episodes[0] if episodes else None
|
|
135
|
-
training_success = train_supervised(adapter, dataset, train_cfg, logger=logger, episode=eval_episode)
|
|
136
|
-
|
|
137
|
-
# Persist the trained adapter if a weights_path was provided and training succeeded.
|
|
138
|
-
if lora_weights_path:
|
|
139
|
-
if training_success:
|
|
140
|
-
save_path = Path(lora_weights_path)
|
|
141
|
-
save_path.mkdir(parents=True, exist_ok=True)
|
|
142
|
-
adapter.model.save_pretrained(save_path) # type: ignore[arg-type]
|
|
143
|
-
print(f"Saved LoRA adapter to {save_path}")
|
|
144
|
-
else:
|
|
145
|
-
print("Training aborted due to invalid loss. Skipping checkpoint save to avoid corrupted weights.")
|
|
140
|
+
print(f"Training complete. Checkpoint saved to: {checkpoint_path}")
|
|
146
141
|
|
|
147
142
|
# Open dashboard in browser if requested
|
|
148
143
|
if open_dashboard:
|
|
@@ -163,12 +158,28 @@ if __name__ == "__main__":
|
|
|
163
158
|
parser.add_argument("--goal", type=str, help="Task goal/description (overrides recording's task description).")
|
|
164
159
|
parser.add_argument("--output-dir", type=str, help="Output directory for logs and dashboard.")
|
|
165
160
|
parser.add_argument("--open", action="store_true", help="Open training dashboard in browser.")
|
|
161
|
+
|
|
162
|
+
parser.add_argument(
|
|
163
|
+
"--use-unsloth",
|
|
164
|
+
action="store_true",
|
|
165
|
+
default=True,
|
|
166
|
+
help="Enable Unsloth optimizations (default)."
|
|
167
|
+
)
|
|
168
|
+
parser.add_argument(
|
|
169
|
+
"--no-unsloth",
|
|
170
|
+
action="store_true",
|
|
171
|
+
help="Disable Unsloth optimizations."
|
|
172
|
+
)
|
|
166
173
|
args = parser.parse_args()
|
|
167
174
|
|
|
175
|
+
# Determine effective flags
|
|
176
|
+
use_unsloth = args.use_unsloth and not args.no_unsloth
|
|
177
|
+
|
|
168
178
|
main(
|
|
169
179
|
args.config,
|
|
170
180
|
capture_path=args.capture,
|
|
171
181
|
goal=args.goal,
|
|
172
182
|
output_dir=args.output_dir,
|
|
173
183
|
open_dashboard=args.open,
|
|
184
|
+
use_unsloth=use_unsloth,
|
|
174
185
|
)
|