openadapt-ml 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. openadapt_ml/baselines/__init__.py +121 -0
  2. openadapt_ml/baselines/adapter.py +185 -0
  3. openadapt_ml/baselines/cli.py +314 -0
  4. openadapt_ml/baselines/config.py +448 -0
  5. openadapt_ml/baselines/parser.py +922 -0
  6. openadapt_ml/baselines/prompts.py +787 -0
  7. openadapt_ml/benchmarks/__init__.py +13 -107
  8. openadapt_ml/benchmarks/agent.py +297 -374
  9. openadapt_ml/benchmarks/azure.py +62 -24
  10. openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
  11. openadapt_ml/benchmarks/cli.py +1874 -751
  12. openadapt_ml/benchmarks/trace_export.py +631 -0
  13. openadapt_ml/benchmarks/viewer.py +1236 -0
  14. openadapt_ml/benchmarks/vm_monitor.py +1111 -0
  15. openadapt_ml/benchmarks/waa_deploy/Dockerfile +216 -0
  16. openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
  17. openadapt_ml/benchmarks/waa_deploy/api_agent.py +540 -0
  18. openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
  19. openadapt_ml/cloud/azure_inference.py +3 -5
  20. openadapt_ml/cloud/lambda_labs.py +722 -307
  21. openadapt_ml/cloud/local.py +3194 -89
  22. openadapt_ml/cloud/ssh_tunnel.py +595 -0
  23. openadapt_ml/datasets/next_action.py +125 -96
  24. openadapt_ml/evals/grounding.py +32 -9
  25. openadapt_ml/evals/plot_eval_metrics.py +15 -13
  26. openadapt_ml/evals/trajectory_matching.py +120 -57
  27. openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
  28. openadapt_ml/experiments/demo_prompt/format_demo.py +236 -0
  29. openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
  30. openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
  31. openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
  32. openadapt_ml/experiments/demo_prompt/run_experiment.py +541 -0
  33. openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
  34. openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
  35. openadapt_ml/experiments/representation_shootout/config.py +390 -0
  36. openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
  37. openadapt_ml/experiments/representation_shootout/runner.py +687 -0
  38. openadapt_ml/experiments/waa_demo/__init__.py +10 -0
  39. openadapt_ml/experiments/waa_demo/demos.py +357 -0
  40. openadapt_ml/experiments/waa_demo/runner.py +732 -0
  41. openadapt_ml/experiments/waa_demo/tasks.py +151 -0
  42. openadapt_ml/export/__init__.py +9 -0
  43. openadapt_ml/export/__main__.py +6 -0
  44. openadapt_ml/export/cli.py +89 -0
  45. openadapt_ml/export/parquet.py +277 -0
  46. openadapt_ml/grounding/detector.py +18 -14
  47. openadapt_ml/ingest/__init__.py +11 -10
  48. openadapt_ml/ingest/capture.py +97 -86
  49. openadapt_ml/ingest/loader.py +120 -69
  50. openadapt_ml/ingest/synthetic.py +344 -193
  51. openadapt_ml/models/api_adapter.py +14 -4
  52. openadapt_ml/models/base_adapter.py +10 -2
  53. openadapt_ml/models/providers/__init__.py +288 -0
  54. openadapt_ml/models/providers/anthropic.py +266 -0
  55. openadapt_ml/models/providers/base.py +299 -0
  56. openadapt_ml/models/providers/google.py +376 -0
  57. openadapt_ml/models/providers/openai.py +342 -0
  58. openadapt_ml/models/qwen_vl.py +46 -19
  59. openadapt_ml/perception/__init__.py +35 -0
  60. openadapt_ml/perception/integration.py +399 -0
  61. openadapt_ml/retrieval/README.md +226 -0
  62. openadapt_ml/retrieval/USAGE.md +391 -0
  63. openadapt_ml/retrieval/__init__.py +91 -0
  64. openadapt_ml/retrieval/demo_retriever.py +843 -0
  65. openadapt_ml/retrieval/embeddings.py +630 -0
  66. openadapt_ml/retrieval/index.py +194 -0
  67. openadapt_ml/retrieval/retriever.py +162 -0
  68. openadapt_ml/runtime/__init__.py +50 -0
  69. openadapt_ml/runtime/policy.py +27 -14
  70. openadapt_ml/runtime/safety_gate.py +471 -0
  71. openadapt_ml/schema/__init__.py +113 -0
  72. openadapt_ml/schema/converters.py +588 -0
  73. openadapt_ml/schema/episode.py +470 -0
  74. openadapt_ml/scripts/capture_screenshots.py +530 -0
  75. openadapt_ml/scripts/compare.py +102 -61
  76. openadapt_ml/scripts/demo_policy.py +4 -1
  77. openadapt_ml/scripts/eval_policy.py +19 -14
  78. openadapt_ml/scripts/make_gif.py +1 -1
  79. openadapt_ml/scripts/prepare_synthetic.py +16 -17
  80. openadapt_ml/scripts/train.py +98 -75
  81. openadapt_ml/segmentation/README.md +920 -0
  82. openadapt_ml/segmentation/__init__.py +97 -0
  83. openadapt_ml/segmentation/adapters/__init__.py +5 -0
  84. openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
  85. openadapt_ml/segmentation/annotator.py +610 -0
  86. openadapt_ml/segmentation/cache.py +290 -0
  87. openadapt_ml/segmentation/cli.py +674 -0
  88. openadapt_ml/segmentation/deduplicator.py +656 -0
  89. openadapt_ml/segmentation/frame_describer.py +788 -0
  90. openadapt_ml/segmentation/pipeline.py +340 -0
  91. openadapt_ml/segmentation/schemas.py +622 -0
  92. openadapt_ml/segmentation/segment_extractor.py +634 -0
  93. openadapt_ml/training/azure_ops_viewer.py +1097 -0
  94. openadapt_ml/training/benchmark_viewer.py +3255 -19
  95. openadapt_ml/training/shared_ui.py +7 -7
  96. openadapt_ml/training/stub_provider.py +57 -35
  97. openadapt_ml/training/trainer.py +255 -441
  98. openadapt_ml/training/trl_trainer.py +403 -0
  99. openadapt_ml/training/viewer.py +323 -108
  100. openadapt_ml/training/viewer_components.py +180 -0
  101. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +312 -69
  102. openadapt_ml-0.2.1.dist-info/RECORD +116 -0
  103. openadapt_ml/benchmarks/base.py +0 -366
  104. openadapt_ml/benchmarks/data_collection.py +0 -432
  105. openadapt_ml/benchmarks/runner.py +0 -381
  106. openadapt_ml/benchmarks/waa.py +0 -704
  107. openadapt_ml/schemas/__init__.py +0 -53
  108. openadapt_ml/schemas/sessions.py +0 -122
  109. openadapt_ml/schemas/validation.py +0 -252
  110. openadapt_ml-0.1.0.dist-info/RECORD +0 -55
  111. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
  112. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,14 +1,27 @@
