together 1.5.25__py3-none-any.whl → 1.5.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
together/filemanager.py CHANGED
@@ -1,28 +1,40 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import math
3
4
  import os
4
5
  import shutil
5
6
  import stat
6
7
  import tempfile
7
8
  import uuid
9
+ from concurrent.futures import ThreadPoolExecutor, as_completed
8
10
  from functools import partial
9
11
  from pathlib import Path
10
- from typing import Tuple
12
+ from typing import Any, Dict, List, Tuple
11
13
 
12
14
  import requests
13
15
  from filelock import FileLock
14
16
  from requests.structures import CaseInsensitiveDict
15
17
  from tqdm import tqdm
16
- from tqdm.utils import CallbackIOWrapper
17
18
 
18
- import together.utils
19
19
  from together.abstract import api_requestor
20
- from together.constants import DISABLE_TQDM, DOWNLOAD_BLOCK_SIZE, MAX_RETRIES
20
+ from together.constants import (
21
+ DISABLE_TQDM,
22
+ DOWNLOAD_BLOCK_SIZE,
23
+ MAX_CONCURRENT_PARTS,
24
+ MAX_FILE_SIZE_GB,
25
+ MAX_RETRIES,
26
+ MIN_PART_SIZE_MB,
27
+ NUM_BYTES_IN_GB,
28
+ TARGET_PART_SIZE_MB,
29
+ MAX_MULTIPART_PARTS,
30
+ MULTIPART_UPLOAD_TIMEOUT,
31
+ )
21
32
  from together.error import (
22
33
  APIError,
23
34
  AuthenticationError,
24
35
  DownloadError,
25
36
  FileTypeError,
37
+ ResponseError,
26
38
  )
27
39
  from together.together_response import TogetherResponse
28
40
  from together.types import (
@@ -32,6 +44,8 @@ from together.types import (
32
44
  TogetherClient,
33
45
  TogetherRequest,
34
46
  )
47
+ from tqdm.utils import CallbackIOWrapper
48
+ import together.utils
35
49
 
36
50
 
37
51
  def chmod_and_replace(src: Path, dst: Path) -> None:
@@ -339,7 +353,7 @@ class UploadManager:
339
353
  )
340
354
  redirect_url, file_id = self.get_upload_url(url, file, purpose, filetype)
341
355
 
342
- file_size = os.stat(file.as_posix()).st_size
356
+ file_size = os.stat(file).st_size
343
357
 
