langtune 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langtune/api.py ADDED
@@ -0,0 +1,320 @@
1
+ """
2
+ api.py: High-level API for Langtune
3
+
4
+ Provides simple, user-friendly functions that work both locally and via server.
5
+ """
6
+
7
+ import os
8
+ from typing import Optional, Dict, Any, List, Union
9
+ from pathlib import Path
10
+ import logging
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def finetune(
16
+ training_data: Union[str, List[str]],
17
+ model: str = "llama-7b",
18
+ validation_data: Optional[Union[str, List[str]]] = None,
19
+ use_server: bool = True,
20
+ # Hyperparameters
21
+ epochs: int = 3,
22
+ batch_size: int = 4,
23
+ learning_rate: float = 2e-4,
24
+ lora_rank: int = 16,
25
+ # Local options
26
+ preset: str = "small",
27
+ output_dir: str = "./output",
28
+ # Server options
29
+ wait: bool = True,
30
+ **kwargs
31
+ ) -> Union[str, Any]:
32
+ """
33
+ Fine-tune a language model.
34
+
35
+ By default, runs on Langtrain server. Set use_server=False for local training.
36
+
37
+ Args:
38
+ training_data: Path to training data or list of texts
39
+ model: Model to fine-tune (server: 'llama-7b', local: preset name)
40
+ validation_data: Optional validation data
41
+ use_server: Use server for training (default: True)
42
+ epochs: Number of training epochs
43
+ batch_size: Batch size
44
+ learning_rate: Learning rate
45
+ lora_rank: LoRA rank
46
+ preset: Local model preset (tiny/small/base/large)
47
+ output_dir: Output directory for local training
48
+ wait: Wait for server job to complete
49
+ **kwargs: Additional arguments
50
+
51
+ Returns:
52
+ Server: Fine-tuned model ID or Job object
53
+ Local: Trained model
54
+
55
+ Examples:
56
+ # Server-side fine-tuning
57
+ >>> from langtune import finetune
58
+ >>> model_id = finetune("data.jsonl", model="llama-7b")
59
+
60
+ # Local fine-tuning
61
+ >>> model = finetune("data.txt", use_server=False, preset="small")
62
+ """
63
+ if use_server:
64
+ return _finetune_server(
65
+ training_data=training_data,
66
+ model=model,
67
+ validation_data=validation_data,
68
+ epochs=epochs,
69
+ batch_size=batch_size,
70
+ learning_rate=learning_rate,
71
+ lora_rank=lora_rank,
72
+ wait=wait,
73
+ **kwargs
74
+ )
75
+ else:
76
+ return _finetune_local(
77
+ training_data=training_data,
78
+ validation_data=validation_data,
79
+ preset=preset,
80
+ epochs=epochs,
81
+ batch_size=batch_size,
82
+ learning_rate=learning_rate,
83
+ lora_rank=lora_rank,
84
+ output_dir=output_dir,
85
+ **kwargs
86
+ )
87
+
88
+
89
+ def _finetune_server(
90
+ training_data: str,
91
+ model: str,
92
+ validation_data: Optional[str],
93
+ epochs: int,
94
+ batch_size: int,
95
+ learning_rate: float,
96
+ lora_rank: int,
97
+ wait: bool,
98
+ **kwargs
99
+ ):
100
+ """Server-side fine-tuning."""
101
+ from .client import LangtuneClient
102
+
103
+ client = LangtuneClient()
104
+
105
+ hyperparameters = {
106
+ "n_epochs": epochs,
107
+ "batch_size": batch_size,
108
+ "learning_rate_multiplier": learning_rate,
109
+ "lora_rank": lora_rank
110
+ }
111
+
112
+ job = client.create_finetune_job(
113
+ training_file=training_data,
114
+ model=model,
115
+ validation_file=validation_data,
116
+ hyperparameters=hyperparameters
117
+ )
118
+
119
+ logger.info(f"Created fine-tuning job: {job.id}")
120
+
121
+ if wait:
122
+ job = client.wait_for_job(job.id)
123
+ if job.error:
124
+ raise RuntimeError(f"Fine-tuning failed: {job.error}")
125
+ logger.info(f"Fine-tuning complete! Model: {job.result_url}")
126
+ return job.result_url
127
+
128
+ return job
129
+
130
+
131
+ def _finetune_local(
132
+ training_data: Union[str, List[str]],
133
+ validation_data: Optional[Union[str, List[str]]],
134
+ preset: str,
135
+ epochs: int,
136
+ batch_size: int,
137
+ learning_rate: float,
138
+ lora_rank: int,
139
+ output_dir: str,
140
+ **kwargs
141
+ ):
142
+ """Local fine-tuning."""
143
+ from .finetune import finetune as local_finetune
144
+
145
+ return local_finetune(
146
+ train_data=training_data,
147
+ val_data=validation_data,
148
+ preset=preset,
149
+ epochs=epochs,
150
+ batch_size=batch_size,
151
+ learning_rate=learning_rate,
152
+ lora_rank=lora_rank,
153
+ output_dir=output_dir,
154
+ **kwargs
155
+ )
156
+
157
+
158
+ def generate(
159
+ prompt: str,
160
+ model: str = "llama-7b",
161
+ max_tokens: int = 256,
162
+ temperature: float = 0.7,
163
+ top_p: float = 0.9,
164
+ use_server: bool = True,
165
+ **kwargs
166
+ ) -> str:
167
+ """
168
+ Generate text completion.
169
+
170
+ Args:
171
+ prompt: Input prompt
172
+ model: Model to use
173
+ max_tokens: Maximum tokens to generate
174
+ temperature: Sampling temperature
175
+ top_p: Nucleus sampling parameter
176
+ use_server: Use server for generation
177
+ **kwargs: Additional arguments
178
+
179
+ Returns:
180
+ Generated text
181
+
182
+ Example:
183
+ >>> from langtune import generate
184
+ >>> text = generate("Once upon a time", model="llama-7b")
185
+ """
186
+ if use_server:
187
+ from .client import LangtuneClient
188
+ client = LangtuneClient()
189
+ return client.generate(
190
+ prompt=prompt,
191
+ model=model,
192
+ max_tokens=max_tokens,
193
+ temperature=temperature,
194
+ top_p=top_p,
195
+ **kwargs
196
+ )
197
+ else:
198
+ raise NotImplementedError("Local generation requires a loaded model. Use TextGenerator.")
199
+
200
+
201
+ def chat(
202
+ messages: List[Dict[str, str]],
203
+ model: str = "llama-7b-chat",
204
+ max_tokens: int = 256,
205
+ temperature: float = 0.7,
206
+ **kwargs
207
+ ) -> str:
208
+ """
209
+ Chat completion.
210
+
211
+ Args:
212
+ messages: List of {"role": "user/assistant", "content": "..."}
213
+ model: Model to use
214
+ max_tokens: Maximum tokens
215
+ temperature: Sampling temperature
216
+
217
+ Returns:
218
+ Assistant response
219
+
220
+ Example:
221
+ >>> from langtune import chat
222
+ >>> response = chat([{"role": "user", "content": "Hello!"}])
223
+ """
224
+ from .client import LangtuneClient
225
+ client = LangtuneClient()
226
+ return client.chat(
227
+ messages=messages,
228
+ model=model,
229
+ max_tokens=max_tokens,
230
+ temperature=temperature,
231
+ **kwargs
232
+ )
233
+
234
+
235
+ def list_models() -> List[Dict[str, Any]]:
236
+ """
237
+ List available models.
238
+
239
+ Returns:
240
+ List of model info dicts
241
+ """
242
+ from .client import LangtuneClient
243
+ client = LangtuneClient()
244
+ models = client.list_models()
245
+ return [
246
+ {
247
+ "id": m.id,
248
+ "name": m.name,
249
+ "description": m.description,
250
+ "parameters": m.parameters,
251
+ "supports_finetuning": m.supports_finetuning
252
+ }
253
+ for m in models
254
+ ]
255
+
256
+
257
+ def list_jobs(limit: int = 10) -> List[Dict[str, Any]]:
258
+ """
259
+ List fine-tuning jobs.
260
+
261
+ Args:
262
+ limit: Maximum number of jobs to return
263
+
264
+ Returns:
265
+ List of job info dicts
266
+ """
267
+ from .client import LangtuneClient
268
+ client = LangtuneClient()
269
+ jobs = client.list_finetune_jobs(limit=limit)
270
+ return [
271
+ {
272
+ "id": j.id,
273
+ "status": j.status.value,
274
+ "model": j.model,
275
+ "created_at": j.created_at,
276
+ "error": j.error
277
+ }
278
+ for j in jobs
279
+ ]
280
+
281
+
282
+ def get_job(job_id: str) -> Dict[str, Any]:
283
+ """
284
+ Get fine-tuning job status.
285
+
286
+ Args:
287
+ job_id: Job ID
288
+
289
+ Returns:
290
+ Job info dict
291
+ """
292
+ from .client import LangtuneClient
293
+ client = LangtuneClient()
294
+ job = client.get_finetune_job(job_id)
295
+ return {
296
+ "id": job.id,
297
+ "status": job.status.value,
298
+ "model": job.model,
299
+ "created_at": job.created_at,
300
+ "completed_at": job.completed_at,
301
+ "error": job.error,
302
+ "result_url": job.result_url,
303
+ "metrics": job.metrics
304
+ }
305
+
306
+
307
+ def cancel_job(job_id: str) -> Dict[str, Any]:
308
+ """
309
+ Cancel a fine-tuning job.
310
+
311
+ Args:
312
+ job_id: Job ID
313
+
314
+ Returns:
315
+ Updated job info
316
+ """
317
+ from .client import LangtuneClient
318
+ client = LangtuneClient()
319
+ job = client.cancel_finetune_job(job_id)
320
+ return {"id": job.id, "status": job.status.value}