lm-deluge 0.0.12__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

lm_deluge/__init__.py CHANGED
@@ -1,7 +1,15 @@
1
1
  from .client import LLMClient, SamplingParams, APIResponse
2
2
  from .prompt import Conversation, Message
3
+ from .tool import Tool
3
4
  import dotenv
4
5
 
5
6
  dotenv.load_dotenv()
6
7
 
7
- __all__ = ["LLMClient", "SamplingParams", "APIResponse", "Conversation", "Message"]
8
+ __all__ = [
9
+ "LLMClient",
10
+ "SamplingParams",
11
+ "APIResponse",
12
+ "Conversation",
13
+ "Message",
14
+ "Tool",
15
+ ]
lm_deluge/agent.py ADDED
File without changes
@@ -1,9 +1,6 @@
1
- import asyncio
2
1
  from aiohttp import ClientResponse
3
2
  import json
4
3
  import os
5
- import warnings
6
- from tqdm import tqdm
7
4
  from typing import Callable
8
5
 
9
6
  from lm_deluge.prompt import (
@@ -14,12 +11,84 @@ from lm_deluge.prompt import (
14
11
  Thinking,
15
12
  CachePattern,
16
13
  )
14
+ from lm_deluge.tool import Tool
17
15
  from lm_deluge.usage import Usage
18
16
  from .base import APIRequestBase, APIResponse
19
17
 
20
18
  from ..tracker import StatusTracker
21
- from ..sampling_params import SamplingParams
19
+ from ..config import SamplingParams
22
20
  from ..models import APIModel
21
+ from ..computer_use.anthropic_tools import get_anthropic_cu_tools
22
+
23
+
24
+ def _build_anthropic_request(
25
+ model: APIModel,
26
+ prompt: Conversation,
27
+ tools: list[Tool] | None,
28
+ sampling_params: SamplingParams,
29
+ cache_pattern: CachePattern | None = None,
30
+ computer_use: bool = False,
31
+ display_width: int = 1024,
32
+ display_height: int = 768,
33
+ ):
34
+ system_message, messages = prompt.to_anthropic(cache_pattern=cache_pattern)
35
+ request_header = {
36
+ "x-api-key": os.getenv(model.api_key_env_var),
37
+ "anthropic-version": "2023-06-01",
38
+ "content-type": "application/json",
39
+ }
40
+
41
+ # Add beta header for Computer Use
42
+ if computer_use:
43
+ request_header["anthropic-beta"] = "computer-use-2025-01-24"
44
+
45
+ request_json = {
46
+ "model": model.name,
47
+ "messages": messages,
48
+ "temperature": sampling_params.temperature,
49
+ "top_p": sampling_params.top_p,
50
+ "max_tokens": sampling_params.max_new_tokens,
51
+ }
52
+
53
+ # handle thinking
54
+ if model.reasoning_model and sampling_params.reasoning_effort:
55
+ # translate reasoning effort of low, medium, high to budget tokens
56
+ budget = {"low": 1024, "medium": 4096, "high": 16384}.get(
57
+ sampling_params.reasoning_effort
58
+ )
59
+ request_json["thinking"] = {
60
+ "type": "enabled",
61
+ "budget_tokens": budget,
62
+ }
63
+ request_json.pop("top_p")
64
+ request_json["temperature"] = 1.0
65
+ request_json["max_tokens"] += budget
66
+ else:
67
+ request_json["thinking"] = {"type": "disabled"}
68
+ if sampling_params.reasoning_effort:
69
+ print("ignoring reasoning_effort for non-reasoning model")
70
+ if system_message is not None:
71
+ request_json["system"] = system_message
72
+ if tools or computer_use:
73
+ tool_definitions = []
74
+ if tools:
75
+ tool_definitions.extend([tool.dump_for("anthropic") for tool in tools])
76
+ # Add Computer Use tools
77
+ if computer_use:
78
+ cu_tools = get_anthropic_cu_tools(
79
+ model=model.id,
80
+ display_width=display_width, # todo: set from ComputerUseParams
81
+ display_height=display_height,
82
+ )
83
+ tool_definitions.extend(cu_tools)
84
+
85
+ # Add cache control to last tool if tools_only caching is specified
86
+ if cache_pattern == "tools_only" and tool_definitions:
87
+ tool_definitions[-1]["cache_control"] = {"type": "ephemeral"}
88
+
89
+ request_json["tools"] = tool_definitions
90
+
91
+ return request_json, request_header
23
92
 
24
93
 
25
94
  class AnthropicRequest(APIRequestBase):
@@ -32,18 +101,19 @@ class AnthropicRequest(APIRequestBase):
32
101
  prompt: Conversation,
33
102
  attempts_left: int,
34
103
  status_tracker: StatusTracker,
35
- retry_queue: asyncio.Queue,
36
104
  results_arr: list,
37
105
  request_timeout: int = 30,
38
106
  sampling_params: SamplingParams = SamplingParams(),
39
- pbar: tqdm | None = None,
40
107
  callback: Callable | None = None,
41
- debug: bool = False,
42
108
  # for retries
43
109
  all_model_names: list[str] | None = None,
44
110
  all_sampling_params: list[SamplingParams] | None = None,
45
111
  tools: list | None = None,
46
112
  cache: CachePattern | None = None,
113
+ # Computer Use support
114
+ computer_use: bool = False,
115
+ display_width: int = 1024,
116
+ display_height: int = 768,
47
117
  ):
