openadapt-ml 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. openadapt_ml/benchmarks/__init__.py +8 -0
  2. openadapt_ml/benchmarks/agent.py +90 -11
  3. openadapt_ml/benchmarks/azure.py +35 -6
  4. openadapt_ml/benchmarks/cli.py +4449 -201
  5. openadapt_ml/benchmarks/live_tracker.py +180 -0
  6. openadapt_ml/benchmarks/runner.py +41 -4
  7. openadapt_ml/benchmarks/viewer.py +1219 -0
  8. openadapt_ml/benchmarks/vm_monitor.py +610 -0
  9. openadapt_ml/benchmarks/waa.py +61 -4
  10. openadapt_ml/benchmarks/waa_deploy/Dockerfile +222 -0
  11. openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
  12. openadapt_ml/benchmarks/waa_deploy/api_agent.py +539 -0
  13. openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
  14. openadapt_ml/benchmarks/waa_live.py +619 -0
  15. openadapt_ml/cloud/local.py +1555 -1
  16. openadapt_ml/cloud/ssh_tunnel.py +553 -0
  17. openadapt_ml/datasets/next_action.py +87 -68
  18. openadapt_ml/evals/grounding.py +26 -8
  19. openadapt_ml/evals/trajectory_matching.py +84 -36
  20. openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
  21. openadapt_ml/experiments/demo_prompt/format_demo.py +226 -0
  22. openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
  23. openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
  24. openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
  25. openadapt_ml/experiments/demo_prompt/run_experiment.py +531 -0
  26. openadapt_ml/experiments/waa_demo/__init__.py +10 -0
  27. openadapt_ml/experiments/waa_demo/demos.py +357 -0
  28. openadapt_ml/experiments/waa_demo/runner.py +717 -0
  29. openadapt_ml/experiments/waa_demo/tasks.py +151 -0
  30. openadapt_ml/export/__init__.py +9 -0
  31. openadapt_ml/export/__main__.py +6 -0
  32. openadapt_ml/export/cli.py +89 -0
  33. openadapt_ml/export/parquet.py +265 -0
  34. openadapt_ml/ingest/__init__.py +3 -4
  35. openadapt_ml/ingest/capture.py +89 -81
  36. openadapt_ml/ingest/loader.py +116 -68
  37. openadapt_ml/ingest/synthetic.py +221 -159
  38. openadapt_ml/retrieval/README.md +226 -0
  39. openadapt_ml/retrieval/USAGE.md +391 -0
  40. openadapt_ml/retrieval/__init__.py +91 -0
  41. openadapt_ml/retrieval/demo_retriever.py +817 -0
  42. openadapt_ml/retrieval/embeddings.py +629 -0
  43. openadapt_ml/retrieval/index.py +194 -0
  44. openadapt_ml/retrieval/retriever.py +160 -0
  45. openadapt_ml/runtime/policy.py +10 -10
  46. openadapt_ml/schema/__init__.py +104 -0
  47. openadapt_ml/schema/converters.py +541 -0
  48. openadapt_ml/schema/episode.py +457 -0
  49. openadapt_ml/scripts/compare.py +26 -16
  50. openadapt_ml/scripts/eval_policy.py +4 -5
  51. openadapt_ml/scripts/prepare_synthetic.py +14 -17
  52. openadapt_ml/scripts/train.py +81 -70
  53. openadapt_ml/training/benchmark_viewer.py +3225 -0
  54. openadapt_ml/training/trainer.py +120 -363
  55. openadapt_ml/training/trl_trainer.py +354 -0
  56. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/METADATA +102 -60
  57. openadapt_ml-0.2.0.dist-info/RECORD +86 -0
  58. openadapt_ml/schemas/__init__.py +0 -53
  59. openadapt_ml/schemas/sessions.py +0 -122
  60. openadapt_ml/schemas/validation.py +0 -252
  61. openadapt_ml-0.1.0.dist-info/RECORD +0 -55
  62. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/WHEEL +0 -0
  63. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,14 +1,27 @@
1
+ """Train a VLM using TRL SFTTrainer + Unsloth.
2
+
3
+ This script provides the main training entry point for openadapt-ml.
4
+ It uses TRL's SFTTrainer with optional Unsloth optimizations for
5
+ efficient VLM fine-tuning.
6
+
7
+ Usage:
8
+ # Train on synthetic data
9
+ python -m openadapt_ml.scripts.train --config configs/qwen3vl_synthetic_som.yaml
10
+
11
+ # Train on capture recording
12
+ python -m openadapt_ml.scripts.train --config configs/qwen3vl_capture.yaml \
13
+ --capture /path/to/capture --goal "Task description" --open
14
+ """
15
+
1
16
  from __future__ import annotations
2
17
 
3
18
  from pathlib import Path
4
- from typing import List, Optional, Dict, Any
19
+ from typing import Dict, Any, Optional
5
20
 
6
21
  import yaml
7
22
 
