lm-deluge 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lm-deluge might be problematic. Click here for more details.
- lm_deluge/__init__.py +6 -0
- lm_deluge/api_requests/__init__.py +3 -0
- lm_deluge/api_requests/anthropic.py +177 -0
- lm_deluge/api_requests/base.py +375 -0
- lm_deluge/api_requests/cohere.py +138 -0
- lm_deluge/api_requests/common.py +18 -0
- lm_deluge/api_requests/deprecated/bedrock.py +288 -0
- lm_deluge/api_requests/deprecated/deepseek.py +118 -0
- lm_deluge/api_requests/deprecated/mistral.py +120 -0
- lm_deluge/api_requests/google.py +0 -0
- lm_deluge/api_requests/openai.py +145 -0
- lm_deluge/api_requests/vertex.py +365 -0
- lm_deluge/cache.py +144 -0
- lm_deluge/client.py +760 -0
- lm_deluge/embed.py +392 -0
- lm_deluge/errors.py +8 -0
- lm_deluge/gemini_limits.py +65 -0
- lm_deluge/image.py +200 -0
- lm_deluge/llm_tools/__init__.py +11 -0
- lm_deluge/llm_tools/extract.py +111 -0
- lm_deluge/llm_tools/score.py +71 -0
- lm_deluge/llm_tools/translate.py +44 -0
- lm_deluge/models.py +957 -0
- lm_deluge/prompt.py +355 -0
- lm_deluge/rerank.py +338 -0
- lm_deluge/sampling_params.py +25 -0
- lm_deluge/tool.py +106 -0
- lm_deluge/tracker.py +12 -0
- lm_deluge/util/json.py +167 -0
- lm_deluge/util/logprobs.py +446 -0
- lm_deluge/util/pdf.py +45 -0
- lm_deluge/util/validation.py +46 -0
- lm_deluge/util/xml.py +291 -0
- lm_deluge-0.0.3.dist-info/METADATA +127 -0
- lm_deluge-0.0.3.dist-info/RECORD +37 -0
- lm_deluge-0.0.3.dist-info/WHEEL +5 -0
- lm_deluge-0.0.3.dist-info/top_level.txt +1 -0
lm_deluge/prompt.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import json
|
|
3
|
+
import tiktoken
|
|
4
|
+
import xxhash
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Literal, Sequence
|
|
8
|
+
from lm_deluge.models import APIModel
|
|
9
|
+
from lm_deluge.image import Image
|
|
10
|
+
|
|
11
|
+
###############################################################################
|
|
12
|
+
# 1. Low-level content blocks – either text or an image #
|
|
13
|
+
###############################################################################
|
|
14
|
+
Role = Literal["system", "user", "assistant"]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(slots=True)
|
|
18
|
+
class Text:
|
|
19
|
+
text: str
|
|
20
|
+
type: str = field(init=False, default="text")
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def fingerprint(self) -> str:
|
|
24
|
+
return xxhash.xxh64(self.text.encode()).hexdigest()
|
|
25
|
+
|
|
26
|
+
# ── provider-specific emission ────────────────────────────────────────────
|
|
27
|
+
def oa_chat(self) -> dict | str: # OpenAI Chat Completions
|
|
28
|
+
return {"type": "text", "text": self.text}
|
|
29
|
+
|
|
30
|
+
def oa_resp(self) -> dict: # OpenAI *Responses* (new)
|
|
31
|
+
return {"type": "input_text", "text": self.text}
|
|
32
|
+
|
|
33
|
+
def anthropic(self) -> dict: # Anthropic Messages
|
|
34
|
+
return {"type": "text", "text": self.text}
|
|
35
|
+
|
|
36
|
+
def gemini(self) -> dict:
|
|
37
|
+
return {"text": self.text}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
###############################################################################
|
|
41
|
+
# 2. One conversational turn (role + parts) #
|
|
42
|
+
###############################################################################
|
|
43
|
+
@dataclass(slots=True)
|
|
44
|
+
class Message:
|
|
45
|
+
role: Role
|
|
46
|
+
parts: list[Text | Image]
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def fingerprint(self) -> str:
|
|
50
|
+
return self.role + "," + ",".join(part.fingerprint for part in self.parts)
|
|
51
|
+
|
|
52
|
+
def add_text(self, content: str) -> "Message":
|
|
53
|
+
"""Append a text block and return self for chaining."""
|
|
54
|
+
self.parts.append(Text(content))
|
|
55
|
+
return self
|
|
56
|
+
|
|
57
|
+
def add_image(
|
|
58
|
+
self,
|
|
59
|
+
data: bytes | str | Path | io.BytesIO,
|
|
60
|
+
*,
|
|
61
|
+
media_type: str | None = None,
|
|
62
|
+
detail: Literal["low", "high", "auto"] = "auto",
|
|
63
|
+
max_size: int | None = None,
|
|
64
|
+
) -> "Message":
|
|
65
|
+
"""
|
|
66
|
+
Append an image block and return self for chaining.
|
|
67
|
+
|
|
68
|
+
If max_size is provided, the image will be resized so that its longer
|
|
69
|
+
dimension equals max_size, but only if the longer dimension is currently
|
|
70
|
+
larger than max_size.
|
|
71
|
+
"""
|
|
72
|
+
img = Image(data, media_type=media_type, detail=detail)
|
|
73
|
+
|
|
74
|
+
# Resize if max_size is provided
|
|
75
|
+
if max_size is not None:
|
|
76
|
+
img.resize(max_size)
|
|
77
|
+
|
|
78
|
+
self.parts.append(img)
|
|
79
|
+
return self
|
|
80
|
+
|
|
81
|
+
# -------- convenient constructors --------
|
|
82
|
+
@classmethod
|
|
83
|
+
def user(
|
|
84
|
+
cls,
|
|
85
|
+
text: str | None = None,
|
|
86
|
+
*,
|
|
87
|
+
image: str | bytes | Path | io.BytesIO | None = None,
|
|
88
|
+
) -> "Message":
|
|
89
|
+
res = cls("user", [])
|
|
90
|
+
if text is not None:
|
|
91
|
+
res.add_text(text)
|
|
92
|
+
if image is not None:
|
|
93
|
+
res.add_image(image)
|
|
94
|
+
return res
|
|
95
|
+
|
|
96
|
+
@classmethod
|
|
97
|
+
def system(cls, text: str | None = None) -> "Message":
|
|
98
|
+
res = cls("system", [])
|
|
99
|
+
if text is not None:
|
|
100
|
+
res.add_text(text)
|
|
101
|
+
return res
|
|
102
|
+
|
|
103
|
+
@classmethod
|
|
104
|
+
def ai(cls, text: str | None = None) -> "Message":
|
|
105
|
+
res = cls("assistant", [])
|
|
106
|
+
if text is not None:
|
|
107
|
+
res.add_text(text)
|
|
108
|
+
return res
|
|
109
|
+
|
|
110
|
+
# ──── provider-specific constructors ───
|
|
111
|
+
@classmethod
|
|
112
|
+
def from_oa(cls, msg: dict):
|
|
113
|
+
role = (
|
|
114
|
+
"system"
|
|
115
|
+
if msg["role"] in ["developer", "system"]
|
|
116
|
+
else ("user" if msg["role"] == "user" else "assistant")
|
|
117
|
+
)
|
|
118
|
+
parts: Sequence[Text | Image] = []
|
|
119
|
+
content = msg["content"]
|
|
120
|
+
if isinstance(content, str):
|
|
121
|
+
parts.append(Text(content))
|
|
122
|
+
else:
|
|
123
|
+
for item in content:
|
|
124
|
+
if item["type"] == "text":
|
|
125
|
+
parts.append(Text(item["text"]))
|
|
126
|
+
elif item["type"] == "image_url":
|
|
127
|
+
parts.append(Image(data=item["image_url"]["url"]))
|
|
128
|
+
return cls(role, parts)
|
|
129
|
+
|
|
130
|
+
@classmethod
|
|
131
|
+
def from_oa_resp(cls, msg: dict):
|
|
132
|
+
raise NotImplementedError("not implemented")
|
|
133
|
+
|
|
134
|
+
@classmethod
|
|
135
|
+
def from_anthropic(cls, msg: dict):
|
|
136
|
+
pass
|
|
137
|
+
|
|
138
|
+
# ───── provider-specific emission ─────
|
|
139
|
+
def oa_chat(self) -> dict:
|
|
140
|
+
content = []
|
|
141
|
+
for p in self.parts:
|
|
142
|
+
content.append(p.oa_chat())
|
|
143
|
+
return {"role": self.role, "content": content}
|
|
144
|
+
|
|
145
|
+
def oa_resp(self) -> dict:
|
|
146
|
+
content = [p.oa_resp() for p in self.parts]
|
|
147
|
+
return {"role": self.role, "content": content}
|
|
148
|
+
|
|
149
|
+
def anthropic(self) -> dict:
|
|
150
|
+
# Anthropic: system message is *not* in the list
|
|
151
|
+
if self.role == "system":
|
|
152
|
+
raise ValueError("Anthropic keeps system outside message list")
|
|
153
|
+
content = [p.anthropic() for p in self.parts]
|
|
154
|
+
# Shortcut: single text becomes a bare string
|
|
155
|
+
if len(content) == 1 and content[0]["type"] == "text":
|
|
156
|
+
content = content[0]["text"]
|
|
157
|
+
return {"role": self.role, "content": content}
|
|
158
|
+
|
|
159
|
+
def gemini(self) -> dict:
|
|
160
|
+
parts = [p.gemini() for p in self.parts]
|
|
161
|
+
# Shortcut: single text becomes a bare string
|
|
162
|
+
role = "user" if self.role == "user" else "model"
|
|
163
|
+
return {"role": role, "parts": parts}
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
###############################################################################
|
|
167
|
+
# 3. A whole conversation (ordered list of messages) #
|
|
168
|
+
###############################################################################
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
@dataclass(slots=True)
|
|
172
|
+
class Conversation:
|
|
173
|
+
messages: list[Message] = field(default_factory=list)
|
|
174
|
+
|
|
175
|
+
# ── convenience shorthands ------------------------------------------------
|
|
176
|
+
@classmethod
|
|
177
|
+
def system(cls, text: str) -> "Conversation":
|
|
178
|
+
return cls([Message.system(text)])
|
|
179
|
+
|
|
180
|
+
@classmethod
|
|
181
|
+
def user(
|
|
182
|
+
cls, text: str, *, image: bytes | str | Path | None = None
|
|
183
|
+
) -> "Conversation":
|
|
184
|
+
msg = (
|
|
185
|
+
Message.user(text) if image is None else Message.user(text).add_image(image)
|
|
186
|
+
)
|
|
187
|
+
return cls([msg])
|
|
188
|
+
|
|
189
|
+
@classmethod
|
|
190
|
+
def from_openai(cls, messages: list[dict]):
|
|
191
|
+
"""Compatibility with openai-formatted messages"""
|
|
192
|
+
pass
|
|
193
|
+
|
|
194
|
+
@classmethod
|
|
195
|
+
def from_anthropic(cls, messages: list[dict], system: str | None = None):
|
|
196
|
+
"""Compatibility with anthropic-formatted messages"""
|
|
197
|
+
pass
|
|
198
|
+
|
|
199
|
+
# fluent additions
|
|
200
|
+
def add(self, msg: Message) -> "Conversation":
|
|
201
|
+
self.messages.append(msg)
|
|
202
|
+
return self
|
|
203
|
+
|
|
204
|
+
# ── conversions -----------------------------------------------------------
|
|
205
|
+
def to_openai(self) -> list[dict]:
|
|
206
|
+
return [m.oa_chat() for m in self.messages]
|
|
207
|
+
|
|
208
|
+
def to_openai_responses(self) -> dict:
|
|
209
|
+
# OpenAI Responses = single “input” array, role must be user/assistant
|
|
210
|
+
return {"input": [m.oa_resp() for m in self.messages if m.role != "system"]}
|
|
211
|
+
|
|
212
|
+
def to_anthropic(self) -> tuple[str | None, list[dict]]:
|
|
213
|
+
system_msg = next(
|
|
214
|
+
(
|
|
215
|
+
m.parts[0].text
|
|
216
|
+
for m in self.messages
|
|
217
|
+
if m.role == "system" and isinstance(m.parts[0], Text)
|
|
218
|
+
),
|
|
219
|
+
None,
|
|
220
|
+
)
|
|
221
|
+
other = [m.anthropic() for m in self.messages if m.role != "system"]
|
|
222
|
+
return system_msg, other
|
|
223
|
+
|
|
224
|
+
def to_gemini(self) -> tuple[str | None, list[dict]]:
|
|
225
|
+
system_msg = next(
|
|
226
|
+
(
|
|
227
|
+
m.parts[0].text
|
|
228
|
+
for m in self.messages
|
|
229
|
+
if m.role == "system" and isinstance(m.parts[0], Text)
|
|
230
|
+
),
|
|
231
|
+
None,
|
|
232
|
+
)
|
|
233
|
+
other = [m.gemini() for m in self.messages if m.role != "system"]
|
|
234
|
+
return system_msg, other
|
|
235
|
+
|
|
236
|
+
def to_cohere(self) -> list[dict]:
|
|
237
|
+
messages = []
|
|
238
|
+
for m in self.messages:
|
|
239
|
+
if len(m.parts) > 1:
|
|
240
|
+
raise ValueError("Cohere does not support multi-part messages")
|
|
241
|
+
if isinstance(m.parts[0], Image):
|
|
242
|
+
raise ValueError("Cohere does not support images")
|
|
243
|
+
messages.append({"role": m.role, "text": m.parts[0].text})
|
|
244
|
+
return messages
|
|
245
|
+
|
|
246
|
+
# ── misc helpers ----------------------------------------------------------
|
|
247
|
+
_tok = tiktoken.encoding_for_model("gpt-4")
|
|
248
|
+
|
|
249
|
+
def count_tokens(self, max_new_tokens: int = 0, img_tokens: int = 85) -> int:
|
|
250
|
+
n = max_new_tokens
|
|
251
|
+
for m in self.messages:
|
|
252
|
+
for p in m.parts:
|
|
253
|
+
if isinstance(p, Text):
|
|
254
|
+
n += len(self._tok.encode(p.text))
|
|
255
|
+
else: # Image – crude flat cost per image
|
|
256
|
+
n += img_tokens
|
|
257
|
+
|
|
258
|
+
# very rough BOS/EOS padding
|
|
259
|
+
return n + 6 * len(self.messages)
|
|
260
|
+
|
|
261
|
+
def dry_run(self, model_name: str, max_new_tokens: int):
|
|
262
|
+
model_obj = APIModel.from_registry(model_name)
|
|
263
|
+
if model_obj.api_spec == "openai":
|
|
264
|
+
image_tokens = 85
|
|
265
|
+
elif model_obj.api_spec == "anthropic":
|
|
266
|
+
image_tokens = 1_200
|
|
267
|
+
else:
|
|
268
|
+
image_tokens = 0
|
|
269
|
+
input_tokens = self.count_tokens(0, image_tokens)
|
|
270
|
+
output_tokens = max_new_tokens
|
|
271
|
+
|
|
272
|
+
min_cost, max_cost = None, None
|
|
273
|
+
if model_obj.input_cost and model_obj.output_cost:
|
|
274
|
+
min_cost = model_obj.input_cost * input_tokens / 1e6
|
|
275
|
+
max_cost = min_cost + model_obj.output_cost * output_tokens / 1e6
|
|
276
|
+
|
|
277
|
+
return input_tokens, output_tokens, min_cost, max_cost
|
|
278
|
+
|
|
279
|
+
@property
|
|
280
|
+
def fingerprint(self) -> str:
|
|
281
|
+
hasher = xxhash.xxh64()
|
|
282
|
+
hasher.update(json.dumps([m.fingerprint for m in self.messages]).encode())
|
|
283
|
+
return hasher.hexdigest()
|
|
284
|
+
|
|
285
|
+
def to_log(self) -> dict:
|
|
286
|
+
"""
|
|
287
|
+
Return a JSON-serialisable dict that fully captures the conversation.
|
|
288
|
+
"""
|
|
289
|
+
serialized: list[dict] = []
|
|
290
|
+
|
|
291
|
+
for msg in self.messages:
|
|
292
|
+
content_blocks: list[dict] = []
|
|
293
|
+
for p in msg.parts:
|
|
294
|
+
if isinstance(p, Text):
|
|
295
|
+
content_blocks.append({"type": "text", "text": p.text})
|
|
296
|
+
else: # Image – redact the bytes, keep a hint
|
|
297
|
+
w, h = p.size
|
|
298
|
+
content_blocks.append(
|
|
299
|
+
{"type": "image", "tag": f"<Image ({w}×{h})>"}
|
|
300
|
+
)
|
|
301
|
+
serialized.append({"role": msg.role, "content": content_blocks})
|
|
302
|
+
|
|
303
|
+
return {"messages": serialized}
|
|
304
|
+
|
|
305
|
+
@classmethod
|
|
306
|
+
def from_log(cls, payload: dict) -> "Conversation":
|
|
307
|
+
"""Re-hydrate a Conversation previously produced by `to_log()`."""
|
|
308
|
+
msgs: list[Message] = []
|
|
309
|
+
|
|
310
|
+
for m in payload.get("messages", []):
|
|
311
|
+
role: Role = m["role"] # 'system' | 'user' | 'assistant'
|
|
312
|
+
parts: list[Text | Image] = []
|
|
313
|
+
|
|
314
|
+
for p in m["content"]:
|
|
315
|
+
if p["type"] == "text":
|
|
316
|
+
parts.append(Text(p["text"]))
|
|
317
|
+
elif p["type"] == "image":
|
|
318
|
+
# We only stored a placeholder tag, so keep that placeholder.
|
|
319
|
+
# You could raise instead if real image bytes are required.
|
|
320
|
+
parts.append(Image(p["tag"], detail="low"))
|
|
321
|
+
else:
|
|
322
|
+
raise ValueError(f"Unknown part type {p['type']!r}")
|
|
323
|
+
|
|
324
|
+
msgs.append(Message(role, parts))
|
|
325
|
+
|
|
326
|
+
return cls(msgs)
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
###############################################################################
|
|
330
|
+
# --------------------------------------------------------------------------- #
|
|
331
|
+
# Basic usage examples #
|
|
332
|
+
# --------------------------------------------------------------------------- #
|
|
333
|
+
|
|
334
|
+
# 1️⃣ trivial single-turn (text only) ---------------------------------------
|
|
335
|
+
# conv = Conversation.user("Hi Claude, who won the 2018 World Cup?")
|
|
336
|
+
# client.messages.create(model="claude-3-7-sonnet", **conv.to_anthropic())
|
|
337
|
+
|
|
338
|
+
# # 2️⃣ system + vision + follow-up for OpenAI Chat Completions ---------------
|
|
339
|
+
# conv = (
|
|
340
|
+
# Conversation.system("You are a visual assistant.")
|
|
341
|
+
# .add(
|
|
342
|
+
# Message.with_image(
|
|
343
|
+
# "user",
|
|
344
|
+
# "What's in this photo?",
|
|
345
|
+
# Image("boardwalk.jpg", detail="low"),
|
|
346
|
+
# )
|
|
347
|
+
# )
|
|
348
|
+
# .add(Message.text("assistant", "Looks like a lakeside boardwalk."))
|
|
349
|
+
# .add(Message.text("user", "Great, write a haiku about it."))
|
|
350
|
+
# )
|
|
351
|
+
|
|
352
|
+
# openai.chat.completions.create(model="gpt-4o-mini", messages=conv.to_openai_chat())
|
|
353
|
+
|
|
354
|
+
# # 3️⃣ Same conversation sent through new Responses API -----------------------
|
|
355
|
+
# openai.responses.create(model="gpt-4o-mini", **conv.to_openai_responses())
|
lm_deluge/rerank.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
1
|
+
### specific utility for cohere rerank api
|
|
2
|
+
import os
|
|
3
|
+
import aiohttp
|
|
4
|
+
from tqdm.auto import tqdm
|
|
5
|
+
import asyncio
|
|
6
|
+
import time
|
|
7
|
+
from typing import Optional
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from .tracker import StatusTracker
|
|
10
|
+
|
|
11
|
+
registry = [
|
|
12
|
+
"rerank-english-v3.0",
|
|
13
|
+
"rerank-multilingual-v3.0",
|
|
14
|
+
"rerank-english-v2.0",
|
|
15
|
+
"rerank-multilingual-v2.0",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RerankingRequest:
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
task_id: int,
|
|
23
|
+
model_name: str,
|
|
24
|
+
query: str,
|
|
25
|
+
documents: list[str],
|
|
26
|
+
top_k: int,
|
|
27
|
+
attempts_left: int,
|
|
28
|
+
status_tracker: StatusTracker,
|
|
29
|
+
retry_queue: asyncio.Queue,
|
|
30
|
+
request_timeout: int,
|
|
31
|
+
pbar: Optional[tqdm] = None,
|
|
32
|
+
):
|
|
33
|
+
self.task_id = task_id
|
|
34
|
+
self.model_name = model_name
|
|
35
|
+
self.query = query
|
|
36
|
+
self.documents = documents
|
|
37
|
+
self.top_k = top_k
|
|
38
|
+
self.attempts_left = attempts_left
|
|
39
|
+
self.status_tracker = status_tracker
|
|
40
|
+
self.retry_queue = retry_queue
|
|
41
|
+
self.request_timeout = request_timeout
|
|
42
|
+
self.pbar = pbar
|
|
43
|
+
self.result = []
|
|
44
|
+
|
|
45
|
+
def increment_pbar(self):
|
|
46
|
+
if self.pbar is not None:
|
|
47
|
+
self.pbar.update(1)
|
|
48
|
+
|
|
49
|
+
def handle_success(self):
|
|
50
|
+
self.increment_pbar()
|
|
51
|
+
self.status_tracker.num_tasks_in_progress -= 1
|
|
52
|
+
self.status_tracker.num_tasks_succeeded += 1
|
|
53
|
+
|
|
54
|
+
def handle_error(self):
|
|
55
|
+
"""
|
|
56
|
+
If create_new_request is True, will create a new API request (so that it
|
|
57
|
+
has a chance of being sent to a different model). If false, will retry
|
|
58
|
+
the same request.
|
|
59
|
+
"""
|
|
60
|
+
last_result: RerankingResponse = self.result[-1]
|
|
61
|
+
error_to_print = (
|
|
62
|
+
f"Error on task {self.task_id}, Code: {last_result.status_code}, "
|
|
63
|
+
)
|
|
64
|
+
error_to_print += f"Message: {last_result.error_message}."
|
|
65
|
+
print(error_to_print)
|
|
66
|
+
if self.attempts_left > 0:
|
|
67
|
+
self.attempts_left -= 1
|
|
68
|
+
self.retry_queue.put_nowait(self)
|
|
69
|
+
return
|
|
70
|
+
else:
|
|
71
|
+
print(f"Task {self.task_id} out of tries.")
|
|
72
|
+
self.status_tracker.num_tasks_in_progress -= 1
|
|
73
|
+
self.status_tracker.num_tasks_failed += 1
|
|
74
|
+
|
|
75
|
+
async def handle_response(self, response: aiohttp.ClientResponse):
|
|
76
|
+
try:
|
|
77
|
+
if response.status == 200:
|
|
78
|
+
result = await response.json()
|
|
79
|
+
# TODO: add cost calculation
|
|
80
|
+
return RerankingResponse(
|
|
81
|
+
id=self.task_id,
|
|
82
|
+
status_code=response.status,
|
|
83
|
+
is_error=False,
|
|
84
|
+
error_message=None,
|
|
85
|
+
query=self.query,
|
|
86
|
+
documents=self.documents,
|
|
87
|
+
top_k_indices=[doc["index"] for doc in result["results"]],
|
|
88
|
+
top_k_scores=[doc["relevance_score"] for doc in result["results"]],
|
|
89
|
+
)
|
|
90
|
+
else:
|
|
91
|
+
error_msg = await response.text()
|
|
92
|
+
return RerankingResponse(
|
|
93
|
+
id=self.task_id,
|
|
94
|
+
status_code=response.status,
|
|
95
|
+
is_error=True,
|
|
96
|
+
error_message=error_msg,
|
|
97
|
+
query=self.query,
|
|
98
|
+
documents=[],
|
|
99
|
+
top_k_indices=[],
|
|
100
|
+
top_k_scores=[],
|
|
101
|
+
)
|
|
102
|
+
except Exception as e:
|
|
103
|
+
return RerankingResponse(
|
|
104
|
+
id=self.task_id,
|
|
105
|
+
status_code=response.status,
|
|
106
|
+
is_error=True,
|
|
107
|
+
error_message=str(e),
|
|
108
|
+
query=self.query,
|
|
109
|
+
documents=[],
|
|
110
|
+
top_k_indices=[],
|
|
111
|
+
top_k_scores=[],
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
async def call_api(self):
|
|
115
|
+
url = "https://api.cohere.com/v1/rerank"
|
|
116
|
+
headers = {
|
|
117
|
+
"accept": "application/json",
|
|
118
|
+
"content-type": "application/json",
|
|
119
|
+
"Authorization": f"Bearer {os.environ.get('COHERE_API_KEY')}",
|
|
120
|
+
}
|
|
121
|
+
data = {
|
|
122
|
+
"model": self.model_name,
|
|
123
|
+
"query": self.query,
|
|
124
|
+
"top_n": self.top_k,
|
|
125
|
+
"documents": self.documents,
|
|
126
|
+
}
|
|
127
|
+
try:
|
|
128
|
+
self.status_tracker.total_requests += 1
|
|
129
|
+
async with aiohttp.ClientSession() as session:
|
|
130
|
+
async with session.post(
|
|
131
|
+
url, headers=headers, json=data, timeout=self.request_timeout
|
|
132
|
+
) as response:
|
|
133
|
+
# print("got response!!")
|
|
134
|
+
response_obj: RerankingResponse = await self.handle_response(
|
|
135
|
+
response
|
|
136
|
+
)
|
|
137
|
+
self.result.append(response_obj)
|
|
138
|
+
if response_obj.is_error:
|
|
139
|
+
self.handle_error()
|
|
140
|
+
else:
|
|
141
|
+
self.handle_success()
|
|
142
|
+
|
|
143
|
+
except asyncio.TimeoutError:
|
|
144
|
+
self.result.append(
|
|
145
|
+
RerankingResponse(
|
|
146
|
+
id=self.task_id,
|
|
147
|
+
status_code=None,
|
|
148
|
+
is_error=True,
|
|
149
|
+
error_message="Timeout",
|
|
150
|
+
query=self.query,
|
|
151
|
+
documents=[],
|
|
152
|
+
top_k_indices=[],
|
|
153
|
+
top_k_scores=[],
|
|
154
|
+
)
|
|
155
|
+
)
|
|
156
|
+
self.handle_error()
|
|
157
|
+
|
|
158
|
+
except Exception as e:
|
|
159
|
+
self.result.append(
|
|
160
|
+
RerankingResponse(
|
|
161
|
+
id=self.task_id,
|
|
162
|
+
status_code=None,
|
|
163
|
+
is_error=True,
|
|
164
|
+
error_message=f"Unexpected {type(e).__name__}: {str(e) or 'No message.'}",
|
|
165
|
+
query=self.query,
|
|
166
|
+
documents=[],
|
|
167
|
+
top_k_indices=[],
|
|
168
|
+
top_k_scores=[],
|
|
169
|
+
)
|
|
170
|
+
)
|
|
171
|
+
self.handle_error()
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
@dataclass
|
|
175
|
+
class RerankingResponse:
|
|
176
|
+
id: int
|
|
177
|
+
status_code: int | None
|
|
178
|
+
is_error: bool
|
|
179
|
+
error_message: Optional[str]
|
|
180
|
+
query: str
|
|
181
|
+
documents: list[str]
|
|
182
|
+
top_k_indices: list[int]
|
|
183
|
+
top_k_scores: list[float]
|
|
184
|
+
|
|
185
|
+
@property
|
|
186
|
+
def ranked_documents(self):
|
|
187
|
+
return [self.documents[i] for i in self.top_k_indices]
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
async def rerank_parallel_async(
|
|
191
|
+
queries: list[str],
|
|
192
|
+
docs: list[list[str]], # one list per query
|
|
193
|
+
top_k: int = 3,
|
|
194
|
+
model: str = "rerank-english-v3.0",
|
|
195
|
+
max_attempts: int = 5,
|
|
196
|
+
max_requests_per_minute: int = 4_000,
|
|
197
|
+
max_concurrent_requests: int = 500,
|
|
198
|
+
request_timeout: int = 10,
|
|
199
|
+
progress_bar: Optional[tqdm] = None,
|
|
200
|
+
):
|
|
201
|
+
"""Processes rerank requests in parallel, throttling to stay under rate limits."""
|
|
202
|
+
ids = range(len(queries))
|
|
203
|
+
# constants
|
|
204
|
+
seconds_to_pause_after_rate_limit_error = 5
|
|
205
|
+
seconds_to_sleep_each_loop = 0.003 # so concurrent tasks can run
|
|
206
|
+
|
|
207
|
+
# initialize trackers
|
|
208
|
+
retry_queue = asyncio.Queue()
|
|
209
|
+
status_tracker = StatusTracker()
|
|
210
|
+
next_request = None # variable to hold the next request to call
|
|
211
|
+
|
|
212
|
+
# initialize available capacity counts
|
|
213
|
+
# throttle over a 1 second window rather than minute,
|
|
214
|
+
# since some models limit RPS rather than RPM
|
|
215
|
+
available_request_capacity = max_requests_per_minute
|
|
216
|
+
last_update_time = time.time()
|
|
217
|
+
last_pbar_update_time = time.time()
|
|
218
|
+
|
|
219
|
+
# initialize flags
|
|
220
|
+
prompts_not_finished = True
|
|
221
|
+
prompts_iter = iter(zip(ids, queries, docs))
|
|
222
|
+
results: list[RerankingRequest] = []
|
|
223
|
+
|
|
224
|
+
while True:
|
|
225
|
+
# get next request (if one is not already waiting for capacity)
|
|
226
|
+
if next_request is None:
|
|
227
|
+
if not retry_queue.empty():
|
|
228
|
+
next_request = retry_queue.get_nowait()
|
|
229
|
+
print(f"Retrying request {next_request.task_id}.")
|
|
230
|
+
elif prompts_not_finished:
|
|
231
|
+
try:
|
|
232
|
+
# get new request
|
|
233
|
+
req_id, req_query, req_docs = next(prompts_iter)
|
|
234
|
+
next_request = RerankingRequest(
|
|
235
|
+
task_id=req_id,
|
|
236
|
+
model_name=model,
|
|
237
|
+
query=req_query,
|
|
238
|
+
documents=req_docs,
|
|
239
|
+
top_k=top_k,
|
|
240
|
+
attempts_left=max_attempts,
|
|
241
|
+
status_tracker=status_tracker,
|
|
242
|
+
retry_queue=retry_queue,
|
|
243
|
+
request_timeout=request_timeout,
|
|
244
|
+
pbar=progress_bar,
|
|
245
|
+
)
|
|
246
|
+
status_tracker.num_tasks_started += 1
|
|
247
|
+
status_tracker.num_tasks_in_progress += 1
|
|
248
|
+
results.append(next_request)
|
|
249
|
+
|
|
250
|
+
except StopIteration:
|
|
251
|
+
prompts_not_finished = False
|
|
252
|
+
print("API requests finished, only retries remain.")
|
|
253
|
+
|
|
254
|
+
# update available capacity
|
|
255
|
+
current_time = time.time()
|
|
256
|
+
seconds_since_update = current_time - last_update_time
|
|
257
|
+
available_request_capacity = min(
|
|
258
|
+
available_request_capacity
|
|
259
|
+
+ max_requests_per_minute * seconds_since_update / 60.0,
|
|
260
|
+
max_requests_per_minute,
|
|
261
|
+
)
|
|
262
|
+
last_update_time = current_time
|
|
263
|
+
|
|
264
|
+
# update pbar status
|
|
265
|
+
if progress_bar:
|
|
266
|
+
if current_time - last_pbar_update_time > 1:
|
|
267
|
+
last_pbar_update_time = current_time
|
|
268
|
+
progress_bar.set_postfix(
|
|
269
|
+
{
|
|
270
|
+
"Request Capacity": f"{available_request_capacity:.1f}",
|
|
271
|
+
"Requests in Progress": status_tracker.num_tasks_in_progress,
|
|
272
|
+
}
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# if enough capacity available, call API
|
|
276
|
+
if next_request:
|
|
277
|
+
if (
|
|
278
|
+
available_request_capacity >= 1
|
|
279
|
+
and status_tracker.num_tasks_in_progress < max_concurrent_requests
|
|
280
|
+
):
|
|
281
|
+
# update counters
|
|
282
|
+
available_request_capacity -= 1
|
|
283
|
+
next_request.attempts_left -= 1
|
|
284
|
+
|
|
285
|
+
# call API
|
|
286
|
+
asyncio.create_task(next_request.call_api())
|
|
287
|
+
next_request = None # reset next_request to empty
|
|
288
|
+
|
|
289
|
+
# if all tasks are finished, break
|
|
290
|
+
if status_tracker.num_tasks_in_progress == 0:
|
|
291
|
+
break
|
|
292
|
+
|
|
293
|
+
# main loop sleeps briefly so concurrent tasks can run
|
|
294
|
+
await asyncio.sleep(seconds_to_sleep_each_loop)
|
|
295
|
+
|
|
296
|
+
# if a rate limit error was hit recently, pause to cool down
|
|
297
|
+
seconds_since_rate_limit_error = (
|
|
298
|
+
time.time() - status_tracker.time_of_last_rate_limit_error
|
|
299
|
+
)
|
|
300
|
+
if seconds_since_rate_limit_error < seconds_to_pause_after_rate_limit_error:
|
|
301
|
+
remaining_seconds_to_pause = (
|
|
302
|
+
seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error
|
|
303
|
+
)
|
|
304
|
+
await asyncio.sleep(remaining_seconds_to_pause)
|
|
305
|
+
# ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago
|
|
306
|
+
print(
|
|
307
|
+
f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}"
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
# after finishing, log final status
|
|
311
|
+
if status_tracker.num_tasks_failed > 0:
|
|
312
|
+
print(
|
|
313
|
+
f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed."
|
|
314
|
+
)
|
|
315
|
+
if status_tracker.num_rate_limit_errors > 0:
|
|
316
|
+
print(
|
|
317
|
+
f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate."
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
print(
|
|
321
|
+
f"After processing, got {len(results)} results for {len(ids)} inputs. Removing duplicates."
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
# deduplicate results by id
|
|
325
|
+
deduplicated = {}
|
|
326
|
+
for request in results:
|
|
327
|
+
if request.task_id not in deduplicated:
|
|
328
|
+
deduplicated[request.task_id] = request.result[-1]
|
|
329
|
+
else:
|
|
330
|
+
current_response: RerankingResponse = deduplicated[request.task_id]
|
|
331
|
+
# only replace if the current request has no top_k_indexes and the new one does
|
|
332
|
+
if request.result[-1].top_k_indices and not current_response.top_k_indices:
|
|
333
|
+
deduplicated[request.task_id] = request.result[-1]
|
|
334
|
+
|
|
335
|
+
output = list(deduplicated.values())
|
|
336
|
+
print(f"Returning {len(output)} unique results.")
|
|
337
|
+
|
|
338
|
+
return output
|