48
118
  super().__init__(
49
119
  task_id=task_id,
@@ -51,18 +121,18 @@ class AnthropicRequest(APIRequestBase):
51
121
  prompt=prompt,
52
122
  attempts_left=attempts_left,
53
123
  status_tracker=status_tracker,
54
- retry_queue=retry_queue,
55
124
  results_arr=results_arr,
56
125
  request_timeout=request_timeout,
57
126
  sampling_params=sampling_params,
58
- pbar=pbar,
59
127
  callback=callback,
60
- debug=debug,
61
128
  all_model_names=all_model_names,
62
129
  all_sampling_params=all_sampling_params,
63
130
  tools=tools,
64
131
  cache=cache,
65
132
  )
133
+ self.computer_use = computer_use
134
+ self.display_width = display_width
135
+ self.display_height = display_height
66
136
  self.model = APIModel.from_registry(model_name)
67
137
  self.url = f"{self.model.api_base}/messages"
68
138
 
@@ -70,52 +140,16 @@ class AnthropicRequest(APIRequestBase):
70
140
  if cache is not None:
71
141
  prompt.lock_images_as_bytes()
72
142
 
73
- self.system_message, messages = prompt.to_anthropic(cache_pattern=cache)
74
- self.request_header = {
75
- "x-api-key": os.getenv(self.model.api_key_env_var),
76
- "anthropic-version": "2023-06-01",
77
- "content-type": "application/json",
78
- }
79
-
80
- self.request_json = {
81
- "model": self.model.name,
82
- "messages": messages,
83
- "temperature": self.sampling_params.temperature,
84
- "top_p": self.sampling_params.top_p,
85
- "max_tokens": self.sampling_params.max_new_tokens,
86
- }
87
- # handle thinking
88
- if self.model.reasoning_model:
89
- if sampling_params.reasoning_effort:
90
- # translate reasoning effort of low, medium, high to budget tokens
91
- budget = {"low": 1024, "medium": 4096, "high": 16384}.get(
92
- sampling_params.reasoning_effort
93
- )
94
- self.request_json["thinking"] = {
95
- "type": "enabled",
96
- "budget_tokens": budget,
97
- }
98
- self.request_json.pop("top_p")
99
- self.request_json["temperature"] = 1.0
100
- self.request_json["max_tokens"] += (
101
- budget # assume max tokens is max completion tokens
102
- )
103
- else:
104
- # no thinking
105
- self.request_json["thinking"] = {"type": "disabled"}
106
- else:
107
- if sampling_params.reasoning_effort:
108
- warnings.warn(
109
- f"Ignoring reasoning_effort param for non-reasoning model: {model_name}"
110
- )
111
- if self.system_message is not None:
112
- self.request_json["system"] = self.system_message
113
- if tools:
114
- tool_definitions = [tool.dump_for("anthropic") for tool in tools]
115
- # Add cache control to last tool if tools_only caching is specified
116
- if cache == "tools_only" and tool_definitions:
117
- tool_definitions[-1]["cache_control"] = {"type": "ephemeral"}
118
- self.request_json["tools"] = tool_definitions
143
+ self.request_json, self.request_header = _build_anthropic_request(
144
+ self.model,
145
+ prompt,
146
+ tools,
147
+ sampling_params,
148
+ cache,
149
+ computer_use,
150
+ display_width,
151
+ display_height,
152
+ )
119
153
 
