rapidfireai 0.10.3rc1__py3-none-any.whl → 0.11.1rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rapidfireai might be problematic. Click here for more details.

Files changed (26) hide show
  1. rapidfireai/automl/grid_search.py +4 -5
  2. rapidfireai/automl/model_config.py +41 -37
  3. rapidfireai/automl/random_search.py +21 -33
  4. rapidfireai/backend/controller.py +54 -148
  5. rapidfireai/backend/worker.py +14 -3
  6. rapidfireai/cli.py +148 -136
  7. rapidfireai/experiment.py +22 -11
  8. rapidfireai/frontend/build/asset-manifest.json +3 -3
  9. rapidfireai/frontend/build/index.html +1 -1
  10. rapidfireai/frontend/build/static/js/{main.e7d3b759.js → main.aee6c455.js} +3 -3
  11. rapidfireai/frontend/build/static/js/{main.e7d3b759.js.map → main.aee6c455.js.map} +1 -1
  12. rapidfireai/ml/callbacks.py +10 -24
  13. rapidfireai/ml/trainer.py +37 -81
  14. rapidfireai/utils/constants.py +3 -1
  15. rapidfireai/utils/interactive_controller.py +40 -61
  16. rapidfireai/utils/logging.py +1 -2
  17. rapidfireai/utils/mlflow_manager.py +1 -0
  18. rapidfireai/utils/ping.py +4 -2
  19. rapidfireai/version.py +2 -2
  20. {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc2.dist-info}/METADATA +1 -1
  21. {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc2.dist-info}/RECORD +26 -26
  22. /rapidfireai/frontend/build/static/js/{main.e7d3b759.js.LICENSE.txt → main.aee6c455.js.LICENSE.txt} +0 -0
  23. {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc2.dist-info}/WHEEL +0 -0
  24. {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc2.dist-info}/entry_points.txt +0 -0
  25. {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc2.dist-info}/licenses/LICENSE +0 -0
  26. {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc2.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,9 @@
1
- from typing import Callable, Dict, List, Optional
1
+ from collections.abc import Callable
2
2
 
3
3
  import torch
4
4
  from datasets import Dataset
5
5
  from tqdm import tqdm
6
- from transformers import (
7
- TrainerCallback,
8
- TrainerControl,
9
- TrainerState,
10
- TrainingArguments,
11
- )
6
+ from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
12
7
  from transformers.trainer_utils import IntervalStrategy, SaveStrategy
13
8
 
14
9
 
@@ -17,7 +12,7 @@ class GenerationMetricsCallback(TrainerCallback):
17
12
  self,
18
13
  tokenizer,
19
14
  eval_dataset: Dataset,
20
- generation_config: Optional[Dict] = None,
15
+ generation_config: dict | None = None,
21
16
  compute_metrics: Callable = None,
22
17
  batch_size: int = 8,
23
18
  metric_logger=None,
@@ -90,9 +85,7 @@ class GenerationMetricsCallback(TrainerCallback):
90
85
  elif "prompt" in item and "completion" in item:
91
86
  input_text = item["prompt"]
92
87
  reference = item["completion"][-1]["content"]
93
- input_text = self.tokenizer.apply_chat_template(
94
- input_text, tokenize=False
95
- )
88
+ input_text = self.tokenizer.apply_chat_template(input_text, tokenize=False)
96
89
  elif "text" in item:
97
90
  # SFT format - use text as input, response as reference
98
91
  input_text = item["text"]
@@ -114,7 +107,7 @@ class GenerationMetricsCallback(TrainerCallback):
114
107
 
115
108
  return input_texts, references
116
109
 
117
- def _generate_batch(self, model, input_texts: List[str]) -> torch.Tensor:
110
+ def _generate_batch(self, model, input_texts: list[str]) -> torch.Tensor:
118
111
  """Generate text for a batch of inputs with defensive validation"""
119
112
  # Defensive validation for empty inputs
120
113
  if not input_texts:
@@ -136,7 +129,7 @@ class GenerationMetricsCallback(TrainerCallback):
136
129
  print(f"Warning: Tokenization error in generation callback: {e}")
137
130
  return torch.empty((0, 0), dtype=torch.long).to(model.device)
138
131
 
139
- def _compute_generation_metrics(self, model, step: int) -> Dict[str, float]:
132
+ def _compute_generation_metrics(self, model, step: int) -> dict[str, float]:
140
133
  """Generate text and compute BLEU/ROUGE metrics with batch processing"""
141
134
  model.eval()
142
135
 
@@ -163,14 +156,10 @@ class GenerationMetricsCallback(TrainerCallback):
163
156
  return {}
164
157
 
165
158
  with torch.no_grad():
166
- for i in tqdm(
167
- range(0, len(indices), self.batch_size), desc="Generating for metrics"
168
- ):
159
+ for i in tqdm(range(0, len(indices), self.batch_size), desc="Generating for metrics"):
169
160
  input_ids_batch = input_ids[i : i + self.batch_size]
170
161
  with torch.inference_mode(), torch.cuda.amp.autocast():
171
- outputs_batch = model.generate(
172
- input_ids_batch, **self.generation_config
173
- )
162
+ outputs_batch = model.generate(input_ids_batch, **self.generation_config)
174
163
  generated_texts = self.tokenizer.batch_decode(
175
164
  outputs_batch[:, input_ids_batch.shape[1] :],
176
165
  skip_special_tokens=True,
@@ -257,7 +246,7 @@ class LogLevelCallback(TrainerCallback):
257
246
  A [`TrainerCallback`] that handles the default flow of the training loop for logs, evaluation and checkpoints.
258
247
  """
259
248
 
260
- def __init__(self, global_step_args: Dict):
249
+ def __init__(self, global_step_args: dict):
261
250
  self.eval_first_step = global_step_args.get("eval_first_step", 0)
262
251
  self.actual_steps = global_step_args.get("actual_steps", 0)
263
252
  self.log_first_step = global_step_args.get("log_first_step", 0)
@@ -315,10 +304,7 @@ class LogLevelCallback(TrainerCallback):
315
304
  control.should_log = True
316
305
 
317
306
  # Evaluate
318
- if (
319
- args.eval_strategy == IntervalStrategy.EPOCH
320
- and args.eval_delay <= state.epoch
321
- ):
307
+ if args.eval_strategy == IntervalStrategy.EPOCH and args.eval_delay <= state.epoch:
322
308
  control.should_evaluate = True
323
309
 
324
310
  # Save
rapidfireai/ml/trainer.py CHANGED
@@ -1,4 +1,3 @@
1
- import logging
2
1
  import math
3
2
  import os
4
3
 
@@ -7,11 +6,7 @@ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dic
7
6
  from transformers.utils.logging import set_verbosity_error
8
7
  from trl import DPOConfig, DPOTrainer, GRPOConfig, GRPOTrainer, SFTConfig, SFTTrainer
9
8
 
10
- from rapidfireai.ml.callbacks import (
11
- GenerationMetricsCallback,
12
- MLflowLoggingCallback,
13
- LogLevelCallback,
14
- )
9
+ from rapidfireai.ml.callbacks import GenerationMetricsCallback, LogLevelCallback, MLflowLoggingCallback
15
10
  from rapidfireai.ml.checkpoint_utils import (
16
11
  ensure_gradient_compatibility,
17
12
  load_checkpoint_from_disk,
@@ -51,21 +46,15 @@ def create_trainer_instance(
51
46
  compute_metrics = additional_trainer_kwargs.get("compute_metrics", None)
52
47
 
53
48
  # Configure training arguments
54
- training_args, global_step_args = _configure_training_args(
55
- training_args, trainer_config
56
- )
49
+ training_args, global_step_args = _configure_training_args(training_args, trainer_config)
57
50
  trainer_config_obj = _create_trainer_config_object(trainer_type, training_args)
58
51
  # check if peft params is empty dict
59
52
  is_peft = bool(config_leaf.get("peft_params"))
60
53
  # Load model and tokenizer
61
54
  if use_shared_memory:
62
- model_instance, tokenizer = load_checkpoint_from_shared_memory(
63
- trainer_config, shm_manager, is_peft=is_peft
64
- )
55
+ model_instance, tokenizer = load_checkpoint_from_shared_memory(trainer_config, shm_manager, is_peft=is_peft)
65
56
  else:
66
- model_instance, tokenizer = load_checkpoint_from_disk(
67
- trainer_config, is_peft=is_peft
68
- )
57
+ model_instance, tokenizer = load_checkpoint_from_disk(trainer_config, is_peft=is_peft)
69
58
  # add model name to model config
70
59
  config_leaf["model_name"] = model_instance.config._name_or_path
71
60
 
@@ -84,30 +73,26 @@ def create_trainer_instance(
84
73
 
85
74
  model_instance = model_instance.to(device)
86
75
 
87
- trainer_kwargs, formatting_func, additional_trainer_kwargs = (
88
- _prepare_trainer_kwargs(
89
- model_instance,
90
- trainer_config_obj,
91
- tokenizer,
92
- trainer_config,
93
- additional_trainer_kwargs,
94
- ref_model_instance,
95
- config_leaf,
96
- )
76
+ trainer_kwargs, formatting_func, additional_trainer_kwargs = _prepare_trainer_kwargs(
77
+ model_instance,
78
+ trainer_config_obj,
79
+ tokenizer,
80
+ trainer_config,
81
+ additional_trainer_kwargs,
82
+ ref_model_instance,
83
+ config_leaf,
97
84
  )
98
85
 
99
- callbacks, additional_trainer_kwargs = (
100
- _setup_callbacks( # FIXME: avoid returning additional_trainer_kwargs
101
- metric_logger,
102
- trainer_config,
103
- chunk_id,
104
- compute_metrics,
105
- additional_trainer_kwargs,
106
- tokenizer,
107
- training_args,
108
- formatting_func,
109
- global_step_args,
110
- )
86
+ callbacks, additional_trainer_kwargs = _setup_callbacks( # FIXME: avoid returning additional_trainer_kwargs
87
+ metric_logger,
88
+ trainer_config,
89
+ chunk_id,
90
+ compute_metrics,
91
+ additional_trainer_kwargs,
92
+ tokenizer,
93
+ training_args,
94
+ formatting_func,
95
+ global_step_args,
111
96
  )
112
97
 
113
98
  if callbacks:
@@ -116,29 +101,22 @@ def create_trainer_instance(
116
101
  trainer_kwargs.update(additional_trainer_kwargs)
117
102
  trainer_kwargs = {k: v for k, v in trainer_kwargs.items() if v is not None}
118
103
 
119
- trainer = _create_trainer_by_type(
120
- trainer_type, trainer_kwargs, trainer_config, use_shared_memory, shm_manager
121
- )
104
+ trainer = _create_trainer_by_type(trainer_type, trainer_kwargs, trainer_config, use_shared_memory, shm_manager)
122
105
  return trainer, config_leaf["model_name"]
123
106
 
124
107
 
125
- def _configure_training_args(
126
- training_args: dict, trainer_config: TrainerConfig
127
- ) -> dict:
108
+ def _configure_training_args(training_args: dict, trainer_config: TrainerConfig) -> dict:
128
109
  """Configure training arguments with default values."""
129
110
  completed_steps = trainer_config.completed_steps
130
111
  per_device_train_batch_size = training_args.get("per_device_train_batch_size", 1)
131
112
  gradient_accumulation_steps = training_args.get("gradient_accumulation_steps", 1)
132
- len_dataloader = math.ceil(
133
- trainer_config.train_dataset.num_rows / per_device_train_batch_size
134
- )
113
+ len_dataloader = math.ceil(trainer_config.train_dataset.num_rows / per_device_train_batch_size)
135
114
  steps_per_epoch = max(
136
- len_dataloader // gradient_accumulation_steps
137
- + int(len_dataloader % gradient_accumulation_steps > 0),
115
+ len_dataloader // gradient_accumulation_steps + int(len_dataloader % gradient_accumulation_steps > 0),
138
116
  1,
139
117
  )
140
118
 
141
- if trainer_config.config_leaf.get("trainer_type","SFT") == "GRPO":
119
+ if trainer_config.config_leaf.get("trainer_type", "SFT") == "GRPO":
142
120
  num_generations = training_args.get("num_generations", 8)
143
121
  steps_per_epoch = (num_generations * trainer_config.train_dataset.num_rows) // (
144
122
  gradient_accumulation_steps * per_device_train_batch_size
@@ -215,10 +193,7 @@ def _setup_reference_model(
215
193
  if model_adapter_name is not None and ref_adapter_name is not None:
216
194
  if use_shared_memory:
217
195
  peft_config = LoraConfig(**config_leaf["peft_params"])
218
- if (
219
- trainer_config.completed_steps == 0
220
- and trainer_config.warm_started_from is None
221
- ):
196
+ if trainer_config.completed_steps == 0 and trainer_config.warm_started_from is None:
222
197
  reference_state_dict = get_peft_model_state_dict(model_instance)
223
198
  reference_state_dict = move_tensors_to_cpu(reference_state_dict)
224
199
  shm_manager.save_model_object(
@@ -230,14 +205,10 @@ def _setup_reference_model(
230
205
  reference_state_dict = shm_manager.load_model_object(
231
206
  trainer_config.run_id, SHMObjectType.REF_STATE_DICT
232
207
  )
233
- reference_state_dict = move_tensors_to_device(
234
- reference_state_dict, device
235
- )
208
+ reference_state_dict = move_tensors_to_device(reference_state_dict, device)
236
209
  model_instance.add_adapter(ref_adapter_name, peft_config)
237
210
  model_instance.set_adapter(ref_adapter_name)
238
- set_peft_model_state_dict(
239
- model_instance, reference_state_dict, adapter_name=ref_adapter_name
240
- )
211
+ set_peft_model_state_dict(model_instance, reference_state_dict, adapter_name=ref_adapter_name)
241
212
  model_instance.set_adapter(model_adapter_name)
242
213
  else:
243
214
  base_run_path = DataPath.base_run_path(trainer_config.run_id)
@@ -289,9 +260,7 @@ def _prepare_trainer_kwargs(
289
260
 
290
261
  if additional_trainer_kwargs.get("formatting_func") is not None:
291
262
  formatting_func = additional_trainer_kwargs.get("formatting_func")
292
- train_dataset = train_dataset.map(
293
- formatting_func
294
- ) # FIXME: add try exception with batched/unbatched
263
+ train_dataset = train_dataset.map(formatting_func) # FIXME: add try exception with batched/unbatched
295
264
  if eval_dataset is not None:
296
265
  eval_dataset = eval_dataset.map(formatting_func)
297
266
  additional_trainer_kwargs_copy = additional_trainer_kwargs.copy()
@@ -337,10 +306,7 @@ def _setup_callbacks(
337
306
  )
338
307
  callbacks.append(mlflow_callback)
339
308
 
340
- if (
341
- compute_metrics is not None
342
- and additional_trainer_kwargs.get("generation_config") is not None
343
- ):
309
+ if compute_metrics is not None and additional_trainer_kwargs.get("generation_config") is not None:
344
310
  compute_metrics_function = compute_metrics
345
311
  if formatting_func is not None:
346
312
  formatted_eval_dataset = trainer_config.eval_dataset.map(formatting_func)
@@ -365,15 +331,11 @@ def _setup_callbacks(
365
331
  return callbacks, additional_trainer_kwargs
366
332
 
367
333
 
368
- def _create_trainer_by_type(
369
- trainer_type, trainer_kwargs, trainer_config, use_shared_memory, shm_manager
370
- ):
334
+ def _create_trainer_by_type(trainer_type, trainer_kwargs, trainer_config, use_shared_memory, shm_manager):
371
335
  """Create trainer instance based on type with proper state restoration."""
372
336
  if trainer_type == "SFT":
373
337
  dummy_trainer = SFTTrainer(**trainer_kwargs)
374
- dummy_trainer.create_optimizer_and_scheduler(
375
- num_training_steps=trainer_config.total_steps
376
- )
338
+ dummy_trainer.create_optimizer_and_scheduler(num_training_steps=trainer_config.total_steps)
377
339
  trainer = SFTTrainer(
378
340
  **trainer_kwargs,
379
341
  optimizers=(dummy_trainer.optimizer, dummy_trainer.lr_scheduler),
@@ -382,9 +344,7 @@ def _create_trainer_by_type(
382
344
 
383
345
  elif trainer_type == "DPO":
384
346
  dummy_trainer = DPOTrainer(**trainer_kwargs)
385
- dummy_trainer.create_optimizer_and_scheduler(
386
- num_training_steps=trainer_config.total_steps
387
- )
347
+ dummy_trainer.create_optimizer_and_scheduler(num_training_steps=trainer_config.total_steps)
388
348
  trainer = DPOTrainer(
389
349
  **trainer_kwargs,
390
350
  optimizers=(dummy_trainer.optimizer, dummy_trainer.lr_scheduler),
@@ -393,9 +353,7 @@ def _create_trainer_by_type(
393
353
 
394
354
  elif trainer_type == "GRPO":
395
355
  dummy_trainer = GRPOTrainer(**trainer_kwargs)
396
- dummy_trainer.create_optimizer_and_scheduler(
397
- num_training_steps=trainer_config.total_steps
398
- )
356
+ dummy_trainer.create_optimizer_and_scheduler(num_training_steps=trainer_config.total_steps)
399
357
  trainer = GRPOTrainer(
400
358
  **trainer_kwargs,
401
359
  optimizers=(dummy_trainer.optimizer, dummy_trainer.lr_scheduler),
@@ -406,9 +364,7 @@ def _create_trainer_by_type(
406
364
 
407
365
  if trainer_config.completed_steps > 0:
408
366
  if use_shared_memory:
409
- trainer = restore_trainer_from_shared_memory(
410
- trainer, trainer_config, shm_manager
411
- )
367
+ trainer = restore_trainer_from_shared_memory(trainer, trainer_config, shm_manager)
412
368
  else:
413
369
  trainer = restore_trainer_from_disk(trainer, trainer_config)
414
370
 
@@ -1,8 +1,9 @@
1
- from enum import Enum
2
1
  import os
2
+ from enum import Enum
3
3
 
4
4
  MLFLOW_URL = "http://127.0.0.1:5002"
5
5
 
6
+
6
7
  # Tracking Backend Configuration
7
8
  def get_tracking_backend() -> str:
8
9
  """
@@ -17,6 +18,7 @@ def get_tracking_backend() -> str:
17
18
  backend = os.getenv("RF_TRACKING_BACKEND", "mlflow")
18
19
  return backend
19
20
 
21
+
20
22
  # Backwards compatibility: Keep constant but it will be stale if env var changes after import
21
23
  TRACKING_BACKEND = get_tracking_backend() # Options: 'mlflow', 'tensorboard', 'both'
22
24
  TENSORBOARD_LOG_DIR = os.getenv("RF_TENSORBOARD_LOG_DIR", None) # Default set by experiment path
@@ -6,18 +6,14 @@ Provides UI controls for managing training runs similar to the frontend.
6
6
  import json
7
7
  import threading
8
8
  import time
9
- from typing import Any, Dict, Optional
10
9
 
11
10
  import requests
12
- from IPython.display import clear_output, display
11
+ from IPython.display import display
13
12
 
14
13
  try:
15
14
  import ipywidgets as widgets
16
- except ImportError:
17
- raise ImportError(
18
- "ipywidgets is required for InteractiveController. "
19
- "Install with: pip install ipywidgets"
20
- )
15
+ except ImportError as e:
16
+ raise ImportError("ipywidgets is required for InteractiveController. Install with: pip install ipywidgets") from e
21
17
 
22
18
 
23
19
  class InteractiveController:
@@ -25,8 +21,8 @@ class InteractiveController:
25
21
 
26
22
  def __init__(self, dispatcher_url: str = "http://127.0.0.1:8081"):
27
23
  self.dispatcher_url = dispatcher_url.rstrip("/")
28
- self.run_id: Optional[int] = None
29
- self.config: Optional[Dict] = None
24
+ self.run_id: int | None = None
25
+ self.config: dict | None = None
30
26
  self.status: str = "Unknown"
31
27
  self.chunk_number: int = 0
32
28
 
@@ -37,22 +33,16 @@ class InteractiveController:
37
33
  """Create ipywidgets UI components"""
38
34
  # Run selector
39
35
  self.run_selector = widgets.Dropdown(
40
- options=[],
41
- description='',
42
- disabled=False,
43
- layout=widgets.Layout(width='300px')
36
+ options=[], description="", disabled=False, layout=widgets.Layout(width="300px")
44
37
  )
45
38
  self.load_btn = widgets.Button(
46
- description="Load Run",
47
- button_style="primary",
48
- tooltip="Load the selected run",
49
- icon="download"
39
+ description="Load Run", button_style="primary", tooltip="Load the selected run", icon="download"
50
40
  )
51
41
  self.refresh_selector_btn = widgets.Button(
52
42
  description="Refresh List",
53
43
  button_style="info",
54
44
  tooltip="Refresh the list of available runs",
55
- icon="refresh"
45
+ icon="refresh",
56
46
  )
57
47
 
58
48
  # Status display
@@ -67,12 +57,7 @@ class InteractiveController:
67
57
  tooltip="Resume this run",
68
58
  icon="play",
69
59
  )
70
- self.stop_btn = widgets.Button(
71
- description="Stop",
72
- button_style="danger",
73
- tooltip="Stop this run",
74
- icon="stop"
75
- )
60
+ self.stop_btn = widgets.Button(description="Stop", button_style="danger", tooltip="Stop this run", icon="stop")
76
61
  self.delete_btn = widgets.Button(
77
62
  description="Delete",
78
63
  button_style="danger",
@@ -97,32 +82,28 @@ class InteractiveController:
97
82
  value=False,
98
83
  description="Warm Start (continue from previous checkpoint)",
99
84
  disabled=True,
100
- style={'description_width': 'initial'},
101
- layout=widgets.Layout(margin='10px 0px')
85
+ style={"description_width": "initial"},
86
+ layout=widgets.Layout(margin="10px 0px"),
102
87
  )
103
88
  self.clone_btn = widgets.Button(
104
89
  description="Clone",
105
90
  button_style="primary",
106
91
  tooltip="Clone this run with modifications",
107
92
  )
108
- self.submit_clone_btn = widgets.Button(
109
- description=" Submit Clone", button_style="success", disabled=True
110
- )
111
- self.cancel_clone_btn = widgets.Button(
112
- description="✗ Cancel", button_style="", disabled=True
113
- )
93
+ self.submit_clone_btn = widgets.Button(description="✓ Submit Clone", button_style="success", disabled=True)
94
+ self.cancel_clone_btn = widgets.Button(description=" Cancel", button_style="", disabled=True)
114
95
 
115
96
  # Status message box
116
97
  self.status_message = widgets.HTML(
117
- value='',
98
+ value="",
118
99
  layout=widgets.Layout(
119
- width='100%',
120
- min_height='40px',
121
- padding='10px',
122
- margin='10px 0px',
123
- border='2px solid #ddd',
124
- border_radius='5px'
125
- )
100
+ width="100%",
101
+ min_height="40px",
102
+ padding="10px",
103
+ margin="10px 0px",
104
+ border="2px solid #ddd",
105
+ border_radius="5px",
106
+ ),
126
107
  )
