rapidfireai 0.10.2rc5__py3-none-any.whl → 0.11.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rapidfireai might be problematic. Click here for more details.

Files changed (36) hide show
  1. rapidfireai/automl/grid_search.py +4 -5
  2. rapidfireai/automl/model_config.py +41 -37
  3. rapidfireai/automl/random_search.py +21 -33
  4. rapidfireai/backend/controller.py +80 -161
  5. rapidfireai/backend/worker.py +26 -8
  6. rapidfireai/cli.py +171 -132
  7. rapidfireai/db/rf_db.py +1 -1
  8. rapidfireai/db/tables.sql +1 -1
  9. rapidfireai/dispatcher/dispatcher.py +3 -1
  10. rapidfireai/dispatcher/gunicorn.conf.py +1 -1
  11. rapidfireai/experiment.py +86 -7
  12. rapidfireai/frontend/build/asset-manifest.json +3 -3
  13. rapidfireai/frontend/build/index.html +1 -1
  14. rapidfireai/frontend/build/static/js/{main.1bf27639.js → main.58393d31.js} +3 -3
  15. rapidfireai/frontend/build/static/js/{main.1bf27639.js.map → main.58393d31.js.map} +1 -1
  16. rapidfireai/frontend/proxy_middleware.py +1 -1
  17. rapidfireai/ml/callbacks.py +85 -59
  18. rapidfireai/ml/trainer.py +42 -86
  19. rapidfireai/start.sh +117 -34
  20. rapidfireai/utils/constants.py +22 -1
  21. rapidfireai/utils/experiment_utils.py +87 -43
  22. rapidfireai/utils/interactive_controller.py +473 -0
  23. rapidfireai/utils/logging.py +1 -2
  24. rapidfireai/utils/metric_logger.py +346 -0
  25. rapidfireai/utils/mlflow_manager.py +0 -1
  26. rapidfireai/utils/ping.py +4 -2
  27. rapidfireai/utils/worker_manager.py +16 -6
  28. rapidfireai/version.py +2 -2
  29. {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/METADATA +7 -4
  30. {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/RECORD +36 -33
  31. tutorial_notebooks/rf-colab-tensorboard-tutorial.ipynb +314 -0
  32. /rapidfireai/frontend/build/static/js/{main.1bf27639.js.LICENSE.txt → main.58393d31.js.LICENSE.txt} +0 -0
  33. {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/WHEEL +0 -0
  34. {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/entry_points.txt +0 -0
  35. {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/licenses/LICENSE +0 -0
  36. {rapidfireai-0.10.2rc5.dist-info → rapidfireai-0.11.1rc1.dist-info}/top_level.txt +0 -0
rapidfireai/cli.py CHANGED
@@ -3,15 +3,16 @@
3
3
  Command-line interface for RapidFire AI
4
4
  """
5
5
 
6
- import os
7
- import sys
8
- import subprocess
9
6
  import argparse
7
+ import os
10
8
  import platform
11
- import shutil
12
9
  import re
10
+ import shutil
13
11
  import site
12
+ import subprocess
13
+ import sys
14
14
  from pathlib import Path
15
+
15
16
  from .version import __version__
16
17
 
17
18
 
@@ -20,24 +21,24 @@ def get_script_path():
20
21
  # Get the directory where this package is installed
21
22
  package_dir = Path(__file__).parent
22
23
  script_path = package_dir / "start.sh"
23
-
24
+
24
25
  if not script_path.exists():
25
26
  # Fallback: try to find it relative to the current working directory
26
27
  script_path = Path.cwd() / "rapidfireai" / "start.sh"
27
28
  if not script_path.exists():
28
29
  raise FileNotFoundError(f"Could not find start.sh script at {script_path}")
29
-
30
+
30
31
  return script_path
31
32
 
32
33
 
33
34
  def run_script(args):
34
35
  """Run the start.sh script with the given arguments."""
35
36
  script_path = get_script_path()
36
-
37
+
37
38
  # Make sure the script is executable
38
39
  if not os.access(script_path, os.X_OK):
39
40
  os.chmod(script_path, 0o755)
40
-
41
+
41
42
  # Run the script with the provided arguments
42
43
  try:
43
44
  result = subprocess.run([str(script_path)] + args, check=True)
@@ -53,24 +54,27 @@ def run_script(args):
53
54
  def get_python_info():
54
55
  """Get comprehensive Python information."""
55
56
  info = {}
56
-
57
+
57
58
  # Python version and implementation
58
- info['version'] = sys.version
59
- info['implementation'] = platform.python_implementation()
60
- info['executable'] = sys.executable
61
-
59
+ info["version"] = sys.version
60
+ info["implementation"] = platform.python_implementation()
61
+ info["executable"] = sys.executable
62
+
62
63
  # Environment information
63
- info['conda_env'] = os.environ.get('CONDA_DEFAULT_ENV', 'none')
64
- info['venv'] = 'yes' if hasattr(sys, 'real_prefix') or (hasattr(sys, 'base_prefix') and sys.base_prefix != sys.prefix) else 'no'
65
-
64
+ info["conda_env"] = os.environ.get("CONDA_DEFAULT_ENV", "none")
65
+ info["venv"] = (
66
+ "yes"
67
+ if hasattr(sys, "real_prefix") or (hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix)
68
+ else "no"
69
+ )
70
+
66
71
  return info
67
72
 
68
73
 
69
74
  def get_pip_packages():
70
75
  """Get list of installed pip packages."""
71
76
  try:
72
- result = subprocess.run([sys.executable, '-m', 'pip', 'list'],
73
- capture_output=True, text=True, check=True)
77
+ result = subprocess.run([sys.executable, "-m", "pip", "list"], capture_output=True, text=True, check=True)
74
78
  return result.stdout
75
79
  except (subprocess.CalledProcessError, FileNotFoundError):
76
80
  return "Failed to get pip packages"
@@ -79,113 +83,107 @@ def get_pip_packages():
79
83
  def get_gpu_info():
80
84
  """Get comprehensive GPU and CUDA information."""
81
85
  info = {}
82
-
86
+
83
87
  # Check for nvidia-smi
84
- nvidia_smi_path = shutil.which('nvidia-smi')
85
- info['nvidia_smi'] = 'found' if nvidia_smi_path else 'not found'
86
-
88
+ nvidia_smi_path = shutil.which("nvidia-smi")
89
+ info["nvidia_smi"] = "found" if nvidia_smi_path else "not found"
90
+
87
91
  if nvidia_smi_path:
88
92
  try:
89
93
  # Get driver and CUDA runtime version from the full nvidia-smi output
90
- result = subprocess.run(['nvidia-smi'],
91
- capture_output=True, text=True, check=True)
94
+ result = subprocess.run(["nvidia-smi"], capture_output=True, text=True, check=True)
92
95
  if result.stdout.strip():
93
- lines = result.stdout.strip().split('\n')
96
+ lines = result.stdout.strip().split("\n")
94
97
  # Look for the header line that contains CUDA version
95
98
  for line in lines:
96
- if 'CUDA Version:' in line:
99
+ if "CUDA Version:" in line:
97
100
  # Extract CUDA version from line like "NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.2"
98
- cuda_version = line.split('CUDA Version:')[1].split()[0]
99
- info['cuda_runtime'] = cuda_version
101
+ cuda_version = line.split("CUDA Version:")[1].split()[0]
102
+ info["cuda_runtime"] = cuda_version
100
103
  # Also extract driver version from the same line
101
- if 'Driver Version:' in line:
102
- driver_version = line.split('Driver Version:')[1].split('CUDA Version:')[0].strip()
103
- info['driver_version'] = driver_version
104
+ if "Driver Version:" in line:
105
+ driver_version = line.split("Driver Version:")[1].split("CUDA Version:")[0].strip()
106
+ info["driver_version"] = driver_version
104
107
  break
105
108
  else:
106
- info['driver_version'] = 'unknown'
107
- info['cuda_runtime'] = 'unknown'
109
+ info["driver_version"] = "unknown"
110
+ info["cuda_runtime"] = "unknown"
108
111
  except (subprocess.CalledProcessError, ValueError):
109
- info['driver_version'] = 'unknown'
110
- info['cuda_runtime'] = 'unknown'
111
-
112
+ info["driver_version"] = "unknown"
113
+ info["cuda_runtime"] = "unknown"
114
+
112
115
  # Get GPU count, models, and VRAM
113
116
  try:
114
- result = subprocess.run(['nvidia-smi', '--query-gpu=count,name,memory.total', '--format=csv,noheader,nounits'],
115
- capture_output=True, text=True, check=True)
117
+ result = subprocess.run(
118
+ ["nvidia-smi", "--query-gpu=count,name,memory.total", "--format=csv,noheader,nounits"],
119
+ capture_output=True,
120
+ text=True,
121
+ check=True,
122
+ )
116
123
  if result.stdout.strip():
117
- lines = result.stdout.strip().split('\n')
124
+ lines = result.stdout.strip().split("\n")
118
125
  if lines:
119
- count, name, memory = lines[0].split(', ')
120
- info['gpu_count'] = int(count)
121
- info['gpu_model'] = name.strip()
126
+ count, name, memory = lines[0].split(", ")
127
+ info["gpu_count"] = int(count)
128
+ info["gpu_model"] = name.strip()
122
129
  # Convert memory from MiB to GB
123
130
  memory_mib = int(memory.split()[0])
124
131
  memory_gb = memory_mib / 1024
125
- info['gpu_memory_gb'] = f"{memory_gb:.1f}"
126
-
132
+ info["gpu_memory_gb"] = f"{memory_gb:.1f}"
133
+
127
134
  # Get detailed info for multiple GPUs if present
128
- if info['gpu_count'] > 1:
129
- info['gpu_details'] = []
135
+ if info["gpu_count"] > 1:
136
+ info["gpu_details"] = []
130
137
  for line in lines:
131
- count, name, memory = line.split(', ')
138
+ count, name, memory = line.split(", ")
132
139
  memory_mib = int(memory.split()[0])
133
140
  memory_gb = memory_mib / 1024
134
- info['gpu_details'].append({
135
- 'name': name.strip(),
136
- 'memory_gb': f"{memory_gb:.1f}"
137
- })
141
+ info["gpu_details"].append({"name": name.strip(), "memory_gb": f"{memory_gb:.1f}"})
138
142
  except (subprocess.CalledProcessError, ValueError):
139
- info['gpu_count'] = 0
140
- info['gpu_model'] = 'unknown'
141
- info['gpu_memory_gb'] = 'unknown'
143
+ info["gpu_count"] = 0
144
+ info["gpu_model"] = "unknown"
145
+ info["gpu_memory_gb"] = "unknown"
142
146
  else:
143
- info['driver_version'] = 'N/A'
144
- info['cuda_runtime'] = 'N/A'
145
- info['gpu_count'] = 0
146
- info['gpu_model'] = 'N/A'
147
- info['gpu_memory_gb'] = 'N/A'
148
-
147
+ info["driver_version"] = "N/A"
148
+ info["cuda_runtime"] = "N/A"
149
+ info["gpu_count"] = 0
150
+ info["gpu_model"] = "N/A"
151
+ info["gpu_memory_gb"] = "N/A"
152
+
149
153
  # Check for nvcc (CUDA compiler)
150
- nvcc_path = shutil.which('nvcc')
151
- info['nvcc'] = 'found' if nvcc_path else 'not found'
152
-
154
+ nvcc_path = shutil.which("nvcc")
155
+ info["nvcc"] = "found" if nvcc_path else "not found"
156
+
153
157
  if nvcc_path:
154
158
  try:
155
- result = subprocess.run(['nvcc', '--version'],
156
- capture_output=True, text=True, check=True)
159
+ result = subprocess.run(["nvcc", "--version"], capture_output=True, text=True, check=True)
157
160
  # Extract version from output like "Cuda compilation tools, release 11.8, V11.8.89"
158
- version_line = result.stdout.split('\n')[0]
159
- if 'release' in version_line:
160
- version = version_line.split('release')[1].split(',')[0].strip()
161
- info['nvcc_version'] = version
161
+ version_line = result.stdout.split("\n")[0]
162
+ if "release" in version_line:
163
+ version = version_line.split("release")[1].split(",")[0].strip()
164
+ info["nvcc_version"] = version
162
165
  else:
163
- info['nvcc_version'] = 'unknown'
166
+ info["nvcc_version"] = "unknown"
164
167
  except subprocess.CalledProcessError:
165
- info['nvcc_version'] = 'unknown'
168
+ info["nvcc_version"] = "unknown"
166
169
  else:
167
- info['nvcc_version'] = 'N/A'
168
-
170
+ info["nvcc_version"] = "N/A"
171
+
169
172
  # Check CUDA installation paths
170
- cuda_paths = [
171
- '/usr/local/cuda',
172
- '/opt/cuda',
173
- '/usr/cuda',
174
- os.path.expanduser('~/cuda')
175
- ]
176
-
173
+ cuda_paths = ["/usr/local/cuda", "/opt/cuda", "/usr/cuda", os.path.expanduser("~/cuda")]
174
+
177
175
  cuda_installed = False
178
176
  for path in cuda_paths:
179
177
  if os.path.exists(path):
180
178
  cuda_installed = True
181
179
  break
182
-
183
- info['cuda_installation'] = 'present' if cuda_installed else 'not present'
184
-
180
+
181
+ info["cuda_installation"] = "present" if cuda_installed else "not present"
182
+
185
183
  # Check if CUDA is on PATH
186
- cuda_on_path = any('cuda' in p.lower() for p in os.environ.get('PATH', '').split(os.pathsep))
187
- info['cuda_on_path'] = 'yes' if cuda_on_path else 'no'
188
-
184
+ cuda_on_path = any("cuda" in p.lower() for p in os.environ.get("PATH", "").split(os.pathsep))
185
+ info["cuda_on_path"] = "yes" if cuda_on_path else "no"
186
+
189
187
  return info
190
188
 
191
189
 
@@ -193,7 +191,7 @@ def run_doctor():
193
191
  """Run the doctor command to diagnose system issues."""
194
192
  print("🔍 RapidFire AI System Diagnostics")
195
193
  print("=" * 50)
196
-
194
+
197
195
  # Python Information
198
196
  print("\n🐍 Python Environment:")
199
197
  print("-" * 30)
@@ -203,94 +201,112 @@ def run_doctor():
203
201
  print(f"Executable: {python_info['executable']}")
204
202
  print(f"Conda Environment: {python_info['conda_env']}")
205
203
  print(f"Virtual Environment: {python_info['venv']}")
206
-
204
+
207
205
  # Pip Packages
208
206
  print("\n📦 Installed Packages:")
209
207
  print("-" * 30)
210
208
  pip_output = get_pip_packages()
211
209
  if pip_output != "Failed to get pip packages":
212
210
  # Show only relevant packages
213
- relevant_packages = ['rapidfireai', 'mlflow', 'torch', 'transformers', 'flask', 'gunicorn', 'peft', 'trl', 'bitsandbytes', 'nltk', 'evaluate', 'rouge-score', 'sentencepiece']
214
- lines = pip_output.split('\n')
211
+ relevant_packages = [
212
+ "rapidfireai",
213
+ "mlflow",
214
+ "torch",
215
+ "transformers",
216
+ "flask",
217
+ "gunicorn",
218
+ "peft",
219
+ "trl",
220
+ "bitsandbytes",
221
+ "nltk",
222
+ "evaluate",
223
+ "rouge-score",
224
+ "sentencepiece",
225
+ ]
226
+ lines = pip_output.split("\n")
215
227
  for line in lines:
216
228
  if any(pkg.lower() in line.lower() for pkg in relevant_packages):
217
229
  print(line)
218
230
  print("... (showing only relevant packages)")
219
231
  else:
220
232
  print(pip_output)
221
-
233
+
222
234
  # GPU Information
223
235
  print("\n🚀 GPU & CUDA Information:")
224
236
  print("-" * 30)
225
237
  gpu_info = get_gpu_info()
226
238
  print(f"nvidia-smi: {gpu_info['nvidia_smi']}")
227
-
228
- if gpu_info['nvidia_smi'] == 'found':
239
+
240
+ if gpu_info["nvidia_smi"] == "found":
229
241
  print(f"Driver Version: {gpu_info['driver_version']}")
230
242
  print(f"CUDA Runtime: {gpu_info['cuda_runtime']}")
231
243
  print(f"GPU Count: {gpu_info['gpu_count']}")
232
-
233
- if gpu_info['gpu_count'] > 0:
234
- if 'gpu_details' in gpu_info:
244
+
245
+ if gpu_info["gpu_count"] > 0:
246
+ if "gpu_details" in gpu_info:
235
247
  print("GPU Details:")
236
- for i, gpu in enumerate(gpu_info['gpu_details']):
248
+ for i, gpu in enumerate(gpu_info["gpu_details"]):
237
249
  print(f" GPU {i}: {gpu['name']} ({gpu['memory_gb']} GB)")
238
250
  else:
239
251
  print(f"GPU Model: {gpu_info['gpu_model']}")
240
252
  print(f"Total VRAM: {gpu_info['gpu_memory_gb']} GB")
241
-
253
+
242
254
  print(f"nvcc: {gpu_info['nvcc']}")
243
- if gpu_info['nvcc'] == 'found':
255
+ if gpu_info["nvcc"] == "found":
244
256
  print(f"nvcc Version: {gpu_info['nvcc_version']}")
245
-
257
+
246
258
  print(f"CUDA Installation: {gpu_info['cuda_installation']}")
247
259
  print(f"CUDA on PATH: {gpu_info['cuda_on_path']}")
248
-
260
+
249
261
  # System Information
250
262
  print("\n💻 System Information:")
251
263
  print("-" * 30)
252
264
  print(f"Platform: {platform.platform()}")
253
265
  print(f"Architecture: {platform.machine()}")
254
266
  print(f"Processor: {platform.processor()}")
255
-
267
+
256
268
  # Environment Variables
257
269
  print("\n🔧 Environment Variables:")
258
270
  print("-" * 30)
259
- relevant_vars = ['CUDA_HOME', 'CUDA_PATH', 'LD_LIBRARY_PATH', 'PATH']
271
+ relevant_vars = ["CUDA_HOME", "CUDA_PATH", "LD_LIBRARY_PATH", "PATH"]
260
272
  for var in relevant_vars:
261
- value = os.environ.get(var, 'not set')
262
- if value != 'not set' and len(value) > 100:
273
+ value = os.environ.get(var, "not set")
274
+ if value != "not set" and len(value) > 100:
263
275
  value = value[:100] + "..."
264
276
  print(f"{var}: {value}")
265
-
277
+
266
278
  print("\n✅ Diagnostics complete!")
267
279
  return 0
268
280
 
281
+
269
282
  def get_cuda_version():
270
283
  """Detect CUDA version from nvcc or nvidia-smi"""
271
284
  try:
272
- result = subprocess.run(['nvcc', '--version'],
273
- capture_output=True, text=True, check=True)
274
- match = re.search(r'release (\d+)\.(\d+)', result.stdout)
285
+ result = subprocess.run(["nvcc", "--version"], capture_output=True, text=True, check=True)
286
+ match = re.search(r"release (\d+)\.(\d+)", result.stdout)
275
287
  if match:
276
288
  return int(match.group(1))
277
289
  except (subprocess.CalledProcessError, FileNotFoundError):
278
290
  try:
279
- result = subprocess.run(['nvidia-smi'],
280
- capture_output=True, text=True, check=True)
281
- match = re.search(r'CUDA Version: (\d+)\.(\d+)', result.stdout)
291
+ result = subprocess.run(["nvidia-smi"], capture_output=True, text=True, check=True)
292
+ match = re.search(r"CUDA Version: (\d+)\.(\d+)", result.stdout)
282
293
  if match:
283
294
  return int(match.group(1))
284
295
  except (subprocess.CalledProcessError, FileNotFoundError):
285
296
  pass
286
297
  return None
287
298
 
299
+
288
300
  def get_compute_capability():
289
301
  """Get compute capability from nvidia-smi"""
290
302
  try:
291
- result = subprocess.run(['nvidia-smi', '--query-gpu=compute_cap', '--format=csv,noheader,nounits'],
292
- capture_output=True, text=True, check=True)
293
- match = re.search(r'(\d+)\.(\d+)', result.stdout)
303
+ result = subprocess.run(
304
+ ["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader,nounits"],
305
+ capture_output=True,
306
+ text=True,
307
+ check=True,
308
+ )
309
+ match = re.search(r"(\d+)\.(\d+)", result.stdout)
294
310
  if match:
295
311
  major = int(match.group(1))
296
312
  minor = int(match.group(2))
@@ -298,6 +314,7 @@ def get_compute_capability():
298
314
  except (subprocess.CalledProcessError, FileNotFoundError):
299
315
  return None
300
316
 
317
+
301
318
  def install_packages():
302
319
  """Install packages for the RapidFire AI project."""
303
320
  packages = []
@@ -314,11 +331,11 @@ def install_packages():
314
331
  # packages.append({"package": "vllm==0.10.1.1", "extra_args": ["--torch-backend=cu118"]})
315
332
  # else:
316
333
  # print("\n⚠️ CUDA version not detected or unsupported.")
317
-
334
+
318
335
  ## TODO: re-enable once flash-attn has fix
319
336
  # if cuda_major is not None:
320
337
  # print(f"\n🎯 Detected CUDA {cuda_major}.x")
321
-
338
+
322
339
  # # Determine flash-attn version based on CUDA version
323
340
  # if cuda_major < 8:
324
341
  # # flash-attn 1.x for CUDA < 8.0
@@ -349,13 +366,14 @@ def install_packages():
349
366
  print(f" You may need to install {package} manually")
350
367
  return 0
351
368
 
369
+
352
370
  def copy_tutorial_notebooks():
353
371
  """Copy the tutorial notebooks to the project."""
354
372
  print("Getting tutorial notebooks...")
355
373
  try:
356
374
  tutorial_path = os.getenv("RF_TUTORIAL_PATH", os.path.join(".", "tutorial_notebooks"))
357
375
  site_packages_path = site.getsitepackages()[0]
358
- source_path =os.path.join(site_packages_path, "tutorial_notebooks")
376
+ source_path = os.path.join(site_packages_path, "tutorial_notebooks")
359
377
  print(f"Copying tutorial notebooks from {source_path} to {tutorial_path}...")
360
378
  os.makedirs(tutorial_path, exist_ok=True)
361
379
  shutil.copytree(source_path, tutorial_path, dirs_exist_ok=True)
@@ -378,29 +396,50 @@ def run_init():
378
396
 
379
397
  return 0
380
398
 
399
+
381
400
  def main():
382
401
  """Main entry point for the rapidfireai command."""
383
- parser = argparse.ArgumentParser(
384
- description="RapidFire AI - Start/stop/manage services",
385
- prog="rapidfireai"
386
- )
387
-
402
+ parser = argparse.ArgumentParser(description="RapidFire AI - Start/stop/manage services", prog="rapidfireai")
403
+
388
404
  parser.add_argument(
389
405
  "command",
390
406
  nargs="?",
391
407
  default="start",
392
408
  choices=["start", "stop", "status", "restart", "setup", "doctor", "init"],
393
- help="Command to execute (default: start)"
409
+ help="Command to execute (default: start)",
410
+ )
411
+
412
+ parser.add_argument("--version", action="version", version=f"RapidFire AI {__version__}")
413
+
414
+ parser.add_argument(
415
+ "--tracking-backend",
416
+ choices=["mlflow", "tensorboard", "both"],
417
+ default=os.getenv("RF_TRACKING_BACKEND", "mlflow"),
418
+ help="Tracking backend to use for metrics (default: mlflow)",
419
+ )
420
+
421
+ parser.add_argument(
422
+ "--tensorboard-log-dir",
423
+ default=os.getenv("RF_TENSORBOARD_LOG_DIR", None),
424
+ help="Directory for TensorBoard logs (default: {experiment_path}/tensorboard_logs)",
394
425
  )
395
-
426
+
396
427
  parser.add_argument(
397
- "--version",
398
- action="version",
399
- version=f"RapidFire AI {__version__}"
428
+ "--colab",
429
+ action="store_true",
430
+ help="Run in Colab mode (skips frontend, conditionally starts MLflow based on tracking backend)",
400
431
  )
401
-
432
+
402
433
  args = parser.parse_args()
403
-
434
+
435
+ # Set environment variables from CLI args
436
+ if args.tracking_backend:
437
+ os.environ["RF_TRACKING_BACKEND"] = args.tracking_backend
438
+ if args.tensorboard_log_dir:
439
+ os.environ["RF_TENSORBOARD_LOG_DIR"] = args.tensorboard_log_dir
440
+ if args.colab:
441
+ os.environ["RF_COLAB_MODE"] = "true"
442
+
404
443
  # Handle doctor command separately
405
444
  if args.command == "doctor":
406
445
  return run_doctor()
@@ -408,10 +447,10 @@ def main():
408
447
  # Handle init command separately
409
448
  if args.command == "init":
410
449
  return run_init()
411
-
450
+
412
451
  # Run the script with the specified command
413
452
  return run_script([args.command])
414
453
 
415
454
 
416
455
  if __name__ == "__main__":
417
- sys.exit(main())
456
+ sys.exit(main())
rapidfireai/db/rf_db.py CHANGED
@@ -113,7 +113,7 @@ class RfDb:
113
113
  def create_experiment(
114
114
  self,
115
115
  experiment_name: str,
116
- mlflow_experiment_id: str,
116
+ mlflow_experiment_id: str | None,
117
117
  config_options: dict[str, Any],
118
118
  ) -> int:
119
119
  """Create a new experiment"""
rapidfireai/db/tables.sql CHANGED
@@ -2,7 +2,7 @@
2
2
  CREATE TABLE IF NOT EXISTS experiments (
3
3
  experiment_id INTEGER PRIMARY KEY AUTOINCREMENT,
4
4
  experiment_name TEXT NOT NULL,
5
- mlflow_experiment_id TEXT NOT NULL,
5
+ mlflow_experiment_id TEXT,
6
6
  config_options TEXT NOT NULL,
7
7
  status TEXT NOT NULL,
8
8
  current_task TEXT NOT NULL,
@@ -30,7 +30,9 @@ class Dispatcher:
30
30
  self.app: Flask = Flask(__name__)
31
31
 
32
32
  # Enable CORS for all routes
33
- _ = CORS(self.app, resources={r"/*": {"origins": CORS_ALLOWED_ORIGINS}})
33
+ # Allow all origins for local development (dispatcher runs on localhost)
34
+ # This is safe since the API is not exposed to the internet
35
+ _ = CORS(self.app, resources={r"/*": {"origins": "*"}})
34
36
 
35
37
  # register routes
36
38
  self.register_routes()
@@ -5,7 +5,7 @@ from rapidfireai.utils.constants import DispatcherConfig
5
5
 
6
6
  # Other Gunicorn settings...
7
7
  bind = f"{DispatcherConfig.HOST}:{DispatcherConfig.PORT}"
8
- workers = 2
8
+ workers = 1 # Single worker for Colab/single-user environments to save memory
9
9
 
10
10
  wsgi_app = "rapidfireai.dispatcher.dispatcher:serve_forever()"
11
11