344
358
  with tqdm(
345
359
  total=file_size,
@@ -385,3 +399,214 @@ class UploadManager:
385
399
  assert isinstance(response, TogetherResponse)
386
400
 
387
401
  return FileResponse(**response.data)
402
+
403
+
404
+ class MultipartUploadManager:
405
+ """Handles multipart uploads for large files"""
406
+
407
+ def __init__(self, client: TogetherClient) -> None:
408
+ self._client = client
409
+ self.max_concurrent_parts = MAX_CONCURRENT_PARTS
410
+
411
+ def upload(
412
+ self,
413
+ url: str,
414
+ file: Path,
415
+ purpose: FilePurpose,
416
+ ) -> FileResponse:
417
+ """Upload large file using multipart upload"""
418
+
419
+ file_size = os.stat(file).st_size
420
+
421
+ file_size_gb = file_size / NUM_BYTES_IN_GB
422
+ if file_size_gb > MAX_FILE_SIZE_GB:
423
+ raise FileTypeError(
424
+ f"File size {file_size_gb:.1f}GB exceeds maximum supported size of {MAX_FILE_SIZE_GB}GB"
425
+ )
426
+
427
+ part_size, num_parts = self._calculate_parts(file_size)
428
+
429
+ file_type = self._get_file_type(file)
430
+ upload_info = None
431
+
432
+ try:
433
+ upload_info = self._initiate_upload(
434
+ url, file, file_size, num_parts, purpose, file_type
435
+ )
436
+
437
+ completed_parts = self._upload_parts_concurrent(
438
+ file, upload_info, part_size
439
+ )
440
+
441
+ return self._complete_upload(
442
+ url, upload_info["upload_id"], upload_info["file_id"], completed_parts
443
+ )
444
+
445
+ except Exception as e:
446
+ # Cleanup on failure
447
+ if upload_info is not None:
448
+ self._abort_upload(
449
+ url, upload_info["upload_id"], upload_info["file_id"]
450
+ )
451
+ raise e
452
+
453
+ def _get_file_type(self, file: Path) -> str:
454
+ """Get file type from extension, raising ValueError for unsupported extensions"""
455
+ if file.suffix == ".jsonl":
456
+ return "jsonl"
457
+ elif file.suffix == ".parquet":
458
+ return "parquet"
459
+ elif file.suffix == ".csv":
460
+ return "csv"
461
+ else:
462
+ raise ValueError(
463
+ f"Unsupported file extension: '{file.suffix}'. "
464
+ f"Supported extensions: .jsonl, .parquet, .csv"
465
+ )
466
+
467
+ def _calculate_parts(self, file_size: int) -> tuple[int, int]:
468
+ """Calculate optimal part size and count"""
469
+ min_part_size = MIN_PART_SIZE_MB * 1024 * 1024 # 5MB
470
+ target_part_size = TARGET_PART_SIZE_MB * 1024 * 1024 # 100MB
471
+
472
+ if file_size <= target_part_size:
473
+ return file_size, 1
474
+
475
+ num_parts = min(MAX_MULTIPART_PARTS, math.ceil(file_size / target_part_size))
476
+ part_size = math.ceil(file_size / num_parts)
477
+
478
+ if part_size < min_part_size:
479
+ part_size = min_part_size
480
+ num_parts = math.ceil(file_size / part_size)
481
+
482
+ return part_size, num_parts
483
+
484
+ def _initiate_upload(
485
+ self,
486
+ url: str,
487
+ file: Path,
488
+ file_size: int,
489
+ num_parts: int,
490
+ purpose: FilePurpose,
491
+ file_type: str,
492
+ ) -> Any:
493
+ """Initiate multipart upload with backend"""
494
+
495
+ requestor = api_requestor.APIRequestor(client=self._client)
496
+
497
+ payload = {
498
+ "file_name": file.name,
499
+ "file_size": file_size,
500
+ "num_parts": num_parts,
501
+ "purpose": purpose.value,
502
+ "file_type": file_type,
503
+ }
504
+
505
+ response, _, _ = requestor.request(
506
+ options=TogetherRequest(
507
+ method="POST",
508
+ url="files/multipart/initiate",
509
+ params=payload,
510
+ ),
511
+ )
512
+
513
+ return response.data
514
+
515
+ def _upload_parts_concurrent(
516
+ self, file: Path, upload_info: Dict[str, Any], part_size: int
517
+ ) -> List[Dict[str, Any]]:
518
+ """Upload file parts concurrently with progress tracking"""
519
+
520
+ parts = upload_info["parts"]
521
+ completed_parts = []
522
+
523
+ with ThreadPoolExecutor(max_workers=self.max_concurrent_parts) as executor:
524
+ with tqdm(total=len(parts), desc="Uploading parts", unit="part") as pbar:
525
+ future_to_part = {}
526
+
527
+ with open(file, "rb") as f:
528
+ for part_info in parts:
529
+ f.seek((part_info["PartNumber"] - 1) * part_size)
530
+ part_data = f.read(part_size)
531
+
532
+ future = executor.submit(
533
+ self._upload_single_part, part_info, part_data
534
+ )
535
+ future_to_part[future] = part_info["PartNumber"]
536
+
537
+ # Collect results
538
+ for future in as_completed(future_to_part):
539
+ part_number = future_to_part[future]
540
+ try:
541
+ etag = future.result()
542
+ completed_parts.append(
543
+ {"part_number": part_number, "etag": etag}
544
+ )
545
+ pbar.update(1)
546
+ except Exception as e:
547
+ raise Exception(f"Failed to upload part {part_number}: {e}")
548
+
549
+ completed_parts.sort(key=lambda x: x["part_number"])
550
+ return completed_parts
551
+
552
+ def _upload_single_part(self, part_info: Dict[str, Any], part_data: bytes) -> str:
553
+ """Upload a single part and return ETag"""
554
+
555
+ response = requests.put(
556
+ part_info["URL"],
557
+ data=part_data,
558
+ headers=part_info.get("Headers", {}),
559
+ timeout=MULTIPART_UPLOAD_TIMEOUT,
560
+ )
561
+ response.raise_for_status()
562
+
563
+ etag = response.headers.get("ETag", "").strip('"')
564
+ if not etag:
565
+ raise ResponseError(f"No ETag returned for part {part_info['PartNumber']}")
566
+
567
+ return etag
568
+
569
+ def _complete_upload(
570
+ self,
571
+ url: str,
572
+ upload_id: str,
573
+ file_id: str,
574
+ completed_parts: List[Dict[str, Any]],
575
+ ) -> FileResponse:
576
+ """Complete the multipart upload"""
577
+
578
+ requestor = api_requestor.APIRequestor(client=self._client)
579
+
580
+ payload = {
581
+ "upload_id": upload_id,
582
+ "file_id": file_id,
583
+ "parts": completed_parts,
584
+ }
585
+
586
+ response, _, _ = requestor.request(
587
+ options=TogetherRequest(
588
+ method="POST",
589
+ url="files/multipart/complete",
590
+ params=payload,
591
+ ),
592
+ )
593
+
594
+ return FileResponse(**response.data.get("file", response.data))
595
+
596
+ def _abort_upload(self, url: str, upload_id: str, file_id: str) -> None:
597
+ """Abort the multipart upload"""
598
+
599
+ requestor = api_requestor.APIRequestor(client=self._client)
600
+
601
+ payload = {
602
+ "upload_id": upload_id,
603
+ "file_id": file_id,
604
+ }
605
+
606
+ requestor.request(
607
+ options=TogetherRequest(
608
+ method="POST",
609
+ url="files/multipart/abort",
610
+ params=payload,
611
+ ),
612
+ )
@@ -10,6 +10,7 @@ from together.resources.models import AsyncModels, Models
10
10
  from together.resources.rerank import AsyncRerank, Rerank
11
11
  from together.resources.batch import Batches, AsyncBatches
12
12
  from together.resources.evaluation import Evaluation, AsyncEvaluation
13
+ from together.resources.videos import AsyncVideos, Videos
13
14
 
14
15
 
15
16
  __all__ = [
@@ -37,4 +38,6 @@ __all__ = [
37
38
  "AsyncBatches",
38
39
  "Evaluation",
39
40
  "AsyncEvaluation",
41
+ "AsyncVideos",
42
+ "Videos",
40
43
  ]
@@ -104,7 +104,12 @@ class Transcriptions:
104
104
  )
105
105
 
106
106
  # Add any additional kwargs
107
- params_data.update(kwargs)
107
+ # Convert boolean values to lowercase strings for proper form encoding
108
+ for key, value in kwargs.items():
109
+ if isinstance(value, bool):
110
+ params_data[key] = str(value).lower()
111
+ else:
112
+ params_data[key] = value
108
113
 
109
114
  try:
110
115
  response, _, _ = requestor.request(
@@ -131,7 +136,8 @@ class Transcriptions:
131
136
  response_format == "verbose_json"
132
137
  or response_format == AudioTranscriptionResponseFormat.VERBOSE_JSON
133
138
  ):
134
- return AudioTranscriptionVerboseResponse(**response.data)
139
+ # Create response with model validation that preserves extra fields
140
+ return AudioTranscriptionVerboseResponse.model_validate(response.data)
135
141
  else:
136
142
  return AudioTranscriptionResponse(**response.data)
137
143
 
@@ -234,7 +240,12 @@ class AsyncTranscriptions:
234
240
  )