120
154
  async def handle_response(self, http_response: ClientResponse) -> APIResponse:
121
155
  is_error = False
@@ -135,8 +169,6 @@ class AnthropicRequest(APIRequestBase):
135
169
  "anthropic-ratelimit-tokens-reset",
136
170
  ]:
137
171
  rate_limits[header] = http_response.headers.get(header, None)
138
- if self.debug:
139
- print(f"Rate limits: {rate_limits}")
140
172
  if status_code >= 200 and status_code < 300:
141
173
  try:
142
174
  data = await http_response.json()
@@ -1,20 +1,21 @@
1
- import aiohttp
2
1
  import asyncio
3
2
  import json
4
3
  import random
5
- from tqdm import tqdm
6
- from dataclasses import dataclass
4
+ import traceback
7
5
  from abc import ABC, abstractmethod
6
+ from dataclasses import dataclass
8
7
  from typing import Callable
9
8
 
10
- from lm_deluge.prompt import Conversation, Message, CachePattern
9
+ import aiohttp
10
+ from aiohttp import ClientResponse
11
+
12
+ from lm_deluge.prompt import CachePattern, Conversation, Message
11
13
  from lm_deluge.usage import Usage
12
14
 
13
- from ..tracker import StatusTracker
14
- from ..sampling_params import SamplingParams
15
- from ..models import APIModel
15
+ from ..config import SamplingParams
16
16
  from ..errors import raise_if_modal_exception
17
- from aiohttp import ClientResponse
17
+ from ..models import APIModel
18
+ from ..tracker import StatusTracker
18
19
 
19
20
 
20
21
  @dataclass
@@ -48,6 +49,10 @@ class APIResponse:
48
49
  retry_with_different_model: bool | None = False
49
50
  # set to true if should NOT retry with the same model (unrecoverable error)
50
51
  give_up_if_no_other_models: bool | None = False
52
+ # OpenAI Responses API specific - used for computer use continuation
53
+ response_id: str | None = None
54
+ # Raw API response for debugging
55
+ raw_response: dict | None = None
51
56
 
52
57
  @property
53
58
  def completion(self) -> str | None:
@@ -176,16 +181,11 @@ class APIRequestBase(ABC):
176
181
  prompt: Conversation,
177
182
  attempts_left: int,
178
183
  status_tracker: StatusTracker,
179
- retry_queue: asyncio.Queue,
180
184
  # needed in order to retry with a different model and not throw the output away
181
185
  results_arr: list["APIRequestBase"],
182
186
  request_timeout: int = 30,
183
187
  sampling_params: SamplingParams = SamplingParams(),
184
- logprobs: bool = False,
185
- top_logprobs: int | None = None,
186
- pbar: tqdm | None = None,
187
188
  callback: Callable | None = None,
188
- debug: bool = False,
189
189
  all_model_names: list[str] | None = None,
190
190
  all_sampling_params: list[SamplingParams] | None = None,