8
- from openadapt_ml.datasets.next_action import NextActionDataset, build_next_action_sft_samples
9
- from openadapt_ml.ingest.synthetic import generate_synthetic_sessions
10
- from openadapt_ml.models.qwen_vl import QwenVLAdapter
11
- from openadapt_ml.training.trainer import TrainingConfig, TrainingLogger, train_supervised
23
+ from openadapt_ml.ingest.synthetic import generate_synthetic_episodes
24
+ from openadapt_ml.training.trl_trainer import TRLTrainingConfig, train_with_trl
12
25
 
13
26
 
14
27
  def _load_config(path: str | Path) -> dict:
@@ -31,22 +44,27 @@ def main(
31
44
  goal: str | None = None,
32
45
  output_dir: str | None = None,
33
46
  open_dashboard: bool = False,
47
+ use_unsloth: bool = True,
34
48
  ) -> None:
49
+ """Train a VLM using TRL SFTTrainer.
50
+
51
+ Args:
52
+ config_path: Path to YAML config file
53
+ capture_path: Optional path to openadapt-capture recording
54
+ goal: Task goal/description (overrides recording's task description)
55
+ output_dir: Output directory for logs and dashboard
56
+ open_dashboard: Open training dashboard in browser after training
57
+ use_unsloth: Enable Unsloth optimizations (default True)
58
+ """
35
59
  cfg = _load_config(config_path)
36
60
 
37
61
  model_name = cfg["model"]["name"]
38
62
  load_in_4bit = cfg["model"].get("load_in_4bit", False)
39
- max_pixels = cfg["model"].get("max_pixels") # For faster training with smaller images
40
- min_pixels = cfg["model"].get("min_pixels")
41
63
 
42
- # LoRA config may include an optional weights_path where the trained
43
- # adapter should be saved. We pass a cleaned config (without
44
- # weights_path) to the adapter loader.
64
+ # LoRA config
45
65
  raw_lora_cfg = cfg.get("lora")
46
- lora_weights_path: Optional[str] = None
47
66
  lora_cfg: Optional[Dict[str, Any]] = None
48
67
  if isinstance(raw_lora_cfg, dict):
49
- lora_weights_path = raw_lora_cfg.get("weights_path")
50
68
  lora_cfg = {k: v for k, v in raw_lora_cfg.items() if k != "weights_path"}
51
69
  else:
52
70
  lora_cfg = raw_lora_cfg