235
241
 
236
242
  # Add any additional kwargs
237
- params_data.update(kwargs)
243
+ # Convert boolean values to lowercase strings for proper form encoding
244
+ for key, value in kwargs.items():
245
+ if isinstance(value, bool):
246
+ params_data[key] = str(value).lower()
247
+ else:
248
+ params_data[key] = value
238
249
 
239
250
  try:
240
251
  response, _, _ = await requestor.arequest(
@@ -261,6 +272,7 @@ class AsyncTranscriptions:
261
272
  response_format == "verbose_json"
262
273
  or response_format == AudioTranscriptionResponseFormat.VERBOSE_JSON
263
274
  ):
264
- return AudioTranscriptionVerboseResponse(**response.data)
275
+ # Create response with model validation that preserves extra fields
276
+ return AudioTranscriptionVerboseResponse.model_validate(response.data)
265
277
  else:
266
278
  return AudioTranscriptionResponse(**response.data)
@@ -56,8 +56,8 @@ class Endpoints:
56
56
  min_replicas: int,
57
57
  max_replicas: int,
58
58
  display_name: Optional[str] = None,
59
- disable_prompt_cache: bool = False,
60
- disable_speculative_decoding: bool = False,
59
+ disable_prompt_cache: bool = True,
60
+ disable_speculative_decoding: bool = True,
61
61
  state: Literal["STARTED", "STOPPED"] = "STARTED",
