arbor-ai 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,19 +1,40 @@
1
+ import os
2
+ import random
3
+ import string
4
+ import time
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+
1
8
  from arbor.server.api.models.schemas import FineTuneRequest
2
- from arbor.server.services.job_manager import Job, JobStatus
9
+ from arbor.server.core.config import Settings
3
10
  from arbor.server.services.file_manager import FileManager
11
+ from arbor.server.services.job_manager import Job, JobEvent, JobStatus
4
12
 
5
- class TrainingManager:
6
- def __init__(self):
7
- pass
8
13
 
9
- def find_train_args(self, request: FineTuneRequest, file_manager: FileManager):
14
+ class TrainingManager:
15
+ def __init__(self, settings: Settings):
16
+ self.settings = settings
17
+
18
+ def make_output_dir(self, request: FineTuneRequest):
19
+ model_name = request.model.split("/")[-1].lower()
20
+ suffix = (
21
+ request.suffix
22
+ if request.suffix is not None
23
+ else "".join(random.choices(string.ascii_letters + string.digits, k=6))
24
+ )
25
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
26
+ name = f"ft:{model_name}:{suffix}:{timestamp}"
27
+ return name, str(Path(self.settings.STORAGE_PATH).resolve() / "models" / name)
28
+
29
+ def find_train_args_sft(self, request: FineTuneRequest, file_manager: FileManager):
10
30
  file = file_manager.get_file(request.training_file)
11
31
  if file is None:
12
32
  raise ValueError(f"Training file {request.training_file} not found")
13
33
 
14
34
  data_path = file["path"]
15
- output_dir = f"models/{request.model}" # TODO: This should be updated to be unique in some way
35
+ file_manager.validate_file_format_sft(data_path)
16
36
 
37
+ name, output_dir = self.make_output_dir(request)
17
38
 
18
39
  default_train_kwargs = {
19
40
  "device": None,
@@ -28,25 +49,225 @@ class TrainingManager:
28
49
  "output_dir": output_dir,
29
50
  "train_data_path": data_path,
30
51
  }
31
- train_kwargs = {'packing': False}
32
- train_kwargs={**default_train_kwargs, **(train_kwargs or {})}
33
- output_dir = train_kwargs["output_dir"] # user might have changed the output_dir
52
+ train_kwargs = {"packing": False}
53
+ train_kwargs = {**default_train_kwargs, **(train_kwargs or {})}
34
54
 
35
55
  return train_kwargs
36
56
 
57
+ def find_train_args_dpo(self, request: FineTuneRequest, file_manager: FileManager):
58
+ file = file_manager.get_file(request.training_file)
59
+ if file is None:
60
+ raise ValueError(f"Training file {request.training_file} not found")
61
+
62
+ data_path = file["path"]
63
+ file_manager.validate_file_format_dpo(data_path)
64
+
65
+ name, output_dir = self.make_output_dir(request)
66
+
67
+ default_train_kwargs = {
68
+ "device": "cuda:2",
69
+ "use_peft": False,
70
+ "num_train_epochs": 5,
71
+ "per_device_train_batch_size": 1,
72
+ "gradient_accumulation_steps": 8,
73
+ "learning_rate": 1e-5,
74
+ "packing": True,
75
+ "bf16": True,
76
+ "output_dir": output_dir,
77
+ "train_data_path": data_path,
78
+ "prompt_length": 1024,
79
+ "max_seq_length": 1512,
80
+ "use_peft": False,
81
+ }
82
+
83
+ # https://www.philschmid.de/dpo-align-llms-in-2024-with-trl#3-align-llm-with-trl-and-the-dpotrainer
84
+
85
+ train_kwargs = request.model_dump(exclude_unset=True)
86
+ train_kwargs = {**default_train_kwargs, **(train_kwargs or {})}
87
+
88
+ return train_kwargs
37
89
 
38
90
  def fine_tune(self, request: FineTuneRequest, job: Job, file_manager: FileManager):
39
- # Get logger for this job
40
- logger = job.setup_logger("training")
41
91
 