191
191
  tools: list | None = None,
@@ -199,16 +199,11 @@ class APIRequestBase(ABC):
199
199
  self.prompt = prompt
200
200
  self.attempts_left = attempts_left
201
201
  self.status_tracker = status_tracker
202
- self.retry_queue = retry_queue
203
202
  self.request_timeout = request_timeout
204
203
  self.sampling_params = sampling_params
205
- self.logprobs = logprobs # len(completion) logprobs
206
- self.top_logprobs = top_logprobs
207
- self.pbar = pbar
208
204
  self.callback = callback
209
205
  self.num_tokens = prompt.count_tokens(sampling_params.max_new_tokens)
210
206
  self.results_arr = results_arr
211
- self.debug = debug
212
207
  self.all_model_names = all_model_names
213
208
  self.all_sampling_params = all_sampling_params
214
209
  self.tools = tools
@@ -222,8 +217,7 @@ class APIRequestBase(ABC):
222
217
  self.region = None
223
218
 
224
219
  def increment_pbar(self):
225
- if self.pbar is not None:
226
- self.pbar.update(1)
220
+ self.status_tracker.increment_pbar()
227
221
 
228
222
  def call_callback(self):
229
223
  if self.callback is not None:
@@ -232,7 +226,6 @@ class APIRequestBase(ABC):
232
226
 
233
227
  def handle_success(self, data):
234
228
  self.call_callback()
235
- self.increment_pbar()
236
229
  self.status_tracker.task_succeeded(self.task_id)
237
230
 
238
231
  def handle_error(self, create_new_request=False, give_up_if_no_other_models=False):
@@ -253,7 +246,8 @@ class APIRequestBase(ABC):
253
246
  if self.attempts_left > 0:
254
247
  self.attempts_left -= 1
255
248
  if not create_new_request:
256
- self.retry_queue.put_nowait(self)
249
+ assert self.status_tracker.retry_queue
250
+ self.status_tracker.retry_queue.put_nowait(self)
257
251
  return
258
252
  else:
259
253
  # make sure we have another model to send it to besides the current one
@@ -267,7 +261,8 @@ class APIRequestBase(ABC):
267
261
  print(
268
262
  f"No other models to try for task {self.task_id}. Retrying with same model."
269
263
  )
270
- self.retry_queue.put_nowait(self)
264
+ assert self.status_tracker.retry_queue
265
+ self.status_tracker.retry_queue.put_nowait(self)
271
266
  else:
272
267
  # two things to change: model_name and sampling_params
273
268
  new_model_name = self.model_name
@@ -292,21 +287,21 @@ class APIRequestBase(ABC):
292
287
  prompt=self.prompt,
293
288
  attempts_left=self.attempts_left,
294
289
  status_tracker=self.status_tracker,
295
- retry_queue=self.retry_queue,
296
290
  results_arr=self.results_arr,
297
291
  request_timeout=self.request_timeout,
298
292
  sampling_params=new_sampling_params,
299
- logprobs=self.logprobs,
300
- top_logprobs=self.top_logprobs,
301
- pbar=self.pbar,
302
293
  callback=self.callback,
303
294
  all_model_names=self.all_model_names,
304
295
  all_sampling_params=self.all_sampling_params,
305
296
  tools=self.tools,
306
297
  cache=self.cache,
298
+ computer_use=getattr(self, "computer_use", False),
299
+ display_width=getattr(self, "display_width", 1024),
300
+ display_height=getattr(self, "display_height", 768),
307
301
  )
308
302
  # PROBLEM: new request is never put into results array, so we can't get the result.
309
- self.retry_queue.put_nowait(new_request)
303
+ assert self.status_tracker.retry_queue
304
+ self.status_tracker.retry_queue.put_nowait(self)
310
305
  # SOLUTION: just need to make sure it's deduplicated by task_id later.
311
306
  self.results_arr.append(new_request)