62
62
  inactive_timeout: Optional[int] = None,
63
63
  ) -> DedicatedEndpoint:
@@ -304,8 +304,8 @@ class AsyncEndpoints:
304
304
  min_replicas: int,
305
305
  max_replicas: int,
306
306
  display_name: Optional[str] = None,
307
- disable_prompt_cache: bool = False,
308
- disable_speculative_decoding: bool = False,
307
+ disable_prompt_cache: bool = True,
308
+ disable_speculative_decoding: bool = True,
309
309
  state: Literal["STARTED", "STOPPED"] = "STARTED",
310
310
  inactive_timeout: Optional[int] = None,
311
311
  ) -> DedicatedEndpoint:
@@ -27,9 +27,12 @@ class Evaluation:
27
27
  def create(
28
28
  self,
29
29
  type: str,
30
- judge_model_name: str,
30
+ judge_model: str,
31
+ judge_model_source: str,
31
32
  judge_system_template: str,
32
33
  input_data_file_path: str,
34
+ judge_external_api_token: Optional[str] = None,
35
+ judge_external_base_url: Optional[str] = None,
33
36
  # Classify-specific parameters
34
37
  labels: Optional[List[str]] = None,
35
38
  pass_labels: Optional[List[str]] = None,
@@ -48,9 +51,12 @@ class Evaluation:
48
51
 
49
52
  Args:
50
53
  type: The type of evaluation ("classify", "score", or "compare")
51
- judge_model_name: Name of the judge model
54
+ judge_model: Name or URL of the judge model
55
+ judge_model_source: Source of the judge model ("serverless", "dedicated", or "external")
52
56
  judge_system_template: System template for the judge
53
57
  input_data_file_path: Path to input data file
58
+ judge_external_api_token: Optional external API token for the judge model
59
+ judge_external_base_url: Optional external base URLs for the judge model
54
60
  labels: List of classification labels (required for classify)
55
61
  pass_labels: List of labels considered as passing (required for classify)
56
62
  min_score: Minimum score value (required for score)
@@ -67,10 +73,18 @@ class Evaluation:
67
73
  client=self._client,
68
74
  )
69
75
 
76
+ if judge_model_source == "external" and not judge_external_api_token:
77
+ raise ValueError(
78
+ "judge_external_api_token is required when judge_model_source is 'external'"
79
+ )
80
+
70
81
  # Build judge config