127
108
 
128
109
  # Experiment status display (live progress)
@@ -148,7 +129,7 @@ class InteractiveController:
148
129
  self.cancel_clone_btn.on_click(lambda b: self._handle_cancel_clone())
149
130
 
150
131
  # Auto-load run when dropdown selection changes
151
- self.run_selector.observe(self._on_run_selected, names='value')
132
+ self.run_selector.observe(self._on_run_selected, names="value")
152
133
 
153
134
  def _show_message(self, message: str, message_type: str = "info"):
154
135
  """Display a status message with styling"""
@@ -156,23 +137,23 @@ class InteractiveController:
156
137
  "success": {"bg": "#d4edda", "border": "#28a745", "text": "#155724"},
157
138
  "error": {"bg": "#f8d7da", "border": "#dc3545", "text": "#721c24"},
158
139
  "info": {"bg": "#d1ecf1", "border": "#17a2b8", "text": "#0c5460"},
159
- "warning": {"bg": "#fff3cd", "border": "#ffc107", "text": "#856404"}
140
+ "warning": {"bg": "#fff3cd", "border": "#ffc107", "text": "#856404"},
160
141
  }
161
142
 
162
143
  style = colors.get(message_type, colors["info"])
163
144
 
164
- self.status_message.value = f'''
145
+ self.status_message.value = f"""
165
146
  <div style="
166
- background-color: {style['bg']};
167
- border: 2px solid {style['border']};
168
- color: {style['text']};
147
+ background-color: {style["bg"]};
148
+ border: 2px solid {style["border"]};
149
+ color: {style["text"]};
169
150
  padding: 10px;
170
151
  border-radius: 5px;
171
152
  font-weight: 600;
172
153
  ">
173
154
  {message}
174
155
  </div>
175
- '''
156
+ """
176
157
 
177
158
  def _update_experiment_status(self):
178
159
  """Update experiment status display with live progress"""
@@ -186,8 +167,8 @@ class InteractiveController:
186
167
 
187
168
  if runs:
188
169
  total_runs = len(runs)
189
- completed_runs = sum(1 for r in runs if r.get('status') == 'COMPLETED')
190
- ongoing_runs = sum(1 for r in runs if r.get('status') == 'ONGOING')
170
+ completed_runs = sum(1 for r in runs if r.get("status") == "COMPLETED")
171
+ ongoing_runs = sum(1 for r in runs if r.get("status") == "ONGOING")
191
172
 
192
173
  # Determine status color and icon
193
174
  if completed_runs == total_runs:
@@ -212,16 +193,16 @@ class InteractiveController:
212
193
  self.experiment_status.value = (
213
194
  f'<div style="padding: 10px; background-color: {bg_color}; '
214
195
  f'border: 2px solid {border_color}; border-radius: 5px; color: {text_color};">'
215
- f'<b>{icon} Experiment Status:</b> {status_text}<br>'
216
- f'<b>Progress:</b> {completed_runs}/{total_runs} runs completed'
217
- '</div>'
196
+ f"<b>{icon} Experiment Status:</b> {status_text}<br>"
197
+ f"<b>Progress:</b> {completed_runs}/{total_runs} runs completed"
198
+ "</div>"
218
199
  )
219
200
  else:
220
201
  self.experiment_status.value = (
221
202
  '<div style="padding: 10px; background-color: #f8f9fa; '
222
203
  'border: 2px solid #dee2e6; border-radius: 5px;">'
223
- '<b>Experiment Status:</b> No runs found'
224
- '</div>'
204
+ "<b>Experiment Status:</b> No runs found"
205
+ "</div>"
225
206
  )
226
207
 
227
208
  except requests.RequestException:
@@ -240,8 +221,7 @@ class InteractiveController:
240
221
 
241
222
  if runs:
242
223
  # Create options as (label, value) tuples
243
- options = [(f"Run {run['run_id']} - {run.get('status', 'Unknown')}", run['run_id'])
244
- for run in runs]
224
+ options = [(f"Run {run['run_id']} - {run.get('status', 'Unknown')}", run["run_id"]) for run in runs]
245
225
  self.run_selector.options = options
246
226
  self._show_message(f"Found {len(runs)} runs", "success")
247
227
  else:
@@ -257,8 +237,8 @@ class InteractiveController:
257
237
 
258
238
  def _on_run_selected(self, change):
259
239
  """Handle dropdown selection change - auto-load run"""
260
- if change['new'] is not None:
261
- self.load_run(change['new'])
240
+ if change["new"] is not None:
241
+ self.load_run(change["new"])
262
242
 
263
243
  def _handle_load(self):
264
244
  """Handle load button click"""
@@ -397,6 +377,7 @@ class InteractiveController:
397
377
 
398
378
  # Enable custom widget manager for ipywidgets to work in Colab
399
379
  from google.colab import output
380
+
400
381
  output.enable_custom_widget_manager()
401
382
  except ImportError:
402
383
  # Not in Colab, no action needed
@@ -455,9 +436,7 @@ class InteractiveController:
455
436
  ]
456
437
  )
457
438
 
458
- actions = widgets.HBox(
459
- [self.resume_btn, self.stop_btn, self.delete_btn, self.refresh_btn]
460
- )
439
+ actions = widgets.HBox([self.resume_btn, self.stop_btn, self.delete_btn, self.refresh_btn])
461
440
 
462
441
  config_section = widgets.VBox(
463
442
  [
@@ -1,7 +1,6 @@
1
1
  import os
2
2
  import threading
3
3
  from abc import ABC, abstractmethod
4
- from typing import Dict
5
4
 
6
5
  from loguru import logger
7
6
 
@@ -13,7 +12,7 @@ class BaseRFLogger(ABC):
13
12
  """Base class for RapidFire loggers"""
14
13
 
15
14
  _experiment_name = ""
16
- _initialized_loggers: Dict[str, bool] = {}
15
+ _initialized_loggers: dict[str, bool] = {}
17
16
  _lock = threading.Lock()
18
17
 
19
18
  def __init__(self, level: str = "DEBUG"):
@@ -1,4 +1,5 @@
1
1
  """This module contains the MLflowManager class which is responsible for managing the MLflow runs."""
2
+
2
3
  import mlflow
3
4
  from mlflow.tracking import MlflowClient
4
5
 
rapidfireai/utils/ping.py CHANGED
@@ -1,9 +1,10 @@
1
1
  #!/usr/bin/env python
2
- import socket
3
2
  import argparse
3
+ import socket
4
+
4
5
 
5
6
  def ping_server(server: str, port: int, timeout=3):
6
- """ping server:port """
7
+ """ping server:port"""
7
8
  try:
8
9
  socket.setdefaulttimeout(timeout)
9
10
  s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -14,6 +15,7 @@ def ping_server(server: str, port: int, timeout=3):
14
15
  s.close()
15
16
  return True
16
17
 
18
+
17
19
  if __name__ == "__main__":
18
20
  parser = argparse.ArgumentParser(description="Ping a server port")
19
21
  parser.add_argument("server", type=str, help="Server to ping")
rapidfireai/version.py CHANGED
@@ -2,5 +2,5 @@
2
2
  Version information for RapidFire AI
3
3
  """
4
4
 
5
- __version__ = "0.10.3rc1"
6
- __version_info__ = (0, 10, "3rc1")
5
+ __version__ = "0.11.1rc2"
6
+ __version_info__ = (0, 11, "1rc2")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rapidfireai
3
- Version: 0.10.3rc1
3
+ Version: 0.11.1rc2
4
4
  Summary: RapidFire AI: Rapid Experimentation Engine for Customizing LLMs
5
5
  Author-email: "RapidFire AI Inc." <support@rapidfire.ai>
6
6
  License: Apache-2.0