lm-deluge 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

lm_deluge/__init__.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from .client import LLMClient, SamplingParams, APIResponse
2
+ from .prompt import Conversation, Message
2
3
  import dotenv
3
4
 
4
5
  dotenv.load_dotenv()
5
6
 
6
- __all__ = ["LLMClient", "SamplingParams", "APIResponse"]
7
+ __all__ = ["LLMClient", "SamplingParams", "APIResponse", "Conversation", "Message"]
@@ -119,9 +119,7 @@ class AnthropicRequest(APIRequestBase):
119
119
  if status_code >= 200 and status_code < 300:
120
120
  try:
121
121
  data = await http_response.json()
122
- print("response data:", data)
123
122
  content = data["content"] # [0]["text"]
124
- print("content is length", len(content))
125
123
  for item in content:
126
124
  if item["type"] == "text":
127
125
  completion = item["text"]
@@ -41,6 +41,7 @@ class APIResponse:
41
41
  logprobs: list | None = None
42
42
  finish_reason: str | None = None # make required later
43
43
  cost: float | None = None # calculated automatically
44
+ cache_hit: bool = False # manually set if true
44
45
  # set to true if is_error and should be retried with a different model
45
46
  retry_with_different_model: bool | None = False
46
47
  # set to true if should NOT retry with the same model (unrecoverable error)
@@ -1,18 +1,9 @@
1
- # from .vertex import VertexAnthropicRequest, GeminiRequest
2
- # from .bedrock import BedrockAnthropicRequest, MistralBedrockRequest
3
- # from .deepseek import DeepseekRequest
4
1
  from .openai import OpenAIRequest
5
- from .cohere import CohereRequest
6
2
  from .anthropic import AnthropicRequest
3
+ from .mistral import MistralRequest
7
4
 
8
5
  CLASSES = {
9
6
  "openai": OpenAIRequest,
10
- # "deepseek": DeepseekRequest,
11
7
  "anthropic": AnthropicRequest,
12
- # "vertex_anthropic": VertexAnthropicRequest,
13
- # "vertex_gemini": GeminiRequest,
14
- "cohere": CohereRequest,
15
- # "bedrock_anthropic": BedrockAnthropicRequest,
16
- # "bedrock_mistral": MistralBedrockRequest,
17
- # "mistral": MistralRequest,
8
+ "mistral": MistralRequest,
18
9
  }
