rapidfireai 0.10.3rc1__py3-none-any.whl → 0.11.1rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rapidfireai might be problematic. Click here for more details.

Files changed (26) hide show
  1. rapidfireai/automl/grid_search.py +4 -5
  2. rapidfireai/automl/model_config.py +41 -37
  3. rapidfireai/automl/random_search.py +21 -33
  4. rapidfireai/backend/controller.py +54 -148
  5. rapidfireai/backend/worker.py +14 -3
  6. rapidfireai/cli.py +148 -136
  7. rapidfireai/experiment.py +22 -11
  8. rapidfireai/frontend/build/asset-manifest.json +3 -3
  9. rapidfireai/frontend/build/index.html +1 -1
  10. rapidfireai/frontend/build/static/js/{main.e7d3b759.js → main.aee6c455.js} +3 -3
  11. rapidfireai/frontend/build/static/js/{main.e7d3b759.js.map → main.aee6c455.js.map} +1 -1
  12. rapidfireai/ml/callbacks.py +10 -24
  13. rapidfireai/ml/trainer.py +37 -81
  14. rapidfireai/utils/constants.py +3 -1
  15. rapidfireai/utils/interactive_controller.py +40 -61
  16. rapidfireai/utils/logging.py +1 -2
  17. rapidfireai/utils/mlflow_manager.py +1 -0
  18. rapidfireai/utils/ping.py +4 -2
  19. rapidfireai/version.py +2 -2
  20. {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc2.dist-info}/METADATA +1 -1
  21. {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc2.dist-info}/RECORD +26 -26
  22. /rapidfireai/frontend/build/static/js/{main.e7d3b759.js.LICENSE.txt → main.aee6c455.js.LICENSE.txt} +0 -0
  23. {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc2.dist-info}/WHEEL +0 -0
  24. {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc2.dist-info}/entry_points.txt +0 -0
  25. {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc2.dist-info}/licenses/LICENSE +0 -0
  26. {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc2.dist-info}/top_level.txt +0 -0
rapidfireai/cli.py CHANGED
@@ -3,15 +3,16 @@
3
3
  Command-line interface for RapidFire AI
