together 2.0.0a9__py3-none-any.whl → 2.0.0a10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. together/_types.py +3 -2
  2. together/_version.py +1 -1
  3. together/lib/cli/api/fine_tuning.py +65 -3
  4. together/lib/cli/api/models.py +1 -6
  5. together/lib/resources/fine_tuning.py +41 -2
  6. together/resources/chat/completions.py +48 -0
  7. together/resources/fine_tuning.py +213 -5
  8. together/resources/models.py +41 -5
  9. together/types/__init__.py +3 -0
  10. together/types/audio/voice_list_response.py +4 -0
  11. together/types/autoscaling.py +2 -0
  12. together/types/autoscaling_param.py +2 -0
  13. together/types/chat/completion_create_params.py +78 -5
  14. together/types/dedicated_endpoint.py +2 -0
  15. together/types/endpoint_list_avzones_response.py +2 -0
  16. together/types/endpoint_list_response.py +2 -0
  17. together/types/execute_response.py +7 -0
  18. together/types/fine_tuning_cancel_response.py +6 -0
  19. together/types/fine_tuning_estimate_price_params.py +98 -0
  20. together/types/fine_tuning_estimate_price_response.py +24 -0
  21. together/types/fine_tuning_list_response.py +6 -0
  22. together/types/hardware_list_response.py +8 -0
  23. together/types/model_list_params.py +12 -0
  24. together/types/video_job.py +8 -0
  25. {together-2.0.0a9.dist-info → together-2.0.0a10.dist-info}/METADATA +9 -11
  26. {together-2.0.0a9.dist-info → together-2.0.0a10.dist-info}/RECORD +29 -27
  27. together/lib/resources/models.py +0 -35
  28. {together-2.0.0a9.dist-info → together-2.0.0a10.dist-info}/WHEEL +0 -0
  29. {together-2.0.0a9.dist-info → together-2.0.0a10.dist-info}/entry_points.txt +0 -0
  30. {together-2.0.0a9.dist-info → together-2.0.0a10.dist-info}/licenses/LICENSE +0 -0
together/_types.py CHANGED
@@ -243,6 +243,9 @@ _T_co = TypeVar("_T_co", covariant=True)
243
243
  if TYPE_CHECKING:
244
244
  # This works because str.__contains__ does not accept object (either in typeshed or at runtime)
245
245
  # https://github.com/hauntsaninja/useful_types/blob/5e9710f3875107d068e7679fd7fec9cfab0eff3b/useful_types/__init__.py#L285
246
+ #
247
+ # Note: index() and count() methods are intentionally omitted to allow pyright to properly
248
+ # infer TypedDict types when dict literals are used in lists assigned to SequenceNotStr.
246
249
  class SequenceNotStr(Protocol[_T_co]):
247
250
  @overload
248
251
  def __getitem__(self, index: SupportsIndex, /) -> _T_co: ...
@@ -251,8 +254,6 @@ if TYPE_CHECKING:
251
254
  def __contains__(self, value: object, /) -> bool: ...
252
255
  def __len__(self) -> int: ...
253
256
  def __iter__(self) -> Iterator[_T_co]: ...
254
- def index(self, value: Any, start: int = 0, stop: int = ..., /) -> int: ...
255
- def count(self, value: Any, /) -> int: ...
256
257
  def __reversed__(self) -> Iterator[_T_co]: ...
257
258
  else:
258
259
  # just point this to a normal `Sequence` at runtime to avoid having to special case
together/_version.py CHANGED
@@ -1,4 +1,4 @@
1
1
  # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
2
 
3
3
  __title__ = "together"
4
- __version__ = "2.0.0-alpha.9" # x-release-please-version
4
+ __version__ = "2.0.0-alpha.10" # x-release-please-version
@@ -13,6 +13,7 @@ from tabulate import tabulate
13
13
  from click.core import ParameterSource # type: ignore[attr-defined]
14
14
 
15
15
  from together import Together
