lm-deluge 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

@@ -1,20 +1,19 @@
1
- # https://docs.cohere.com/reference/chat
2
- # https://cohere.com/pricing
3
1
  import asyncio
2
+ import warnings
4
3
  from aiohttp import ClientResponse
5
4
  import json
6
5
  import os
7
- from tqdm import tqdm
6
+ from tqdm.auto import tqdm
8
7
  from typing import Callable
9
- from lm_deluge.prompt import Conversation
10
- from .base import APIRequestBase, APIResponse
11
8
 
9
+ from .base import APIRequestBase, APIResponse
10
+ from ..prompt import Conversation
12
11
  from ..tracker import StatusTracker
13
12
  from ..sampling_params import SamplingParams
14
13
  from ..models import APIModel
15
14
 
16
15
 
17
- class CohereRequest(APIRequestBase):
16
+ class MistralRequest(APIRequestBase):
18
17
  def __init__(
19
18
  self,
20
19
  task_id: int,
@@ -24,10 +23,12 @@ class CohereRequest(APIRequestBase):
24
23
  prompt: Conversation,
25
24
  attempts_left: int,
26
25
  status_tracker: StatusTracker,
27
- results_arr: list,
28
26
  retry_queue: asyncio.Queue,
27
+ results_arr: list,
29
28
  request_timeout: int = 30,
30
29
  sampling_params: SamplingParams = SamplingParams(),
30
+ logprobs: bool = False,
31
+ top_logprobs: int | None = None,
31
32
  pbar: tqdm | None = None,
32
33
  callback: Callable | None = None,
33
34
  debug: bool = False,
@@ -44,36 +45,36 @@ class CohereRequest(APIRequestBase):
44
45
  results_arr=results_arr,
45
46
  request_timeout=request_timeout,
46
47
  sampling_params=sampling_params,
48
+ logprobs=logprobs,
49
+ top_logprobs=top_logprobs,
47
50
  pbar=pbar,
48
51
  callback=callback,
49
52
  debug=debug,
50
53
  all_model_names=all_model_names,
51
54
  all_sampling_params=all_sampling_params,
52
55
  )
53
- self.system_message = None
54
- self.last_user_message = None
55
-
56
56
  self.model = APIModel.from_registry(model_name)
57
- self.url = f"{self.model.api_base}/chat"
58
- self.system_message, chat_history, last_user_message = prompt.to_cohere()
59
-
57
+ self.url = f"{self.model.api_base}/chat/completions"
60
58
  self.request_header = {
61
- "Authorization": f"bearer {os.getenv(self.model.api_key_env_var)}",
62
- "content-type": "application/json",
63
- "accept": "application/json",
59
+ "Authorization": f"Bearer {os.getenv(self.model.api_key_env_var)}"
64
60
  }
65
-
66
61
  self.request_json = {
67
62
  "model": self.model.name,
68
- "chat_history": chat_history,
69
- "message": last_user_message,
63
+ "messages": prompt.to_mistral(),
70
64
  "temperature": sampling_params.temperature,
71
65
  "top_p": sampling_params.top_p,
72
66
  "max_tokens": sampling_params.max_new_tokens,
73
67
  }
74
-
75
- if self.system_message:
76
- self.request_json["preamble"] = self.system_message
68
+ if sampling_params.reasoning_effort:
69
+ warnings.warn(
70
+ f"Ignoring reasoning_effort param for non-reasoning model: {model_name}"
71
+ )
72
+ if logprobs:
73
+ warnings.warn(
74
+ f"Ignoring logprobs param for non-logprobs model: {model_name}"
75
+ )
76
+ if sampling_params.json_mode and self.model.supports_json:
77
+ self.request_json["response_format"] = {"type": "json_object"}
77
78
 
78
79
  async def handle_response(self, http_response: ClientResponse) -> APIResponse:
79
80
  is_error = False
@@ -81,41 +82,41 @@ class CohereRequest(APIRequestBase):
81
82
  completion = None
82
83
  input_tokens = None
83
84
  output_tokens = None
85
+ logprobs = None
84
86
  status_code = http_response.status
85
87
  mimetype = http_response.headers.get("Content-Type", None)
88
+ data = None
86
89
  if status_code >= 200 and status_code < 300:
87
90
  try:
88
91
  data = await http_response.json()
89
92
  except Exception:
90
- data = None
91
93
  is_error = True
92
94
  error_message = (
93
95
  f"Error calling .json() on response w/ status {status_code}"
94
96
  )
95
- if not is_error and isinstance(data, dict):
97
+ if not is_error:
98
+ assert data is not None, "data is None"
96
99
  try:
97
- completion = data["text"]
98
- input_tokens = data["meta"]["billed_units"]["input_tokens"]
99
- output_tokens = data["meta"]["billed_units"]["input_tokens"]
100
+ completion = data["choices"][0]["message"]["content"]
101
+ input_tokens = data["usage"]["prompt_tokens"]
102
+ output_tokens = data["usage"]["completion_tokens"]
103
+ if self.logprobs and "logprobs" in data["choices"][0]:
104
+ logprobs = data["choices"][0]["logprobs"]["content"]
100
105
  except Exception:
101
106
  is_error = True
102
- error_message = f"Error getting 'text' or 'meta' from {self.model.name} response."
103
- elif mimetype is not None and "json" in mimetype.lower():
107
+ error_message = f"Error getting 'choices' and 'usage' from {self.model.name} response."
108
+ elif mimetype and "json" in mimetype.lower():
104
109
  is_error = True # expected status is 200, otherwise it's an error
105
110
  data = await http_response.json()
106
111
  error_message = json.dumps(data)
107
-
108
112
  else:
109
113
  is_error = True
110
114
  text = await http_response.text()
111
115
  error_message = text
112
116
 
113
- # handle special kinds of errors. TODO: make sure these are correct for anthropic
117
+ # handle special kinds of errors
114
118
  if is_error and error_message is not None:
115
- if (
116
- "rate limit" in error_message.lower()
117
- or "overloaded" in error_message.lower()
118
- ):
119
+ if "rate limit" in error_message.lower() or status_code == 429:
119
120
  error_message += " (Rate limit error, triggering cooldown.)"
120
121
  self.status_tracker.rate_limit_exceeded()
121
122
  if "context length" in error_message:
@@ -128,6 +129,7 @@ class CohereRequest(APIRequestBase):
128
129
  is_error=is_error,
129
130
  error_message=error_message,
130
131
  prompt=self.prompt,
132
+ logprobs=logprobs,
131
133
  completion=completion,
132
134
  model_internal=self.model_name,
133
135
  sampling_params=self.sampling_params,
@@ -58,13 +58,18 @@ class OpenAIRequest(APIRequestBase):
58
58
  self.request_header = {
59
59
  "Authorization": f"Bearer {os.getenv(self.model.api_key_env_var)}"
60
60
  }
61
+
61
62
  self.request_json = {
62
63
  "model": self.model.name,
63
64
  "messages": prompt.to_openai(),
64
65
  "temperature": sampling_params.temperature,
65
66
  "top_p": sampling_params.top_p,
66
- "max_completion_tokens": sampling_params.max_new_tokens,
67
67
  }
68
+ # set max_tokens or max_completion_tokens dep. on provider
69
+ if "cohere" in self.model.api_base:
70
+ self.request_json["max_tokens"] = sampling_params.max_new_tokens
71
+ elif "openai" in self.model.api_base:
72
+ self.request_json["max_completion_tokens"] = sampling_params.max_new_tokens
68
73
  if self.model.reasoning_model:
69
74
  self.request_json["temperature"] = 1.0
70
75
  self.request_json["top_p"] = 1.0
@@ -84,6 +89,7 @@ class OpenAIRequest(APIRequestBase):
84
89
  async def handle_response(self, http_response: ClientResponse) -> APIResponse:
85
90
  is_error = False
86
91
  error_message = None
92
+ thinking = None
87
93
  completion = None
88
94
  input_tokens = None
89
95
  output_tokens = None
@@ -103,6 +109,8 @@ class OpenAIRequest(APIRequestBase):
103
109
  assert data is not None, "data is None"
104
110
  try:
105
111
  completion = data["choices"][0]["message"]["content"]
112
+ if "reasoning_content" in data["choices"][0]["message"]:
113
+ thinking = data["choices"][0]["message"]["reasoning_content"]
106
114
  input_tokens = data["usage"]["prompt_tokens"]
107
115
  output_tokens = data["usage"]["completion_tokens"]
108
116
  if self.logprobs and "logprobs" in data["choices"][0]:
@@ -135,6 +143,7 @@ class OpenAIRequest(APIRequestBase):
135
143
  error_message=error_message,
136
144
  prompt=self.prompt,
137
145
  logprobs=logprobs,
146
+ thinking=thinking,
138
147
  completion=completion,
139
148
  model_internal=self.model_name,
140
149
  sampling_params=self.sampling_params,
lm_deluge/client.py CHANGED
@@ -85,6 +85,7 @@ class LLMClient:
85
85
  def __init__(
86
86
  self,
87
87
  model_names: list[str],
88
+ *,
88
89
  max_requests_per_minute: int,
89
90
  max_tokens_per_minute: int,
90
91
  max_concurrent_requests: int,
@@ -345,6 +346,7 @@ class LLMClient:
345
346
 
346
347
  # add cache hits back in
347
348
  for id, res in zip(cache_hit_ids, cache_hit_results):
349
+ res.cache_hit = True
348
350
  results[id] = res
349
351
 
350
352
  if return_completions_only:
lm_deluge/image.py CHANGED
@@ -191,6 +191,12 @@ class Image:
191
191
  },
192
192
  }
193
193
 
194
+ def mistral(self) -> dict:
195
+ return {
196
+ "type": "image_url",
197
+ "image_url": self._base64(),
198
+ }
199
+
194
200
  def gemini(self) -> dict:
195
201
  return {
196
202
  "inlineData": {