lm-deluge 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

lm_deluge/prompt.py ADDED
@@ -0,0 +1,355 @@
1
+ import io
2
+ import json
3
+ import tiktoken
4
+ import xxhash
5
+ from dataclasses import dataclass, field
6
+ from pathlib import Path
7
+ from typing import Literal, Sequence
8
+ from lm_deluge.models import APIModel
9
+ from lm_deluge.image import Image
10
+
11
+ ###############################################################################
12
+ # 1. Low-level content blocks – either text or an image #
13
+ ###############################################################################
14
+ Role = Literal["system", "user", "assistant"]
15
+
16
+
17
+ @dataclass(slots=True)
18
+ class Text:
19
+ text: str
20
+ type: str = field(init=False, default="text")
21
+
22
+ @property
23
+ def fingerprint(self) -> str:
24
+ return xxhash.xxh64(self.text.encode()).hexdigest()
25
+
26
+ # ── provider-specific emission ────────────────────────────────────────────
27
+ def oa_chat(self) -> dict | str: # OpenAI Chat Completions
28
+ return {"type": "text", "text": self.text}
29
+
30
+ def oa_resp(self) -> dict: # OpenAI *Responses* (new)
31
+ return {"type": "input_text", "text": self.text}
32
+
33
+ def anthropic(self) -> dict: # Anthropic Messages
34
+ return {"type": "text", "text": self.text}
35
+
36
+ def gemini(self) -> dict:
37
+ return {"text": self.text}
38
+
39
+
40
+ ###############################################################################
41
+ # 2. One conversational turn (role + parts) #
42
+ ###############################################################################
43
+ @dataclass(slots=True)
44
+ class Message:
45
+ role: Role
46
+ parts: list[Text | Image]
47
+
48
+ @property
49
+ def fingerprint(self) -> str:
50
+ return self.role + "," + ",".join(part.fingerprint for part in self.parts)
51
+
52
+ def add_text(self, content: str) -> "Message":
53
+ """Append a text block and return self for chaining."""
54
+ self.parts.append(Text(content))
55
+ return self
56
+
57
+ def add_image(
58
+ self,
59
+ data: bytes | str | Path | io.BytesIO,
60
+ *,
61
+ media_type: str | None = None,
62
+ detail: Literal["low", "high", "auto"] = "auto",
63
+ max_size: int | None = None,
64
+ ) -> "Message":
65
+ """
66
+ Append an image block and return self for chaining.
67
+
68
+ If max_size is provided, the image will be resized so that its longer
69
+ dimension equals max_size, but only if the longer dimension is currently
70
+ larger than max_size.
71
+ """
72
+ img = Image(data, media_type=media_type, detail=detail)
73
+
74
+ # Resize if max_size is provided
75
+ if max_size is not None:
76
+ img.resize(max_size)
77
+
78
+ self.parts.append(img)
79
+ return self
80
+
81
+ # -------- convenient constructors --------
82
+ @classmethod
83
+ def user(
84
+ cls,
85
+ text: str | None = None,
86
+ *,
87
+ image: str | bytes | Path | io.BytesIO | None = None,
88
+ ) -> "Message":
89
+ res = cls("user", [])
90
+ if text is not None:
91
+ res.add_text(text)
92
+ if image is not None:
93
+ res.add_image(image)
94
+ return res
95
+
96
+ @classmethod
97
+ def system(cls, text: str | None = None) -> "Message":
98
+ res = cls("system", [])
99
+ if text is not None:
100
+ res.add_text(text)
101
+ return res
102
+
103
+ @classmethod
104
+ def ai(cls, text: str | None = None) -> "Message":
105
+ res = cls("assistant", [])
106
+ if text is not None:
107
+ res.add_text(text)
108
+ return res
109
+
110
+ # ──── provider-specific constructors ───
111
+ @classmethod
112
+ def from_oa(cls, msg: dict):
113
+ role = (
114
+ "system"
115
+ if msg["role"] in ["developer", "system"]
116
+ else ("user" if msg["role"] == "user" else "assistant")
117
+ )
118
+ parts: Sequence[Text | Image] = []
119
+ content = msg["content"]
120
+ if isinstance(content, str):
121
+ parts.append(Text(content))
122
+ else:
123
+ for item in content:
124
+ if item["type"] == "text":
125
+ parts.append(Text(item["text"]))
126
+ elif item["type"] == "image_url":
127
+ parts.append(Image(data=item["image_url"]["url"]))
128
+ return cls(role, parts)
129
+
130
+ @classmethod
131
+ def from_oa_resp(cls, msg: dict):
132
+ raise NotImplementedError("not implemented")
133
+
134
+ @classmethod
135
+ def from_anthropic(cls, msg: dict):
136
+ pass
137
+
138
+ # ───── provider-specific emission ─────
139
+ def oa_chat(self) -> dict:
140
+ content = []
141
+ for p in self.parts:
142
+ content.append(p.oa_chat())
143
+ return {"role": self.role, "content": content}
144
+
145
+ def oa_resp(self) -> dict:
146
+ content = [p.oa_resp() for p in self.parts]
147
+ return {"role": self.role, "content": content}
148
+
149
+ def anthropic(self) -> dict:
150
+ # Anthropic: system message is *not* in the list
151
+ if self.role == "system":
152
+ raise ValueError("Anthropic keeps system outside message list")
153
+ content = [p.anthropic() for p in self.parts]
154
+ # Shortcut: single text becomes a bare string
155
+ if len(content) == 1 and content[0]["type"] == "text":
156
+ content = content[0]["text"]
157
+ return {"role": self.role, "content": content}
158
+
159
+ def gemini(self) -> dict:
160
+ parts = [p.gemini() for p in self.parts]
161
+ # Shortcut: single text becomes a bare string
162
+ role = "user" if self.role == "user" else "model"
163
+ return {"role": role, "parts": parts}
164
+
165
+
166
+ ###############################################################################
167
+ # 3. A whole conversation (ordered list of messages) #
168
+ ###############################################################################
169
+
170
+
171
+ @dataclass(slots=True)
172
+ class Conversation:
173
+ messages: list[Message] = field(default_factory=list)
174
+
175
+ # ── convenience shorthands ------------------------------------------------
176
+ @classmethod
177
+ def system(cls, text: str) -> "Conversation":
178
+ return cls([Message.system(text)])
179
+
180
+ @classmethod
181
+ def user(
182
+ cls, text: str, *, image: bytes | str | Path | None = None
183
+ ) -> "Conversation":
184
+ msg = (
185
+ Message.user(text) if image is None else Message.user(text).add_image(image)
186
+ )
187
+ return cls([msg])
188
+
189
+ @classmethod
190
+ def from_openai(cls, messages: list[dict]):
191
+ """Compatibility with openai-formatted messages"""
192
+ pass
193
+
194
+ @classmethod
195
+ def from_anthropic(cls, messages: list[dict], system: str | None = None):
196
+ """Compatibility with anthropic-formatted messages"""
197
+ pass
198
+
199
+ # fluent additions
200
+ def add(self, msg: Message) -> "Conversation":
201
+ self.messages.append(msg)
202
+ return self
203
+
204
+ # ── conversions -----------------------------------------------------------
205
+ def to_openai(self) -> list[dict]:
206
+ return [m.oa_chat() for m in self.messages]
207
+
208
+ def to_openai_responses(self) -> dict:
209
+ # OpenAI Responses = single “input” array, role must be user/assistant
210
+ return {"input": [m.oa_resp() for m in self.messages if m.role != "system"]}
211
+
212
+ def to_anthropic(self) -> tuple[str | None, list[dict]]:
213
+ system_msg = next(
214
+ (
215
+ m.parts[0].text
216
+ for m in self.messages
217
+ if m.role == "system" and isinstance(m.parts[0], Text)
218
+ ),
219
+ None,
220
+ )
221
+ other = [m.anthropic() for m in self.messages if m.role != "system"]
222
+ return system_msg, other
223
+
224
+ def to_gemini(self) -> tuple[str | None, list[dict]]:
225
+ system_msg = next(
226
+ (
227
+ m.parts[0].text
228
+ for m in self.messages
229
+ if m.role == "system" and isinstance(m.parts[0], Text)
230
+ ),
231
+ None,
232
+ )
233
+ other = [m.gemini() for m in self.messages if m.role != "system"]
234
+ return system_msg, other
235
+
236
+ def to_cohere(self) -> list[dict]:
237
+ messages = []
238
+ for m in self.messages:
239
+ if len(m.parts) > 1:
240
+ raise ValueError("Cohere does not support multi-part messages")
241
+ if isinstance(m.parts[0], Image):
242
+ raise ValueError("Cohere does not support images")
243
+ messages.append({"role": m.role, "text": m.parts[0].text})
244
+ return messages
245
+
246
+ # ── misc helpers ----------------------------------------------------------
247
+ _tok = tiktoken.encoding_for_model("gpt-4")
248
+
249
+ def count_tokens(self, max_new_tokens: int = 0, img_tokens: int = 85) -> int:
250
+ n = max_new_tokens
251
+ for m in self.messages:
252
+ for p in m.parts:
253
+ if isinstance(p, Text):
254
+ n += len(self._tok.encode(p.text))
255
+ else: # Image – crude flat cost per image
256
+ n += img_tokens
257
+
258
+ # very rough BOS/EOS padding
259
+ return n + 6 * len(self.messages)
260
+
261
+ def dry_run(self, model_name: str, max_new_tokens: int):
262
+ model_obj = APIModel.from_registry(model_name)
263
+ if model_obj.api_spec == "openai":
264
+ image_tokens = 85
265
+ elif model_obj.api_spec == "anthropic":
266
+ image_tokens = 1_200
267
+ else:
268
+ image_tokens = 0
269
+ input_tokens = self.count_tokens(0, image_tokens)
270
+ output_tokens = max_new_tokens
271
+
272
+ min_cost, max_cost = None, None
273
+ if model_obj.input_cost and model_obj.output_cost:
274
+ min_cost = model_obj.input_cost * input_tokens / 1e6
275
+ max_cost = min_cost + model_obj.output_cost * output_tokens / 1e6
276
+
277
+ return input_tokens, output_tokens, min_cost, max_cost
278
+
279
+ @property
280
+ def fingerprint(self) -> str:
281
+ hasher = xxhash.xxh64()
282
+ hasher.update(json.dumps([m.fingerprint for m in self.messages]).encode())
283
+ return hasher.hexdigest()
284
+
285
+ def to_log(self) -> dict:
286
+ """
287
+ Return a JSON-serialisable dict that fully captures the conversation.
288
+ """
289
+ serialized: list[dict] = []
290
+
291
+ for msg in self.messages:
292
+ content_blocks: list[dict] = []
293
+ for p in msg.parts:
294
+ if isinstance(p, Text):
295
+ content_blocks.append({"type": "text", "text": p.text})
296
+ else: # Image – redact the bytes, keep a hint
297
+ w, h = p.size
298
+ content_blocks.append(
299
+ {"type": "image", "tag": f"<Image ({w}×{h})>"}
300
+ )
301
+ serialized.append({"role": msg.role, "content": content_blocks})
302
+
303
+ return {"messages": serialized}
304
+
305
+ @classmethod
306
+ def from_log(cls, payload: dict) -> "Conversation":
307
+ """Re-hydrate a Conversation previously produced by `to_log()`."""
308
+ msgs: list[Message] = []
309
+
310
+ for m in payload.get("messages", []):
311
+ role: Role = m["role"] # 'system' | 'user' | 'assistant'
312
+ parts: list[Text | Image] = []
313
+
314
+ for p in m["content"]:
315
+ if p["type"] == "text":
316
+ parts.append(Text(p["text"]))
317
+ elif p["type"] == "image":
318
+ # We only stored a placeholder tag, so keep that placeholder.
319
+ # You could raise instead if real image bytes are required.
320
+ parts.append(Image(p["tag"], detail="low"))
321
+ else:
322
+ raise ValueError(f"Unknown part type {p['type']!r}")
323
+
324
+ msgs.append(Message(role, parts))
325
+
326
+ return cls(msgs)
327
+
328
+
329
+ ###############################################################################
330
+ # --------------------------------------------------------------------------- #
331
+ # Basic usage examples #
332
+ # --------------------------------------------------------------------------- #
333
+
334
+ # 1️⃣ trivial single-turn (text only) ---------------------------------------
335
+ # conv = Conversation.user("Hi Claude, who won the 2018 World Cup?")
336
+ # client.messages.create(model="claude-3-7-sonnet", **conv.to_anthropic())
337
+
338
+ # # 2️⃣ system + vision + follow-up for OpenAI Chat Completions ---------------
339
+ # conv = (
340
+ # Conversation.system("You are a visual assistant.")
341
+ # .add(
342
+ # Message.with_image(
343
+ # "user",
344
+ # "What's in this photo?",
345
+ # Image("boardwalk.jpg", detail="low"),
346
+ # )
347
+ # )
348
+ # .add(Message.text("assistant", "Looks like a lakeside boardwalk."))
349
+ # .add(Message.text("user", "Great, write a haiku about it."))
350
+ # )
351
+
352
+ # openai.chat.completions.create(model="gpt-4o-mini", messages=conv.to_openai_chat())
353
+
354
+ # # 3️⃣ Same conversation sent through new Responses API -----------------------
355
+ # openai.responses.create(model="gpt-4o-mini", **conv.to_openai_responses())
lm_deluge/rerank.py ADDED
@@ -0,0 +1,338 @@
1
+ ### specific utility for cohere rerank api
2
+ import os
3
+ import aiohttp
4
+ from tqdm.auto import tqdm
5
+ import asyncio
6
+ import time
7
+ from typing import Optional
8
+ from dataclasses import dataclass
9
+ from .tracker import StatusTracker
10
+
11
+ registry = [
12
+ "rerank-english-v3.0",
13
+ "rerank-multilingual-v3.0",
14
+ "rerank-english-v2.0",
15
+ "rerank-multilingual-v2.0",
16
+ ]
17
+
18
+
19
+ class RerankingRequest:
20
+ def __init__(
21
+ self,
22
+ task_id: int,
23
+ model_name: str,
24
+ query: str,
25
+ documents: list[str],
26
+ top_k: int,
27
+ attempts_left: int,
28
+ status_tracker: StatusTracker,
29
+ retry_queue: asyncio.Queue,
30
+ request_timeout: int,
31
+ pbar: Optional[tqdm] = None,
32
+ ):
33
+ self.task_id = task_id
34
+ self.model_name = model_name
35
+ self.query = query
36
+ self.documents = documents
37
+ self.top_k = top_k
38
+ self.attempts_left = attempts_left
39
+ self.status_tracker = status_tracker
40
+ self.retry_queue = retry_queue
41
+ self.request_timeout = request_timeout
42
+ self.pbar = pbar
43
+ self.result = []
44
+
45
+ def increment_pbar(self):
46
+ if self.pbar is not None:
47
+ self.pbar.update(1)
48
+
49
+ def handle_success(self):
50
+ self.increment_pbar()
51
+ self.status_tracker.num_tasks_in_progress -= 1
52
+ self.status_tracker.num_tasks_succeeded += 1
53
+
54
+ def handle_error(self):
55
+ """
56
+ If create_new_request is True, will create a new API request (so that it
57
+ has a chance of being sent to a different model). If false, will retry
58
+ the same request.
59
+ """
60
+ last_result: RerankingResponse = self.result[-1]
61
+ error_to_print = (
62
+ f"Error on task {self.task_id}, Code: {last_result.status_code}, "
63
+ )
64
+ error_to_print += f"Message: {last_result.error_message}."
65
+ print(error_to_print)
66
+ if self.attempts_left > 0:
67
+ self.attempts_left -= 1
68
+ self.retry_queue.put_nowait(self)
69
+ return
70
+ else:
71
+ print(f"Task {self.task_id} out of tries.")
72
+ self.status_tracker.num_tasks_in_progress -= 1
73
+ self.status_tracker.num_tasks_failed += 1
74
+
75
+ async def handle_response(self, response: aiohttp.ClientResponse):
76
+ try:
77
+ if response.status == 200:
78
+ result = await response.json()
79
+ # TODO: add cost calculation
80
+ return RerankingResponse(
81
+ id=self.task_id,
82
+ status_code=response.status,
83
+ is_error=False,
84
+ error_message=None,
85
+ query=self.query,
86
+ documents=self.documents,
87
+ top_k_indices=[doc["index"] for doc in result["results"]],
88
+ top_k_scores=[doc["relevance_score"] for doc in result["results"]],
89
+ )
90
+ else:
91
+ error_msg = await response.text()
92
+ return RerankingResponse(
93
+ id=self.task_id,
94
+ status_code=response.status,
95
+ is_error=True,
96
+ error_message=error_msg,
97
+ query=self.query,
98
+ documents=[],
99
+ top_k_indices=[],
100
+ top_k_scores=[],
101
+ )
102
+ except Exception as e:
103
+ return RerankingResponse(
104
+ id=self.task_id,
105
+ status_code=response.status,
106
+ is_error=True,
107
+ error_message=str(e),
108
+ query=self.query,
109
+ documents=[],
110
+ top_k_indices=[],
111
+ top_k_scores=[],
112
+ )
113
+
114
+ async def call_api(self):
115
+ url = "https://api.cohere.com/v1/rerank"
116
+ headers = {
117
+ "accept": "application/json",
118
+ "content-type": "application/json",
119
+ "Authorization": f"Bearer {os.environ.get('COHERE_API_KEY')}",
120
+ }
121
+ data = {
122
+ "model": self.model_name,
123
+ "query": self.query,
124
+ "top_n": self.top_k,
125
+ "documents": self.documents,
126
+ }
127
+ try:
128
+ self.status_tracker.total_requests += 1
129
+ async with aiohttp.ClientSession() as session:
130
+ async with session.post(
131
+ url, headers=headers, json=data, timeout=self.request_timeout
132
+ ) as response:
133
+ # print("got response!!")
134
+ response_obj: RerankingResponse = await self.handle_response(
135
+ response
136
+ )
137
+ self.result.append(response_obj)
138
+ if response_obj.is_error:
139
+ self.handle_error()
140
+ else:
141
+ self.handle_success()
142
+
143
+ except asyncio.TimeoutError:
144
+ self.result.append(
145
+ RerankingResponse(
146
+ id=self.task_id,
147
+ status_code=None,
148
+ is_error=True,
149
+ error_message="Timeout",
150
+ query=self.query,
151
+ documents=[],
152
+ top_k_indices=[],
153
+ top_k_scores=[],
154
+ )
155
+ )
156
+ self.handle_error()
157
+
158
+ except Exception as e:
159
+ self.result.append(
160
+ RerankingResponse(
161
+ id=self.task_id,
162
+ status_code=None,
163
+ is_error=True,
164
+ error_message=f"Unexpected {type(e).__name__}: {str(e) or 'No message.'}",
165
+ query=self.query,
166
+ documents=[],
167
+ top_k_indices=[],
168
+ top_k_scores=[],
169
+ )
170
+ )
171
+ self.handle_error()
172
+
173
+
174
+ @dataclass
175
+ class RerankingResponse:
176
+ id: int
177
+ status_code: int | None
178
+ is_error: bool
179
+ error_message: Optional[str]
180
+ query: str
181
+ documents: list[str]
182
+ top_k_indices: list[int]
183
+ top_k_scores: list[float]
184
+
185
+ @property
186
+ def ranked_documents(self):
187
+ return [self.documents[i] for i in self.top_k_indices]
188
+
189
+
190
+ async def rerank_parallel_async(
191
+ queries: list[str],
192
+ docs: list[list[str]], # one list per query
193
+ top_k: int = 3,
194
+ model: str = "rerank-english-v3.0",
195
+ max_attempts: int = 5,
196
+ max_requests_per_minute: int = 4_000,
197
+ max_concurrent_requests: int = 500,
198
+ request_timeout: int = 10,
199
+ progress_bar: Optional[tqdm] = None,
200
+ ):
201
+ """Processes rerank requests in parallel, throttling to stay under rate limits."""
202
+ ids = range(len(queries))
203
+ # constants
204
+ seconds_to_pause_after_rate_limit_error = 5
205
+ seconds_to_sleep_each_loop = 0.003 # so concurrent tasks can run
206
+
207
+ # initialize trackers
208
+ retry_queue = asyncio.Queue()
209
+ status_tracker = StatusTracker()
210
+ next_request = None # variable to hold the next request to call
211
+
212
+ # initialize available capacity counts
213
+ # throttle over a 1 second window rather than minute,
214
+ # since some models limit RPS rather than RPM
215
+ available_request_capacity = max_requests_per_minute
216
+ last_update_time = time.time()
217
+ last_pbar_update_time = time.time()
218
+
219
+ # initialize flags
220
+ prompts_not_finished = True
221
+ prompts_iter = iter(zip(ids, queries, docs))
222
+ results: list[RerankingRequest] = []
223
+
224
+ while True:
225
+ # get next request (if one is not already waiting for capacity)
226
+ if next_request is None:
227
+ if not retry_queue.empty():
228
+ next_request = retry_queue.get_nowait()
229
+ print(f"Retrying request {next_request.task_id}.")
230
+ elif prompts_not_finished:
231
+ try:
232
+ # get new request
233
+ req_id, req_query, req_docs = next(prompts_iter)
234
+ next_request = RerankingRequest(
235
+ task_id=req_id,
236
+ model_name=model,
237
+ query=req_query,
238
+ documents=req_docs,
239
+ top_k=top_k,
240
+ attempts_left=max_attempts,
241
+ status_tracker=status_tracker,
242
+ retry_queue=retry_queue,
243
+ request_timeout=request_timeout,
244
+ pbar=progress_bar,
245
+ )
246
+ status_tracker.num_tasks_started += 1
247
+ status_tracker.num_tasks_in_progress += 1
248
+ results.append(next_request)
249
+
250
+ except StopIteration:
251
+ prompts_not_finished = False
252
+ print("API requests finished, only retries remain.")
253
+
254
+ # update available capacity
255
+ current_time = time.time()
256
+ seconds_since_update = current_time - last_update_time
257
+ available_request_capacity = min(
258
+ available_request_capacity
259
+ + max_requests_per_minute * seconds_since_update / 60.0,
260
+ max_requests_per_minute,
261
+ )
262
+ last_update_time = current_time
263
+
264
+ # update pbar status
265
+ if progress_bar:
266
+ if current_time - last_pbar_update_time > 1:
267
+ last_pbar_update_time = current_time
268
+ progress_bar.set_postfix(
269
+ {
270
+ "Request Capacity": f"{available_request_capacity:.1f}",
271
+ "Requests in Progress": status_tracker.num_tasks_in_progress,
272
+ }
273
+ )
274
+
275
+ # if enough capacity available, call API
276
+ if next_request:
277
+ if (
278
+ available_request_capacity >= 1
279
+ and status_tracker.num_tasks_in_progress < max_concurrent_requests
280
+ ):
281
+ # update counters
282
+ available_request_capacity -= 1
283
+ next_request.attempts_left -= 1
284
+
285
+ # call API
286
+ asyncio.create_task(next_request.call_api())
287
+ next_request = None # reset next_request to empty
288
+
289
+ # if all tasks are finished, break
290
+ if status_tracker.num_tasks_in_progress == 0:
291
+ break
292
+
293
+ # main loop sleeps briefly so concurrent tasks can run
294
+ await asyncio.sleep(seconds_to_sleep_each_loop)
295
+
296
+ # if a rate limit error was hit recently, pause to cool down
297
+ seconds_since_rate_limit_error = (
298
+ time.time() - status_tracker.time_of_last_rate_limit_error
299
+ )
300
+ if seconds_since_rate_limit_error < seconds_to_pause_after_rate_limit_error:
301
+ remaining_seconds_to_pause = (
302
+ seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error
303
+ )
304
+ await asyncio.sleep(remaining_seconds_to_pause)
305
+ # ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago
306
+ print(
307
+ f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}"
308
+ )
309
+
310
+ # after finishing, log final status
311
+ if status_tracker.num_tasks_failed > 0:
312
+ print(
313
+ f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed."
314
+ )
315
+ if status_tracker.num_rate_limit_errors > 0:
316
+ print(
317
+ f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate."
318
+ )
319
+
320
+ print(
321
+ f"After processing, got {len(results)} results for {len(ids)} inputs. Removing duplicates."
322
+ )
323
+
324
+ # deduplicate results by id
325
+ deduplicated = {}
326
+ for request in results:
327
+ if request.task_id not in deduplicated:
328
+ deduplicated[request.task_id] = request.result[-1]
329
+ else:
330
+ current_response: RerankingResponse = deduplicated[request.task_id]
331
+ # only replace if the current request has no top_k_indexes and the new one does
332
+ if request.result[-1].top_k_indices and not current_response.top_k_indices:
333
+ deduplicated[request.task_id] = request.result[-1]
334
+
335
+ output = list(deduplicated.values())
336
+ print(f"Returning {len(output)} unique results.")
337
+
338
+ return output