42
92
  job.status = JobStatus.RUNNING
43
- logger.info("Starting fine-tuning job")
93
+ job.add_event(
94
+ JobEvent(level="info", message="Starting fine-tuning job", data={})
95
+ )
96
+
97
+ fine_tune_type = request.method["type"]
98
+ if fine_tune_type == "dpo":
99
+ self.dpo_fine_tune(request, job, file_manager)
100
+ else:
101
+ self.sft_fine_tune(request, job, file_manager)
102
+
103
+ def dpo_fine_tune(
104
+ self, request: FineTuneRequest, job: Job, file_manager: FileManager
105
+ ):
106
+ try:
107
+
108
+ job.status = JobStatus.RUNNING
109
+ job.add_event(
110
+ JobEvent(level="info", message="Starting fine-tuning job", data={})
111
+ )
112
+
113
+ train_kwargs = self.find_train_args_dpo(request, file_manager)
114
+ import torch
115
+ from transformers import (
116
+ AutoModelForCausalLM,
117
+ AutoTokenizer,
118
+ TrainingArguments,
119
+ )
120
+ from trl import DPOConfig, DPOTrainer, setup_chat_format
121
+
122
+ device = train_kwargs.get("device", None)
123
+ if device is None:
124
+ device = (
125
+ "cuda"
126
+ if torch.cuda.is_available()
127
+ else "mps" if torch.backends.mps.is_available() else "cpu"
128
+ )
129
+
130
+ job.add_event(
131
+ JobEvent(level="info", message=f"Using device: {device}", data={})
132
+ )
133
+
134
+ model = AutoModelForCausalLM.from_pretrained(
135
+ pretrained_model_name_or_path=request.model, device_map="auto"
136
+ )
137
+ tokenizer = AutoTokenizer.from_pretrained(
138
+ pretrained_model_name_or_path=request.model
139
+ )
140
+
141
+ try:
142
+ model, tokenizer = setup_chat_format(model=model, tokenizer=tokenizer)
143
+ except Exception:
144
+ pass
145
+
146
+ if tokenizer.pad_token_id is None:
147
+ job.add_event(
148
+ JobEvent(
149
+ level="info", message="Adding pad token to tokenizer", data={}
150
+ )
151
+ )
152
+ tokenizer.add_special_tokens({"pad_token": "[!#PAD#!]"})
153
+
154
+ hf_dataset = dataset_from_file(train_kwargs["train_data_path"])
155
+ train_dataset = hf_dataset
156
+
157
+ use_peft = train_kwargs.get("use_peft", False)
158
+ peft_config = None
159
+
160
+ if use_peft:
161
+ from peft import LoraConfig
162
+
163
+ peft_config = LoraConfig(
164
+ lora_alpha=128,
165
+ lora_dropout=0.05,
166
+ r=256,
167
+ bias="none",
168
+ target_modules="all-linear",
169
+ task_type="CAUSAL_LM",
170
+ )
171
+
172
+ dpo_args = DPOConfig(
173
+ output_dir=train_kwargs["output_dir"],
174
+ num_train_epochs=train_kwargs["num_train_epochs"],
175
+ )
176
+
177
+ trainer = DPOTrainer(
178
+ model=model,
179
+ args=dpo_args,
180
+ train_dataset=train_dataset,
181
+ processing_class=tokenizer,
182
+ peft_config=peft_config,
183
+ )
184
+
185
+ trainer.train()
186
+
187
+ if job.status == JobStatus.PENDING_PAUSE:
188
+ trainer.accelerator.save_state(output_dir=dpo_args.output_dir)
189
+ current_step = trainer.state.global_step
190
+ job.status = JobStatus.PAUSED
191
+ job.add_event(
192
+ JobEvent(
193
+ level="info",
194
+ message="Training paused",
195
+ data={"step": current_step},
196
+ )
197
+ )
198
+
199
+ while job.status == JobStatus.PAUSED:
200
+ time.sleep(1) # Sleep to avoid busy waiting
201
+ if job.status == JobStatus.PENDING_RESUME:
202
+ job.status = JobStatus.RUNNING
203
+ job.add_event(JobEvent(level="info", message="Resuming training"))
204
+ trainer.accelerator.load_state(input_dir=dpo_args.output_dir)
205
+ trainer.train(resume_from_checkpoint=True)
206
+
207
+ if job.status == JobStatus.PENDING_CANCEL:
208
+ job.status = JobStatus.CANCELLED
209
+ job.add_event(JobEvent(level="info", message="Training cancelled"))
210
+
211
+ _cleanup(model, tokenizer, trainer)
212
+ raise Exception(
213
+ "Training cancelled"
214
+ ) # not sure if this should be raised or just return None
215
+
216
+ job.add_event(
217
+ JobEvent(level="info", message="Training completed successfully")
218
+ )
219
+
220
+ job.add_event(JobEvent(level="info", message="Saving model", data={}))
221
+ # Save the model!
222
+ trainer.save_model()
223
+ job.add_event(
224
+ JobEvent(
225
+ level="info",
226
+ message="Model saved",
227
+ data={"location": dpo_args.output_dir},
228
+ )
229
+ )
230
+
231
+ MERGE = True
232
+ if use_peft and MERGE:
233
+ from peft import AutoPeftModelForCausalLM
234
+
235
+ # Load PEFT model on CPU
236
+ model_ = AutoPeftModelForCausalLM.from_pretrained(
237
+ pretrained_model_name_or_path=dpo_args.output_dir,
238
+ torch_dtype=torch.float16,
239
+ low_cpu_mem_usage=True,
240
+ )
241
+
242
+ merged_model = model_.merge_and_unload()
243
+ merged_model.save_pretrained(
244
+ dpo_args.output_dir, safe_serialization=True, max_shard_size="5GB"
245
+ )
246
+
247
+ _cleanup(model, tokenizer, trainer)
248
+
249
+ job.status = JobStatus.SUCCEEDED
250
+ job.fine_tuned_model = dpo_args.output_dir
251
+ except Exception as e:
252
+ job.add_event(
253
+ JobEvent(level="error", message=f"Training failed: {str(e)}", data={})
254
+ )
255
+ job.status = JobStatus.FAILED
256
+ raise
257
+ finally:
258
+ pass
259
+
260
+ return dpo_args.output_dir
261
+
262
+ def sft_fine_tune(
263
+ self, request: FineTuneRequest, job: Job, file_manager: FileManager
264
+ ):
44
265
 
