lm-deluge 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

lm_deluge/__init__.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from .client import LLMClient, SamplingParams, APIResponse
2
+ from .prompt import Conversation, Message
2
3
  import dotenv
3
4
 
4
5
  dotenv.load_dotenv()
5
6
 
6
- __all__ = ["LLMClient", "SamplingParams", "APIResponse"]
7
+ __all__ = ["LLMClient", "SamplingParams", "APIResponse", "Conversation", "Message"]
@@ -41,6 +41,7 @@ class APIResponse:
41
41
  logprobs: list | None = None
42
42
  finish_reason: str | None = None # make required later
43
43
  cost: float | None = None # calculated automatically
44
+ cache_hit: bool = False # manually set if true
44
45
  # set to true if is_error and should be retried with a different model
45
46
  retry_with_different_model: bool | None = False
46
47
  # set to true if should NOT retry with the same model (unrecoverable error)
@@ -1,18 +1,9 @@
1
- # from .vertex import VertexAnthropicRequest, GeminiRequest
2
- # from .bedrock import BedrockAnthropicRequest, MistralBedrockRequest
3
- # from .deepseek import DeepseekRequest
4
1
  from .openai import OpenAIRequest
5
- from .cohere import CohereRequest
6
2
  from .anthropic import AnthropicRequest
3
+ from .mistral import MistralRequest
7
4
 
8
5
  CLASSES = {
9
6
  "openai": OpenAIRequest,
10
- # "deepseek": DeepseekRequest,
11
7
  "anthropic": AnthropicRequest,
12
- # "vertex_anthropic": VertexAnthropicRequest,
13
- # "vertex_gemini": GeminiRequest,
14
- "cohere": CohereRequest,
15
- # "bedrock_anthropic": BedrockAnthropicRequest,
16
- # "bedrock_mistral": MistralBedrockRequest,
17
- # "mistral": MistralRequest,
8
+ "mistral": MistralRequest,
18
9
  }
