lm-deluge 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lm-deluge might be problematic. Click here for more details.
- lm_deluge/__init__.py +2 -1
- lm_deluge/api_requests/base.py +1 -0
- lm_deluge/api_requests/common.py +2 -11
- lm_deluge/api_requests/deprecated/cohere.py +132 -0
- lm_deluge/api_requests/deprecated/vertex.py +361 -0
- lm_deluge/api_requests/{cohere.py → mistral.py} +37 -31
- lm_deluge/api_requests/openai.py +10 -1
- lm_deluge/client.py +2 -0
- lm_deluge/image.py +6 -0
- lm_deluge/models.py +348 -288
- lm_deluge/prompt.py +11 -9
- lm_deluge/util/json.py +4 -3
- lm_deluge/util/xml.py +11 -12
- lm_deluge-0.0.6.dist-info/METADATA +170 -0
- {lm_deluge-0.0.5.dist-info → lm_deluge-0.0.6.dist-info}/RECORD +17 -17
- lm_deluge/api_requests/google.py +0 -0
- lm_deluge/api_requests/vertex.py +0 -361
- lm_deluge-0.0.5.dist-info/METADATA +0 -127
- {lm_deluge-0.0.5.dist-info → lm_deluge-0.0.6.dist-info}/WHEEL +0 -0
- {lm_deluge-0.0.5.dist-info → lm_deluge-0.0.6.dist-info}/top_level.txt +0 -0
lm_deluge/__init__.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from .client import LLMClient, SamplingParams, APIResponse
|
|
2
|
+
from .prompt import Conversation, Message
|
|
2
3
|
import dotenv
|
|
3
4
|
|
|
4
5
|
dotenv.load_dotenv()
|
|
5
6
|
|
|
6
|
-
__all__ = ["LLMClient", "SamplingParams", "APIResponse"]
|
|
7
|
+
__all__ = ["LLMClient", "SamplingParams", "APIResponse", "Conversation", "Message"]
|
lm_deluge/api_requests/base.py
CHANGED
|
@@ -41,6 +41,7 @@ class APIResponse:
|
|
|
41
41
|
logprobs: list | None = None
|
|
42
42
|
finish_reason: str | None = None # make required later
|
|
43
43
|
cost: float | None = None # calculated automatically
|
|
44
|
+
cache_hit: bool = False # manually set if true
|
|
44
45
|
# set to true if is_error and should be retried with a different model
|
|
45
46
|
retry_with_different_model: bool | None = False
|
|
46
47
|
# set to true if should NOT retry with the same model (unrecoverable error)
|
lm_deluge/api_requests/common.py
CHANGED
|
@@ -1,18 +1,9 @@
|
|
|
1
|
-
# from .vertex import VertexAnthropicRequest, GeminiRequest
|
|
2
|
-
# from .bedrock import BedrockAnthropicRequest, MistralBedrockRequest
|
|
3
|
-
# from .deepseek import DeepseekRequest
|
|
4
1
|
from .openai import OpenAIRequest
|
|
5
|
-
from .cohere import CohereRequest
|
|
6
2
|
from .anthropic import AnthropicRequest
|
|
3
|
+
from .mistral import MistralRequest
|
|
7
4
|
|
|
8
5
|
CLASSES = {
|
|
9
6
|
"openai": OpenAIRequest,
|
|
10
|
-
# "deepseek": DeepseekRequest,
|
|
11
7
|
"anthropic": AnthropicRequest,
|
|
12
|
-
|
|
13
|
-
# "vertex_gemini": GeminiRequest,
|
|
14
|
-
"cohere": CohereRequest,
|
|
15
|
-
# "bedrock_anthropic": BedrockAnthropicRequest,
|
|
16
|
-
# "bedrock_mistral": MistralBedrockRequest,
|
|
17
|
-
# "mistral": MistralRequest,
|
|
8
|
+
"mistral": MistralRequest,
|
|
18
9
|
}
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
# # https://docs.cohere.com/reference/chat
|
|
2
|
+
# # https://cohere.com/pricing
|
|
3
|
+
# import asyncio
|
|
4
|
+
# from aiohttp import ClientResponse
|
|
5
|
+
# import json
|
|
6
|
+
# import os
|
|
7
|
+
# from tqdm import tqdm
|
|
8
|
+
# from typing import Callable
|
|
9
|
+
# from lm_deluge.prompt import Conversation
|
|
10
|
+
# from .base import APIRequestBase, APIResponse
|
|
11
|
+
|
|
12
|
+
# from ..tracker import StatusTracker
|
|
13
|
+
# from ..sampling_params import SamplingParams
|
|
14
|
+
# from ..models import APIModel
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# class CohereRequest(APIRequestBase):
|
|
18
|
+
# def __init__(
|
|
19
|
+
# self,
|
|
20
|
+
# task_id: int,
|
|
21
|
+
# # should always be 'role', 'content' keys.
|
|
22
|
+
# # internal logic should handle translating to specific API format
|
|
23
|
+
# model_name: str, # must correspond to registry
|
|
24
|
+
# prompt: Conversation,
|
|
25
|
+
# attempts_left: int,
|
|
26
|
+
# status_tracker: StatusTracker,
|
|
27
|
+
# results_arr: list,
|
|
28
|
+
# retry_queue: asyncio.Queue,
|
|
29
|
+
# request_timeout: int = 30,
|
|
30
|
+
# sampling_params: SamplingParams = SamplingParams(),
|
|
31
|
+
# pbar: tqdm | None = None,
|
|
32
|
+
# callback: Callable | None = None,
|
|
33
|
+
# debug: bool = False,
|
|
34
|
+
# all_model_names: list[str] | None = None,
|
|
35
|
+
# all_sampling_params: list[SamplingParams] | None = None,
|
|
36
|
+
# ):
|
|
37
|
+
# super().__init__(
|
|
38
|
+
# task_id=task_id,
|
|
39
|
+
# model_name=model_name,
|
|
40
|
+
# prompt=prompt,
|
|
41
|
+
# attempts_left=attempts_left,
|
|
42
|
+
# status_tracker=status_tracker,
|
|
43
|
+
# retry_queue=retry_queue,
|
|
44
|
+
# results_arr=results_arr,
|
|
45
|
+
# request_timeout=request_timeout,
|
|
46
|
+
# sampling_params=sampling_params,
|
|
47
|
+
# pbar=pbar,
|
|
48
|
+
# callback=callback,
|
|
49
|
+
# debug=debug,
|
|
50
|
+
# all_model_names=all_model_names,
|
|
51
|
+
# all_sampling_params=all_sampling_params,
|
|
52
|
+
# )
|
|
53
|
+
# self.system_message = None
|
|
54
|
+
# self.last_user_message = None
|
|
55
|
+
|
|
56
|
+
# self.model = APIModel.from_registry(model_name)
|
|
57
|
+
# self.url = f"{self.model.api_base}/chat"
|
|
58
|
+
# messages = prompt.to_cohere()
|
|
59
|
+
|
|
60
|
+
# self.request_header = {
|
|
61
|
+
# "Authorization": f"bearer {os.getenv(self.model.api_key_env_var)}",
|
|
62
|
+
# "content-type": "application/json",
|
|
63
|
+
# "accept": "application/json",
|
|
64
|
+
# }
|
|
65
|
+
|
|
66
|
+
# self.request_json = {
|
|
67
|
+
# "model": self.model.name,
|
|
68
|
+
# "messages": messages,
|
|
69
|
+
# "temperature": sampling_params.temperature,
|
|
70
|
+
# "top_p": sampling_params.top_p,
|
|
71
|
+
# "max_tokens": sampling_params.max_new_tokens,
|
|
72
|
+
# }
|
|
73
|
+
|
|
74
|
+
# async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
75
|
+
# is_error = False
|
|
76
|
+
# error_message = None
|
|
77
|
+
# completion = None
|
|
78
|
+
# input_tokens = None
|
|
79
|
+
# output_tokens = None
|
|
80
|
+
# status_code = http_response.status
|
|
81
|
+
# mimetype = http_response.headers.get("Content-Type", None)
|
|
82
|
+
# if status_code >= 200 and status_code < 300:
|
|
83
|
+
# try:
|
|
84
|
+
# data = await http_response.json()
|
|
85
|
+
# except Exception:
|
|
86
|
+
# data = None
|
|
87
|
+
# is_error = True
|
|
88
|
+
# error_message = (
|
|
89
|
+
# f"Error calling .json() on response w/ status {status_code}"
|
|
90
|
+
# )
|
|
91
|
+
# if not is_error and isinstance(data, dict):
|
|
92
|
+
# try:
|
|
93
|
+
# completion = data["text"]
|
|
94
|
+
# input_tokens = data["meta"]["billed_units"]["input_tokens"]
|
|
95
|
+
# output_tokens = data["meta"]["billed_units"]["input_tokens"]
|
|
96
|
+
# except Exception:
|
|
97
|
+
# is_error = True
|
|
98
|
+
# error_message = f"Error getting 'text' or 'meta' from {self.model.name} response."
|
|
99
|
+
# elif mimetype is not None and "json" in mimetype.lower():
|
|
100
|
+
# is_error = True # expected status is 200, otherwise it's an error
|
|
101
|
+
# data = await http_response.json()
|
|
102
|
+
# error_message = json.dumps(data)
|
|
103
|
+
|
|
104
|
+
# else:
|
|
105
|
+
# is_error = True
|
|
106
|
+
# text = await http_response.text()
|
|
107
|
+
# error_message = text
|
|
108
|
+
|
|
109
|
+
# # handle special kinds of errors. TODO: make sure these are correct for anthropic
|
|
110
|
+
# if is_error and error_message is not None:
|
|
111
|
+
# if (
|
|
112
|
+
# "rate limit" in error_message.lower()
|
|
113
|
+
# or "overloaded" in error_message.lower()
|
|
114
|
+
# ):
|
|
115
|
+
# error_message += " (Rate limit error, triggering cooldown.)"
|
|
116
|
+
# self.status_tracker.rate_limit_exceeded()
|
|
117
|
+
# if "context length" in error_message:
|
|
118
|
+
# error_message += " (Context length exceeded, set retries to 0.)"
|
|
119
|
+
# self.attempts_left = 0
|
|
120
|
+
|
|
121
|
+
# return APIResponse(
|
|
122
|
+
# id=self.task_id,
|
|
123
|
+
# status_code=status_code,
|
|
124
|
+
# is_error=is_error,
|
|
125
|
+
# error_message=error_message,
|
|
126
|
+
# prompt=self.prompt,
|
|
127
|
+
# completion=completion,
|
|
128
|
+
# model_internal=self.model_name,
|
|
129
|
+
# sampling_params=self.sampling_params,
|
|
130
|
+
# input_tokens=input_tokens,
|
|
131
|
+
# output_tokens=output_tokens,
|
|
132
|
+
# )
|
|
@@ -0,0 +1,361 @@
|
|
|
1
|
+
# # consider: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/call-gemini-using-openai-library#call-chat-completions-api
|
|
2
|
+
# import asyncio
|
|
3
|
+
# from aiohttp import ClientResponse
|
|
4
|
+
# import json
|
|
5
|
+
# import os
|
|
6
|
+
# import time
|
|
7
|
+
# from tqdm import tqdm
|
|
8
|
+
# from typing import Callable
|
|
9
|
+
|
|
10
|
+
# from lm_deluge.prompt import Conversation
|
|
11
|
+
# from .base import APIRequestBase, APIResponse
|
|
12
|
+
# from ..tracker import StatusTracker
|
|
13
|
+
# from ..sampling_params import SamplingParams
|
|
14
|
+
# from ..models import APIModel
|
|
15
|
+
|
|
16
|
+
# from google.oauth2 import service_account
|
|
17
|
+
# from google.auth.transport.requests import Request
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# def get_access_token(service_account_file: str):
|
|
21
|
+
# """
|
|
22
|
+
# Get access token from environment variables if another process/coroutine
|
|
23
|
+
# has already got them, otherwise get from service account file.
|
|
24
|
+
# """
|
|
25
|
+
# LAST_REFRESHED = os.getenv("VERTEX_TOKEN_LAST_REFRESHED", None)
|
|
26
|
+
# LAST_REFRESHED = int(LAST_REFRESHED) if LAST_REFRESHED is not None else 0
|
|
27
|
+
# VERTEX_API_TOKEN = os.getenv("VERTEX_API_TOKEN", None)
|
|
28
|
+
|
|
29
|
+
# if VERTEX_API_TOKEN is not None and time.time() - LAST_REFRESHED < 60 * 50:
|
|
30
|
+
# return VERTEX_API_TOKEN
|
|
31
|
+
# else:
|
|
32
|
+
# credentials = service_account.Credentials.from_service_account_file(
|
|
33
|
+
# service_account_file,
|
|
34
|
+
# scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
35
|
+
# )
|
|
36
|
+
# credentials.refresh(Request())
|
|
37
|
+
# token = credentials.token
|
|
38
|
+
# os.environ["VERTEX_API_TOKEN"] = token
|
|
39
|
+
# os.environ["VERTEX_TOKEN_LAST_REFRESHED"] = str(int(time.time()))
|
|
40
|
+
|
|
41
|
+
# return token
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# class VertexAnthropicRequest(APIRequestBase):
|
|
45
|
+
# """
|
|
46
|
+
# For Claude on Vertex, you'll also have to set the PROJECT_ID environment variable.
|
|
47
|
+
# """
|
|
48
|
+
|
|
49
|
+
# def __init__(
|
|
50
|
+
# self,
|
|
51
|
+
# task_id: int,
|
|
52
|
+
# model_name: str, # must correspond to registry
|
|
53
|
+
# prompt: Conversation,
|
|
54
|
+
# attempts_left: int,
|
|
55
|
+
# status_tracker: StatusTracker,
|
|
56
|
+
# retry_queue: asyncio.Queue,
|
|
57
|
+
# results_arr: list,
|
|
58
|
+
# request_timeout: int = 30,
|
|
59
|
+
# sampling_params: SamplingParams = SamplingParams(),
|
|
60
|
+
# pbar: tqdm | None = None,
|
|
61
|
+
# callback: Callable | None = None,
|
|
62
|
+
# debug: bool = False,
|
|
63
|
+
# ):
|
|
64
|
+
# super().__init__(
|
|
65
|
+
# task_id=task_id,
|
|
66
|
+
# model_name=model_name,
|
|
67
|
+
# prompt=prompt,
|
|
68
|
+
# attempts_left=attempts_left,
|
|
69
|
+
# status_tracker=status_tracker,
|
|
70
|
+
# retry_queue=retry_queue,
|
|
71
|
+
# results_arr=results_arr,
|
|
72
|
+
# request_timeout=request_timeout,
|
|
73
|
+
# sampling_params=sampling_params,
|
|
74
|
+
# pbar=pbar,
|
|
75
|
+
# callback=callback,
|
|
76
|
+
# debug=debug,
|
|
77
|
+
# )
|
|
78
|
+
# creds = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
|
|
79
|
+
# if not creds:
|
|
80
|
+
# raise RuntimeError(
|
|
81
|
+
# "GOOGLE_APPLICATION_CREDENTIALS not provided in environment"
|
|
82
|
+
# )
|
|
83
|
+
# token = get_access_token(creds)
|
|
84
|
+
|
|
85
|
+
# self.model = APIModel.from_registry(model_name)
|
|
86
|
+
# project_id = os.getenv("PROJECT_ID")
|
|
87
|
+
# region = self.model.sample_region()
|
|
88
|
+
|
|
89
|
+
# endpoint = f"https://{region}-aiplatform.googleapis.com"
|
|
90
|
+
# self.url = f"{endpoint}/v1/projects/{project_id}/locations/{region}/publishers/anthropic/models/{self.model.name}:generateContent"
|
|
91
|
+
# self.request_header = {
|
|
92
|
+
# "Authorization": f"Bearer {token}",
|
|
93
|
+
# "Content-Type": "application/json",
|
|
94
|
+
# }
|
|
95
|
+
# self.system_message, messages = prompt.to_anthropic()
|
|
96
|
+
|
|
97
|
+
# self.request_json = {
|
|
98
|
+
# "anthropic_version": "vertex-2023-10-16",
|
|
99
|
+
# "messages": messages,
|
|
100
|
+
# "temperature": self.sampling_params.temperature,
|
|
101
|
+
# "top_p": self.sampling_params.top_p,
|
|
102
|
+
# "max_tokens": self.sampling_params.max_new_tokens,
|
|
103
|
+
# }
|
|
104
|
+
# if self.system_message is not None:
|
|
105
|
+
# self.request_json["system"] = self.system_message
|
|
106
|
+
|
|
107
|
+
# async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
108
|
+
# is_error = False
|
|
109
|
+
# error_message = None
|
|
110
|
+
# completion = None
|
|
111
|
+
# input_tokens = None
|
|
112
|
+
# output_tokens = None
|
|
113
|
+
# status_code = http_response.status
|
|
114
|
+
# mimetype = http_response.headers.get("Content-Type", None)
|
|
115
|
+
# if status_code >= 200 and status_code < 300:
|
|
116
|
+
# try:
|
|
117
|
+
# data = await http_response.json()
|
|
118
|
+
# completion = data["content"][0]["text"]
|
|
119
|
+
# input_tokens = data["usage"]["input_tokens"]
|
|
120
|
+
# output_tokens = data["usage"]["output_tokens"]
|
|
121
|
+
# except Exception as e:
|
|
122
|
+
# is_error = True
|
|
123
|
+
# error_message = (
|
|
124
|
+
# f"Error calling .json() on response w/ status {status_code}: {e}"
|
|
125
|
+
# )
|
|
126
|
+
# elif "json" in (mimetype or "").lower():
|
|
127
|
+
# is_error = True # expected status is 200, otherwise it's an error
|
|
128
|
+
# data = await http_response.json()
|
|
129
|
+
# error_message = json.dumps(data)
|
|
130
|
+
|
|
131
|
+
# else:
|
|
132
|
+
# is_error = True
|
|
133
|
+
# text = await http_response.text()
|
|
134
|
+
# error_message = text
|
|
135
|
+
|
|
136
|
+
# # handle special kinds of errors. TODO: make sure these are correct for anthropic
|
|
137
|
+
# if is_error and error_message is not None:
|
|
138
|
+
# if (
|
|
139
|
+
# "rate limit" in error_message.lower()
|
|
140
|
+
# or "overloaded" in error_message.lower()
|
|
141
|
+
# or status_code == 429
|
|
142
|
+
# ):
|
|
143
|
+
# error_message += " (Rate limit error, triggering cooldown.)"
|
|
144
|
+
# self.status_tracker.rate_limit_exceeded()
|
|
145
|
+
# if "context length" in error_message:
|
|
146
|
+
# error_message += " (Context length exceeded, set retries to 0.)"
|
|
147
|
+
# self.attempts_left = 0
|
|
148
|
+
|
|
149
|
+
# return APIResponse(
|
|
150
|
+
# id=self.task_id,
|
|
151
|
+
# status_code=status_code,
|
|
152
|
+
# is_error=is_error,
|
|
153
|
+
# error_message=error_message,
|
|
154
|
+
# prompt=self.prompt,
|
|
155
|
+
# completion=completion,
|
|
156
|
+
# model_internal=self.model_name,
|
|
157
|
+
# sampling_params=self.sampling_params,
|
|
158
|
+
# input_tokens=input_tokens,
|
|
159
|
+
# output_tokens=output_tokens,
|
|
160
|
+
# )
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
# SAFETY_SETTING_CATEGORIES = [
|
|
164
|
+
# "HARM_CATEGORY_DANGEROUS_CONTENT",
|
|
165
|
+
# "HARM_CATEGORY_HARASSMENT",
|
|
166
|
+
# "HARM_CATEGORY_HATE_SPEECH",
|
|
167
|
+
# "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
|
168
|
+
# ]
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
# class GeminiRequest(APIRequestBase):
|
|
172
|
+
# """
|
|
173
|
+
# For Gemini, you'll also have to set the PROJECT_ID environment variable.
|
|
174
|
+
# """
|
|
175
|
+
|
|
176
|
+
# def __init__(
|
|
177
|
+
# self,
|
|
178
|
+
# task_id: int,
|
|
179
|
+
# model_name: str, # must correspond to registry
|
|
180
|
+
# prompt: Conversation,
|
|
181
|
+
# attempts_left: int,
|
|
182
|
+
# status_tracker: StatusTracker,
|
|
183
|
+
# retry_queue: asyncio.Queue,
|
|
184
|
+
# results_arr: list,
|
|
185
|
+
# request_timeout: int = 30,
|
|
186
|
+
# sampling_params: SamplingParams = SamplingParams(),
|
|
187
|
+
# pbar: tqdm | None = None,
|
|
188
|
+
# callback: Callable | None = None,
|
|
189
|
+
# debug: bool = False,
|
|
190
|
+
# all_model_names: list[str] | None = None,
|
|
191
|
+
# all_sampling_params: list[SamplingParams] | None = None,
|
|
192
|
+
# ):
|
|
193
|
+
# super().__init__(
|
|
194
|
+
# task_id=task_id,
|
|
195
|
+
# model_name=model_name,
|
|
196
|
+
# prompt=prompt,
|
|
197
|
+
# attempts_left=attempts_left,
|
|
198
|
+
# status_tracker=status_tracker,
|
|
199
|
+
# retry_queue=retry_queue,
|
|
200
|
+
# results_arr=results_arr,
|
|
201
|
+
# request_timeout=request_timeout,
|
|
202
|
+
# sampling_params=sampling_params,
|
|
203
|
+
# pbar=pbar,
|
|
204
|
+
# callback=callback,
|
|
205
|
+
# debug=debug,
|
|
206
|
+
# all_model_names=all_model_names,
|
|
207
|
+
# all_sampling_params=all_sampling_params,
|
|
208
|
+
# )
|
|
209
|
+
# self.model = APIModel.from_registry(model_name)
|
|
210
|
+
# credentials_file = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
|
|
211
|
+
# if not credentials_file:
|
|
212
|
+
# raise RuntimeError(
|
|
213
|
+
# "no credentials file found. ensure you provide a google credentials file and point to it with GOOGLE_APPLICATION_CREDENTIALS environment variable."
|
|
214
|
+
# )
|
|
215
|
+
# token = get_access_token(credentials_file)
|
|
216
|
+
# self.project_id = os.getenv("PROJECT_ID")
|
|
217
|
+
# # sample weighted by region counts
|
|
218
|
+
# self.region = self.model.sample_region()
|
|
219
|
+
# assert self.region is not None, "unable to sample region"
|
|
220
|
+
# self.url = f"https://{self.region}-aiplatform.googleapis.com/v1/projects/{self.project_id}/locations/{self.region}/publishers/google/models/{self.model.name}:generateContent"
|
|
221
|
+
|
|
222
|
+
# self.request_header = {
|
|
223
|
+
# "Authorization": f"Bearer {token}",
|
|
224
|
+
# "Content-Type": "application/json",
|
|
225
|
+
# }
|
|
226
|
+
# self.system_message, contents = prompt.to_gemini()
|
|
227
|
+
# self.request_json = {
|
|
228
|
+
# "contents": contents,
|
|
229
|
+
# "generationConfig": {
|
|
230
|
+
# "stopSequences": [],
|
|
231
|
+
# "temperature": sampling_params.temperature,
|
|
232
|
+
# "maxOutputTokens": sampling_params.max_new_tokens,
|
|
233
|
+
# "topP": sampling_params.top_p,
|
|
234
|
+
# "topK": None,
|
|
235
|
+
# },
|
|
236
|
+
# "safetySettings": [
|
|
237
|
+
# {"category": category, "threshold": "BLOCK_NONE"}
|
|
238
|
+
# for category in SAFETY_SETTING_CATEGORIES
|
|
239
|
+
# ],
|
|
240
|
+
# }
|
|
241
|
+
# if sampling_params.json_mode and self.model.supports_json:
|
|
242
|
+
# self.request_json["generationConfig"]["responseMimeType"] = (
|
|
243
|
+
# "application/json"
|
|
244
|
+
# )
|
|
245
|
+
|
|
246
|
+
# if self.system_message is not None:
|
|
247
|
+
# self.request_json["systemInstruction"] = (
|
|
248
|
+
# {"role": "SYSTEM", "parts": [{"text": self.system_message}]},
|
|
249
|
+
# )
|
|
250
|
+
|
|
251
|
+
# async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
252
|
+
# is_error = False
|
|
253
|
+
# error_message = None
|
|
254
|
+
# completion = None
|
|
255
|
+
# input_tokens = None
|
|
256
|
+
# output_tokens = None
|
|
257
|
+
# finish_reason = None
|
|
258
|
+
# data = None
|
|
259
|
+
# retry_with_different_model = False
|
|
260
|
+
# give_up_if_no_other_models = False
|
|
261
|
+
# status_code = http_response.status
|
|
262
|
+
# mimetype = http_response.headers.get("Content-Type", None)
|
|
263
|
+
# if status_code >= 200 and status_code < 300:
|
|
264
|
+
# try:
|
|
265
|
+
# data = await http_response.json()
|
|
266
|
+
# if "candidates" not in data:
|
|
267
|
+
# is_error = True
|
|
268
|
+
# if "promptFeedback" in data:
|
|
269
|
+
# error_message = "Prompt rejected. Feedback: " + str(
|
|
270
|
+
# data["promptFeedback"]
|
|
271
|
+
# )
|
|
272
|
+
# else:
|
|
273
|
+
# error_message = "No candidates in response."
|
|
274
|
+
# retry_with_different_model = True
|
|
275
|
+
# give_up_if_no_other_models = True
|
|
276
|
+
# else:
|
|
277
|
+
# candidate = data["candidates"][0]
|
|
278
|
+
# finish_reason = candidate["finishReason"]
|
|
279
|
+
# if "content" in candidate:
|
|
280
|
+
# parts = candidate["content"]["parts"]
|
|
281
|
+
# completion = " ".join([part["text"] for part in parts])
|
|
282
|
+
# usage = data["usageMetadata"]
|
|
283
|
+
# input_tokens = usage["promptTokenCount"]
|
|
284
|
+
# output_tokens = usage["candidatesTokenCount"]
|
|
285
|
+
# elif finish_reason == "RECITATION":
|
|
286
|
+
# is_error = True
|
|
287
|
+
# citations = candidate.get("citationMetadata", {}).get(
|
|
288
|
+
# "citations", []
|
|
289
|
+
# )
|
|
290
|
+
# urls = ",".join(
|
|
291
|
+
# [citation.get("uri", "") for citation in citations]
|
|
292
|
+
# )
|
|
293
|
+
# error_message = "Finish reason RECITATION. URLS: " + urls
|
|
294
|
+
# retry_with_different_model = True
|
|
295
|
+
# elif finish_reason == "OTHER":
|
|
296
|
+
# is_error = True
|
|
297
|
+
# error_message = "Finish reason OTHER."
|
|
298
|
+
# retry_with_different_model = True
|
|
299
|
+
# elif finish_reason == "SAFETY":
|
|
300
|
+
# is_error = True
|
|
301
|
+
# error_message = "Finish reason SAFETY."
|
|
302
|
+
# retry_with_different_model = True
|
|
303
|
+
# else:
|
|
304
|
+
# print("Actual structure of response:", data)
|
|
305
|
+
# is_error = True
|
|
306
|
+
# error_message = "No content in response."
|
|
307
|
+
# except Exception as e:
|
|
308
|
+
# is_error = True
|
|
309
|
+
# error_message = f"Error calling .json() on response w/ status {status_code}: {e.__class__} {e}"
|
|
310
|
+
# if isinstance(e, KeyError):
|
|
311
|
+
# print("Actual structure of response:", data)
|
|
312
|
+
# elif "json" in (mimetype or "").lower():
|
|
313
|
+
# is_error = True
|
|
314
|
+
# data = await http_response.json()
|
|
315
|
+
# error_message = json.dumps(data)
|
|
316
|
+
# else:
|
|
317
|
+
# is_error = True
|
|
318
|
+
# text = await http_response.text()
|
|
319
|
+
# error_message = text
|
|
320
|
+
|
|
321
|
+
# old_region = self.region
|
|
322
|
+
# if is_error and error_message is not None:
|
|
323
|
+
# if (
|
|
324
|
+
# "rate limit" in error_message.lower()
|
|
325
|
+
# or "temporarily out of capacity" in error_message.lower()
|
|
326
|
+
# or "exceeded" in error_message.lower()
|
|
327
|
+
# or
|
|
328
|
+
# # 429 code
|
|
329
|
+
# status_code == 429
|
|
330
|
+
# ):
|
|
331
|
+
# error_message += " (Rate limit error, triggering cooldown & retrying with different model.)"
|
|
332
|
+
# self.status_tracker.rate_limit_exceeded()
|
|
333
|
+
# retry_with_different_model = (
|
|
334
|
+
# True # if possible, retry with a different model
|
|
335
|
+
# )
|
|
336
|
+
# if is_error:
|
|
337
|
+
# # change the region in case error is due to region unavailability
|
|
338
|
+
# self.region = self.model.sample_region()
|
|
339
|
+
# assert self.region is not None, "Unable to sample region"
|
|
340
|
+
# self.url = f"https://{self.region}-aiplatform.googleapis.com/v1/projects/{self.project_id}/locations/{self.region}/publishers/google/models/{self.model.name}:generateContent"
|
|
341
|
+
|
|
342
|
+
# return APIResponse(
|
|
343
|
+
# id=self.task_id,
|
|
344
|
+
# status_code=status_code,
|
|
345
|
+
# is_error=is_error,
|
|
346
|
+
# error_message=error_message,
|
|
347
|
+
# prompt=self.prompt,
|
|
348
|
+
# completion=completion,
|
|
349
|
+
# model_internal=self.model_name,
|
|
350
|
+
# sampling_params=self.sampling_params,
|
|
351
|
+
# input_tokens=input_tokens,
|
|
352
|
+
# output_tokens=output_tokens,
|
|
353
|
+
# region=old_region,
|
|
354
|
+
# finish_reason=finish_reason,
|
|
355
|
+
# retry_with_different_model=retry_with_different_model,
|
|
356
|
+
# give_up_if_no_other_models=give_up_if_no_other_models,
|
|
357
|
+
# )
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
# # class LlamaEndpointRequest(APIRequestBase):
|
|
361
|
+
# # raise NotImplementedError("Llama endpoints are not implemented and never will be because Vertex AI sucks ass.")
|
|
@@ -1,20 +1,19 @@
|
|
|
1
|
-
# https://docs.cohere.com/reference/chat
|
|
2
|
-
# https://cohere.com/pricing
|
|
3
1
|
import asyncio
|
|
2
|
+
import warnings
|
|
4
3
|
from aiohttp import ClientResponse
|
|
5
4
|
import json
|
|
6
5
|
import os
|
|
7
|
-
from tqdm import tqdm
|
|
6
|
+
from tqdm.auto import tqdm
|
|
8
7
|
from typing import Callable
|
|
9
|
-
from lm_deluge.prompt import Conversation
|
|
10
|
-
from .base import APIRequestBase, APIResponse
|
|
11
8
|
|
|
9
|
+
from .base import APIRequestBase, APIResponse
|
|
10
|
+
from ..prompt import Conversation
|
|
12
11
|
from ..tracker import StatusTracker
|
|
13
12
|
from ..sampling_params import SamplingParams
|
|
14
13
|
from ..models import APIModel
|
|
15
14
|
|
|
16
15
|
|
|
17
|
-
class
|
|
16
|
+
class MistralRequest(APIRequestBase):
|
|
18
17
|
def __init__(
|
|
19
18
|
self,
|
|
20
19
|
task_id: int,
|
|
@@ -24,10 +23,12 @@ class CohereRequest(APIRequestBase):
|
|
|
24
23
|
prompt: Conversation,
|
|
25
24
|
attempts_left: int,
|
|
26
25
|
status_tracker: StatusTracker,
|
|
27
|
-
results_arr: list,
|
|
28
26
|
retry_queue: asyncio.Queue,
|
|
27
|
+
results_arr: list,
|
|
29
28
|
request_timeout: int = 30,
|
|
30
29
|
sampling_params: SamplingParams = SamplingParams(),
|
|
30
|
+
logprobs: bool = False,
|
|
31
|
+
top_logprobs: int | None = None,
|
|
31
32
|
pbar: tqdm | None = None,
|
|
32
33
|
callback: Callable | None = None,
|
|
33
34
|
debug: bool = False,
|
|
@@ -44,32 +45,36 @@ class CohereRequest(APIRequestBase):
|
|
|
44
45
|
results_arr=results_arr,
|
|
45
46
|
request_timeout=request_timeout,
|
|
46
47
|
sampling_params=sampling_params,
|
|
48
|
+
logprobs=logprobs,
|
|
49
|
+
top_logprobs=top_logprobs,
|
|
47
50
|
pbar=pbar,
|
|
48
51
|
callback=callback,
|
|
49
52
|
debug=debug,
|
|
50
53
|
all_model_names=all_model_names,
|
|
51
54
|
all_sampling_params=all_sampling_params,
|
|
52
55
|
)
|
|
53
|
-
self.system_message = None
|
|
54
|
-
self.last_user_message = None
|
|
55
|
-
|
|
56
56
|
self.model = APIModel.from_registry(model_name)
|
|
57
|
-
self.url = f"{self.model.api_base}/chat"
|
|
58
|
-
messages = prompt.to_cohere()
|
|
59
|
-
|
|
57
|
+
self.url = f"{self.model.api_base}/chat/completions"
|
|
60
58
|
self.request_header = {
|
|
61
|
-
"Authorization": f"
|
|
62
|
-
"content-type": "application/json",
|
|
63
|
-
"accept": "application/json",
|
|
59
|
+
"Authorization": f"Bearer {os.getenv(self.model.api_key_env_var)}"
|
|
64
60
|
}
|
|
65
|
-
|
|
66
61
|
self.request_json = {
|
|
67
62
|
"model": self.model.name,
|
|
68
|
-
"messages":
|
|
63
|
+
"messages": prompt.to_mistral(),
|
|
69
64
|
"temperature": sampling_params.temperature,
|
|
70
65
|
"top_p": sampling_params.top_p,
|
|
71
66
|
"max_tokens": sampling_params.max_new_tokens,
|
|
72
67
|
}
|
|
68
|
+
if sampling_params.reasoning_effort:
|
|
69
|
+
warnings.warn(
|
|
70
|
+
f"Ignoring reasoning_effort param for non-reasoning model: {model_name}"
|
|
71
|
+
)
|
|
72
|
+
if logprobs:
|
|
73
|
+
warnings.warn(
|
|
74
|
+
f"Ignoring logprobs param for non-logprobs model: {model_name}"
|
|
75
|
+
)
|
|
76
|
+
if sampling_params.json_mode and self.model.supports_json:
|
|
77
|
+
self.request_json["response_format"] = {"type": "json_object"}
|
|
73
78
|
|
|
74
79
|
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
75
80
|
is_error = False
|
|
@@ -77,41 +82,41 @@ class CohereRequest(APIRequestBase):
|
|
|
77
82
|
completion = None
|
|
78
83
|
input_tokens = None
|
|
79
84
|
output_tokens = None
|
|
85
|
+
logprobs = None
|
|
80
86
|
status_code = http_response.status
|
|
81
87
|
mimetype = http_response.headers.get("Content-Type", None)
|
|
88
|
+
data = None
|
|
82
89
|
if status_code >= 200 and status_code < 300:
|
|
83
90
|
try:
|
|
84
91
|
data = await http_response.json()
|
|
85
92
|
except Exception:
|
|
86
|
-
data = None
|
|
87
93
|
is_error = True
|
|
88
94
|
error_message = (
|
|
89
95
|
f"Error calling .json() on response w/ status {status_code}"
|
|
90
96
|
)
|
|
91
|
-
if not is_error
|
|
97
|
+
if not is_error:
|
|
98
|
+
assert data is not None, "data is None"
|
|
92
99
|
try:
|
|
93
|
-
completion = data["
|
|
94
|
-
input_tokens = data["
|
|
95
|
-
output_tokens = data["
|
|
100
|
+
completion = data["choices"][0]["message"]["content"]
|
|
101
|
+
input_tokens = data["usage"]["prompt_tokens"]
|
|
102
|
+
output_tokens = data["usage"]["completion_tokens"]
|
|
103
|
+
if self.logprobs and "logprobs" in data["choices"][0]:
|
|
104
|
+
logprobs = data["choices"][0]["logprobs"]["content"]
|
|
96
105
|
except Exception:
|
|
97
106
|
is_error = True
|
|
98
|
-
error_message = f"Error getting '
|
|
99
|
-
elif mimetype
|
|
107
|
+
error_message = f"Error getting 'choices' and 'usage' from {self.model.name} response."
|
|
108
|
+
elif mimetype and "json" in mimetype.lower():
|
|
100
109
|
is_error = True # expected status is 200, otherwise it's an error
|
|
101
110
|
data = await http_response.json()
|
|
102
111
|
error_message = json.dumps(data)
|
|
103
|
-
|
|
104
112
|
else:
|
|
105
113
|
is_error = True
|
|
106
114
|
text = await http_response.text()
|
|
107
115
|
error_message = text
|
|
108
116
|
|
|
109
|
-
# handle special kinds of errors
|
|
117
|
+
# handle special kinds of errors
|
|
110
118
|
if is_error and error_message is not None:
|
|
111
|
-
if (
|
|
112
|
-
"rate limit" in error_message.lower()
|
|
113
|
-
or "overloaded" in error_message.lower()
|
|
114
|
-
):
|
|
119
|
+
if "rate limit" in error_message.lower() or status_code == 429:
|
|
115
120
|
error_message += " (Rate limit error, triggering cooldown.)"
|
|
116
121
|
self.status_tracker.rate_limit_exceeded()
|
|
117
122
|
if "context length" in error_message:
|
|
@@ -124,6 +129,7 @@ class CohereRequest(APIRequestBase):
|
|
|
124
129
|
is_error=is_error,
|
|
125
130
|
error_message=error_message,
|
|
126
131
|
prompt=self.prompt,
|
|
132
|
+
logprobs=logprobs,
|
|
127
133
|
completion=completion,
|
|
128
134
|
model_internal=self.model_name,
|
|
129
135
|
sampling_params=self.sampling_params,
|