lm-deluge 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

@@ -0,0 +1,138 @@
1
+ # https://docs.cohere.com/reference/chat
2
+ # https://cohere.com/pricing
3
+ import asyncio
4
+ from aiohttp import ClientResponse
5
+ import json
6
+ import os
7
+ import time
8
+ from tqdm import tqdm
9
+ from typing import Optional, Callable
10
+ from lm_deluge.prompt import Conversation
11
+ from .base import APIRequestBase, APIResponse
12
+
13
+ from ..tracker import StatusTracker
14
+ from ..sampling_params import SamplingParams
15
+ from ..models import APIModel
16
+
17
+
18
+ class CohereRequest(APIRequestBase):
19
+ def __init__(
20
+ self,
21
+ task_id: int,
22
+ # should always be 'role', 'content' keys.
23
+ # internal logic should handle translating to specific API format
24
+ model_name: str, # must correspond to registry
25
+ prompt: Conversation,
26
+ attempts_left: int,
27
+ status_tracker: StatusTracker,
28
+ results_arr: list,
29
+ retry_queue: asyncio.Queue,
30
+ request_timeout: int = 30,
31
+ sampling_params: SamplingParams = SamplingParams(),
32
+ pbar: Optional[tqdm] = None,
33
+ callback: Optional[Callable] = None,
34
+ debug: bool = False,
35
+ all_model_names: list[str] | None = None,
36
+ all_sampling_params: list[SamplingParams] | None = None,
37
+ ):
38
+ super().__init__(
39
+ task_id=task_id,
40
+ model_name=model_name,
41
+ prompt=prompt,
42
+ attempts_left=attempts_left,
43
+ status_tracker=status_tracker,
44
+ retry_queue=retry_queue,
45
+ results_arr=results_arr,
46
+ request_timeout=request_timeout,
47
+ sampling_params=sampling_params,
48
+ pbar=pbar,
49
+ callback=callback,
50
+ debug=debug,
51
+ all_model_names=all_model_names,
52
+ all_sampling_params=all_sampling_params,
53
+ )
54
+ self.system_message = None
55
+ self.last_user_message = None
56
+
57
+ self.model = APIModel.from_registry(model_name)
58
+ self.url = f"{self.model.api_base}/chat"
59
+ self.system_message, chat_history, last_user_message = prompt.to_cohere()
60
+
61
+ self.request_header = {
62
+ "Authorization": f"bearer {os.getenv(self.model.api_key_env_var)}",
63
+ "content-type": "application/json",
64
+ "accept": "application/json",
65
+ }
66
+
67
+ self.request_json = {
68
+ "model": self.model.name,
69
+ "chat_history": chat_history,
70
+ "message": last_user_message,
71
+ "temperature": sampling_params.temperature,
72
+ "top_p": sampling_params.top_p,
73
+ "max_tokens": sampling_params.max_new_tokens,
74
+ }
75
+
76
+ if self.system_message:
77
+ self.request_json["preamble"] = self.system_message
78
+
79
+ async def handle_response(self, http_response: ClientResponse) -> APIResponse:
80
+ is_error = False
81
+ error_message = None
82
+ completion = None
83
+ input_tokens = None
84
+ output_tokens = None
85
+ status_code = http_response.status
86
+ mimetype = http_response.headers.get("Content-Type", None)
87
+ if status_code >= 200 and status_code < 300:
88
+ try:
89
+ data = await http_response.json()
90
+ except Exception:
91
+ data = None
92
+ is_error = True
93
+ error_message = (
94
+ f"Error calling .json() on response w/ status {status_code}"
95
+ )
96
+ if not is_error and isinstance(data, dict):
97
+ try:
98
+ completion = data["text"]
99
+ input_tokens = data["meta"]["billed_units"]["input_tokens"]
100
+ output_tokens = data["meta"]["billed_units"]["input_tokens"]
101
+ except Exception:
102
+ is_error = True
103
+ error_message = f"Error getting 'text' or 'meta' from {self.model.name} response."
104
+ elif mimetype is not None and "json" in mimetype.lower():
105
+ is_error = True # expected status is 200, otherwise it's an error
106
+ data = await http_response.json()
107
+ error_message = json.dumps(data)
108
+
109
+ else:
110
+ is_error = True
111
+ text = await http_response.text()
112
+ error_message = text
113
+
114
+ # handle special kinds of errors. TODO: make sure these are correct for anthropic
115
+ if is_error and error_message is not None:
116
+ if (
117
+ "rate limit" in error_message.lower()
118
+ or "overloaded" in error_message.lower()
119
+ ):
120
+ error_message += " (Rate limit error, triggering cooldown.)"
121
+ self.status_tracker.time_of_last_rate_limit_error = time.time()
122
+ self.status_tracker.num_rate_limit_errors += 1
123
+ if "context length" in error_message:
124
+ error_message += " (Context length exceeded, set retries to 0.)"
125
+ self.attempts_left = 0
126
+
127
+ return APIResponse(
128
+ id=self.task_id,
129
+ status_code=status_code,
130
+ is_error=is_error,
131
+ error_message=error_message,
132
+ prompt=self.prompt,
133
+ completion=completion,
134
+ model_internal=self.model_name,
135
+ sampling_params=self.sampling_params,
136
+ input_tokens=input_tokens,
137
+ output_tokens=output_tokens,
138
+ )
@@ -0,0 +1,18 @@
1
+ # from .vertex import VertexAnthropicRequest, GeminiRequest
2
+ # from .bedrock import BedrockAnthropicRequest, MistralBedrockRequest
3
+ # from .deepseek import DeepseekRequest
4
+ from .openai import OpenAIRequest
5
+ from .cohere import CohereRequest
6
+ from .anthropic import AnthropicRequest
7
+
8
+ CLASSES = {
9
+ "openai": OpenAIRequest,
10
+ # "deepseek": DeepseekRequest,
11
+ "anthropic": AnthropicRequest,
12
+ # "vertex_anthropic": VertexAnthropicRequest,
13
+ # "vertex_gemini": GeminiRequest,
14
+ "cohere": CohereRequest,
15
+ # "bedrock_anthropic": BedrockAnthropicRequest,
16
+ # "bedrock_mistral": MistralBedrockRequest,
17
+ # "mistral": MistralRequest,
18
+ }
@@ -0,0 +1,288 @@
1
+ # import asyncio
2
+ # import requests
3
+ # from requests.structures import CaseInsensitiveDict
4
+ # from requests_aws4auth import AWS4Auth
5
+ # from aiohttp import ClientResponse
6
+ # import json
7
+ # import os
8
+ # import time
9
+ # from tqdm import tqdm
10
+ # from typing import Optional, Callable
11
+ # from lm_deluge.prompt import Conversation
12
+ # from .base import APIRequestBase, APIResponse
13
+ # from ..tracker import StatusTracker
14
+ # from ..sampling_params import SamplingParams
15
+ # from ..models import APIModel
16
+
17
+
18
+ # def get_aws_headers(
19
+ # access_key_id: str,
20
+ # secret_access_key: str,
21
+ # region: str,
22
+ # url: str,
23
+ # request_json: dict,
24
+ # service: str = "bedrock",
25
+ # ):
26
+ # auth = AWS4Auth(
27
+ # access_key_id,
28
+ # secret_access_key,
29
+ # region,
30
+ # service,
31
+ # )
32
+
33
+ # headers = CaseInsensitiveDict()
34
+ # mock_request = requests.Request(
35
+ # method="POST", url=url, headers=headers, json=request_json
36
+ # ).prepare()
37
+ # auth(mock_request)
38
+ # # print("headers:", mock_request.headers)
39
+ # return mock_request.headers
40
+
41
+
42
+ # class BedrockAnthropicRequest(APIRequestBase):
43
+ # """
44
+ # For Claude on Bedrock, you'll also have to set the PROJECT_ID environment variable.
45
+ # """
46
+
47
+ # def __init__(
48
+ # self,
49
+ # task_id: int,
50
+ # model_name: str, # must correspond to registry
51
+ # prompt: Conversation,
52
+ # attempts_left: int,
53
+ # results_arr: list,
54
+ # status_tracker: StatusTracker,
55
+ # retry_queue: asyncio.Queue,
56
+ # request_timeout: int = 30,
57
+ # sampling_params: SamplingParams = SamplingParams(),
58
+ # pbar: Optional[tqdm] = None,
59
+ # callback: Optional[Callable] = None,
60
+ # debug: bool = False,
61
+ # all_model_names: list[str] | None = None,
62
+ # all_sampling_params: list[SamplingParams] | None = None,
63
+ # ):
64
+ # super().__init__(
65
+ # task_id=task_id,
66
+ # model_name=model_name,
67
+ # prompt=prompt,
68
+ # attempts_left=attempts_left,
69
+ # status_tracker=status_tracker,
70
+ # retry_queue=retry_queue,
71
+ # results_arr=results_arr,
72
+ # request_timeout=request_timeout,
73
+ # sampling_params=sampling_params,
74
+ # pbar=pbar,
75
+ # callback=callback,
76
+ # debug=debug,
77
+ # all_model_names=all_model_names,
78
+ # all_sampling_params=all_sampling_params,
79
+ # )
80
+ # self.model = APIModel.from_registry(model_name)
81
+ # region = self.model.sample_region()
82
+ # assert region is not None, "unable to sample a region"
83
+ # self.url = f"https://bedrock-runtime.{region}.amazonaws.com/model/{self.model.name}/invoke"
84
+ # self.system_message, messages = prompt.to_anthropic()
85
+
86
+ # self.request_json = {
87
+ # "anthropic_version": "bedrock-2023-05-31",
88
+ # "messages": messages,
89
+ # "temperature": self.sampling_params.temperature,
90
+ # "top_p": self.sampling_params.top_p,
91
+ # "max_tokens": self.sampling_params.max_new_tokens,
92
+ # }
93
+ # if self.system_message is not None:
94
+ # self.request_json["system"] = self.system_message
95
+
96
+ # self.request_header = dict(
97
+ # get_aws_headers(
98
+ # access_key_id=os.getenv("AWS_ACCESS_KEY_ID", ""),
99
+ # secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY", ""),
100
+ # region=region,
101
+ # url=self.url,
102
+ # request_json=self.request_json,
103
+ # )
104
+ # )
105
+
106
+ # async def handle_response(self, http_response: ClientResponse) -> APIResponse:
107
+ # is_error = False
108
+ # error_message = None
109
+ # completion = None
110
+ # input_tokens = None
111
+ # output_tokens = None
112
+ # status_code = http_response.status
113
+ # mimetype = http_response.headers.get("Content-Type", None)
114
+ # if status_code >= 200 and status_code < 300:
115
+ # try:
116
+ # data = await http_response.json()
117
+ # completion = data["content"][0]["text"]
118
+ # input_tokens = data["usage"]["input_tokens"]
119
+ # output_tokens = data["usage"]["output_tokens"]
120
+ # except Exception as e:
121
+ # is_error = True
122
+ # error_message = (
123
+ # f"Error calling .json() on response w/ status {status_code}: {e}"
124
+ # )
125
+ # elif "json" in mimetype.lower() if mimetype else "":
126
+ # is_error = True # expected status is 200, otherwise it's an error
127
+ # data = await http_response.json()
128
+ # error_message = json.dumps(data)
129
+
130
+ # else:
131
+ # is_error = True
132
+ # text = await http_response.text()
133
+ # error_message = text
134
+
135
+ # # handle special kinds of errors. TODO: make sure these are correct for anthropic
136
+ # if is_error and error_message is not None:
137
+ # if (
138
+ # "rate limit" in error_message.lower()
139
+ # or "overloaded" in error_message.lower()
140
+ # ):
141
+ # error_message += " (Rate limit error, triggering cooldown.)"
142
+ # self.status_tracker.time_of_last_rate_limit_error = time.time()
143
+ # self.status_tracker.num_rate_limit_errors += 1
144
+ # if "context length" in error_message:
145
+ # error_message += " (Context length exceeded, set retries to 0.)"
146
+ # self.attempts_left = 0
147
+
148
+ # return APIResponse(
149
+ # id=self.task_id,
150
+ # status_code=status_code,
151
+ # is_error=is_error,
152
+ # error_message=error_message,
153
+ # prompt=self.prompt,
154
+ # completion=completion,
155
+ # model_internal=self.model_name,
156
+ # sampling_params=self.sampling_params,
157
+ # input_tokens=input_tokens,
158
+ # output_tokens=output_tokens,
159
+ # )
160
+
161
+
162
+ # class MistralBedrockRequest(APIRequestBase):
163
+ # """
164
+ # Documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html#model-parameters-mistral-request-response
165
+ # """
166
+
167
+ # def __init__(
168
+ # self,
169
+ # task_id: int,
170
+ # model_name: str, # must correspond to registry
171
+ # prompt: Conversation,
172
+ # attempts_left: int,
173
+ # status_tracker: StatusTracker,
174
+ # retry_queue: asyncio.Queue,
175
+ # results_arr: list,
176
+ # request_timeout: int = 30,
177
+ # sampling_params: SamplingParams = SamplingParams(),
178
+ # pbar: Optional[tqdm] = None,
179
+ # callback: Optional[Callable] = None,
180
+ # debug: bool = False,
181
+ # all_model_names: list[str] | None = None,
182
+ # all_sampling_params: list[SamplingParams] | None = None,
183
+ # ):
184
+ # super().__init__(
185
+ # task_id=task_id,
186
+ # model_name=model_name,
187
+ # prompt=prompt,
188
+ # attempts_left=attempts_left,
189
+ # status_tracker=status_tracker,
190
+ # retry_queue=retry_queue,
191
+ # results_arr=results_arr,
192
+ # request_timeout=request_timeout,
193
+ # sampling_params=sampling_params,
194
+ # pbar=pbar,
195
+ # callback=callback,
196
+ # debug=debug,
197
+ # all_model_names=all_model_names,
198
+ # all_sampling_params=all_sampling_params,
199
+ # )
200
+ # self.model = APIModel.from_registry(model_name)
201
+ # self.region = self.model.sample_region()
202
+ # assert self.region is not None, "unable to select a region"
203
+ # self.url = f"https://bedrock-runtime.{self.region}.amazonaws.com/model/{self.model.name}/invoke"
204
+ # self.system_message = None
205
+ # self.request_json = {
206
+ # "prompt": prompt.to_mistral_bedrock(),
207
+ # "max_tokens": self.sampling_params.max_new_tokens,
208
+ # "temperature": self.sampling_params.temperature,
209
+ # "top_p": self.sampling_params.top_p,
210
+ # }
211
+ # self.request_header = dict(
212
+ # get_aws_headers(
213
+ # access_key_id=os.getenv("AWS_ACCESS_KEY_ID", ""),
214
+ # secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY", ""),
215
+ # region=self.region,
216
+ # url=self.url,
217
+ # request_json=self.request_json,
218
+ # )
219
+ # )
220
+
221
+ # async def handle_response(self, http_response: ClientResponse) -> APIResponse:
222
+ # is_error = False
223
+ # error_message: str | None = None
224
+ # completion = None
225
+ # input_tokens = None
226
+ # output_tokens = None
227
+ # status_code = http_response.status
228
+ # mimetype = http_response.headers.get("Content-Type", None)
229
+ # if status_code >= 200 and status_code < 300:
230
+ # try:
231
+ # data = await http_response.json()
232
+ # completion = data["outputs"][0]["text"]
233
+ # input_tokens = len(self.request_json["prompt"]) // 4 # approximate
234
+ # output_tokens = len(completion) // 4 # approximate
235
+ # except Exception as e:
236
+ # is_error = True
237
+ # error_message = (
238
+ # f"Error calling .json() on response w/ status {status_code}: {e}"
239
+ # )
240
+ # elif "json" in (mimetype.lower() if mimetype else ""):
241
+ # is_error = True # expected status is 200, otherwise it's an error
242
+ # data = await http_response.json()
243
+ # error_message = json.dumps(data)
244
+
245
+ # else:
246
+ # is_error = True
247
+ # text = await http_response.text()
248
+ # error_message = (
249
+ # text if isinstance(text, str) else (str(text) if text else "")
250
+ # )
251
+
252
+ # # TODO: Handle rate-limit errors
253
+ # # TODO: in the future, instead of slowing down, switch models?
254
+ # if status_code == 429:
255
+ # assert isinstance(error_message, str)
256
+ # error_message += " (Rate limit error, triggering cooldown.)"
257
+ # self.status_tracker.time_of_last_rate_limit_error = time.time()
258
+ # self.status_tracker.num_rate_limit_errors += 1
259
+
260
+ # # if error, change the region
261
+ # old_region = self.region
262
+ # if is_error:
263
+ # self.region = self.model.sample_region()
264
+ # assert self.region is not None, "could not select a region"
265
+ # self.url = f"https://bedrock-runtime.{self.region}.amazonaws.com/model/{self.model.name}/invoke"
266
+ # self.request_header = dict(
267
+ # get_aws_headers(
268
+ # access_key_id=os.getenv("AWS_ACCESS_KEY_ID", ""),
269
+ # secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY", ""),
270
+ # region=self.region,
271
+ # url=self.url,
272
+ # request_json=self.request_json,
273
+ # )
274
+ # )
275
+
276
+ # return APIResponse(
277
+ # id=self.task_id,
278
+ # status_code=status_code,
279
+ # is_error=is_error,
280
+ # error_message=error_message,
281
+ # prompt=self.prompt,
282
+ # completion=completion,
283
+ # model_internal=self.model_name,
284
+ # sampling_params=self.sampling_params,
285
+ # input_tokens=input_tokens,
286
+ # output_tokens=output_tokens,
287
+ # region=old_region,
288
+ # )
@@ -0,0 +1,118 @@
1
+ # import asyncio
2
+ # from aiohttp import ClientResponse
3
+ # import json
4
+ # import os
5
+ # import time
6
+ # from tqdm import tqdm
7
+ # from typing import Optional, Callable
8
+
9
+ # from .base import APIRequestBase, APIResponse
10
+ # from ..prompt import Prompt
11
+ # from ..tracker import StatusTracker
12
+ # from ..sampling_params import SamplingParams
13
+ # from ..models import APIModel
14
+
15
+
16
+ # class DeepseekRequest(APIRequestBase):
17
+ # def __init__(
18
+ # self,
19
+ # task_id: int,
20
+ # model_name: str, # must correspond to registry
21
+ # prompt: Prompt,
22
+ # attempts_left: int,
23
+ # status_tracker: StatusTracker,
24
+ # retry_queue: asyncio.Queue,
25
+ # results_arr: list,
26
+ # request_timeout: int = 30,
27
+ # sampling_params: SamplingParams = SamplingParams(),
28
+ # pbar: Optional[tqdm] = None,
29
+ # callback: Optional[Callable] = None,
30
+ # debug: bool = False,
31
+ # all_model_names: list[str] = None,
32
+ # all_sampling_params: list[SamplingParams] = None,
33
+ # ):
34
+ # super().__init__(
35
+ # task_id=task_id,
36
+ # model_name=model_name,
37
+ # prompt=prompt,
38
+ # attempts_left=attempts_left,
39
+ # status_tracker=status_tracker,
40
+ # retry_queue=retry_queue,
41
+ # results_arr=results_arr,
42
+ # request_timeout=request_timeout,
43
+ # sampling_params=sampling_params,
44
+ # pbar=pbar,
45
+ # callback=callback,
46
+ # debug=debug,
47
+ # all_model_names=all_model_names,
48
+ # all_sampling_params=all_sampling_params,
49
+ # )
50
+ # self.model = APIModel.from_registry(model_name)
51
+ # self.url = f"{self.model.api_base}/chat/completions"
52
+ # self.request_header = {
53
+ # "Authorization": f"Bearer {os.getenv(self.model.api_key_env_var)}"
54
+ # }
55
+ # if prompt.image is not None:
56
+ # raise ValueError("Deepseek does not support images.")
57
+
58
+ # self.request_json = {
59
+ # "model": self.model.name,
60
+ # "messages": prompt.to_openai(),
61
+ # "temperature": sampling_params.temperature,
62
+ # "top_p": sampling_params.top_p,
63
+ # "max_tokens": sampling_params.max_new_tokens,
64
+ # }
65
+ # if sampling_params.json_mode and self.model.supports_json:
66
+ # self.request_json["response_format"] = {"type": "json_object"}
67
+
68
+ # async def handle_response(self, response: ClientResponse) -> APIResponse:
69
+ # is_error = False
70
+ # error_message = None
71
+ # completion = None
72
+ # input_tokens = None
73
+ # output_tokens = None
74
+ # status_code = response.status
75
+ # mimetype = response.headers.get("Content-Type", None)
76
+ # if status_code >= 200 and status_code < 300:
77
+ # try:
78
+ # data = await response.json()
79
+ # completion = data["choices"][0]["message"]["content"]
80
+ # input_tokens = data["usage"]["prompt_tokens"]
81
+ # output_tokens = data["usage"]["completion_tokens"]
82
+
83
+ # except Exception:
84
+ # is_error = True
85
+ # error_message = (
86
+ # f"Error calling .json() on response w/ status {status_code}"
87
+ # )
88
+ # elif "json" in mimetype.lower():
89
+ # is_error = True # expected status is 200, otherwise it's an error
90
+ # data = await response.json()
91
+ # error_message = json.dumps(data)
92
+ # else:
93
+ # is_error = True
94
+ # text = await response.text()
95
+ # error_message = text
96
+
97
+ # # handle special kinds of errors
98
+ # if is_error and error_message is not None:
99
+ # if "rate limit" in error_message.lower():
100
+ # error_message += " (Rate limit error, triggering cooldown.)"
101
+ # self.status_tracker.time_of_last_rate_limit_error = time.time()
102
+ # self.status_tracker.num_rate_limit_errors += 1
103
+ # if "context length" in error_message:
104
+ # error_message += " (Context length exceeded, set retries to 0.)"
105
+ # self.attempts_left = 0
106
+
107
+ # return APIResponse(
108
+ # id=self.task_id,
109
+ # status_code=status_code,
110
+ # is_error=is_error,
111
+ # error_message=error_message,
112
+ # prompt=self.prompt,
113
+ # completion=completion,
114
+ # model_internal=self.model_name,
115
+ # sampling_params=self.sampling_params,
116
+ # input_tokens=input_tokens,
117
+ # output_tokens=output_tokens,
118
+ # )
@@ -0,0 +1,120 @@
1
+ # import asyncio
2
+ # from aiohttp import ClientResponse
3
+ # import json
4
+ # import os
5
+ # import time
6
+ # from tqdm import tqdm
7
+ # from typing import Optional, Callable
8
+
9
+ # from .base import APIRequestBase, APIResponse
10
+ # from ..prompt import Prompt
11
+ # from ..tracker import StatusTracker
12
+ # from ..sampling_params import SamplingParams
13
+ # from ..models import APIModel
14
+
15
+
16
+ # class MistralRequest(APIRequestBase):
17
+ # def __init__(
18
+ # self,
19
+ # task_id: int,
20
+ # # should always be 'role', 'content' keys.
21
+ # # internal logic should handle translating to specific API format
22
+ # model_name: str, # must correspond to registry
23
+ # prompt: Prompt,
24
+ # attempts_left: int,
25
+ # status_tracker: StatusTracker,
26
+ # retry_queue: asyncio.Queue,
27
+ # results_arr: list,
28
+ # request_timeout: int = 30,
29
+ # sampling_params: SamplingParams = SamplingParams(),
30
+ # pbar: Optional[tqdm] = None,
31
+ # callback: Optional[Callable] = None,
32
+ # debug: bool = False,
33
+ # all_model_names: list[str] = None,
34
+ # all_sampling_params: list[SamplingParams] = None,
35
+ # ):
36
+ # super().__init__(
37
+ # task_id=task_id,
38
+ # model_name=model_name,
39
+ # prompt=prompt,
40
+ # attempts_left=attempts_left,
41
+ # status_tracker=status_tracker,
42
+ # retry_queue=retry_queue,
43
+ # results_arr=results_arr,
44
+ # request_timeout=request_timeout,
45
+ # sampling_params=sampling_params,
46
+ # pbar=pbar,
47
+ # callback=callback,
48
+ # debug=debug,
49
+ # all_model_names=all_model_names,
50
+ # all_sampling_params=all_sampling_params,
51
+ # )
52
+ # self.model = APIModel.from_registry(model_name)
53
+ # self.url = f"{self.model.api_base}/chat/completions"
54
+ # self.request_header = {
55
+ # "Authorization": f"Bearer {os.getenv(self.model.api_key_env_var)}"
56
+ # }
57
+ # if prompt.image is not None:
58
+ # raise ValueError("Mistral does not support images.")
59
+
60
+ # self.request_json = {
61
+ # "model": self.model.name,
62
+ # "messages": prompt.to_openai(),
63
+ # "temperature": sampling_params.temperature,
64
+ # "top_p": sampling_params.top_p,
65
+ # "max_tokens": sampling_params.max_new_tokens,
66
+ # }
67
+ # if sampling_params.json_mode and self.model.supports_json:
68
+ # self.request_json["response_format"] = {"type": "json_object"}
69
+
70
+ # async def handle_response(self, response: ClientResponse) -> APIResponse:
71
+ # is_error = False
72
+ # error_message = None
73
+ # completion = None
74
+ # input_tokens = None
75
+ # output_tokens = None
76
+ # status_code = response.status
77
+ # mimetype = response.headers.get("Content-Type", None)
78
+ # if status_code >= 200 and status_code < 300:
79
+ # try:
80
+ # data = await response.json()
81
+ # completion = data["choices"][0]["message"]["content"]
82
+ # input_tokens = data["usage"]["prompt_tokens"]
83
+ # output_tokens = data["usage"]["completion_tokens"]
84
+
85
+ # except Exception:
86
+ # is_error = True
87
+ # error_message = (
88
+ # f"Error calling .json() on response w/ status {status_code}"
89
+ # )
90
+ # elif "json" in mimetype.lower():
91
+ # is_error = True # expected status is 200, otherwise it's an error
92
+ # data = await response.json()
93
+ # error_message = json.dumps(data)
94
+ # else:
95
+ # is_error = True
96
+ # text = await response.text()
97
+ # error_message = text
98
+
99
+ # # handle special kinds of errors
100
+ # if is_error and error_message is not None:
101
+ # if "rate limit" in error_message.lower():
102
+ # error_message += " (Rate limit error, triggering cooldown.)"
103
+ # self.status_tracker.time_of_last_rate_limit_error = time.time()
104
+ # self.status_tracker.num_rate_limit_errors += 1
105
+ # if "context length" in error_message:
106
+ # error_message += " (Context length exceeded, set retries to 0.)"
107
+ # self.attempts_left = 0
108
+
109
+ # return APIResponse(
110
+ # id=self.task_id,
111
+ # status_code=status_code,
112
+ # is_error=is_error,
113
+ # error_message=error_message,
114
+ # prompt=self.prompt,
115
+ # completion=completion,
116
+ # model_internal=self.model_name,
117
+ # sampling_params=self.sampling_params,
118
+ # input_tokens=input_tokens,
119
+ # output_tokens=output_tokens,
120
+ # )
File without changes