@@ -0,0 +1,132 @@
1
+ # # https://docs.cohere.com/reference/chat
2
+ # # https://cohere.com/pricing
3
+ # import asyncio
4
+ # from aiohttp import ClientResponse
5
+ # import json
6
+ # import os
7
+ # from tqdm import tqdm
8
+ # from typing import Callable
9
+ # from lm_deluge.prompt import Conversation
10
+ # from .base import APIRequestBase, APIResponse
11
+
12
+ # from ..tracker import StatusTracker
13
+ # from ..sampling_params import SamplingParams
14
+ # from ..models import APIModel
15
+
16
+
17
+ # class CohereRequest(APIRequestBase):
18
+ # def __init__(
19
+ # self,
20
+ # task_id: int,
21
+ # # should always be 'role', 'content' keys.
22
+ # # internal logic should handle translating to specific API format
23
+ # model_name: str, # must correspond to registry
24
+ # prompt: Conversation,
25
+ # attempts_left: int,
26
+ # status_tracker: StatusTracker,
27
+ # results_arr: list,
28
+ # retry_queue: asyncio.Queue,
29
+ # request_timeout: int = 30,
30
+ # sampling_params: SamplingParams = SamplingParams(),
31
+ # pbar: tqdm | None = None,
32
+ # callback: Callable | None = None,
33
+ # debug: bool = False,
34
+ # all_model_names: list[str] | None = None,
35
+ # all_sampling_params: list[SamplingParams] | None = None,
36
+ # ):
37
+ # super().__init__(
38
+ # task_id=task_id,
39
+ # model_name=model_name,
40
+ # prompt=prompt,
41
+ # attempts_left=attempts_left,
42
+ # status_tracker=status_tracker,
43
+ # retry_queue=retry_queue,
44
+ # results_arr=results_arr,
45
+ # request_timeout=request_timeout,
46
+ # sampling_params=sampling_params,
47
+ # pbar=pbar,
48
+ # callback=callback,
49
+ # debug=debug,
50
+ # all_model_names=all_model_names,
51
+ # all_sampling_params=all_sampling_params,
52
+ # )
53
+ # self.system_message = None
54
+ # self.last_user_message = None
55
+
56
+ # self.model = APIModel.from_registry(model_name)
57
+ # self.url = f"{self.model.api_base}/chat"
58
+ # messages = prompt.to_cohere()
59
+
60
+ # self.request_header = {
61
+ # "Authorization": f"bearer {os.getenv(self.model.api_key_env_var)}",
62
+ # "content-type": "application/json",
63
+ # "accept": "application/json",
64
+ # }
65
+
66
+ # self.request_json = {
67
+ # "model": self.model.name,
68
+ # "messages": messages,
69
+ # "temperature": sampling_params.temperature,
70
+ # "top_p": sampling_params.top_p,
71
+ # "max_tokens": sampling_params.max_new_tokens,
72
+ # }
73
+
74
+ # async def handle_response(self, http_response: ClientResponse) -> APIResponse:
75
+ # is_error = False
76
+ # error_message = None
77
+ # completion = None
78
+ # input_tokens = None
79
+ # output_tokens = None
80
+ # status_code = http_response.status
81
+ # mimetype = http_response.headers.get("Content-Type", None)
82
+ # if status_code >= 200 and status_code < 300:
83
+ # try:
84
+ # data = await http_response.json()
85
+ # except Exception:
86
+ # data = None
87
+ # is_error = True
88
+ # error_message = (
89
+ # f"Error calling .json() on response w/ status {status_code}"
90
+ # )
91
+ # if not is_error and isinstance(data, dict):
92
+ # try:
93
+ # completion = data["text"]
94
+ # input_tokens = data["meta"]["billed_units"]["input_tokens"]
95
+ # output_tokens = data["meta"]["billed_units"]["input_tokens"]
96
+ # except Exception:
97
+ # is_error = True
98
+ # error_message = f"Error getting 'text' or 'meta' from {self.model.name} response."
99
+ # elif mimetype is not None and "json" in mimetype.lower():
100
+ # is_error = True # expected status is 200, otherwise it's an error
101
+ # data = await http_response.json()
102
+ # error_message = json.dumps(data)
103
+
104
+ # else:
105
+ # is_error = True
106
+ # text = await http_response.text()
107
+ # error_message = text
108
+
109
+ # # handle special kinds of errors. TODO: make sure these are correct for anthropic
110
+ # if is_error and error_message is not None:
111
+ # if (
112
+ # "rate limit" in error_message.lower()
113
+ # or "overloaded" in error_message.lower()
114
+ # ):
115
+ # error_message += " (Rate limit error, triggering cooldown.)"
116
+ # self.status_tracker.rate_limit_exceeded()
117
+ # if "context length" in error_message:
118
+ # error_message += " (Context length exceeded, set retries to 0.)"
119
+ # self.attempts_left = 0
120
+
121
+ # return APIResponse(
122
+ # id=self.task_id,
123
+ # status_code=status_code,
124
+ # is_error=is_error,
125
+ # error_message=error_message,
126
+ # prompt=self.prompt,
127
+ # completion=completion,
128
+ # model_internal=self.model_name,
129
+ # sampling_params=self.sampling_params,
130
+ # input_tokens=input_tokens,
131
+ # output_tokens=output_tokens,
132
+ # )
@@ -0,0 +1,361 @@
1
+ # # consider: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/call-gemini-using-openai-library#call-chat-completions-api
2
+ # import asyncio
3
+ # from aiohttp import ClientResponse
4
+ # import json
5
+ # import os
6
+ # import time
7
+ # from tqdm import tqdm
8
+ # from typing import Callable
9
+
10
+ # from lm_deluge.prompt import Conversation
11
+ # from .base import APIRequestBase, APIResponse
12
+ # from ..tracker import StatusTracker
13
+ # from ..sampling_params import SamplingParams
14
+ # from ..models import APIModel
15
+
16
+ # from google.oauth2 import service_account
17
+ # from google.auth.transport.requests import Request
18
+
19
+
20
+ # def get_access_token(service_account_file: str):
21
+ # """
22
+ # Get access token from environment variables if another process/coroutine
23
+ # has already got them, otherwise get from service account file.
24
+ # """
25
+ # LAST_REFRESHED = os.getenv("VERTEX_TOKEN_LAST_REFRESHED", None)
26
+ # LAST_REFRESHED = int(LAST_REFRESHED) if LAST_REFRESHED is not None else 0
27
+ # VERTEX_API_TOKEN = os.getenv("VERTEX_API_TOKEN", None)
28
+
29
+ # if VERTEX_API_TOKEN is not None and time.time() - LAST_REFRESHED < 60 * 50:
30
+ # return VERTEX_API_TOKEN
31
+ # else:
32
+ # credentials = service_account.Credentials.from_service_account_file(
33
+ # service_account_file,
34
+ # scopes=["https://www.googleapis.com/auth/cloud-platform"],
35
+ # )
36
+ # credentials.refresh(Request())
37
+ # token = credentials.token
38
+ # os.environ["VERTEX_API_TOKEN"] = token
39
+ # os.environ["VERTEX_TOKEN_LAST_REFRESHED"] = str(int(time.time()))
40
+
41
+ # return token
42
+
43
+
44
+ # class VertexAnthropicRequest(APIRequestBase):
45
+ # """
46
+ # For Claude on Vertex, you'll also have to set the PROJECT_ID environment variable.
47
+ # """
48
+
49
+ # def __init__(
50
+ # self,
51
+ # task_id: int,
52
+ # model_name: str, # must correspond to registry
53
+ # prompt: Conversation,
54
+ # attempts_left: int,
55
+ # status_tracker: StatusTracker,
56
+ # retry_queue: asyncio.Queue,
57
+ # results_arr: list,
58
+ # request_timeout: int = 30,
59
+ # sampling_params: SamplingParams = SamplingParams(),
60
+ # pbar: tqdm | None = None,
61
+ # callback: Callable | None = None,
62
+ # debug: bool = False,
63
+ # ):
64
+ # super().__init__(
65
+ # task_id=task_id,
66
+ # model_name=model_name,
67
+ # prompt=prompt,
68
+ # attempts_left=attempts_left,
69
+ # status_tracker=status_tracker,
70
+ # retry_queue=retry_queue,
71
+ # results_arr=results_arr,
72
+ # request_timeout=request_timeout,
73
+ # sampling_params=sampling_params,
74
+ # pbar=pbar,
75
+ # callback=callback,
76
+ # debug=debug,
77
+ # )
78
+ # creds = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
79
+ # if not creds:
80
+ # raise RuntimeError(
81
+ # "GOOGLE_APPLICATION_CREDENTIALS not provided in environment"
82
+ # )
83
+ # token = get_access_token(creds)
84
+
85
+ # self.model = APIModel.from_registry(model_name)
86
+ # project_id = os.getenv("PROJECT_ID")
87
+ # region = self.model.sample_region()
88
+
89
+ # endpoint = f"https://{region}-aiplatform.googleapis.com"
90
+ # self.url = f"{endpoint}/v1/projects/{project_id}/locations/{region}/publishers/anthropic/models/{self.model.name}:generateContent"
91
+ # self.request_header = {
92
+ # "Authorization": f"Bearer {token}",
93
+ # "Content-Type": "application/json",
94
+ # }
95
+ # self.system_message, messages = prompt.to_anthropic()
96
+
97
+ # self.request_json = {
98
+ # "anthropic_version": "vertex-2023-10-16",
99
+ # "messages": messages,
100
+ # "temperature": self.sampling_params.temperature,
101
+ # "top_p": self.sampling_params.top_p,
102
+ # "max_tokens": self.sampling_params.max_new_tokens,
103
+ # }
104
+ # if self.system_message is not None:
105
+ # self.request_json["system"] = self.system_message
106
+
107
+ # async def handle_response(self, http_response: ClientResponse) -> APIResponse:
108
+ # is_error = False
109
+ # error_message = None
110
+ # completion = None
111
+ # input_tokens = None
112
+ # output_tokens = None
113
+ # status_code = http_response.status
114
+ # mimetype = http_response.headers.get("Content-Type", None)
115
+ # if status_code >= 200 and status_code < 300:
116
+ # try:
117
+ # data = await http_response.json()
118
+ # completion = data["content"][0]["text"]
119
+ # input_tokens = data["usage"]["input_tokens"]
120
+ # output_tokens = data["usage"]["output_tokens"]
121
+ # except Exception as e:
122
+ # is_error = True
123
+ # error_message = (
124
+ # f"Error calling .json() on response w/ status {status_code}: {e}"
125
+ # )
126
+ # elif "json" in (mimetype or "").lower():
127
+ # is_error = True # expected status is 200, otherwise it's an error
128
+ # data = await http_response.json()
129
+ # error_message = json.dumps(data)
130
+
131
+ # else:
132
+ # is_error = True
133
+ # text = await http_response.text()
134
+ # error_message = text
135
+
136
+ # # handle special kinds of errors. TODO: make sure these are correct for anthropic
137
+ # if is_error and error_message is not None:
138
+ # if (
139
+ # "rate limit" in error_message.lower()
140
+ # or "overloaded" in error_message.lower()
141
+ # or status_code == 429
142
+ # ):
143
+ # error_message += " (Rate limit error, triggering cooldown.)"
144
+ # self.status_tracker.rate_limit_exceeded()
145
+ # if "context length" in error_message:
146
+ # error_message += " (Context length exceeded, set retries to 0.)"
147
+ # self.attempts_left = 0
148
+
149
+ # return APIResponse(
150
+ # id=self.task_id,
151
+ # status_code=status_code,
152
+ # is_error=is_error,
153
+ # error_message=error_message,
154
+ # prompt=self.prompt,
155
+ # completion=completion,
156
+ # model_internal=self.model_name,
157
+ # sampling_params=self.sampling_params,
158
+ # input_tokens=input_tokens,
159
+ # output_tokens=output_tokens,
160
+ # )
161
+
162
+
163
+ # SAFETY_SETTING_CATEGORIES = [
164
+ # "HARM_CATEGORY_DANGEROUS_CONTENT",
165
+ # "HARM_CATEGORY_HARASSMENT",
166
+ # "HARM_CATEGORY_HATE_SPEECH",
167
+ # "HARM_CATEGORY_SEXUALLY_EXPLICIT",
168
+ # ]
169
+
170
+
171
+ # class GeminiRequest(APIRequestBase):
172
+ # """
173
+ # For Gemini, you'll also have to set the PROJECT_ID environment variable.
174
+ # """
175
+
176
+ # def __init__(
177
+ # self,
178
+ # task_id: int,
179
+ # model_name: str, # must correspond to registry
180
+ # prompt: Conversation,
181
+ # attempts_left: int,
182
+ # status_tracker: StatusTracker,
183
+ # retry_queue: asyncio.Queue,
184
+ # results_arr: list,
185
+ # request_timeout: int = 30,
186
+ # sampling_params: SamplingParams = SamplingParams(),
187
+ # pbar: tqdm | None = None,
188
+ # callback: Callable | None = None,
189
+ # debug: bool = False,
190
+ # all_model_names: list[str] | None = None,
191
+ # all_sampling_params: list[SamplingParams] | None = None,
192
+ # ):
193
+ # super().__init__(
194
+ # task_id=task_id,
195
+ # model_name=model_name,
196
+ # prompt=prompt,
197
+ # attempts_left=attempts_left,
198
+ # status_tracker=status_tracker,
199
+ # retry_queue=retry_queue,
200
+ # results_arr=results_arr,
201
+ # request_timeout=request_timeout,
202
+ # sampling_params=sampling_params,
203
+ # pbar=pbar,
204
+ # callback=callback,
205
+ # debug=debug,
206
+ # all_model_names=all_model_names,
207
+ # all_sampling_params=all_sampling_params,
208
+ # )
209
+ # self.model = APIModel.from_registry(model_name)
210
+ # credentials_file = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
211
+ # if not credentials_file:
212
+ # raise RuntimeError(
213
+ # "no credentials file found. ensure you provide a google credentials file and point to it with GOOGLE_APPLICATION_CREDENTIALS environment variable."
214
+ # )
215
+ # token = get_access_token(credentials_file)
216
+ # self.project_id = os.getenv("PROJECT_ID")
217
+ # # sample weighted by region counts
218
+ # self.region = self.model.sample_region()
219
+ # assert self.region is not None, "unable to sample region"
220
+ # self.url = f"https://{self.region}-aiplatform.googleapis.com/v1/projects/{self.project_id}/locations/{self.region}/publishers/google/models/{self.model.name}:generateContent"
221
+
222
+ # self.request_header = {
223
+ # "Authorization": f"Bearer {token}",
224
+ # "Content-Type": "application/json",
225
+ # }
226
+ # self.system_message, contents = prompt.to_gemini()
227
+ # self.request_json = {
228
+ # "contents": contents,
229
+ # "generationConfig": {
230
+ # "stopSequences": [],
231
+ # "temperature": sampling_params.temperature,
232
+ # "maxOutputTokens": sampling_params.max_new_tokens,
233
+ # "topP": sampling_params.top_p,
234
+ # "topK": None,
235
+ # },
236
+ # "safetySettings": [
237
+ # {"category": category, "threshold": "BLOCK_NONE"}
238
+ # for category in SAFETY_SETTING_CATEGORIES
239
+ # ],
240
+ # }
241
+ # if sampling_params.json_mode and self.model.supports_json:
242
+ # self.request_json["generationConfig"]["responseMimeType"] = (
243
+ # "application/json"
244
+ # )
245
+
246
+ # if self.system_message is not None:
247
+ # self.request_json["systemInstruction"] = (
248
+ # {"role": "SYSTEM", "parts": [{"text": self.system_message}]},
249
+ # )
250
+
251
+ # async def handle_response(self, http_response: ClientResponse) -> APIResponse:
252
+ # is_error = False
253
+ # error_message = None
254
+ # completion = None
255
+ # input_tokens = None
256
+ # output_tokens = None
257
+ # finish_reason = None
258
+ # data = None
259
+ # retry_with_different_model = False
260
+ # give_up_if_no_other_models = False
261
+ # status_code = http_response.status
262
+ # mimetype = http_response.headers.get("Content-Type", None)
263
+ # if status_code >= 200 and status_code < 300:
264
+ # try:
265
+ # data = await http_response.json()
266
+ # if "candidates" not in data:
267
+ # is_error = True
268
+ # if "promptFeedback" in data:
269
+ # error_message = "Prompt rejected. Feedback: " + str(
270
+ # data["promptFeedback"]
271
+ # )
272
+ # else:
273
+ # error_message = "No candidates in response."
274
+ # retry_with_different_model = True
275
+ # give_up_if_no_other_models = True
276
+ # else:
277
+ # candidate = data["candidates"][0]
278
+ # finish_reason = candidate["finishReason"]
279
+ # if "content" in candidate:
280
+ # parts = candidate["content"]["parts"]
281
+ # completion = " ".join([part["text"] for part in parts])
282
+ # usage = data["usageMetadata"]
283
+ # input_tokens = usage["promptTokenCount"]
284
+ # output_tokens = usage["candidatesTokenCount"]
285
+ # elif finish_reason == "RECITATION":
286
+ # is_error = True
287
+ # citations = candidate.get("citationMetadata", {}).get(
288
+ # "citations", []
289
+ # )
290
+ # urls = ",".join(
291
+ # [citation.get("uri", "") for citation in citations]
292
+ # )
293
+ # error_message = "Finish reason RECITATION. URLS: " + urls
294
+ # retry_with_different_model = True
295
+ # elif finish_reason == "OTHER":
296
+ # is_error = True
297
+ # error_message = "Finish reason OTHER."
298
+ # retry_with_different_model = True
299
+ # elif finish_reason == "SAFETY":
300
+ # is_error = True
301
+ # error_message = "Finish reason SAFETY."
302
+ # retry_with_different_model = True
303
+ # else:
304
+ # print("Actual structure of response:", data)
305
+ # is_error = True
306
+ # error_message = "No content in response."
307
+ # except Exception as e:
308
+ # is_error = True
309
+ # error_message = f"Error calling .json() on response w/ status {status_code}: {e.__class__} {e}"
310
+ # if isinstance(e, KeyError):
311
+ # print("Actual structure of response:", data)
312
+ # elif "json" in (mimetype or "").lower():
313
+ # is_error = True
314
+ # data = await http_response.json()
315
+ # error_message = json.dumps(data)
316
+ # else:
317
+ # is_error = True
318
+ # text = await http_response.text()
319
+ # error_message = text
320
+
321
+ # old_region = self.region
322
+ # if is_error and error_message is not None:
323
+ # if (
324
+ # "rate limit" in error_message.lower()
325
+ # or "temporarily out of capacity" in error_message.lower()
326
+ # or "exceeded" in error_message.lower()
327
+ # or
328
+ # # 429 code
329
+ # status_code == 429
330
+ # ):
331
+ # error_message += " (Rate limit error, triggering cooldown & retrying with different model.)"
332
+ # self.status_tracker.rate_limit_exceeded()
333
+ # retry_with_different_model = (
334
+ # True # if possible, retry with a different model
335
+ # )
336
+ # if is_error:
337
+ # # change the region in case error is due to region unavailability
338
+ # self.region = self.model.sample_region()
339
+ # assert self.region is not None, "Unable to sample region"
340
+ # self.url = f"https://{self.region}-aiplatform.googleapis.com/v1/projects/{self.project_id}/locations/{self.region}/publishers/google/models/{self.model.name}:generateContent"
341
+
342
+ # return APIResponse(
343
+ # id=self.task_id,
344
+ # status_code=status_code,
345
+ # is_error=is_error,
346
+ # error_message=error_message,
347
+ # prompt=self.prompt,
348
+ # completion=completion,
349
+ # model_internal=self.model_name,
350
+ # sampling_params=self.sampling_params,
351
+ # input_tokens=input_tokens,
352
+ # output_tokens=output_tokens,
353
+ # region=old_region,
354
+ # finish_reason=finish_reason,
355
+ # retry_with_different_model=retry_with_different_model,
356
+ # give_up_if_no_other_models=give_up_if_no_other_models,
357
+ # )
358
+
359
+
360
+ # # class LlamaEndpointRequest(APIRequestBase):
361
+ # # raise NotImplementedError("Llama endpoints are not implemented and never will be because Vertex AI sucks ass.")
@@ -1,20 +1,19 @@
1
- # https://docs.cohere.com/reference/chat
2
- # https://cohere.com/pricing
3
1
  import asyncio
