lm-deluge 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lm-deluge might be problematic. Click here for more details.
- lm_deluge/__init__.py +2 -1
- lm_deluge/api_requests/anthropic.py +0 -2
- lm_deluge/api_requests/base.py +1 -0
- lm_deluge/api_requests/common.py +2 -11
- lm_deluge/api_requests/deprecated/cohere.py +132 -0
- lm_deluge/api_requests/deprecated/vertex.py +361 -0
- lm_deluge/api_requests/{cohere.py → mistral.py} +37 -35
- lm_deluge/api_requests/openai.py +10 -1
- lm_deluge/client.py +2 -0
- lm_deluge/image.py +6 -0
- lm_deluge/models.py +348 -288
- lm_deluge/prompt.py +11 -9
- lm_deluge/util/json.py +4 -3
- lm_deluge/util/xml.py +11 -12
- lm_deluge-0.0.6.dist-info/METADATA +170 -0
- {lm_deluge-0.0.4.dist-info → lm_deluge-0.0.6.dist-info}/RECORD +18 -18
- lm_deluge/api_requests/google.py +0 -0
- lm_deluge/api_requests/vertex.py +0 -361
- lm_deluge-0.0.4.dist-info/METADATA +0 -127
- {lm_deluge-0.0.4.dist-info → lm_deluge-0.0.6.dist-info}/WHEEL +0 -0
- {lm_deluge-0.0.4.dist-info → lm_deluge-0.0.6.dist-info}/top_level.txt +0 -0
|
@@ -1,20 +1,19 @@
|
|
|
1
|
-
# https://docs.cohere.com/reference/chat
|
|
2
|
-
# https://cohere.com/pricing
|
|
3
1
|
import asyncio
|
|
2
|
+
import warnings
|
|
4
3
|
from aiohttp import ClientResponse
|
|
5
4
|
import json
|
|
6
5
|
import os
|
|
7
|
-
from tqdm import tqdm
|
|
6
|
+
from tqdm.auto import tqdm
|
|
8
7
|
from typing import Callable
|
|
9
|
-
from lm_deluge.prompt import Conversation
|
|
10
|
-
from .base import APIRequestBase, APIResponse
|
|
11
8
|
|
|
9
|
+
from .base import APIRequestBase, APIResponse
|
|
10
|
+
from ..prompt import Conversation
|
|
12
11
|
from ..tracker import StatusTracker
|
|
13
12
|
from ..sampling_params import SamplingParams
|
|
14
13
|
from ..models import APIModel
|
|
15
14
|
|
|
16
15
|
|
|
17
|
-
class
|
|
16
|
+
class MistralRequest(APIRequestBase):
|
|
18
17
|
def __init__(
|
|
19
18
|
self,
|
|
20
19
|
task_id: int,
|
|
@@ -24,10 +23,12 @@ class CohereRequest(APIRequestBase):
|
|
|
24
23
|
prompt: Conversation,
|
|
25
24
|
attempts_left: int,
|
|
26
25
|
status_tracker: StatusTracker,
|
|
27
|
-
results_arr: list,
|
|
28
26
|
retry_queue: asyncio.Queue,
|
|
27
|
+
results_arr: list,
|
|
29
28
|
request_timeout: int = 30,
|
|
30
29
|
sampling_params: SamplingParams = SamplingParams(),
|
|
30
|
+
logprobs: bool = False,
|
|
31
|
+
top_logprobs: int | None = None,
|
|
31
32
|
pbar: tqdm | None = None,
|
|
32
33
|
callback: Callable | None = None,
|
|
33
34
|
debug: bool = False,
|
|
@@ -44,36 +45,36 @@ class CohereRequest(APIRequestBase):
|
|
|
44
45
|
results_arr=results_arr,
|
|
45
46
|
request_timeout=request_timeout,
|
|
46
47
|
sampling_params=sampling_params,
|
|
48
|
+
logprobs=logprobs,
|
|
49
|
+
top_logprobs=top_logprobs,
|
|
47
50
|
pbar=pbar,
|
|
48
51
|
callback=callback,
|
|
49
52
|
debug=debug,
|
|
50
53
|
all_model_names=all_model_names,
|
|
51
54
|
all_sampling_params=all_sampling_params,
|
|
52
55
|
)
|
|
53
|
-
self.system_message = None
|
|
54
|
-
self.last_user_message = None
|
|
55
|
-
|
|
56
56
|
self.model = APIModel.from_registry(model_name)
|
|
57
|
-
self.url = f"{self.model.api_base}/chat"
|
|
58
|
-
self.system_message, chat_history, last_user_message = prompt.to_cohere()
|
|
59
|
-
|
|
57
|
+
self.url = f"{self.model.api_base}/chat/completions"
|
|
60
58
|
self.request_header = {
|
|
61
|
-
"Authorization": f"
|
|
62
|
-
"content-type": "application/json",
|
|
63
|
-
"accept": "application/json",
|
|
59
|
+
"Authorization": f"Bearer {os.getenv(self.model.api_key_env_var)}"
|
|
64
60
|
}
|
|
65
|
-
|
|
66
61
|
self.request_json = {
|
|
67
62
|
"model": self.model.name,
|
|
68
|
-
"
|
|
69
|
-
"message": last_user_message,
|
|
63
|
+
"messages": prompt.to_mistral(),
|
|
70
64
|
"temperature": sampling_params.temperature,
|
|
71
65
|
"top_p": sampling_params.top_p,
|
|
72
66
|
"max_tokens": sampling_params.max_new_tokens,
|
|
73
67
|
}
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
68
|
+
if sampling_params.reasoning_effort:
|
|
69
|
+
warnings.warn(
|
|
70
|
+
f"Ignoring reasoning_effort param for non-reasoning model: {model_name}"
|
|
71
|
+
)
|
|
72
|
+
if logprobs:
|
|
73
|
+
warnings.warn(
|
|
74
|
+
f"Ignoring logprobs param for non-logprobs model: {model_name}"
|
|
75
|
+
)
|
|
76
|
+
if sampling_params.json_mode and self.model.supports_json:
|
|
77
|
+
self.request_json["response_format"] = {"type": "json_object"}
|
|
77
78
|
|
|
78
79
|
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
79
80
|
is_error = False
|
|
@@ -81,41 +82,41 @@ class CohereRequest(APIRequestBase):
|
|
|
81
82
|
completion = None
|
|
82
83
|
input_tokens = None
|
|
83
84
|
output_tokens = None
|
|
85
|
+
logprobs = None
|
|
84
86
|
status_code = http_response.status
|
|
85
87
|
mimetype = http_response.headers.get("Content-Type", None)
|
|
88
|
+
data = None
|
|
86
89
|
if status_code >= 200 and status_code < 300:
|
|
87
90
|
try:
|
|
88
91
|
data = await http_response.json()
|
|
89
92
|
except Exception:
|
|
90
|
-
data = None
|
|
91
93
|
is_error = True
|
|
92
94
|
error_message = (
|
|
93
95
|
f"Error calling .json() on response w/ status {status_code}"
|
|
94
96
|
)
|
|
95
|
-
if not is_error
|
|
97
|
+
if not is_error:
|
|
98
|
+
assert data is not None, "data is None"
|
|
96
99
|
try:
|
|
97
|
-
completion = data["
|
|
98
|
-
input_tokens = data["
|
|
99
|
-
output_tokens = data["
|
|
100
|
+
completion = data["choices"][0]["message"]["content"]
|
|
101
|
+
input_tokens = data["usage"]["prompt_tokens"]
|
|
102
|
+
output_tokens = data["usage"]["completion_tokens"]
|
|
103
|
+
if self.logprobs and "logprobs" in data["choices"][0]:
|
|
104
|
+
logprobs = data["choices"][0]["logprobs"]["content"]
|
|
100
105
|
except Exception:
|
|
101
106
|
is_error = True
|
|
102
|
-
error_message = f"Error getting '
|
|
103
|
-
elif mimetype
|
|
107
|
+
error_message = f"Error getting 'choices' and 'usage' from {self.model.name} response."
|
|
108
|
+
elif mimetype and "json" in mimetype.lower():
|
|
104
109
|
is_error = True # expected status is 200, otherwise it's an error
|
|
105
110
|
data = await http_response.json()
|
|
106
111
|
error_message = json.dumps(data)
|
|
107
|
-
|
|
108
112
|
else:
|
|
109
113
|
is_error = True
|
|
110
114
|
text = await http_response.text()
|
|
111
115
|
error_message = text
|
|
112
116
|
|
|
113
|
-
# handle special kinds of errors
|
|
117
|
+
# handle special kinds of errors
|
|
114
118
|
if is_error and error_message is not None:
|
|
115
|
-
if (
|
|
116
|
-
"rate limit" in error_message.lower()
|
|
117
|
-
or "overloaded" in error_message.lower()
|
|
118
|
-
):
|
|
119
|
+
if "rate limit" in error_message.lower() or status_code == 429:
|
|
119
120
|
error_message += " (Rate limit error, triggering cooldown.)"
|
|
120
121
|
self.status_tracker.rate_limit_exceeded()
|
|
121
122
|
if "context length" in error_message:
|
|
@@ -128,6 +129,7 @@ class CohereRequest(APIRequestBase):
|
|
|
128
129
|
is_error=is_error,
|
|
129
130
|
error_message=error_message,
|
|
130
131
|
prompt=self.prompt,
|
|
132
|
+
logprobs=logprobs,
|
|
131
133
|
completion=completion,
|
|
132
134
|
model_internal=self.model_name,
|
|
133
135
|
sampling_params=self.sampling_params,
|
lm_deluge/api_requests/openai.py
CHANGED
|
@@ -58,13 +58,18 @@ class OpenAIRequest(APIRequestBase):
|
|
|
58
58
|
self.request_header = {
|
|
59
59
|
"Authorization": f"Bearer {os.getenv(self.model.api_key_env_var)}"
|
|
60
60
|
}
|
|
61
|
+
|
|
61
62
|
self.request_json = {
|
|
62
63
|
"model": self.model.name,
|
|
63
64
|
"messages": prompt.to_openai(),
|
|
64
65
|
"temperature": sampling_params.temperature,
|
|
65
66
|
"top_p": sampling_params.top_p,
|
|
66
|
-
"max_completion_tokens": sampling_params.max_new_tokens,
|
|
67
67
|
}
|
|
68
|
+
# set max_tokens or max_completion_tokens dep. on provider
|
|
69
|
+
if "cohere" in self.model.api_base:
|
|
70
|
+
self.request_json["max_tokens"] = sampling_params.max_new_tokens
|
|
71
|
+
elif "openai" in self.model.api_base:
|
|
72
|
+
self.request_json["max_completion_tokens"] = sampling_params.max_new_tokens
|
|
68
73
|
if self.model.reasoning_model:
|
|
69
74
|
self.request_json["temperature"] = 1.0
|
|
70
75
|
self.request_json["top_p"] = 1.0
|
|
@@ -84,6 +89,7 @@ class OpenAIRequest(APIRequestBase):
|
|
|
84
89
|
async def handle_response(self, http_response: ClientResponse) -> APIResponse:
|
|
85
90
|
is_error = False
|
|
86
91
|
error_message = None
|
|
92
|
+
thinking = None
|
|
87
93
|
completion = None
|
|
88
94
|
input_tokens = None
|
|
89
95
|
output_tokens = None
|
|
@@ -103,6 +109,8 @@ class OpenAIRequest(APIRequestBase):
|
|
|
103
109
|
assert data is not None, "data is None"
|
|
104
110
|
try:
|
|
105
111
|
completion = data["choices"][0]["message"]["content"]
|
|
112
|
+
if "reasoning_content" in data["choices"][0]["message"]:
|
|
113
|
+
thinking = data["choices"][0]["message"]["reasoning_content"]
|
|
106
114
|
input_tokens = data["usage"]["prompt_tokens"]
|
|
107
115
|
output_tokens = data["usage"]["completion_tokens"]
|
|
108
116
|
if self.logprobs and "logprobs" in data["choices"][0]:
|
|
@@ -135,6 +143,7 @@ class OpenAIRequest(APIRequestBase):
|
|
|
135
143
|
error_message=error_message,
|
|
136
144
|
prompt=self.prompt,
|
|
137
145
|
logprobs=logprobs,
|
|
146
|
+
thinking=thinking,
|
|
138
147
|
completion=completion,
|
|
139
148
|
model_internal=self.model_name,
|
|
140
149
|
sampling_params=self.sampling_params,
|
lm_deluge/client.py
CHANGED
|
@@ -85,6 +85,7 @@ class LLMClient:
|
|
|
85
85
|
def __init__(
|
|
86
86
|
self,
|
|
87
87
|
model_names: list[str],
|
|
88
|
+
*,
|
|
88
89
|
max_requests_per_minute: int,
|
|
89
90
|
max_tokens_per_minute: int,
|
|
90
91
|
max_concurrent_requests: int,
|
|
@@ -345,6 +346,7 @@ class LLMClient:
|
|
|
345
346
|
|
|
346
347
|
# add cache hits back in
|
|
347
348
|
for id, res in zip(cache_hit_ids, cache_hit_results):
|
|
349
|
+
res.cache_hit = True
|
|
348
350
|
results[id] = res
|
|
349
351
|
|
|
350
352
|
if return_completions_only:
|