arbor-ai 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arbor/cli.py +89 -5
- arbor/client/api.py +1 -2
- arbor/server/api/models/schemas.py +209 -5
- arbor/server/api/routes/files.py +39 -10
- arbor/server/api/routes/grpo.py +54 -0
- arbor/server/api/routes/inference.py +53 -0
- arbor/server/api/routes/jobs.py +110 -7
- arbor/server/core/config.py +44 -7
- arbor/server/main.py +6 -5
- arbor/server/services/comms/__init__.py +0 -0
- arbor/server/services/comms/comms.py +226 -0
- arbor/server/services/dependencies.py +0 -16
- arbor/server/services/file_manager.py +270 -109
- arbor/server/services/grpo_manager.py +310 -0
- arbor/server/services/inference_manager.py +275 -0
- arbor/server/services/job_manager.py +74 -69
- arbor/server/services/scripts/grpo_training.py +576 -0
- arbor/server/services/training_manager.py +337 -40
- arbor_ai-0.1.6.dist-info/METADATA +78 -0
- arbor_ai-0.1.6.dist-info/RECORD +34 -0
- {arbor_ai-0.1.4.dist-info → arbor_ai-0.1.6.dist-info}/WHEEL +2 -1
- arbor_ai-0.1.6.dist-info/entry_points.txt +2 -0
- arbor_ai-0.1.6.dist-info/top_level.txt +1 -0
- arbor/server/api/routes/training.py +0 -16
- arbor_ai-0.1.4.dist-info/METADATA +0 -97
- arbor_ai-0.1.4.dist-info/RECORD +0 -27
- arbor_ai-0.1.4.dist-info/entry_points.txt +0 -3
- {arbor_ai-0.1.4.dist-info → arbor_ai-0.1.6.dist-info/licenses}/LICENSE +0 -0
@@ -1,76 +1,81 @@
|
|
1
1
|
import uuid
|
2
|
-
from enum import Enum
|
3
|
-
import logging
|
4
2
|
from datetime import datetime
|
3
|
+
from typing import Literal
|
4
|
+
|
5
|
+
from arbor.server.api.models.schemas import JobStatus
|
6
|
+
from arbor.server.core.config import Settings
|
7
|
+
|
8
|
+
|
9
|
+
class JobEvent:
|
10
|
+
def __init__(
|
11
|
+
self, level: Literal["info", "warning", "error"], message: str, data: dict = {}
|
12
|
+
):
|
13
|
+
self.level = level
|
14
|
+
self.message = message
|
15
|
+
self.data = data
|
16
|
+
|
17
|
+
self.id = str(f"ftevent-{uuid.uuid4()}")
|
18
|
+
self.created_at = datetime.now()
|
19
|
+
|
20
|
+
|
21
|
+
class JobCheckpoint:
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
fine_tuned_model_checkpoint: str,
|
25
|
+
fine_tuning_job_id: str,
|
26
|
+
metrics: dict,
|
27
|
+
step_number: int,
|
28
|
+
):
|
29
|
+
self.id = str(f"ftckpt-{uuid.uuid4()}")
|
30
|
+
self.fine_tuned_model_checkpoint = fine_tuned_model_checkpoint
|
31
|
+
self.fine_tuning_job_id = fine_tuning_job_id
|
32
|
+
self.metrics = metrics
|
33
|
+
self.step_number = step_number
|
34
|
+
self.created_at = datetime.now()
|
5
35
|
|
6
|
-
# https://platform.openai.com/docs/api-reference/fine-tuning/object
|
7
|
-
class JobStatus(Enum):
|
8
|
-
PENDING = "pending" # Not in OAI
|
9
|
-
VALIDATING_FILES = "validating_files"
|
10
|
-
QUEUED = "queued"
|
11
|
-
RUNNING = "running"
|
12
|
-
SUCCEEDED = "succeeded"
|
13
|
-
FAILED = "failed"
|
14
|
-
CANCELLED = "cancelled"
|
15
|
-
|
16
|
-
class JobLogHandler(logging.Handler):
|
17
|
-
def __init__(self, job):
|
18
|
-
super().__init__()
|
19
|
-
self.job = job
|
20
|
-
|
21
|
-
def emit(self, record):
|
22
|
-
log_entry = {
|
23
|
-
'timestamp': datetime.fromtimestamp(record.created).isoformat(),
|
24
|
-
'level': record.levelname,
|
25
|
-
'message': record.getMessage()
|
26
|
-
}
|
27
|
-
self.job.logs.append(log_entry)
|
28
36
|
|
29
37
|
class Job:
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
logger.addHandler(handler)
|
52
|
-
self.log_handler = handler
|
53
|
-
|
54
|
-
self.logger = logger
|
55
|
-
return logger
|
56
|
-
|
57
|
-
def cleanup_logger(self):
|
58
|
-
"""Removes the job's logging handler."""
|
59
|
-
if self.logger and self.log_handler:
|
60
|
-
self.logger.removeHandler(self.log_handler)
|
61
|
-
self.log_handler = None
|
62
|
-
self.logger = None
|
38
|
+
def __init__(self, status: JobStatus):
|
39
|
+
self.id = str(f"ftjob-{uuid.uuid4()}")
|
40
|
+
self.status = status
|
41
|
+
self.fine_tuned_model = None
|
42
|
+
self.events: list[JobEvent] = []
|
43
|
+
self.checkpoints: list[JobCheckpoint] = []
|
44
|
+
|
45
|
+
self.created_at = datetime.now()
|
46
|
+
|
47
|
+
def add_event(self, event: JobEvent):
|
48
|
+
self.events.append(event)
|
49
|
+
|
50
|
+
def get_events(self) -> list[JobEvent]:
|
51
|
+
return self.events
|
52
|
+
|
53
|
+
def add_checkpoint(self, checkpoint: JobCheckpoint):
|
54
|
+
self.checkpoints.append(checkpoint)
|
55
|
+
|
56
|
+
def get_checkpoints(self) -> list[JobCheckpoint]:
|
57
|
+
return self.checkpoints
|
58
|
+
|
63
59
|
|
64
60
|
class JobManager:
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
61
|
+
def __init__(self, settings: Settings):
|
62
|
+
self.jobs = {}
|
63
|
+
|
64
|
+
def get_job(self, job_id: str):
|
65
|
+
if job_id not in self.jobs:
|
66
|
+
raise ValueError(f"Job {job_id} not found")
|
67
|
+
return self.jobs[job_id]
|
68
|
+
|
69
|
+
def create_job(self):
|
70
|
+
job = Job(status=JobStatus.PENDING)
|
71
|
+
self.jobs[job.id] = job
|
72
|
+
return job
|
73
|
+
|
74
|
+
def get_jobs(self):
|
75
|
+
return list(self.jobs.values())
|
76
|
+
|
77
|
+
def get_active_job(self):
|
78
|
+
for job in self.jobs.values():
|
79
|
+
if job.status == JobStatus.RUNNING:
|
80
|
+
return job
|
81
|
+
return None
|