4
4
  """
5
5
 
6
- import os
7
- import sys
8
- import subprocess
9
6
  import argparse
7
+ import os
10
8
  import platform
11
- import shutil
12
9
  import re
10
+ import shutil
13
11
  import site
12
+ import subprocess
13
+ import sys
14
14
  from pathlib import Path
15
+
15
16
  from .version import __version__
16
17
 
17
18
 
@@ -20,24 +21,24 @@ def get_script_path():
20
21
  # Get the directory where this package is installed
21
22
  package_dir = Path(__file__).parent
22
23
  script_path = package_dir / "start.sh"
23
-
24
+
24
25
  if not script_path.exists():
25
26
  # Fallback: try to find it relative to the current working directory
26
27
  script_path = Path.cwd() / "rapidfireai" / "start.sh"
27
28
  if not script_path.exists():
28
29
  raise FileNotFoundError(f"Could not find start.sh script at {script_path}")
29
-
30
+
30
31
  return script_path
31
32
 
32
33
 
33
34
  def run_script(args):
34
35
  """Run the start.sh script with the given arguments."""
35
36
  script_path = get_script_path()
36
-
37
+
37
38
  # Make sure the script is executable
38
39
  if not os.access(script_path, os.X_OK):
39
40
  os.chmod(script_path, 0o755)
40
-
41
+
41
42
  # Run the script with the provided arguments
42
43
  try:
43
44
  result = subprocess.run([str(script_path)] + args, check=True)
@@ -53,24 +54,27 @@ def run_script(args):
53
54
  def get_python_info():
54
55
  """Get comprehensive Python information."""
55
56
  info = {}
56
-
57
+
57
58
  # Python version and implementation
58
- info['version'] = sys.version
59
- info['implementation'] = platform.python_implementation()
60
- info['executable'] = sys.executable
61
-
59
+ info["version"] = sys.version
60
+ info["implementation"] = platform.python_implementation()
61
+ info["executable"] = sys.executable
62
+
62
63
  # Environment information
63
- info['conda_env'] = os.environ.get('CONDA_DEFAULT_ENV', 'none')
64
- info['venv'] = 'yes' if hasattr(sys, 'real_prefix') or (hasattr(sys, 'base_prefix') and sys.base_prefix != sys.prefix) else 'no'
65
-
64
+ info["conda_env"] = os.environ.get("CONDA_DEFAULT_ENV", "none")
65
+ info["venv"] = (
66
+ "yes"
67
+ if hasattr(sys, "real_prefix") or (hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix)
68
+ else "no"
69
+ )
70
+
66
71
  return info
67
72
 
68
73
 
69
74
  def get_pip_packages():
70
75
  """Get list of installed pip packages."""
71
76
  try:
72
- result = subprocess.run([sys.executable, '-m', 'pip', 'list'],
73
- capture_output=True, text=True, check=True)
77
+ result = subprocess.run([sys.executable, "-m", "pip", "list"], capture_output=True, text=True, check=True)
74
78
  return result.stdout
75
79
  except (subprocess.CalledProcessError, FileNotFoundError):
76
80
  return "Failed to get pip packages"
@@ -79,113 +83,107 @@ def get_pip_packages():
79
83
  def get_gpu_info():
80
84
  """Get comprehensive GPU and CUDA information."""
81
85
  info = {}
82
-
86
+
83
87
  # Check for nvidia-smi
84
- nvidia_smi_path = shutil.which('nvidia-smi')
85
- info['nvidia_smi'] = 'found' if nvidia_smi_path else 'not found'
86
-
88
+ nvidia_smi_path = shutil.which("nvidia-smi")
89
+ info["nvidia_smi"] = "found" if nvidia_smi_path else "not found"
90
+
87
91
  if nvidia_smi_path:
88
92
  try:
89
93
  # Get driver and CUDA runtime version from the full nvidia-smi output
90
- result = subprocess.run(['nvidia-smi'],
91
- capture_output=True, text=True, check=True)
94
+ result = subprocess.run(["nvidia-smi"], capture_output=True, text=True, check=True)
92
95
  if result.stdout.strip():
93
- lines = result.stdout.strip().split('\n')
96
+ lines = result.stdout.strip().split("\n")
94
97
  # Look for the header line that contains CUDA version
95
98
  for line in lines:
96
- if 'CUDA Version:' in line:
99
+ if "CUDA Version:" in line:
97
100
  # Extract CUDA version from line like "NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.2"
98
- cuda_version = line.split('CUDA Version:')[1].split()[0]
99
- info['cuda_runtime'] = cuda_version
101
+ cuda_version = line.split("CUDA Version:")[1].split()[0]
102
+ info["cuda_runtime"] = cuda_version
100
103
  # Also extract driver version from the same line
101
- if 'Driver Version:' in line:
102
- driver_version = line.split('Driver Version:')[1].split('CUDA Version:')[0].strip()
103
- info['driver_version'] = driver_version
104
+ if "Driver Version:" in line:
105
+ driver_version = line.split("Driver Version:")[1].split("CUDA Version:")[0].strip()
106
+ info["driver_version"] = driver_version
104
107
  break
105
108
  else:
106
- info['driver_version'] = 'unknown'
107
- info['cuda_runtime'] = 'unknown'
109
+ info["driver_version"] = "unknown"
110
+ info["cuda_runtime"] = "unknown"
108
111
  except (subprocess.CalledProcessError, ValueError):
109
- info['driver_version'] = 'unknown'
110
- info['cuda_runtime'] = 'unknown'
111
-
112
+ info["driver_version"] = "unknown"
113
+ info["cuda_runtime"] = "unknown"
114
+
112
115
  # Get GPU count, models, and VRAM
113
116
  try:
114
- result = subprocess.run(['nvidia-smi', '--query-gpu=count,name,memory.total', '--format=csv,noheader,nounits'],
115
- capture_output=True, text=True, check=True)
117
+ result = subprocess.run(
118
+ ["nvidia-smi", "--query-gpu=count,name,memory.total", "--format=csv,noheader,nounits"],
119
+ capture_output=True,
120
+ text=True,
121
+ check=True,
122
+ )
116
123
  if result.stdout.strip():
117
- lines = result.stdout.strip().split('\n')
124
+ lines = result.stdout.strip().split("\n")
118
125
  if lines:
119
- count, name, memory = lines[0].split(', ')
120
- info['gpu_count'] = int(count)
121
- info['gpu_model'] = name.strip()
126
+ count, name, memory = lines[0].split(", ")
127
+ info["gpu_count"] = int(count)
128
+ info["gpu_model"] = name.strip()
122
129
  # Convert memory from MiB to GB
123
130
  memory_mib = int(memory.split()[0])
124
131
  memory_gb = memory_mib / 1024
125
- info['gpu_memory_gb'] = f"{memory_gb:.1f}"
126
-
132
+ info["gpu_memory_gb"] = f"{memory_gb:.1f}"
133
+
127
134
  # Get detailed info for multiple GPUs if present
128
- if info['gpu_count'] > 1:
129
- info['gpu_details'] = []
135
+ if info["gpu_count"] > 1:
136
+ info["gpu_details"] = []
130
137
  for line in lines:
131
- count, name, memory = line.split(', ')
138
+ count, name, memory = line.split(", ")
132
139
  memory_mib = int(memory.split()[0])
133
140
  memory_gb = memory_mib / 1024
134
- info['gpu_details'].append({
135
- 'name': name.strip(),
136
- 'memory_gb': f"{memory_gb:.1f}"
137
- })
141
+ info["gpu_details"].append({"name": name.strip(), "memory_gb": f"{memory_gb:.1f}"})
138
142
  except (subprocess.CalledProcessError, ValueError):
139
- info['gpu_count'] = 0
140
- info['gpu_model'] = 'unknown'
141
- info['gpu_memory_gb'] = 'unknown'
143
+ info["gpu_count"] = 0
144
+ info["gpu_model"] = "unknown"
145
+ info["gpu_memory_gb"] = "unknown"
142
146
  else:
143
- info['driver_version'] = 'N/A'
144
- info['cuda_runtime'] = 'N/A'
145
- info['gpu_count'] = 0
146
- info['gpu_model'] = 'N/A'
147
- info['gpu_memory_gb'] = 'N/A'
148
-
147
+ info["driver_version"] = "N/A"
148
+ info["cuda_runtime"] = "N/A"
149
+ info["gpu_count"] = 0
150
+ info["gpu_model"] = "N/A"
151
+ info["gpu_memory_gb"] = "N/A"
152
+
149
153
  # Check for nvcc (CUDA compiler)
150
- nvcc_path = shutil.which('nvcc')
151
- info['nvcc'] = 'found' if nvcc_path else 'not found'
152
-
154
+ nvcc_path = shutil.which("nvcc")
155
+ info["nvcc"] = "found" if nvcc_path else "not found"
156
+
153
157
  if nvcc_path:
154
158
  try:
155
- result = subprocess.run(['nvcc', '--version'],
156
- capture_output=True, text=True, check=True)
159
+ result = subprocess.run(["nvcc", "--version"], capture_output=True, text=True, check=True)
157
160
  # Extract version from output like "Cuda compilation tools, release 11.8, V11.8.89"
158
- version_line = result.stdout.split('\n')[0]
159
- if 'release' in version_line:
160
- version = version_line.split('release')[1].split(',')[0].strip()
161
- info['nvcc_version'] = version
161
+ version_line = result.stdout.split("\n")[0]
162
+ if "release" in version_line:
163
+ version = version_line.split("release")[1].split(",")[0].strip()
164
+ info["nvcc_version"] = version
162
165
  else:
163
- info['nvcc_version'] = 'unknown'
166
+ info["nvcc_version"] = "unknown"
164
167
  except subprocess.CalledProcessError:
165
- info['nvcc_version'] = 'unknown'
168
+ info["nvcc_version"] = "unknown"
166
169
  else:
167
- info['nvcc_version'] = 'N/A'
168
-
170
+ info["nvcc_version"] = "N/A"
171
+
169
172
  # Check CUDA installation paths
170
- cuda_paths = [
171
- '/usr/local/cuda',
172
- '/opt/cuda',
173
- '/usr/cuda',
174
- os.path.expanduser('~/cuda')
175
- ]
176
-
173
+ cuda_paths = ["/usr/local/cuda", "/opt/cuda", "/usr/cuda", os.path.expanduser("~/cuda")]
174
+
177
175
  cuda_installed = False
178
176
  for path in cuda_paths:
179
177
  if os.path.exists(path):
180
178
  cuda_installed = True
181
179
  break
182
-
183
- info['cuda_installation'] = 'present' if cuda_installed else 'not present'
184
-
180
+
181
+ info["cuda_installation"] = "present" if cuda_installed else "not present"
182
+
185
183
  # Check if CUDA is on PATH
186
- cuda_on_path = any('cuda' in p.lower() for p in os.environ.get('PATH', '').split(os.pathsep))
187
- info['cuda_on_path'] = 'yes' if cuda_on_path else 'no'
188
-
184
+ cuda_on_path = any("cuda" in p.lower() for p in os.environ.get("PATH", "").split(os.pathsep))
185
+ info["cuda_on_path"] = "yes" if cuda_on_path else "no"
186
+
189
187
  return info
190
188
 
191
189
 
@@ -193,7 +191,7 @@ def run_doctor():
193
191
  """Run the doctor command to diagnose system issues."""
194
192
  print("🔍 RapidFire AI System Diagnostics")
195
193
  print("=" * 50)
196
-
194
+
197
195
  # Python Information
198
196
  print("\n🐍 Python Environment:")
199
197
  print("-" * 30)
@@ -203,94 +201,112 @@ def run_doctor():
203
201
  print(f"Executable: {python_info['executable']}")
204
202
  print(f"Conda Environment: {python_info['conda_env']}")
205
203
  print(f"Virtual Environment: {python_info['venv']}")
206
-
204
+
207
205
  # Pip Packages
208
206
  print("\n📦 Installed Packages:")
209
207
  print("-" * 30)
210
208
  pip_output = get_pip_packages()
211
209
  if pip_output != "Failed to get pip packages":
212
210
  # Show only relevant packages
213
- relevant_packages = ['rapidfireai', 'mlflow', 'torch', 'transformers', 'flask', 'gunicorn', 'peft', 'trl', 'bitsandbytes', 'nltk', 'evaluate', 'rouge-score', 'sentencepiece']
214
- lines = pip_output.split('\n')
211
+ relevant_packages = [
212
+ "rapidfireai",
213
+ "mlflow",
214
+ "torch",
215
+ "transformers",
216
+ "flask",
217
+ "gunicorn",
218
+ "peft",
219
+ "trl",
220
+ "bitsandbytes",
221
+ "nltk",
222
+ "evaluate",
223
+ "rouge-score",
224
+ "sentencepiece",
225
+ ]
226
+ lines = pip_output.split("\n")
215
227
  for line in lines:
216
228
  if any(pkg.lower() in line.lower() for pkg in relevant_packages):
217
229
  print(line)
218
230
  print("... (showing only relevant packages)")
219
231
  else:
220
232
  print(pip_output)
221
-
233
+
222
234
  # GPU Information
223
235
  print("\n🚀 GPU & CUDA Information:")
224
236
  print("-" * 30)
225
237
  gpu_info = get_gpu_info()
226
238
  print(f"nvidia-smi: {gpu_info['nvidia_smi']}")
227
-
228
- if gpu_info['nvidia_smi'] == 'found':
239
+
240
+ if gpu_info["nvidia_smi"] == "found":
229
241
  print(f"Driver Version: {gpu_info['driver_version']}")
230
242
  print(f"CUDA Runtime: {gpu_info['cuda_runtime']}")
231
243
  print(f"GPU Count: {gpu_info['gpu_count']}")
232
-
233
- if gpu_info['gpu_count'] > 0:
234
- if 'gpu_details' in gpu_info:
244
+
245
+ if gpu_info["gpu_count"] > 0:
246
+ if "gpu_details" in gpu_info:
235
247
  print("GPU Details:")
236
- for i, gpu in enumerate(gpu_info['gpu_details']):
248
+ for i, gpu in enumerate(gpu_info["gpu_details"]):
237
249
  print(f" GPU {i}: {gpu['name']} ({gpu['memory_gb']} GB)")
238
250
  else:
239
251
  print(f"GPU Model: {gpu_info['gpu_model']}")
240
252
  print(f"Total VRAM: {gpu_info['gpu_memory_gb']} GB")
241
-
253
+
242
254
  print(f"nvcc: {gpu_info['nvcc']}")
243
- if gpu_info['nvcc'] == 'found':
255
+ if gpu_info["nvcc"] == "found":
244
256
  print(f"nvcc Version: {gpu_info['nvcc_version']}")
245
-
257
+
246
258
  print(f"CUDA Installation: {gpu_info['cuda_installation']}")
247
259
  print(f"CUDA on PATH: {gpu_info['cuda_on_path']}")
248
-
260
+
249
261
  # System Information
250
262
  print("\n💻 System Information:")
251
263
  print("-" * 30)
252
264
  print(f"Platform: {platform.platform()}")
253
265
  print(f"Architecture: {platform.machine()}")
254
266
  print(f"Processor: {platform.processor()}")
255
-
267
+
256
268
  # Environment Variables
257
269
  print("\n🔧 Environment Variables:")
258
270
  print("-" * 30)
259
- relevant_vars = ['CUDA_HOME', 'CUDA_PATH', 'LD_LIBRARY_PATH', 'PATH']
271
+ relevant_vars = ["CUDA_HOME", "CUDA_PATH", "LD_LIBRARY_PATH", "PATH"]
260
272
  for var in relevant_vars:
261
- value = os.environ.get(var, 'not set')
262
- if value != 'not set' and len(value) > 100:
273
+ value = os.environ.get(var, "not set")
274
+ if value != "not set" and len(value) > 100:
263
275
  value = value[:100] + "..."
264
276
  print(f"{var}: {value}")
265
-
277
+
266
278
  print("\n✅ Diagnostics complete!")
267
279
  return 0
268
280
 
281
+
269
282
  def get_cuda_version():
270
283
  """Detect CUDA version from nvcc or nvidia-smi"""
271
284
  try:
272
- result = subprocess.run(['nvcc', '--version'],
273
- capture_output=True, text=True, check=True)
274
- match = re.search(r'release (\d+)\.(\d+)', result.stdout)
285
+ result = subprocess.run(["nvcc", "--version"], capture_output=True, text=True, check=True)
286
+ match = re.search(r"release (\d+)\.(\d+)", result.stdout)
275
287
  if match:
276
288
  return int(match.group(1))
277
289
  except (subprocess.CalledProcessError, FileNotFoundError):
278
290
  try:
279
- result = subprocess.run(['nvidia-smi'],
280
- capture_output=True, text=True, check=True)
281
- match = re.search(r'CUDA Version: (\d+)\.(\d+)', result.stdout)
291
+ result = subprocess.run(["nvidia-smi"], capture_output=True, text=True, check=True)
292
+ match = re.search(r"CUDA Version: (\d+)\.(\d+)", result.stdout)
282
293
  if match:
283
294
  return int(match.group(1))
284
295
  except (subprocess.CalledProcessError, FileNotFoundError):
285
296
  pass
286
297
  return None
287
298
 
299
+
288
300
  def get_compute_capability():
289
301
  """Get compute capability from nvidia-smi"""
290
302
  try:
291
- result = subprocess.run(['nvidia-smi', '--query-gpu=compute_cap', '--format=csv,noheader,nounits'],
292
- capture_output=True, text=True, check=True)
293
- match = re.search(r'(\d+)\.(\d+)', result.stdout)
303
+ result = subprocess.run(
304
+ ["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader,nounits"],
305
+ capture_output=True,
306
+ text=True,
307
+ check=True,
308
+ )
309
+ match = re.search(r"(\d+)\.(\d+)", result.stdout)
294
310
  if match:
295
311
  major = int(match.group(1))
296
312
  minor = int(match.group(2))
@@ -298,6 +314,7 @@ def get_compute_capability():
298
314
  except (subprocess.CalledProcessError, FileNotFoundError):
299
315
  return None
300
316
 
317
+
301
318
  def install_packages():
302
319
  """Install packages for the RapidFire AI project."""
303
320
  packages = []
@@ -314,11 +331,11 @@ def install_packages():
314
331
  # packages.append({"package": "vllm==0.10.1.1", "extra_args": ["--torch-backend=cu118"]})
315
332
  # else:
316
333
  # print("\n⚠️ CUDA version not detected or unsupported.")
317
-
334
+
318
335
  ## TODO: re-enable once flash-attn has fix
319
336
  # if cuda_major is not None:
320
337
  # print(f"\n🎯 Detected CUDA {cuda_major}.x")
321
-
338
+
322
339
  # # Determine flash-attn version based on CUDA version
323
340
  # if cuda_major < 8:
324
341
  # # flash-attn 1.x for CUDA < 8.0
@@ -349,13 +366,14 @@ def install_packages():
349
366
  print(f" You may need to install {package} manually")
350
367
  return 0
351
368
 
369
+
352
370
  def copy_tutorial_notebooks():
353
371
  """Copy the tutorial notebooks to the project."""
354
372
  print("Getting tutorial notebooks...")
355
373
  try:
356
374
  tutorial_path = os.getenv("RF_TUTORIAL_PATH", os.path.join(".", "tutorial_notebooks"))
357
375
  site_packages_path = site.getsitepackages()[0]
358
- source_path =os.path.join(site_packages_path, "tutorial_notebooks")
376
+ source_path = os.path.join(site_packages_path, "tutorial_notebooks")
359
377
  print(f"Copying tutorial notebooks from {source_path} to {tutorial_path}...")
360
378
  os.makedirs(tutorial_path, exist_ok=True)
361
379
  shutil.copytree(source_path, tutorial_path, dirs_exist_ok=True)
@@ -378,44 +396,38 @@ def run_init():
378
396
 
379
397
  return 0
380
398
 
399
+
381
400
  def main():
382
401
  """Main entry point for the rapidfireai command."""
383
- parser = argparse.ArgumentParser(
384
- description="RapidFire AI - Start/stop/manage services",
385
- prog="rapidfireai"
386
- )
387
-
402
+ parser = argparse.ArgumentParser(description="RapidFire AI - Start/stop/manage services", prog="rapidfireai")
403
+
388
404
  parser.add_argument(
389
405
  "command",
390
406
  nargs="?",
391
407
  default="start",
392
408
  choices=["start", "stop", "status", "restart", "setup", "doctor", "init"],
393
- help="Command to execute (default: start)"
394
- )
395
-
396
- parser.add_argument(
397
- "--version",
398
- action="version",
399
- version=f"RapidFire AI {__version__}"
409
+ help="Command to execute (default: start)",
400
410
  )
401
411
 
412
+ parser.add_argument("--version", action="version", version=f"RapidFire AI {__version__}")
413
+
402
414
  parser.add_argument(
403
415
  "--tracking-backend",
404
416
  choices=["mlflow", "tensorboard", "both"],
405
417
  default=os.getenv("RF_TRACKING_BACKEND", "mlflow"),
406
- help="Tracking backend to use for metrics (default: mlflow)"
418
+ help="Tracking backend to use for metrics (default: mlflow)",
407
419
  )
408
420
 
409
421
  parser.add_argument(
410
422
  "--tensorboard-log-dir",
411
423
  default=os.getenv("RF_TENSORBOARD_LOG_DIR", None),
412
- help="Directory for TensorBoard logs (default: {experiment_path}/tensorboard_logs)"
424
+ help="Directory for TensorBoard logs (default: {experiment_path}/tensorboard_logs)",
413
425
  )
414
426
 
415
427
  parser.add_argument(
416
428
  "--colab",
417
429
  action="store_true",
418
- help="Run in Colab mode (skips frontend, conditionally starts MLflow based on tracking backend)"
430
+ help="Run in Colab mode (skips frontend, conditionally starts MLflow based on tracking backend)",
419
431
  )
420
432
 
421
433
  args = parser.parse_args()
@@ -427,7 +439,7 @@ def main():
427
439
  os.environ["RF_TENSORBOARD_LOG_DIR"] = args.tensorboard_log_dir
428
440
  if args.colab:
429
441
  os.environ["RF_COLAB_MODE"] = "true"
430
-
442
+
431
443
  # Handle doctor command separately
432
444
  if args.command == "doctor":
433
445
  return run_doctor()
@@ -435,10 +447,10 @@ def main():
435
447
  # Handle init command separately
436
448
  if args.command == "init":
437
449
  return run_init()
438
-
450
+
439
451
  # Run the script with the specified command
440
452
  return run_script([args.command])
441
453
 
442
454
 
443
455
  if __name__ == "__main__":
444
- sys.exit(main())
456
+ sys.exit(main())
rapidfireai/experiment.py CHANGED
@@ -100,16 +100,18 @@ class Experiment:
100
100
  # Detect if running in Google Colab
101
101
  try:
102
102
  import google.colab
103
+
103
104
  in_colab = True
104
105
  except ImportError:
105
106
  in_colab = False
106
107
 
107
108
  if in_colab:
108
109
  # Run Controller in background thread to keep kernel responsive
109
- import threading
110
110
  import sys
111
+ import threading
111
112
  from io import StringIO
112
- from IPython.display import display, HTML
113
+
114
+ from IPython.display import HTML, display
113
115
 
114
116
  def _run_controller_background():
115
117
  """Run controller in background thread with output suppression"""
@@ -130,20 +132,26 @@ class Experiment:
130
132
  # Restore stdout
131
133
  sys.stdout = old_stdout
132
134
  # Display completion message
133
- display(HTML('<p style="color: blue; font-weight: bold;">🎉 Training completed! Check InteractiveController for final results.</p>'))
135
+ display(
136
+ HTML(
137
+ '<p style="color: blue; font-weight: bold;">🎉 Training completed! Check InteractiveController for final results.</p>'
138
+ )
139
+ )
134
140
  self._training_thread = None
135
141
 
136
142
  self._training_thread = threading.Thread(target=_run_controller_background, daemon=True)
137
143
  self._training_thread.start()
138
144
 
139
145
  # Use IPython display for reliable output in Colab
140
- display(HTML(
141
- '<div style="padding: 10px; background-color: #d4edda; border: 1px solid #28a745; border-radius: 5px; color: #155724;">'
142
- '<b>✓ Training started in background</b><br>'
143
- 'Use InteractiveController to monitor progress. The notebook kernel will remain responsive while training runs.<br>'
144
- '<small>Tip: Interact with InteractiveController periodically to keep Colab active.</small>'
145
- '</div>'
146
- ))
146
+ display(
147
+ HTML(
148
+ '<div style="padding: 10px; background-color: #d4edda; border: 1px solid #28a745; border-radius: 5px; color: #155724;">'
149
+ "<b>✓ Training started in background</b><br>"
150
+ "Use InteractiveController to monitor progress. The notebook kernel will remain responsive while training runs.<br>"
151
+ "<small>Tip: Interact with InteractiveController periodically to keep Colab active.</small>"
152
+ "</div>"
153
+ )
154
+ )
147
155
  else:
148
156
  # Original blocking behavior for non-Colab environments
149
157
  try:
@@ -162,7 +170,9 @@ class Experiment:
162
170
  runs_info_df = self.experiment_utils.get_runs_info()
163
171
 
164
172
  # Check if there are any mlflow_run_ids before importing MLflow
165
- has_mlflow_runs = runs_info_df.get("mlflow_run_id") is not None and runs_info_df["mlflow_run_id"].notna().any()
173
+ has_mlflow_runs = (
174
+ runs_info_df.get("mlflow_run_id") is not None and runs_info_df["mlflow_run_id"].notna().any()
175
+ )
166
176
 
167
177
  if not has_mlflow_runs:
168
178
  # No MLflow runs to fetch, return empty DataFrame
@@ -170,6 +180,7 @@ class Experiment:
170
180
 
171
181
  # Lazy import - only import when we actually have MLflow runs to fetch
172
182
  from rapidfireai.utils.mlflow_manager import MLflowManager
183
+
173
184
  mlflow_manager = MLflowManager(MLFLOW_URL)
174
185
 
175
186
  metrics_data = []
@@ -1,7 +1,7 @@
1
1
  {
2
2
  "files": {
3
3
  "main.css": "/static-files/static/css/main.702595df.css",
4
- "main.js": "/static-files/static/js/main.e7d3b759.js",
4
+ "main.js": "/static-files/static/js/main.aee6c455.js",
5
5
  "ml-model-trace-renderer.js": "/static-files/lib/notebook-trace-renderer/js/ml-model-trace-renderer.5490ebc325fe0f300ad9.js",
6
6
  "static/js/6019.9025341e.chunk.js": "/static-files/static/js/6019.9025341e.chunk.js",
7
7
  "static/js/6336.8153bc1c.chunk.js": "/static-files/static/js/6336.8153bc1c.chunk.js",
@@ -120,7 +120,7 @@
120
120
  "static/media/chart-line.svg": "/static-files/static/media/chart-line.0adaa2036bb4eb5956db6d0c7e925a3d.svg",
121
121
  "lib/notebook-trace-renderer/index.html": "/static-files/lib/notebook-trace-renderer/index.html",
122
122
  "main.702595df.css.map": "/static-files/static/css/main.702595df.css.map",
123
- "main.e7d3b759.js.map": "/static-files/static/js/main.e7d3b759.js.map",
123
+ "main.aee6c455.js.map": "/static-files/static/js/main.aee6c455.js.map",
124
124
  "ml-model-trace-renderer.js.map": "/static-files/lib/notebook-trace-renderer/js/ml-model-trace-renderer.5490ebc325fe0f300ad9.js.map",
125
125
  "6336.8153bc1c.chunk.js.map": "/static-files/static/js/6336.8153bc1c.chunk.js.map",
126
126
  "9478.cbf55ef3.chunk.js.map": "/static-files/static/js/9478.cbf55ef3.chunk.js.map",
@@ -216,6 +216,6 @@
216
216
  },
217
217
  "entrypoints": [
218
218
  "static/css/main.702595df.css",
219
- "static/js/main.e7d3b759.js"
219
+ "static/js/main.aee6c455.js"
220
220
  ]
221
221
  }
@@ -1 +1 @@
1
- <!doctype html><html lang="en"><head><meta charset="utf-8"/><meta name="viewport" content="width=device-width,initial-scale=1,shrink-to-fit=no"/><link rel="shortcut icon" href="./static-files/favicon.ico"/><meta name="theme-color" content="#000000"/><link rel="manifest" href="./static-files/manifest.json" crossorigin="use-credentials"/><title>RapidFire AI</title><script defer="defer" src="static-files/static/js/main.e7d3b759.js"></script><link href="static-files/static/css/main.702595df.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root" class="mlflow-ui-container"></div><div id="modal" class="mlflow-ui-container"></div></body></html>
1
+ <!doctype html><html lang="en"><head><meta charset="utf-8"/><meta name="viewport" content="width=device-width,initial-scale=1,shrink-to-fit=no"/><link rel="shortcut icon" href="./static-files/favicon.ico"/><meta name="theme-color" content="#000000"/><link rel="manifest" href="./static-files/manifest.json" crossorigin="use-credentials"/><title>RapidFire AI</title><script defer="defer" src="static-files/static/js/main.aee6c455.js"></script><link href="static-files/static/css/main.702595df.css" rel="stylesheet"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root" class="mlflow-ui-container"></div><div id="modal" class="mlflow-ui-container"></div></body></html>