2
+ import warnings
4
3
  from aiohttp import ClientResponse
5
4
  import json
6
5
  import os
7
- from tqdm import tqdm
6
+ from tqdm.auto import tqdm
8
7
  from typing import Callable
9
- from lm_deluge.prompt import Conversation
10
- from .base import APIRequestBase, APIResponse
11
8
 
9
+ from .base import APIRequestBase, APIResponse
10
+ from ..prompt import Conversation
12
11
  from ..tracker import StatusTracker
13
12
  from ..sampling_params import SamplingParams
14
13
  from ..models import APIModel
15
14
 
16
15
 
17
- class CohereRequest(APIRequestBase):
16
+ class MistralRequest(APIRequestBase):
18
17
  def __init__(
19
18
  self,
20
19
  task_id: int,
@@ -24,10 +23,12 @@ class CohereRequest(APIRequestBase):
24
23
  prompt: Conversation,
25
24
  attempts_left: int,
26
25
  status_tracker: StatusTracker,
27
- results_arr: list,
28
26
  retry_queue: asyncio.Queue,
27
+ results_arr: list,
29
28
  request_timeout: int = 30,
30
29
  sampling_params: SamplingParams = SamplingParams(),
30
+ logprobs: bool = False,
31
+ top_logprobs: int | None = None,
31
32
  pbar: tqdm | None = None,
32
33
  callback: Callable | None = None,
33
34
  debug: bool = False,
@@ -44,32 +45,36 @@ class CohereRequest(APIRequestBase):
44
45
  results_arr=results_arr,
45
46
  request_timeout=request_timeout,
46
47
  sampling_params=sampling_params,
48
+ logprobs=logprobs,
49
+ top_logprobs=top_logprobs,
47
50
  pbar=pbar,
48
51
  callback=callback,
49
52
  debug=debug,
50
53
  all_model_names=all_model_names,
51
54
  all_sampling_params=all_sampling_params,
52
55
  )
