together 1.5.17__py3-none-any.whl → 2.0.0a8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- together/__init__.py +101 -63
- together/_base_client.py +1995 -0
- together/_client.py +1033 -0
- together/_compat.py +219 -0
- together/_constants.py +14 -0
- together/_exceptions.py +108 -0
- together/_files.py +123 -0
- together/_models.py +857 -0
- together/_qs.py +150 -0
- together/_resource.py +43 -0
- together/_response.py +830 -0
- together/_streaming.py +370 -0
- together/_types.py +260 -0
- together/_utils/__init__.py +64 -0
- together/_utils/_compat.py +45 -0
- together/_utils/_datetime_parse.py +136 -0
- together/_utils/_logs.py +25 -0
- together/_utils/_proxy.py +65 -0
- together/_utils/_reflection.py +42 -0
- together/_utils/_resources_proxy.py +24 -0
- together/_utils/_streams.py +12 -0
- together/_utils/_sync.py +58 -0
- together/_utils/_transform.py +457 -0
- together/_utils/_typing.py +156 -0
- together/_utils/_utils.py +421 -0
- together/_version.py +4 -0
- together/lib/.keep +4 -0
- together/lib/__init__.py +23 -0
- together/{cli → lib/cli}/api/endpoints.py +108 -75
- together/lib/cli/api/evals.py +588 -0
- together/{cli → lib/cli}/api/files.py +20 -17
- together/{cli/api/finetune.py → lib/cli/api/fine_tuning.py} +161 -120
- together/lib/cli/api/models.py +140 -0
- together/{cli → lib/cli}/api/utils.py +6 -7
- together/{cli → lib/cli}/cli.py +16 -24
- together/{constants.py → lib/constants.py} +17 -12
- together/lib/resources/__init__.py +11 -0
- together/lib/resources/files.py +999 -0
- together/lib/resources/fine_tuning.py +280 -0
- together/lib/resources/models.py +35 -0
- together/lib/types/__init__.py +13 -0
- together/lib/types/error.py +9 -0
- together/lib/types/fine_tuning.py +455 -0
- together/{utils → lib/utils}/__init__.py +6 -14
- together/{utils → lib/utils}/_log.py +11 -16
- together/lib/utils/files.py +628 -0
- together/lib/utils/serializer.py +10 -0
- together/{utils → lib/utils}/tools.py +19 -55
- together/resources/__init__.py +225 -33
- together/resources/audio/__init__.py +72 -21
- together/resources/audio/audio.py +198 -0
- together/resources/audio/speech.py +574 -122
- together/resources/audio/transcriptions.py +282 -0
- together/resources/audio/translations.py +256 -0
- together/resources/audio/voices.py +135 -0
- together/resources/batches.py +417 -0
- together/resources/chat/__init__.py +30 -21
- together/resources/chat/chat.py +102 -0
- together/resources/chat/completions.py +1063 -263
- together/resources/code_interpreter/__init__.py +33 -0
- together/resources/code_interpreter/code_interpreter.py +258 -0
- together/resources/code_interpreter/sessions.py +135 -0
- together/resources/completions.py +884 -225
- together/resources/embeddings.py +172 -68
- together/resources/endpoints.py +598 -395
- together/resources/evals.py +452 -0
- together/resources/files.py +398 -121
- together/resources/fine_tuning.py +1033 -0
- together/resources/hardware.py +181 -0
- together/resources/images.py +256 -108
- together/resources/jobs.py +214 -0
- together/resources/models.py +238 -90
- together/resources/rerank.py +190 -92
- together/resources/videos.py +374 -0
- together/types/__init__.py +65 -109
- together/types/audio/__init__.py +10 -0
- together/types/audio/speech_create_params.py +75 -0
- together/types/audio/transcription_create_params.py +54 -0
- together/types/audio/transcription_create_response.py +111 -0
- together/types/audio/translation_create_params.py +40 -0
- together/types/audio/translation_create_response.py +70 -0
- together/types/audio/voice_list_response.py +23 -0
- together/types/audio_speech_stream_chunk.py +16 -0
- together/types/autoscaling.py +13 -0
- together/types/autoscaling_param.py +15 -0
- together/types/batch_create_params.py +24 -0
- together/types/batch_create_response.py +14 -0
- together/types/batch_job.py +45 -0
- together/types/batch_list_response.py +10 -0
- together/types/chat/__init__.py +18 -0
- together/types/chat/chat_completion.py +60 -0
- together/types/chat/chat_completion_chunk.py +61 -0
- together/types/chat/chat_completion_structured_message_image_url_param.py +18 -0
- together/types/chat/chat_completion_structured_message_text_param.py +13 -0
- together/types/chat/chat_completion_structured_message_video_url_param.py +18 -0
- together/types/chat/chat_completion_usage.py +13 -0
- together/types/chat/chat_completion_warning.py +9 -0
- together/types/chat/completion_create_params.py +329 -0
- together/types/code_interpreter/__init__.py +5 -0
- together/types/code_interpreter/session_list_response.py +31 -0
- together/types/code_interpreter_execute_params.py +45 -0
- together/types/completion.py +42 -0
- together/types/completion_chunk.py +66 -0
- together/types/completion_create_params.py +138 -0
- together/types/dedicated_endpoint.py +44 -0
- together/types/embedding.py +24 -0
- together/types/embedding_create_params.py +31 -0
- together/types/endpoint_create_params.py +43 -0
- together/types/endpoint_list_avzones_response.py +11 -0
- together/types/endpoint_list_params.py +18 -0
- together/types/endpoint_list_response.py +41 -0
- together/types/endpoint_update_params.py +27 -0
- together/types/eval_create_params.py +263 -0
- together/types/eval_create_response.py +16 -0
- together/types/eval_list_params.py +21 -0
- together/types/eval_list_response.py +10 -0
- together/types/eval_status_response.py +100 -0
- together/types/evaluation_job.py +139 -0
- together/types/execute_response.py +108 -0
- together/types/file_delete_response.py +13 -0
- together/types/file_list.py +12 -0
- together/types/file_purpose.py +9 -0
- together/types/file_response.py +31 -0
- together/types/file_type.py +7 -0
- together/types/fine_tuning_cancel_response.py +194 -0
- together/types/fine_tuning_content_params.py +24 -0
- together/types/fine_tuning_delete_params.py +11 -0
- together/types/fine_tuning_delete_response.py +12 -0
- together/types/fine_tuning_list_checkpoints_response.py +21 -0
- together/types/fine_tuning_list_events_response.py +12 -0
- together/types/fine_tuning_list_response.py +199 -0
- together/types/finetune_event.py +41 -0
- together/types/finetune_event_type.py +33 -0
- together/types/finetune_response.py +177 -0
- together/types/hardware_list_params.py +16 -0
- together/types/hardware_list_response.py +58 -0
- together/types/image_data_b64.py +15 -0
- together/types/image_data_url.py +15 -0
- together/types/image_file.py +23 -0
- together/types/image_generate_params.py +85 -0
- together/types/job_list_response.py +47 -0
- together/types/job_retrieve_response.py +43 -0
- together/types/log_probs.py +18 -0
- together/types/model_list_response.py +10 -0
- together/types/model_object.py +42 -0
- together/types/model_upload_params.py +36 -0
- together/types/model_upload_response.py +23 -0
- together/types/rerank_create_params.py +36 -0
- together/types/rerank_create_response.py +36 -0
- together/types/tool_choice.py +23 -0
- together/types/tool_choice_param.py +23 -0
- together/types/tools_param.py +23 -0
- together/types/training_method_dpo.py +22 -0
- together/types/training_method_sft.py +18 -0
- together/types/video_create_params.py +86 -0
- together/types/video_job.py +57 -0
- together-2.0.0a8.dist-info/METADATA +680 -0
- together-2.0.0a8.dist-info/RECORD +164 -0
- {together-1.5.17.dist-info → together-2.0.0a8.dist-info}/WHEEL +1 -1
- together-2.0.0a8.dist-info/entry_points.txt +2 -0
- {together-1.5.17.dist-info → together-2.0.0a8.dist-info/licenses}/LICENSE +1 -1
- together/abstract/api_requestor.py +0 -729
- together/cli/api/chat.py +0 -276
- together/cli/api/completions.py +0 -119
- together/cli/api/images.py +0 -93
- together/cli/api/models.py +0 -55
- together/client.py +0 -176
- together/error.py +0 -194
- together/filemanager.py +0 -389
- together/legacy/__init__.py +0 -0
- together/legacy/base.py +0 -27
- together/legacy/complete.py +0 -93
- together/legacy/embeddings.py +0 -27
- together/legacy/files.py +0 -146
- together/legacy/finetune.py +0 -177
- together/legacy/images.py +0 -27
- together/legacy/models.py +0 -44
- together/resources/batch.py +0 -136
- together/resources/code_interpreter.py +0 -82
- together/resources/finetune.py +0 -1064
- together/together_response.py +0 -50
- together/types/abstract.py +0 -26
- together/types/audio_speech.py +0 -110
- together/types/batch.py +0 -53
- together/types/chat_completions.py +0 -197
- together/types/code_interpreter.py +0 -57
- together/types/common.py +0 -66
- together/types/completions.py +0 -107
- together/types/embeddings.py +0 -35
- together/types/endpoints.py +0 -123
- together/types/error.py +0 -16
- together/types/files.py +0 -90
- together/types/finetune.py +0 -398
- together/types/images.py +0 -44
- together/types/models.py +0 -45
- together/types/rerank.py +0 -43
- together/utils/api_helpers.py +0 -124
- together/utils/files.py +0 -425
- together/version.py +0 -6
- together-1.5.17.dist-info/METADATA +0 -525
- together-1.5.17.dist-info/RECORD +0 -69
- together-1.5.17.dist-info/entry_points.txt +0 -3
- /together/{abstract → lib/cli}/__init__.py +0 -0
- /together/{cli → lib/cli/api}/__init__.py +0 -0
- /together/{cli/api/__init__.py → py.typed} +0 -0
|
@@ -0,0 +1,999 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import math
|
|
5
|
+
import stat
|
|
6
|
+
import uuid
|
|
7
|
+
import shutil
|
|
8
|
+
import asyncio
|
|
9
|
+
import logging
|
|
10
|
+
import tempfile
|
|
11
|
+
from typing import IO, Any, Dict, List, Tuple, cast
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from functools import partial
|
|
14
|
+
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
|
|
15
|
+
|
|
16
|
+
import httpx
|
|
17
|
+
from tqdm import tqdm
|
|
18
|
+
from filelock import FileLock
|
|
19
|
+
from tqdm.utils import CallbackIOWrapper
|
|
20
|
+
|
|
21
|
+
from ...types import FileType, FilePurpose, FileResponse
|
|
22
|
+
from ..._types import RequestOptions
|
|
23
|
+
from ..constants import (
|
|
24
|
+
DISABLE_TQDM,
|
|
25
|
+
NUM_BYTES_IN_GB,
|
|
26
|
+
MAX_FILE_SIZE_GB,
|
|
27
|
+
MIN_PART_SIZE_MB,
|
|
28
|
+
DOWNLOAD_BLOCK_SIZE,
|
|
29
|
+
MAX_MULTIPART_PARTS,
|
|
30
|
+
TARGET_PART_SIZE_MB,
|
|
31
|
+
MAX_CONCURRENT_PARTS,
|
|
32
|
+
MULTIPART_THRESHOLD_GB,
|
|
33
|
+
MULTIPART_UPLOAD_TIMEOUT,
|
|
34
|
+
)
|
|
35
|
+
from ..._resource import SyncAPIResource, AsyncAPIResource
|
|
36
|
+
from ..types.error import DownloadError, FileTypeError
|
|
37
|
+
from ..._exceptions import APIStatusError, AuthenticationError
|
|
38
|
+
|
|
39
|
+
log: logging.Logger = logging.getLogger(__name__)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def chmod_and_replace(src: Path, dst: Path) -> None:
|
|
43
|
+
"""Set correct permission before moving a blob from tmp directory to cache dir.
|
|
44
|
+
|
|
45
|
+
Do not take into account the `umask` from the process as there is no convenient way
|
|
46
|
+
to get it that is thread-safe.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
# Get umask by creating a temporary file in the cache folder.
|
|
50
|
+
tmp_file = dst.parent / f"tmp_{uuid.uuid4()}"
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
tmp_file.touch()
|
|
54
|
+
|
|
55
|
+
cache_dir_mode = Path(tmp_file).stat().st_mode
|
|
56
|
+
|
|
57
|
+
os.chmod(src.as_posix(), stat.S_IMODE(cache_dir_mode))
|
|
58
|
+
|
|
59
|
+
finally:
|
|
60
|
+
tmp_file.unlink()
|
|
61
|
+
|
|
62
|
+
shutil.move(src.as_posix(), dst.as_posix())
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _get_file_size(
|
|
66
|
+
headers: httpx.Headers,
|
|
67
|
+
) -> int:
|
|
68
|
+
"""
|
|
69
|
+
Extracts file size from header
|
|
70
|
+
"""
|
|
71
|
+
total_size_in_bytes = 0
|
|
72
|
+
|
|
73
|
+
parts = headers.get("Content-Range", "").split(" ")
|
|
74
|
+
|
|
75
|
+
if len(parts) == 2:
|
|
76
|
+
range_parts = parts[1].split("/")
|
|
77
|
+
|
|
78
|
+
if len(range_parts) == 2:
|
|
79
|
+
total_size_in_bytes = int(range_parts[1])
|
|
80
|
+
|
|
81
|
+
assert total_size_in_bytes != 0, "Unable to retrieve remote file."
|
|
82
|
+
|
|
83
|
+
return total_size_in_bytes
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _prepare_output(
|
|
87
|
+
headers: httpx.Headers,
|
|
88
|
+
step: int = -1,
|
|
89
|
+
output: Path | None = None,
|
|
90
|
+
remote_name: str | None = None,
|
|
91
|
+
) -> Path:
|
|
92
|
+
"""
|
|
93
|
+
Generates output file name from remote name and headers
|
|
94
|
+
"""
|
|
95
|
+
if output:
|
|
96
|
+
return output
|
|
97
|
+
|
|
98
|
+
content_type = str(headers.get("content-type"))
|
|
99
|
+
|
|
100
|
+
assert remote_name, "No model name found in fine_tuning object. Please specify an `output` file name."
|
|
101
|
+
|
|
102
|
+
if step > 0:
|
|
103
|
+
remote_name += f"-checkpoint-{step}"
|
|
104
|
+
|
|
105
|
+
if "x-tar" in content_type.lower():
|
|
106
|
+
remote_name += ".tar.gz"
|
|
107
|
+
|
|
108
|
+
else:
|
|
109
|
+
remote_name += ".tar.zst"
|
|
110
|
+
|
|
111
|
+
return Path(remote_name)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class DownloadManager(SyncAPIResource):
|
|
115
|
+
def get_file_metadata(
|
|
116
|
+
self,
|
|
117
|
+
url: str,
|
|
118
|
+
output: Path | None = None,
|
|
119
|
+
remote_name: str | None = None,
|
|
120
|
+
fetch_metadata: bool = False,
|
|
121
|
+
) -> Tuple[Path, int]:
|
|
122
|
+
"""
|
|
123
|
+
gets remote file head and parses out file name and file size
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
if not fetch_metadata:
|
|
127
|
+
if isinstance(output, Path):
|
|
128
|
+
file_path = output
|
|
129
|
+
else:
|
|
130
|
+
assert isinstance(remote_name, str)
|
|
131
|
+
file_path = Path(remote_name)
|
|
132
|
+
|
|
133
|
+
return file_path, 0
|
|
134
|
+
|
|
135
|
+
try:
|
|
136
|
+
response = self._client.get(
|
|
137
|
+
path=url,
|
|
138
|
+
options=RequestOptions(
|
|
139
|
+
headers={"Range": "bytes=0-1"},
|
|
140
|
+
),
|
|
141
|
+
cast_to=httpx.Response,
|
|
142
|
+
stream=False,
|
|
143
|
+
)
|
|
144
|
+
except APIStatusError as e:
|
|
145
|
+
raise APIStatusError(
|
|
146
|
+
"Error fetching file metadata",
|
|
147
|
+
response=e.response,
|
|
148
|
+
body=e.body,
|
|
149
|
+
) from e
|
|
150
|
+
|
|
151
|
+
headers = response.headers
|
|
152
|
+
|
|
153
|
+
assert isinstance(headers, httpx.Headers)
|
|
154
|
+
|
|
155
|
+
file_path = _prepare_output(
|
|
156
|
+
headers=headers,
|
|
157
|
+
output=output,
|
|
158
|
+
remote_name=remote_name,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
file_size = _get_file_size(headers)
|
|
162
|
+
|
|
163
|
+
return file_path, file_size
|
|
164
|
+
|
|
165
|
+
def download(
|
|
166
|
+
self,
|
|
167
|
+
url: str,
|
|
168
|
+
output: Path | None = None,
|
|
169
|
+
remote_name: str | None = None,
|
|
170
|
+
fetch_metadata: bool = False,
|
|
171
|
+
) -> Tuple[str, int]:
|
|
172
|
+
# pre-fetch remote file name and file size
|
|
173
|
+
file_path, file_size = self.get_file_metadata(url, output, remote_name, fetch_metadata)
|
|
174
|
+
|
|
175
|
+
temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=file_path.parent, delete=False)
|
|
176
|
+
|
|
177
|
+
# Prevent parallel downloads of the same file with a lock.
|
|
178
|
+
lock_path = Path(file_path.as_posix() + ".lock")
|
|
179
|
+
|
|
180
|
+
with FileLock(lock_path.as_posix()):
|
|
181
|
+
with temp_file_manager() as temp_file:
|
|
182
|
+
try:
|
|
183
|
+
response = self._client.get(
|
|
184
|
+
path=url,
|
|
185
|
+
cast_to=httpx.Response,
|
|
186
|
+
stream=True,
|
|
187
|
+
)
|
|
188
|
+
except APIStatusError as e:
|
|
189
|
+
os.remove(lock_path)
|
|
190
|
+
raise APIStatusError(
|
|
191
|
+
"Error downloading file",
|
|
192
|
+
response=e.response,
|
|
193
|
+
body=e.response,
|
|
194
|
+
) from e
|
|
195
|
+
|
|
196
|
+
if not fetch_metadata:
|
|
197
|
+
file_size = int(response.headers.get("content-length", 0))
|
|
198
|
+
|
|
199
|
+
assert file_size != 0, "Unable to retrieve remote file."
|
|
200
|
+
|
|
201
|
+
with tqdm(
|
|
202
|
+
total=file_size,
|
|
203
|
+
unit="B",
|
|
204
|
+
unit_scale=True,
|
|
205
|
+
desc=f"Downloading file {file_path.name}",
|
|
206
|
+
disable=bool(DISABLE_TQDM),
|
|
207
|
+
) as pbar:
|
|
208
|
+
for chunk in response.iter_bytes(DOWNLOAD_BLOCK_SIZE):
|
|
209
|
+
pbar.update(len(chunk))
|
|
210
|
+
temp_file.write(chunk) # type: ignore
|
|
211
|
+
|
|
212
|
+
# Raise exception if remote file size does not match downloaded file size
|
|
213
|
+
if os.stat(temp_file.name).st_size != file_size:
|
|
214
|
+
DownloadError(
|
|
215
|
+
f"Downloaded file size `{pbar.n}` bytes does not match remote file size `{file_size}` bytes."
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Moves temp file to output file path
|
|
219
|
+
chmod_and_replace(Path(temp_file.name), file_path)
|
|
220
|
+
|
|
221
|
+
os.remove(lock_path)
|
|
222
|
+
|
|
223
|
+
return str(file_path.resolve()), file_size
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class UploadManager(SyncAPIResource):
|
|
227
|
+
def get_upload_url(
|
|
228
|
+
self,
|
|
229
|
+
url: str,
|
|
230
|
+
file: Path,
|
|
231
|
+
purpose: FilePurpose,
|
|
232
|
+
filetype: FileType,
|
|
233
|
+
) -> Tuple[str, str]:
|
|
234
|
+
data = {
|
|
235
|
+
"purpose": purpose,
|
|
236
|
+
"file_name": file.name,
|
|
237
|
+
"file_type": filetype,
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
try:
|
|
241
|
+
response = self._client.post(
|
|
242
|
+
path=url,
|
|
243
|
+
cast_to=httpx.Response,
|
|
244
|
+
body=data,
|
|
245
|
+
options={"headers": {"Content-Type": "multipart/form-data"}, "follow_redirects": False},
|
|
246
|
+
)
|
|
247
|
+
except APIStatusError as e:
|
|
248
|
+
if e.response.status_code == 401:
|
|
249
|
+
raise AuthenticationError(
|
|
250
|
+
"This job would exceed your free trial credits. "
|
|
251
|
+
"Please upgrade to a paid account through "
|
|
252
|
+
"Settings -> Billing on api.together.ai to continue.",
|
|
253
|
+
response=e.response,
|
|
254
|
+
body=e.body,
|
|
255
|
+
) from e
|
|
256
|
+
if e.response.status_code != 302:
|
|
257
|
+
raise APIStatusError(
|
|
258
|
+
f"Unexpected error raised by endpoint: {e.response.content.decode()}, headers: {e.response.headers}",
|
|
259
|
+
response=e.response,
|
|
260
|
+
body=e.response.content.decode(),
|
|
261
|
+
) from e
|
|
262
|
+
response = e.response
|
|
263
|
+
|
|
264
|
+
redirect_url = response.headers.get("Location")
|
|
265
|
+
file_id = response.headers.get("X-Together-File-Id")
|
|
266
|
+
|
|
267
|
+
if not redirect_url or not file_id:
|
|
268
|
+
raise APIStatusError(
|
|
269
|
+
f"Missing required headers in response. Location: {redirect_url}, File-Id: {file_id}",
|
|
270
|
+
response=response,
|
|
271
|
+
body=response.content.decode() if hasattr(response, "content") else "",
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
return redirect_url, file_id
|
|
275
|
+
|
|
276
|
+
def callback(self, url: str) -> FileResponse:
|
|
277
|
+
response = self._client.post(
|
|
278
|
+
cast_to=FileResponse,
|
|
279
|
+
path=url,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
return response
|
|
283
|
+
|
|
284
|
+
def upload(
|
|
285
|
+
self,
|
|
286
|
+
url: str,
|
|
287
|
+
file: Path,
|
|
288
|
+
purpose: FilePurpose,
|
|
289
|
+
) -> FileResponse:
|
|
290
|
+
file_size = os.stat(file.as_posix()).st_size
|
|
291
|
+
file_size_gb = file_size / NUM_BYTES_IN_GB
|
|
292
|
+
|
|
293
|
+
if file_size_gb > MAX_FILE_SIZE_GB:
|
|
294
|
+
raise FileTypeError(
|
|
295
|
+
f"File size {file_size_gb:.1f}GB exceeds maximum supported size of {MAX_FILE_SIZE_GB}GB"
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
if file_size_gb > MULTIPART_THRESHOLD_GB:
|
|
299
|
+
multipart_manager = MultipartUploadManager(self._client)
|
|
300
|
+
return multipart_manager.upload(url, file, purpose)
|
|
301
|
+
else:
|
|
302
|
+
return self._upload_single_file(url, file, purpose)
|
|
303
|
+
|
|
304
|
+
def _upload_single_file(
|
|
305
|
+
self,
|
|
306
|
+
url: str,
|
|
307
|
+
file: Path,
|
|
308
|
+
purpose: FilePurpose,
|
|
309
|
+
) -> FileResponse:
|
|
310
|
+
file_id = None
|
|
311
|
+
|
|
312
|
+
redirect_url = None
|
|
313
|
+
if file.suffix == ".jsonl":
|
|
314
|
+
filetype = "jsonl"
|
|
315
|
+
elif file.suffix == ".parquet":
|
|
316
|
+
filetype = "parquet"
|
|
317
|
+
else:
|
|
318
|
+
raise FileTypeError(
|
|
319
|
+
f"Unknown extension of file {file}. Only files with extensions .jsonl and .parquet are supported."
|
|
320
|
+
)
|
|
321
|
+
redirect_url, file_id = self.get_upload_url(url, file, purpose, filetype) # type: ignore
|
|
322
|
+
|
|
323
|
+
file_size = os.stat(file.as_posix()).st_size
|
|
324
|
+
|
|
325
|
+
with tqdm(
|
|
326
|
+
total=file_size,
|
|
327
|
+
unit="B",
|
|
328
|
+
unit_scale=True,
|
|
329
|
+
desc=f"Uploading file {file.name}",
|
|
330
|
+
disable=bool(DISABLE_TQDM),
|
|
331
|
+
) as pbar:
|
|
332
|
+
with file.open("rb") as f:
|
|
333
|
+
wrapped_file = cast(IO[bytes], CallbackIOWrapper(pbar.update, f, "read"))
|
|
334
|
+
|
|
335
|
+
assert redirect_url is not None
|
|
336
|
+
callback_response = self._client._client.put(
|
|
337
|
+
url=redirect_url,
|
|
338
|
+
content=wrapped_file.read(),
|
|
339
|
+
)
|
|
340
|
+
log.debug(
|
|
341
|
+
'HTTP Response: %s %s "%i %s" %s',
|
|
342
|
+
"put",
|
|
343
|
+
redirect_url,
|
|
344
|
+
callback_response.status_code,
|
|
345
|
+
callback_response.reason_phrase,
|
|
346
|
+
callback_response.headers,
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
assert isinstance(callback_response, httpx.Response) # type: ignore
|
|
350
|
+
|
|
351
|
+
if not callback_response.status_code == 200:
|
|
352
|
+
raise APIStatusError(
|
|
353
|
+
f"Error during file upload: {callback_response.content.decode()}, headers: {callback_response.headers}",
|
|
354
|
+
response=callback_response,
|
|
355
|
+
body=callback_response.content.decode(),
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
response = self.callback(f"{url}/{file_id}/preprocess")
|
|
359
|
+
|
|
360
|
+
assert isinstance(response, FileResponse) # type: ignore
|
|
361
|
+
|
|
362
|
+
return response
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
class MultipartUploadManager(SyncAPIResource):
|
|
366
|
+
"""Handles multipart uploads for large files"""
|
|
367
|
+
|
|
368
|
+
def __init__(self, client: Any) -> None: # Accept any client type
|
|
369
|
+
super().__init__(client)
|
|
370
|
+
self.max_concurrent_parts = MAX_CONCURRENT_PARTS
|
|
371
|
+
|
|
372
|
+
def upload(
|
|
373
|
+
self,
|
|
374
|
+
url: str,
|
|
375
|
+
file: Path,
|
|
376
|
+
purpose: FilePurpose,
|
|
377
|
+
) -> FileResponse:
|
|
378
|
+
"""Upload large file using multipart upload"""
|
|
379
|
+
|
|
380
|
+
file_size = os.stat(file.as_posix()).st_size
|
|
381
|
+
file_size_gb = file_size / NUM_BYTES_IN_GB
|
|
382
|
+
|
|
383
|
+
if file_size_gb > MAX_FILE_SIZE_GB:
|
|
384
|
+
raise FileTypeError(
|
|
385
|
+
f"File size {file_size_gb:.1f}GB exceeds maximum supported size of {MAX_FILE_SIZE_GB}GB"
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
part_size, num_parts = _calculate_parts(file_size)
|
|
389
|
+
file_type = self._get_file_type(file)
|
|
390
|
+
upload_info = None
|
|
391
|
+
|
|
392
|
+
try:
|
|
393
|
+
upload_info = self._initiate_upload(url, file, file_size, num_parts, purpose, file_type)
|
|
394
|
+
|
|
395
|
+
completed_parts = self._upload_parts_concurrent(file, upload_info, part_size)
|
|
396
|
+
|
|
397
|
+
upload_id = upload_info.get("upload_id")
|
|
398
|
+
file_id = upload_info.get("file_id")
|
|
399
|
+
if not upload_id or not file_id:
|
|
400
|
+
raise ValueError("Missing upload_id or file_id from initiate response")
|
|
401
|
+
|
|
402
|
+
return self._complete_upload(url, upload_id, file_id, completed_parts)
|
|
403
|
+
|
|
404
|
+
except Exception as e:
|
|
405
|
+
if upload_info is not None:
|
|
406
|
+
upload_id = upload_info.get("upload_id")
|
|
407
|
+
file_id = upload_info.get("file_id")
|
|
408
|
+
if upload_id and file_id:
|
|
409
|
+
self._abort_upload(url, upload_id, file_id)
|
|
410
|
+
raise e
|
|
411
|
+
|
|
412
|
+
def _get_file_type(self, file: Path) -> str:
|
|
413
|
+
"""Get file type from extension"""
|
|
414
|
+
if file.suffix == ".jsonl":
|
|
415
|
+
return "jsonl"
|
|
416
|
+
elif file.suffix == ".parquet":
|
|
417
|
+
return "parquet"
|
|
418
|
+
elif file.suffix == ".csv":
|
|
419
|
+
return "csv"
|
|
420
|
+
else:
|
|
421
|
+
raise ValueError(
|
|
422
|
+
f"Unsupported file extension: '{file.suffix}'. Supported extensions: .jsonl, .parquet, .csv"
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
def _initiate_upload(
|
|
426
|
+
self,
|
|
427
|
+
url: str,
|
|
428
|
+
file: Path,
|
|
429
|
+
file_size: int,
|
|
430
|
+
num_parts: int,
|
|
431
|
+
purpose: FilePurpose,
|
|
432
|
+
file_type: str,
|
|
433
|
+
) -> Dict[str, Any]:
|
|
434
|
+
"""Initiate multipart upload with backend"""
|
|
435
|
+
|
|
436
|
+
payload: Dict[str, Any] = {
|
|
437
|
+
"file_name": file.name,
|
|
438
|
+
"file_size": file_size,
|
|
439
|
+
"num_parts": num_parts,
|
|
440
|
+
"purpose": str(purpose),
|
|
441
|
+
"file_type": file_type,
|
|
442
|
+
}
|
|
443
|
+
|
|
444
|
+
try:
|
|
445
|
+
response = self._client.post(
|
|
446
|
+
path=f"{url}/multipart/initiate",
|
|
447
|
+
cast_to=httpx.Response,
|
|
448
|
+
body=payload,
|
|
449
|
+
options={"headers": {"Content-Type": "application/json"}},
|
|
450
|
+
)
|
|
451
|
+
except APIStatusError as e:
|
|
452
|
+
if e.response.status_code == 400:
|
|
453
|
+
response = e.response
|
|
454
|
+
else:
|
|
455
|
+
raise e from e
|
|
456
|
+
|
|
457
|
+
if response.status_code == 200:
|
|
458
|
+
return cast(Dict[str, Any], response.json())
|
|
459
|
+
else:
|
|
460
|
+
raise APIStatusError(
|
|
461
|
+
f"Failed to initiate multipart upload: {response.text}",
|
|
462
|
+
response=response,
|
|
463
|
+
body=response.text,
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
def _submit_part(
|
|
467
|
+
self,
|
|
468
|
+
executor: ThreadPoolExecutor,
|
|
469
|
+
file_handle: IO[bytes],
|
|
470
|
+
part_info: Dict[str, Any],
|
|
471
|
+
part_size: int,
|
|
472
|
+
) -> Tuple[Future[str], int]:
|
|
473
|
+
"""Submit a single part for upload and return its future and part number."""
|
|
474
|
+
|
|
475
|
+
part_number = part_info.get("PartNumber", part_info.get("part_number", 1))
|
|
476
|
+
file_handle.seek((part_number - 1) * part_size)
|
|
477
|
+
part_data = file_handle.read(part_size)
|
|
478
|
+
|
|
479
|
+
future = executor.submit(self._upload_single_part, part_info, part_data)
|
|
480
|
+
return future, part_number
|
|
481
|
+
|
|
482
|
+
def _upload_parts_concurrent(self, file: Path, upload_info: Dict[str, Any], part_size: int) -> List[Dict[str, Any]]:
|
|
483
|
+
"""Upload file parts concurrently with progress tracking"""
|
|
484
|
+
|
|
485
|
+
parts = upload_info["parts"]
|
|
486
|
+
completed_parts: List[Dict[str, Any]] = []
|
|
487
|
+
|
|
488
|
+
with ThreadPoolExecutor(max_workers=self.max_concurrent_parts) as executor:
|
|
489
|
+
with tqdm(total=len(parts), desc="Uploading parts", unit="part", disable=bool(DISABLE_TQDM)) as pbar:
|
|
490
|
+
with open(file, "rb") as f:
|
|
491
|
+
future_to_part: Dict[Future[str], int] = {}
|
|
492
|
+
part_index = 0
|
|
493
|
+
|
|
494
|
+
while part_index < len(parts) and len(future_to_part) < self.max_concurrent_parts:
|
|
495
|
+
part_info = parts[part_index]
|
|
496
|
+
future, part_number = self._submit_part(executor, f, part_info, part_size)
|
|
497
|
+
future_to_part[future] = part_number
|
|
498
|
+
part_index += 1
|
|
499
|
+
|
|
500
|
+
while future_to_part:
|
|
501
|
+
done_future = next(as_completed(future_to_part))
|
|
502
|
+
part_number = future_to_part.pop(done_future)
|
|
503
|
+
|
|
504
|
+
try:
|
|
505
|
+
etag = done_future.result()
|
|
506
|
+
completed_parts.append({"part_number": part_number, "etag": etag})
|
|
507
|
+
pbar.update(1)
|
|
508
|
+
except Exception as e:
|
|
509
|
+
raise Exception(f"Failed to upload part {part_number}: {e}") from e
|
|
510
|
+
|
|
511
|
+
if part_index < len(parts):
|
|
512
|
+
part_info = parts[part_index]
|
|
513
|
+
future, next_part_number = self._submit_part(executor, f, part_info, part_size)
|
|
514
|
+
future_to_part[future] = next_part_number
|
|
515
|
+
part_index += 1
|
|
516
|
+
|
|
517
|
+
completed_parts.sort(key=lambda x: x["part_number"])
|
|
518
|
+
return completed_parts
|
|
519
|
+
|
|
520
|
+
def _upload_single_part(self, part_info: Dict[str, Any], part_data: bytes) -> str:
|
|
521
|
+
"""Upload a single part and return ETag"""
|
|
522
|
+
|
|
523
|
+
upload_url = part_info.get("URL", part_info.get("UploadURL"))
|
|
524
|
+
if not upload_url:
|
|
525
|
+
raise ValueError("Missing upload URL in part info")
|
|
526
|
+
|
|
527
|
+
part_headers = part_info.get("Headers", {})
|
|
528
|
+
|
|
529
|
+
response = self._client._client.put(
|
|
530
|
+
url=upload_url,
|
|
531
|
+
content=part_data,
|
|
532
|
+
headers=part_headers,
|
|
533
|
+
timeout=MULTIPART_UPLOAD_TIMEOUT,
|
|
534
|
+
)
|
|
535
|
+
response.raise_for_status()
|
|
536
|
+
|
|
537
|
+
etag = str(response.headers.get("ETag", "")).strip('"')
|
|
538
|
+
if not etag:
|
|
539
|
+
part_number = part_info.get("PartNumber", part_info.get("part_number", "unknown"))
|
|
540
|
+
raise APIStatusError(
|
|
541
|
+
f"No ETag returned for part {part_number}",
|
|
542
|
+
response=response,
|
|
543
|
+
body=response.content.decode(),
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
return etag
|
|
547
|
+
|
|
548
|
+
def _complete_upload(
|
|
549
|
+
self,
|
|
550
|
+
url: str,
|
|
551
|
+
upload_id: str,
|
|
552
|
+
file_id: str,
|
|
553
|
+
completed_parts: List[Dict[str, Any]],
|
|
554
|
+
) -> FileResponse:
|
|
555
|
+
"""Complete the multipart upload"""
|
|
556
|
+
|
|
557
|
+
payload = {
|
|
558
|
+
"upload_id": upload_id,
|
|
559
|
+
"file_id": file_id,
|
|
560
|
+
"parts": completed_parts,
|
|
561
|
+
}
|
|
562
|
+
|
|
563
|
+
try:
|
|
564
|
+
response = self._client.post(
|
|
565
|
+
path=f"{url}/multipart/complete",
|
|
566
|
+
cast_to=httpx.Response,
|
|
567
|
+
body=payload,
|
|
568
|
+
options={"headers": {"Content-Type": "application/json"}},
|
|
569
|
+
)
|
|
570
|
+
except APIStatusError as e:
|
|
571
|
+
if e.response.status_code == 400:
|
|
572
|
+
response = e.response
|
|
573
|
+
else:
|
|
574
|
+
raise e from e
|
|
575
|
+
|
|
576
|
+
if response.status_code == 200:
|
|
577
|
+
response_data = response.json()
|
|
578
|
+
file_data = response_data.get("file", response_data)
|
|
579
|
+
return FileResponse(**file_data)
|
|
580
|
+
else:
|
|
581
|
+
raise APIStatusError(
|
|
582
|
+
f"Failed to complete multipart upload: {response.text}",
|
|
583
|
+
response=response,
|
|
584
|
+
body=response.text,
|
|
585
|
+
)
|
|
586
|
+
|
|
587
|
+
def _abort_upload(self, url: str, upload_id: str, file_id: str) -> None:
|
|
588
|
+
"""Abort the multipart upload"""
|
|
589
|
+
|
|
590
|
+
payload = {
|
|
591
|
+
"upload_id": upload_id,
|
|
592
|
+
"file_id": file_id,
|
|
593
|
+
}
|
|
594
|
+
|
|
595
|
+
self._client.post(
|
|
596
|
+
path=f"{url}/multipart/abort",
|
|
597
|
+
cast_to=dict,
|
|
598
|
+
body=payload,
|
|
599
|
+
options={"headers": {"Content-Type": "application/json"}},
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
class AsyncUploadManager(AsyncAPIResource):
|
|
604
|
+
async def get_upload_url(
|
|
605
|
+
self,
|
|
606
|
+
url: str,
|
|
607
|
+
file: Path,
|
|
608
|
+
purpose: FilePurpose,
|
|
609
|
+
filetype: FileType,
|
|
610
|
+
) -> Tuple[str, str]:
|
|
611
|
+
data = {
|
|
612
|
+
"purpose": str(purpose),
|
|
613
|
+
"file_name": file.name,
|
|
614
|
+
"file_type": filetype,
|
|
615
|
+
}
|
|
616
|
+
|
|
617
|
+
try:
|
|
618
|
+
response = await self._client.post(
|
|
619
|
+
path=url,
|
|
620
|
+
cast_to=httpx.Response,
|
|
621
|
+
body=data,
|
|
622
|
+
options={"headers": {"Content-Type": "multipart/form-data"}, "follow_redirects": False},
|
|
623
|
+
)
|
|
624
|
+
except APIStatusError as e:
|
|
625
|
+
if e.response.status_code == 401:
|
|
626
|
+
raise AuthenticationError(
|
|
627
|
+
"This job would exceed your free trial credits. "
|
|
628
|
+
"Please upgrade to a paid account through "
|
|
629
|
+
"Settings -> Billing on api.together.ai to continue.",
|
|
630
|
+
response=e.response,
|
|
631
|
+
body=e.body,
|
|
632
|
+
) from e
|
|
633
|
+
if e.response.status_code != 302:
|
|
634
|
+
raise APIStatusError(
|
|
635
|
+
f"Unexpected error raised by endpoint: {e.response.content.decode()}, headers: {e.response.headers}",
|
|
636
|
+
response=e.response,
|
|
637
|
+
body=e.response.content.decode(),
|
|
638
|
+
) from e
|
|
639
|
+
response = e.response
|
|
640
|
+
|
|
641
|
+
redirect_url = response.headers.get("Location")
|
|
642
|
+
file_id = response.headers.get("X-Together-File-Id")
|
|
643
|
+
|
|
644
|
+
if not redirect_url or not file_id:
|
|
645
|
+
# Mock server scenario - return mock values for testing
|
|
646
|
+
if response.status_code == 200:
|
|
647
|
+
return "https://mock-upload-url.com", "mock-file-id"
|
|
648
|
+
else:
|
|
649
|
+
raise APIStatusError(
|
|
650
|
+
f"Missing required headers in response. Location: {redirect_url}, File-Id: {file_id}",
|
|
651
|
+
response=response,
|
|
652
|
+
body=response.content.decode() if hasattr(response, "content") else "",
|
|
653
|
+
)
|
|
654
|
+
|
|
655
|
+
return redirect_url, file_id
|
|
656
|
+
|
|
657
|
+
async def callback(self, url: str) -> FileResponse:
|
|
658
|
+
response = self._client.post(
|
|
659
|
+
cast_to=FileResponse,
|
|
660
|
+
path=url,
|
|
661
|
+
)
|
|
662
|
+
|
|
663
|
+
return await response
|
|
664
|
+
|
|
665
|
+
async def upload(
|
|
666
|
+
self,
|
|
667
|
+
url: str,
|
|
668
|
+
file: Path,
|
|
669
|
+
purpose: FilePurpose,
|
|
670
|
+
) -> FileResponse:
|
|
671
|
+
file_size = os.stat(file.as_posix()).st_size
|
|
672
|
+
file_size_gb = file_size / NUM_BYTES_IN_GB
|
|
673
|
+
|
|
674
|
+
if file_size_gb > MAX_FILE_SIZE_GB:
|
|
675
|
+
raise FileTypeError(
|
|
676
|
+
f"File size {file_size_gb:.1f}GB exceeds maximum supported size of {MAX_FILE_SIZE_GB}GB"
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
if file_size_gb > MULTIPART_THRESHOLD_GB:
|
|
680
|
+
multipart_manager = AsyncMultipartUploadManager(self._client)
|
|
681
|
+
return await multipart_manager.upload(url, file, purpose)
|
|
682
|
+
else:
|
|
683
|
+
return await self._upload_single_file(url, file, purpose)
|
|
684
|
+
|
|
685
|
+
async def _upload_single_file(
|
|
686
|
+
self,
|
|
687
|
+
url: str,
|
|
688
|
+
file: Path,
|
|
689
|
+
purpose: FilePurpose,
|
|
690
|
+
) -> FileResponse:
|
|
691
|
+
file_id = None
|
|
692
|
+
|
|
693
|
+
redirect_url = None
|
|
694
|
+
if file.suffix == ".jsonl":
|
|
695
|
+
filetype = "jsonl"
|
|
696
|
+
elif file.suffix == ".parquet":
|
|
697
|
+
filetype = "parquet"
|
|
698
|
+
else:
|
|
699
|
+
raise FileTypeError(
|
|
700
|
+
f"Unknown extension of file {file}. Only files with extensions .jsonl and .parquet are supported."
|
|
701
|
+
)
|
|
702
|
+
redirect_url, file_id = await self.get_upload_url(url, file, purpose, filetype) # type: ignore
|
|
703
|
+
|
|
704
|
+
file_size = os.stat(file.as_posix()).st_size
|
|
705
|
+
|
|
706
|
+
with tqdm(
|
|
707
|
+
total=file_size,
|
|
708
|
+
unit="B",
|
|
709
|
+
unit_scale=True,
|
|
710
|
+
desc=f"Uploading file {file.name}",
|
|
711
|
+
disable=bool(DISABLE_TQDM),
|
|
712
|
+
) as pbar:
|
|
713
|
+
with file.open("rb") as f:
|
|
714
|
+
wrapped_file = cast(IO[bytes], CallbackIOWrapper(pbar.update, f, "read"))
|
|
715
|
+
|
|
716
|
+
assert redirect_url is not None
|
|
717
|
+
callback_response = await self._client._client.put(
|
|
718
|
+
url=redirect_url,
|
|
719
|
+
content=wrapped_file.read(),
|
|
720
|
+
)
|
|
721
|
+
log.debug(
|
|
722
|
+
'HTTP Response: %s %s "%i %s" %s',
|
|
723
|
+
"put",
|
|
724
|
+
redirect_url,
|
|
725
|
+
callback_response.status_code,
|
|
726
|
+
callback_response.reason_phrase,
|
|
727
|
+
callback_response.headers,
|
|
728
|
+
)
|
|
729
|
+
|
|
730
|
+
assert isinstance(callback_response, httpx.Response) # type: ignore
|
|
731
|
+
|
|
732
|
+
if not callback_response.status_code == 200:
|
|
733
|
+
raise APIStatusError(
|
|
734
|
+
f"Error during file upload: {callback_response.content.decode()}, headers: {callback_response.headers}",
|
|
735
|
+
response=callback_response,
|
|
736
|
+
body=callback_response.content.decode(),
|
|
737
|
+
)
|
|
738
|
+
|
|
739
|
+
response = await self.callback(f"{url}/{file_id}/preprocess")
|
|
740
|
+
|
|
741
|
+
assert isinstance(response, FileResponse) # type: ignore
|
|
742
|
+
|
|
743
|
+
return response
|
|
744
|
+
|
|
745
|
+
|
|
746
|
+
class AsyncMultipartUploadManager(AsyncAPIResource):
|
|
747
|
+
"""Handles async multipart uploads using ThreadPoolExecutor for efficiency"""
|
|
748
|
+
|
|
749
|
+
def __init__(self, client: Any) -> None: # Accept any client type
|
|
750
|
+
super().__init__(client)
|
|
751
|
+
self.max_concurrent_parts = MAX_CONCURRENT_PARTS
|
|
752
|
+
|
|
753
|
+
async def upload(
|
|
754
|
+
self,
|
|
755
|
+
url: str,
|
|
756
|
+
file: Path,
|
|
757
|
+
purpose: FilePurpose,
|
|
758
|
+
) -> FileResponse:
|
|
759
|
+
"""Upload large file using multipart upload via ThreadPoolExecutor"""
|
|
760
|
+
|
|
761
|
+
file_size = os.stat(file.as_posix()).st_size
|
|
762
|
+
file_size_gb = file_size / NUM_BYTES_IN_GB
|
|
763
|
+
|
|
764
|
+
if file_size_gb > MAX_FILE_SIZE_GB:
|
|
765
|
+
raise FileTypeError(
|
|
766
|
+
f"File size {file_size_gb:.1f}GB exceeds maximum supported size of {MAX_FILE_SIZE_GB}GB"
|
|
767
|
+
)
|
|
768
|
+
|
|
769
|
+
part_size, num_parts = _calculate_parts(file_size)
|
|
770
|
+
file_type = self._get_file_type(file)
|
|
771
|
+
upload_info = None
|
|
772
|
+
|
|
773
|
+
try:
|
|
774
|
+
upload_info = await self._initiate_upload(url, file, file_size, num_parts, purpose, file_type)
|
|
775
|
+
|
|
776
|
+
completed_parts = await self._upload_parts_concurrent(file, upload_info, part_size)
|
|
777
|
+
|
|
778
|
+
upload_id = upload_info.get("upload_id")
|
|
779
|
+
file_id = upload_info.get("file_id")
|
|
780
|
+
if not upload_id or not file_id:
|
|
781
|
+
raise ValueError("Missing upload_id or file_id from initiate response")
|
|
782
|
+
|
|
783
|
+
return await self._complete_upload(url, upload_id, file_id, completed_parts)
|
|
784
|
+
|
|
785
|
+
except Exception as e:
|
|
786
|
+
if upload_info is not None:
|
|
787
|
+
upload_id = upload_info.get("upload_id")
|
|
788
|
+
file_id = upload_info.get("file_id")
|
|
789
|
+
if upload_id and file_id:
|
|
790
|
+
await self._abort_upload(url, upload_id, file_id)
|
|
791
|
+
raise e
|
|
792
|
+
|
|
793
|
+
def _get_file_type(self, file: Path) -> str:
|
|
794
|
+
"""Get file type from extension"""
|
|
795
|
+
if file.suffix == ".jsonl":
|
|
796
|
+
return "jsonl"
|
|
797
|
+
elif file.suffix == ".parquet":
|
|
798
|
+
return "parquet"
|
|
799
|
+
elif file.suffix == ".csv":
|
|
800
|
+
return "csv"
|
|
801
|
+
else:
|
|
802
|
+
raise ValueError(
|
|
803
|
+
f"Unsupported file extension: '{file.suffix}'. Supported extensions: .jsonl, .parquet, .csv"
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
async def _initiate_upload(
|
|
807
|
+
self,
|
|
808
|
+
url: str,
|
|
809
|
+
file: Path,
|
|
810
|
+
file_size: int,
|
|
811
|
+
num_parts: int,
|
|
812
|
+
purpose: FilePurpose,
|
|
813
|
+
file_type: str,
|
|
814
|
+
) -> Dict[str, Any]:
|
|
815
|
+
"""Initiate multipart upload with backend"""
|
|
816
|
+
|
|
817
|
+
payload = {
|
|
818
|
+
"file_name": file.name,
|
|
819
|
+
"file_size": file_size,
|
|
820
|
+
"num_parts": num_parts,
|
|
821
|
+
"purpose": str(purpose),
|
|
822
|
+
"file_type": file_type,
|
|
823
|
+
}
|
|
824
|
+
|
|
825
|
+
try:
|
|
826
|
+
response = await self._client.post(
|
|
827
|
+
path=f"{url}/multipart/initiate",
|
|
828
|
+
cast_to=httpx.Response,
|
|
829
|
+
body=payload,
|
|
830
|
+
options={"headers": {"Content-Type": "application/json"}},
|
|
831
|
+
)
|
|
832
|
+
except APIStatusError as e:
|
|
833
|
+
if e.response.status_code == 400:
|
|
834
|
+
response = e.response
|
|
835
|
+
else:
|
|
836
|
+
raise e from e
|
|
837
|
+
|
|
838
|
+
if response.status_code == 200:
|
|
839
|
+
return cast(Dict[str, Any], response.json())
|
|
840
|
+
else:
|
|
841
|
+
raise APIStatusError(
|
|
842
|
+
f"Failed to initiate multipart upload: {response.text}",
|
|
843
|
+
response=response,
|
|
844
|
+
body=response.text,
|
|
845
|
+
)
|
|
846
|
+
|
|
847
|
+
async def _upload_parts_concurrent(
|
|
848
|
+
self, file: Path, upload_info: Dict[str, Any], part_size: int
|
|
849
|
+
) -> List[Dict[str, Any]]:
|
|
850
|
+
"""Upload file parts concurrently using ThreadPoolExecutor"""
|
|
851
|
+
|
|
852
|
+
parts = upload_info["parts"]
|
|
853
|
+
completed_parts: List[Dict[str, Any]] = []
|
|
854
|
+
|
|
855
|
+
# Use ThreadPoolExecutor for HTTP I/O efficiency
|
|
856
|
+
loop = asyncio.get_event_loop()
|
|
857
|
+
|
|
858
|
+
with ThreadPoolExecutor(max_workers=self.max_concurrent_parts) as executor:
|
|
859
|
+
with tqdm(total=len(parts), desc="Uploading parts", unit="part", disable=bool(DISABLE_TQDM)) as pbar:
|
|
860
|
+
with open(file, "rb") as f:
|
|
861
|
+
future_to_part: Dict[asyncio.Future[str], int] = {}
|
|
862
|
+
part_index = 0
|
|
863
|
+
|
|
864
|
+
while part_index < len(parts) and len(future_to_part) < self.max_concurrent_parts:
|
|
865
|
+
part_info = parts[part_index]
|
|
866
|
+
part_number = part_info.get("PartNumber", part_info.get("part_number", 1))
|
|
867
|
+
f.seek((part_number - 1) * part_size)
|
|
868
|
+
part_data = f.read(part_size)
|
|
869
|
+
|
|
870
|
+
future = loop.run_in_executor(executor, self._upload_single_part_sync, part_info, part_data)
|
|
871
|
+
future_to_part[future] = part_number
|
|
872
|
+
part_index += 1
|
|
873
|
+
|
|
874
|
+
while future_to_part:
|
|
875
|
+
done, _ = await asyncio.wait(
|
|
876
|
+
tuple(future_to_part.keys()),
|
|
877
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
878
|
+
)
|
|
879
|
+
|
|
880
|
+
for done_future in done:
|
|
881
|
+
part_number = future_to_part.pop(done_future)
|
|
882
|
+
|
|
883
|
+
try:
|
|
884
|
+
etag = await done_future
|
|
885
|
+
completed_parts.append({"part_number": part_number, "etag": etag})
|
|
886
|
+
pbar.update(1)
|
|
887
|
+
except Exception as e:
|
|
888
|
+
raise Exception(f"Failed to upload part {part_number}: {e}") from e
|
|
889
|
+
|
|
890
|
+
if part_index < len(parts):
|
|
891
|
+
part_info = parts[part_index]
|
|
892
|
+
next_part_number = part_info.get("PartNumber", part_info.get("part_number", 1))
|
|
893
|
+
f.seek((next_part_number - 1) * part_size)
|
|
894
|
+
part_data = f.read(part_size)
|
|
895
|
+
future = loop.run_in_executor(
|
|
896
|
+
executor, self._upload_single_part_sync, part_info, part_data
|
|
897
|
+
)
|
|
898
|
+
future_to_part[future] = next_part_number
|
|
899
|
+
part_index += 1
|
|
900
|
+
|
|
901
|
+
completed_parts.sort(key=lambda x: x["part_number"])
|
|
902
|
+
return completed_parts
|
|
903
|
+
|
|
904
|
+
def _upload_single_part_sync(self, part_info: Dict[str, Any], part_data: bytes) -> str:
|
|
905
|
+
"""Sync version of single part upload for use in ThreadPoolExecutor"""
|
|
906
|
+
|
|
907
|
+
upload_url = part_info.get("URL", part_info.get("UploadURL"))
|
|
908
|
+
if not upload_url:
|
|
909
|
+
raise ValueError("Missing upload URL in part info")
|
|
910
|
+
|
|
911
|
+
part_headers = part_info.get("Headers", {})
|
|
912
|
+
|
|
913
|
+
with httpx.Client() as client:
|
|
914
|
+
response = client.put(
|
|
915
|
+
url=upload_url,
|
|
916
|
+
content=part_data,
|
|
917
|
+
headers=part_headers,
|
|
918
|
+
timeout=MULTIPART_UPLOAD_TIMEOUT,
|
|
919
|
+
)
|
|
920
|
+
response.raise_for_status()
|
|
921
|
+
|
|
922
|
+
etag = str(response.headers.get("ETag", "")).strip('"')
|
|
923
|
+
if not etag:
|
|
924
|
+
part_number = part_info.get("PartNumber", part_info.get("part_number", "unknown"))
|
|
925
|
+
raise ValueError(f"No ETag returned for part {part_number}")
|
|
926
|
+
|
|
927
|
+
return etag
|
|
928
|
+
|
|
929
|
+
async def _complete_upload(
|
|
930
|
+
self,
|
|
931
|
+
url: str,
|
|
932
|
+
upload_id: str,
|
|
933
|
+
file_id: str,
|
|
934
|
+
completed_parts: List[Dict[str, Any]],
|
|
935
|
+
) -> FileResponse:
|
|
936
|
+
"""Complete the multipart upload"""
|
|
937
|
+
|
|
938
|
+
payload = {
|
|
939
|
+
"upload_id": upload_id,
|
|
940
|
+
"file_id": file_id,
|
|
941
|
+
"parts": completed_parts,
|
|
942
|
+
}
|
|
943
|
+
|
|
944
|
+
try:
|
|
945
|
+
response = await self._client.post(
|
|
946
|
+
path=f"{url}/multipart/complete",
|
|
947
|
+
cast_to=httpx.Response,
|
|
948
|
+
body=payload,
|
|
949
|
+
options={"headers": {"Content-Type": "application/json"}},
|
|
950
|
+
)
|
|
951
|
+
except APIStatusError as e:
|
|
952
|
+
if e.response.status_code == 400:
|
|
953
|
+
response = e.response
|
|
954
|
+
else:
|
|
955
|
+
raise e from e
|
|
956
|
+
|
|
957
|
+
if response.status_code == 200:
|
|
958
|
+
response_data = response.json()
|
|
959
|
+
file_data = response_data.get("file", response_data)
|
|
960
|
+
return FileResponse(**file_data)
|
|
961
|
+
else:
|
|
962
|
+
raise APIStatusError(
|
|
963
|
+
f"Failed to complete multipart upload: {response.text}",
|
|
964
|
+
response=response,
|
|
965
|
+
body=response.text,
|
|
966
|
+
)
|
|
967
|
+
|
|
968
|
+
async def _abort_upload(self, url: str, upload_id: str, file_id: str) -> None:
|
|
969
|
+
"""Abort the multipart upload"""
|
|
970
|
+
|
|
971
|
+
payload = {
|
|
972
|
+
"upload_id": upload_id,
|
|
973
|
+
"file_id": file_id,
|
|
974
|
+
}
|
|
975
|
+
|
|
976
|
+
await self._client.post(
|
|
977
|
+
path=f"{url}/multipart/abort",
|
|
978
|
+
cast_to=dict,
|
|
979
|
+
body=payload,
|
|
980
|
+
options={"headers": {"Content-Type": "application/json"}},
|
|
981
|
+
)
|
|
982
|
+
|
|
983
|
+
|
|
984
|
+
def _calculate_parts(file_size: int) -> Tuple[int, int]:
|
|
985
|
+
"""Calculate optimal part size and count"""
|
|
986
|
+
min_part_size = MIN_PART_SIZE_MB * 1024 * 1024 # 5MB
|
|
987
|
+
target_part_size = TARGET_PART_SIZE_MB * 1024 * 1024 # 100MB
|
|
988
|
+
|
|
989
|
+
if file_size <= target_part_size:
|
|
990
|
+
return file_size, 1
|
|
991
|
+
|
|
992
|
+
num_parts = min(MAX_MULTIPART_PARTS, math.ceil(file_size / target_part_size))
|
|
993
|
+
part_size = math.ceil(file_size / num_parts)
|
|
994
|
+
|
|
995
|
+
if part_size < min_part_size:
|
|
996
|
+
part_size = min_part_size
|
|
997
|
+
num_parts = math.ceil(file_size / part_size)
|
|
998
|
+
|
|
999
|
+
return part_size, num_parts
|