45
266
  try:
46
- train_kwargs = self.find_train_args(request, file_manager)
267
+ train_kwargs = self.find_train_args_sft(request, file_manager)
47
268
 
48
269
  import torch
49
- from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback
270
+ from transformers import AutoModelForCausalLM, AutoTokenizer
50
271
  from trl import SFTConfig, SFTTrainer, setup_chat_format
51
272
 
52
273
  device = train_kwargs.get("device", None)
@@ -56,12 +277,16 @@ class TrainingManager:
56
277
  if torch.cuda.is_available()
57
278
  else "mps" if torch.backends.mps.is_available() else "cpu"
58
279
  )
59
- logger.info(f"Using device: {device}")
280
+ job.add_event(
281
+ JobEvent(level="info", message=f"Using device: {device}", data={})
282
+ )
60
283
 
61
284
  model = AutoModelForCausalLM.from_pretrained(
62
285
  pretrained_model_name_or_path=request.model
63
286
  ).to(device)
64
- tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=request.model)
287
+ tokenizer = AutoTokenizer.from_pretrained(
288
+ pretrained_model_name_or_path=request.model
289
+ )
65
290
 
66
291
  # Set up the chat format; generally only for non-chat model variants, hence the try-except.
67
292
  try:
@@ -70,21 +295,40 @@ class TrainingManager:
70
295
  pass
71
296
 
72
297
  if tokenizer.pad_token_id is None:
