openadapt-ml 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/__init__.py +0 -0
- openadapt_ml/benchmarks/__init__.py +125 -0
- openadapt_ml/benchmarks/agent.py +825 -0
- openadapt_ml/benchmarks/azure.py +761 -0
- openadapt_ml/benchmarks/base.py +366 -0
- openadapt_ml/benchmarks/cli.py +884 -0
- openadapt_ml/benchmarks/data_collection.py +432 -0
- openadapt_ml/benchmarks/runner.py +381 -0
- openadapt_ml/benchmarks/waa.py +704 -0
- openadapt_ml/cloud/__init__.py +5 -0
- openadapt_ml/cloud/azure_inference.py +441 -0
- openadapt_ml/cloud/lambda_labs.py +2445 -0
- openadapt_ml/cloud/local.py +790 -0
- openadapt_ml/config.py +56 -0
- openadapt_ml/datasets/__init__.py +0 -0
- openadapt_ml/datasets/next_action.py +507 -0
- openadapt_ml/evals/__init__.py +23 -0
- openadapt_ml/evals/grounding.py +241 -0
- openadapt_ml/evals/plot_eval_metrics.py +174 -0
- openadapt_ml/evals/trajectory_matching.py +486 -0
- openadapt_ml/grounding/__init__.py +45 -0
- openadapt_ml/grounding/base.py +236 -0
- openadapt_ml/grounding/detector.py +570 -0
- openadapt_ml/ingest/__init__.py +43 -0
- openadapt_ml/ingest/capture.py +312 -0
- openadapt_ml/ingest/loader.py +232 -0
- openadapt_ml/ingest/synthetic.py +1102 -0
- openadapt_ml/models/__init__.py +0 -0
- openadapt_ml/models/api_adapter.py +171 -0
- openadapt_ml/models/base_adapter.py +59 -0
- openadapt_ml/models/dummy_adapter.py +42 -0
- openadapt_ml/models/qwen_vl.py +426 -0
- openadapt_ml/runtime/__init__.py +0 -0
- openadapt_ml/runtime/policy.py +182 -0
- openadapt_ml/schemas/__init__.py +53 -0
- openadapt_ml/schemas/sessions.py +122 -0
- openadapt_ml/schemas/validation.py +252 -0
- openadapt_ml/scripts/__init__.py +0 -0
- openadapt_ml/scripts/compare.py +1490 -0
- openadapt_ml/scripts/demo_policy.py +62 -0
- openadapt_ml/scripts/eval_policy.py +287 -0
- openadapt_ml/scripts/make_gif.py +153 -0
- openadapt_ml/scripts/prepare_synthetic.py +43 -0
- openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
- openadapt_ml/scripts/train.py +174 -0
- openadapt_ml/training/__init__.py +0 -0
- openadapt_ml/training/benchmark_viewer.py +1538 -0
- openadapt_ml/training/shared_ui.py +157 -0
- openadapt_ml/training/stub_provider.py +276 -0
- openadapt_ml/training/trainer.py +2446 -0
- openadapt_ml/training/viewer.py +2970 -0
- openadapt_ml-0.1.0.dist-info/METADATA +818 -0
- openadapt_ml-0.1.0.dist-info/RECORD +55 -0
- openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
- openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""Shared UI components for dashboards and viewers.
|
|
2
|
+
|
|
3
|
+
This module contains CSS and HTML generation functions used by both
|
|
4
|
+
the Training Dashboard and the Viewer for visual consistency.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def get_shared_header_css() -> str:
|
|
11
|
+
"""Generate CSS for the shared dashboard header.
|
|
12
|
+
|
|
13
|
+
This CSS is used by both the Training Dashboard and the Viewer.
|
|
14
|
+
Any changes here will affect all dashboards consistently.
|
|
15
|
+
"""
|
|
16
|
+
return '''
|
|
17
|
+
.unified-header {
|
|
18
|
+
display: flex;
|
|
19
|
+
align-items: center;
|
|
20
|
+
justify-content: space-between;
|
|
21
|
+
padding: 12px 24px;
|
|
22
|
+
background: linear-gradient(180deg, rgba(18,18,26,0.98) 0%, rgba(26,26,36,0.98) 100%);
|
|
23
|
+
border-bottom: 1px solid rgba(255,255,255,0.08);
|
|
24
|
+
margin-bottom: 20px;
|
|
25
|
+
gap: 16px;
|
|
26
|
+
flex-wrap: wrap;
|
|
27
|
+
box-shadow: 0 2px 8px rgba(0,0,0,0.3);
|
|
28
|
+
}
|
|
29
|
+
.unified-header .nav-tabs {
|
|
30
|
+
display: flex;
|
|
31
|
+
align-items: center;
|
|
32
|
+
gap: 4px;
|
|
33
|
+
background: rgba(0,0,0,0.3);
|
|
34
|
+
padding: 4px;
|
|
35
|
+
border-radius: 8px;
|
|
36
|
+
}
|
|
37
|
+
.unified-header .nav-tab {
|
|
38
|
+
padding: 8px 16px;
|
|
39
|
+
border-radius: 6px;
|
|
40
|
+
font-size: 0.85rem;
|
|
41
|
+
font-weight: 500;
|
|
42
|
+
text-decoration: none;
|
|
43
|
+
color: var(--text-secondary);
|
|
44
|
+
background: transparent;
|
|
45
|
+
border: none;
|
|
46
|
+
transition: all 0.2s;
|
|
47
|
+
cursor: pointer;
|
|
48
|
+
}
|
|
49
|
+
.unified-header .nav-tab:hover {
|
|
50
|
+
color: var(--text-primary);
|
|
51
|
+
background: rgba(255,255,255,0.05);
|
|
52
|
+
}
|
|
53
|
+
.unified-header .nav-tab.active {
|
|
54
|
+
color: var(--bg-primary);
|
|
55
|
+
background: var(--accent);
|
|
56
|
+
font-weight: 600;
|
|
57
|
+
}
|
|
58
|
+
.unified-header .controls-section {
|
|
59
|
+
display: flex;
|
|
60
|
+
align-items: center;
|
|
61
|
+
gap: 24px;
|
|
62
|
+
flex-wrap: wrap;
|
|
63
|
+
}
|
|
64
|
+
.unified-header .control-group {
|
|
65
|
+
display: flex;
|
|
66
|
+
align-items: center;
|
|
67
|
+
gap: 10px;
|
|
68
|
+
}
|
|
69
|
+
.unified-header .control-label {
|
|
70
|
+
font-size: 0.7rem;
|
|
71
|
+
color: var(--text-muted);
|
|
72
|
+
font-weight: 600;
|
|
73
|
+
letter-spacing: 0.5px;
|
|
74
|
+
text-transform: uppercase;
|
|
75
|
+
}
|
|
76
|
+
.unified-header select {
|
|
77
|
+
padding: 8px 32px 8px 12px;
|
|
78
|
+
border-radius: 8px;
|
|
79
|
+
font-size: 0.85rem;
|
|
80
|
+
background: rgba(0,0,0,0.4);
|
|
81
|
+
color: var(--text-primary);
|
|
82
|
+
border: 1px solid rgba(255,255,255,0.1);
|
|
83
|
+
cursor: pointer;
|
|
84
|
+
appearance: none;
|
|
85
|
+
background-image: url('data:image/svg+xml,%3Csvg xmlns=%27http://www.w3.org/2000/svg%27 width=%2712%27 height=%278%27%3E%3Cpath fill=%27%23888%27 d=%27M0 0l6 8 6-8z%27/%3E%3C/svg%3E');
|
|
86
|
+
background-repeat: no-repeat;
|
|
87
|
+
background-position: right 10px center;
|
|
88
|
+
transition: all 0.2s;
|
|
89
|
+
}
|
|
90
|
+
.unified-header select:hover {
|
|
91
|
+
border-color: var(--accent);
|
|
92
|
+
background-color: rgba(0,212,170,0.1);
|
|
93
|
+
}
|
|
94
|
+
.unified-header select:focus {
|
|
95
|
+
outline: none;
|
|
96
|
+
border-color: var(--accent);
|
|
97
|
+
box-shadow: 0 0 0 2px rgba(0,212,170,0.2);
|
|
98
|
+
}
|
|
99
|
+
.unified-header .header-meta {
|
|
100
|
+
font-size: 0.75rem;
|
|
101
|
+
color: var(--text-muted);
|
|
102
|
+
font-family: "SF Mono", Monaco, monospace;
|
|
103
|
+
}
|
|
104
|
+
'''
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def generate_shared_header_html(
|
|
108
|
+
active_page: str,
|
|
109
|
+
controls_html: str = "",
|
|
110
|
+
meta_html: str = "",
|
|
111
|
+
) -> str:
|
|
112
|
+
"""Generate the shared header HTML.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
active_page: Either "training", "viewer", or "benchmarks" to highlight the active tab
|
|
116
|
+
controls_html: Optional HTML for control groups (dropdowns, etc.)
|
|
117
|
+
meta_html: Optional HTML for metadata display (job ID, capture ID, etc.)
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
HTML string for the header
|
|
121
|
+
"""
|
|
122
|
+
training_active = "active" if active_page == "training" else ""
|
|
123
|
+
viewer_active = "active" if active_page == "viewer" else ""
|
|
124
|
+
benchmarks_active = "active" if active_page == "benchmarks" else ""
|
|
125
|
+
|
|
126
|
+
controls_section = ""
|
|
127
|
+
if controls_html or meta_html:
|
|
128
|
+
controls_section = f'''
|
|
129
|
+
<div class="controls-section">
|
|
130
|
+
{controls_html}
|
|
131
|
+
{f'<span class="header-meta">{meta_html}</span>' if meta_html else ''}
|
|
132
|
+
</div>
|
|
133
|
+
'''
|
|
134
|
+
|
|
135
|
+
return f'''
|
|
136
|
+
<div class="unified-header">
|
|
137
|
+
<div class="nav-tabs">
|
|
138
|
+
<a href="dashboard.html" class="nav-tab {training_active}">Training</a>
|
|
139
|
+
<a href="viewer.html" class="nav-tab {viewer_active}">Viewer</a>
|
|
140
|
+
<a href="benchmark.html" class="nav-tab {benchmarks_active}">Benchmarks</a>
|
|
141
|
+
</div>
|
|
142
|
+
{controls_section}
|
|
143
|
+
</div>
|
|
144
|
+
'''
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def build_nav_links() -> list[tuple[str, str]]:
|
|
148
|
+
"""Build navigation links for multi-capture dashboards.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
List of (filename, label) tuples
|
|
152
|
+
"""
|
|
153
|
+
return [
|
|
154
|
+
("dashboard.html", "Training"),
|
|
155
|
+
("viewer.html", "Viewer"),
|
|
156
|
+
("benchmark.html", "Benchmarks"),
|
|
157
|
+
]
|
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
"""Stub training provider for rapid UI testing without actual training."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import random
|
|
5
|
+
import sys
|
|
6
|
+
import time
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class StubTrainingProvider:
|
|
12
|
+
"""Simulates training without actual computation.
|
|
13
|
+
|
|
14
|
+
Use this to test dashboard, viewer, stop button, etc. without
|
|
15
|
+
waiting for real training on GPU or Lambda.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
output_dir: Path,
|
|
21
|
+
epochs: int = 5,
|
|
22
|
+
steps_per_epoch: int = 10,
|
|
23
|
+
step_delay: float = 0.5,
|
|
24
|
+
early_stop_loss: float = 0.0,
|
|
25
|
+
early_stop_patience: int = 3,
|
|
26
|
+
):
|
|
27
|
+
"""Initialize stub provider.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
output_dir: Directory to write training_log.json
|
|
31
|
+
epochs: Number of epochs to simulate
|
|
32
|
+
steps_per_epoch: Steps per epoch
|
|
33
|
+
step_delay: Delay between steps in seconds (for realistic feel)
|
|
34
|
+
early_stop_loss: Stop if loss drops below this threshold
|
|
35
|
+
early_stop_patience: Number of consecutive steps below threshold before stopping
|
|
36
|
+
"""
|
|
37
|
+
self.output_dir = Path(output_dir)
|
|
38
|
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
39
|
+
|
|
40
|
+
self.epochs = epochs
|
|
41
|
+
self.steps_per_epoch = steps_per_epoch
|
|
42
|
+
self.step_delay = step_delay
|
|
43
|
+
self.early_stop_loss = early_stop_loss
|
|
44
|
+
self.early_stop_patience = early_stop_patience
|
|
45
|
+
|
|
46
|
+
self.current_epoch = 0
|
|
47
|
+
self.current_step = 0
|
|
48
|
+
self.losses = []
|
|
49
|
+
self.evaluations = []
|
|
50
|
+
self.start_time = time.time()
|
|
51
|
+
self.job_id = f"stub_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
52
|
+
self.consecutive_low_loss = 0
|
|
53
|
+
self.termination_status = None
|
|
54
|
+
self.termination_message = None
|
|
55
|
+
|
|
56
|
+
# Set up logging to file
|
|
57
|
+
self.log_file = self.output_dir / "training.log"
|
|
58
|
+
self.log_handle = None
|
|
59
|
+
|
|
60
|
+
def _log(self, message: str, to_stdout: bool = True):
|
|
61
|
+
"""Write message to both log file and stdout.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
message: Message to log
|
|
65
|
+
to_stdout: If True, also print to stdout (default: True)
|
|
66
|
+
"""
|
|
67
|
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
68
|
+
log_line = f"[{timestamp}] {message}"
|
|
69
|
+
|
|
70
|
+
# Write to file
|
|
71
|
+
if self.log_handle is None:
|
|
72
|
+
self.log_handle = open(self.log_file, "w", buffering=1) # Line buffered
|
|
73
|
+
|
|
74
|
+
self.log_handle.write(log_line + "\n")
|
|
75
|
+
self.log_handle.flush()
|
|
76
|
+
|
|
77
|
+
# Print to stdout
|
|
78
|
+
if to_stdout:
|
|
79
|
+
print(message)
|
|
80
|
+
|
|
81
|
+
def simulate_step(self) -> dict:
|
|
82
|
+
"""Simulate one training step.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Current training status dict
|
|
86
|
+
"""
|
|
87
|
+
# Generate decreasing loss with noise
|
|
88
|
+
progress = self.current_step / (self.epochs * self.steps_per_epoch)
|
|
89
|
+
base_loss = 2.5 * (1 - progress * 0.8) # Decrease from 2.5 to ~0.5
|
|
90
|
+
noise = random.uniform(-0.15, 0.15)
|
|
91
|
+
loss = max(0.1, base_loss + noise)
|
|
92
|
+
|
|
93
|
+
elapsed = time.time() - self.start_time
|
|
94
|
+
|
|
95
|
+
self.losses.append({
|
|
96
|
+
"epoch": self.current_epoch,
|
|
97
|
+
"step": self.current_step + 1,
|
|
98
|
+
"loss": loss,
|
|
99
|
+
"lr": 5e-5,
|
|
100
|
+
"time": elapsed,
|
|
101
|
+
})
|
|
102
|
+
|
|
103
|
+
self.current_step += 1
|
|
104
|
+
|
|
105
|
+
# Check for epoch completion
|
|
106
|
+
if self.current_step % self.steps_per_epoch == 0:
|
|
107
|
+
self._generate_epoch_evaluation()
|
|
108
|
+
self.current_epoch += 1
|
|
109
|
+
# Cap at max epochs for display
|
|
110
|
+
if self.current_epoch > self.epochs:
|
|
111
|
+
self.current_epoch = self.epochs
|
|
112
|
+
|
|
113
|
+
return self.get_status()
|
|
114
|
+
|
|
115
|
+
def _generate_epoch_evaluation(self):
|
|
116
|
+
"""Generate fake evaluation for completed epoch."""
|
|
117
|
+
# Improve accuracy as training progresses
|
|
118
|
+
progress = self.current_epoch / self.epochs
|
|
119
|
+
accuracy_boost = progress * 0.3 # Up to 30% improvement
|
|
120
|
+
|
|
121
|
+
# Use real screenshot if available, otherwise placeholder
|
|
122
|
+
sample_path = self.output_dir / "screenshots" / "sample.png"
|
|
123
|
+
if not sample_path.exists():
|
|
124
|
+
# Try to copy from common capture location
|
|
125
|
+
import shutil
|
|
126
|
+
capture_screenshots = Path.home() / "oa/src/openadapt-capture/turn-off-nightshift/screenshots"
|
|
127
|
+
if capture_screenshots.exists():
|
|
128
|
+
sample_path.parent.mkdir(parents=True, exist_ok=True)
|
|
129
|
+
for img in capture_screenshots.glob("*.png"):
|
|
130
|
+
shutil.copy(img, sample_path)
|
|
131
|
+
break # Just copy the first one
|
|
132
|
+
|
|
133
|
+
self.evaluations.append({
|
|
134
|
+
"epoch": self.current_epoch,
|
|
135
|
+
"sample_idx": 7, # Match the real training sample
|
|
136
|
+
"image_path": "screenshots/sample.png",
|
|
137
|
+
"human_action": {
|
|
138
|
+
"type": "click",
|
|
139
|
+
"x": 0.65,
|
|
140
|
+
"y": 0.65,
|
|
141
|
+
"text": None,
|
|
142
|
+
},
|
|
143
|
+
"predicted_action": {
|
|
144
|
+
"type": "click",
|
|
145
|
+
"x": 0.65 + random.uniform(-0.15, 0.15) * (1 - accuracy_boost),
|
|
146
|
+
"y": 0.65 + random.uniform(-0.15, 0.15) * (1 - accuracy_boost),
|
|
147
|
+
"raw_output": f"Thought: [Stub] Epoch {self.current_epoch} - analyzing screenshot to find target element. The model is learning to identify UI components.\nAction: CLICK(x=0.65, y=0.65)",
|
|
148
|
+
},
|
|
149
|
+
"distance": random.uniform(0.05, 0.2) * (1 - accuracy_boost),
|
|
150
|
+
"correct": random.random() > (0.5 - accuracy_boost),
|
|
151
|
+
})
|
|
152
|
+
|
|
153
|
+
def get_status(self) -> dict:
|
|
154
|
+
"""Return current training status.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
Status dict compatible with training_log.json format
|
|
158
|
+
"""
|
|
159
|
+
current_loss = self.losses[-1]["loss"] if self.losses else 0
|
|
160
|
+
elapsed = time.time() - self.start_time
|
|
161
|
+
|
|
162
|
+
# Determine status
|
|
163
|
+
if self.termination_status:
|
|
164
|
+
status = "completed" if self.termination_status == "auto_complete" else self.termination_status
|
|
165
|
+
elif self.is_complete():
|
|
166
|
+
status = "completed"
|
|
167
|
+
else:
|
|
168
|
+
status = "training"
|
|
169
|
+
|
|
170
|
+
return {
|
|
171
|
+
"job_id": self.job_id,
|
|
172
|
+
"hostname": "stub-local",
|
|
173
|
+
"capture_path": "/stub/capture",
|
|
174
|
+
"config_path": "configs/stub.yaml",
|
|
175
|
+
"instance_type": "stub",
|
|
176
|
+
"instance_ip": "127.0.0.1",
|
|
177
|
+
"started_at": datetime.fromtimestamp(self.start_time).isoformat() + "Z",
|
|
178
|
+
"cloud_provider": "stub",
|
|
179
|
+
"cloud_dashboard_url": "",
|
|
180
|
+
"cloud_instance_id": "stub",
|
|
181
|
+
"setup_status": "training",
|
|
182
|
+
"setup_logs": ["[Stub] Simulated training in progress..."],
|
|
183
|
+
"epoch": self.current_epoch,
|
|
184
|
+
"step": self.current_step,
|
|
185
|
+
"total_steps": self.epochs * self.steps_per_epoch,
|
|
186
|
+
"total_epochs": self.epochs,
|
|
187
|
+
"loss": current_loss,
|
|
188
|
+
"learning_rate": 5e-5,
|
|
189
|
+
"samples_seen": self.current_step,
|
|
190
|
+
"elapsed_time": elapsed,
|
|
191
|
+
"losses": self.losses,
|
|
192
|
+
"evaluations": self.evaluations,
|
|
193
|
+
"status": status,
|
|
194
|
+
"termination_status": self.termination_status,
|
|
195
|
+
"termination_message": self.termination_message,
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
def write_status(self):
|
|
199
|
+
"""Write current status to training_log.json."""
|
|
200
|
+
log_path = self.output_dir / "training_log.json"
|
|
201
|
+
log_path.write_text(json.dumps(self.get_status(), indent=2))
|
|
202
|
+
|
|
203
|
+
def is_complete(self) -> bool:
|
|
204
|
+
"""Check if training simulation is complete."""
|
|
205
|
+
return self.current_epoch >= self.epochs
|
|
206
|
+
|
|
207
|
+
def check_stop_signal(self) -> bool:
|
|
208
|
+
"""Check if stop signal file exists."""
|
|
209
|
+
stop_file = self.output_dir / "STOP_TRAINING"
|
|
210
|
+
return stop_file.exists()
|
|
211
|
+
|
|
212
|
+
def run(self, callback=None):
|
|
213
|
+
"""Run the full training simulation.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
callback: Optional function called after each step with status dict
|
|
217
|
+
"""
|
|
218
|
+
self._log(f"[Stub] Starting simulated training: {self.epochs} epochs, {self.steps_per_epoch} steps/epoch")
|
|
219
|
+
self._log(f"[Stub] Output: {self.output_dir}")
|
|
220
|
+
self._log(f"[Stub] Step delay: {self.step_delay}s (total ~{self.epochs * self.steps_per_epoch * self.step_delay:.0f}s)")
|
|
221
|
+
if self.early_stop_loss > 0:
|
|
222
|
+
self._log(f"[Stub] Early stop: loss < {self.early_stop_loss} for {self.early_stop_patience} steps")
|
|
223
|
+
self._log("")
|
|
224
|
+
|
|
225
|
+
while not self.is_complete():
|
|
226
|
+
# Check for user stop signal
|
|
227
|
+
if self.check_stop_signal():
|
|
228
|
+
self._log("\n[Stub] Stop signal received from user!")
|
|
229
|
+
(self.output_dir / "STOP_TRAINING").unlink(missing_ok=True)
|
|
230
|
+
self.termination_status = "user_stop"
|
|
231
|
+
self.termination_message = f"Stopped at epoch {self.current_epoch + 1}, step {self.current_step}"
|
|
232
|
+
self.write_status()
|
|
233
|
+
break
|
|
234
|
+
|
|
235
|
+
status = self.simulate_step()
|
|
236
|
+
|
|
237
|
+
# Check for early stop loss
|
|
238
|
+
loss = status["loss"]
|
|
239
|
+
if self.early_stop_loss > 0 and loss < self.early_stop_loss:
|
|
240
|
+
self.consecutive_low_loss += 1
|
|
241
|
+
if self.consecutive_low_loss >= self.early_stop_patience:
|
|
242
|
+
self._log(f"\n[Stub] Auto-stopped: loss ({loss:.4f}) < {self.early_stop_loss} for {self.early_stop_patience} steps")
|
|
243
|
+
self.termination_status = "auto_low_loss"
|
|
244
|
+
self.termination_message = f"Loss reached {loss:.4f} (< {self.early_stop_loss})"
|
|
245
|
+
self.write_status()
|
|
246
|
+
break
|
|
247
|
+
else:
|
|
248
|
+
self.consecutive_low_loss = 0
|
|
249
|
+
|
|
250
|
+
self.write_status()
|
|
251
|
+
|
|
252
|
+
# Progress output
|
|
253
|
+
epoch = status["epoch"]
|
|
254
|
+
step = status["step"]
|
|
255
|
+
display_epoch = min(epoch + 1, self.epochs) # Cap at max for display
|
|
256
|
+
self._log(f" Epoch {display_epoch}/{self.epochs} | Step {step} | Loss: {loss:.4f}")
|
|
257
|
+
|
|
258
|
+
if callback:
|
|
259
|
+
callback(status)
|
|
260
|
+
|
|
261
|
+
time.sleep(self.step_delay)
|
|
262
|
+
|
|
263
|
+
# Set completion status if not already set
|
|
264
|
+
if self.termination_status is None:
|
|
265
|
+
self.termination_status = "auto_complete"
|
|
266
|
+
self.termination_message = f"Completed {self.epochs} epochs"
|
|
267
|
+
self.write_status()
|
|
268
|
+
|
|
269
|
+
self._log(f"\n[Stub] Training complete: {self.termination_status}")
|
|
270
|
+
|
|
271
|
+
# Close log file
|
|
272
|
+
if self.log_handle:
|
|
273
|
+
self.log_handle.close()
|
|
274
|
+
self.log_handle = None
|
|
275
|
+
|
|
276
|
+
return self.get_status()
|