rapidfireai 0.10.2rc5__py3-none-any.whl → 0.11.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rapidfireai might be problematic. Click here for more details.

Files changed (36) hide show
  1. rapidfireai/automl/grid_search.py +4 -5
  2. rapidfireai/automl/model_config.py +41 -37
  3. rapidfireai/automl/random_search.py +21 -33
  4. rapidfireai/backend/controller.py +80 -161
  5. rapidfireai/backend/worker.py +26 -8
  6. rapidfireai/cli.py +171 -132
  7. rapidfireai/db/rf_db.py +1 -1
  8. rapidfireai/db/tables.sql +1 -1
  9. rapidfireai/dispatcher/dispatcher.py +3 -1
  10. rapidfireai/dispatcher/gunicorn.conf.py +1 -1
  11. rapidfireai/experiment.py +86 -7
  12. rapidfireai/frontend/build/asset-manifest.json +3 -3
  13. rapidfireai/frontend/build/index.html +1 -1
  14. rapidfireai/frontend/build/static/js/{main.1bf27639.js → main.58393d31.js} +3 -3
  15. rapidfireai/frontend/build/static/js/{main.1bf27639.js.map → main.58393d31.js.map} +1 -1
  16. rapidfireai/frontend/proxy_middleware.py +1 -1
  17. rapidfireai/ml/callbacks.py +85 -59
  18. rapidfireai/ml/trainer.py +42 -86
  19. rapidfireai/start.sh +117 -34
  20. rapidfireai/utils/constants.py +22 -1
  21. rapidfireai/utils/experiment_utils.py +87 -43
  22. rapidfireai/utils/interactive_controller.py +473 -0
  23. rapidfireai/utils/logging.py +1 -2
  24. rapidfireai/utils/metric_logger.py +346 -0
  25. rapidfireai/utils/mlflow_manager.py +0 -1
  26. rapidfireai/utils/ping.py +4 -2
  27. rapidfireai/utils/worker_manager.py +16 -6
  28. rapidfireai/version.py +2 -2
  29. {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/METADATA +7 -4
  30. {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/RECORD +36 -33
  31. tutorial_notebooks/rf-colab-tensorboard-tutorial.ipynb +314 -0
  32. /rapidfireai/frontend/build/static/js/{main.1bf27639.js.LICENSE.txt → main.58393d31.js.LICENSE.txt} +0 -0
  33. {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/WHEEL +0 -0
  34. {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/entry_points.txt +0 -0
  35. {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/licenses/LICENSE +0 -0
  36. {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,473 @@
1
+ """
2
+ Interactive Controller for Jupyter/Colab notebooks.
3
+ Provides UI controls for managing training runs similar to the frontend.
4
+ """
5
+
6
+ import json
7
+ import threading
8
+ import time
9
+
10
+ import requests
11
+ from IPython.display import display
12
+
13
+ try:
14
+ import ipywidgets as widgets
15
+ except ImportError as e:
16
+ raise ImportError("ipywidgets is required for InteractiveController. Install with: pip install ipywidgets") from e
17
+
18
+
19
+ class InteractiveController:
20
+ """Interactive run controller for notebooks"""
21
+
22
+ def __init__(self, dispatcher_url: str = "http://127.0.0.1:8081"):
23
+ self.dispatcher_url = dispatcher_url.rstrip("/")
24
+ self.run_id: int | None = None
25
+ self.config: dict | None = None
26
+ self.status: str = "Unknown"
27
+ self.chunk_number: int = 0
28
+
29
+ # Create UI widgets
30
+ self._create_widgets()
31
+
32
+ def _create_widgets(self):
33
+ """Create ipywidgets UI components"""
34
+ # Run selector
35
+ self.run_selector = widgets.Dropdown(
36
+ options=[], description="", disabled=False, layout=widgets.Layout(width="300px")
37
+ )
38
+ self.load_btn = widgets.Button(
39
+ description="Load Run", button_style="primary", tooltip="Load the selected run", icon="download"
40
+ )
41
+ self.refresh_selector_btn = widgets.Button(
42
+ description="Refresh List",
43
+ button_style="info",
44
+ tooltip="Refresh the list of available runs",
45
+ icon="refresh",
46
+ )
47
+
48
+ # Status display
49
+ self.status_label = widgets.HTML(value="<b>Status:</b> Not loaded")
50
+ self.chunk_label = widgets.HTML(value="<b>Chunk:</b> N/A")
51
+ self.run_id_label = widgets.HTML(value="<b>Run ID:</b> N/A")
52
+
53
+ # Action buttons
54
+ self.resume_btn = widgets.Button(
55
+ description="Resume",
56
+ button_style="success",
57
+ tooltip="Resume this run",
58
+ icon="play",
59
+ )
60
+ self.stop_btn = widgets.Button(description="Stop", button_style="danger", tooltip="Stop this run", icon="stop")
61
+ self.delete_btn = widgets.Button(
62
+ description="Delete",
63
+ button_style="danger",
64
+ tooltip="Delete this run",
65
+ icon="trash",
66
+ )
67
+ self.refresh_btn = widgets.Button(
68
+ description="Refresh Status",
69
+ button_style="info",
70
+ tooltip="Refresh current run status and metrics",
71
+ icon="sync",
72
+ )
73
+
74
+ # Config editor (for clone/modify)
75
+ self.config_text = widgets.Textarea(
76
+ value="{}",
77
+ placeholder="Run configuration (JSON)",
78
+ disabled=True,
79
+ layout=widgets.Layout(width="100%", height="200px"),
80
+ )
81
+ self.warm_start_checkbox = widgets.Checkbox(
82
+ value=False,
83
+ description="Warm Start (continue from previous checkpoint)",
84
+ disabled=True,
85
+ style={"description_width": "initial"},
86
+ layout=widgets.Layout(margin="10px 0px"),
87
+ )
88
+ self.clone_btn = widgets.Button(
89
+ description="Clone",
90
+ button_style="primary",
91
+ tooltip="Clone this run with modifications",
92
+ )
93
+ self.submit_clone_btn = widgets.Button(description="✓ Submit Clone", button_style="success", disabled=True)
94
+ self.cancel_clone_btn = widgets.Button(description="✗ Cancel", button_style="", disabled=True)
95
+
96
+ # Status message box
97
+ self.status_message = widgets.HTML(
98
+ value="",
99
+ layout=widgets.Layout(
100
+ width="100%",
101
+ min_height="40px",
102
+ padding="10px",
103
+ margin="10px 0px",
104
+ border="2px solid #ddd",
105
+ border_radius="5px",
106
+ ),
107
+ )
108
+
109
+ # Experiment status display (live progress)
110
+ # self.experiment_status = widgets.HTML(
111
+ # value='<div style="padding: 10px; background-color: #f8f9fa; border: 2px solid #dee2e6; border-radius: 5px;">'
112
+ # '<b>Experiment Status:</b> Loading...'
113
+ # '</div>',
114
+ # layout=widgets.Layout(
115
+ # width='100%',
116
+ # margin='10px 0px'
117
+ # )
118
+ # )
119
+
120
+ # Bind button callbacks
121
+ self.refresh_selector_btn.on_click(lambda b: self.fetch_all_runs())
122
+ self.load_btn.on_click(lambda b: self._handle_load())
123
+ self.resume_btn.on_click(lambda b: self._handle_resume())
124
+ self.stop_btn.on_click(lambda b: self._handle_stop())
125
+ self.delete_btn.on_click(lambda b: self._handle_delete())
126
+ self.refresh_btn.on_click(lambda b: self.load_run(self.run_id) if self.run_id else None)
127
+ self.clone_btn.on_click(lambda b: self._enable_clone_mode())
128
+ self.submit_clone_btn.on_click(lambda b: self._handle_clone())
129
+ self.cancel_clone_btn.on_click(lambda b: self._handle_cancel_clone())
130
+
131
+ # Auto-load run when dropdown selection changes
132
+ self.run_selector.observe(self._on_run_selected, names="value")
133
+
134
+ def _show_message(self, message: str, message_type: str = "info"):
135
+ """Display a status message with styling"""
136
+ colors = {
137
+ "success": {"bg": "#d4edda", "border": "#28a745", "text": "#155724"},
138
+ "error": {"bg": "#f8d7da", "border": "#dc3545", "text": "#721c24"},
139
+ "info": {"bg": "#d1ecf1", "border": "#17a2b8", "text": "#0c5460"},
140
+ "warning": {"bg": "#fff3cd", "border": "#ffc107", "text": "#856404"},
141
+ }
142
+
143
+ style = colors.get(message_type, colors["info"])
144
+
145
+ self.status_message.value = f"""
146
+ <div style="
147
+ background-color: {style["bg"]};
148
+ border: 2px solid {style["border"]};
149
+ color: {style["text"]};
150
+ padding: 10px;
151
+ border-radius: 5px;
152
+ font-weight: 600;
153
+ ">
154
+ {message}
155
+ </div>
156
+ """
157
+
158
+ def _update_experiment_status(self):
159
+ """Update experiment status display with live progress"""
160
+ try:
161
+ response = requests.get(
162
+ f"{self.dispatcher_url}/dispatcher/get-all-runs",
163
+ timeout=5,
164
+ )
165
+ response.raise_for_status()
166
+ runs = response.json()
167
+
168
+ if runs:
169
+ total_runs = len(runs)
170
+ completed_runs = sum(1 for r in runs if r.get("status") == "COMPLETED")
171
+ ongoing_runs = sum(1 for r in runs if r.get("status") == "ONGOING")
172
+
173
+ # Determine status color and icon
174
+ if completed_runs == total_runs:
175
+ bg_color = "#d4edda"
176
+ border_color = "#28a745"
177
+ text_color = "#155724"
178
+ icon = "✓"
179
+ status_text = "All runs completed"
180
+ elif ongoing_runs > 0:
181
+ bg_color = "#d1ecf1"
182
+ border_color = "#17a2b8"
183
+ text_color = "#0c5460"
184
+ icon = "🔄"
185
+ status_text = "Training in progress"
186
+ else:
187
+ bg_color = "#fff3cd"
188
+ border_color = "#ffc107"
189
+ text_color = "#856404"
190
+ icon = "⏸"
191
+ status_text = "Training paused or stopped"
192
+
193
+ self.experiment_status.value = (
194
+ f'<div style="padding: 10px; background-color: {bg_color}; '
195
+ f'border: 2px solid {border_color}; border-radius: 5px; color: {text_color};">'
196
+ f"<b>{icon} Experiment Status:</b> {status_text}<br>"
197
+ f"<b>Progress:</b> {completed_runs}/{total_runs} runs completed"
198
+ "</div>"
199
+ )
200
+ else:
201
+ self.experiment_status.value = (
202
+ '<div style="padding: 10px; background-color: #f8f9fa; '
203
+ 'border: 2px solid #dee2e6; border-radius: 5px;">'
204
+ "<b>Experiment Status:</b> No runs found"
205
+ "</div>"
206
+ )
207
+
208
+ except requests.RequestException:
209
+ # Silently fail - don't update status if request fails
210
+ pass
211
+
212
+ def fetch_all_runs(self):
213
+ """Fetch all runs and populate dropdown"""
214
+ try:
215
+ response = requests.get(
216
+ f"{self.dispatcher_url}/dispatcher/get-all-runs",
217
+ timeout=5,
218
+ )
219
+ response.raise_for_status()
220
+ runs = response.json()
221
+
222
+ if runs:
223
+ # Create options as (label, value) tuples
224
+ options = [(f"Run {run['run_id']} - {run.get('status', 'Unknown')}", run["run_id"]) for run in runs]
225
+ self.run_selector.options = options
226
+ self._show_message(f"Found {len(runs)} runs", "success")
227
+ else:
228
+ self.run_selector.options = []
229
+ self._show_message("No runs found", "info")
230
+
231
+ # Update experiment status
232
+ # COMMENTED OUT
233
+ # self._update_experiment_status()
234
+
235
+ except requests.RequestException as e:
236
+ self._show_message(f"Error fetching runs: {e}", "error")
237
+
238
+ def _on_run_selected(self, change):
239
+ """Handle dropdown selection change - auto-load run"""
240
+ if change["new"] is not None:
241
+ self.load_run(change["new"])
242
+
243
+ def _handle_load(self):
244
+ """Handle load button click"""
245
+ if self.run_selector.value is not None:
246
+ self.load_run(self.run_selector.value)
247
+ else:
248
+ self._show_message("Please select a run first", "warning")
249
+
250
+ def load_run(self, run_id: int):
251
+ """Load run details from dispatcher API"""
252
+ self.run_id = run_id
253
+ try:
254
+ response = requests.post(
255
+ f"{self.dispatcher_url}/dispatcher/get-run",
256
+ json={"run_id": run_id},
257
+ timeout=5,
258
+ )
259
+ response.raise_for_status()
260
+ data = response.json()
261
+
262
+ # Update state
263
+ self.config = data.get("config", {})
264
+ self.status = data.get("status", "Unknown")
265
+ self.chunk_number = data.get("num_chunks_visited", 0)
266
+
267
+ # Update UI
268
+ self._update_display()
269
+ self._show_message(f"Loaded run {run_id}", "success")
270
+
271
+ # Update experiment status
272
+ # COMMENTED OUT
273
+ # self._update_experiment_status()
274
+
275
+ except requests.RequestException as e:
276
+ self._show_message(f"Error loading run: {e}", "error")
277
+
278
+ def _update_display(self):
279
+ """Update widget values"""
280
+ self.run_id_label.value = f"<b>Run ID:</b> {self.run_id}"
281
+ self.status_label.value = f"<b>Status:</b> {self.status}"
282
+ self.chunk_label.value = f"<b>Chunk:</b> {self.chunk_number}"
283
+ self.config_text.value = json.dumps(self.config, indent=2)
284
+
285
+ # Disable buttons if completed
286
+ is_completed = self.status.lower() == "completed"
287
+ self.resume_btn.disabled = is_completed
288
+ self.stop_btn.disabled = is_completed
289
+ self.clone_btn.disabled = is_completed
290
+ self.delete_btn.disabled = is_completed
291
+
292
+ def _handle_resume(self):
293
+ """Resume the run"""
294
+ try:
295
+ response = requests.post(
296
+ f"{self.dispatcher_url}/dispatcher/resume-run",
297
+ json={"run_id": self.run_id},
298
+ timeout=5,
299
+ )
300
+ response.raise_for_status()
301
+ result = response.json()
302
+
303
+ if result.get("error"):
304
+ self._show_message(f"Error: {result['error']}", "error")
305
+ else:
306
+ self._show_message(f"Resumed run {self.run_id}", "success")
307
+ self.load_run(self.run_id)
308
+ except requests.RequestException as e:
309
+ self._show_message(f"Error resuming run: {e}", "error")
310
+
311
+ def _handle_stop(self):
312
+ """Stop the run"""
313
+ try:
314
+ response = requests.post(
315
+ f"{self.dispatcher_url}/dispatcher/stop-run",
316
+ json={"run_id": self.run_id},
317
+ timeout=5,
318
+ )
319
+ response.raise_for_status()
320
+ result = response.json()
321
+
322
+ if result.get("error"):
323
+ self._show_message(f"Error: {result['error']}", "error")
324
+ else:
325
+ self._show_message(f"Stopped run {self.run_id}", "success")
326
+ self.load_run(self.run_id)
327
+ except requests.RequestException as e:
328
+ self._show_message(f"Error stopping run: {e}", "error")
329
+
330
+ def _handle_delete(self):
331
+ """Delete the run"""
332
+ try:
333
+ response = requests.post(
334
+ f"{self.dispatcher_url}/dispatcher/delete-run",
335
+ json={"run_id": self.run_id},
336
+ timeout=5,
337
+ )
338
+ response.raise_for_status()
339
+ result = response.json()
340
+
341
+ if result.get("error"):
342
+ self._show_message(f"Error: {result['error']}", "error")
343
+ else:
344
+ self._show_message(f"Deleted run {self.run_id}", "success")
345
+ except requests.RequestException as e:
346
+ self._show_message(f"Error deleting run: {e}", "error")
347
+
348
+ def _enable_clone_mode(self):
349
+ """Enable config editing for clone/modify"""
350
+ self.config_text.disabled = False
351
+ self.warm_start_checkbox.disabled = False
352
+ self.submit_clone_btn.disabled = False
353
+ self.cancel_clone_btn.disabled = False
354
+ self.clone_btn.disabled = True
355
+ self._show_message("Edit config and click Submit to clone", "info")
356
+
357
+ def _disable_clone_mode(self):
358
+ """Disable config editing"""
359
+ self.config_text.disabled = True
360
+ self.config_text.value = json.dumps(self.config, indent=2)
361
+ self.warm_start_checkbox.disabled = True
362
+ self.warm_start_checkbox.value = False
363
+ self.submit_clone_btn.disabled = True
364
+ self.cancel_clone_btn.disabled = True
365
+ self.clone_btn.disabled = False
366
+
367
+ def _handle_cancel_clone(self):
368
+ """Handle cancel clone button click"""
369
+ self._disable_clone_mode()
370
+ self._show_message("Cancelled clone", "info")
371
+
372
+ def _enable_colab_widgets(self):
373
+ """Enable custom widget manager for Google Colab"""
374
+ try:
375
+ # Try to import google.colab to detect if we're in Colab
376
+ import google.colab
377
+
378
+ # Enable custom widget manager for ipywidgets to work in Colab
379
+ from google.colab import output
380
+
381
+ output.enable_custom_widget_manager()
382
+ except ImportError:
383
+ # Not in Colab, no action needed
384
+ pass
385
+
386
+ def _handle_clone(self):
387
+ """Clone/modify the run"""
388
+ try:
389
+ # Parse config
390
+ try:
391
+ new_config = json.loads(self.config_text.value)
392
+ except json.JSONDecodeError as e:
393
+ self._show_message(f"Invalid JSON: {e}", "error")
394
+ return
395
+
396
+ response = requests.post(
397
+ f"{self.dispatcher_url}/dispatcher/clone-modify-run",
398
+ json={
399
+ "run_id": self.run_id,
400
+ "config": new_config,
401
+ "warm_start": self.warm_start_checkbox.value,
402
+ },
403
+ timeout=5,
404
+ )
405
+ response.raise_for_status()
406
+ result = response.json()
407
+
408
+ if result.get("error") or (result.get("result") is False):
409
+ error_msg = result.get("err_msg") or result.get("error")
410
+ self._show_message(f"Error: {error_msg}", "error")
411
+ else:
412
+ self._show_message(f"Cloned run {self.run_id}", "success")
413
+ self._disable_clone_mode()
414
+
415
+ except requests.RequestException as e:
416
+ self._show_message(f"Error cloning run: {e}", "error")
417
+
418
+ def display(self):
419
+ """Display the interactive controller UI"""
420
+ # Enable custom widget manager for Google Colab
421
+ self._enable_colab_widgets()
422
+
423
+ # Layout
424
+ header = widgets.VBox(
425
+ [
426
+ widgets.HTML("<h3>Interactive Run Controller</h3>"),
427
+ widgets.HBox([self.run_id_label, self.status_label, self.chunk_label]),
428
+ ]
429
+ )
430
+
431
+ # Run selector section
432
+ selector_section = widgets.VBox(
433
+ [
434
+ widgets.HTML("<b>Select a Run:</b>"),
435
+ widgets.HBox([self.run_selector, self.load_btn, self.refresh_selector_btn]),
436
+ ]
437
+ )
438
+
439
+ actions = widgets.HBox([self.resume_btn, self.stop_btn, self.delete_btn, self.refresh_btn])
440
+
441
+ config_section = widgets.VBox(
442
+ [
443
+ widgets.HTML("<b>Configuration:</b>"),
444
+ self.config_text,
445
+ self.warm_start_checkbox,
446
+ widgets.HBox([self.clone_btn, self.submit_clone_btn, self.cancel_clone_btn]),
447
+ ]
448
+ )
449
+
450
+ # COMMENTED OUT - Displaying experiment status in cell
451
+ # ui = widgets.VBox([header, self.experiment_status, self.status_message, selector_section, actions, config_section])
452
+ ui = widgets.VBox([header, self.status_message, selector_section, actions, config_section])
453
+
454
+ display(ui)
455
+
456
+ # Automatically fetch available runs
457
+ self.fetch_all_runs()
458
+
459
+ # Load initial data if run_id set
460
+ if self.run_id:
461
+ self.load_run(self.run_id)
462
+
463
+ def auto_refresh(self, interval: int = 5):
464
+ """Auto-refresh status every N seconds (run in background)"""
465
+
466
+ def refresh_loop():
467
+ while True:
468
+ if self.run_id:
469
+ self.load_run(self.run_id)
470
+ time.sleep(interval)
471
+
472
+ thread = threading.Thread(target=refresh_loop, daemon=True)
473
+ thread.start()
@@ -1,7 +1,6 @@
1
1
  import os
2
2
  import threading
3
3
  from abc import ABC, abstractmethod
4
- from typing import Dict
5
4
 
6
5
  from loguru import logger
7
6
 
@@ -13,7 +12,7 @@ class BaseRFLogger(ABC):
13
12
  """Base class for RapidFire loggers"""
14
13
 
15
14
  _experiment_name = ""
16
- _initialized_loggers: Dict[str, bool] = {}
15
+ _initialized_loggers: dict[str, bool] = {}
17
16
  _lock = threading.Lock()
18
17
 
19
18
  def __init__(self, level: str = "DEBUG"):