16
+ from together.types import fine_tuning_estimate_price_params as pe_params
16
17
  from together._types import NOT_GIVEN, NotGiven
17
18
  from together.lib.utils import log_warn
18
19
  from together.lib.utils.tools import format_timestamp, finetune_price_to_dollars
@@ -24,13 +25,21 @@ from together.lib.resources.fine_tuning import get_model_limits
24
25
 
25
26
  _CONFIRMATION_MESSAGE = (
26
27
  "You are about to create a fine-tuning job. "
27
- "The cost of your job will be determined by the model size, the number of tokens "
28
+ "The estimated price of this job is {price}. "
29
+ "The actual cost of your job will be determined by the model size, the number of tokens "
28
30
  "in the training file, the number of tokens in the validation file, the number of epochs, and "
29
- "the number of evaluations. Visit https://www.together.ai/pricing to get a price estimate.\n"
31
+ "the number of evaluations. Visit https://www.together.ai/pricing to learn more about pricing.\n"
32
+ "{warning}"
30
33
  "You can pass `-y` or `--confirm` to your command to skip this message.\n\n"
31
34
  "Do you want to proceed?"
32
35
  )
33
36
 
37
+ _WARNING_MESSAGE_INSUFFICIENT_FUNDS = (
38
+ "The estimated price of this job is significantly greater than your current credit limit and balance combined. "
39
+ "It will likely get cancelled due to insufficient funds. "
40
+ "Consider increasing your credit limit at https://api.together.xyz/settings/profile\n"
41
+ )
42
+
34
43
  _FT_JOB_WITH_STEP_REGEX = r"^ft-[\dabcdef-]+:\d+$"
35
44
 
36
45
 