73
- logger.info("Adding pad token to tokenizer")
298
+ job.add_event(
299
+ JobEvent(
300
+ level="info", message="Adding pad token to tokenizer", data={}
301
+ )
302
+ )
74
303
  tokenizer.add_special_tokens({"pad_token": "[!#PAD#!]"})
75
304
 
76
- logger.info("Creating dataset")
77
- if "max_seq_length" not in train_kwargs or train_kwargs["max_seq_length"] is None:
305
+ job.add_event(JobEvent(level="info", message="Creating dataset", data={}))
306
+ if (
307
+ "max_seq_length" not in train_kwargs
308
+ or train_kwargs["max_seq_length"] is None
309
+ ):
78
310
  train_kwargs["max_seq_length"] = 4096
79
- logger.info(f"The 'train_kwargs' parameter didn't include a 'max_seq_length', defaulting to {train_kwargs['max_seq_length']}")
80
-
311
+ job.add_event(
312
+ JobEvent(
313
+ level="info",
314
+ message=f"The 'train_kwargs' parameter didn't include a 'max_seq_length', defaulting to {train_kwargs['max_seq_length']}",
315
+ data={},
316
+ )
317
+ )
81
318
 
319
+ job.add_event(JobEvent(level="info", message="Tokenizing dataset", data={}))
82
320
  hf_dataset = dataset_from_file(train_kwargs["train_data_path"])
321
+
83
322
  def tokenize_function(example):
84
- return encode_sft_example(example, tokenizer, train_kwargs["max_seq_length"])
323
+ return encode_sft_example(
324
+ example, tokenizer, train_kwargs["max_seq_length"]
325
+ )
326
+
85
327
  tokenized_dataset = hf_dataset.map(tokenize_function, batched=False)
86
328
  tokenized_dataset.set_format(type="torch")
87
- tokenized_dataset = tokenized_dataset.filter(lambda example: (example["labels"] != -100).any())
329
+ tokenized_dataset = tokenized_dataset.filter(
330
+ lambda example: (example["labels"] != -100).any()
331
+ )
88
332
 
89
333
  USE_PEFT = train_kwargs.get("use_peft", False)
90
334
  peft_config = None
@@ -125,20 +369,61 @@ class TrainingManager:
125
369
  },
126
370
  )
127
371
 
128
- logger.info("Starting training")
372
+ job.add_event(JobEvent(level="info", message="Starting training"))
129
373
  trainer = SFTTrainer(
130
374
  model=model,
131
375
  args=sft_config,
132
376
  train_dataset=tokenized_dataset,
133
377
  peft_config=peft_config,
134
-
135
378
  )
136
379
 
137
380
  # Train!
138
381
  trainer.train()
139
382
 
383
+ if job.status == JobStatus.PENDING_PAUSE:
384
+ trainer.accelerator.save_state(output_dir=sft_config.output_dir)
385
+ current_step = trainer.state.global_step
386
+ job.status = JobStatus.PAUSED
387
+ job.add_event(
388
+ JobEvent(
389
+ level="info",
390
+ message="Training paused",
391
+ data={"step": current_step},
392
+ )
393
+ )
394
+
395
+ while job.status == JobStatus.PAUSED:
396
+ time.sleep(1) # Sleep to avoid busy waiting
397
+ if job.status == JobStatus.PENDING_RESUME:
398
+ job.status = JobStatus.RUNNING
399
+ job.add_event(JobEvent(level="info", message="Resuming training"))
400
+ trainer.accelerator.load_state(input_dir=sft_config.output_dir)
401
+ trainer.train(resume_from_checkpoint=True)
402
+
403
+ if job.status == JobStatus.PENDING_CANCEL:
404
+ job.status = JobStatus.CANCELLED
405
+ job.add_event(JobEvent(level="info", message="Training cancelled"))
406
+
407
+ _cleanup(model, tokenizer, trainer)
408
+
409
+ raise Exception(
410
+ "Training cancelled"
411
+ ) # not sure if this should be raised or just return None
412
+
413
+ job.add_event(
414
+ JobEvent(level="info", message="Training completed successfully")
415
+ )
416
+
417
+ job.add_event(JobEvent(level="info", message="Saving model", data={}))
140
418
  # Save the model!