1
+ """Train a VLM using TRL SFTTrainer + Unsloth.
2
+
3
+ This script provides the main training entry point for openadapt-ml.
4
+ It uses TRL's SFTTrainer with optional Unsloth optimizations for
5
+ efficient VLM fine-tuning.
6
+
7
+ Usage:
8
+ # Train on synthetic data
9
+ python -m openadapt_ml.scripts.train --config configs/qwen3vl_synthetic_som.yaml
10
+
11
+ # Train on capture recording
12
+ python -m openadapt_ml.scripts.train --config configs/qwen3vl_capture.yaml \
13
+ --capture /path/to/capture --goal "Task description" --open
14
+ """
15
+
1
16
  from __future__ import annotations
2
17
 
3
18
  from pathlib import Path
4
- from typing import List, Optional, Dict, Any
19
+ from typing import Dict, Any, Optional
5
20
 
6
21
  import yaml
7
22
 
8
- from openadapt_ml.datasets.next_action import NextActionDataset, build_next_action_sft_samples
9
- from openadapt_ml.ingest.synthetic import generate_synthetic_sessions
10
- from openadapt_ml.models.qwen_vl import QwenVLAdapter
11
- from openadapt_ml.training.trainer import TrainingConfig, TrainingLogger, train_supervised
23
+ from openadapt_ml.ingest.synthetic import generate_synthetic_episodes
24
+ from openadapt_ml.training.trl_trainer import TRLTrainingConfig, train_with_trl
12
25
 
13
26
 
14
27
  def _load_config(path: str | Path) -> dict:
@@ -31,22 +44,27 @@ def main(
31
44
  goal: str | None = None,
32
45
  output_dir: str | None = None,
33
46
  open_dashboard: bool = False,
47
+ use_unsloth: bool = True,
34
48
  ) -> None:
49
+ """Train a VLM using TRL SFTTrainer.
50
+
51
+ Args:
52
+ config_path: Path to YAML config file
53
+ capture_path: Optional path to openadapt-capture recording
54
+ goal: Task goal/description (overrides recording's task description)
55
+ output_dir: Output directory for logs and dashboard
56
+ open_dashboard: Open training dashboard in browser after training
57
+ use_unsloth: Enable Unsloth optimizations (default True)
58
+ """
35
59
  cfg = _load_config(config_path)
36
60
 
37
61
  model_name = cfg["model"]["name"]
38
62
  load_in_4bit = cfg["model"].get("load_in_4bit", False)
39
- max_pixels = cfg["model"].get("max_pixels") # For faster training with smaller images
40
- min_pixels = cfg["model"].get("min_pixels")
41
63
 