53
- self.system_message = None
54
- self.last_user_message = None
55
-
56
56
  self.model = APIModel.from_registry(model_name)
57
- self.url = f"{self.model.api_base}/chat"
58
- messages = prompt.to_cohere()
59
-
57
+ self.url = f"{self.model.api_base}/chat/completions"
60
58
  self.request_header = {
61
- "Authorization": f"bearer {os.getenv(self.model.api_key_env_var)}",
62
- "content-type": "application/json",
63
- "accept": "application/json",
59
+ "Authorization": f"Bearer {os.getenv(self.model.api_key_env_var)}"
64
60
  }
65
-
66
61
  self.request_json = {
67
62
  "model": self.model.name,
68
- "messages": messages,
63
+ "messages": prompt.to_mistral(),
69
64
  "temperature": sampling_params.temperature,
70
65
  "top_p": sampling_params.top_p,
71
66
  "max_tokens": sampling_params.max_new_tokens,
72
67
  }
68
+ if sampling_params.reasoning_effort:
69
+ warnings.warn(
70
+ f"Ignoring reasoning_effort param for non-reasoning model: {model_name}"
71
+ )
72
+ if logprobs:
73
+ warnings.warn(
74
+ f"Ignoring logprobs param for non-logprobs model: {model_name}"
75
+ )
76
+ if sampling_params.json_mode and self.model.supports_json:
77
+ self.request_json["response_format"] = {"type": "json_object"}
73
78
 
