rapidfireai 0.10.3rc1__py3-none-any.whl → 0.11.1rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rapidfireai might be problematic. Click here for more details.
- rapidfireai/automl/grid_search.py +4 -5
- rapidfireai/automl/model_config.py +41 -37
- rapidfireai/automl/random_search.py +21 -33
- rapidfireai/backend/controller.py +54 -148
- rapidfireai/backend/worker.py +14 -3
- rapidfireai/cli.py +148 -136
- rapidfireai/experiment.py +22 -11
- rapidfireai/frontend/build/asset-manifest.json +3 -3
- rapidfireai/frontend/build/index.html +1 -1
- rapidfireai/frontend/build/static/js/{main.e7d3b759.js → main.aee6c455.js} +3 -3
- rapidfireai/frontend/build/static/js/{main.e7d3b759.js.map → main.aee6c455.js.map} +1 -1
- rapidfireai/ml/callbacks.py +10 -24
- rapidfireai/ml/trainer.py +37 -81
- rapidfireai/utils/constants.py +3 -1
- rapidfireai/utils/interactive_controller.py +40 -61
- rapidfireai/utils/logging.py +1 -2
- rapidfireai/utils/mlflow_manager.py +1 -0
- rapidfireai/utils/ping.py +4 -2
- rapidfireai/version.py +2 -2
- {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc2.dist-info}/METADATA +1 -1
- {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc2.dist-info}/RECORD +26 -26
- /rapidfireai/frontend/build/static/js/{main.e7d3b759.js.LICENSE.txt → main.aee6c455.js.LICENSE.txt} +0 -0
- {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc2.dist-info}/WHEEL +0 -0
- {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc2.dist-info}/entry_points.txt +0 -0
- {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc2.dist-info}/licenses/LICENSE +0 -0
- {rapidfireai-0.10.3rc1.dist-info → rapidfireai-0.11.1rc2.dist-info}/top_level.txt +0 -0
rapidfireai/ml/callbacks.py
CHANGED
|
@@ -1,14 +1,9 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections.abc import Callable
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
from datasets import Dataset
|
|
5
5
|
from tqdm import tqdm
|
|
6
|
-
from transformers import
|
|
7
|
-
TrainerCallback,
|
|
8
|
-
TrainerControl,
|
|
9
|
-
TrainerState,
|
|
10
|
-
TrainingArguments,
|
|
11
|
-
)
|
|
6
|
+
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
|
|
12
7
|
from transformers.trainer_utils import IntervalStrategy, SaveStrategy
|
|
13
8
|
|
|
14
9
|
|
|
@@ -17,7 +12,7 @@ class GenerationMetricsCallback(TrainerCallback):
|
|
|
17
12
|
self,
|
|
18
13
|
tokenizer,
|
|
19
14
|
eval_dataset: Dataset,
|
|
20
|
-
generation_config:
|
|
15
|
+
generation_config: dict | None = None,
|
|
21
16
|
compute_metrics: Callable = None,
|
|
22
17
|
batch_size: int = 8,
|
|
23
18
|
metric_logger=None,
|
|
@@ -90,9 +85,7 @@ class GenerationMetricsCallback(TrainerCallback):
|
|
|
90
85
|
elif "prompt" in item and "completion" in item:
|
|
91
86
|
input_text = item["prompt"]
|
|
92
87
|
reference = item["completion"][-1]["content"]
|
|
93
|
-
input_text = self.tokenizer.apply_chat_template(
|
|
94
|
-
input_text, tokenize=False
|
|
95
|
-
)
|
|
88
|
+
input_text = self.tokenizer.apply_chat_template(input_text, tokenize=False)
|
|
96
89
|
elif "text" in item:
|
|
97
90
|
# SFT format - use text as input, response as reference
|
|
98
91
|
input_text = item["text"]
|
|
@@ -114,7 +107,7 @@ class GenerationMetricsCallback(TrainerCallback):
|
|
|
114
107
|
|
|
115
108
|
return input_texts, references
|
|
116
109
|
|
|
117
|
-
def _generate_batch(self, model, input_texts:
|
|
110
|
+
def _generate_batch(self, model, input_texts: list[str]) -> torch.Tensor:
|
|
118
111
|
"""Generate text for a batch of inputs with defensive validation"""
|
|
119
112
|
# Defensive validation for empty inputs
|
|
120
113
|
if not input_texts:
|
|
@@ -136,7 +129,7 @@ class GenerationMetricsCallback(TrainerCallback):
|
|
|
136
129
|
print(f"Warning: Tokenization error in generation callback: {e}")
|
|
137
130
|
return torch.empty((0, 0), dtype=torch.long).to(model.device)
|
|
138
131
|
|
|
139
|
-
def _compute_generation_metrics(self, model, step: int) ->
|
|
132
|
+
def _compute_generation_metrics(self, model, step: int) -> dict[str, float]:
|
|
140
133
|
"""Generate text and compute BLEU/ROUGE metrics with batch processing"""
|
|
141
134
|
model.eval()
|
|
142
135
|
|
|
@@ -163,14 +156,10 @@ class GenerationMetricsCallback(TrainerCallback):
|
|
|
163
156
|
return {}
|
|
164
157
|
|
|
165
158
|
with torch.no_grad():
|
|
166
|
-
for i in tqdm(
|
|
167
|
-
range(0, len(indices), self.batch_size), desc="Generating for metrics"
|
|
168
|
-
):
|
|
159
|
+
for i in tqdm(range(0, len(indices), self.batch_size), desc="Generating for metrics"):
|
|
169
160
|
input_ids_batch = input_ids[i : i + self.batch_size]
|
|
170
161
|
with torch.inference_mode(), torch.cuda.amp.autocast():
|
|
171
|
-
outputs_batch = model.generate(
|
|
172
|
-
input_ids_batch, **self.generation_config
|
|
173
|
-
)
|
|
162
|
+
outputs_batch = model.generate(input_ids_batch, **self.generation_config)
|
|
174
163
|
generated_texts = self.tokenizer.batch_decode(
|
|
175
164
|
outputs_batch[:, input_ids_batch.shape[1] :],
|
|
176
165
|
skip_special_tokens=True,
|
|
@@ -257,7 +246,7 @@ class LogLevelCallback(TrainerCallback):
|
|
|
257
246
|
A [`TrainerCallback`] that handles the default flow of the training loop for logs, evaluation and checkpoints.
|
|
258
247
|
"""
|
|
259
248
|
|
|
260
|
-
def __init__(self, global_step_args:
|
|
249
|
+
def __init__(self, global_step_args: dict):
|
|
261
250
|
self.eval_first_step = global_step_args.get("eval_first_step", 0)
|
|
262
251
|
self.actual_steps = global_step_args.get("actual_steps", 0)
|
|
263
252
|
self.log_first_step = global_step_args.get("log_first_step", 0)
|
|
@@ -315,10 +304,7 @@ class LogLevelCallback(TrainerCallback):
|
|
|
315
304
|
control.should_log = True
|
|
316
305
|
|
|
317
306
|
# Evaluate
|
|
318
|
-
if
|
|
319
|
-
args.eval_strategy == IntervalStrategy.EPOCH
|
|
320
|
-
and args.eval_delay <= state.epoch
|
|
321
|
-
):
|
|
307
|
+
if args.eval_strategy == IntervalStrategy.EPOCH and args.eval_delay <= state.epoch:
|
|
322
308
|
control.should_evaluate = True
|
|
323
309
|
|
|
324
310
|
# Save
|
rapidfireai/ml/trainer.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import logging
|
|
2
1
|
import math
|
|
3
2
|
import os
|
|
4
3
|
|
|
@@ -7,11 +6,7 @@ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dic
|
|
|
7
6
|
from transformers.utils.logging import set_verbosity_error
|
|
8
7
|
from trl import DPOConfig, DPOTrainer, GRPOConfig, GRPOTrainer, SFTConfig, SFTTrainer
|
|
9
8
|
|
|
10
|
-
from rapidfireai.ml.callbacks import
|
|
11
|
-
GenerationMetricsCallback,
|
|
12
|
-
MLflowLoggingCallback,
|
|
13
|
-
LogLevelCallback,
|
|
14
|
-
)
|
|
9
|
+
from rapidfireai.ml.callbacks import GenerationMetricsCallback, LogLevelCallback, MLflowLoggingCallback
|
|
15
10
|
from rapidfireai.ml.checkpoint_utils import (
|
|
16
11
|
ensure_gradient_compatibility,
|
|
17
12
|
load_checkpoint_from_disk,
|
|
@@ -51,21 +46,15 @@ def create_trainer_instance(
|
|
|
51
46
|
compute_metrics = additional_trainer_kwargs.get("compute_metrics", None)
|
|
52
47
|
|
|
53
48
|
# Configure training arguments
|
|
54
|
-
training_args, global_step_args = _configure_training_args(
|
|
55
|
-
training_args, trainer_config
|
|
56
|
-
)
|
|
49
|
+
training_args, global_step_args = _configure_training_args(training_args, trainer_config)
|
|
57
50
|
trainer_config_obj = _create_trainer_config_object(trainer_type, training_args)
|
|
58
51
|
# check if peft params is empty dict
|
|
59
52
|
is_peft = bool(config_leaf.get("peft_params"))
|
|
60
53
|
# Load model and tokenizer
|
|
61
54
|
if use_shared_memory:
|
|
62
|
-
model_instance, tokenizer = load_checkpoint_from_shared_memory(
|
|
63
|
-
trainer_config, shm_manager, is_peft=is_peft
|
|
64
|
-
)
|
|
55
|
+
model_instance, tokenizer = load_checkpoint_from_shared_memory(trainer_config, shm_manager, is_peft=is_peft)
|
|
65
56
|
else:
|
|
66
|
-
model_instance, tokenizer = load_checkpoint_from_disk(
|
|
67
|
-
trainer_config, is_peft=is_peft
|
|
68
|
-
)
|
|
57
|
+
model_instance, tokenizer = load_checkpoint_from_disk(trainer_config, is_peft=is_peft)
|
|
69
58
|
# add model name to model config
|
|
70
59
|
config_leaf["model_name"] = model_instance.config._name_or_path
|
|
71
60
|
|
|
@@ -84,30 +73,26 @@ def create_trainer_instance(
|
|
|
84
73
|
|
|
85
74
|
model_instance = model_instance.to(device)
|
|
86
75
|
|
|
87
|
-
trainer_kwargs, formatting_func, additional_trainer_kwargs = (
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
config_leaf,
|
|
96
|
-
)
|
|
76
|
+
trainer_kwargs, formatting_func, additional_trainer_kwargs = _prepare_trainer_kwargs(
|
|
77
|
+
model_instance,
|
|
78
|
+
trainer_config_obj,
|
|
79
|
+
tokenizer,
|
|
80
|
+
trainer_config,
|
|
81
|
+
additional_trainer_kwargs,
|
|
82
|
+
ref_model_instance,
|
|
83
|
+
config_leaf,
|
|
97
84
|
)
|
|
98
85
|
|
|
99
|
-
callbacks, additional_trainer_kwargs = (
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
global_step_args,
|
|
110
|
-
)
|
|
86
|
+
callbacks, additional_trainer_kwargs = _setup_callbacks( # FIXME: avoid returning additional_trainer_kwargs
|
|
87
|
+
metric_logger,
|
|
88
|
+
trainer_config,
|
|
89
|
+
chunk_id,
|
|
90
|
+
compute_metrics,
|
|
91
|
+
additional_trainer_kwargs,
|
|
92
|
+
tokenizer,
|
|
93
|
+
training_args,
|
|
94
|
+
formatting_func,
|
|
95
|
+
global_step_args,
|
|
111
96
|
)
|
|
112
97
|
|
|
113
98
|
if callbacks:
|
|
@@ -116,29 +101,22 @@ def create_trainer_instance(
|
|
|
116
101
|
trainer_kwargs.update(additional_trainer_kwargs)
|
|
117
102
|
trainer_kwargs = {k: v for k, v in trainer_kwargs.items() if v is not None}
|
|
118
103
|
|
|
119
|
-
trainer = _create_trainer_by_type(
|
|
120
|
-
trainer_type, trainer_kwargs, trainer_config, use_shared_memory, shm_manager
|
|
121
|
-
)
|
|
104
|
+
trainer = _create_trainer_by_type(trainer_type, trainer_kwargs, trainer_config, use_shared_memory, shm_manager)
|
|
122
105
|
return trainer, config_leaf["model_name"]
|
|
123
106
|
|
|
124
107
|
|
|
125
|
-
def _configure_training_args(
|
|
126
|
-
training_args: dict, trainer_config: TrainerConfig
|
|
127
|
-
) -> dict:
|
|
108
|
+
def _configure_training_args(training_args: dict, trainer_config: TrainerConfig) -> dict:
|
|
128
109
|
"""Configure training arguments with default values."""
|
|
129
110
|
completed_steps = trainer_config.completed_steps
|
|
130
111
|
per_device_train_batch_size = training_args.get("per_device_train_batch_size", 1)
|
|
131
112
|
gradient_accumulation_steps = training_args.get("gradient_accumulation_steps", 1)
|
|
132
|
-
len_dataloader = math.ceil(
|
|
133
|
-
trainer_config.train_dataset.num_rows / per_device_train_batch_size
|
|
134
|
-
)
|
|
113
|
+
len_dataloader = math.ceil(trainer_config.train_dataset.num_rows / per_device_train_batch_size)
|
|
135
114
|
steps_per_epoch = max(
|
|
136
|
-
len_dataloader // gradient_accumulation_steps
|
|
137
|
-
+ int(len_dataloader % gradient_accumulation_steps > 0),
|
|
115
|
+
len_dataloader // gradient_accumulation_steps + int(len_dataloader % gradient_accumulation_steps > 0),
|
|
138
116
|
1,
|
|
139
117
|
)
|
|
140
118
|
|
|
141
|
-
if trainer_config.config_leaf.get("trainer_type","SFT") == "GRPO":
|
|
119
|
+
if trainer_config.config_leaf.get("trainer_type", "SFT") == "GRPO":
|
|
142
120
|
num_generations = training_args.get("num_generations", 8)
|
|
143
121
|
steps_per_epoch = (num_generations * trainer_config.train_dataset.num_rows) // (
|
|
144
122
|
gradient_accumulation_steps * per_device_train_batch_size
|
|
@@ -215,10 +193,7 @@ def _setup_reference_model(
|
|
|
215
193
|
if model_adapter_name is not None and ref_adapter_name is not None:
|
|
216
194
|
if use_shared_memory:
|
|
217
195
|
peft_config = LoraConfig(**config_leaf["peft_params"])
|
|
218
|
-
if
|
|
219
|
-
trainer_config.completed_steps == 0
|
|
220
|
-
and trainer_config.warm_started_from is None
|
|
221
|
-
):
|
|
196
|
+
if trainer_config.completed_steps == 0 and trainer_config.warm_started_from is None:
|
|
222
197
|
reference_state_dict = get_peft_model_state_dict(model_instance)
|
|
223
198
|
reference_state_dict = move_tensors_to_cpu(reference_state_dict)
|
|
224
199
|
shm_manager.save_model_object(
|
|
@@ -230,14 +205,10 @@ def _setup_reference_model(
|
|
|
230
205
|
reference_state_dict = shm_manager.load_model_object(
|
|
231
206
|
trainer_config.run_id, SHMObjectType.REF_STATE_DICT
|
|
232
207
|
)
|
|
233
|
-
reference_state_dict = move_tensors_to_device(
|
|
234
|
-
reference_state_dict, device
|
|
235
|
-
)
|
|
208
|
+
reference_state_dict = move_tensors_to_device(reference_state_dict, device)
|
|
236
209
|
model_instance.add_adapter(ref_adapter_name, peft_config)
|
|
237
210
|
model_instance.set_adapter(ref_adapter_name)
|
|
238
|
-
set_peft_model_state_dict(
|
|
239
|
-
model_instance, reference_state_dict, adapter_name=ref_adapter_name
|
|
240
|
-
)
|
|
211
|
+
set_peft_model_state_dict(model_instance, reference_state_dict, adapter_name=ref_adapter_name)
|
|
241
212
|
model_instance.set_adapter(model_adapter_name)
|
|
242
213
|
else:
|
|
243
214
|
base_run_path = DataPath.base_run_path(trainer_config.run_id)
|
|
@@ -289,9 +260,7 @@ def _prepare_trainer_kwargs(
|
|
|
289
260
|
|
|
290
261
|
if additional_trainer_kwargs.get("formatting_func") is not None:
|
|
291
262
|
formatting_func = additional_trainer_kwargs.get("formatting_func")
|
|
292
|
-
train_dataset = train_dataset.map(
|
|
293
|
-
formatting_func
|
|
294
|
-
) # FIXME: add try exception with batched/unbatched
|
|
263
|
+
train_dataset = train_dataset.map(formatting_func) # FIXME: add try exception with batched/unbatched
|
|
295
264
|
if eval_dataset is not None:
|
|
296
265
|
eval_dataset = eval_dataset.map(formatting_func)
|
|
297
266
|
additional_trainer_kwargs_copy = additional_trainer_kwargs.copy()
|
|
@@ -337,10 +306,7 @@ def _setup_callbacks(
|
|
|
337
306
|
)
|
|
338
307
|
callbacks.append(mlflow_callback)
|
|
339
308
|
|
|
340
|
-
if (
|
|
341
|
-
compute_metrics is not None
|
|
342
|
-
and additional_trainer_kwargs.get("generation_config") is not None
|
|
343
|
-
):
|
|
309
|
+
if compute_metrics is not None and additional_trainer_kwargs.get("generation_config") is not None:
|
|
344
310
|
compute_metrics_function = compute_metrics
|
|
345
311
|
if formatting_func is not None:
|
|
346
312
|
formatted_eval_dataset = trainer_config.eval_dataset.map(formatting_func)
|
|
@@ -365,15 +331,11 @@ def _setup_callbacks(
|
|
|
365
331
|
return callbacks, additional_trainer_kwargs
|
|
366
332
|
|
|
367
333
|
|
|
368
|
-
def _create_trainer_by_type(
|
|
369
|
-
trainer_type, trainer_kwargs, trainer_config, use_shared_memory, shm_manager
|
|
370
|
-
):
|
|
334
|
+
def _create_trainer_by_type(trainer_type, trainer_kwargs, trainer_config, use_shared_memory, shm_manager):
|
|
371
335
|
"""Create trainer instance based on type with proper state restoration."""
|
|
372
336
|
if trainer_type == "SFT":
|
|
373
337
|
dummy_trainer = SFTTrainer(**trainer_kwargs)
|
|
374
|
-
dummy_trainer.create_optimizer_and_scheduler(
|
|
375
|
-
num_training_steps=trainer_config.total_steps
|
|
376
|
-
)
|
|
338
|
+
dummy_trainer.create_optimizer_and_scheduler(num_training_steps=trainer_config.total_steps)
|
|
377
339
|
trainer = SFTTrainer(
|
|
378
340
|
**trainer_kwargs,
|
|
379
341
|
optimizers=(dummy_trainer.optimizer, dummy_trainer.lr_scheduler),
|
|
@@ -382,9 +344,7 @@ def _create_trainer_by_type(
|
|
|
382
344
|
|
|
383
345
|
elif trainer_type == "DPO":
|
|
384
346
|
dummy_trainer = DPOTrainer(**trainer_kwargs)
|
|
385
|
-
dummy_trainer.create_optimizer_and_scheduler(
|
|
386
|
-
num_training_steps=trainer_config.total_steps
|
|
387
|
-
)
|
|
347
|
+
dummy_trainer.create_optimizer_and_scheduler(num_training_steps=trainer_config.total_steps)
|
|
388
348
|
trainer = DPOTrainer(
|
|
389
349
|
**trainer_kwargs,
|
|
390
350
|
optimizers=(dummy_trainer.optimizer, dummy_trainer.lr_scheduler),
|
|
@@ -393,9 +353,7 @@ def _create_trainer_by_type(
|
|
|
393
353
|
|
|
394
354
|
elif trainer_type == "GRPO":
|
|
395
355
|
dummy_trainer = GRPOTrainer(**trainer_kwargs)
|
|
396
|
-
dummy_trainer.create_optimizer_and_scheduler(
|
|
397
|
-
num_training_steps=trainer_config.total_steps
|
|
398
|
-
)
|
|
356
|
+
dummy_trainer.create_optimizer_and_scheduler(num_training_steps=trainer_config.total_steps)
|
|
399
357
|
trainer = GRPOTrainer(
|
|
400
358
|
**trainer_kwargs,
|
|
401
359
|
optimizers=(dummy_trainer.optimizer, dummy_trainer.lr_scheduler),
|
|
@@ -406,9 +364,7 @@ def _create_trainer_by_type(
|
|
|
406
364
|
|
|
407
365
|
if trainer_config.completed_steps > 0:
|
|
408
366
|
if use_shared_memory:
|
|
409
|
-
trainer = restore_trainer_from_shared_memory(
|
|
410
|
-
trainer, trainer_config, shm_manager
|
|
411
|
-
)
|
|
367
|
+
trainer = restore_trainer_from_shared_memory(trainer, trainer_config, shm_manager)
|
|
412
368
|
else:
|
|
413
369
|
trainer = restore_trainer_from_disk(trainer, trainer_config)
|
|
414
370
|
|
rapidfireai/utils/constants.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
|
-
from enum import Enum
|
|
2
1
|
import os
|
|
2
|
+
from enum import Enum
|
|
3
3
|
|
|
4
4
|
MLFLOW_URL = "http://127.0.0.1:5002"
|
|
5
5
|
|
|
6
|
+
|
|
6
7
|
# Tracking Backend Configuration
|
|
7
8
|
def get_tracking_backend() -> str:
|
|
8
9
|
"""
|
|
@@ -17,6 +18,7 @@ def get_tracking_backend() -> str:
|
|
|
17
18
|
backend = os.getenv("RF_TRACKING_BACKEND", "mlflow")
|
|
18
19
|
return backend
|
|
19
20
|
|
|
21
|
+
|
|
20
22
|
# Backwards compatibility: Keep constant but it will be stale if env var changes after import
|
|
21
23
|
TRACKING_BACKEND = get_tracking_backend() # Options: 'mlflow', 'tensorboard', 'both'
|
|
22
24
|
TENSORBOARD_LOG_DIR = os.getenv("RF_TENSORBOARD_LOG_DIR", None) # Default set by experiment path
|
|
@@ -6,18 +6,14 @@ Provides UI controls for managing training runs similar to the frontend.
|
|
|
6
6
|
import json
|
|
7
7
|
import threading
|
|
8
8
|
import time
|
|
9
|
-
from typing import Any, Dict, Optional
|
|
10
9
|
|
|
11
10
|
import requests
|
|
12
|
-
from IPython.display import
|
|
11
|
+
from IPython.display import display
|
|
13
12
|
|
|
14
13
|
try:
|
|
15
14
|
import ipywidgets as widgets
|
|
16
|
-
except ImportError:
|
|
17
|
-
raise ImportError(
|
|
18
|
-
"ipywidgets is required for InteractiveController. "
|
|
19
|
-
"Install with: pip install ipywidgets"
|
|
20
|
-
)
|
|
15
|
+
except ImportError as e:
|
|
16
|
+
raise ImportError("ipywidgets is required for InteractiveController. Install with: pip install ipywidgets") from e
|
|
21
17
|
|
|
22
18
|
|
|
23
19
|
class InteractiveController:
|
|
@@ -25,8 +21,8 @@ class InteractiveController:
|
|
|
25
21
|
|
|
26
22
|
def __init__(self, dispatcher_url: str = "http://127.0.0.1:8081"):
|
|
27
23
|
self.dispatcher_url = dispatcher_url.rstrip("/")
|
|
28
|
-
self.run_id:
|
|
29
|
-
self.config:
|
|
24
|
+
self.run_id: int | None = None
|
|
25
|
+
self.config: dict | None = None
|
|
30
26
|
self.status: str = "Unknown"
|
|
31
27
|
self.chunk_number: int = 0
|
|
32
28
|
|
|
@@ -37,22 +33,16 @@ class InteractiveController:
|
|
|
37
33
|
"""Create ipywidgets UI components"""
|
|
38
34
|
# Run selector
|
|
39
35
|
self.run_selector = widgets.Dropdown(
|
|
40
|
-
options=[],
|
|
41
|
-
description='',
|
|
42
|
-
disabled=False,
|
|
43
|
-
layout=widgets.Layout(width='300px')
|
|
36
|
+
options=[], description="", disabled=False, layout=widgets.Layout(width="300px")
|
|
44
37
|
)
|
|
45
38
|
self.load_btn = widgets.Button(
|
|
46
|
-
description="Load Run",
|
|
47
|
-
button_style="primary",
|
|
48
|
-
tooltip="Load the selected run",
|
|
49
|
-
icon="download"
|
|
39
|
+
description="Load Run", button_style="primary", tooltip="Load the selected run", icon="download"
|
|
50
40
|
)
|
|
51
41
|
self.refresh_selector_btn = widgets.Button(
|
|
52
42
|
description="Refresh List",
|
|
53
43
|
button_style="info",
|
|
54
44
|
tooltip="Refresh the list of available runs",
|
|
55
|
-
icon="refresh"
|
|
45
|
+
icon="refresh",
|
|
56
46
|
)
|
|
57
47
|
|
|
58
48
|
# Status display
|
|
@@ -67,12 +57,7 @@ class InteractiveController:
|
|
|
67
57
|
tooltip="Resume this run",
|
|
68
58
|
icon="play",
|
|
69
59
|
)
|
|
70
|
-
self.stop_btn = widgets.Button(
|
|
71
|
-
description="Stop",
|
|
72
|
-
button_style="danger",
|
|
73
|
-
tooltip="Stop this run",
|
|
74
|
-
icon="stop"
|
|
75
|
-
)
|
|
60
|
+
self.stop_btn = widgets.Button(description="Stop", button_style="danger", tooltip="Stop this run", icon="stop")
|
|
76
61
|
self.delete_btn = widgets.Button(
|
|
77
62
|
description="Delete",
|
|
78
63
|
button_style="danger",
|
|
@@ -97,32 +82,28 @@ class InteractiveController:
|
|
|
97
82
|
value=False,
|
|
98
83
|
description="Warm Start (continue from previous checkpoint)",
|
|
99
84
|
disabled=True,
|
|
100
|
-
style={
|
|
101
|
-
layout=widgets.Layout(margin=
|
|
85
|
+
style={"description_width": "initial"},
|
|
86
|
+
layout=widgets.Layout(margin="10px 0px"),
|
|
102
87
|
)
|
|
103
88
|
self.clone_btn = widgets.Button(
|
|
104
89
|
description="Clone",
|
|
105
90
|
button_style="primary",
|
|
106
91
|
tooltip="Clone this run with modifications",
|
|
107
92
|
)
|
|
108
|
-
self.submit_clone_btn = widgets.Button(
|
|
109
|
-
|
|
110
|
-
)
|
|
111
|
-
self.cancel_clone_btn = widgets.Button(
|
|
112
|
-
description="✗ Cancel", button_style="", disabled=True
|
|
113
|
-
)
|
|
93
|
+
self.submit_clone_btn = widgets.Button(description="✓ Submit Clone", button_style="success", disabled=True)
|
|
94
|
+
self.cancel_clone_btn = widgets.Button(description="✗ Cancel", button_style="", disabled=True)
|
|
114
95
|
|
|
115
96
|
# Status message box
|
|
116
97
|
self.status_message = widgets.HTML(
|
|
117
|
-
value=
|
|
98
|
+
value="",
|
|
118
99
|
layout=widgets.Layout(
|
|
119
|
-
width=
|
|
120
|
-
min_height=
|
|
121
|
-
padding=
|
|
122
|
-
margin=
|
|
123
|
-
border=
|
|
124
|
-
border_radius=
|
|
125
|
-
)
|
|
100
|
+
width="100%",
|
|
101
|
+
min_height="40px",
|
|
102
|
+
padding="10px",
|
|
103
|
+
margin="10px 0px",
|
|
104
|
+
border="2px solid #ddd",
|
|
105
|
+
border_radius="5px",
|
|
106
|
+
),
|
|
126
107
|
)
|
|
127
108
|
|
|
128
109
|
# Experiment status display (live progress)
|
|
@@ -148,7 +129,7 @@ class InteractiveController:
|
|
|
148
129
|
self.cancel_clone_btn.on_click(lambda b: self._handle_cancel_clone())
|
|
149
130
|
|
|
150
131
|
# Auto-load run when dropdown selection changes
|
|
151
|
-
self.run_selector.observe(self._on_run_selected, names=
|
|
132
|
+
self.run_selector.observe(self._on_run_selected, names="value")
|
|
152
133
|
|
|
153
134
|
def _show_message(self, message: str, message_type: str = "info"):
|
|
154
135
|
"""Display a status message with styling"""
|
|
@@ -156,23 +137,23 @@ class InteractiveController:
|
|
|
156
137
|
"success": {"bg": "#d4edda", "border": "#28a745", "text": "#155724"},
|
|
157
138
|
"error": {"bg": "#f8d7da", "border": "#dc3545", "text": "#721c24"},
|
|
158
139
|
"info": {"bg": "#d1ecf1", "border": "#17a2b8", "text": "#0c5460"},
|
|
159
|
-
"warning": {"bg": "#fff3cd", "border": "#ffc107", "text": "#856404"}
|
|
140
|
+
"warning": {"bg": "#fff3cd", "border": "#ffc107", "text": "#856404"},
|
|
160
141
|
}
|
|
161
142
|
|
|
162
143
|
style = colors.get(message_type, colors["info"])
|
|
163
144
|
|
|
164
|
-
self.status_message.value = f
|
|
145
|
+
self.status_message.value = f"""
|
|
165
146
|
<div style="
|
|
166
|
-
background-color: {style[
|
|
167
|
-
border: 2px solid {style[
|
|
168
|
-
color: {style[
|
|
147
|
+
background-color: {style["bg"]};
|
|
148
|
+
border: 2px solid {style["border"]};
|
|
149
|
+
color: {style["text"]};
|
|
169
150
|
padding: 10px;
|
|
170
151
|
border-radius: 5px;
|
|
171
152
|
font-weight: 600;
|
|
172
153
|
">
|
|
173
154
|
{message}
|
|
174
155
|
</div>
|
|
175
|
-
|
|
156
|
+
"""
|
|
176
157
|
|
|
177
158
|
def _update_experiment_status(self):
|
|
178
159
|
"""Update experiment status display with live progress"""
|
|
@@ -186,8 +167,8 @@ class InteractiveController:
|
|
|
186
167
|
|
|
187
168
|
if runs:
|
|
188
169
|
total_runs = len(runs)
|
|
189
|
-
completed_runs = sum(1 for r in runs if r.get(
|
|
190
|
-
ongoing_runs = sum(1 for r in runs if r.get(
|
|
170
|
+
completed_runs = sum(1 for r in runs if r.get("status") == "COMPLETED")
|
|
171
|
+
ongoing_runs = sum(1 for r in runs if r.get("status") == "ONGOING")
|
|
191
172
|
|
|
192
173
|
# Determine status color and icon
|
|
193
174
|
if completed_runs == total_runs:
|
|
@@ -212,16 +193,16 @@ class InteractiveController:
|
|
|
212
193
|
self.experiment_status.value = (
|
|
213
194
|
f'<div style="padding: 10px; background-color: {bg_color}; '
|
|
214
195
|
f'border: 2px solid {border_color}; border-radius: 5px; color: {text_color};">'
|
|
215
|
-
f
|
|
216
|
-
f
|
|
217
|
-
|
|
196
|
+
f"<b>{icon} Experiment Status:</b> {status_text}<br>"
|
|
197
|
+
f"<b>Progress:</b> {completed_runs}/{total_runs} runs completed"
|
|
198
|
+
"</div>"
|
|
218
199
|
)
|
|
219
200
|
else:
|
|
220
201
|
self.experiment_status.value = (
|
|
221
202
|
'<div style="padding: 10px; background-color: #f8f9fa; '
|
|
222
203
|
'border: 2px solid #dee2e6; border-radius: 5px;">'
|
|
223
|
-
|
|
224
|
-
|
|
204
|
+
"<b>Experiment Status:</b> No runs found"
|
|
205
|
+
"</div>"
|
|
225
206
|
)
|
|
226
207
|
|
|
227
208
|
except requests.RequestException:
|
|
@@ -240,8 +221,7 @@ class InteractiveController:
|
|
|
240
221
|
|
|
241
222
|
if runs:
|
|
242
223
|
# Create options as (label, value) tuples
|
|
243
|
-
options = [(f"Run {run['run_id']} - {run.get('status', 'Unknown')}", run[
|
|
244
|
-
for run in runs]
|
|
224
|
+
options = [(f"Run {run['run_id']} - {run.get('status', 'Unknown')}", run["run_id"]) for run in runs]
|
|
245
225
|
self.run_selector.options = options
|
|
246
226
|
self._show_message(f"Found {len(runs)} runs", "success")
|
|
247
227
|
else:
|
|
@@ -257,8 +237,8 @@ class InteractiveController:
|
|
|
257
237
|
|
|
258
238
|
def _on_run_selected(self, change):
|
|
259
239
|
"""Handle dropdown selection change - auto-load run"""
|
|
260
|
-
if change[
|
|
261
|
-
self.load_run(change[
|
|
240
|
+
if change["new"] is not None:
|
|
241
|
+
self.load_run(change["new"])
|
|
262
242
|
|
|
263
243
|
def _handle_load(self):
|
|
264
244
|
"""Handle load button click"""
|
|
@@ -397,6 +377,7 @@ class InteractiveController:
|
|
|
397
377
|
|
|
398
378
|
# Enable custom widget manager for ipywidgets to work in Colab
|
|
399
379
|
from google.colab import output
|
|
380
|
+
|
|
400
381
|
output.enable_custom_widget_manager()
|
|
401
382
|
except ImportError:
|
|
402
383
|
# Not in Colab, no action needed
|
|
@@ -455,9 +436,7 @@ class InteractiveController:
|
|
|
455
436
|
]
|
|
456
437
|
)
|
|
457
438
|
|
|
458
|
-
actions = widgets.HBox(
|
|
459
|
-
[self.resume_btn, self.stop_btn, self.delete_btn, self.refresh_btn]
|
|
460
|
-
)
|
|
439
|
+
actions = widgets.HBox([self.resume_btn, self.stop_btn, self.delete_btn, self.refresh_btn])
|
|
461
440
|
|
|
462
441
|
config_section = widgets.VBox(
|
|
463
442
|
[
|
rapidfireai/utils/logging.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import threading
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from typing import Dict
|
|
5
4
|
|
|
6
5
|
from loguru import logger
|
|
7
6
|
|
|
@@ -13,7 +12,7 @@ class BaseRFLogger(ABC):
|
|
|
13
12
|
"""Base class for RapidFire loggers"""
|
|
14
13
|
|
|
15
14
|
_experiment_name = ""
|
|
16
|
-
_initialized_loggers:
|
|
15
|
+
_initialized_loggers: dict[str, bool] = {}
|
|
17
16
|
_lock = threading.Lock()
|
|
18
17
|
|
|
19
18
|
def __init__(self, level: str = "DEBUG"):
|
rapidfireai/utils/ping.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
#!/usr/bin/env python
|
|
2
|
-
import socket
|
|
3
2
|
import argparse
|
|
3
|
+
import socket
|
|
4
|
+
|
|
4
5
|
|
|
5
6
|
def ping_server(server: str, port: int, timeout=3):
|
|
6
|
-
"""ping server:port
|
|
7
|
+
"""ping server:port"""
|
|
7
8
|
try:
|
|
8
9
|
socket.setdefaulttimeout(timeout)
|
|
9
10
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
@@ -14,6 +15,7 @@ def ping_server(server: str, port: int, timeout=3):
|
|
|
14
15
|
s.close()
|
|
15
16
|
return True
|
|
16
17
|
|
|
18
|
+
|
|
17
19
|
if __name__ == "__main__":
|
|
18
20
|
parser = argparse.ArgumentParser(description="Ping a server port")
|
|
19
21
|
parser.add_argument("server", type=str, help="Server to ping")
|
rapidfireai/version.py
CHANGED