langtune 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langtune/__init__.py +315 -0
- langtune/acceleration.py +132 -0
- langtune/api.py +320 -0
- langtune/auth.py +434 -0
- langtune/callbacks.py +268 -0
- langtune/cli.py +687 -0
- langtune/client.py +721 -0
- langtune/config.py +356 -0
- langtune/data.py +526 -0
- langtune/distributed.py +154 -0
- langtune/facade.py +174 -0
- langtune/finetune.py +491 -0
- langtune/generation.py +95 -0
- langtune/logging_utils.py +182 -0
- langtune/metrics.py +345 -0
- langtune/model/__init__.py +20 -0
- langtune/model/hub.py +109 -0
- langtune/model/loader.py +84 -0
- langtune/model/safetensors.py +104 -0
- langtune/model/weights.py +100 -0
- langtune/models.py +19 -0
- langtune/nn/fast_transformer.py +399 -0
- langtune/nn/layers.py +178 -0
- langtune/nn/transformer.py +254 -0
- langtune/optimizations.py +870 -0
- langtune/py.typed +2 -0
- langtune/schedulers.py +234 -0
- langtune/tokenizers.py +275 -0
- langtune/trainer.py +889 -0
- langtune/training/neftune.py +80 -0
- langtune/utils.py +337 -0
- langtune-0.1.19.dist-info/METADATA +257 -0
- langtune-0.1.19.dist-info/RECORD +37 -0
- langtune-0.1.19.dist-info/WHEEL +5 -0
- langtune-0.1.19.dist-info/entry_points.txt +2 -0
- langtune-0.1.19.dist-info/licenses/LICENSE +21 -0
- langtune-0.1.19.dist-info/top_level.txt +1 -0
langtune/api.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
"""
|
|
2
|
+
api.py: High-level API for Langtune
|
|
3
|
+
|
|
4
|
+
Provides simple, user-friendly functions that work both locally and via server.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
from typing import Optional, Dict, Any, List, Union
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
import logging
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def finetune(
|
|
16
|
+
training_data: Union[str, List[str]],
|
|
17
|
+
model: str = "llama-7b",
|
|
18
|
+
validation_data: Optional[Union[str, List[str]]] = None,
|
|
19
|
+
use_server: bool = True,
|
|
20
|
+
# Hyperparameters
|
|
21
|
+
epochs: int = 3,
|
|
22
|
+
batch_size: int = 4,
|
|
23
|
+
learning_rate: float = 2e-4,
|
|
24
|
+
lora_rank: int = 16,
|
|
25
|
+
# Local options
|
|
26
|
+
preset: str = "small",
|
|
27
|
+
output_dir: str = "./output",
|
|
28
|
+
# Server options
|
|
29
|
+
wait: bool = True,
|
|
30
|
+
**kwargs
|
|
31
|
+
) -> Union[str, Any]:
|
|
32
|
+
"""
|
|
33
|
+
Fine-tune a language model.
|
|
34
|
+
|
|
35
|
+
By default, runs on Langtrain server. Set use_server=False for local training.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
training_data: Path to training data or list of texts
|
|
39
|
+
model: Model to fine-tune (server: 'llama-7b', local: preset name)
|
|
40
|
+
validation_data: Optional validation data
|
|
41
|
+
use_server: Use server for training (default: True)
|
|
42
|
+
epochs: Number of training epochs
|
|
43
|
+
batch_size: Batch size
|
|
44
|
+
learning_rate: Learning rate
|
|
45
|
+
lora_rank: LoRA rank
|
|
46
|
+
preset: Local model preset (tiny/small/base/large)
|
|
47
|
+
output_dir: Output directory for local training
|
|
48
|
+
wait: Wait for server job to complete
|
|
49
|
+
**kwargs: Additional arguments
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Server: Fine-tuned model ID or Job object
|
|
53
|
+
Local: Trained model
|
|
54
|
+
|
|
55
|
+
Examples:
|
|
56
|
+
# Server-side fine-tuning
|
|
57
|
+
>>> from langtune import finetune
|
|
58
|
+
>>> model_id = finetune("data.jsonl", model="llama-7b")
|
|
59
|
+
|
|
60
|
+
# Local fine-tuning
|
|
61
|
+
>>> model = finetune("data.txt", use_server=False, preset="small")
|
|
62
|
+
"""
|
|
63
|
+
if use_server:
|
|
64
|
+
return _finetune_server(
|
|
65
|
+
training_data=training_data,
|
|
66
|
+
model=model,
|
|
67
|
+
validation_data=validation_data,
|
|
68
|
+
epochs=epochs,
|
|
69
|
+
batch_size=batch_size,
|
|
70
|
+
learning_rate=learning_rate,
|
|
71
|
+
lora_rank=lora_rank,
|
|
72
|
+
wait=wait,
|
|
73
|
+
**kwargs
|
|
74
|
+
)
|
|
75
|
+
else:
|
|
76
|
+
return _finetune_local(
|
|
77
|
+
training_data=training_data,
|
|
78
|
+
validation_data=validation_data,
|
|
79
|
+
preset=preset,
|
|
80
|
+
epochs=epochs,
|
|
81
|
+
batch_size=batch_size,
|
|
82
|
+
learning_rate=learning_rate,
|
|
83
|
+
lora_rank=lora_rank,
|
|
84
|
+
output_dir=output_dir,
|
|
85
|
+
**kwargs
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _finetune_server(
|
|
90
|
+
training_data: str,
|
|
91
|
+
model: str,
|
|
92
|
+
validation_data: Optional[str],
|
|
93
|
+
epochs: int,
|
|
94
|
+
batch_size: int,
|
|
95
|
+
learning_rate: float,
|
|
96
|
+
lora_rank: int,
|
|
97
|
+
wait: bool,
|
|
98
|
+
**kwargs
|
|
99
|
+
):
|
|
100
|
+
"""Server-side fine-tuning."""
|
|
101
|
+
from .client import LangtuneClient
|
|
102
|
+
|
|
103
|
+
client = LangtuneClient()
|
|
104
|
+
|
|
105
|
+
hyperparameters = {
|
|
106
|
+
"n_epochs": epochs,
|
|
107
|
+
"batch_size": batch_size,
|
|
108
|
+
"learning_rate_multiplier": learning_rate,
|
|
109
|
+
"lora_rank": lora_rank
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
job = client.create_finetune_job(
|
|
113
|
+
training_file=training_data,
|
|
114
|
+
model=model,
|
|
115
|
+
validation_file=validation_data,
|
|
116
|
+
hyperparameters=hyperparameters
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
logger.info(f"Created fine-tuning job: {job.id}")
|
|
120
|
+
|
|
121
|
+
if wait:
|
|
122
|
+
job = client.wait_for_job(job.id)
|
|
123
|
+
if job.error:
|
|
124
|
+
raise RuntimeError(f"Fine-tuning failed: {job.error}")
|
|
125
|
+
logger.info(f"Fine-tuning complete! Model: {job.result_url}")
|
|
126
|
+
return job.result_url
|
|
127
|
+
|
|
128
|
+
return job
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _finetune_local(
|
|
132
|
+
training_data: Union[str, List[str]],
|
|
133
|
+
validation_data: Optional[Union[str, List[str]]],
|
|
134
|
+
preset: str,
|
|
135
|
+
epochs: int,
|
|
136
|
+
batch_size: int,
|
|
137
|
+
learning_rate: float,
|
|
138
|
+
lora_rank: int,
|
|
139
|
+
output_dir: str,
|
|
140
|
+
**kwargs
|
|
141
|
+
):
|
|
142
|
+
"""Local fine-tuning."""
|
|
143
|
+
from .finetune import finetune as local_finetune
|
|
144
|
+
|
|
145
|
+
return local_finetune(
|
|
146
|
+
train_data=training_data,
|
|
147
|
+
val_data=validation_data,
|
|
148
|
+
preset=preset,
|
|
149
|
+
epochs=epochs,
|
|
150
|
+
batch_size=batch_size,
|
|
151
|
+
learning_rate=learning_rate,
|
|
152
|
+
lora_rank=lora_rank,
|
|
153
|
+
output_dir=output_dir,
|
|
154
|
+
**kwargs
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def generate(
|
|
159
|
+
prompt: str,
|
|
160
|
+
model: str = "llama-7b",
|
|
161
|
+
max_tokens: int = 256,
|
|
162
|
+
temperature: float = 0.7,
|
|
163
|
+
top_p: float = 0.9,
|
|
164
|
+
use_server: bool = True,
|
|
165
|
+
**kwargs
|
|
166
|
+
) -> str:
|
|
167
|
+
"""
|
|
168
|
+
Generate text completion.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
prompt: Input prompt
|
|
172
|
+
model: Model to use
|
|
173
|
+
max_tokens: Maximum tokens to generate
|
|
174
|
+
temperature: Sampling temperature
|
|
175
|
+
top_p: Nucleus sampling parameter
|
|
176
|
+
use_server: Use server for generation
|
|
177
|
+
**kwargs: Additional arguments
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Generated text
|
|
181
|
+
|
|
182
|
+
Example:
|
|
183
|
+
>>> from langtune import generate
|
|
184
|
+
>>> text = generate("Once upon a time", model="llama-7b")
|
|
185
|
+
"""
|
|
186
|
+
if use_server:
|
|
187
|
+
from .client import LangtuneClient
|
|
188
|
+
client = LangtuneClient()
|
|
189
|
+
return client.generate(
|
|
190
|
+
prompt=prompt,
|
|
191
|
+
model=model,
|
|
192
|
+
max_tokens=max_tokens,
|
|
193
|
+
temperature=temperature,
|
|
194
|
+
top_p=top_p,
|
|
195
|
+
**kwargs
|
|
196
|
+
)
|
|
197
|
+
else:
|
|
198
|
+
raise NotImplementedError("Local generation requires a loaded model. Use TextGenerator.")
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def chat(
|
|
202
|
+
messages: List[Dict[str, str]],
|
|
203
|
+
model: str = "llama-7b-chat",
|
|
204
|
+
max_tokens: int = 256,
|
|
205
|
+
temperature: float = 0.7,
|
|
206
|
+
**kwargs
|
|
207
|
+
) -> str:
|
|
208
|
+
"""
|
|
209
|
+
Chat completion.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
messages: List of {"role": "user/assistant", "content": "..."}
|
|
213
|
+
model: Model to use
|
|
214
|
+
max_tokens: Maximum tokens
|
|
215
|
+
temperature: Sampling temperature
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
Assistant response
|
|
219
|
+
|
|
220
|
+
Example:
|
|
221
|
+
>>> from langtune import chat
|
|
222
|
+
>>> response = chat([{"role": "user", "content": "Hello!"}])
|
|
223
|
+
"""
|
|
224
|
+
from .client import LangtuneClient
|
|
225
|
+
client = LangtuneClient()
|
|
226
|
+
return client.chat(
|
|
227
|
+
messages=messages,
|
|
228
|
+
model=model,
|
|
229
|
+
max_tokens=max_tokens,
|
|
230
|
+
temperature=temperature,
|
|
231
|
+
**kwargs
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def list_models() -> List[Dict[str, Any]]:
|
|
236
|
+
"""
|
|
237
|
+
List available models.
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
List of model info dicts
|
|
241
|
+
"""
|
|
242
|
+
from .client import LangtuneClient
|
|
243
|
+
client = LangtuneClient()
|
|
244
|
+
models = client.list_models()
|
|
245
|
+
return [
|
|
246
|
+
{
|
|
247
|
+
"id": m.id,
|
|
248
|
+
"name": m.name,
|
|
249
|
+
"description": m.description,
|
|
250
|
+
"parameters": m.parameters,
|
|
251
|
+
"supports_finetuning": m.supports_finetuning
|
|
252
|
+
}
|
|
253
|
+
for m in models
|
|
254
|
+
]
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def list_jobs(limit: int = 10) -> List[Dict[str, Any]]:
|
|
258
|
+
"""
|
|
259
|
+
List fine-tuning jobs.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
limit: Maximum number of jobs to return
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
List of job info dicts
|
|
266
|
+
"""
|
|
267
|
+
from .client import LangtuneClient
|
|
268
|
+
client = LangtuneClient()
|
|
269
|
+
jobs = client.list_finetune_jobs(limit=limit)
|
|
270
|
+
return [
|
|
271
|
+
{
|
|
272
|
+
"id": j.id,
|
|
273
|
+
"status": j.status.value,
|
|
274
|
+
"model": j.model,
|
|
275
|
+
"created_at": j.created_at,
|
|
276
|
+
"error": j.error
|
|
277
|
+
}
|
|
278
|
+
for j in jobs
|
|
279
|
+
]
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def get_job(job_id: str) -> Dict[str, Any]:
|
|
283
|
+
"""
|
|
284
|
+
Get fine-tuning job status.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
job_id: Job ID
|
|
288
|
+
|
|
289
|
+
Returns:
|
|
290
|
+
Job info dict
|
|
291
|
+
"""
|
|
292
|
+
from .client import LangtuneClient
|
|
293
|
+
client = LangtuneClient()
|
|
294
|
+
job = client.get_finetune_job(job_id)
|
|
295
|
+
return {
|
|
296
|
+
"id": job.id,
|
|
297
|
+
"status": job.status.value,
|
|
298
|
+
"model": job.model,
|
|
299
|
+
"created_at": job.created_at,
|
|
300
|
+
"completed_at": job.completed_at,
|
|
301
|
+
"error": job.error,
|
|
302
|
+
"result_url": job.result_url,
|
|
303
|
+
"metrics": job.metrics
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def cancel_job(job_id: str) -> Dict[str, Any]:
|
|
308
|
+
"""
|
|
309
|
+
Cancel a fine-tuning job.
|
|
310
|
+
|
|
311
|
+
Args:
|
|
312
|
+
job_id: Job ID
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
Updated job info
|
|
316
|
+
"""
|
|
317
|
+
from .client import LangtuneClient
|
|
318
|
+
client = LangtuneClient()
|
|
319
|
+
job = client.cancel_finetune_job(job_id)
|
|
320
|
+
return {"id": job.id, "status": job.status.value}
|