74
79
  async def handle_response(self, http_response: ClientResponse) -> APIResponse:
75
80
  is_error = False
@@ -77,41 +82,41 @@ class CohereRequest(APIRequestBase):
77
82
  completion = None
78
83
  input_tokens = None
79
84
  output_tokens = None
85
+ logprobs = None
80
86
  status_code = http_response.status
81
87
  mimetype = http_response.headers.get("Content-Type", None)
88
+ data = None
82
89
  if status_code >= 200 and status_code < 300:
83
90
  try:
84
91
  data = await http_response.json()
85
92
  except Exception:
86
- data = None
87
93
  is_error = True
88
94
  error_message = (
89
95
  f"Error calling .json() on response w/ status {status_code}"
90
96
  )
91
- if not is_error and isinstance(data, dict):
97
+ if not is_error:
98
+ assert data is not None, "data is None"
92
99
  try:
93
- completion = data["text"]
94
- input_tokens = data["meta"]["billed_units"]["input_tokens"]
95
- output_tokens = data["meta"]["billed_units"]["input_tokens"]
100
+ completion = data["choices"][0]["message"]["content"]
101
+ input_tokens = data["usage"]["prompt_tokens"]
102
+ output_tokens = data["usage"]["completion_tokens"]
103
+ if self.logprobs and "logprobs" in data["choices"][0]:
104
+ logprobs = data["choices"][0]["logprobs"]["content"]
96
105
  except Exception:
