together 1.4.0__py3-none-any.whl → 1.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- together/abstract/api_requestor.py +7 -9
- together/cli/api/endpoints.py +415 -0
- together/cli/api/finetune.py +67 -5
- together/cli/cli.py +2 -0
- together/client.py +1 -0
- together/constants.py +6 -0
- together/error.py +3 -0
- together/legacy/finetune.py +1 -1
- together/resources/__init__.py +4 -1
- together/resources/endpoints.py +488 -0
- together/resources/finetune.py +173 -15
- together/types/__init__.py +25 -20
- together/types/chat_completions.py +6 -0
- together/types/endpoints.py +123 -0
- together/types/finetune.py +45 -0
- together/utils/__init__.py +4 -0
- together/utils/files.py +139 -66
- together/utils/tools.py +53 -2
- {together-1.4.0.dist-info → together-1.4.4.dist-info}/METADATA +93 -23
- {together-1.4.0.dist-info → together-1.4.4.dist-info}/RECORD +23 -20
- {together-1.4.0.dist-info → together-1.4.4.dist-info}/WHEEL +1 -1
- {together-1.4.0.dist-info → together-1.4.4.dist-info}/LICENSE +0 -0
- {together-1.4.0.dist-info → together-1.4.4.dist-info}/entry_points.txt +0 -0
|
@@ -437,7 +437,7 @@ class APIRequestor:
|
|
|
437
437
|
[(k, v) for k, v in options.params.items() if v is not None]
|
|
438
438
|
)
|
|
439
439
|
abs_url = _build_api_url(abs_url, encoded_params)
|
|
440
|
-
elif options.method.lower() in {"post", "put"}:
|
|
440
|
+
elif options.method.lower() in {"post", "put", "patch"}:
|
|
441
441
|
if options.params and (options.files or options.override_headers):
|
|
442
442
|
data = options.params
|
|
443
443
|
elif options.params and not options.files:
|
|
@@ -587,16 +587,14 @@ class APIRequestor:
|
|
|
587
587
|
)
|
|
588
588
|
headers["Content-Type"] = content_type
|
|
589
589
|
|
|
590
|
-
request_kwargs = {
|
|
591
|
-
"headers": headers,
|
|
592
|
-
"data": data,
|
|
593
|
-
"timeout": timeout,
|
|
594
|
-
"allow_redirects": options.allow_redirects,
|
|
595
|
-
}
|
|
596
|
-
|
|
597
590
|
try:
|
|
598
591
|
result = await session.request(
|
|
599
|
-
method=options.method,
|
|
592
|
+
method=options.method,
|
|
593
|
+
url=abs_url,
|
|
594
|
+
headers=headers,
|
|
595
|
+
data=data,
|
|
596
|
+
timeout=timeout,
|
|
597
|
+
allow_redirects=options.allow_redirects,
|
|
600
598
|
)
|
|
601
599
|
utils.log_debug(
|
|
602
600
|
"Together API response",
|
|
@@ -0,0 +1,415 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import sys
|
|
5
|
+
from functools import wraps
|
|
6
|
+
from typing import Any, Callable, Dict, List, Literal, TypeVar, Union
|
|
7
|
+
|
|
8
|
+
import click
|
|
9
|
+
|
|
10
|
+
from together import Together
|
|
11
|
+
from together.error import InvalidRequestError
|
|
12
|
+
from together.types import DedicatedEndpoint, ListEndpoint
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def print_endpoint(
|
|
16
|
+
endpoint: Union[DedicatedEndpoint, ListEndpoint],
|
|
17
|
+
) -> None:
|
|
18
|
+
"""Print endpoint details in a Docker-like format or JSON."""
|
|
19
|
+
|
|
20
|
+
# Print header info
|
|
21
|
+
click.echo(f"ID:\t\t{endpoint.id}")
|
|
22
|
+
click.echo(f"Name:\t\t{endpoint.name}")
|
|
23
|
+
|
|
24
|
+
# Print type-specific fields
|
|
25
|
+
if isinstance(endpoint, DedicatedEndpoint):
|
|
26
|
+
click.echo(f"Display Name:\t{endpoint.display_name}")
|
|
27
|
+
click.echo(f"Hardware:\t{endpoint.hardware}")
|
|
28
|
+
click.echo(
|
|
29
|
+
f"Autoscaling:\tMin={endpoint.autoscaling.min_replicas}, "
|
|
30
|
+
f"Max={endpoint.autoscaling.max_replicas}"
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
click.echo(f"Model:\t\t{endpoint.model}")
|
|
34
|
+
click.echo(f"Type:\t\t{endpoint.type}")
|
|
35
|
+
click.echo(f"Owner:\t\t{endpoint.owner}")
|
|
36
|
+
click.echo(f"State:\t\t{endpoint.state}")
|
|
37
|
+
click.echo(f"Created:\t{endpoint.created_at}")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
F = TypeVar("F", bound=Callable[..., Any])
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def print_api_error(
|
|
44
|
+
e: InvalidRequestError,
|
|
45
|
+
) -> None:
|
|
46
|
+
error_details = e.api_response.message
|
|
47
|
+
|
|
48
|
+
if error_details and (
|
|
49
|
+
"credentials" in error_details.lower()
|
|
50
|
+
or "authentication" in error_details.lower()
|
|
51
|
+
):
|
|
52
|
+
click.echo("Error: Invalid API key or authentication failed", err=True)
|
|
53
|
+
else:
|
|
54
|
+
click.echo(f"Error: {error_details}", err=True)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def handle_api_errors(f: F) -> F:
|
|
58
|
+
"""Decorator to handle common API errors in CLI commands."""
|
|
59
|
+
|
|
60
|
+
@wraps(f)
|
|
61
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
62
|
+
try:
|
|
63
|
+
return f(*args, **kwargs)
|
|
64
|
+
except InvalidRequestError as e:
|
|
65
|
+
print_api_error(e)
|
|
66
|
+
sys.exit(1)
|
|
67
|
+
except Exception as e:
|
|
68
|
+
click.echo(f"Error: An unexpected error occurred - {str(e)}", err=True)
|
|
69
|
+
sys.exit(1)
|
|
70
|
+
|
|
71
|
+
return wrapper # type: ignore
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@click.group()
|
|
75
|
+
@click.pass_context
|
|
76
|
+
def endpoints(ctx: click.Context) -> None:
|
|
77
|
+
"""Endpoints API commands"""
|
|
78
|
+
pass
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@endpoints.command()
|
|
82
|
+
@click.option(
|
|
83
|
+
"--model",
|
|
84
|
+
required=True,
|
|
85
|
+
help="The model to deploy (e.g. mistralai/Mixtral-8x7B-Instruct-v0.1)",
|
|
86
|
+
)
|
|
87
|
+
@click.option(
|
|
88
|
+
"--min-replicas",
|
|
89
|
+
type=int,
|
|
90
|
+
default=1,
|
|
91
|
+
help="Minimum number of replicas to deploy",
|
|
92
|
+
)
|
|
93
|
+
@click.option(
|
|
94
|
+
"--max-replicas",
|
|
95
|
+
type=int,
|
|
96
|
+
default=1,
|
|
97
|
+
help="Maximum number of replicas to deploy",
|
|
98
|
+
)
|
|
99
|
+
@click.option(
|
|
100
|
+
"--gpu",
|
|
101
|
+
type=click.Choice(["h100", "a100", "l40", "l40s", "rtx-6000"]),
|
|
102
|
+
required=True,
|
|
103
|
+
help="GPU type to use for inference",
|
|
104
|
+
)
|
|
105
|
+
@click.option(
|
|
106
|
+
"--gpu-count",
|
|
107
|
+
type=int,
|
|
108
|
+
default=1,
|
|
109
|
+
help="Number of GPUs to use per replica",
|
|
110
|
+
)
|
|
111
|
+
@click.option(
|
|
112
|
+
"--display-name",
|
|
113
|
+
help="A human-readable name for the endpoint",
|
|
114
|
+
)
|
|
115
|
+
@click.option(
|
|
116
|
+
"--no-prompt-cache",
|
|
117
|
+
is_flag=True,
|
|
118
|
+
help="Disable the prompt cache for this endpoint",
|
|
119
|
+
)
|
|
120
|
+
@click.option(
|
|
121
|
+
"--no-speculative-decoding",
|
|
122
|
+
is_flag=True,
|
|
123
|
+
help="Disable speculative decoding for this endpoint",
|
|
124
|
+
)
|
|
125
|
+
@click.option(
|
|
126
|
+
"--no-auto-start",
|
|
127
|
+
is_flag=True,
|
|
128
|
+
help="Create the endpoint in STOPPED state instead of auto-starting it",
|
|
129
|
+
)
|
|
130
|
+
@click.option(
|
|
131
|
+
"--wait",
|
|
132
|
+
is_flag=True,
|
|
133
|
+
default=True,
|
|
134
|
+
help="Wait for the endpoint to be ready after creation",
|
|
135
|
+
)
|
|
136
|
+
@click.pass_obj
|
|
137
|
+
@handle_api_errors
|
|
138
|
+
def create(
|
|
139
|
+
client: Together,
|
|
140
|
+
model: str,
|
|
141
|
+
min_replicas: int,
|
|
142
|
+
max_replicas: int,
|
|
143
|
+
gpu: str,
|
|
144
|
+
gpu_count: int,
|
|
145
|
+
display_name: str | None,
|
|
146
|
+
no_prompt_cache: bool,
|
|
147
|
+
no_speculative_decoding: bool,
|
|
148
|
+
no_auto_start: bool,
|
|
149
|
+
wait: bool,
|
|
150
|
+
) -> None:
|
|
151
|
+
"""Create a new dedicated inference endpoint."""
|
|
152
|
+
# Map GPU types to their full hardware ID names
|
|
153
|
+
gpu_map = {
|
|
154
|
+
"h100": "nvidia_h100_80gb_sxm",
|
|
155
|
+
"a100": "nvidia_a100_80gb_pcie" if gpu_count == 1 else "nvidia_a100_80gb_sxm",
|
|
156
|
+
"l40": "nvidia_l40",
|
|
157
|
+
"l40s": "nvidia_l40s",
|
|
158
|
+
"rtx-6000": "nvidia_rtx_6000_ada",
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
hardware_id = f"{gpu_count}x_{gpu_map[gpu]}"
|
|
162
|
+
|
|
163
|
+
try:
|
|
164
|
+
response = client.endpoints.create(
|
|
165
|
+
model=model,
|
|
166
|
+
hardware=hardware_id,
|
|
167
|
+
min_replicas=min_replicas,
|
|
168
|
+
max_replicas=max_replicas,
|
|
169
|
+
display_name=display_name,
|
|
170
|
+
disable_prompt_cache=no_prompt_cache,
|
|
171
|
+
disable_speculative_decoding=no_speculative_decoding,
|
|
172
|
+
state="STOPPED" if no_auto_start else "STARTED",
|
|
173
|
+
)
|
|
174
|
+
except InvalidRequestError as e:
|
|
175
|
+
print_api_error(e)
|
|
176
|
+
if "check the hardware api" in str(e).lower():
|
|
177
|
+
fetch_and_print_hardware_options(
|
|
178
|
+
client=client, model=model, print_json=False, available=True
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
sys.exit(1)
|
|
182
|
+
|
|
183
|
+
# Print detailed information to stderr
|
|
184
|
+
click.echo("Created dedicated endpoint with:", err=True)
|
|
185
|
+
click.echo(f" Model: {model}", err=True)
|
|
186
|
+
click.echo(f" Min replicas: {min_replicas}", err=True)
|
|
187
|
+
click.echo(f" Max replicas: {max_replicas}", err=True)
|
|
188
|
+
click.echo(f" Hardware: {hardware_id}", err=True)
|
|
189
|
+
if display_name:
|
|
190
|
+
click.echo(f" Display name: {display_name}", err=True)
|
|
191
|
+
if no_prompt_cache:
|
|
192
|
+
click.echo(" Prompt cache: disabled", err=True)
|
|
193
|
+
if no_speculative_decoding:
|
|
194
|
+
click.echo(" Speculative decoding: disabled", err=True)
|
|
195
|
+
if no_auto_start:
|
|
196
|
+
click.echo(" Auto-start: disabled", err=True)
|
|
197
|
+
|
|
198
|
+
click.echo(f"Endpoint created successfully, id: {response.id}", err=True)
|
|
199
|
+
|
|
200
|
+
if wait:
|
|
201
|
+
import time
|
|
202
|
+
|
|
203
|
+
click.echo("Waiting for endpoint to be ready...", err=True)
|
|
204
|
+
while client.endpoints.get(response.id).state != "STARTED":
|
|
205
|
+
time.sleep(1)
|
|
206
|
+
click.echo("Endpoint ready", err=True)
|
|
207
|
+
|
|
208
|
+
# Print only the endpoint ID to stdout
|
|
209
|
+
click.echo(response.id)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
@endpoints.command()
|
|
213
|
+
@click.argument("endpoint-id", required=True)
|
|
214
|
+
@click.option("--json", is_flag=True, help="Print output in JSON format")
|
|
215
|
+
@click.pass_obj
|
|
216
|
+
@handle_api_errors
|
|
217
|
+
def get(client: Together, endpoint_id: str, json: bool) -> None:
|
|
218
|
+
"""Get a dedicated inference endpoint."""
|
|
219
|
+
endpoint = client.endpoints.get(endpoint_id)
|
|
220
|
+
if json:
|
|
221
|
+
import json as json_lib
|
|
222
|
+
|
|
223
|
+
click.echo(json_lib.dumps(endpoint.model_dump(), indent=2))
|
|
224
|
+
else:
|
|
225
|
+
print_endpoint(endpoint)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
@endpoints.command()
|
|
229
|
+
@click.option("--model", help="Filter hardware options by model")
|
|
230
|
+
@click.option("--json", is_flag=True, help="Print output in JSON format")
|
|
231
|
+
@click.option(
|
|
232
|
+
"--available",
|
|
233
|
+
is_flag=True,
|
|
234
|
+
help="Print only available hardware options (can only be used if model is passed in)",
|
|
235
|
+
)
|
|
236
|
+
@click.pass_obj
|
|
237
|
+
@handle_api_errors
|
|
238
|
+
def hardware(client: Together, model: str | None, json: bool, available: bool) -> None:
|
|
239
|
+
"""List all available hardware options, optionally filtered by model."""
|
|
240
|
+
fetch_and_print_hardware_options(client, model, json, available)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def fetch_and_print_hardware_options(
|
|
244
|
+
client: Together, model: str | None, print_json: bool, available: bool
|
|
245
|
+
) -> None:
|
|
246
|
+
"""Print hardware options for a model."""
|
|
247
|
+
|
|
248
|
+
message = "Available hardware options:" if available else "All hardware options:"
|
|
249
|
+
click.echo(message, err=True)
|
|
250
|
+
hardware_options = client.endpoints.list_hardware(model)
|
|
251
|
+
if available:
|
|
252
|
+
hardware_options = [
|
|
253
|
+
hardware
|
|
254
|
+
for hardware in hardware_options
|
|
255
|
+
if hardware.availability is not None
|
|
256
|
+
and hardware.availability.status == "available"
|
|
257
|
+
]
|
|
258
|
+
|
|
259
|
+
if print_json:
|
|
260
|
+
json_output = [hardware.model_dump() for hardware in hardware_options]
|
|
261
|
+
click.echo(json.dumps(json_output, indent=2))
|
|
262
|
+
else:
|
|
263
|
+
for hardware in hardware_options:
|
|
264
|
+
click.echo(f" {hardware.id}", err=True)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
@endpoints.command()
|
|
268
|
+
@click.argument("endpoint-id", required=True)
|
|
269
|
+
@click.option(
|
|
270
|
+
"--wait", is_flag=True, default=True, help="Wait for the endpoint to stop"
|
|
271
|
+
)
|
|
272
|
+
@click.pass_obj
|
|
273
|
+
@handle_api_errors
|
|
274
|
+
def stop(client: Together, endpoint_id: str, wait: bool) -> None:
|
|
275
|
+
"""Stop a dedicated inference endpoint."""
|
|
276
|
+
client.endpoints.update(endpoint_id, state="STOPPED")
|
|
277
|
+
click.echo("Successfully marked endpoint as stopping", err=True)
|
|
278
|
+
|
|
279
|
+
if wait:
|
|
280
|
+
import time
|
|
281
|
+
|
|
282
|
+
click.echo("Waiting for endpoint to stop...", err=True)
|
|
283
|
+
while client.endpoints.get(endpoint_id).state != "STOPPED":
|
|
284
|
+
time.sleep(1)
|
|
285
|
+
click.echo("Endpoint stopped", err=True)
|
|
286
|
+
|
|
287
|
+
click.echo(endpoint_id)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
@endpoints.command()
|
|
291
|
+
@click.argument("endpoint-id", required=True)
|
|
292
|
+
@click.option(
|
|
293
|
+
"--wait", is_flag=True, default=True, help="Wait for the endpoint to start"
|
|
294
|
+
)
|
|
295
|
+
@click.pass_obj
|
|
296
|
+
@handle_api_errors
|
|
297
|
+
def start(client: Together, endpoint_id: str, wait: bool) -> None:
|
|
298
|
+
"""Start a dedicated inference endpoint."""
|
|
299
|
+
client.endpoints.update(endpoint_id, state="STARTED")
|
|
300
|
+
click.echo("Successfully marked endpoint as starting", err=True)
|
|
301
|
+
|
|
302
|
+
if wait:
|
|
303
|
+
import time
|
|
304
|
+
|
|
305
|
+
click.echo("Waiting for endpoint to start...", err=True)
|
|
306
|
+
while client.endpoints.get(endpoint_id).state != "STARTED":
|
|
307
|
+
time.sleep(1)
|
|
308
|
+
click.echo("Endpoint started", err=True)
|
|
309
|
+
|
|
310
|
+
click.echo(endpoint_id)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
@endpoints.command()
|
|
314
|
+
@click.argument("endpoint-id", required=True)
|
|
315
|
+
@click.pass_obj
|
|
316
|
+
@handle_api_errors
|
|
317
|
+
def delete(client: Together, endpoint_id: str) -> None:
|
|
318
|
+
"""Delete a dedicated inference endpoint."""
|
|
319
|
+
client.endpoints.delete(endpoint_id)
|
|
320
|
+
click.echo("Successfully deleted endpoint", err=True)
|
|
321
|
+
click.echo(endpoint_id)
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
@endpoints.command()
|
|
325
|
+
@click.option("--json", is_flag=True, help="Print output in JSON format")
|
|
326
|
+
@click.option(
|
|
327
|
+
"--type",
|
|
328
|
+
type=click.Choice(["dedicated", "serverless"]),
|
|
329
|
+
help="Filter by endpoint type",
|
|
330
|
+
)
|
|
331
|
+
@click.pass_obj
|
|
332
|
+
@handle_api_errors
|
|
333
|
+
def list(
|
|
334
|
+
client: Together, json: bool, type: Literal["dedicated", "serverless"] | None
|
|
335
|
+
) -> None:
|
|
336
|
+
"""List all inference endpoints (includes both dedicated and serverless endpoints)."""
|
|
337
|
+
endpoints: List[ListEndpoint] = client.endpoints.list(type=type)
|
|
338
|
+
|
|
339
|
+
if not endpoints:
|
|
340
|
+
click.echo("No dedicated endpoints found", err=True)
|
|
341
|
+
return
|
|
342
|
+
|
|
343
|
+
click.echo("Endpoints:", err=True)
|
|
344
|
+
if json:
|
|
345
|
+
import json as json_lib
|
|
346
|
+
|
|
347
|
+
click.echo(
|
|
348
|
+
json_lib.dumps([endpoint.model_dump() for endpoint in endpoints], indent=2)
|
|
349
|
+
)
|
|
350
|
+
else:
|
|
351
|
+
for endpoint in endpoints:
|
|
352
|
+
print_endpoint(
|
|
353
|
+
endpoint,
|
|
354
|
+
)
|
|
355
|
+
click.echo()
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
@endpoints.command()
|
|
359
|
+
@click.argument("endpoint-id", required=True)
|
|
360
|
+
@click.option(
|
|
361
|
+
"--display-name",
|
|
362
|
+
help="A new human-readable name for the endpoint",
|
|
363
|
+
)
|
|
364
|
+
@click.option(
|
|
365
|
+
"--min-replicas",
|
|
366
|
+
type=int,
|
|
367
|
+
help="New minimum number of replicas to maintain",
|
|
368
|
+
)
|
|
369
|
+
@click.option(
|
|
370
|
+
"--max-replicas",
|
|
371
|
+
type=int,
|
|
372
|
+
help="New maximum number of replicas to scale up to",
|
|
373
|
+
)
|
|
374
|
+
@click.pass_obj
|
|
375
|
+
@handle_api_errors
|
|
376
|
+
def update(
|
|
377
|
+
client: Together,
|
|
378
|
+
endpoint_id: str,
|
|
379
|
+
display_name: str | None,
|
|
380
|
+
min_replicas: int | None,
|
|
381
|
+
max_replicas: int | None,
|
|
382
|
+
) -> None:
|
|
383
|
+
"""Update a dedicated inference endpoint's configuration."""
|
|
384
|
+
if not any([display_name, min_replicas, max_replicas]):
|
|
385
|
+
click.echo("Error: At least one update option must be specified", err=True)
|
|
386
|
+
sys.exit(1)
|
|
387
|
+
|
|
388
|
+
# If only one of min/max replicas is specified, we need both for the update
|
|
389
|
+
if (min_replicas is None) != (max_replicas is None):
|
|
390
|
+
click.echo(
|
|
391
|
+
"Error: Both --min-replicas and --max-replicas must be specified together",
|
|
392
|
+
err=True,
|
|
393
|
+
)
|
|
394
|
+
sys.exit(1)
|
|
395
|
+
|
|
396
|
+
# Build kwargs for the update
|
|
397
|
+
kwargs: Dict[str, Any] = {}
|
|
398
|
+
if display_name is not None:
|
|
399
|
+
kwargs["display_name"] = display_name
|
|
400
|
+
if min_replicas is not None and max_replicas is not None:
|
|
401
|
+
kwargs["min_replicas"] = min_replicas
|
|
402
|
+
kwargs["max_replicas"] = max_replicas
|
|
403
|
+
|
|
404
|
+
_response = client.endpoints.update(endpoint_id, **kwargs)
|
|
405
|
+
|
|
406
|
+
# Print what was updated
|
|
407
|
+
click.echo("Updated endpoint configuration:", err=True)
|
|
408
|
+
if display_name:
|
|
409
|
+
click.echo(f" Display name: {display_name}", err=True)
|
|
410
|
+
if min_replicas is not None and max_replicas is not None:
|
|
411
|
+
click.echo(f" Min replicas: {min_replicas}", err=True)
|
|
412
|
+
click.echo(f" Max replicas: {max_replicas}", err=True)
|
|
413
|
+
|
|
414
|
+
click.echo("Successfully updated endpoint", err=True)
|
|
415
|
+
click.echo(endpoint_id)
|
together/cli/api/finetune.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
-
from datetime import datetime
|
|
4
|
+
from datetime import datetime, timezone
|
|
5
5
|
from textwrap import wrap
|
|
6
6
|
from typing import Any, Literal
|
|
7
|
+
import re
|
|
7
8
|
|
|
8
9
|
import click
|
|
9
10
|
from click.core import ParameterSource # type: ignore[attr-defined]
|
|
@@ -17,8 +18,13 @@ from together.utils import (
|
|
|
17
18
|
log_warn,
|
|
18
19
|
log_warn_once,
|
|
19
20
|
parse_timestamp,
|
|
21
|
+
format_timestamp,
|
|
22
|
+
)
|
|
23
|
+
from together.types.finetune import (
|
|
24
|
+
DownloadCheckpointType,
|
|
25
|
+
FinetuneTrainingLimits,
|
|
26
|
+
FinetuneEventType,
|
|
20
27
|
)
|
|
21
|
-
from together.types.finetune import DownloadCheckpointType, FinetuneTrainingLimits
|
|
22
28
|
|
|
23
29
|
|
|
24
30
|
_CONFIRMATION_MESSAGE = (
|
|
@@ -104,6 +110,18 @@ def fine_tuning(ctx: click.Context) -> None:
|
|
|
104
110
|
default="all-linear",
|
|
105
111
|
help="Trainable modules for LoRA adapters. For example, 'all-linear', 'q_proj,v_proj'",
|
|
106
112
|
)
|
|
113
|
+
@click.option(
|
|
114
|
+
"--training-method",
|
|
115
|
+
type=click.Choice(["sft", "dpo"]),
|
|
116
|
+
default="sft",
|
|
117
|
+
help="Training method to use. Options: sft (supervised fine-tuning), dpo (Direct Preference Optimization)",
|
|
118
|
+
)
|
|
119
|
+
@click.option(
|
|
120
|
+
"--dpo-beta",
|
|
121
|
+
type=float,
|
|
122
|
+
default=0.1,
|
|
123
|
+
help="Beta parameter for DPO training (only used when '--training-method' is 'dpo')",
|
|
124
|
+
)
|
|
107
125
|
@click.option(
|
|
108
126
|
"--suffix", type=str, default=None, help="Suffix for the fine-tuned model name"
|
|
109
127
|
)
|
|
@@ -126,6 +144,14 @@ def fine_tuning(ctx: click.Context) -> None:
|
|
|
126
144
|
help="Whether to mask the user messages in conversational data or prompts in instruction data. "
|
|
127
145
|
"`auto` will automatically determine whether to mask the inputs based on the data format.",
|
|
128
146
|
)
|
|
147
|
+
@click.option(
|
|
148
|
+
"--from-checkpoint",
|
|
149
|
+
type=str,
|
|
150
|
+
default=None,
|
|
151
|
+
help="The checkpoint identifier to continue training from a previous fine-tuning job. "
|
|
152
|
+
"The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}. "
|
|
153
|
+
"The step value is optional, without it the final checkpoint will be used.",
|
|
154
|
+
)
|
|
129
155
|
def create(
|
|
130
156
|
ctx: click.Context,
|
|
131
157
|
training_file: str,
|
|
@@ -152,6 +178,9 @@ def create(
|
|
|
152
178
|
wandb_name: str,
|
|
153
179
|
confirm: bool,
|
|
154
180
|
train_on_inputs: bool | Literal["auto"],
|
|
181
|
+
training_method: str,
|
|
182
|
+
dpo_beta: float,
|
|
183
|
+
from_checkpoint: str,
|
|
155
184
|
) -> None:
|
|
156
185
|
"""Start fine-tuning"""
|
|
157
186
|
client: Together = ctx.obj
|
|
@@ -180,6 +209,9 @@ def create(
|
|
|
180
209
|
wandb_project_name=wandb_project_name,
|
|
181
210
|
wandb_name=wandb_name,
|
|
182
211
|
train_on_inputs=train_on_inputs,
|
|
212
|
+
training_method=training_method,
|
|
213
|
+
dpo_beta=dpo_beta,
|
|
214
|
+
from_checkpoint=from_checkpoint,
|
|
183
215
|
)
|
|
184
216
|
|
|
185
217
|
model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits(
|
|
@@ -261,7 +293,9 @@ def list(ctx: click.Context) -> None:
|
|
|
261
293
|
|
|
262
294
|
response.data = response.data or []
|
|
263
295
|
|
|
264
|
-
|
|
296
|
+
# Use a default datetime for None values to make sure the key function always returns a comparable value
|
|
297
|
+
epoch_start = datetime.fromtimestamp(0, tz=timezone.utc)
|
|
298
|
+
response.data.sort(key=lambda x: parse_timestamp(x.created_at or "") or epoch_start)
|
|
265
299
|
|
|
266
300
|
display_list = []
|
|
267
301
|
for i in response.data:
|
|
@@ -344,6 +378,34 @@ def list_events(ctx: click.Context, fine_tune_id: str) -> None:
|
|
|
344
378
|
click.echo(table)
|
|
345
379
|
|
|
346
380
|
|
|
381
|
+
@fine_tuning.command()
|
|
382
|
+
@click.pass_context
|
|
383
|
+
@click.argument("fine_tune_id", type=str, required=True)
|
|
384
|
+
def list_checkpoints(ctx: click.Context, fine_tune_id: str) -> None:
|
|
385
|
+
"""List available checkpoints for a fine-tuning job"""
|
|
386
|
+
client: Together = ctx.obj
|
|
387
|
+
|
|
388
|
+
checkpoints = client.fine_tuning.list_checkpoints(fine_tune_id)
|
|
389
|
+
|
|
390
|
+
display_list = []
|
|
391
|
+
for checkpoint in checkpoints:
|
|
392
|
+
display_list.append(
|
|
393
|
+
{
|
|
394
|
+
"Type": checkpoint.type,
|
|
395
|
+
"Timestamp": format_timestamp(checkpoint.timestamp),
|
|
396
|
+
"Name": checkpoint.name,
|
|
397
|
+
}
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
if display_list:
|
|
401
|
+
click.echo(f"Job {fine_tune_id} contains the following checkpoints:")
|
|
402
|
+
table = tabulate(display_list, headers="keys", tablefmt="grid")
|
|
403
|
+
click.echo(table)
|
|
404
|
+
click.echo("\nTo download a checkpoint, use `together fine-tuning download`")
|
|
405
|
+
else:
|
|
406
|
+
click.echo(f"No checkpoints found for job {fine_tune_id}")
|
|
407
|
+
|
|
408
|
+
|
|
347
409
|
@fine_tuning.command()
|
|
348
410
|
@click.pass_context
|
|
349
411
|
@click.argument("fine_tune_id", type=str, required=True)
|
|
@@ -358,7 +420,7 @@ def list_events(ctx: click.Context, fine_tune_id: str) -> None:
|
|
|
358
420
|
"--checkpoint-step",
|
|
359
421
|
type=int,
|
|
360
422
|
required=False,
|
|
361
|
-
default
|
|
423
|
+
default=None,
|
|
362
424
|
help="Download fine-tuning checkpoint. Defaults to latest.",
|
|
363
425
|
)
|
|
364
426
|
@click.option(
|
|
@@ -372,7 +434,7 @@ def download(
|
|
|
372
434
|
ctx: click.Context,
|
|
373
435
|
fine_tune_id: str,
|
|
374
436
|
output_dir: str,
|
|
375
|
-
checkpoint_step: int,
|
|
437
|
+
checkpoint_step: int | None,
|
|
376
438
|
checkpoint_type: DownloadCheckpointType,
|
|
377
439
|
) -> None:
|
|
378
440
|
"""Download fine-tuning checkpoint"""
|
together/cli/cli.py
CHANGED
|
@@ -8,6 +8,7 @@ import click
|
|
|
8
8
|
import together
|
|
9
9
|
from together.cli.api.chat import chat, interactive
|
|
10
10
|
from together.cli.api.completions import completions
|
|
11
|
+
from together.cli.api.endpoints import endpoints
|
|
11
12
|
from together.cli.api.files import files
|
|
12
13
|
from together.cli.api.finetune import fine_tuning
|
|
13
14
|
from together.cli.api.images import images
|
|
@@ -72,6 +73,7 @@ main.add_command(images)
|
|
|
72
73
|
main.add_command(files)
|
|
73
74
|
main.add_command(fine_tuning)
|
|
74
75
|
main.add_command(models)
|
|
76
|
+
main.add_command(endpoints)
|
|
75
77
|
|
|
76
78
|
if __name__ == "__main__":
|
|
77
79
|
main()
|
together/client.py
CHANGED
together/constants.py
CHANGED
|
@@ -39,12 +39,18 @@ class DatasetFormat(enum.Enum):
|
|
|
39
39
|
GENERAL = "general"
|
|
40
40
|
CONVERSATION = "conversation"
|
|
41
41
|
INSTRUCTION = "instruction"
|
|
42
|
+
PREFERENCE_OPENAI = "preference_openai"
|
|
42
43
|
|
|
43
44
|
|
|
44
45
|
JSONL_REQUIRED_COLUMNS_MAP = {
|
|
45
46
|
DatasetFormat.GENERAL: ["text"],
|
|
46
47
|
DatasetFormat.CONVERSATION: ["messages"],
|
|
47
48
|
DatasetFormat.INSTRUCTION: ["prompt", "completion"],
|
|
49
|
+
DatasetFormat.PREFERENCE_OPENAI: [
|
|
50
|
+
"input",
|
|
51
|
+
"preferred_output",
|
|
52
|
+
"non_preferred_output",
|
|
53
|
+
],
|
|
48
54
|
}
|
|
49
55
|
REQUIRED_COLUMNS_MESSAGE = ["role", "content"]
|
|
50
56
|
POSSIBLE_ROLES_CONVERSATION = ["system", "user", "assistant"]
|
together/error.py
CHANGED
|
@@ -18,6 +18,9 @@ class TogetherException(Exception):
|
|
|
18
18
|
request_id: str | None = None,
|
|
19
19
|
http_status: int | None = None,
|
|
20
20
|
) -> None:
|
|
21
|
+
if isinstance(message, TogetherErrorResponse):
|
|
22
|
+
self.api_response = message
|
|
23
|
+
|
|
21
24
|
_message = (
|
|
22
25
|
json.dumps(message.model_dump(exclude_none=True))
|
|
23
26
|
if isinstance(message, TogetherErrorResponse)
|
together/legacy/finetune.py
CHANGED
together/resources/__init__.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
|
1
|
+
from together.resources.audio import AsyncAudio, Audio
|
|
1
2
|
from together.resources.chat import AsyncChat, Chat
|
|
2
3
|
from together.resources.completions import AsyncCompletions, Completions
|
|
3
4
|
from together.resources.embeddings import AsyncEmbeddings, Embeddings
|
|
5
|
+
from together.resources.endpoints import AsyncEndpoints, Endpoints
|
|
4
6
|
from together.resources.files import AsyncFiles, Files
|
|
5
7
|
from together.resources.finetune import AsyncFineTuning, FineTuning
|
|
6
8
|
from together.resources.images import AsyncImages, Images
|
|
7
9
|
from together.resources.models import AsyncModels, Models
|
|
8
10
|
from together.resources.rerank import AsyncRerank, Rerank
|
|
9
|
-
from together.resources.audio import AsyncAudio, Audio
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
__all__ = [
|
|
@@ -28,4 +29,6 @@ __all__ = [
|
|
|
28
29
|
"Rerank",
|
|
29
30
|
"AsyncAudio",
|
|
30
31
|
"Audio",
|
|
32
|
+
"AsyncEndpoints",
|
|
33
|
+
"Endpoints",
|
|
31
34
|
]
|