@@ -65,84 +83,61 @@ def main(
65
83
  num_sessions = synth_cfg.get("num_sessions", 10)
66
84
  seed = synth_cfg.get("seed")
67
85
  default_output_dir = str(Path("synthetic") / "train")
68
- output_dir = synth_cfg.get("output_dir", default_output_dir)
86
+ synth_output = synth_cfg.get("output_dir", default_output_dir)
69
87
  use_som = synth_cfg.get("use_som", False)
70
88
  scenario = synth_cfg.get("scenario", "login")
71
89
 
72
- sessions = generate_synthetic_sessions(
73
- num_sessions=num_sessions,
90
+ episodes = generate_synthetic_episodes(
91
+ num_episodes=num_sessions,
74
92
  seed=seed,
75
- output_dir=output_dir,
93
+ output_dir=synth_output,
76
94
  use_som=use_som,
77
95
  scenario=scenario,
78
96
  )
79
- episodes = [ep for sess in sessions for ep in sess.episodes]
80
97
  data_source = f"synthetic '{scenario}'"
81
98
 
82
- samples = build_next_action_sft_samples(episodes, use_som=use_som)
83
- dataset = NextActionDataset(samples)
84
-
85
- # Adapter + model
86
- adapter = QwenVLAdapter.from_pretrained(
87
- model_name=model_name,
88
- lora_config=lora_cfg,
89
- load_in_4bit=load_in_4bit,
90
- max_pixels=max_pixels,
91
- min_pixels=min_pixels,
92
- )
93
-
94
- # Training config
95
- train_cfg_raw = cfg.get("training", {})
96
99
  # Determine output directory
100
+ train_cfg_raw = cfg.get("training", {})
97
101
  if output_dir is None:
98
102
  output_dir = train_cfg_raw.get("output_dir", "training_output")
99
- train_cfg = TrainingConfig(
100
- num_train_epochs=train_cfg_raw.get("num_train_epochs", 1),
101
- per_device_train_batch_size=train_cfg_raw.get("per_device_train_batch_size", 1),
102
- gradient_accumulation_steps=train_cfg_raw.get("gradient_accumulation_steps", 1),
103
+
104
+ print(f"Using TRL trainer (Unsloth: {use_unsloth})")
105
+
106
+ # Build TRL config from YAML config
107
+ lora_dict = lora_cfg if isinstance(lora_cfg, dict) else {}
108
+ trl_config = TRLTrainingConfig(
109
+ model_name=model_name,
110
+ load_in_4bit=load_in_4bit,
111
+ max_seq_length=train_cfg_raw.get("max_seq_length", 4096),
112
+ lora_r=lora_dict.get("r", 16),
113
+ lora_alpha=lora_dict.get("lora_alpha", 32),
114
+ lora_dropout=lora_dict.get("lora_dropout", 0.0),
115
+ finetune_vision_layers=lora_dict.get("finetune_vision_layers", False),
116
+ num_epochs=train_cfg_raw.get("num_train_epochs", 3),
117
+ batch_size=train_cfg_raw.get("per_device_train_batch_size", 1),
118
+ gradient_accumulation_steps=train_cfg_raw.get("gradient_accumulation_steps", 4),
103
119
  learning_rate=train_cfg_raw.get("learning_rate", 2e-4),
104
120
  warmup_ratio=train_cfg_raw.get("warmup_ratio", 0.03),
105
- weight_decay=train_cfg_raw.get("weight_decay", 0.0),
106
- max_grad_norm=train_cfg_raw.get("max_grad_norm", 1.0),
107
- logging_steps=train_cfg_raw.get("logging_steps", 10),
108
- lr_scheduler_type=train_cfg_raw.get("lr_scheduler_type", "linear"),
109
- early_stop_loss=train_cfg_raw.get("early_stop_loss", 1e-4),
110
- early_stop_patience=train_cfg_raw.get("early_stop_patience", 10),
111
121
  output_dir=output_dir,
112
- # Evaluation settings
113
- eval_every_epoch=train_cfg_raw.get("eval_every_epoch", True),
114
- eval_samples=train_cfg_raw.get("eval_samples", 3),
122
+ logging_steps=train_cfg_raw.get("logging_steps", 10),
123
+ save_strategy=train_cfg_raw.get("save_strategy", "epoch"),
115
124
  )
116
125
 
117
- som_label = " (SoM mode)" if use_som else " (coordinate mode)"
118
- print(f"Loaded {len(episodes)} episodes and {len(samples)} SFT samples{som_label} from {data_source}.")
119
- print("Starting training...")
126
+ # Disable Unsloth if requested
127
+ if not use_unsloth:
128
+ import os
129
+ os.environ["OPENADAPT_DISABLE_UNSLOTH"] = "1"
120
130
 
121
- # Get goal from episodes (for logging/viewer)
122
- episode_goal = episodes[0].goal if episodes else ""
131
+ base_path = Path(capture_path).parent if capture_path else None
132
+ print(f"Training on {len(episodes)} episodes from {data_source}")
123
133
 
124
- # Create logger with metadata for dashboard
125
- logger = TrainingLogger(
126
- output_dir=train_cfg.output_dir,
127
- config=train_cfg,
128
- capture_path=str(capture_path) if capture_path else "",
129
- config_path=str(config_path),
130
- goal=goal or episode_goal, # Use explicit goal or episode goal
134
+ checkpoint_path = train_with_trl(
135
+ episodes=episodes,
136
+ config=trl_config,
137
+ use_som=use_som,
138
+ base_path=base_path,
131
139
  )
132
-
133
- # Pass the first episode for periodic evaluation (if available)
134
- eval_episode = episodes[0] if episodes else None
135
- training_success = train_supervised(adapter, dataset, train_cfg, logger=logger, episode=eval_episode)
136
-
137
- # Persist the trained adapter if a weights_path was provided and training succeeded.
138
- if lora_weights_path:
139
- if training_success:
140
- save_path = Path(lora_weights_path)
141
- save_path.mkdir(parents=True, exist_ok=True)
142
- adapter.model.save_pretrained(save_path) # type: ignore[arg-type]
143
- print(f"Saved LoRA adapter to {save_path}")
144
- else:
145
- print("Training aborted due to invalid loss. Skipping checkpoint save to avoid corrupted weights.")
140
+ print(f"Training complete. Checkpoint saved to: {checkpoint_path}")
146
141
 
147
142
  # Open dashboard in browser if requested
148
143
  if open_dashboard:
@@ -163,12 +158,28 @@ if __name__ == "__main__":
163
158
  parser.add_argument("--goal", type=str, help="Task goal/description (overrides recording's task description).")
164
159
  parser.add_argument("--output-dir", type=str, help="Output directory for logs and dashboard.")
165
160
  parser.add_argument("--open", action="store_true", help="Open training dashboard in browser.")
161
+
162
+ parser.add_argument(
163
+ "--use-unsloth",
164
+ action="store_true",
165
+ default=True,
166
+ help="Enable Unsloth optimizations (default)."
167
+ )
168
+ parser.add_argument(
169
+ "--no-unsloth",
170
+ action="store_true",
171
+ help="Disable Unsloth optimizations."
172
+ )
166
173
  args = parser.parse_args()
167
174
 
175
+ # Determine effective flags
176
+ use_unsloth = args.use_unsloth and not args.no_unsloth
177
+
168
178
  main(
169
179
  args.config,
170
180
  capture_path=args.capture,
171
181
  goal=args.goal,
172
182
  output_dir=args.output_dir,
173
183
  open_dashboard=args.open,
184
+ use_unsloth=use_unsloth,
174
185
  )