lm-deluge 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lm-deluge might be problematic. Click here for more details.
- lm_deluge/__init__.py +6 -0
- lm_deluge/api_requests/__init__.py +3 -0
- lm_deluge/api_requests/anthropic.py +177 -0
- lm_deluge/api_requests/base.py +375 -0
- lm_deluge/api_requests/cohere.py +138 -0
- lm_deluge/api_requests/common.py +18 -0
- lm_deluge/api_requests/deprecated/bedrock.py +288 -0
- lm_deluge/api_requests/deprecated/deepseek.py +118 -0
- lm_deluge/api_requests/deprecated/mistral.py +120 -0
- lm_deluge/api_requests/google.py +0 -0
- lm_deluge/api_requests/openai.py +145 -0
- lm_deluge/api_requests/vertex.py +365 -0
- lm_deluge/cache.py +144 -0
- lm_deluge/client.py +760 -0
- lm_deluge/embed.py +392 -0
- lm_deluge/errors.py +8 -0
- lm_deluge/gemini_limits.py +65 -0
- lm_deluge/image.py +200 -0
- lm_deluge/llm_tools/__init__.py +11 -0
- lm_deluge/llm_tools/extract.py +111 -0
- lm_deluge/llm_tools/score.py +71 -0
- lm_deluge/llm_tools/translate.py +44 -0
- lm_deluge/models.py +957 -0
- lm_deluge/prompt.py +355 -0
- lm_deluge/rerank.py +338 -0
- lm_deluge/sampling_params.py +25 -0
- lm_deluge/tool.py +106 -0
- lm_deluge/tracker.py +12 -0
- lm_deluge/util/json.py +167 -0
- lm_deluge/util/logprobs.py +446 -0
- lm_deluge/util/pdf.py +45 -0
- lm_deluge/util/validation.py +46 -0
- lm_deluge/util/xml.py +291 -0
- lm_deluge-0.0.3.dist-info/METADATA +127 -0
- lm_deluge-0.0.3.dist-info/RECORD +37 -0
- lm_deluge-0.0.3.dist-info/WHEEL +5 -0
- lm_deluge-0.0.3.dist-info/top_level.txt +1 -0
lm_deluge/embed.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
1
|
+
### specific utility for cohere rerank api
|
|
2
|
+
import os
|
|
3
|
+
import numpy as np
|
|
4
|
+
import aiohttp
|
|
5
|
+
from tqdm.auto import tqdm
|
|
6
|
+
import asyncio
|
|
7
|
+
import time
|
|
8
|
+
from typing import Any, Optional
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from .tracker import StatusTracker
|
|
11
|
+
|
|
12
|
+
registry = {
|
|
13
|
+
"text-embedding-3-small": {
|
|
14
|
+
"name": "text-embedding-3-small",
|
|
15
|
+
"provider": "openai",
|
|
16
|
+
"cost": 0.02, # per million tokens
|
|
17
|
+
},
|
|
18
|
+
"text-embedding-3-large": {
|
|
19
|
+
"name": "text-embedding-3-large",
|
|
20
|
+
"provider": "openai",
|
|
21
|
+
"cost": 0.13,
|
|
22
|
+
},
|
|
23
|
+
"text-embedding-ada-002": {
|
|
24
|
+
"name": "text-embedding-ada-002",
|
|
25
|
+
"provider": "openai",
|
|
26
|
+
"cost": 1,
|
|
27
|
+
},
|
|
28
|
+
"embed-english-v3.0": {
|
|
29
|
+
"name": "embed-english-v3.0",
|
|
30
|
+
"provider": "cohere",
|
|
31
|
+
"cost": 0.1,
|
|
32
|
+
},
|
|
33
|
+
"embed-english-light-v3.0": {
|
|
34
|
+
"name": "embed-english-light-v3.0",
|
|
35
|
+
"provider": "cohere",
|
|
36
|
+
"cost": 0.1,
|
|
37
|
+
},
|
|
38
|
+
"embed-multilingual-v3.0": {
|
|
39
|
+
"name": "embed-multilingual-v3.0",
|
|
40
|
+
"provider": "cohere",
|
|
41
|
+
"cost": 0.1,
|
|
42
|
+
},
|
|
43
|
+
"embed-multilingual-light-v3.0": {
|
|
44
|
+
"name": "embed-multilingual-light-v3.0",
|
|
45
|
+
"provider": "cohere",
|
|
46
|
+
"cost": 0.1,
|
|
47
|
+
},
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class EmbeddingRequest:
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
task_id: int,
|
|
55
|
+
model_name: str,
|
|
56
|
+
texts: list[str],
|
|
57
|
+
attempts_left: int,
|
|
58
|
+
status_tracker: StatusTracker,
|
|
59
|
+
retry_queue: asyncio.Queue,
|
|
60
|
+
request_timeout: int,
|
|
61
|
+
pbar: Optional[tqdm] = None,
|
|
62
|
+
**kwargs, # openai or cohere specific params
|
|
63
|
+
):
|
|
64
|
+
self.task_id = task_id
|
|
65
|
+
self.model_name = model_name
|
|
66
|
+
self.texts = texts
|
|
67
|
+
self.attempts_left = attempts_left
|
|
68
|
+
self.status_tracker = status_tracker
|
|
69
|
+
self.retry_queue = retry_queue
|
|
70
|
+
self.request_timeout = request_timeout
|
|
71
|
+
self.pbar = pbar
|
|
72
|
+
self.result = []
|
|
73
|
+
self.kwargs = kwargs
|
|
74
|
+
|
|
75
|
+
def increment_pbar(self):
|
|
76
|
+
if self.pbar is not None:
|
|
77
|
+
self.pbar.update(1)
|
|
78
|
+
|
|
79
|
+
def handle_success(self):
|
|
80
|
+
self.increment_pbar()
|
|
81
|
+
self.status_tracker.num_tasks_in_progress -= 1
|
|
82
|
+
self.status_tracker.num_tasks_succeeded += 1
|
|
83
|
+
|
|
84
|
+
def handle_error(self):
|
|
85
|
+
last_result: EmbeddingResponse = self.result[-1]
|
|
86
|
+
error_to_print = (
|
|
87
|
+
f"Error on task {self.task_id}, Code: {last_result.status_code}, "
|
|
88
|
+
)
|
|
89
|
+
error_to_print += f"Message: {last_result.error_message}."
|
|
90
|
+
print(error_to_print)
|
|
91
|
+
if self.attempts_left > 0:
|
|
92
|
+
self.attempts_left -= 1
|
|
93
|
+
self.retry_queue.put_nowait(self)
|
|
94
|
+
return
|
|
95
|
+
else:
|
|
96
|
+
print(f"Task {self.task_id} out of tries.")
|
|
97
|
+
self.status_tracker.num_tasks_in_progress -= 1
|
|
98
|
+
self.status_tracker.num_tasks_failed += 1
|
|
99
|
+
|
|
100
|
+
async def handle_response(self, response: aiohttp.ClientResponse):
|
|
101
|
+
try:
|
|
102
|
+
if response.status == 200:
|
|
103
|
+
result = await response.json()
|
|
104
|
+
# TODO: add cost calculation
|
|
105
|
+
if self.model_name in [
|
|
106
|
+
"text-embedding-3-small",
|
|
107
|
+
"text-embedding-3-large",
|
|
108
|
+
"text-embedding-ada-002",
|
|
109
|
+
]:
|
|
110
|
+
embeddings = [
|
|
111
|
+
embedding["embedding"] for embedding in result["data"]
|
|
112
|
+
]
|
|
113
|
+
elif self.model_name in [
|
|
114
|
+
"embed-english-v3.0",
|
|
115
|
+
"embed-english-light-v3.0",
|
|
116
|
+
"embed-multilingual-v3.0",
|
|
117
|
+
"embed-multilingual-light-v3.0",
|
|
118
|
+
]:
|
|
119
|
+
embeddings = result["embeddings"]
|
|
120
|
+
else:
|
|
121
|
+
raise ValueError(f"Unsupported model {self.model_name}")
|
|
122
|
+
return EmbeddingResponse(
|
|
123
|
+
id=self.task_id,
|
|
124
|
+
status_code=response.status,
|
|
125
|
+
is_error=False,
|
|
126
|
+
error_message=None,
|
|
127
|
+
texts=self.texts,
|
|
128
|
+
embeddings=embeddings,
|
|
129
|
+
)
|
|
130
|
+
else:
|
|
131
|
+
error_msg = await response.text()
|
|
132
|
+
return EmbeddingResponse(
|
|
133
|
+
id=self.task_id,
|
|
134
|
+
status_code=response.status,
|
|
135
|
+
is_error=True,
|
|
136
|
+
error_message=error_msg,
|
|
137
|
+
texts=[],
|
|
138
|
+
embeddings=[],
|
|
139
|
+
)
|
|
140
|
+
except Exception as e:
|
|
141
|
+
return EmbeddingResponse(
|
|
142
|
+
id=self.task_id,
|
|
143
|
+
status_code=response.status,
|
|
144
|
+
is_error=True,
|
|
145
|
+
error_message=str(e),
|
|
146
|
+
texts=[],
|
|
147
|
+
embeddings=[],
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
async def call_api(
|
|
151
|
+
self,
|
|
152
|
+
session: aiohttp.ClientSession,
|
|
153
|
+
):
|
|
154
|
+
if len(self.texts) > 96:
|
|
155
|
+
raise ValueError("Embeddings only support up to 96 texts per request.")
|
|
156
|
+
model_obj = registry[self.model_name]
|
|
157
|
+
url = (
|
|
158
|
+
"https://api.openai.com/v1/embeddings"
|
|
159
|
+
if model_obj["provider"] == "openai"
|
|
160
|
+
else "https://api.cohere.com/v1/embed"
|
|
161
|
+
)
|
|
162
|
+
headers = {
|
|
163
|
+
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
|
164
|
+
if model_obj["provider"] == "openai"
|
|
165
|
+
else f"bearer {os.environ.get('COHERE_API_KEY')}"
|
|
166
|
+
}
|
|
167
|
+
payload: dict[str, Any] = {"model": self.model_name}
|
|
168
|
+
if model_obj["provider"] == "openai":
|
|
169
|
+
payload["input"] = self.texts
|
|
170
|
+
payload["encoding_format"] = "float"
|
|
171
|
+
for k, v in self.kwargs.items():
|
|
172
|
+
payload[k] = v
|
|
173
|
+
elif model_obj["provider"] == "cohere":
|
|
174
|
+
payload["texts"] = self.texts
|
|
175
|
+
payload["input_type"] = self.kwargs.get("input_type", "search_document")
|
|
176
|
+
for k, v in self.kwargs.items():
|
|
177
|
+
payload[k] = v
|
|
178
|
+
try:
|
|
179
|
+
self.status_tracker.total_requests += 1
|
|
180
|
+
async with session.post(url, json=payload, headers=headers) as response:
|
|
181
|
+
response_obj: EmbeddingResponse = await self.handle_response(response)
|
|
182
|
+
self.result.append(response_obj)
|
|
183
|
+
if response_obj.is_error:
|
|
184
|
+
self.handle_error()
|
|
185
|
+
else:
|
|
186
|
+
self.handle_success()
|
|
187
|
+
|
|
188
|
+
except asyncio.TimeoutError:
|
|
189
|
+
self.result.append(
|
|
190
|
+
EmbeddingResponse(
|
|
191
|
+
id=self.task_id,
|
|
192
|
+
status_code=None,
|
|
193
|
+
is_error=True,
|
|
194
|
+
error_message="Timeout",
|
|
195
|
+
texts=[],
|
|
196
|
+
embeddings=[],
|
|
197
|
+
)
|
|
198
|
+
)
|
|
199
|
+
self.handle_error()
|
|
200
|
+
|
|
201
|
+
except Exception as e:
|
|
202
|
+
self.result.append(
|
|
203
|
+
EmbeddingResponse(
|
|
204
|
+
id=self.task_id,
|
|
205
|
+
status_code=None,
|
|
206
|
+
is_error=True,
|
|
207
|
+
error_message=f"Unexpected {type(e).__name__}: {str(e) or 'No message.'}",
|
|
208
|
+
texts=[],
|
|
209
|
+
embeddings=[],
|
|
210
|
+
)
|
|
211
|
+
)
|
|
212
|
+
self.handle_error()
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
@dataclass
|
|
216
|
+
class EmbeddingResponse:
|
|
217
|
+
id: int
|
|
218
|
+
status_code: int | None
|
|
219
|
+
is_error: bool
|
|
220
|
+
error_message: Optional[str]
|
|
221
|
+
texts: list[str]
|
|
222
|
+
embeddings: list[list[float]]
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
async def embed_parallel_async(
|
|
226
|
+
texts: list[str],
|
|
227
|
+
model: str = "rerank-english-v3.0",
|
|
228
|
+
max_attempts: int = 5,
|
|
229
|
+
max_requests_per_minute: int = 4_000,
|
|
230
|
+
max_concurrent_requests: int = 500,
|
|
231
|
+
request_timeout: int = 10,
|
|
232
|
+
batch_size: int = 16,
|
|
233
|
+
show_progress: bool = True,
|
|
234
|
+
**kwargs,
|
|
235
|
+
):
|
|
236
|
+
"""Processes embed requests in parallel, throttling to stay under rate limits."""
|
|
237
|
+
if batch_size > 96:
|
|
238
|
+
raise ValueError("Embeddings only support up to 96 texts per request.")
|
|
239
|
+
batches = [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)]
|
|
240
|
+
pbar = tqdm(total=len(batches), desc="Embedding") if show_progress else None
|
|
241
|
+
ids = range(len(batches))
|
|
242
|
+
# constants
|
|
243
|
+
seconds_to_pause_after_rate_limit_error = 5
|
|
244
|
+
seconds_to_sleep_each_loop = 0.003 # so concurrent tasks can run
|
|
245
|
+
|
|
246
|
+
# initialize trackers
|
|
247
|
+
retry_queue = asyncio.Queue()
|
|
248
|
+
status_tracker = StatusTracker()
|
|
249
|
+
next_request = None # variable to hold the next request to call
|
|
250
|
+
|
|
251
|
+
# initialize available capacity counts
|
|
252
|
+
# throttle over a 1 second window rather than minute,
|
|
253
|
+
# since some models limit RPS rather than RPM
|
|
254
|
+
available_request_capacity = max_requests_per_minute
|
|
255
|
+
last_update_time = time.time()
|
|
256
|
+
last_pbar_update_time = time.time()
|
|
257
|
+
|
|
258
|
+
# initialize flags
|
|
259
|
+
prompts_not_finished = True
|
|
260
|
+
prompts_iter = iter(zip(ids, batches))
|
|
261
|
+
results: list = []
|
|
262
|
+
session = aiohttp.ClientSession()
|
|
263
|
+
|
|
264
|
+
while True:
|
|
265
|
+
# get next request (if one is not already waiting for capacity)
|
|
266
|
+
if next_request is None:
|
|
267
|
+
if not retry_queue.empty():
|
|
268
|
+
next_request = retry_queue.get_nowait()
|
|
269
|
+
print(f"Retrying request {next_request.task_id}.")
|
|
270
|
+
elif prompts_not_finished:
|
|
271
|
+
try:
|
|
272
|
+
# get new request
|
|
273
|
+
batch_id, batch = next(prompts_iter)
|
|
274
|
+
next_request = EmbeddingRequest(
|
|
275
|
+
task_id=batch_id,
|
|
276
|
+
model_name=model,
|
|
277
|
+
texts=batch,
|
|
278
|
+
attempts_left=max_attempts,
|
|
279
|
+
status_tracker=status_tracker,
|
|
280
|
+
retry_queue=retry_queue,
|
|
281
|
+
request_timeout=request_timeout,
|
|
282
|
+
pbar=pbar,
|
|
283
|
+
**kwargs,
|
|
284
|
+
)
|
|
285
|
+
status_tracker.num_tasks_started += 1
|
|
286
|
+
status_tracker.num_tasks_in_progress += 1
|
|
287
|
+
results.append(next_request)
|
|
288
|
+
|
|
289
|
+
except StopIteration:
|
|
290
|
+
prompts_not_finished = False
|
|
291
|
+
print("API requests finished, only retries remain.")
|
|
292
|
+
|
|
293
|
+
# update available capacity
|
|
294
|
+
current_time = time.time()
|
|
295
|
+
seconds_since_update = current_time - last_update_time
|
|
296
|
+
available_request_capacity = min(
|
|
297
|
+
available_request_capacity
|
|
298
|
+
+ max_requests_per_minute * seconds_since_update / 60.0,
|
|
299
|
+
max_requests_per_minute,
|
|
300
|
+
)
|
|
301
|
+
last_update_time = current_time
|
|
302
|
+
|
|
303
|
+
# update pbar status
|
|
304
|
+
if pbar:
|
|
305
|
+
if current_time - last_pbar_update_time > 1:
|
|
306
|
+
last_pbar_update_time = current_time
|
|
307
|
+
pbar.set_postfix(
|
|
308
|
+
{
|
|
309
|
+
"Req. Capacity": f"{available_request_capacity:.1f}",
|
|
310
|
+
"Reqs in Progress": status_tracker.num_tasks_in_progress,
|
|
311
|
+
}
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
# if enough capacity available, call API
|
|
315
|
+
if next_request:
|
|
316
|
+
if (
|
|
317
|
+
available_request_capacity >= 1
|
|
318
|
+
and status_tracker.num_tasks_in_progress < max_concurrent_requests
|
|
319
|
+
):
|
|
320
|
+
# update counters
|
|
321
|
+
available_request_capacity -= 1
|
|
322
|
+
next_request.attempts_left -= 1
|
|
323
|
+
|
|
324
|
+
# call API
|
|
325
|
+
asyncio.create_task(next_request.call_api(session=session))
|
|
326
|
+
next_request = None # reset next_request to empty
|
|
327
|
+
|
|
328
|
+
# if all tasks are finished, break
|
|
329
|
+
if status_tracker.num_tasks_in_progress == 0:
|
|
330
|
+
break
|
|
331
|
+
|
|
332
|
+
# main loop sleeps briefly so concurrent tasks can run
|
|
333
|
+
await asyncio.sleep(seconds_to_sleep_each_loop)
|
|
334
|
+
|
|
335
|
+
# if a rate limit error was hit recently, pause to cool down
|
|
336
|
+
seconds_since_rate_limit_error = (
|
|
337
|
+
time.time() - status_tracker.time_of_last_rate_limit_error
|
|
338
|
+
)
|
|
339
|
+
if seconds_since_rate_limit_error < seconds_to_pause_after_rate_limit_error:
|
|
340
|
+
remaining_seconds_to_pause = (
|
|
341
|
+
seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error
|
|
342
|
+
)
|
|
343
|
+
await asyncio.sleep(remaining_seconds_to_pause)
|
|
344
|
+
# ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago
|
|
345
|
+
print(
|
|
346
|
+
f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}"
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
# after finishing, log final status
|
|
350
|
+
if status_tracker.num_tasks_failed > 0:
|
|
351
|
+
print(
|
|
352
|
+
f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed."
|
|
353
|
+
)
|
|
354
|
+
if status_tracker.num_rate_limit_errors > 0:
|
|
355
|
+
print(
|
|
356
|
+
f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate."
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
print(
|
|
360
|
+
f"After processing, got {len(results)} results for {len(ids)} inputs. Removing duplicates."
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
# deduplicate results by id
|
|
364
|
+
deduplicated = {}
|
|
365
|
+
for request in results:
|
|
366
|
+
if request.task_id not in deduplicated:
|
|
367
|
+
deduplicated[request.task_id] = request.result[-1]
|
|
368
|
+
else:
|
|
369
|
+
current_response: EmbeddingResponse = deduplicated[request.task_id]
|
|
370
|
+
# only replace if the current request has no top_k_indexes and the new one does
|
|
371
|
+
if request.result[-1].embeddings and not current_response.embeddings:
|
|
372
|
+
deduplicated[request.task_id] = request.result[-1]
|
|
373
|
+
|
|
374
|
+
output = list(deduplicated.values())
|
|
375
|
+
# sort by id
|
|
376
|
+
output.sort(key=lambda x: x.id)
|
|
377
|
+
print(f"Returning {len(output)} unique results.")
|
|
378
|
+
await session.close()
|
|
379
|
+
return output
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def stack_results(
|
|
383
|
+
results: list[EmbeddingResponse], return_numpy: bool = True
|
|
384
|
+
) -> list[list[float]] | np.ndarray:
|
|
385
|
+
if not all(response.status_code == 200 for response in results):
|
|
386
|
+
raise ValueError("Some responses were not successful; cannot coalesce results.")
|
|
387
|
+
stacked = np.concatenate([response.embeddings for response in results], axis=0)
|
|
388
|
+
return stacked.tolist() if not return_numpy else stacked # type: ignore
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
def submit_batch_request():
|
|
392
|
+
pass
|
lm_deluge/errors.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
gemini_flash_limits = {
|
|
2
|
+
"asia-east1": 2000,
|
|
3
|
+
"asia-east2": 200,
|
|
4
|
+
"asia-northeast1": 200,
|
|
5
|
+
"asia-northeast3": 200,
|
|
6
|
+
"asia-south1": 200,
|
|
7
|
+
"asia-southeast1": 3_000,
|
|
8
|
+
"australia-southeast1": 200,
|
|
9
|
+
"europe-central2": 200,
|
|
10
|
+
"europe-north1": 200,
|
|
11
|
+
"europe-southwest1": 200,
|
|
12
|
+
"europe-west1": 10_000,
|
|
13
|
+
"europe-west2": 200,
|
|
14
|
+
"europe-west3": 200,
|
|
15
|
+
"europe-west4": 200,
|
|
16
|
+
"europe-west6": 200,
|
|
17
|
+
"europe-west8": 200,
|
|
18
|
+
"europe-west9": 200,
|
|
19
|
+
# 'me-central1': 200,
|
|
20
|
+
"me-central2": 200,
|
|
21
|
+
"me-west1": 200,
|
|
22
|
+
"northamerica-northeast1": 200,
|
|
23
|
+
"southamerica-east1": 200,
|
|
24
|
+
"us-central1": 5_000,
|
|
25
|
+
"us-east1": 3_000,
|
|
26
|
+
"us-east4": 200,
|
|
27
|
+
# 'us-east5': 200,
|
|
28
|
+
"us-south1": 3_000,
|
|
29
|
+
"us-west1": 5_000,
|
|
30
|
+
"us-west4": 200,
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
# total: 7_520
|
|
34
|
+
gemini_1_5_pro_limits = {
|
|
35
|
+
"asia-east1": 500,
|
|
36
|
+
"asia-east2": 500,
|
|
37
|
+
"asia-northeast1": 500,
|
|
38
|
+
# "asia-northeast2": 500,
|
|
39
|
+
"asia-northeast3": 500,
|
|
40
|
+
"asia-south1": 500,
|
|
41
|
+
"asia-southeast1": 500,
|
|
42
|
+
"australia-southeast1": 60,
|
|
43
|
+
"europe-central2": 500,
|
|
44
|
+
"europe-north1": 60,
|
|
45
|
+
"europe-southwest1": 60,
|
|
46
|
+
"europe-west1": 500,
|
|
47
|
+
"europe-west2": 60,
|
|
48
|
+
"europe-west3": 60,
|
|
49
|
+
"europe-west4": 60,
|
|
50
|
+
"europe-west6": 60,
|
|
51
|
+
"europe-west8": 60,
|
|
52
|
+
"europe-west9": 60,
|
|
53
|
+
"me-central1": 60,
|
|
54
|
+
"me-central2": 60,
|
|
55
|
+
"me-west1": 60,
|
|
56
|
+
"northamerica-northeast1": 60,
|
|
57
|
+
"southamerica-east1": 500,
|
|
58
|
+
"us-central1": 500,
|
|
59
|
+
"us-east1": 500,
|
|
60
|
+
"us-east4": 60,
|
|
61
|
+
# "us-east5": 60,
|
|
62
|
+
"us-south1": 60,
|
|
63
|
+
"us-west1": 500,
|
|
64
|
+
"us-west4": 60,
|
|
65
|
+
}
|
lm_deluge/image.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
from functools import cached_property
|
|
4
|
+
import io
|
|
5
|
+
import requests
|
|
6
|
+
from PIL import Image as PILImage # type: ignore
|
|
7
|
+
import base64
|
|
8
|
+
import mimetypes
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Literal
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(slots=True)
|
|
15
|
+
class Image:
|
|
16
|
+
# raw bytes, pathlike, http url, or base64 data url
|
|
17
|
+
data: bytes | io.BytesIO | Path | str
|
|
18
|
+
media_type: str | None = None # inferred if None
|
|
19
|
+
detail: Literal["low", "high", "auto"] = "auto"
|
|
20
|
+
type: str = field(init=False, default="image")
|
|
21
|
+
|
|
22
|
+
@classmethod
|
|
23
|
+
def from_pdf(
|
|
24
|
+
cls,
|
|
25
|
+
pdf_path: str,
|
|
26
|
+
dpi: int = 200,
|
|
27
|
+
target_size: int = 1024,
|
|
28
|
+
first_page: int | None = None,
|
|
29
|
+
last_page: int | None = None,
|
|
30
|
+
) -> list["Image"]:
|
|
31
|
+
try:
|
|
32
|
+
from pdf2image import convert_from_path # type: ignore
|
|
33
|
+
except ImportError:
|
|
34
|
+
raise RuntimeError("pdf2image is required for PDF conversion.")
|
|
35
|
+
|
|
36
|
+
# Convert the first page of the PDF to an image
|
|
37
|
+
pages = convert_from_path(
|
|
38
|
+
pdf_path,
|
|
39
|
+
dpi=dpi,
|
|
40
|
+
first_page=first_page or 1,
|
|
41
|
+
last_page=last_page, # type: ignore
|
|
42
|
+
)
|
|
43
|
+
images = []
|
|
44
|
+
for page in pages:
|
|
45
|
+
buffer = io.BytesIO()
|
|
46
|
+
page.save(buffer, format="JPEG")
|
|
47
|
+
image = cls(buffer.getvalue(), media_type="image/jpeg")
|
|
48
|
+
image.resize(target_size)
|
|
49
|
+
images.append(image)
|
|
50
|
+
return images
|
|
51
|
+
|
|
52
|
+
# helpers -----------------------------------------------------------------
|
|
53
|
+
def _bytes(self) -> bytes:
|
|
54
|
+
if isinstance(self.data, bytes):
|
|
55
|
+
return self.data
|
|
56
|
+
elif isinstance(self.data, io.BytesIO):
|
|
57
|
+
return self.data.getvalue()
|
|
58
|
+
elif isinstance(self.data, str) and self.data.startswith("http"):
|
|
59
|
+
res = requests.get(self.data)
|
|
60
|
+
res.raise_for_status()
|
|
61
|
+
return res.content
|
|
62
|
+
elif isinstance(self.data, str) and os.path.exists(self.data):
|
|
63
|
+
with open(self.data, "rb") as f:
|
|
64
|
+
return f.read()
|
|
65
|
+
elif isinstance(self.data, Path) and self.data.exists():
|
|
66
|
+
return Path(self.data).read_bytes()
|
|
67
|
+
elif isinstance(self.data, str) and self.data.startswith("data:"):
|
|
68
|
+
header, encoded = self.data.split(",", 1)
|
|
69
|
+
return base64.b64decode(encoded)
|
|
70
|
+
else:
|
|
71
|
+
raise ValueError("unreadable image format")
|
|
72
|
+
|
|
73
|
+
def _mime(self) -> str:
|
|
74
|
+
if self.media_type:
|
|
75
|
+
return self.media_type
|
|
76
|
+
if isinstance(self.data, (Path, str)):
|
|
77
|
+
guess = mimetypes.guess_type(str(self.data))[0]
|
|
78
|
+
if guess:
|
|
79
|
+
return guess
|
|
80
|
+
return "image/png"
|
|
81
|
+
|
|
82
|
+
def _base64(self, include_header: bool = True) -> str:
|
|
83
|
+
encoded = base64.b64encode(self._bytes()).decode("utf-8")
|
|
84
|
+
if not include_header:
|
|
85
|
+
return encoded
|
|
86
|
+
return f"data:{self._mime()};base64,{encoded}"
|
|
87
|
+
|
|
88
|
+
@contextmanager
|
|
89
|
+
def _image(self):
|
|
90
|
+
img = None
|
|
91
|
+
try:
|
|
92
|
+
img = PILImage.open(io.BytesIO(self._bytes()))
|
|
93
|
+
yield img
|
|
94
|
+
finally:
|
|
95
|
+
if img:
|
|
96
|
+
img.close()
|
|
97
|
+
|
|
98
|
+
@cached_property
|
|
99
|
+
def size(self) -> tuple[int, int]:
|
|
100
|
+
with self._image() as img:
|
|
101
|
+
return img.size
|
|
102
|
+
|
|
103
|
+
@cached_property
|
|
104
|
+
def num_pixels(self) -> int:
|
|
105
|
+
return self.size[0] * self.size[1]
|
|
106
|
+
|
|
107
|
+
def _resize(self, size: tuple[int, int]) -> bytes:
|
|
108
|
+
buffer = io.BytesIO()
|
|
109
|
+
new_width, new_height = size
|
|
110
|
+
with self._image() as img:
|
|
111
|
+
# Resize with Lanczos antialiasing
|
|
112
|
+
img.resize((new_width, new_height), PILImage.Resampling.LANCZOS).save(
|
|
113
|
+
buffer, format=self._mime().split("/")[-1].upper()
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
return buffer.getvalue()
|
|
117
|
+
|
|
118
|
+
def _resize_longer(
|
|
119
|
+
self, *, size: int | None = None, max_size: int | None = None
|
|
120
|
+
) -> bytes:
|
|
121
|
+
if not max_size and not size:
|
|
122
|
+
raise ValueError("Either size or max_size must be provided")
|
|
123
|
+
width, height = self.size
|
|
124
|
+
if width > height:
|
|
125
|
+
new_width = size if size is not None else min(max_size, width) # type: ignore
|
|
126
|
+
new_height = int(new_width / width * height)
|
|
127
|
+
else:
|
|
128
|
+
new_height = size if size is not None else min(max_size, height) # type: ignore
|
|
129
|
+
new_width = int(new_height / height * width)
|
|
130
|
+
return self._resize((new_width, new_height))
|
|
131
|
+
|
|
132
|
+
def _resize_shorter(
|
|
133
|
+
self, *, size: int | None = None, max_size: int | None = None
|
|
134
|
+
) -> bytes:
|
|
135
|
+
if not max_size and not size:
|
|
136
|
+
raise ValueError("Either size or max_size must be provided")
|
|
137
|
+
width, height = self.size
|
|
138
|
+
if width <= height:
|
|
139
|
+
new_width = size if size is not None else min(max_size, width) # type: ignore
|
|
140
|
+
new_height = int(new_width / width * height)
|
|
141
|
+
else:
|
|
142
|
+
new_height = size if size is not None else min(max_size, height) # type: ignore
|
|
143
|
+
new_width = int(new_height / height * width)
|
|
144
|
+
return self._resize((new_width, new_height))
|
|
145
|
+
|
|
146
|
+
@cached_property
|
|
147
|
+
def fingerprint(self) -> str:
|
|
148
|
+
# return base64 of a very small version of the image
|
|
149
|
+
small_image = self._resize_longer(max_size=48) # longer side = 48px
|
|
150
|
+
return base64.b64encode(small_image).decode("utf-8")
|
|
151
|
+
|
|
152
|
+
def resize(self, max_size: int) -> None:
|
|
153
|
+
"""
|
|
154
|
+
Resize the image and save to the data value.
|
|
155
|
+
"""
|
|
156
|
+
self.data = self._resize_longer(max_size=max_size)
|
|
157
|
+
|
|
158
|
+
# ── provider-specific emission ────────────────────────────────────────────
|
|
159
|
+
def oa_chat(self) -> dict:
|
|
160
|
+
# if max(self.size) > 1_568:
|
|
161
|
+
# self.resize_longer_side(1_568)
|
|
162
|
+
return {
|
|
163
|
+
"type": "image_url",
|
|
164
|
+
"image_url": {
|
|
165
|
+
"url": self._base64(),
|
|
166
|
+
"detail": self.detail,
|
|
167
|
+
},
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
def oa_resp(self) -> dict:
|
|
171
|
+
# if max(self.size) > 1_568:
|
|
172
|
+
# self.resize_longer_side(1_568)
|
|
173
|
+
return {"type": "input_image", "image_url": self._base64()}
|
|
174
|
+
|
|
175
|
+
def anthropic(self) -> dict:
|
|
176
|
+
# n_pixels = self.num_pixels
|
|
177
|
+
# if n_pixels > 1_200_000:
|
|
178
|
+
# resize_factor = (1_200_000 / n_pixels) ** 0.5
|
|
179
|
+
# new_size = (
|
|
180
|
+
# int(self.size[0] * resize_factor),
|
|
181
|
+
# int(self.size[1] * resize_factor),
|
|
182
|
+
# )
|
|
183
|
+
# self.resize(new_size)
|
|
184
|
+
b64 = base64.b64encode(self._bytes()).decode()
|
|
185
|
+
return {
|
|
186
|
+
"type": "image",
|
|
187
|
+
"source": {
|
|
188
|
+
"type": "base64",
|
|
189
|
+
"media_type": self._mime(),
|
|
190
|
+
"data": b64,
|
|
191
|
+
},
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
def gemini(self) -> dict:
|
|
195
|
+
return {
|
|
196
|
+
"inlineData": {
|
|
197
|
+
"mimeType": self._mime(),
|
|
198
|
+
"data": self._base64(include_header=False),
|
|
199
|
+
}
|
|
200
|
+
}
|