312
307
  else:
@@ -354,6 +349,8 @@ class APIRequestBase(ABC):
354
349
 
355
350
  except Exception as e:
356
351
  raise_if_modal_exception(e)
352
+ tb = traceback.format_exc()
353
+ print(tb)
357
354
  self.result.append(
358
355
  APIResponse(
359
356
  id=self.task_id,
@@ -381,39 +378,52 @@ def create_api_request(
381
378
  prompt: Conversation,
382
379
  attempts_left: int,
383
380
  status_tracker: StatusTracker,
384
- retry_queue: asyncio.Queue,
385
381
  results_arr: list["APIRequestBase"],
386
382
  request_timeout: int = 30,
387
383
  sampling_params: SamplingParams = SamplingParams(),
388
- logprobs: bool = False,
389
- top_logprobs: int | None = None,
390
- pbar: tqdm | None = None,
391
384
  callback: Callable | None = None,
392
385
  all_model_names: list[str] | None = None,
393
386
  all_sampling_params: list[SamplingParams] | None = None,
394
387
  tools: list | None = None,
395
388
  cache: CachePattern | None = None,
389
+ computer_use: bool = False,
390
+ display_width: int = 1024,
391
+ display_height: int = 768,
392
+ use_responses_api: bool = False,
396
393
  ) -> APIRequestBase:
397
394
  from .common import CLASSES # circular import so made it lazy, does this work?
398
395
 
399
396
  model_obj = APIModel.from_registry(model_name)
400
- request_class = CLASSES.get(model_obj.api_spec, None)
397
+
398
+ # Choose API spec based on use_responses_api flag and model support
399
+ api_spec = model_obj.api_spec
400
+ if use_responses_api and model_obj.supports_responses and api_spec == "openai":
401
+ api_spec = "openai-responses"
402
+
403
+ request_class = CLASSES.get(api_spec, None)
401
404
  if request_class is None:
402
- raise ValueError(f"Unsupported API spec: {model_obj.api_spec}")
403
- kwargs = (
404
- {} if not logprobs else {"logprobs": logprobs, "top_logprobs": top_logprobs}
405
- )
405
+ raise ValueError(f"Unsupported API spec: {api_spec}")
406
+ kwargs = {}
407
+ # Add computer_use to kwargs if the request class supports it
408
+ model_obj = APIModel.from_registry(model_name)
409
+ if computer_use and api_spec in ["anthropic", "bedrock", "openai-responses"]:
410
+ kwargs.update(
411
+ {
412
+ "computer_use": computer_use,
413
+ "display_width": display_width,
414
+ "display_height": display_height,
415
+ }
416
+ )
417
+
406
418
  return request_class(
407
419
  task_id=task_id,
408
420
  model_name=model_name,
409
421
  prompt=prompt,
410
422
  attempts_left=attempts_left,
411
423
  status_tracker=status_tracker,
412
- retry_queue=retry_queue,
413
424
  results_arr=results_arr,
414
425
  request_timeout=request_timeout,
415
426
  sampling_params=sampling_params,
416
- pbar=pbar,
417
427
  callback=callback,
418
428
  all_model_names=all_model_names,
419
429
  all_sampling_params=all_sampling_params,
@@ -421,3 +431,22 @@ def create_api_request(
421
431
  cache=cache,
422
432
  **kwargs,
423
433
  )
434
+
435
+
436
+ def deduplicate_responses(results: list[APIRequestBase]) -> list[APIResponse]:
437
+ deduplicated = {}
438
+ for request in results:
439
+ if request.task_id not in deduplicated:
440
+ deduplicated[request.task_id] = request.result[-1]
441
+ else:
442
+ current_response: APIResponse = deduplicated[request.task_id]
443
+ # only replace if the current request has no completion and the new one does
444
+ if (
445
+ request.result[-1].completion is not None
446
+ and current_response.completion is None
447
+ ):
448
+ deduplicated[request.task_id] = request.result[-1]
449
+
450
+ output = [deduplicated[request.task_id] for request in results]
451
+
452
+ return output
@@ -2,7 +2,6 @@ import asyncio
2
2
  import json
3
3
  import os
4
4
  from aiohttp import ClientResponse
5
- from tqdm import tqdm
6
5
  from typing import Callable
7
6
 
8
7
  try:
@@ -24,7 +23,7 @@ from lm_deluge.usage import Usage
24
23
  from .base import APIRequestBase, APIResponse
25
24
 
26
25
  from ..tracker import StatusTracker
27
- from ..sampling_params import SamplingParams
26
+ from ..config import SamplingParams
28
27
  from ..models import APIModel
29
28
 
30
29
 
@@ -36,17 +35,18 @@ class BedrockRequest(APIRequestBase):
36
35
  prompt: Conversation,
37
36
  attempts_left: int,
38
37
  status_tracker: StatusTracker,
39
- retry_queue: asyncio.Queue,
40
38
  results_arr: list,
41
39
  request_timeout: int = 30,
42
40
  sampling_params: SamplingParams = SamplingParams(),
43
- pbar: tqdm | None = None,
44
41
  callback: Callable | None = None,
45
- debug: bool = False,
46
42
  all_model_names: list[str] | None = None,
47
43
  all_sampling_params: list[SamplingParams] | None = None,
48
44
  tools: list | None = None,
49
45
  cache: CachePattern | None = None,
46
+ # Computer Use support
47
+ computer_use: bool = False,
48
+ display_width: int = 1024,
49
+ display_height: int = 768,
50
50
  ):
51
51
  super().__init__(
52
52
  task_id=task_id,
@@ -54,19 +54,20 @@ class BedrockRequest(APIRequestBase):
54
54
  prompt=prompt,
55
55
  attempts_left=attempts_left,
56
56
  status_tracker=status_tracker,
57
- retry_queue=retry_queue,
58
57
  results_arr=results_arr,
59
58
  request_timeout=request_timeout,
60
59
  sampling_params=sampling_params,
61
- pbar=pbar,
62
60
  callback=callback,
63
- debug=debug,
64
61
  all_model_names=all_model_names,
65
62
  all_sampling_params=all_sampling_params,
66
63
  tools=tools,
67
64
  cache=cache,
68
65
  )
69
66
 
67
+ self.computer_use = computer_use
68
+ self.display_width = display_width
69
+ self.display_height = display_height
70
+
70
71
  # Lock images as bytes if caching is enabled
71
72
  if cache is not None:
72
73
  prompt.lock_images_as_bytes()
@@ -115,11 +116,34 @@ class BedrockRequest(APIRequestBase):
115
116
  if self.system_message is not None:
116
117
  self.request_json["system"] = self.system_message
117
118
 
118
- if tools:
119
- tool_definitions = [tool.dump_for("anthropic") for tool in tools]
119
+ if tools or self.computer_use:
120
+ tool_definitions = []
121
+
122
+ # Add Computer Use tools at the beginning if enabled
123
+ if self.computer_use:
124
+ from ..computer_use.anthropic_tools import get_anthropic_cu_tools
125
+
126
+ cu_tools = get_anthropic_cu_tools(
127
+ model=self.model.id,
128
+ display_width=self.display_width,
129
+ display_height=self.display_height,
130
+ )
131
+ tool_definitions.extend(cu_tools)
132
+
133
+ # Add computer use display parameters to the request
134
+ self.request_json["computer_use_display_width_px"] = self.display_width
135
+ self.request_json["computer_use_display_height_px"] = (
136
+ self.display_height
137
+ )
138
+
139
+ # Add user-provided tools
140
+ if tools:
141
+ tool_definitions.extend([tool.dump_for("anthropic") for tool in tools])
142
+
120
143
  # Add cache control to last tool if tools_only caching is specified
121
144
  if cache == "tools_only" and tool_definitions:
122
145
  tool_definitions[-1]["cache_control"] = {"type": "ephemeral"}
146
+
123
147
  self.request_json["tools"] = tool_definitions
124
148
 
125
149
  # Setup AWS4Auth for signing
@@ -1,10 +1,11 @@
1
- from .openai import OpenAIRequest
1
+ from .openai import OpenAIRequest, OpenAIResponsesRequest
2
2
  from .anthropic import AnthropicRequest
3
3
  from .mistral import MistralRequest
4
4
  from .bedrock import BedrockRequest
5
5
 
6
6
  CLASSES = {
7
7
  "openai": OpenAIRequest,
8
+ "openai-responses": OpenAIResponsesRequest,
8
9
  "anthropic": AnthropicRequest,
9
10
  "mistral": MistralRequest,
10
11
  "bedrock": BedrockRequest,
@@ -1,16 +1,14 @@
1
- import asyncio
2
1
  import warnings
3
2
  from aiohttp import ClientResponse
4
3
  import json
5
4
  import os
6
- from tqdm.auto import tqdm
7
5
  from typing import Callable
8
6
 
9
7
  from .base import APIRequestBase, APIResponse
10
8
  from ..prompt import Conversation, Message, CachePattern
11
9
  from ..usage import Usage
12
10
  from ..tracker import StatusTracker
13
- from ..sampling_params import SamplingParams
11
+ from ..config import SamplingParams
14
12
  from ..models import APIModel
15
13
 
16
14
 
@@ -24,15 +22,10 @@ class MistralRequest(APIRequestBase):
24
22
  prompt: Conversation,
25
23
  attempts_left: int,
26
24
  status_tracker: StatusTracker,
27
- retry_queue: asyncio.Queue,
28
25
  results_arr: list,
29
26
  request_timeout: int = 30,
30
27
  sampling_params: SamplingParams = SamplingParams(),
31
- logprobs: bool = False,
32
- top_logprobs: int | None = None,
33
- pbar: tqdm | None = None,
34
28
  callback: Callable | None = None,
35
- debug: bool = False,
36
29
  all_model_names: list[str] | None = None,
37
30
  all_sampling_params: list[SamplingParams] | None = None,
38
31
  tools: list | None = None,
@@ -44,15 +37,10 @@ class MistralRequest(APIRequestBase):
44
37
  prompt=prompt,
45
38
  attempts_left=attempts_left,
46
39
  status_tracker=status_tracker,
47
- retry_queue=retry_queue,
48
40
  results_arr=results_arr,
49
41
  request_timeout=request_timeout,
50
42
  sampling_params=sampling_params,
51
- logprobs=logprobs,
52
- top_logprobs=top_logprobs,
53
- pbar=pbar,
54
43
  callback=callback,
55
- debug=debug,
56
44
  all_model_names=all_model_names,
57
45
  all_sampling_params=all_sampling_params,
58
46
  tools=tools,
@@ -80,7 +68,7 @@ class MistralRequest(APIRequestBase):
80
68
  warnings.warn(
81
69
  f"Ignoring reasoning_effort param for non-reasoning model: {model_name}"
82
70
  )
83
- if logprobs:
71
+ if sampling_params.logprobs:
84
72
  warnings.warn(
85
73
  f"Ignoring logprobs param for non-logprobs model: {model_name}"
86
74
  )
@@ -109,7 +97,10 @@ class MistralRequest(APIRequestBase):
109
97
  try:
110
98
  completion = data["choices"][0]["message"]["content"]
111
99
  usage = Usage.from_mistral_usage(data["usage"])
112
- if self.logprobs and "logprobs" in data["choices"][0]:
100
+ if (
101
+ self.sampling_params.logprobs
102
+ and "logprobs" in data["choices"][0]
103
+ ):
113
104
  logprobs = data["choices"][0]["logprobs"]["content"]
114
105
  except Exception:
115
106
  is_error = True