97
106
  is_error = True
98
- error_message = f"Error getting 'text' or 'meta' from {self.model.name} response."
99
- elif mimetype is not None and "json" in mimetype.lower():
107
+ error_message = f"Error getting 'choices' and 'usage' from {self.model.name} response."
108
+ elif mimetype and "json" in mimetype.lower():
100
109
  is_error = True # expected status is 200, otherwise it's an error
101
110
  data = await http_response.json()
102
111
  error_message = json.dumps(data)
103
-
104
112
  else:
105
113
  is_error = True
106
114
  text = await http_response.text()
107
115
  error_message = text
108
116
 
109
- # handle special kinds of errors. TODO: make sure these are correct for anthropic
117
+ # handle special kinds of errors
110
118
  if is_error and error_message is not None:
111
- if (
112
- "rate limit" in error_message.lower()
113
- or "overloaded" in error_message.lower()
114
- ):
119
+ if "rate limit" in error_message.lower() or status_code == 429:
115
120
  error_message += " (Rate limit error, triggering cooldown.)"
116
121
  self.status_tracker.rate_limit_exceeded()
117
122
  if "context length" in error_message:
@@ -124,6 +129,7 @@ class CohereRequest(APIRequestBase):
124
129
  is_error=is_error,
125
130
  error_message=error_message,
126
131
  prompt=self.prompt,
132
+ logprobs=logprobs,
127
133
  completion=completion,
128
134
  model_internal=self.model_name,
129
135
  sampling_params=self.sampling_params,