141
419
  trainer.save_model()
420
+ job.add_event(
421
+ JobEvent(
422
+ level="info",
423
+ message="Model saved",
424
+ data={"location": sft_config.output_dir},
425
+ )
426
+ )
142
427
 
143
428
  MERGE = True
144
429
  if USE_PEFT and MERGE:
@@ -156,27 +441,22 @@ class TrainingManager:
156
441
  sft_config.output_dir, safe_serialization=True, max_shard_size="5GB"
157
442
  )
158
443
 
159
- # Clean up!
160
- import gc
444
+ _cleanup(model, tokenizer, trainer)
161
445
 
162
- del model
163
- del tokenizer
164
- del trainer
165
- gc.collect()
166
- torch.cuda.empty_cache()
167
-
168
- logger.info("Training completed successfully")
169
446
  job.status = JobStatus.SUCCEEDED
170
447
  job.fine_tuned_model = sft_config.output_dir
171
448
  except Exception as e:
172
- logger.error(f"Training failed: {str(e)}")
449
+ job.add_event(
450
+ JobEvent(level="error", message=f"Training failed: {str(e)}", data={})
451
+ )
173
452
  job.status = JobStatus.FAILED
174
453
  raise
175
454
  finally:
176
- job.cleanup_logger()
455
+ pass
177
456
 
178
457
  return sft_config.output_dir
179
458
 
459
+
180
460
  def dataset_from_file(data_path):