71
82
  judge_config = JudgeModelConfig(
72
- model_name=judge_model_name,
83
+ model=judge_model,
84
+ model_source=judge_model_source,
73
85
  system_template=judge_system_template,
86
+ external_api_token=judge_external_api_token,
87
+ external_base_url=judge_external_base_url,
74
88
  )
75
89
  parameters: Union[ClassifyParameters, ScoreParameters, CompareParameters]
76
90
  # Build parameters based on type
@@ -112,7 +126,8 @@ class Evaluation:
112
126
  elif isinstance(model_to_evaluate, dict):
113
127
  # Validate that all required fields are present for model config
114
128
  required_fields = [
115
- "model_name",
129
+ "model",
130
+ "model_source",
116
131
  "max_tokens",
117
132
  "temperature",
118
133
  "system_template",
@@ -128,6 +143,12 @@ class Evaluation:
128
143
  f"All model config parameters are required when using detailed configuration. "
129
144
  f"Missing: {', '.join(missing_fields)}"
130
145
  )
146
+ if model_to_evaluate.get(
147
+ "model_source"
148
+ ) == "external" and not model_to_evaluate.get("external_api_token"):
149
+ raise ValueError(
150
+ "external_api_token is required when model_source is 'external' for model_to_evaluate"
151
+ )
131
152
  parameters.model_to_evaluate = ModelRequest(**model_to_evaluate)
132
153
 
133
154
  elif type == "score":
@@ -163,7 +184,8 @@ class Evaluation:
163
184
  elif isinstance(model_to_evaluate, dict):
164
185
  # Validate that all required fields are present for model config
165
186
  required_fields = [
166
- "model_name",
187
+ "model",
188
+ "model_source",
167
189
  "max_tokens",
168
190
  "temperature",
169
191
  "system_template",
@@ -179,6 +201,12 @@ class Evaluation:
179
201
  f"All model config parameters are required when using detailed configuration. "
180
202
  f"Missing: {', '.join(missing_fields)}"
181
203
  )
204
+ if model_to_evaluate.get(
205
+ "model_source"
206
+ ) == "external" and not model_to_evaluate.get("external_api_token"):
207
+ raise ValueError(
208
+ "external_api_token is required when model_source is 'external' for model_to_evaluate"
209
+ )
182
210
  parameters.model_to_evaluate = ModelRequest(**model_to_evaluate)
183
211
 
184
212
  elif type == "compare":
@@ -223,7 +251,8 @@ class Evaluation:
223
251
  elif isinstance(model_a, dict):
224
252
  # Validate that all required fields are present for model config
225
253
  required_fields = [
226
- "model_name",
254
+ "model",
255
+ "model_source",
227
256
  "max_tokens",
228
257
  "temperature",
229
258
  "system_template",
@@ -237,6 +266,12 @@ class Evaluation:
237
266
  f"All model config parameters are required for model_a when using detailed configuration. "
238
267
  f"Missing: {', '.join(missing_fields)}"
239
268
  )
269
+ if model_a.get("model_source") == "external" and not model_a.get(
270
+ "external_api_token"
271
+ ):
272
+ raise ValueError(
273
+ "external_api_token is required when model_source is 'external' for model_a"
274
+ )
240
275
  parameters.model_a = ModelRequest(**model_a)
241
276
 
242
277
  # Handle model_b
@@ -245,7 +280,8 @@ class Evaluation:
245
280
  elif isinstance(model_b, dict):
246
281
  # Validate that all required fields are present for model config
247
282
  required_fields = [
248
- "model_name",
283
+ "model",
284
+ "model_source",
249
285
  "max_tokens",
250
286
  "temperature",
251
287
  "system_template",
@@ -259,6 +295,12 @@ class Evaluation:
259
295
  f"All model config parameters are required for model_b when using detailed configuration. "
260
296
  f"Missing: {', '.join(missing_fields)}"
261
297
  )
298
+ if model_b.get("model_source") == "external" and not model_b.get(
299
+ "external_api_token"
300
+ ):
301
+ raise ValueError(
302
+ "external_api_token is required when model_source is 'external' for model_b"
303
+ )
262
304
  parameters.model_b = ModelRequest(**model_b)