@@ -0,0 +1,132 @@
1
+ # # https://docs.cohere.com/reference/chat
2
+ # # https://cohere.com/pricing
3
+ # import asyncio
4
+ # from aiohttp import ClientResponse
5
+ # import json
6
+ # import os
7
+ # from tqdm import tqdm
8
+ # from typing import Callable
9
+ # from lm_deluge.prompt import Conversation
10
+ # from .base import APIRequestBase, APIResponse
11
+
12
+ # from ..tracker import StatusTracker
13
+ # from ..sampling_params import SamplingParams
14
+ # from ..models import APIModel
15
+
16
+
17
+ # class CohereRequest(APIRequestBase):
18
+ # def __init__(
19
+ # self,
20
+ # task_id: int,
21
+ # # should always be 'role', 'content' keys.
22
+ # # internal logic should handle translating to specific API format
23
+ # model_name: str, # must correspond to registry
24
+ # prompt: Conversation,
25
+ # attempts_left: int,
26
+ # status_tracker: StatusTracker,
27
+ # results_arr: list,
28
+ # retry_queue: asyncio.Queue,
29
+ # request_timeout: int = 30,
30
+ # sampling_params: SamplingParams = SamplingParams(),
31
+ # pbar: tqdm | None = None,
32
+ # callback: Callable | None = None,
33
+ # debug: bool = False,
34
+ # all_model_names: list[str] | None = None,
35
+ # all_sampling_params: list[SamplingParams] | None = None,
36
+ # ):
37
+ # super().__init__(
38
+ # task_id=task_id,
39
+ # model_name=model_name,
40
+ # prompt=prompt,
41
+ # attempts_left=attempts_left,
42
+ # status_tracker=status_tracker,
43
+ # retry_queue=retry_queue,
44
+ # results_arr=results_arr,
45
+ # request_timeout=request_timeout,
46
+ # sampling_params=sampling_params,
47
+ # pbar=pbar,
48
+ # callback=callback,
49
+ # debug=debug,
50
+ # all_model_names=all_model_names,
51
+ # all_sampling_params=all_sampling_params,
52
+ # )
53
+ # self.system_message = None
54
+ # self.last_user_message = None
55
+
56
+ # self.model = APIModel.from_registry(model_name)
57
+ # self.url = f"{self.model.api_base}/chat"
58
+ # messages = prompt.to_cohere()
59
+
60
+ # self.request_header = {
61
+ # "Authorization": f"bearer {os.getenv(self.model.api_key_env_var)}",
62
+ # "content-type": "application/json",
63
+ # "accept": "application/json",
64
+ # }
65
+
66
+ # self.request_json = {
67
+ # "model": self.model.name,
68
+ # "messages": messages,
69
+ # "temperature": sampling_params.temperature,
70
+ # "top_p": sampling_params.top_p,
71
+ # "max_tokens": sampling_params.max_new_tokens,
72
+ # }
73
+
74
+ # async def handle_response(self, http_response: ClientResponse) -> APIResponse:
75
+ # is_error = False
76
+ # error_message = None
77
+ # completion = None
78
+ # input_tokens = None
79
+ # output_tokens = None
80
+ # status_code = http_response.status
81
+ # mimetype = http_response.headers.get("Content-Type", None)
82
+ # if status_code >= 200 and status_code < 300:
83
+ # try:
84
+ # data = await http_response.json()
85
+ # except Exception:
86
+ # data = None
87
+ # is_error = True
88
+ # error_message = (
89
+ # f"Error calling .json() on response w/ status {status_code}"
90
+ # )
91
+ # if not is_error and isinstance(data, dict):
92
+ # try:
93
+ # completion = data["text"]
94
+ # input_tokens = data["meta"]["billed_units"]["input_tokens"]
95
+ # output_tokens = data["meta"]["billed_units"]["input_tokens"]
96
+ # except Exception:
97
+ # is_error = True
98
+ # error_message = f"Error getting 'text' or 'meta' from {self.model.name} response."
99
+ # elif mimetype is not None and "json" in mimetype.lower():
100
+ # is_error = True # expected status is 200, otherwise it's an error
101
+ # data = await http_response.json()
102
+ # error_message = json.dumps(data)
103
+
104
+ # else:
105
+ # is_error = True
106
+ # text = await http_response.text()
107
+ # error_message = text
108
+
109
+ # # handle special kinds of errors. TODO: make sure these are correct for anthropic
110
+ # if is_error and error_message is not None:
111
+ # if (
112
+ # "rate limit" in error_message.lower()
113
+ # or "overloaded" in error_message.lower()
114
+ # ):
115
+ # error_message += " (Rate limit error, triggering cooldown.)"
116
+ # self.status_tracker.rate_limit_exceeded()
117
+ # if "context length" in error_message:
118
+ # error_message += " (Context length exceeded, set retries to 0.)"
119
+ # self.attempts_left = 0
120
+
121
+ # return APIResponse(
122
+ # id=self.task_id,
123
+ # status_code=status_code,
124
+ # is_error=is_error,
125
+ # error_message=error_message,
126
+ # prompt=self.prompt,
127
+ # completion=completion,
128
+ # model_internal=self.model_name,
129
+ # sampling_params=self.sampling_params,
130
+ # input_tokens=input_tokens,
131
+ # output_tokens=output_tokens,
132
+ # )
@@ -0,0 +1,361 @@
1
+ # # consider: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/call-gemini-using-openai-library#call-chat-completions-api
2
+ # import asyncio
3
+ # from aiohttp import ClientResponse
4
+ # import json
5
+ # import os
6
+ # import time
7
+ # from tqdm import tqdm
8
+ # from typing import Callable
9
+
10
+ # from lm_deluge.prompt import Conversation
11
+ # from .base import APIRequestBase, APIResponse
12
+ # from ..tracker import StatusTracker
13
+ # from ..sampling_params import SamplingParams
14
+ # from ..models import APIModel
15
+
16
+ # from google.oauth2 import service_account
17
+ # from google.auth.transport.requests import Request
18
+
19
+
20
+ # def get_access_token(service_account_file: str):
21
+ # """
22
+ # Get access token from environment variables if another process/coroutine
23
+ # has already got them, otherwise get from service account file.
24
+ # """
25
+ # LAST_REFRESHED = os.getenv("VERTEX_TOKEN_LAST_REFRESHED", None)
26
+ # LAST_REFRESHED = int(LAST_REFRESHED) if LAST_REFRESHED is not None else 0
27
+ # VERTEX_API_TOKEN = os.getenv("VERTEX_API_TOKEN", None)
28
+
29
+ # if VERTEX_API_TOKEN is not None and time.time() - LAST_REFRESHED < 60 * 50:
30
+ # return VERTEX_API_TOKEN
31
+ # else:
32
+ # credentials = service_account.Credentials.from_service_account_file(
33
+ # service_account_file,
34
+ # scopes=["https://www.googleapis.com/auth/cloud-platform"],
35
+ # )
36
+ # credentials.refresh(Request())
37
+ # token = credentials.token
38
+ # os.environ["VERTEX_API_TOKEN"] = token
39
+ # os.environ["VERTEX_TOKEN_LAST_REFRESHED"] = str(int(time.time()))
40
+
41
+ # return token
42
+
43
+
44
+ # class VertexAnthropicRequest(APIRequestBase):
45
+ # """
46
+ # For Claude on Vertex, you'll also have to set the PROJECT_ID environment variable.
47
+ # """
48
+
49
+ # def __init__(
50
+ # self,
51
+ # task_id: int,
52
+ # model_name: str, # must correspond to registry
53
+ # prompt: Conversation,
54
+ # attempts_left: int,
55
+ # status_tracker: StatusTracker,
56
+ # retry_queue: asyncio.Queue,
57
+ # results_arr: list,
58
+ # request_timeout: int = 30,
59
+ # sampling_params: SamplingParams = SamplingParams(),
60
+ # pbar: tqdm | None = None,
61
+ # callback: Callable | None = None,
62
+ # debug: bool = False,
63
+ # ):
64
+ # super().__init__(
65
+ # task_id=task_id,
66
+ # model_name=model_name,
67
+ # prompt=prompt,
68
+ # attempts_left=attempts_left,
69
+ # status_tracker=status_tracker,
70
+ # retry_queue=retry_queue,
71
+ # results_arr=results_arr,
72
+ # request_timeout=request_timeout,
73
+ # sampling_params=sampling_params,
74
+ # pbar=pbar,
75
+ # callback=callback,
76
+ # debug=debug,
77
+ # )
78
+ # creds = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
79
+ # if not creds:
80
+ # raise RuntimeError(
81
+ # "GOOGLE_APPLICATION_CREDENTIALS not provided in environment"
82
+ # )
83
+ # token = get_access_token(creds)
84
+
85
+ # self.model = APIModel.from_registry(model_name)
86
+ # project_id = os.getenv("PROJECT_ID")
87
+ # region = self.model.sample_region()
88
+
89
+ # endpoint = f"https://{region}-aiplatform.googleapis.com"
90
+ # self.url = f"{endpoint}/v1/projects/{project_id}/locations/{region}/publishers/anthropic/models/{self.model.name}:generateContent"
91
+ # self.request_header = {
92
+ # "Authorization": f"Bearer {token}",
93
+ # "Content-Type": "application/json",
94
+ # }
95
+ # self.system_message, messages = prompt.to_anthropic()
96
+
97
+ # self.request_json = {
98
+ # "anthropic_version": "vertex-2023-10-16",
99
+ # "messages": messages,
100
+ # "temperature": self.sampling_params.temperature,
101
+ # "top_p": self.sampling_params.top_p,
102
+ # "max_tokens": self.sampling_params.max_new_tokens,
103
+ # }
104
+ # if self.system_message is not None:
105
+ # self.request_json["system"] = self.system_message
106
+
107
+ # async def handle_response(self, http_response: ClientResponse) -> APIResponse:
108
+ # is_error = False
109
+ # error_message = None
110
+ # completion = None
111
+ # input_tokens = None
112
+ # output_tokens = None
113
+ # status_code = http_response.status
114
+ # mimetype = http_response.headers.get("Content-Type", None)
115
+ # if status_code >= 200 and status_code < 300:
116
+ # try:
117
+ # data = await http_response.json()
118
+ # completion = data["content"][0]["text"]
119
+ # input_tokens = data["usage"]["input_tokens"]
120
+ # output_tokens = data["usage"]["output_tokens"]
121
+ # except Exception as e:
122
+ # is_error = True
123
+ # error_message = (
124
+ # f"Error calling .json() on response w/ status {status_code}: {e}"
125
+ # )
126
+ # elif "json" in (mimetype or "").lower():
127
+ # is_error = True # expected status is 200, otherwise it's an error
128
+ # data = await http_response.json()
129
+ # error_message = json.dumps(data)
130
+
131
+ # else:
132
+ # is_error = True
133
+ # text = await http_response.text()
134
+ # error_message = text
135
+
136
+ # # handle special kinds of errors. TODO: make sure these are correct for anthropic
137
+ # if is_error and error_message is not None:
138
+ # if (
139
+ # "rate limit" in error_message.lower()
140
+ # or "overloaded" in error_message.lower()
141
+ # or status_code == 429
142
+ # ):
143
+ # error_message += " (Rate limit error, triggering cooldown.)"
144
+ # self.status_tracker.rate_limit_exceeded()
145
+ # if "context length" in error_message:
146
+ # error_message += " (Context length exceeded, set retries to 0.)"
147
+ # self.attempts_left = 0
148
+
149
+ # return APIResponse(
150
+ # id=self.task_id,
151
+ # status_code=status_code,
152
+ # is_error=is_error,
153
+ # error_message=error_message,
154
+ # prompt=self.prompt,
155
+ # completion=completion,
156
+ # model_internal=self.model_name,
157
+ # sampling_params=self.sampling_params,
158
+ # input_tokens=input_tokens,
159
+ # output_tokens=output_tokens,
160
+ # )
161
+
162
+
163
+ # SAFETY_SETTING_CATEGORIES = [
164
+ # "HARM_CATEGORY_DANGEROUS_CONTENT",
165
+ # "HARM_CATEGORY_HARASSMENT",
166
+ # "HARM_CATEGORY_HATE_SPEECH",
167
+ # "HARM_CATEGORY_SEXUALLY_EXPLICIT",
168
+ # ]
169
+
170
+
171
+ # class GeminiRequest(APIRequestBase):
172
+ # """
173
+ # For Gemini, you'll also have to set the PROJECT_ID environment variable.
174
+ # """
175
+
176
+ # def __init__(
177
+ # self,
178
+ # task_id: int,
179
+ # model_name: str, # must correspond to registry
180
+ # prompt: Conversation,
181
+ # attempts_left: int,
182
+ # status_tracker: StatusTracker,
183
+ # retry_queue: asyncio.Queue,
184
+ # results_arr: list,
185
+ # request_timeout: int = 30,
186
+ # sampling_params: SamplingParams = SamplingParams(),
187
+ # pbar: tqdm | None = None,
188
+ # callback: Callable | None = None,
189
+ # debug: bool = False,
190
+ # all_model_names: list[str] | None = None,
191
+ # all_sampling_params: list[SamplingParams] | None = None,
192
+ # ):
193
+ # super().__init__(
194
+ # task_id=task_id,
195
+ # model_name=model_name,
196
+ # prompt=prompt,
197
+ # attempts_left=attempts_left,
198
+ # status_tracker=status_tracker,
199
+ # retry_queue=retry_queue,
200
+ # results_arr=results_arr,
201
+ # request_timeout=request_timeout,
202
+ # sampling_params=sampling_params,
203
+ # pbar=pbar,
204
+ # callback=callback,
205
+ # debug=debug,
206
+ # all_model_names=all_model_names,
207
+ # all_sampling_params=all_sampling_params,
208
+ # )
209
+ # self.model = APIModel.from_registry(model_name)
210
+ # credentials_file = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
211
+ # if not credentials_file:
212
+ # raise RuntimeError(
213
+ # "no credentials file found. ensure you provide a google credentials file and point to it with GOOGLE_APPLICATION_CREDENTIALS environment variable."
214
+ # )
215
+ # token = get_access_token(credentials_file)
216
+ # self.project_id = os.getenv("PROJECT_ID")
217
+ # # sample weighted by region counts
218
+ # self.region = self.model.sample_region()
219
+ # assert self.region is not None, "unable to sample region"
220
+ # self.url = f"https://{self.region}-aiplatform.googleapis.com/v1/projects/{self.project_id}/locations/{self.region}/publishers/google/models/{self.model.name}:generateContent"
221
+
222
+ # self.request_header = {
223
+ # "Authorization": f"Bearer {token}",
224
+ # "Content-Type": "application/json",
225
+ # }
226
+ # self.system_message, contents = prompt.to_gemini()
227
+ # self.request_json = {
228
+ # "contents": contents,
229
+ # "generationConfig": {
230
+ # "stopSequences": [],
231
+ # "temperature": sampling_params.temperature,
232
+ # "maxOutputTokens": sampling_params.max_new_tokens,
233
+ # "topP": sampling_params.top_p,
234
+ # "topK": None,
235
+ # },
236
+ # "safetySettings": [
237
+ # {"category": category, "threshold": "BLOCK_NONE"}
238
+ # for category in SAFETY_SETTING_CATEGORIES
239
+ # ],
240
+ # }
241
+ # if sampling_params.json_mode and self.model.supports_json:
242
+ # self.request_json["generationConfig"]["responseMimeType"] = (
243
+ # "application/json"
244
+ # )
245
+
246
+ # if self.system_message is not None:
247
+ # self.request_json["systemInstruction"] = (
248
+ # {"role": "SYSTEM", "parts": [{"text": self.system_message}]},
249
+ # )
250
+
251
+ # async def handle_response(self, http_response: ClientResponse) -> APIResponse:
252
+ # is_error = False
253
+ # error_message = None
254
+ # completion = None
255
+ # input_tokens = None
256
+ # output_tokens = None
257
+ # finish_reason = None
258
+ # data = None
259
+ # retry_with_different_model = False
260
+ # give_up_if_no_other_models = False
261
+ # status_code = http_response.status
262
+ # mimetype = http_response.headers.get("Content-Type", None)
263
+ # if status_code >= 200 and status_code < 300:
264
+ # try:
265
+ # data = await http_response.json()
266
+ # if "candidates" not in data:
267
+ # is_error = True
268
+ # if "promptFeedback" in data:
269
+ # error_message = "Prompt rejected. Feedback: " + str(
270
+ # data["promptFeedback"]
271
+ # )
272
+ # else:
273
+ # error_message = "No candidates in response."
274
+ # retry_with_different_model = True
275
+ # give_up_if_no_other_models = True
276
+ # else:
277
+ # candidate = data["candidates"][0]
278
+ # finish_reason = candidate["finishReason"]
279
+ # if "content" in candidate:
280
+ # parts = candidate["content"]["parts"]
281
+ # completion = " ".join([part["text"] for part in parts])
282
+ # usage = data["usageMetadata"]
283
+ # input_tokens = usage["promptTokenCount"]
284
+ # output_tokens = usage["candidatesTokenCount"]
285
+ # elif finish_reason == "RECITATION":
286
+ # is_error = True
287
+ # citations = candidate.get("citationMetadata", {}).get(
288
+ # "citations", []
289
+ # )
290
+ # urls = ",".join(
291
+ # [citation.get("uri", "") for citation in citations]
292
+ # )
293
+ # error_message = "Finish reason RECITATION. URLS: " + urls
294
+ # retry_with_different_model = True
295
+ # elif finish_reason == "OTHER":
296
+ # is_error = True
297
+ # error_message = "Finish reason OTHER."
298
+ # retry_with_different_model = True
299
+ # elif finish_reason == "SAFETY":
300
+ # is_error = True
301
+ # error_message = "Finish reason SAFETY."
302
+ # retry_with_different_model = True
303
+ # else:
304
+ # print("Actual structure of response:", data)
305
+ # is_error = True
306
+ # error_message = "No content in response."
307
+ # except Exception as e:
308
+ # is_error = True
309
+ # error_message = f"Error calling .json() on response w/ status {status_code}: {e.__class__} {e}"
310
+ # if isinstance(e, KeyError):
311
+ # print("Actual structure of response:", data)
312
+ # elif "json" in (mimetype or "").lower():
313
+ # is_error = True
314
+ # data = await http_response.json()
315
+ # error_message = json.dumps(data)
316
+ # else:
317
+ # is_error = True
318
+ # text = await http_response.text()
319
+ # error_message = text
320
+
321
+ # old_region = self.region
322
+ # if is_error and error_message is not None:
323
+ # if (
324
+ # "rate limit" in error_message.lower()
325
+ # or "temporarily out of capacity" in error_message.lower()
326
+ # or "exceeded" in error_message.lower()
327
+ # or
328
+ # # 429 code
329
+ # status_code == 429
330
+ # ):
331
+ # error_message += " (Rate limit error, triggering cooldown & retrying with different model.)"
332
+ # self.status_tracker.rate_limit_exceeded()
333
+ # retry_with_different_model = (
334
+ # True # if possible, retry with a different model
335
+ # )
336
+ # if is_error:
337
+ # # change the region in case error is due to region unavailability
338
+ # self.region = self.model.sample_region()
339
+ # assert self.region is not None, "Unable to sample region"
340
+ # self.url = f"https://{self.region}-aiplatform.googleapis.com/v1/projects/{self.project_id}/locations/{self.region}/publishers/google/models/{self.model.name}:generateContent"
341
+
342
+ # return APIResponse(
343
+ # id=self.task_id,
344
+ # status_code=status_code,
345
+ # is_error=is_error,
346
+ # error_message=error_message,
347
+ # prompt=self.prompt,
348
+ # completion=completion,
349
+ # model_internal=self.model_name,
350
+ # sampling_params=self.sampling_params,
351
+ # input_tokens=input_tokens,
352
+ # output_tokens=output_tokens,
353
+ # region=old_region,
354
+ # finish_reason=finish_reason,
355
+ # retry_with_different_model=retry_with_different_model,
356
+ # give_up_if_no_other_models=give_up_if_no_other_models,
357
+ # )
358
+
359
+
360
+ # # class LlamaEndpointRequest(APIRequestBase):
361
+ # # raise NotImplementedError("Llama endpoints are not implemented and never will be because Vertex AI sucks ass.")