@@ -323,7 +332,60 @@ def create(
323
332
  elif n_evals > 0 and not validation_file:
324
333
  raise click.BadParameter("You have specified a number of evaluation loops but no validation file.")
325
334
 
326
- if confirm or click.confirm(_CONFIRMATION_MESSAGE, default=True, show_default=True):
335
+ training_type_cls: pe_params.TrainingType
336
+ if lora:
337
+ training_type_cls = pe_params.TrainingTypeLoRaTrainingType(
338
+ lora_alpha=int(lora_alpha or 0),
339
+ lora_r=lora_r or 0,
340
+ lora_dropout=lora_dropout or 0,
341
+ lora_trainable_modules=lora_trainable_modules or "all-linear",
342
+ type="Lora",
343
+ )
344
+ else:
345
+ training_type_cls = pe_params.TrainingTypeFullTrainingType(
346
+ type="Full",
347
+ )
348
+
349
+ training_method_cls: pe_params.TrainingMethod
350
+ if training_method == "sft":
351
+ training_method_cls = pe_params.TrainingMethodTrainingMethodSft(
352
+ method="sft",
353
+ train_on_inputs=train_on_inputs or "auto",
354
+ )
355
+ else:
356
+ training_method_cls = pe_params.TrainingMethodTrainingMethodDpo(
357
+ method="dpo",
358
+ dpo_beta=dpo_beta or 0,
359
+ dpo_normalize_logratios_by_length=dpo_normalize_logratios_by_length or False,
360
+ dpo_reference_free=False,
361
+ rpo_alpha=rpo_alpha or 0,
362
+ simpo_gamma=simpo_gamma or 0,
363
+ )
364
+
365
+ finetune_price_estimation_result = client.fine_tuning.estimate_price(
366
+ training_file=training_file,
367
+ validation_file=validation_file,
368
+ model=model or "",
369
+ n_epochs=n_epochs,
370
+ n_evals=n_evals,
371
+ training_type=training_type_cls,
372
+ training_method=training_method_cls,
373
+ )
374
+ price = click.style(
375
+ f"${finetune_price_estimation_result.estimated_total_price:.2f}",
376
+ bold=True,
377
+ )
378
+ if not finetune_price_estimation_result.allowed_to_proceed:
379
+ warning = click.style(_WARNING_MESSAGE_INSUFFICIENT_FUNDS, fg="red", bold=True)
380
+ else:
381
+ warning = ""
382
+
383
+ confirmation_message = _CONFIRMATION_MESSAGE.format(
384
+ price=price,
385
+ warning=warning,
386
+ )
387
+
388
+ if confirm or click.confirm(confirmation_message, default=True, show_default=True):
327
389
  response = client.fine_tuning.create(
328
390
  **training_args,
329
391
  verbose=True,
@@ -7,7 +7,6 @@ from tabulate import tabulate
7
7
  from together import Together, omit
8
8
  from together._models import BaseModel
9
9
  from together._response import APIResponse as APIResponse
10
- from together.lib.resources.models import filter_by_dedicated_models
11
10
  from together.types.model_upload_response import ModelUploadResponse
12
11
 
13
12
 
@@ -34,11 +33,7 @@ def list(ctx: click.Context, type: Optional[str], json: bool) -> None:
34
33
  """List models"""
35
34
  client: Together = ctx.obj
36
35
 
37
- response = client.models.list()
38
- models_list = response
39
-
40
- if type == "dedicated":
41
- models_list = filter_by_dedicated_models(client, models_list)
36
+ models_list = client.models.list(dedicated=type == "dedicated" if type else omit)
42
37
 
43
38
  display_list: List[Dict[str, Any]] = []
44
39
  model: BaseModel
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Literal
4
4
 
5
5
  from rich import print as rprint
6
6
 
7
+ from together.types import fine_tuning_estimate_price_params as pe_params
7
8
  from together.lib.utils import log_warn_once
8
9
 
9
10
  if TYPE_CHECKING:
@@ -66,7 +67,7 @@ def create_finetune_request(
66
67
  hf_model_revision: str | None = None,
67
68
  hf_api_token: str | None = None,
68
69
  hf_output_repo_name: str | None = None,
69
- ) -> FinetuneRequest:
70
+ ) -> tuple[FinetuneRequest, pe_params.TrainingType, pe_params.TrainingMethod]:
70
71
  if model is not None and from_checkpoint is not None:
71
72
  raise ValueError("You must specify either a model or a checkpoint to start a job from, not both")
72
73
 
@@ -233,8 +234,46 @@ def create_finetune_request(
233
234
  hf_output_repo_name=hf_output_repo_name,
234
235
  )
235
236
 
236
- return finetune_request
237
+ training_type_pe, training_method_pe = create_price_estimation_params(finetune_request)
237
238
 
239
+ return finetune_request, training_type_pe, training_method_pe
240
+
241
+ def create_price_estimation_params(finetune_request: FinetuneRequest) -> tuple[pe_params.TrainingType, pe_params.TrainingMethod]:
242
+ training_type_cls: pe_params.TrainingType
243
+ if isinstance(finetune_request.training_type, FullTrainingType):
244
+ training_type_cls = pe_params.TrainingTypeFullTrainingType(
245
+ type="Full",
246
+ )
247
+ elif isinstance(finetune_request.training_type, LoRATrainingType):
248
+ training_type_cls = pe_params.TrainingTypeLoRaTrainingType(
249
+ lora_alpha=finetune_request.training_type.lora_alpha,
250
+ lora_r=finetune_request.training_type.lora_r,
251
+ lora_dropout=finetune_request.training_type.lora_dropout,
252
+ lora_trainable_modules=finetune_request.training_type.lora_trainable_modules,
253
+ type="Lora",
254
+ )
255
+ else:
256
+ raise ValueError(f"Unknown training type: {finetune_request.training_type}")
257
+
258
+ training_method_cls: pe_params.TrainingMethod
259
+ if isinstance(finetune_request.training_method, TrainingMethodSFT):
260
+ training_method_cls = pe_params.TrainingMethodTrainingMethodSft(
261
+ method="sft",
262
+ train_on_inputs=finetune_request.training_method.train_on_inputs,
263
+ )
264
+ elif isinstance(finetune_request.training_method, TrainingMethodDPO):
265
+ training_method_cls = pe_params.TrainingMethodTrainingMethodDpo(
266
+ method="dpo",
267
+ dpo_beta=finetune_request.training_method.dpo_beta or 0,
268
+ dpo_normalize_logratios_by_length=finetune_request.training_method.dpo_normalize_logratios_by_length,
269
+ dpo_reference_free=finetune_request.training_method.dpo_reference_free,
270
+ rpo_alpha=finetune_request.training_method.rpo_alpha or 0,
271
+ simpo_gamma=finetune_request.training_method.simpo_gamma or 0,
272
+ )
273
+ else:
274
+ raise ValueError(f"Unknown training method: {finetune_request.training_method}")
275
+
276
+ return training_type_cls, training_method_cls
238
277
 
239
278
  def get_model_limits(client: Together, model: str) -> FinetuneTrainingLimits:
240
279
  """
@@ -136,6 +136,14 @@ class CompletionsResource(SyncAPIResource):
136
136
 
137
137
  response_format: An object specifying the format that the model must output.
138
138
 
139
+ Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured
140
+ Outputs which ensures the model will match your supplied JSON schema. Learn more
141
+ in the [Structured Outputs guide](https://docs.together.ai/docs/json-mode).
142
+
143
+ Setting to `{ "type": "json_object" }` enables the older JSON mode, which
144
+ ensures the message the model generates is valid JSON. Using `json_schema` is
145
+ preferred for models that support it.
146
+
139
147
  safety_model: The name of the moderation model used to validate tokens. Choose from the
140
148
  available moderation models found
141
149
  [here](https://docs.together.ai/docs/inference-models#moderation-models).
@@ -277,6 +285,14 @@ class CompletionsResource(SyncAPIResource):
277
285
 
278
286
  response_format: An object specifying the format that the model must output.
279
287
 
288
+ Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured
289
+ Outputs which ensures the model will match your supplied JSON schema. Learn more
290
+ in the [Structured Outputs guide](https://docs.together.ai/docs/json-mode).
291
+
292
+ Setting to `{ "type": "json_object" }` enables the older JSON mode, which
293
+ ensures the message the model generates is valid JSON. Using `json_schema` is
294
+ preferred for models that support it.
295
+
280
296
  safety_model: The name of the moderation model used to validate tokens. Choose from the
281
297
  available moderation models found
282
298
  [here](https://docs.together.ai/docs/inference-models#moderation-models).
@@ -414,6 +430,14 @@ class CompletionsResource(SyncAPIResource):
414
430
 
415
431
  response_format: An object specifying the format that the model must output.
416
432
 
433
+ Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured
434
+ Outputs which ensures the model will match your supplied JSON schema. Learn more
435
+ in the [Structured Outputs guide](https://docs.together.ai/docs/json-mode).
436
+
437
+ Setting to `{ "type": "json_object" }` enables the older JSON mode, which
438
+ ensures the message the model generates is valid JSON. Using `json_schema` is
439
+ preferred for models that support it.
440
+
417
441
  safety_model: The name of the moderation model used to validate tokens. Choose from the
418
442
  available moderation models found
419
443
  [here](https://docs.together.ai/docs/inference-models#moderation-models).
@@ -653,6 +677,14 @@ class AsyncCompletionsResource(AsyncAPIResource):
653
677
 
654
678
  response_format: An object specifying the format that the model must output.
655
679
 
680
+ Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured
681
+ Outputs which ensures the model will match your supplied JSON schema. Learn more
682
+ in the [Structured Outputs guide](https://docs.together.ai/docs/json-mode).
683
+
684
+ Setting to `{ "type": "json_object" }` enables the older JSON mode, which
685
+ ensures the message the model generates is valid JSON. Using `json_schema` is
686
+ preferred for models that support it.
687
+
656
688
  safety_model: The name of the moderation model used to validate tokens. Choose from the
657
689
  available moderation models found
658
690
  [here](https://docs.together.ai/docs/inference-models#moderation-models).
@@ -794,6 +826,14 @@ class AsyncCompletionsResource(AsyncAPIResource):
794
826
 
795
827
  response_format: An object specifying the format that the model must output.
796
828
 
829
+ Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured
830
+ Outputs which ensures the model will match your supplied JSON schema. Learn more
831
+ in the [Structured Outputs guide](https://docs.together.ai/docs/json-mode).
832
+
833
+ Setting to `{ "type": "json_object" }` enables the older JSON mode, which
834
+ ensures the message the model generates is valid JSON. Using `json_schema` is
835
+ preferred for models that support it.
836
+
797
837
  safety_model: The name of the moderation model used to validate tokens. Choose from the
798
838
  available moderation models found
799
839
  [here](https://docs.together.ai/docs/inference-models#moderation-models).
@@ -931,6 +971,14 @@ class AsyncCompletionsResource(AsyncAPIResource):
931
971
 
932
972
  response_format: An object specifying the format that the model must output.
933
973
 
974
+ Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured
975
+ Outputs which ensures the model will match your supplied JSON schema. Learn more
976
+ in the [Structured Outputs guide](https://docs.together.ai/docs/json-mode).
977
+
978
+ Setting to `{ "type": "json_object" }` enables the older JSON mode, which
979
+ ensures the message the model generates is valid JSON. Using `json_schema` is
980
+ preferred for models that support it.
981
+
934
982
  safety_model: The name of the moderation model used to validate tokens. Choose from the
935
983
  available moderation models found
936
984
  [here](https://docs.together.ai/docs/inference-models#moderation-models).
@@ -7,7 +7,7 @@ from typing_extensions import Literal
7
7
  import httpx
8
8
  from rich import print as rprint
9
9
 
10
- from ..types import fine_tuning_delete_params, fine_tuning_content_params
10
+ from ..types import fine_tuning_delete_params, fine_tuning_content_params, fine_tuning_estimate_price_params
11
11
  from .._types import Body, Omit, Query, Headers, NotGiven, omit, not_given
12
12
  from .._utils import maybe_transform, async_maybe_transform
13
13
  from .._compat import cached_property
@@ -27,17 +27,31 @@ from .._response import (
27
27
  async_to_custom_streamed_response_wrapper,
28
28
  )
29
29
  from .._base_client import make_request_options
30
- from ..lib.types.fine_tuning import FinetuneResponse as FinetuneResponseLib, FinetuneTrainingLimits
30
+ from ..lib.types.fine_tuning import (
31
+ FinetuneResponse as FinetuneResponseLib,
32
+ FinetuneTrainingLimits,
33
+ )
31
34
  from ..types.finetune_response import FinetuneResponse
32
- from ..lib.resources.fine_tuning import get_model_limits, async_get_model_limits, create_finetune_request
35
+ from ..lib.resources.fine_tuning import (
36
+ get_model_limits,
37
+ async_get_model_limits,
38
+ create_finetune_request,
39
+ )
33
40
  from ..types.fine_tuning_list_response import FineTuningListResponse
34
41
  from ..types.fine_tuning_cancel_response import FineTuningCancelResponse
35
42
  from ..types.fine_tuning_delete_response import FineTuningDeleteResponse
36
43
  from ..types.fine_tuning_list_events_response import FineTuningListEventsResponse
44
+ from ..types.fine_tuning_estimate_price_response import FineTuningEstimatePriceResponse
37
45
  from ..types.fine_tuning_list_checkpoints_response import FineTuningListCheckpointsResponse
38
46
 
39
47
  __all__ = ["FineTuningResource", "AsyncFineTuningResource"]
40
48
 
49
+ _WARNING_MESSAGE_INSUFFICIENT_FUNDS = (
50
+ "The estimated price of the fine-tuning job is {} which is significantly "
51
+ "greater than your current credit limit and balance combined. "
52
+ "It will likely get cancelled due to insufficient funds. "
53
+ "Proceed at your own risk."
54
+ )
41
55
 
42
56
  class FineTuningResource(SyncAPIResource):
43
57
  @cached_property
@@ -179,7 +193,7 @@ class FineTuningResource(SyncAPIResource):
179
193
  pass
180
194
  model_limits = get_model_limits(self._client, str(model_name))
181
195
 
182
- finetune_request = create_finetune_request(
196
+ finetune_request, training_type_cls, training_method_cls = create_finetune_request(
183
197
  model_limits=model_limits,
184
198
  training_file=training_file,
185
199
  model=model,
@@ -218,11 +232,32 @@ class FineTuningResource(SyncAPIResource):
218
232
  hf_output_repo_name=hf_output_repo_name,
219
233
  )
220
234
 
235
+
236
+ price_estimation_result = self.estimate_price(
237
+ training_file=training_file,
238
+ from_checkpoint=from_checkpoint or Omit(),
239
+ validation_file=validation_file or Omit(),
240
+ model=model or "",
241
+ n_epochs=finetune_request.n_epochs,
242
+ n_evals=finetune_request.n_evals or 0,
243
+ training_type=training_type_cls,
244
+ training_method=training_method_cls,
245
+ )
246
+
247
+
221
248
  if verbose:
222
249
  rprint(
223
250
  "Submitting a fine-tuning job with the following parameters:",
224
251
  finetune_request,
225
252
  )
253
+ if not price_estimation_result.allowed_to_proceed:
254
+ rprint(
255
+ "[red]"
256
+ + _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format(
257
+ price_estimation_result.estimated_total_price # pyright: ignore[reportPossiblyUnboundVariable]
258
+ )
259
+ + "[/red]",
260
+ )
226
261
  parameter_payload = finetune_request.model_dump(exclude_none=True)
227
262
 
228
263
  return self._client.post(
@@ -413,6 +448,76 @@ class FineTuningResource(SyncAPIResource):
413
448
  cast_to=BinaryAPIResponse,
414
449
  )
415
450
 
451
+ def estimate_price(
452
+ self,
453
+ *,
454
+ training_file: str,
455
+ from_checkpoint: str | Omit = omit,
456
+ model: str | Omit = omit,
457
+ n_epochs: int | Omit = omit,
458
+ n_evals: int | Omit = omit,
459
+ training_method: fine_tuning_estimate_price_params.TrainingMethod | Omit = omit,
460
+ training_type: fine_tuning_estimate_price_params.TrainingType | Omit = omit,
461
+ validation_file: str | Omit = omit,
462
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
463
+ # The extra values given here take precedence over values defined on the client or passed to this method.
464
+ extra_headers: Headers | None = None,
465
+ extra_query: Query | None = None,
466
+ extra_body: Body | None = None,
467
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
468
+ ) -> FineTuningEstimatePriceResponse:
469
+ """
470
+ Estimate the price of a fine-tuning job.
471
+
472
+ Args:
473
+ training_file: File-ID of a training file uploaded to the Together API
474
+
475
+ from_checkpoint: The checkpoint identifier to continue training from a previous fine-tuning job.
476
+ Format is `{$JOB_ID}` or `{$OUTPUT_MODEL_NAME}` or `{$JOB_ID}:{$STEP}` or
477
+ `{$OUTPUT_MODEL_NAME}:{$STEP}`. The step value is optional; without it, the
478
+ final checkpoint will be used.
479
+
480
+ model: Name of the base model to run fine-tune job on
481
+
482
+ n_epochs: Number of complete passes through the training dataset (higher values may
483
+ improve results but increase cost and risk of overfitting)
484
+
485
+ n_evals: Number of evaluations to be run on a given validation set during training
486
+
487
+ training_method: The training method to use. 'sft' for Supervised Fine-Tuning or 'dpo' for Direct
488
+ Preference Optimization.
489
+
490
+ validation_file: File-ID of a validation file uploaded to the Together API
491
+
492
+ extra_headers: Send extra headers
493
+
494
+ extra_query: Add additional query parameters to the request
495
+
496
+ extra_body: Add additional JSON properties to the request
497
+
498
+ timeout: Override the client-level default timeout for this request, in seconds
499
+ """
500
+ return self._post(
501
+ "/fine-tunes/estimate-price",
502
+ body=maybe_transform(
503
+ {
504
+ "training_file": training_file,
505
+ "from_checkpoint": from_checkpoint,
506
+ "model": model,
507
+ "n_epochs": n_epochs,
508
+ "n_evals": n_evals,
509
+ "training_method": training_method,
510
+ "training_type": training_type,
511
+ "validation_file": validation_file,
512
+ },
513
+ fine_tuning_estimate_price_params.FineTuningEstimatePriceParams,
514
+ ),
515
+ options=make_request_options(
516
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
517
+ ),
518
+ cast_to=FineTuningEstimatePriceResponse,
519
+ )
520
+
416
521
  def list_checkpoints(
417
522
  self,
418
523
  id: str,
@@ -620,7 +725,7 @@ class AsyncFineTuningResource(AsyncAPIResource):
620
725
  pass
621
726
  model_limits = await async_get_model_limits(self._client, str(model_name))
622
727
 
623
- finetune_request = create_finetune_request(
728
+ finetune_request, training_type_cls, training_method_cls = create_finetune_request(
624
729
  model_limits=model_limits,
625
730
  training_file=training_file,
626
731
  model=model,
@@ -659,11 +764,32 @@ class AsyncFineTuningResource(AsyncAPIResource):
659
764
  hf_output_repo_name=hf_output_repo_name,
660
765
  )
661
766
 
767
+
768
+ price_estimation_result = await self.estimate_price(
769
+ training_file=training_file,
770
+ from_checkpoint=from_checkpoint or Omit(),
771
+ validation_file=validation_file or Omit(),
772
+ model=model or "",
773
+ n_epochs=finetune_request.n_epochs,
774
+ n_evals=finetune_request.n_evals or 0,
775
+ training_type=training_type_cls,
776
+ training_method=training_method_cls,
777
+ )
778
+
779
+
662
780
  if verbose:
663
781
  rprint(
664
782
  "Submitting a fine-tuning job with the following parameters:",
665
783
  finetune_request,
666
784
  )
785
+ if not price_estimation_result.allowed_to_proceed:
786
+ rprint(
787
+ "[red]"
788
+ + _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format(
789
+ price_estimation_result.estimated_total_price # pyright: ignore[reportPossiblyUnboundVariable]
790
+ )
791
+ + "[/red]",
792
+ )
667
793
  parameter_payload = finetune_request.model_dump(exclude_none=True)
668
794
 
669
795
  return await self._client.post(
@@ -854,6 +980,76 @@ class AsyncFineTuningResource(AsyncAPIResource):
854
980
  cast_to=AsyncBinaryAPIResponse,
855
981
  )
856
982
 
983
+ async def estimate_price(
984
+ self,
985
+ *,
986
+ training_file: str,
987
+ from_checkpoint: str | Omit = omit,
988
+ model: str | Omit = omit,
989
+ n_epochs: int | Omit = omit,
990
+ n_evals: int | Omit = omit,
991
+ training_method: fine_tuning_estimate_price_params.TrainingMethod | Omit = omit,
992
+ training_type: fine_tuning_estimate_price_params.TrainingType | Omit = omit,
993
+ validation_file: str | Omit = omit,
994
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
995
+ # The extra values given here take precedence over values defined on the client or passed to this method.
996
+ extra_headers: Headers | None = None,
997
+ extra_query: Query | None = None,
998
+ extra_body: Body | None = None,
999
+ timeout: float | httpx.Timeout | None | NotGiven = not_given,
1000
+ ) -> FineTuningEstimatePriceResponse:
1001
+ """
1002
+ Estimate the price of a fine-tuning job.
1003
+
1004
+ Args:
1005
+ training_file: File-ID of a training file uploaded to the Together API
1006
+
1007
+ from_checkpoint: The checkpoint identifier to continue training from a previous fine-tuning job.
1008
+ Format is `{$JOB_ID}` or `{$OUTPUT_MODEL_NAME}` or `{$JOB_ID}:{$STEP}` or
1009
+ `{$OUTPUT_MODEL_NAME}:{$STEP}`. The step value is optional; without it, the
1010
+ final checkpoint will be used.
1011
+
1012
+ model: Name of the base model to run fine-tune job on
1013
+
1014
+ n_epochs: Number of complete passes through the training dataset (higher values may
1015
+ improve results but increase cost and risk of overfitting)
1016
+
1017
+ n_evals: Number of evaluations to be run on a given validation set during training
1018
+
1019
+ training_method: The training method to use. 'sft' for Supervised Fine-Tuning or 'dpo' for Direct
1020
+ Preference Optimization.
1021
+
1022
+ validation_file: File-ID of a validation file uploaded to the Together API
1023
+
1024
+ extra_headers: Send extra headers
1025
+
1026
+ extra_query: Add additional query parameters to the request
1027
+
1028
+ extra_body: Add additional JSON properties to the request
1029
+
1030
+ timeout: Override the client-level default timeout for this request, in seconds
1031
+ """
1032
+ return await self._post(
1033
+ "/fine-tunes/estimate-price",
1034
+ body=await async_maybe_transform(
1035
+ {
1036
+ "training_file": training_file,
1037
+ "from_checkpoint": from_checkpoint,
1038
+ "model": model,
1039
+ "n_epochs": n_epochs,
1040
+ "n_evals": n_evals,
1041
+ "training_method": training_method,
1042
+ "training_type": training_type,
1043
+ "validation_file": validation_file,
1044
+ },
1045
+ fine_tuning_estimate_price_params.FineTuningEstimatePriceParams,
1046
+ ),
1047
+ options=make_request_options(
1048
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
1049
+ ),
1050
+ cast_to=FineTuningEstimatePriceResponse,
1051
+ )
1052
+
857
1053
  async def list_checkpoints(
858
1054
  self,
859
1055
  id: str,
@@ -941,6 +1137,9 @@ class FineTuningResourceWithRawResponse:
941
1137
  fine_tuning.content,
942
1138
  BinaryAPIResponse,
943
1139
  )
1140
+ self.estimate_price = to_raw_response_wrapper(
1141
+ fine_tuning.estimate_price,
1142
+ )
944
1143
  self.list_checkpoints = to_raw_response_wrapper(
945
1144
  fine_tuning.list_checkpoints,
946
1145
  )
@@ -969,6 +1168,9 @@ class AsyncFineTuningResourceWithRawResponse:
969
1168
  fine_tuning.content,
970
1169
  AsyncBinaryAPIResponse,
971
1170
  )
1171
+ self.estimate_price = async_to_raw_response_wrapper(
1172
+ fine_tuning.estimate_price,
1173
+ )
972
1174
  self.list_checkpoints = async_to_raw_response_wrapper(
973
1175
  fine_tuning.list_checkpoints,
974
1176
  )
@@ -997,6 +1199,9 @@ class FineTuningResourceWithStreamingResponse:
997
1199
  fine_tuning.content,
998
1200
  StreamedBinaryAPIResponse,
999
1201
  )
1202
+ self.estimate_price = to_streamed_response_wrapper(
1203
+ fine_tuning.estimate_price,
1204
+ )
1000
1205
  self.list_checkpoints = to_streamed_response_wrapper(
1001
1206
  fine_tuning.list_checkpoints,
1002
1207
  )
@@ -1025,6 +1230,9 @@ class AsyncFineTuningResourceWithStreamingResponse:
1025
1230
  fine_tuning.content,
1026
1231
  AsyncStreamedBinaryAPIResponse,
1027
1232
  )
1233
+ self.estimate_price = async_to_streamed_response_wrapper(
1234
+ fine_tuning.estimate_price,
1235
+ )
1028
1236
  self.list_checkpoints = async_to_streamed_response_wrapper(
1029
1237
  fine_tuning.list_checkpoints,
1030
1238
  )