263
305
 
264
306
  else:
@@ -379,9 +421,12 @@ class AsyncEvaluation:
379
421
  async def create(
380
422
  self,
381
423
  type: str,
382
- judge_model_name: str,
424
+ judge_model: str,
425
+ judge_model_source: str,
383
426
  judge_system_template: str,
384
427
  input_data_file_path: str,
428
+ judge_external_api_token: Optional[str] = None,
429
+ judge_external_base_url: Optional[str] = None,
385
430
  # Classify-specific parameters
386
431
  labels: Optional[List[str]] = None,
387
432
  pass_labels: Optional[List[str]] = None,
@@ -400,9 +445,12 @@ class AsyncEvaluation:
400
445
 
401
446
  Args:
402
447
  type: The type of evaluation ("classify", "score", or "compare")
403
- judge_model_name: Name of the judge model
448
+ judge_model: Name or URL of the judge model
449
+ judge_model_source: Source of the judge model ("serverless", "dedicated", or "external")
404
450
  judge_system_template: System template for the judge
405
451
  input_data_file_path: Path to input data file
452
+ judge_external_api_token: Optional external API token for the judge model
453
+ judge_external_base_url: Optional external base URLs for the judge model
406
454
  labels: List of classification labels (required for classify)
407
455
  pass_labels: List of labels considered as passing (required for classify)
408
456
  min_score: Minimum score value (required for score)
@@ -419,10 +467,18 @@ class AsyncEvaluation:
419
467
  client=self._client,
420
468
  )
421
469
 
470
+ if judge_model_source == "external" and not judge_external_api_token:
471
+ raise ValueError(
472
+ "judge_external_api_token is required when judge_model_source is 'external'"
473
+ )
474
+
422
475
  # Build judge config
423
476
  judge_config = JudgeModelConfig(
424
- model_name=judge_model_name,
477
+ model=judge_model,
478
+ model_source=judge_model_source,
425
479
  system_template=judge_system_template,
480
+ external_api_token=judge_external_api_token,
481
+ external_base_url=judge_external_base_url,
426
482
  )
427
483
  parameters: Union[ClassifyParameters, ScoreParameters, CompareParameters]
428
484
  # Build parameters based on type
@@ -464,7 +520,8 @@ class AsyncEvaluation:
464
520
  elif isinstance(model_to_evaluate, dict):
465
521
  # Validate that all required fields are present for model config
466
522
  required_fields = [
467
- "model_name",
523
+ "model",
524
+ "model_source",
468
525
  "max_tokens",
469
526
  "temperature",
470
527
  "system_template",
@@ -480,6 +537,12 @@ class AsyncEvaluation:
480
537
  f"All model config parameters are required when using detailed configuration. "
481
538
  f"Missing: {', '.join(missing_fields)}"
482
539
  )
540
+ if model_to_evaluate.get(
541
+ "model_source"
542
+ ) == "external" and not model_to_evaluate.get("external_api_token"):
543
+ raise ValueError(
544
+ "external_api_token is required when model_source is 'external' for model_to_evaluate"
545
+ )
483
546
  parameters.model_to_evaluate = ModelRequest(**model_to_evaluate)
484
547
 
485
548
  elif type == "score":
@@ -515,7 +578,8 @@ class AsyncEvaluation:
515
578
  elif isinstance(model_to_evaluate, dict):
516
579
  # Validate that all required fields are present for model config
517
580
  required_fields = [
518
- "model_name",
581
+ "model",
582
+ "model_source",
519
583
  "max_tokens",
520
584
  "temperature",
521
585
  "system_template",
@@ -531,6 +595,12 @@ class AsyncEvaluation:
531
595
  f"All model config parameters are required when using detailed configuration. "
532
596
  f"Missing: {', '.join(missing_fields)}"
533
597
  )