181
461
  """
182
462
  Creates a HuggingFace Dataset from a JSONL file.
@@ -218,7 +498,9 @@ def encode_sft_example(example, tokenizer, max_seq_length):
218
498
  message_start_idx = 0
219
499
  else:
220
500
  message_start_idx = tokenizer.apply_chat_template(
221
- conversation=messages[:message_idx], # here marks the end of the previous messages
501
+ conversation=messages[
502
+ :message_idx
503
+ ], # here marks the end of the previous messages
222
504
  tokenize=True,
223
505
  return_tensors="pt",
224
506
  padding=False,
@@ -227,7 +509,10 @@ def encode_sft_example(example, tokenizer, max_seq_length):
227
509
  add_generation_prompt=False,
228
510
  ).shape[1]
229
511
  # next, we calculate the end index of this non-assistant message
230
- if message_idx < len(messages) - 1 and messages[message_idx + 1]["role"] == "assistant":
512
+ if (
513
+ message_idx < len(messages) - 1
514
+ and messages[message_idx + 1]["role"] == "assistant"
515
+ ):
231
516
  # for intermediate messages that follow with an assistant message, we need to
232
517
  # set `add_generation_prompt=True` to avoid the assistant generation prefix being included in the loss
233
518
  # (e.g., `<|assistant|>`)
@@ -260,5 +545,17 @@ def encode_sft_example(example, tokenizer, max_seq_length):
260
545
  return {
261
546
  "input_ids": input_ids.flatten(),
262
547
  "labels": labels.flatten(),
263
- "attention_mask": attention_mask.flatten()
264
- }
548
+ "attention_mask": attention_mask.flatten(),
549
+ }
550
+
551
+
552
+ def _cleanup(model, tokenizer, trainer):
553
+ import gc
554
+
555
+ import torch
556
+
557
+ del model
558
+ del tokenizer
559
+ del trainer
560
+ gc.collect()
561
+ torch.cuda.empty_cache()
@@ -0,0 +1,78 @@
1
+ Metadata-Version: 2.4
2
+ Name: arbor-ai
3
+ Version: 0.1.6
4
+ Summary: A framework for fine-tuning and managing language models
5
+ Author-email: Noah Ziems <nziems2@nd.edu>
6
+ Project-URL: Homepage, https://github.com/Ziems/arbor
7
+ Project-URL: Issues, https://github.com/Ziems/arbor/issues
8
+ Requires-Python: >=3.10
9
+ Description-Content-Type: text/markdown
10
+ License-File: LICENSE
11
+ Requires-Dist: fastapi
12
+ Requires-Dist: uvicorn
13
+ Requires-Dist: click
14
+ Requires-Dist: python-multipart
15
+ Requires-Dist: pydantic-settings
16
+ Requires-Dist: torch
17
+ Requires-Dist: transformers
18
+ Requires-Dist: trl
19
+ Requires-Dist: peft
20
+ Requires-Dist: ray>=2.9
21
+ Requires-Dist: setuptools<77.0.0,>=76.0.0
22
+ Requires-Dist: pyzmq>=26.4.0
23
+ Requires-Dist: pyyaml>=6.0.2
24
+ Requires-Dist: sglang>=0.4.5.post3
25
+ Requires-Dist: sglang-router
26
+ Dynamic: license-file
27
+
28
+ <p align="center">
29
+ <img src="https://github.com/user-attachments/assets/ed0dd782-65fa-48b5-a762-b343b183be09" alt="Description" width="400"/>
30
+ </p>
31
+
32
+ **A framework for optimizing DSPy programs with RL.**
33
+
34
+ ---
35
+
36
+ ## 🚀 Installation
37
+
38
+ Install Arbor via pip:
39
+
40
+ ```bash
41
+ pip install git+https://github.com/Ziems/arbor.git
42
+ ```
43
+
44
+ ---
45
+
46
+ ## ⚡ Quick Start
47
+
48
+ ### 1️⃣ Make an `arbor.yaml` File
49
+
50
+ This is all dependent on your setup. Here is an example of one:
51
+ ```yaml
52
+ inference:
53
+ gpu_ids: '0'
54
+
55
+ training:
56
+ gpu_ids: '1, 2'
57
+ ```
58
+
59
+ ### 2️⃣ Start the Server
60
+
61
+ **CLI:**
62
+
63
+ ```bash
64
+ python -m arbor.cli serve --arbor-config arbor.yaml
65
+ ```
66
+
67
+ ### 3️⃣ Optimize a DSPy Program
68
+
69
+ Follow the DSPy tutorials here to see usage examples:
70
+ [DSPy RL Optimization Examples](https://dspy.ai/tutorials/rl_papillon/)
71
+
72
+ ---
73
+
74
+ ## 🙏 Acknowledgements
75
+
76
+ Arbor builds on the shoulders of great work. We extend our thanks to:
77
+ - **[Will Brown's Verifiers library](https://github.com/willccbb/verifiers)**
78
+ - **[Hugging Face TRL library](https://github.com/huggingface/trl)**
@@ -0,0 +1,34 @@
1
+ arbor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ arbor/cli.py,sha256=3o9A03Kew9cM5ZvD_6xOTaquNIE_hTYMOeQH3hkuJbY,3110
3
+ arbor/client/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ arbor/client/api.py,sha256=86bgHuGM_AvI1Uhic_QaCnpF4VFqXie9ZzxmbTXUPpQ,19
5
+ arbor/server/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
6
+ arbor/server/main.py,sha256=tY4Vlaaj4oq1FTGYOkbFMGF0quLEeR-VBaKaXhQ5mEE,382
7
+ arbor/server/api/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
8
+ arbor/server/api/models/schemas.py,sha256=s_G8sSb05FjkKEqpKpLlqaEd8NysJddHibRHhcnrKIk,5594
9
+ arbor/server/api/routes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ arbor/server/api/routes/files.py,sha256=DQC_ogH5zlzhHZSAA4Cj5wzK07XBIBVs2Po91W9rcDY,1835
11
+ arbor/server/api/routes/grpo.py,sha256=VuEvSOwwrHegn9qM-1nbHFmmUnnC_BMwnIHsfIdiJyI,1877
12
+ arbor/server/api/routes/inference.py,sha256=xlP-FMpOJAiiPZkE470l9mCR0ujLki8RrcO9hmTQD-k,1662
13
+ arbor/server/api/routes/jobs.py,sha256=BNdaSYUBJX6xSd6Pj6qx1DQJiZ5EKVxxbXDbEkfkCpw,3634
14
+ arbor/server/core/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
15
+ arbor/server/core/config.py,sha256=Mx77S3ByIMvHmPDikQLcczhzA5so3Vrw_U4QefOiHOU,1257
16
+ arbor/server/core/logging.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
+ arbor/server/services/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
+ arbor/server/services/dependencies.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
+ arbor/server/services/file_manager.py,sha256=Z9z4A4EzvPauid_DBfpim401DDtuJy_TbX4twTWDJWI,12119
20
+ arbor/server/services/grpo_manager.py,sha256=T-f1TrNSi_kmxPOcpaDphS8Xf3UMUbricocc6fuaKIM,12077
21
+ arbor/server/services/inference_manager.py,sha256=qR9xPiYs4Is24vgeF72w7Hbe8j_PGEbl-qewcvUV-dA,9731
22
+ arbor/server/services/job_manager.py,sha256=m_d4UPwN_82f7t7K443DaFpFoyv7JZSZKml8tawt1Bk,2186
23
+ arbor/server/services/training_manager.py,sha256=oQdhpfxdgp_lCTb_lxhvjupdLrcg6HL3TEbct_q9F6I,21065
24
+ arbor/server/services/comms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
+ arbor/server/services/comms/comms.py,sha256=Dg08D2Fm5TAEiGyr0Qcr0uocabQpFD_sBVhxIkj9D2M,7424
26
+ arbor/server/services/scripts/grpo_training.py,sha256=V36pCMZDJj2DdzquxScOddi9zP8EVPGWN3HGiftFfrY,21082
27
+ arbor/server/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
+ arbor/server/utils/helpers.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
+ arbor_ai-0.1.6.dist-info/licenses/LICENSE,sha256=5vFGrbOFeXXM83JV9o16w7ohH4WLeu3-57GocJSz8ow,1067
30
+ arbor_ai-0.1.6.dist-info/METADATA,sha256=ot1XsjoFawGbNBmaAaxv29lE_uNh8TTdl8OANkTPTS8,1823
31
+ arbor_ai-0.1.6.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
32
+ arbor_ai-0.1.6.dist-info/entry_points.txt,sha256=PGBX-MfNwfIl8UPFgsX3gjtXLqSogRhOktKMpZUysD0,40
33
+ arbor_ai-0.1.6.dist-info/top_level.txt,sha256=jzWdp3BRYqvZDMFsPajrcftvvlluzVDErkD8IMRfhYs,6
34
+ arbor_ai-0.1.6.dist-info/RECORD,,
@@ -1,4 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 2.1.1
2
+ Generator: setuptools (80.3.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ arbor = arbor.cli:cli
@@ -0,0 +1 @@
1
+ arbor
@@ -1,16 +0,0 @@
1
- from fastapi import APIRouter, BackgroundTasks, Depends
2
-
3
- from arbor.server.api.models.schemas import FineTuneRequest, JobStatusResponse
4
- from arbor.server.services.job_manager import JobManager, JobStatus
5
- from arbor.server.services.file_manager import FileManager
6
- from arbor.server.services.training_manager import TrainingManager
7
- from arbor.server.services.dependencies import get_training_manager, get_job_manager, get_file_manager
8
-
9
- router = APIRouter()
10
-
11
- @router.post("", response_model=JobStatusResponse)
12
- def fine_tune(request: FineTuneRequest, background_tasks: BackgroundTasks, training_manager: TrainingManager = Depends(get_training_manager), job_manager: JobManager = Depends(get_job_manager), file_manager: FileManager = Depends(get_file_manager)):
13
- job = job_manager.create_job()
14
- background_tasks.add_task(training_manager.fine_tune, request, job, file_manager)
15
- job.status = JobStatus.QUEUED
16
- return JobStatusResponse(id=job.id, status=job.status.value)