42
- # LoRA config may include an optional weights_path where the trained
43
- # adapter should be saved. We pass a cleaned config (without
44
- # weights_path) to the adapter loader.
64
+ # LoRA config
45
65
  raw_lora_cfg = cfg.get("lora")
46
- lora_weights_path: Optional[str] = None
47
66
  lora_cfg: Optional[Dict[str, Any]] = None
48
67
  if isinstance(raw_lora_cfg, dict):
49
- lora_weights_path = raw_lora_cfg.get("weights_path")
50
68
  lora_cfg = {k: v for k, v in raw_lora_cfg.items() if k != "weights_path"}
51
69
  else:
52
70
  lora_cfg = raw_lora_cfg
@@ -65,88 +83,67 @@ def main(
65
83
  num_sessions = synth_cfg.get("num_sessions", 10)
66
84
  seed = synth_cfg.get("seed")
67
85
  default_output_dir = str(Path("synthetic") / "train")
68
- output_dir = synth_cfg.get("output_dir", default_output_dir)
86
+ synth_output = synth_cfg.get("output_dir", default_output_dir)
69
87
  use_som = synth_cfg.get("use_som", False)
70
88
  scenario = synth_cfg.get("scenario", "login")
71
89
 
72
- sessions = generate_synthetic_sessions(
73
- num_sessions=num_sessions,
90
+ episodes = generate_synthetic_episodes(
91
+ num_episodes=num_sessions,
74
92
  seed=seed,
75
- output_dir=output_dir,
93
+ output_dir=synth_output,
76
94
  use_som=use_som,
77
95
  scenario=scenario,
78
96
  )
79
- episodes = [ep for sess in sessions for ep in sess.episodes]
80
97
  data_source = f"synthetic '{scenario}'"
81
98
 
82
- samples = build_next_action_sft_samples(episodes, use_som=use_som)
83
- dataset = NextActionDataset(samples)
84
-
85
- # Adapter + model
86
- adapter = QwenVLAdapter.from_pretrained(
87
- model_name=model_name,
88
- lora_config=lora_cfg,
89
- load_in_4bit=load_in_4bit,
90
- max_pixels=max_pixels,
91
- min_pixels=min_pixels,
92
- )
93
-
94
- # Training config
95
- train_cfg_raw = cfg.get("training", {})
96
99
  # Determine output directory
100
+ train_cfg_raw = cfg.get("training", {})
97
101
  if output_dir is None:
98
102
  output_dir = train_cfg_raw.get("output_dir", "training_output")
99
- train_cfg = TrainingConfig(
100
- num_train_epochs=train_cfg_raw.get("num_train_epochs", 1),
101
- per_device_train_batch_size=train_cfg_raw.get("per_device_train_batch_size", 1),
102
- gradient_accumulation_steps=train_cfg_raw.get("gradient_accumulation_steps", 1),
103
+
104
+ print(f"Using TRL trainer (Unsloth: {use_unsloth})")
105
+
106
+ # Build TRL config from YAML config
107
+ lora_dict = lora_cfg if isinstance(lora_cfg, dict) else {}
108
+ trl_config = TRLTrainingConfig(
109
+ model_name=model_name,
110
+ load_in_4bit=load_in_4bit,
111
+ max_seq_length=train_cfg_raw.get("max_seq_length", 4096),
112
+ lora_r=lora_dict.get("r", 16),
113
+ lora_alpha=lora_dict.get("lora_alpha", 32),
114
+ lora_dropout=lora_dict.get("lora_dropout", 0.0),
115
+ finetune_vision_layers=lora_dict.get("finetune_vision_layers", False),
116
+ num_epochs=train_cfg_raw.get("num_train_epochs", 3),
117
+ batch_size=train_cfg_raw.get("per_device_train_batch_size", 1),
118
+ gradient_accumulation_steps=train_cfg_raw.get("gradient_accumulation_steps", 4),
103
119
  learning_rate=train_cfg_raw.get("learning_rate", 2e-4),
104
120
  warmup_ratio=train_cfg_raw.get("warmup_ratio", 0.03),
105
- weight_decay=train_cfg_raw.get("weight_decay", 0.0),
106
- max_grad_norm=train_cfg_raw.get("max_grad_norm", 1.0),
107
- logging_steps=train_cfg_raw.get("logging_steps", 10),
108
- lr_scheduler_type=train_cfg_raw.get("lr_scheduler_type", "linear"),
109
- early_stop_loss=train_cfg_raw.get("early_stop_loss", 1e-4),
110
- early_stop_patience=train_cfg_raw.get("early_stop_patience", 10),
111
121
  output_dir=output_dir,
112
- # Evaluation settings
113
- eval_every_epoch=train_cfg_raw.get("eval_every_epoch", True),
114
- eval_samples=train_cfg_raw.get("eval_samples", 3),
122
+ logging_steps=train_cfg_raw.get("logging_steps", 10),
123
+ save_strategy=train_cfg_raw.get("save_strategy", "epoch"),
115
124
  )
116
125
 
117
- som_label = " (SoM mode)" if use_som else " (coordinate mode)"
118
- print(f"Loaded {len(episodes)} episodes and {len(samples)} SFT samples{som_label} from {data_source}.")
119
- print("Starting training...")
126
+ # Disable Unsloth if requested
127
+ if not use_unsloth:
128
+ import os
120
129
 
121
- # Get goal from episodes (for logging/viewer)
122
- episode_goal = episodes[0].goal if episodes else ""
123
-
124
- # Create logger with metadata for dashboard
125
- logger = TrainingLogger(
126
- output_dir=train_cfg.output_dir,
127
- config=train_cfg,
128
- capture_path=str(capture_path) if capture_path else "",
129
- config_path=str(config_path),
130
- goal=goal or episode_goal, # Use explicit goal or episode goal
131
- )
130
+ os.environ["OPENADAPT_DISABLE_UNSLOTH"] = "1"
132
131
 
133
- # Pass the first episode for periodic evaluation (if available)
134
- eval_episode = episodes[0] if episodes else None
135
- training_success = train_supervised(adapter, dataset, train_cfg, logger=logger, episode=eval_episode)
132
+ base_path = Path(capture_path).parent if capture_path else None
133
+ print(f"Training on {len(episodes)} episodes from {data_source}")
136
134
 
137
- # Persist the trained adapter if a weights_path was provided and training succeeded.
138
- if lora_weights_path:
139
- if training_success:
140
- save_path = Path(lora_weights_path)
141
- save_path.mkdir(parents=True, exist_ok=True)
142
- adapter.model.save_pretrained(save_path) # type: ignore[arg-type]
143
- print(f"Saved LoRA adapter to {save_path}")
144
- else:
145
- print("Training aborted due to invalid loss. Skipping checkpoint save to avoid corrupted weights.")
135
+ checkpoint_path = train_with_trl(
136
+ episodes=episodes,
137
+ config=trl_config,
138
+ use_som=use_som,
139
+ base_path=base_path,
140
+ )
141
+ print(f"Training complete. Checkpoint saved to: {checkpoint_path}")
146
142
 
147
143
  # Open dashboard in browser if requested
148
144
  if open_dashboard:
149
145
  import webbrowser
146
+
150
147
  dashboard_path = Path(output_dir) / "dashboard.html"
151
148
  if dashboard_path.exists():
152
149
  webbrowser.open(f"file://{dashboard_path.absolute()}")
@@ -158,17 +155,43 @@ if __name__ == "__main__":
158
155
  parser = argparse.ArgumentParser(
159
156
  description="Train Qwen-VL adapter on synthetic data or openadapt-capture recordings."
160
157
  )
161
- parser.add_argument("--config", type=str, required=True, help="Path to YAML config file.")
162
- parser.add_argument("--capture", type=str, help="Path to openadapt-capture recording directory.")
163
- parser.add_argument("--goal", type=str, help="Task goal/description (overrides recording's task description).")
164
- parser.add_argument("--output-dir", type=str, help="Output directory for logs and dashboard.")
165
- parser.add_argument("--open", action="store_true", help="Open training dashboard in browser.")
158
+ parser.add_argument(
159
+ "--config", type=str, required=True, help="Path to YAML config file."
160
+ )
161
+ parser.add_argument(
162
+ "--capture", type=str, help="Path to openadapt-capture recording directory."
163
+ )
164
+ parser.add_argument(
165
+ "--goal",
166
+ type=str,
167
+ help="Task goal/description (overrides recording's task description).",
168
+ )
169
+ parser.add_argument(
170
+ "--output-dir", type=str, help="Output directory for logs and dashboard."
171
+ )
172
+ parser.add_argument(
173
+ "--open", action="store_true", help="Open training dashboard in browser."
174
+ )
175
+
176
+ parser.add_argument(
177
+ "--use-unsloth",
178
+ action="store_true",
179
+ default=True,
180
+ help="Enable Unsloth optimizations (default).",
181
+ )
182
+ parser.add_argument(
183
+ "--no-unsloth", action="store_true", help="Disable Unsloth optimizations."
184
+ )
166
185
  args = parser.parse_args()
167
186
 
187
+ # Determine effective flags
188
+ use_unsloth = args.use_unsloth and not args.no_unsloth
189
+
168
190
  main(
169
191
  args.config,
170
192
  capture_path=args.capture,
171
193
  goal=args.goal,
172
194
  output_dir=args.output_dir,
173
195
  open_dashboard=args.open,
196
+ use_unsloth=use_unsloth,
174
197
  )