langtune 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langtune/__init__.py +315 -0
- langtune/acceleration.py +132 -0
- langtune/api.py +320 -0
- langtune/auth.py +434 -0
- langtune/callbacks.py +268 -0
- langtune/cli.py +687 -0
- langtune/client.py +721 -0
- langtune/config.py +356 -0
- langtune/data.py +526 -0
- langtune/distributed.py +154 -0
- langtune/facade.py +174 -0
- langtune/finetune.py +491 -0
- langtune/generation.py +95 -0
- langtune/logging_utils.py +182 -0
- langtune/metrics.py +345 -0
- langtune/model/__init__.py +20 -0
- langtune/model/hub.py +109 -0
- langtune/model/loader.py +84 -0
- langtune/model/safetensors.py +104 -0
- langtune/model/weights.py +100 -0
- langtune/models.py +19 -0
- langtune/nn/fast_transformer.py +399 -0
- langtune/nn/layers.py +178 -0
- langtune/nn/transformer.py +254 -0
- langtune/optimizations.py +870 -0
- langtune/py.typed +2 -0
- langtune/schedulers.py +234 -0
- langtune/tokenizers.py +275 -0
- langtune/trainer.py +889 -0
- langtune/training/neftune.py +80 -0
- langtune/utils.py +337 -0
- langtune-0.1.19.dist-info/METADATA +257 -0
- langtune-0.1.19.dist-info/RECORD +37 -0
- langtune-0.1.19.dist-info/WHEEL +5 -0
- langtune-0.1.19.dist-info/entry_points.txt +2 -0
- langtune-0.1.19.dist-info/licenses/LICENSE +21 -0
- langtune-0.1.19.dist-info/top_level.txt +1 -0
langtune/client.py
ADDED
|
@@ -0,0 +1,721 @@
|
|
|
1
|
+
"""
|
|
2
|
+
client.py: Langtune Server Client
|
|
3
|
+
|
|
4
|
+
Client SDK for communicating with Langtrain server for heavy computation tasks.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import time
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
from typing import Optional, Dict, Any, List, Union
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from enum import Enum
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
# Default API base URL
|
|
19
|
+
DEFAULT_API_BASE = "https://api.langtrain.xyz/v1"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class JobStatus(Enum):
|
|
23
|
+
"""Fine-tuning job status."""
|
|
24
|
+
PENDING = "pending"
|
|
25
|
+
QUEUED = "queued"
|
|
26
|
+
RUNNING = "running"
|
|
27
|
+
COMPLETED = "completed"
|
|
28
|
+
FAILED = "failed"
|
|
29
|
+
CANCELLED = "cancelled"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class FineTuneJob:
|
|
34
|
+
"""Represents a fine-tuning job."""
|
|
35
|
+
id: str
|
|
36
|
+
status: JobStatus
|
|
37
|
+
model: str
|
|
38
|
+
created_at: str
|
|
39
|
+
updated_at: Optional[str] = None
|
|
40
|
+
completed_at: Optional[str] = None
|
|
41
|
+
error: Optional[str] = None
|
|
42
|
+
result_url: Optional[str] = None
|
|
43
|
+
metrics: Optional[Dict[str, float]] = None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class Model:
|
|
48
|
+
"""Represents an available model."""
|
|
49
|
+
id: str
|
|
50
|
+
name: str
|
|
51
|
+
description: str
|
|
52
|
+
parameters: int
|
|
53
|
+
context_length: int
|
|
54
|
+
supports_finetuning: bool
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class Agent:
|
|
59
|
+
"""Represents an AI agent."""
|
|
60
|
+
id: str
|
|
61
|
+
name: str
|
|
62
|
+
workspace_id: str
|
|
63
|
+
description: Optional[str] = None
|
|
64
|
+
model_id: Optional[str] = None
|
|
65
|
+
config: Optional[Dict[str, Any]] = None
|
|
66
|
+
is_active: bool = True
|
|
67
|
+
created_at: Optional[str] = None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@dataclass
|
|
71
|
+
class AgentRun:
|
|
72
|
+
"""Represents an agent execution run."""
|
|
73
|
+
id: str
|
|
74
|
+
status: str
|
|
75
|
+
agent_id: str
|
|
76
|
+
input: Dict[str, Any]
|
|
77
|
+
output: Optional[Dict[str, Any]] = None
|
|
78
|
+
token_usage: Optional[Dict[str, int]] = None
|
|
79
|
+
latency_ms: Optional[int] = None
|
|
80
|
+
created_at: Optional[str] = None
|
|
81
|
+
finished_at: Optional[str] = None
|
|
82
|
+
error: Optional[str] = None
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@dataclass
|
|
86
|
+
class UsageRecord:
|
|
87
|
+
"""Represents workspace usage for a period."""
|
|
88
|
+
tokens_used: int
|
|
89
|
+
tokens_limit: int
|
|
90
|
+
finetune_jobs_used: int
|
|
91
|
+
finetune_jobs_limit: int
|
|
92
|
+
agent_runs_used: int
|
|
93
|
+
agent_runs_limit: int
|
|
94
|
+
period_start: str
|
|
95
|
+
period_end: Optional[str] = None
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@dataclass
|
|
99
|
+
class Plan:
|
|
100
|
+
"""Represents a billing plan."""
|
|
101
|
+
id: Optional[str]
|
|
102
|
+
name: str
|
|
103
|
+
code: str
|
|
104
|
+
billing_period: str
|
|
105
|
+
limits: Dict[str, int]
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class LangtuneClient:
|
|
109
|
+
"""
|
|
110
|
+
Client for Langtrain API.
|
|
111
|
+
|
|
112
|
+
Handles authentication and communication with the server for:
|
|
113
|
+
- Fine-tuning jobs
|
|
114
|
+
- Text generation
|
|
115
|
+
- Model management
|
|
116
|
+
|
|
117
|
+
Example:
|
|
118
|
+
>>> client = LangtuneClient()
|
|
119
|
+
>>> job = client.create_finetune_job(
|
|
120
|
+
... training_data="path/to/data.jsonl",
|
|
121
|
+
... model="llama-7b"
|
|
122
|
+
... )
|
|
123
|
+
>>> client.wait_for_job(job.id)
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
def __init__(
|
|
127
|
+
self,
|
|
128
|
+
api_key: Optional[str] = None,
|
|
129
|
+
base_url: Optional[str] = None,
|
|
130
|
+
timeout: int = 300
|
|
131
|
+
):
|
|
132
|
+
"""
|
|
133
|
+
Initialize the client.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
api_key: API key (defaults to LANGTUNE_API_KEY env var)
|
|
137
|
+
base_url: API base URL (defaults to https://api.langtrain.xyz/v1)
|
|
138
|
+
timeout: Request timeout in seconds
|
|
139
|
+
"""
|
|
140
|
+
self.api_key = api_key or os.environ.get("LANGTUNE_API_KEY")
|
|
141
|
+
self.base_url = (base_url or os.environ.get("LANGTUNE_API_BASE") or DEFAULT_API_BASE).rstrip("/")
|
|
142
|
+
self.timeout = timeout
|
|
143
|
+
|
|
144
|
+
self._session = None
|
|
145
|
+
|
|
146
|
+
def _get_headers(self) -> Dict[str, str]:
|
|
147
|
+
"""Get request headers."""
|
|
148
|
+
headers = {
|
|
149
|
+
"Content-Type": "application/json",
|
|
150
|
+
"User-Agent": "langtune-python/0.1"
|
|
151
|
+
}
|
|
152
|
+
if self.api_key:
|
|
153
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
154
|
+
return headers
|
|
155
|
+
|
|
156
|
+
# ==================== API Key Validation ====================
|
|
157
|
+
|
|
158
|
+
def validate(self) -> Dict[str, Any]:
|
|
159
|
+
"""
|
|
160
|
+
Validate the API key and return plan/feature info.
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
dict with:
|
|
164
|
+
- valid: bool
|
|
165
|
+
- plan: str (free, pro, enterprise)
|
|
166
|
+
- features: list of feature names
|
|
167
|
+
- limits: dict of limits
|
|
168
|
+
- workspace_id: str
|
|
169
|
+
|
|
170
|
+
Raises:
|
|
171
|
+
APIError: If validation fails
|
|
172
|
+
"""
|
|
173
|
+
if not self.api_key:
|
|
174
|
+
return {"valid": False, "error": "No API key configured"}
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
response = self._request("POST", "/auth/api-keys/validate", {"api_key": self.api_key})
|
|
178
|
+
return response
|
|
179
|
+
except Exception as e:
|
|
180
|
+
return {"valid": False, "error": str(e)}
|
|
181
|
+
|
|
182
|
+
def is_valid(self) -> bool:
|
|
183
|
+
"""Check if API key is valid."""
|
|
184
|
+
result = self.validate()
|
|
185
|
+
return result.get("valid", False)
|
|
186
|
+
|
|
187
|
+
def get_features(self) -> List[str]:
|
|
188
|
+
"""Get list of available features for current plan."""
|
|
189
|
+
result = self.validate()
|
|
190
|
+
return result.get("features", [])
|
|
191
|
+
|
|
192
|
+
def has_feature(self, feature: str) -> bool:
|
|
193
|
+
"""Check if a specific feature is available."""
|
|
194
|
+
return feature in self.get_features()
|
|
195
|
+
|
|
196
|
+
def get_limits(self) -> Dict[str, int]:
|
|
197
|
+
"""Get current plan limits."""
|
|
198
|
+
result = self.validate()
|
|
199
|
+
return result.get("limits", {})
|
|
200
|
+
|
|
201
|
+
def _request(
|
|
202
|
+
self,
|
|
203
|
+
method: str,
|
|
204
|
+
endpoint: str,
|
|
205
|
+
data: Optional[Dict] = None,
|
|
206
|
+
files: Optional[Dict] = None
|
|
207
|
+
) -> Dict[str, Any]:
|
|
208
|
+
"""Make an API request."""
|
|
209
|
+
try:
|
|
210
|
+
import requests
|
|
211
|
+
except ImportError:
|
|
212
|
+
raise ImportError("requests library required. Install with: pip install requests")
|
|
213
|
+
|
|
214
|
+
url = f"{self.base_url}/{endpoint.lstrip('/')}"
|
|
215
|
+
|
|
216
|
+
try:
|
|
217
|
+
if files:
|
|
218
|
+
# Multipart form data
|
|
219
|
+
response = requests.request(
|
|
220
|
+
method,
|
|
221
|
+
url,
|
|
222
|
+
headers={"Authorization": f"Bearer {self.api_key}"},
|
|
223
|
+
files=files,
|
|
224
|
+
data=data,
|
|
225
|
+
timeout=self.timeout
|
|
226
|
+
)
|
|
227
|
+
else:
|
|
228
|
+
response = requests.request(
|
|
229
|
+
method,
|
|
230
|
+
url,
|
|
231
|
+
headers=self._get_headers(),
|
|
232
|
+
json=data,
|
|
233
|
+
timeout=self.timeout
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
response.raise_for_status()
|
|
237
|
+
return response.json()
|
|
238
|
+
|
|
239
|
+
except requests.exceptions.HTTPError as e:
|
|
240
|
+
error_msg = str(e)
|
|
241
|
+
try:
|
|
242
|
+
error_data = e.response.json()
|
|
243
|
+
error_msg = error_data.get("error", {}).get("message", str(e))
|
|
244
|
+
except:
|
|
245
|
+
pass
|
|
246
|
+
raise APIError(f"API error: {error_msg}")
|
|
247
|
+
except requests.exceptions.RequestException as e:
|
|
248
|
+
raise ConnectionError(f"Connection error: {e}")
|
|
249
|
+
|
|
250
|
+
# ==================== Fine-tuning ====================
|
|
251
|
+
|
|
252
|
+
def create_finetune_job(
|
|
253
|
+
self,
|
|
254
|
+
training_file: str,
|
|
255
|
+
model: str = "llama-7b",
|
|
256
|
+
training_method: str = "qlora",
|
|
257
|
+
validation_file: Optional[str] = None,
|
|
258
|
+
hyperparameters: Optional[Dict[str, Any]] = None,
|
|
259
|
+
suffix: Optional[str] = None,
|
|
260
|
+
sft_config: Optional[Dict[str, Any]] = None,
|
|
261
|
+
dpo_config: Optional[Dict[str, Any]] = None,
|
|
262
|
+
rlhf_config: Optional[Dict[str, Any]] = None,
|
|
263
|
+
) -> FineTuneJob:
|
|
264
|
+
"""
|
|
265
|
+
Create a fine-tuning job.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
training_file: Path to training data (JSONL format)
|
|
269
|
+
model: Base model to fine-tune
|
|
270
|
+
training_method: Training method - one of:
|
|
271
|
+
- "sft" (Supervised Fine-Tuning)
|
|
272
|
+
- "dpo" (Direct Preference Optimization)
|
|
273
|
+
- "rlhf" (Reinforcement Learning from Human Feedback)
|
|
274
|
+
- "lora" (LoRA adapters)
|
|
275
|
+
- "qlora" (Quantized LoRA, default)
|
|
276
|
+
validation_file: Optional validation data
|
|
277
|
+
hyperparameters: Training hyperparameters
|
|
278
|
+
suffix: Suffix for the fine-tuned model name
|
|
279
|
+
sft_config: SFT-specific config (packing, dataset_text_field)
|
|
280
|
+
dpo_config: DPO-specific config (beta, loss_type)
|
|
281
|
+
rlhf_config: RLHF-specific config (reward_model, ppo_epochs)
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
FineTuneJob object
|
|
285
|
+
"""
|
|
286
|
+
# Upload training file first
|
|
287
|
+
training_file_id = self._upload_file(training_file, "fine-tune")
|
|
288
|
+
|
|
289
|
+
data = {
|
|
290
|
+
"training_file": training_file_id,
|
|
291
|
+
"model": model,
|
|
292
|
+
"training_method": training_method
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
if validation_file:
|
|
296
|
+
val_file_id = self._upload_file(validation_file, "fine-tune")
|
|
297
|
+
data["validation_file"] = val_file_id
|
|
298
|
+
|
|
299
|
+
if hyperparameters:
|
|
300
|
+
data["hyperparameters"] = hyperparameters
|
|
301
|
+
|
|
302
|
+
if suffix:
|
|
303
|
+
data["suffix"] = suffix
|
|
304
|
+
|
|
305
|
+
# Method-specific configs
|
|
306
|
+
if sft_config and training_method == "sft":
|
|
307
|
+
data["sft_config"] = sft_config
|
|
308
|
+
if dpo_config and training_method == "dpo":
|
|
309
|
+
data["dpo_config"] = dpo_config
|
|
310
|
+
if rlhf_config and training_method == "rlhf":
|
|
311
|
+
data["rlhf_config"] = rlhf_config
|
|
312
|
+
|
|
313
|
+
response = self._request("POST", "/fine-tuning/jobs", data)
|
|
314
|
+
return self._parse_job(response)
|
|
315
|
+
|
|
316
|
+
def get_finetune_job(self, job_id: str) -> FineTuneJob:
|
|
317
|
+
"""Get fine-tuning job status."""
|
|
318
|
+
response = self._request("GET", f"/fine-tuning/jobs/{job_id}")
|
|
319
|
+
return self._parse_job(response)
|
|
320
|
+
|
|
321
|
+
def list_finetune_jobs(self, limit: int = 10) -> List[FineTuneJob]:
|
|
322
|
+
"""List fine-tuning jobs."""
|
|
323
|
+
response = self._request("GET", f"/fine-tuning/jobs?limit={limit}")
|
|
324
|
+
return [self._parse_job(j) for j in response.get("data", [])]
|
|
325
|
+
|
|
326
|
+
def cancel_finetune_job(self, job_id: str) -> FineTuneJob:
|
|
327
|
+
"""Cancel a fine-tuning job."""
|
|
328
|
+
response = self._request("POST", f"/fine-tuning/jobs/{job_id}/cancel")
|
|
329
|
+
return self._parse_job(response)
|
|
330
|
+
|
|
331
|
+
def wait_for_job(
|
|
332
|
+
self,
|
|
333
|
+
job_id: str,
|
|
334
|
+
poll_interval: int = 30,
|
|
335
|
+
timeout: Optional[int] = None,
|
|
336
|
+
callback: Optional[callable] = None
|
|
337
|
+
) -> FineTuneJob:
|
|
338
|
+
"""
|
|
339
|
+
Wait for a job to complete.
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
job_id: Job ID
|
|
343
|
+
poll_interval: Seconds between status checks
|
|
344
|
+
timeout: Maximum wait time (None for no limit)
|
|
345
|
+
callback: Optional callback function(job) called on each poll
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
Completed job
|
|
349
|
+
"""
|
|
350
|
+
start_time = time.time()
|
|
351
|
+
|
|
352
|
+
while True:
|
|
353
|
+
job = self.get_finetune_job(job_id)
|
|
354
|
+
|
|
355
|
+
if callback:
|
|
356
|
+
callback(job)
|
|
357
|
+
|
|
358
|
+
if job.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED):
|
|
359
|
+
return job
|
|
360
|
+
|
|
361
|
+
if timeout and (time.time() - start_time) > timeout:
|
|
362
|
+
raise TimeoutError(f"Job {job_id} did not complete within {timeout}s")
|
|
363
|
+
|
|
364
|
+
logger.info(f"Job {job_id} status: {job.status.value}")
|
|
365
|
+
time.sleep(poll_interval)
|
|
366
|
+
|
|
367
|
+
def _parse_job(self, data: Dict) -> FineTuneJob:
|
|
368
|
+
"""Parse job response."""
|
|
369
|
+
return FineTuneJob(
|
|
370
|
+
id=data["id"],
|
|
371
|
+
status=JobStatus(data.get("status", "pending")),
|
|
372
|
+
model=data.get("model", ""),
|
|
373
|
+
created_at=data.get("created_at", ""),
|
|
374
|
+
updated_at=data.get("updated_at"),
|
|
375
|
+
completed_at=data.get("finished_at"),
|
|
376
|
+
error=data.get("error", {}).get("message") if data.get("error") else None,
|
|
377
|
+
result_url=data.get("result_files", [None])[0] if data.get("result_files") else None,
|
|
378
|
+
metrics=data.get("metrics")
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
# ==================== Files ====================
|
|
382
|
+
|
|
383
|
+
def _upload_file(self, file_path: str, purpose: str = "fine-tune") -> str:
|
|
384
|
+
"""Upload a file and return file ID."""
|
|
385
|
+
path = Path(file_path)
|
|
386
|
+
if not path.exists():
|
|
387
|
+
raise FileNotFoundError(f"File not found: {file_path}")
|
|
388
|
+
|
|
389
|
+
with open(path, "rb") as f:
|
|
390
|
+
response = self._request(
|
|
391
|
+
"POST",
|
|
392
|
+
"/files",
|
|
393
|
+
data={"purpose": purpose},
|
|
394
|
+
files={"file": (path.name, f)}
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
return response["id"]
|
|
398
|
+
|
|
399
|
+
# ==================== Generation ====================
|
|
400
|
+
|
|
401
|
+
def generate(
|
|
402
|
+
self,
|
|
403
|
+
prompt: str,
|
|
404
|
+
model: str = "llama-7b",
|
|
405
|
+
max_tokens: int = 256,
|
|
406
|
+
temperature: float = 0.7,
|
|
407
|
+
top_p: float = 0.9,
|
|
408
|
+
stop: Optional[List[str]] = None
|
|
409
|
+
) -> str:
|
|
410
|
+
"""
|
|
411
|
+
Generate text completion.
|
|
412
|
+
|
|
413
|
+
Args:
|
|
414
|
+
prompt: Input prompt
|
|
415
|
+
model: Model to use
|
|
416
|
+
max_tokens: Maximum tokens to generate
|
|
417
|
+
temperature: Sampling temperature
|
|
418
|
+
top_p: Nucleus sampling parameter
|
|
419
|
+
stop: Stop sequences
|
|
420
|
+
|
|
421
|
+
Returns:
|
|
422
|
+
Generated text
|
|
423
|
+
"""
|
|
424
|
+
data = {
|
|
425
|
+
"model": model,
|
|
426
|
+
"prompt": prompt,
|
|
427
|
+
"max_tokens": max_tokens,
|
|
428
|
+
"temperature": temperature,
|
|
429
|
+
"top_p": top_p
|
|
430
|
+
}
|
|
431
|
+
if stop:
|
|
432
|
+
data["stop"] = stop
|
|
433
|
+
|
|
434
|
+
response = self._request("POST", "/completions", data)
|
|
435
|
+
return response["choices"][0]["text"]
|
|
436
|
+
|
|
437
|
+
def chat(
|
|
438
|
+
self,
|
|
439
|
+
messages: List[Dict[str, str]],
|
|
440
|
+
model: str = "llama-7b-chat",
|
|
441
|
+
max_tokens: int = 256,
|
|
442
|
+
temperature: float = 0.7
|
|
443
|
+
) -> str:
|
|
444
|
+
"""
|
|
445
|
+
Chat completion.
|
|
446
|
+
|
|
447
|
+
Args:
|
|
448
|
+
messages: List of {"role": "user/assistant", "content": "..."}
|
|
449
|
+
model: Model to use
|
|
450
|
+
max_tokens: Maximum tokens
|
|
451
|
+
temperature: Sampling temperature
|
|
452
|
+
|
|
453
|
+
Returns:
|
|
454
|
+
Assistant response
|
|
455
|
+
"""
|
|
456
|
+
data = {
|
|
457
|
+
"model": model,
|
|
458
|
+
"messages": messages,
|
|
459
|
+
"max_tokens": max_tokens,
|
|
460
|
+
"temperature": temperature
|
|
461
|
+
}
|
|
462
|
+
|
|
463
|
+
response = self._request("POST", "/chat/completions", data)
|
|
464
|
+
return response["choices"][0]["message"]["content"]
|
|
465
|
+
|
|
466
|
+
# ==================== Models ====================
|
|
467
|
+
|
|
468
|
+
def list_models(self) -> List[Model]:
|
|
469
|
+
"""List available models."""
|
|
470
|
+
response = self._request("GET", "/models")
|
|
471
|
+
return [
|
|
472
|
+
Model(
|
|
473
|
+
id=m["id"],
|
|
474
|
+
name=m.get("name", m["id"]),
|
|
475
|
+
description=m.get("description", ""),
|
|
476
|
+
parameters=m.get("parameters", 0),
|
|
477
|
+
context_length=m.get("context_length", 4096),
|
|
478
|
+
supports_finetuning=m.get("supports_finetuning", False)
|
|
479
|
+
)
|
|
480
|
+
for m in response.get("data", [])
|
|
481
|
+
]
|
|
482
|
+
|
|
483
|
+
def get_model(self, model_id: str) -> Model:
|
|
484
|
+
"""Get model details."""
|
|
485
|
+
response = self._request("GET", f"/models/{model_id}")
|
|
486
|
+
return Model(
|
|
487
|
+
id=response["id"],
|
|
488
|
+
name=response.get("name", response["id"]),
|
|
489
|
+
description=response.get("description", ""),
|
|
490
|
+
parameters=response.get("parameters", 0),
|
|
491
|
+
context_length=response.get("context_length", 4096),
|
|
492
|
+
supports_finetuning=response.get("supports_finetuning", False)
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
# ==================== Agents ====================
|
|
496
|
+
|
|
497
|
+
def create_agent(
|
|
498
|
+
self,
|
|
499
|
+
workspace_id: str,
|
|
500
|
+
name: str,
|
|
501
|
+
model_id: Optional[str] = None,
|
|
502
|
+
description: Optional[str] = None,
|
|
503
|
+
config: Optional[Dict[str, Any]] = None
|
|
504
|
+
) -> Agent:
|
|
505
|
+
"""
|
|
506
|
+
Create a new agent.
|
|
507
|
+
|
|
508
|
+
Args:
|
|
509
|
+
workspace_id: Workspace ID
|
|
510
|
+
name: Agent name
|
|
511
|
+
model_id: Optional model ID to use
|
|
512
|
+
description: Agent description
|
|
513
|
+
config: Agent configuration (system_prompt, temperature, tools, etc.)
|
|
514
|
+
|
|
515
|
+
Returns:
|
|
516
|
+
Agent object
|
|
517
|
+
"""
|
|
518
|
+
data = {
|
|
519
|
+
"name": name,
|
|
520
|
+
"description": description or "",
|
|
521
|
+
"model_id": model_id,
|
|
522
|
+
"config": config or {"system_prompt": "You are a helpful assistant.", "temperature": 0.7}
|
|
523
|
+
}
|
|
524
|
+
|
|
525
|
+
response = self._request("POST", f"/workspaces/{workspace_id}/agents", data)
|
|
526
|
+
return self._parse_agent(response)
|
|
527
|
+
|
|
528
|
+
def list_agents(self, workspace_id: str) -> List[Agent]:
|
|
529
|
+
"""List agents in a workspace."""
|
|
530
|
+
response = self._request("GET", f"/workspaces/{workspace_id}/agents")
|
|
531
|
+
return [self._parse_agent(a) for a in response.get("data", [])]
|
|
532
|
+
|
|
533
|
+
def get_agent(self, agent_id: str) -> Agent:
|
|
534
|
+
"""Get agent details."""
|
|
535
|
+
response = self._request("GET", f"/agents/{agent_id}")
|
|
536
|
+
return self._parse_agent(response)
|
|
537
|
+
|
|
538
|
+
def update_agent(
|
|
539
|
+
self,
|
|
540
|
+
agent_id: str,
|
|
541
|
+
name: Optional[str] = None,
|
|
542
|
+
description: Optional[str] = None,
|
|
543
|
+
model_id: Optional[str] = None,
|
|
544
|
+
config: Optional[Dict[str, Any]] = None
|
|
545
|
+
) -> Agent:
|
|
546
|
+
"""Update an agent."""
|
|
547
|
+
data = {}
|
|
548
|
+
if name is not None:
|
|
549
|
+
data["name"] = name
|
|
550
|
+
if description is not None:
|
|
551
|
+
data["description"] = description
|
|
552
|
+
if model_id is not None:
|
|
553
|
+
data["model_id"] = model_id
|
|
554
|
+
if config is not None:
|
|
555
|
+
data["config"] = config
|
|
556
|
+
|
|
557
|
+
response = self._request("PATCH", f"/agents/{agent_id}", data)
|
|
558
|
+
return self._parse_agent(response)
|
|
559
|
+
|
|
560
|
+
def delete_agent(self, agent_id: str) -> bool:
|
|
561
|
+
"""Delete an agent (soft delete)."""
|
|
562
|
+
response = self._request("DELETE", f"/agents/{agent_id}")
|
|
563
|
+
return response.get("success", False)
|
|
564
|
+
|
|
565
|
+
def run_agent(
|
|
566
|
+
self,
|
|
567
|
+
agent_id: str,
|
|
568
|
+
messages: List[Dict[str, str]],
|
|
569
|
+
params: Optional[Dict[str, Any]] = None
|
|
570
|
+
) -> AgentRun:
|
|
571
|
+
"""
|
|
572
|
+
Execute an agent run.
|
|
573
|
+
|
|
574
|
+
Args:
|
|
575
|
+
agent_id: Agent ID
|
|
576
|
+
messages: List of {"role": "user/assistant", "content": "..."}
|
|
577
|
+
params: Optional additional parameters
|
|
578
|
+
|
|
579
|
+
Returns:
|
|
580
|
+
AgentRun with output
|
|
581
|
+
"""
|
|
582
|
+
data = {
|
|
583
|
+
"input": {
|
|
584
|
+
"messages": messages,
|
|
585
|
+
**(params or {})
|
|
586
|
+
}
|
|
587
|
+
}
|
|
588
|
+
|
|
589
|
+
response = self._request("POST", f"/agents/{agent_id}/runs", data)
|
|
590
|
+
return self._parse_agent_run(response)
|
|
591
|
+
|
|
592
|
+
def list_agent_runs(
|
|
593
|
+
self,
|
|
594
|
+
agent_id: str,
|
|
595
|
+
limit: int = 50,
|
|
596
|
+
offset: int = 0
|
|
597
|
+
) -> List[AgentRun]:
|
|
598
|
+
"""List runs for an agent."""
|
|
599
|
+
response = self._request("GET", f"/agents/{agent_id}/runs?limit={limit}&offset={offset}")
|
|
600
|
+
return [self._parse_agent_run(r) for r in response.get("data", [])]
|
|
601
|
+
|
|
602
|
+
def _parse_agent(self, data: Dict) -> Agent:
|
|
603
|
+
"""Parse agent response."""
|
|
604
|
+
return Agent(
|
|
605
|
+
id=data["id"],
|
|
606
|
+
name=data.get("name", ""),
|
|
607
|
+
workspace_id=data.get("workspace_id", ""),
|
|
608
|
+
description=data.get("description"),
|
|
609
|
+
model_id=data.get("model_id"),
|
|
610
|
+
config=data.get("config"),
|
|
611
|
+
is_active=data.get("is_active", True),
|
|
612
|
+
created_at=data.get("created_at")
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
def _parse_agent_run(self, data: Dict) -> AgentRun:
|
|
616
|
+
"""Parse agent run response."""
|
|
617
|
+
return AgentRun(
|
|
618
|
+
id=data["id"],
|
|
619
|
+
status=data.get("status", "unknown"),
|
|
620
|
+
agent_id=data.get("agent_id", ""),
|
|
621
|
+
input=data.get("input", {}),
|
|
622
|
+
output=data.get("output"),
|
|
623
|
+
token_usage=data.get("token_usage"),
|
|
624
|
+
latency_ms=data.get("latency_ms"),
|
|
625
|
+
created_at=data.get("created_at"),
|
|
626
|
+
finished_at=data.get("finished_at"),
|
|
627
|
+
error=data.get("error")
|
|
628
|
+
)
|
|
629
|
+
|
|
630
|
+
# ==================== Billing & Usage ====================
|
|
631
|
+
|
|
632
|
+
def get_usage(self, workspace_id: str) -> UsageRecord:
|
|
633
|
+
"""
|
|
634
|
+
Get current usage for a workspace.
|
|
635
|
+
|
|
636
|
+
Args:
|
|
637
|
+
workspace_id: Workspace ID
|
|
638
|
+
|
|
639
|
+
Returns:
|
|
640
|
+
UsageRecord with current usage and limits
|
|
641
|
+
"""
|
|
642
|
+
response = self._request("GET", f"/billing/usage?workspace_id={workspace_id}")
|
|
643
|
+
return UsageRecord(
|
|
644
|
+
tokens_used=response.get("tokens", {}).get("used", 0),
|
|
645
|
+
tokens_limit=response.get("tokens", {}).get("limit", 0),
|
|
646
|
+
finetune_jobs_used=response.get("finetune_jobs", {}).get("used", 0),
|
|
647
|
+
finetune_jobs_limit=response.get("finetune_jobs", {}).get("limit", 0),
|
|
648
|
+
agent_runs_used=response.get("agent_runs", {}).get("used", 0),
|
|
649
|
+
agent_runs_limit=response.get("agent_runs", {}).get("limit", 0),
|
|
650
|
+
period_start=response.get("period", {}).get("start", ""),
|
|
651
|
+
period_end=response.get("period", {}).get("end")
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
def get_plan(self, workspace_id: str) -> Plan:
|
|
655
|
+
"""
|
|
656
|
+
Get current plan for a workspace.
|
|
657
|
+
|
|
658
|
+
Args:
|
|
659
|
+
workspace_id: Workspace ID
|
|
660
|
+
|
|
661
|
+
Returns:
|
|
662
|
+
Plan with limits
|
|
663
|
+
"""
|
|
664
|
+
response = self._request("GET", f"/billing/plan?workspace_id={workspace_id}")
|
|
665
|
+
plan_data = response.get("plan", {})
|
|
666
|
+
return Plan(
|
|
667
|
+
id=plan_data.get("id"),
|
|
668
|
+
name=plan_data.get("name", "Free"),
|
|
669
|
+
code=plan_data.get("code", "free"),
|
|
670
|
+
billing_period=plan_data.get("billing_period", "lifetime"),
|
|
671
|
+
limits=response.get("limits", {})
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
# ==================== Workspace Finetune Jobs ====================
|
|
675
|
+
|
|
676
|
+
def create_workspace_finetune_job(
|
|
677
|
+
self,
|
|
678
|
+
workspace_id: str,
|
|
679
|
+
base_model: str,
|
|
680
|
+
dataset_id: str,
|
|
681
|
+
name: Optional[str] = None,
|
|
682
|
+
config: Optional[Dict[str, Any]] = None
|
|
683
|
+
) -> FineTuneJob:
|
|
684
|
+
"""
|
|
685
|
+
Create a finetune job in a workspace.
|
|
686
|
+
|
|
687
|
+
Args:
|
|
688
|
+
workspace_id: Workspace ID
|
|
689
|
+
base_model: Base model name (e.g., "Llama-3-8B")
|
|
690
|
+
dataset_id: Dataset ID
|
|
691
|
+
name: Optional name for the finetuned model
|
|
692
|
+
config: Training configuration (epochs, lr, batch_size, etc.)
|
|
693
|
+
|
|
694
|
+
Returns:
|
|
695
|
+
FineTuneJob object
|
|
696
|
+
"""
|
|
697
|
+
data = {
|
|
698
|
+
"base_model": base_model,
|
|
699
|
+
"dataset_id": dataset_id,
|
|
700
|
+
"name": name,
|
|
701
|
+
"config": config or {}
|
|
702
|
+
}
|
|
703
|
+
|
|
704
|
+
response = self._request("POST", f"/workspaces/{workspace_id}/finetune-jobs", data)
|
|
705
|
+
return self._parse_job(response)
|
|
706
|
+
|
|
707
|
+
def list_workspace_finetune_jobs(self, workspace_id: str) -> List[FineTuneJob]:
|
|
708
|
+
"""List finetune jobs in a workspace."""
|
|
709
|
+
response = self._request("GET", f"/workspaces/{workspace_id}/finetune-jobs")
|
|
710
|
+
return [self._parse_job(j) for j in response.get("data", [])]
|
|
711
|
+
|
|
712
|
+
|
|
713
|
+
class APIError(Exception):
|
|
714
|
+
"""API error."""
|
|
715
|
+
pass
|
|
716
|
+
|
|
717
|
+
|
|
718
|
+
# Convenience function
|
|
719
|
+
def get_client(api_key: Optional[str] = None) -> LangtuneClient:
|
|
720
|
+
"""Get a configured client instance."""
|
|
721
|
+
return LangtuneClient(api_key=api_key)
|