arbor-ai 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. arbor/client/__init__.py +0 -0
  2. arbor/client/api.py +1 -0
  3. arbor/server/__init__.py +1 -0
  4. arbor/server/api/__init__.py +1 -0
  5. arbor/server/api/models/schemas.py +223 -0
  6. arbor/server/api/routes/__init__.py +0 -0
  7. arbor/server/api/routes/files.py +52 -0
  8. arbor/server/api/routes/grpo.py +54 -0
  9. arbor/server/api/routes/inference.py +53 -0
  10. arbor/server/api/routes/jobs.py +117 -0
  11. arbor/server/core/__init__.py +1 -0
  12. arbor/server/core/config.py +47 -0
  13. arbor/server/core/logging.py +0 -0
  14. arbor/server/main.py +11 -0
  15. arbor/server/services/__init__.py +0 -0
  16. arbor/server/services/comms/__init__.py +0 -0
  17. arbor/server/services/comms/comms.py +226 -0
  18. arbor/server/services/dependencies.py +0 -0
  19. arbor/server/services/file_manager.py +289 -0
  20. arbor/server/services/grpo_manager.py +310 -0
  21. arbor/server/services/inference_manager.py +275 -0
  22. arbor/server/services/job_manager.py +81 -0
  23. arbor/server/services/scripts/grpo_training.py +576 -0
  24. arbor/server/services/training_manager.py +561 -0
  25. arbor/server/utils/__init__.py +0 -0
  26. arbor/server/utils/helpers.py +0 -0
  27. {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.6.dist-info}/METADATA +1 -1
  28. arbor_ai-0.1.6.dist-info/RECORD +34 -0
  29. arbor_ai-0.1.5.dist-info/RECORD +0 -8
  30. {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.6.dist-info}/WHEEL +0 -0
  31. {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.6.dist-info}/entry_points.txt +0 -0
  32. {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.6.dist-info}/licenses/LICENSE +0 -0
  33. {arbor_ai-0.1.5.dist-info → arbor_ai-0.1.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,310 @@
1
+ import json
2
+ import os
3
+ import random
4
+ import signal
5
+ import string
6
+ import subprocess
7
+ import sys
8
+ import threading
9
+ import time
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+ from typing import Optional
13
+
14
+ from arbor.server.api.models.schemas import GRPOConfigRequest, GRPORequest
15
+ from arbor.server.core.config import Settings
16
+ from arbor.server.services.comms.comms import ArborServerCommsHandler
17
+ from arbor.server.services.inference_manager import InferenceManager
18
+
19
+
20
+ class GRPOManager:
21
+ def __init__(self, settings: Settings):
22
+ self.settings = settings
23
+ self.training_process = None
24
+ self.current_model = None
25
+ self.train_kwargs = None
26
+ self.server_comms_handler = None
27
+ self.status_thread = None
28
+ self.model_saved_and_reload_requested = False
29
+
30
+ self.data_count = 0
31
+ self.last_inference_update = 0
32
+ # Set up signal handler
33
+ signal.signal(signal.SIGINT, self._signal_handler)
34
+ signal.signal(signal.SIGTERM, self._signal_handler)
35
+
36
+ def _signal_handler(self, signum, frame):
37
+ """Handle keyboard interrupt (SIGINT) gracefully."""
38
+ print("\nReceived keyboard interrupt. Shutting down gracefully...")
39
+ self.terminate(None)
40
+ sys.exit(0)
41
+
42
+ def make_output_dir(
43
+ self, model_name: str, run_suffix: Optional[str] = None
44
+ ) -> tuple[str, str]:
45
+ """Create a unique output directory name for the training run."""
46
+ model_name = model_name.split("/")[-1].lower()
47
+ suffix = (
48
+ run_suffix
49
+ if run_suffix
50
+ else "".join(random.choices(string.ascii_letters + string.digits, k=6))
51
+ )
52
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
53
+ name = f"grpo:{model_name}:{suffix}:{timestamp}"
54
+ return name, str(Path(self.settings.STORAGE_PATH).resolve() / "models" / name)
55
+
56
+ def find_training_args(self, request: GRPOConfigRequest) -> dict:
57
+ """Process the config request and return training arguments."""
58
+ name, output_dir = self.make_output_dir(request.model, request.suffix)
59
+
60
+ # Here are defaults for training. We can adjust them if we disagree w the huggingface defaults
61
+ default_train_kwargs = {
62
+ "output_dir": output_dir,
63
+ }
64
+
65
+ train_kwargs = request.model_dump(exclude_unset=True)
66
+ return {**default_train_kwargs, **(train_kwargs or {})}
67
+
68
+ def process_training_args(self, train_kwargs: dict) -> tuple[dict, dict]:
69
+ # NOTE: These also need to be in the GRPOConfigRequest
70
+ trl_keys = [
71
+ "output_dir",
72
+ "temperature",
73
+ "beta",
74
+ "num_iterations",
75
+ "num_generations",
76
+ "per_device_train_batch_size",
77
+ "learning_rate",
78
+ "gradient_accumulation_steps",
79
+ "gradient_checkpointing",
80
+ "lr_scheduler_type",
81
+ "max_prompt_length",
82
+ "max_completion_length",
83
+ "gradient_checkpointing_kwargs",
84
+ "bf16",
85
+ "scale_rewards",
86
+ "max_grad_norm",
87
+ ]
88
+ trl_train_kwargs = {
89
+ key: train_kwargs[key] for key in trl_keys if key in train_kwargs
90
+ }
91
+
92
+ arbor_keys = ["update_interval", "lora"]
93
+ arbor_train_kwargs = {
94
+ key: train_kwargs[key] for key in arbor_keys if key in train_kwargs
95
+ }
96
+
97
+ return trl_train_kwargs, arbor_train_kwargs
98
+
99
+ def initialize(
100
+ self, request: GRPOConfigRequest, inference_manager: InferenceManager
101
+ ):
102
+ """Initialize the training process with ZMQ-based communication."""
103
+ self.train_kwargs = self.find_training_args(request)
104
+
105
+ trl_train_kwargs, arbor_train_kwargs = self.process_training_args(
106
+ self.train_kwargs
107
+ )
108
+
109
+ self.current_model = request.model
110
+
111
+ # Initialize ZMQ socket manager - no need for connection acceptance thread anymore
112
+ self.server_comms_handler = ArborServerCommsHandler()
113
+
114
+ script_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts")
115
+ script_path = os.path.join(script_dir, "grpo_training.py")
116
+
117
+ # Start the training process with ZMQ ports
118
+ my_env = os.environ.copy()
119
+ my_env["CUDA_VISIBLE_DEVICES"] = self.settings.arbor_config.training.gpu_ids
120
+
121
+ num_processes = self.settings.arbor_config.training.gpu_ids.count(",") + 1
122
+
123
+ params = [
124
+ "python",
125
+ "-m",
126
+ "accelerate.commands.launch",
127
+ "--num_processes",
128
+ str(num_processes),
129
+ ]
130
+ if self.settings.arbor_config.training.accelerate_config:
131
+ params.extend(
132
+ [
133
+ "--config_file",
134
+ self.settings.arbor_config.training.accelerate_config,
135
+ ]
136
+ )
137
+ params.extend(
138
+ [
139
+ script_path,
140
+ # Comms args
141
+ "--host",
142
+ self.server_comms_handler.host,
143
+ "--command_port",
144
+ str(self.server_comms_handler.command_port),
145
+ "--status_port",
146
+ str(self.server_comms_handler.status_port),
147
+ "--data_port",
148
+ str(self.server_comms_handler.data_port),
149
+ "--broadcast_port",
150
+ str(self.server_comms_handler.broadcast_port),
151
+ "--handshake_port",
152
+ str(self.server_comms_handler.handshake_port),
153
+ # Training args
154
+ "--model",
155
+ self.current_model,
156
+ "--trl_train_kwargs",
157
+ json.dumps(trl_train_kwargs),
158
+ "--arbor_train_kwargs",
159
+ json.dumps(arbor_train_kwargs),
160
+ ]
161
+ )
162
+ print(f"Running following command\n: {' '.join(params)}")
163
+
164
+ self.training_process = subprocess.Popen(
165
+ params,
166
+ text=True,
167
+ stdout=subprocess.PIPE,
168
+ stderr=subprocess.STDOUT,
169
+ env=my_env,
170
+ )
171
+
172
+ # A threading.Event to control printing after the server is ready.
173
+ stop_printing_event = threading.Event()
174
+ logs_buffer = []
175
+
176
+ def _tail_process(proc, buffer, stop_event):
177
+ while True:
178
+ line = proc.stdout.readline()
179
+ if not line and proc.poll() is not None:
180
+ # Process ended and no new line
181
+ break
182
+ if line:
183
+ buffer.append(line)
184
+ # Print only if stop_event is not set
185
+ if not stop_event.is_set():
186
+ print(f"[GRPO LOG] {line}", end="")
187
+
188
+ # Start a background thread to read from the process continuously
189
+ thread = threading.Thread(
190
+ target=_tail_process,
191
+ args=(self.training_process, logs_buffer, stop_printing_event),
192
+ daemon=True,
193
+ )
194
+ thread.start()
195
+
196
+ # Start status handling thread
197
+ self.status_thread = threading.Thread(
198
+ target=self._handle_status_updates, args=(inference_manager,), daemon=True
199
+ )
200
+ self.status_thread.start()
201
+ self.server_comms_handler.wait_for_clients(num_processes)
202
+
203
+ # Launch the inference server
204
+ print("Launching inference server...")
205
+ inference_manager.launch(self.current_model)
206
+
207
+ def _handle_status_updates(self, inference_manager: InferenceManager):
208
+ """Handle status updates from training process using ZMQ SUB socket"""
209
+ print("Starting status update handler...")
210
+ try:
211
+
212
+ for status in self.server_comms_handler.receive_status():
213
+ print(f"Received status update: {status}")
214
+ if status["status"] == "model_saved":
215
+ print("Updating inference model...")
216
+ # There is a case where this status is sent multiple times
217
+ # We need to make sure we only update the model once
218
+ if self._should_update_model():
219
+ inference_manager.update_model(status["output_dir"])
220
+ # self.last_inference_update = self.data_count
221
+ self.model_saved_and_reload_requested = False
222
+ self.current_model = status["output_dir"]
223
+ print("Model update complete")
224
+ elif status["status"] == "error":
225
+ print(f"Training error: {status.get('error', 'Unknown error')}")
226
+ elif status["status"] == "terminated":
227
+ print("Training process terminated")
228
+ break
229
+ except Exception as e:
230
+ print(f"Error in status update handler: {e}")
231
+
232
+ def grpo_step(
233
+ self, request: GRPORequest, inference_manager: InferenceManager
234
+ ) -> str:
235
+ while inference_manager.is_server_restarting():
236
+ print("Inferece manager restarting, waiting for GRPO step")
237
+ time.sleep(5)
238
+
239
+ while self._should_update_model():
240
+ print(
241
+ f"Waiting for model update. Data count: {self.data_count}, Last inference update: {self.last_inference_update}"
242
+ )
243
+ time.sleep(5)
244
+
245
+ try:
246
+ # Send the batch to the training process
247
+ self.server_comms_handler.send_data(request.batch)
248
+ self.data_count += 1
249
+ except Exception as e:
250
+ print(f"Failed to send batch to training process: {e}")
251
+
252
+ # We tell the script to save the model. The script will let us know when it's done via the status update handler
253
+ # Then we'll actually run the update_model function in the inference manager and finally update the last_inference_update variable
254
+ # if self._should_update_model():
255
+ # self.server_comms_handler.send_command({"command": "save_model"})
256
+
257
+ return self.current_model
258
+
259
+ def update_model(self, request, inference_manager: InferenceManager):
260
+ # THIS IS HACKY AND NEEDS TO BE FIXED BEFORE RELEASE
261
+ inference_manager.restarting = True
262
+ self.model_saved_and_reload_requested = True
263
+ self.server_comms_handler.send_command({"command": "save_model"})
264
+ while self.model_saved_and_reload_requested:
265
+ print(
266
+ "Waiting for model to be saved and reloaded... This usually takes 20-30 seconds"
267
+ )
268
+ time.sleep(5)
269
+ return self.current_model
270
+
271
+ def terminate(self, inference_manager: InferenceManager):
272
+ """Clean up resources and save the final model."""
273
+ try:
274
+ # Stop the inference server
275
+ if inference_manager.process is not None:
276
+ inference_manager.kill()
277
+
278
+ # Send termination command through REQ socket
279
+ self.server_comms_handler.send_broadcast({"message": "terminate"})
280
+
281
+ # Wait for training process to finish
282
+ if self.training_process:
283
+ self.training_process.wait(timeout=30)
284
+
285
+ except Exception as e:
286
+ print(f"Error during termination: {e}")
287
+ finally:
288
+ # Clean up ZMQ connections
289
+ if self.server_comms_handler:
290
+ self.server_comms_handler.close()
291
+
292
+ if self.train_kwargs and "output_dir" in self.train_kwargs:
293
+ print(
294
+ f"Training completed. Model saved to {self.train_kwargs['output_dir']}"
295
+ )
296
+ if not os.path.exists(self.train_kwargs["output_dir"]):
297
+ print(
298
+ f"Warning: Output directory {self.train_kwargs['output_dir']} does not exist"
299
+ )
300
+ return self.train_kwargs["output_dir"]
301
+ else:
302
+ print("Training terminated, no output directory specified")
303
+ return None
304
+
305
+ def _should_update_model(self):
306
+ # return (
307
+ # self.data_count - self.last_inference_update
308
+ # >= self.train_kwargs["update_interval"]
309
+ # )
310
+ return self.model_saved_and_reload_requested
@@ -0,0 +1,275 @@
1
+ import os
2
+ import signal
3
+ import socket
4
+ import subprocess
5
+ import sys
6
+ import threading
7
+ import time
8
+ from datetime import datetime
9
+ from typing import Any, Dict, Optional
10
+
11
+ import requests
12
+
13
+ from arbor.server.core.config import Settings
14
+
15
+
16
+ class InferenceManager:
17
+ def __init__(self, settings: Settings):
18
+ self.settings = settings
19
+ self.process = None
20
+ self.launch_kwargs = {}
21
+ self.last_activity = None
22
+ self.restarting = False
23
+ self._shutting_down = False
24
+ self.current_model = None
25
+ self.inference_count = 0
26
+ # Set up signal handler for graceful shutdown
27
+ signal.signal(signal.SIGINT, self._signal_handler)
28
+ signal.signal(signal.SIGTERM, self._signal_handler)
29
+
30
+ def _signal_handler(self, signum, frame):
31
+ if self._shutting_down:
32
+ print("\nForced exit during cleanup...")
33
+ os._exit(1)
34
+
35
+ print("\nReceived signal to terminate. Cleaning up...")
36
+ self._shutting_down = True
37
+ self.kill()
38
+ sys.exit(0)
39
+
40
+ def is_server_running(self):
41
+ return self.process is not None
42
+
43
+ def is_server_restarting(self):
44
+ return self.restarting
45
+
46
+ def launch(self, model: str, launch_kwargs: Optional[Dict[str, Any]] = None):
47
+ if self.is_server_running():
48
+ print("Server is already launched.")
49
+ return
50
+
51
+ launch_kwargs = launch_kwargs or self.launch_kwargs
52
+
53
+ prefixes = ["openai/", "huggingface/", "local:", "arbor:"]
54
+ for prefix in prefixes:
55
+ if model.startswith(prefix):
56
+ model = model[len(prefix) :]
57
+
58
+ print(f"Grabbing a free port to launch an SGLang server for model {model}")
59
+ port = get_free_port()
60
+ timeout = launch_kwargs.get("timeout", 1800)
61
+ my_env = os.environ.copy()
62
+ my_env["CUDA_VISIBLE_DEVICES"] = self.settings.arbor_config.inference.gpu_ids
63
+ n_gpus = self.settings.arbor_config.inference.gpu_ids.count(",") + 1
64
+ # command = f"vllm serve {model} --port {port} --gpu-memory-utilization 0.9 --tensor-parallel-size {n_gpus} --max_model_len 8192 --enable_prefix_caching --guided-decoding-backend xgrammar"
65
+ command = f"python -m sglang_router.launch_server --model-path {model} --dp-size {n_gpus} --router-policy round_robin --port {port} --host 0.0.0.0"
66
+ print(f"Running command: {command}")
67
+
68
+ # We will manually stream & capture logs.
69
+ process = subprocess.Popen(
70
+ command.replace("\\\n", " ").replace("\\", " ").split(),
71
+ text=True,
72
+ stdout=subprocess.PIPE, # We'll read from pipe
73
+ stderr=subprocess.STDOUT, # Merge stderr into stdout
74
+ env=my_env,
75
+ )
76
+
77
+ # A threading.Event to control printing after the server is ready.
78
+ # This will store *all* lines (both before and after readiness).
79
+ print(f"SGLang server process started with PID {process.pid}.")
80
+ stop_printing_event = threading.Event()
81
+ logs_buffer = []
82
+
83
+ def _tail_process(proc, buffer, stop_event):
84
+ while True:
85
+ line = proc.stdout.readline()
86
+ if not line and proc.poll() is not None:
87
+ # Process ended and no new line
88
+ break
89
+ if line:
90
+ buffer.append(line)
91
+ # Print only if stop_event is not set
92
+ if not stop_event.is_set():
93
+ print(f"[SGLang LOG] {line}", end="")
94
+
95
+ # Start a background thread to read from the process continuously
96
+ thread = threading.Thread(
97
+ target=_tail_process,
98
+ args=(process, logs_buffer, stop_printing_event),
99
+ daemon=True,
100
+ )
101
+ thread.start()
102
+
103
+ # Wait until the server is ready (or times out)
104
+ base_url = f"http://localhost:{port}"
105
+ try:
106
+ wait_for_server(base_url, timeout=timeout)
107
+ except TimeoutError:
108
+ # If the server doesn't come up, we might want to kill it:
109
+ process.kill()
110
+ raise
111
+
112
+ # Once server is ready, we tell the thread to stop printing further lines.
113
+ stop_printing_event.set()
114
+
115
+ # A convenience getter so the caller can see all logs so far (and future).
116
+ def get_logs() -> str:
117
+ # Join them all into a single string, or you might return a list
118
+ return "".join(logs_buffer)
119
+
120
+ # Let the user know server is up
121
+ print(f"Server ready on random port {port}!")
122
+
123
+ self.launch_kwargs["api_base"] = f"http://localhost:{port}/v1"
124
+ self.launch_kwargs["api_key"] = "local"
125
+ self.get_logs = get_logs
126
+ self.process = process
127
+ self.thread = thread
128
+ self.current_model = model
129
+
130
+ def kill(self):
131
+ from sglang.utils import terminate_process
132
+
133
+ if self.process is None:
134
+ print("No running server to kill.")
135
+ return
136
+
137
+ process = self.process
138
+ thread = self.thread
139
+
140
+ terminate_process(process)
141
+
142
+ # Clear references first
143
+ self.process = None
144
+ self.thread = None
145
+ self.get_logs = None
146
+ self.last_activity = None
147
+
148
+ try:
149
+ # Handle nested signal case
150
+ if self._shutting_down:
151
+ process.kill() # Go straight to SIGKILL if we're shutting down
152
+ else:
153
+ process.terminate() # Try SIGTERM first
154
+ try:
155
+ process.wait(timeout=10)
156
+ except subprocess.TimeoutExpired:
157
+ print(
158
+ "Process did not terminate after 10 seconds, forcing with SIGKILL..."
159
+ )
160
+ process.kill()
161
+
162
+ process.wait(timeout=5)
163
+
164
+ if thread and thread.is_alive():
165
+ thread.join(timeout=5)
166
+
167
+ except Exception as e:
168
+ print(f"Error during cleanup: {e}")
169
+ try:
170
+ process.kill() # Final attempt to kill
171
+ except:
172
+ pass
173
+
174
+ print("Server killed.")
175
+
176
+ def run_inference(self, request_json: dict):
177
+ model = request_json["model"]
178
+ prefixes = ["openai/", "huggingface/", "local:", "arbor:"]
179
+ for prefix in prefixes:
180
+ if model.startswith(prefix):
181
+ model = model[len(prefix) :]
182
+ print(f"Running inference for model {model}")
183
+ # Monkeypatch:
184
+ if model != self.current_model:
185
+ print(f"MONKEYPATCH: Model changed from {model} to {self.current_model}")
186
+ model = self.current_model
187
+ request_json["model"] = model
188
+
189
+ # Update last_activity timestamp
190
+ self.last_activity = datetime.now()
191
+
192
+ if self.process is None or self.launch_kwargs.get("api_base") is None:
193
+ raise RuntimeError("Server is not running. Please launch it first.")
194
+
195
+ if self.restarting:
196
+ while self.restarting:
197
+ print("Inference is paused while server is restarting...")
198
+ time.sleep(5)
199
+ request_json["model"] = self.current_model
200
+
201
+ url = f"{self.launch_kwargs['api_base']}/chat/completions"
202
+ try:
203
+ self.inference_count += 1
204
+ response = requests.post(url, json=request_json)
205
+ return response.json()
206
+ except requests.exceptions.ConnectionError:
207
+ print("Server disconnected...ignoring")
208
+ return None
209
+ except Exception as e:
210
+ print(f"Error during inference: {e}")
211
+ raise
212
+ finally:
213
+ self.inference_count -= 1
214
+
215
+ def update_model(self, output_dir):
216
+ print("Restarting server with new model...")
217
+ self.restarting = True
218
+
219
+ while self.inference_count > 0:
220
+ print(
221
+ f"Waiting for inference requests to finish... {self.inference_count} remaining"
222
+ )
223
+ time.sleep(5)
224
+
225
+ tik = time.time()
226
+ self.kill()
227
+ print("Just killed server")
228
+ # Check that output directory exists and was created successfully
229
+ print(f"Checking that output directory {output_dir} exists")
230
+ if not os.path.exists(output_dir):
231
+ raise RuntimeError(
232
+ f"Failed to save model - output directory {output_dir} does not exist"
233
+ )
234
+
235
+ print("Launching new server")
236
+ self.launch(output_dir, self.launch_kwargs)
237
+ tok = time.time()
238
+ self.restarting = False
239
+ print(f"Time taken to update model: {tok - tik} seconds")
240
+
241
+
242
+ def get_free_port() -> int:
243
+ """
244
+ Return a free TCP port on localhost.
245
+ """
246
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
247
+ s.bind(("localhost", 0))
248
+ return s.getsockname()[1]
249
+
250
+
251
+ def wait_for_server(base_url: str, timeout: int = None) -> None:
252
+ """
253
+ Wait for the server to be ready by polling the /v1/models endpoint.
254
+
255
+ Args:
256
+ base_url: The base URL of the server (e.g. http://localhost:1234)
257
+ timeout: Maximum time to wait in seconds. None means wait forever.
258
+ """
259
+ start_time = time.time()
260
+ while True:
261
+ try:
262
+ response = requests.get(
263
+ f"{base_url}/v1/models",
264
+ headers={"Authorization": "Bearer None"},
265
+ )
266
+ if response.status_code == 200:
267
+ # A small extra sleep to ensure server is fully up.
268
+ time.sleep(5)
269
+ break
270
+
271
+ if timeout and (time.time() - start_time) > timeout:
272
+ raise TimeoutError("Server did not become ready within timeout period")
273
+ except requests.exceptions.RequestException:
274
+ # Server not up yet, wait and retry
275
+ time.sleep(1)
@@ -0,0 +1,81 @@
1
+ import uuid
2
+ from datetime import datetime
3
+ from typing import Literal
4
+
5
+ from arbor.server.api.models.schemas import JobStatus
6
+ from arbor.server.core.config import Settings
7
+
8
+
9
+ class JobEvent:
10
+ def __init__(
11
+ self, level: Literal["info", "warning", "error"], message: str, data: dict = {}
12
+ ):
13
+ self.level = level
14
+ self.message = message
15
+ self.data = data
16
+
17
+ self.id = str(f"ftevent-{uuid.uuid4()}")
18
+ self.created_at = datetime.now()
19
+
20
+
21
+ class JobCheckpoint:
22
+ def __init__(
23
+ self,
24
+ fine_tuned_model_checkpoint: str,
25
+ fine_tuning_job_id: str,
26
+ metrics: dict,
27
+ step_number: int,
28
+ ):
29
+ self.id = str(f"ftckpt-{uuid.uuid4()}")
30
+ self.fine_tuned_model_checkpoint = fine_tuned_model_checkpoint
31
+ self.fine_tuning_job_id = fine_tuning_job_id
32
+ self.metrics = metrics
33
+ self.step_number = step_number
34
+ self.created_at = datetime.now()
35
+
36
+
37
+ class Job:
38
+ def __init__(self, status: JobStatus):
39
+ self.id = str(f"ftjob-{uuid.uuid4()}")
40
+ self.status = status
41
+ self.fine_tuned_model = None
42
+ self.events: list[JobEvent] = []
43
+ self.checkpoints: list[JobCheckpoint] = []
44
+
45
+ self.created_at = datetime.now()
46
+
47
+ def add_event(self, event: JobEvent):
48
+ self.events.append(event)
49
+
50
+ def get_events(self) -> list[JobEvent]:
51
+ return self.events
52
+
53
+ def add_checkpoint(self, checkpoint: JobCheckpoint):
54
+ self.checkpoints.append(checkpoint)
55
+
56
+ def get_checkpoints(self) -> list[JobCheckpoint]:
57
+ return self.checkpoints
58
+
59
+
60
+ class JobManager:
61
+ def __init__(self, settings: Settings):
62
+ self.jobs = {}
63
+
64
+ def get_job(self, job_id: str):
65
+ if job_id not in self.jobs:
66
+ raise ValueError(f"Job {job_id} not found")
67
+ return self.jobs[job_id]
68
+
69
+ def create_job(self):
70
+ job = Job(status=JobStatus.PENDING)
71
+ self.jobs[job.id] = job
72
+ return job
73
+
74
+ def get_jobs(self):
75
+ return list(self.jobs.values())
76
+
77
+ def get_active_job(self):
78
+ for job in self.jobs.values():
79
+ if job.status == JobStatus.RUNNING:
80
+ return job
81
+ return None