598
+ if model_to_evaluate.get(
599
+ "model_source"
600
+ ) == "external" and not model_to_evaluate.get("external_api_token"):
601
+ raise ValueError(
602
+ "external_api_token is required when model_source is 'external' for model_to_evaluate"
603
+ )
534
604
  parameters.model_to_evaluate = ModelRequest(**model_to_evaluate)
535
605
 
536
606
  elif type == "compare":
@@ -575,7 +645,8 @@ class AsyncEvaluation:
575
645
  elif isinstance(model_a, dict):
576
646
  # Validate that all required fields are present for model config
577
647
  required_fields = [
578
- "model_name",
648
+ "model",
649
+ "model_source",
579
650
  "max_tokens",
580
651
  "temperature",
581
652
  "system_template",
@@ -589,6 +660,12 @@ class AsyncEvaluation:
589
660
  f"All model config parameters are required for model_a when using detailed configuration. "
590
661
  f"Missing: {', '.join(missing_fields)}"
591
662
  )
663
+ if model_a.get("model_source") == "external" and not model_a.get(
664
+ "external_api_token"
665
+ ):
666
+ raise ValueError(
667
+ "external_api_token is required when model_source is 'external' for model_a"
668
+ )
592
669
  parameters.model_a = ModelRequest(**model_a)
593
670
 
594
671
  # Handle model_b
@@ -597,7 +674,8 @@ class AsyncEvaluation:
597
674
  elif isinstance(model_b, dict):
598
675
  # Validate that all required fields are present for model config
599
676
  required_fields = [
600
- "model_name",
677
+ "model",
678
+ "model_source",
601
679
  "max_tokens",
602
680
  "temperature",
603
681
  "system_template",
@@ -611,6 +689,12 @@ class AsyncEvaluation:
611
689
  f"All model config parameters are required for model_b when using detailed configuration. "
612
690
  f"Missing: {', '.join(missing_fields)}"
613
691
  )
692
+ if model_b.get("model_source") == "external" and not model_b.get(
693
+ "external_api_token"
694
+ ):
695
+ raise ValueError(
696
+ "external_api_token is required when model_source is 'external' for model_b"
697
+ )
614
698
  parameters.model_b = ModelRequest(**model_b)
615
699
 
616
700
  else:
@@ -1,11 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import os
3
4
  from pathlib import Path
4
5
  from pprint import pformat
5
6
 
6
7
  from together.abstract import api_requestor
8
+ from together.constants import MULTIPART_THRESHOLD_GB, NUM_BYTES_IN_GB
7
9
  from together.error import FileTypeError
8
- from together.filemanager import DownloadManager, UploadManager
10
+ from together.filemanager import DownloadManager, UploadManager, MultipartUploadManager
9
11
  from together.together_response import TogetherResponse
10
12
  from together.types import (
11
13
  FileDeleteResponse,
@@ -30,7 +32,6 @@ class Files:
30
32
  purpose: FilePurpose | str = FilePurpose.FineTune,
31
33
  check: bool = True,
32
34
  ) -> FileResponse:
33
- upload_manager = UploadManager(self._client)
34
35
 
35
36
  if check and purpose == FilePurpose.FineTune:
36
37
  report_dict = check_file(file)
@@ -47,7 +48,15 @@ class Files:
47
48
 
48
49
  assert isinstance(purpose, FilePurpose)
49
50
 
50
- return upload_manager.upload("files", file, purpose=purpose, redirect=True)
51
+ file_size = os.stat(file).st_size
52
+ file_size_gb = file_size / NUM_BYTES_IN_GB
53
+
54
+ if file_size_gb > MULTIPART_THRESHOLD_GB:
55
+ multipart_manager = MultipartUploadManager(self._client)
56
+ return multipart_manager.upload("files", file, purpose)
57
+ else:
58
+ upload_manager = UploadManager(self._client)
59
+ return upload_manager.upload("files", file, purpose=purpose, redirect=True)
51
60
 
52
61
  def list(self) -> FileList:
53
62
  requestor = api_requestor.APIRequestor(