lm-deluge 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

lm_deluge/__init__.py ADDED
@@ -0,0 +1,6 @@
1
+ from .client import LLMClient, SamplingParams, APIResponse
2
+ import dotenv
3
+
4
+ dotenv.load_dotenv()
5
+
6
+ __all__ = ["LLMClient", "SamplingParams", "APIResponse"]
@@ -0,0 +1,3 @@
1
+ from .base import create_api_request
2
+
3
+ __all__ = ["create_api_request"]
@@ -0,0 +1,177 @@
1
+ import asyncio
2
+ from aiohttp import ClientResponse
3
+ import json
4
+ import os
5
+ import warnings
6
+ import time
7
+ from tqdm import tqdm
8
+ from typing import Optional, Callable
9
+
10
+ from lm_deluge.prompt import Conversation
11
+ from .base import APIRequestBase, APIResponse
12
+
13
+ from ..tracker import StatusTracker
14
+ from ..sampling_params import SamplingParams
15
+ from ..models import APIModel
16
+
17
+
18
+ class AnthropicRequest(APIRequestBase):
19
+ def __init__(
20
+ self,
21
+ task_id: int,
22
+ # should always be 'role', 'content' keys.
23
+ # internal logic should handle translating to specific API format
24
+ model_name: str, # must correspond to registry
25
+ prompt: Conversation,
26
+ attempts_left: int,
27
+ status_tracker: StatusTracker,
28
+ retry_queue: asyncio.Queue,
29
+ results_arr: list,
30
+ request_timeout: int = 30,
31
+ sampling_params: SamplingParams = SamplingParams(),
32
+ pbar: Optional[tqdm] = None,
33
+ callback: Optional[Callable] = None,
34
+ debug: bool = False,
35
+ # for retries
36
+ all_model_names: list[str] | None = None,
37
+ all_sampling_params: list[SamplingParams] | None = None,
38
+ ):
39
+ super().__init__(
40
+ task_id=task_id,
41
+ model_name=model_name,
42
+ prompt=prompt,
43
+ attempts_left=attempts_left,
44
+ status_tracker=status_tracker,
45
+ retry_queue=retry_queue,
46
+ results_arr=results_arr,
47
+ request_timeout=request_timeout,
48
+ sampling_params=sampling_params,
49
+ pbar=pbar,
50
+ callback=callback,
51
+ debug=debug,
52
+ all_model_names=all_model_names,
53
+ all_sampling_params=all_sampling_params,
54
+ )
55
+ self.model = APIModel.from_registry(model_name)
56
+ self.url = f"{self.model.api_base}/messages"
57
+
58
+ self.system_message, messages = prompt.to_anthropic()
59
+ self.request_header = {
60
+ "x-api-key": os.getenv(self.model.api_key_env_var),
61
+ "anthropic-version": "2023-06-01",
62
+ "content-type": "application/json",
63
+ }
64
+
65
+ self.request_json = {
66
+ "model": self.model.name,
67
+ "messages": messages,
68
+ "temperature": self.sampling_params.temperature,
69
+ "top_p": self.sampling_params.top_p,
70
+ "max_tokens": self.sampling_params.max_new_tokens,
71
+ }
72
+ # handle thinking
73
+ if self.model.reasoning_model:
74
+ if sampling_params.reasoning_effort:
75
+ # translate reasoning effort of low, medium, high to budget tokens
76
+ budget = {"low": 1024, "medium": 4096, "high": 16384}.get(
77
+ sampling_params.reasoning_effort
78
+ )
79
+ self.request_json["thinking"] = {
80
+ "type": "enabled",
81
+ "budget_tokens": budget,
82
+ }
83
+ self.request_json.pop("top_p")
84
+ self.request_json["temperature"] = 1.0
85
+ self.request_json["max_tokens"] += (
86
+ budget # assume max tokens is max completion tokens
87
+ )
88
+ else:
89
+ # no thinking
90
+ self.request_json["thinking"] = {"type": "disabled"}
91
+ else:
92
+ if sampling_params.reasoning_effort:
93
+ warnings.warn(
94
+ f"Ignoring reasoning_effort param for non-reasoning model: {model_name}"
95
+ )
96
+ if self.system_message is not None:
97
+ self.request_json["system"] = self.system_message
98
+
99
+ # print("request data:", self.request_json)
100
+
101
+ async def handle_response(self, http_response: ClientResponse) -> APIResponse:
102
+ is_error = False
103
+ error_message = None
104
+ thinking = None
105
+ completion = None
106
+ input_tokens = None
107
+ output_tokens = None
108
+ status_code = http_response.status
109
+ mimetype = http_response.headers.get("Content-Type", None)
110
+ rate_limits = {}
111
+ for header in [
112
+ "anthropic-ratelimit-requests-limit",
113
+ "anthropic-ratelimit-requests-remaining",
114
+ "anthropic-ratelimit-requests-reset",
115
+ "anthropic-ratelimit-tokens-limit",
116
+ "anthropic-ratelimit-tokens-remaining",
117
+ "anthropic-ratelimit-tokens-reset",
118
+ ]:
119
+ rate_limits[header] = http_response.headers.get(header, None)
120
+ if self.debug:
121
+ print(f"Rate limits: {rate_limits}")
122
+ if status_code >= 200 and status_code < 300:
123
+ try:
124
+ data = await http_response.json()
125
+ print("response data:", data)
126
+ content = data["content"] # [0]["text"]
127
+ print("content is length", len(content))
128
+ for item in content:
129
+ if item["type"] == "text":
130
+ completion = item["text"]
131
+ elif item["type"] == "thinking":
132
+ thinking = item["thinking"]
133
+ elif item["type"] == "tool_use":
134
+ continue # TODO: implement and report tool use
135
+ input_tokens = data["usage"]["input_tokens"]
136
+ output_tokens = data["usage"]["output_tokens"]
137
+ except Exception as e:
138
+ is_error = True
139
+ error_message = (
140
+ f"Error calling .json() on response w/ status {status_code}: {e}"
141
+ )
142
+ elif mimetype and "json" in mimetype.lower():
143
+ is_error = True # expected status is 200, otherwise it's an error
144
+ data = await http_response.json()
145
+ error_message = json.dumps(data)
146
+
147
+ else:
148
+ is_error = True
149
+ text = await http_response.text()
150
+ error_message = text
151
+
152
+ # handle special kinds of errors. TODO: make sure these are correct for anthropic
153
+ if is_error and error_message is not None:
154
+ if (
155
+ "rate limit" in error_message.lower()
156
+ or "overloaded" in error_message.lower()
157
+ ):
158
+ error_message += " (Rate limit error, triggering cooldown.)"
159
+ self.status_tracker.time_of_last_rate_limit_error = time.time()
160
+ self.status_tracker.num_rate_limit_errors += 1
161
+ if "context length" in error_message:
162
+ error_message += " (Context length exceeded, set retries to 0.)"
163
+ self.attempts_left = 0
164
+
165
+ return APIResponse(
166
+ id=self.task_id,
167
+ status_code=status_code,
168
+ is_error=is_error,
169
+ error_message=error_message,
170
+ prompt=self.prompt,
171
+ completion=completion,
172
+ thinking=thinking,
173
+ model_internal=self.model_name,
174
+ sampling_params=self.sampling_params,
175
+ input_tokens=input_tokens,
176
+ output_tokens=output_tokens,
177
+ )
@@ -0,0 +1,375 @@
1
+ import aiohttp
2
+ import asyncio
3
+ import json
4
+ import random
5
+ from tqdm import tqdm
6
+ from dataclasses import dataclass
7
+ from abc import ABC, abstractmethod
8
+ from typing import Optional, Callable
9
+
10
+ from lm_deluge.prompt import Conversation
11
+
12
+ from ..tracker import StatusTracker
13
+ from ..sampling_params import SamplingParams
14
+ from ..models import APIModel
15
+ from ..errors import raise_if_modal_exception
16
+ from aiohttp import ClientResponse
17
+
18
+
19
+ @dataclass
20
+ class APIResponse:
21
+ # request information
22
+ id: int # should be unique to the request within a given prompt-processing call
23
+ model_internal: str # our internal model tag
24
+ prompt: Conversation
25
+ sampling_params: SamplingParams
26
+
27
+ # http response information
28
+ status_code: int | None
29
+ is_error: Optional[bool]
30
+ error_message: Optional[str]
31
+
32
+ # completion information
33
+ completion: Optional[str]
34
+ input_tokens: Optional[int]
35
+ output_tokens: Optional[int]
36
+
37
+ # optional or calculated automatically
38
+ thinking: Optional[str] = None # if model shows thinking tokens
39
+ model_external: Optional[str] = None # the model tag used by the API
40
+ region: Optional[str] = None
41
+ logprobs: Optional[list] = None
42
+ finish_reason: Optional[str] = None # make required later
43
+ cost: Optional[float] = None # calculated automatically
44
+ # set to true if is_error and should be retried with a different model
45
+ retry_with_different_model: Optional[bool] = False
46
+ # set to true if should NOT retry with the same model (unrecoverable error)
47
+ give_up_if_no_other_models: Optional[bool] = False
48
+
49
+ def __post_init__(self):
50
+ # calculate cost & get external model name
51
+ self.id = int(self.id)
52
+ api_model = APIModel.from_registry(self.model_internal)
53
+ self.model_external = api_model.name
54
+ self.cost = None
55
+ if (
56
+ self.input_tokens is not None
57
+ and self.output_tokens is not None
58
+ and api_model.input_cost is not None
59
+ and api_model.output_cost is not None
60
+ ):
61
+ self.cost = (
62
+ self.input_tokens * api_model.input_cost / 1e6
63
+ + self.output_tokens * api_model.output_cost / 1e6
64
+ )
65
+ elif self.completion is not None:
66
+ print(
67
+ f"Warning: Completion provided without token counts for model {self.model_internal}."
68
+ )
69
+
70
+ def to_dict(self):
71
+ return {
72
+ "id": self.id,
73
+ "model_internal": self.model_internal,
74
+ "model_external": self.model_external,
75
+ "region": self.region,
76
+ "prompt": self.prompt.to_log(), # destroys image if present
77
+ "sampling_params": self.sampling_params.__dict__,
78
+ "status_code": self.status_code,
79
+ "is_error": self.is_error,
80
+ "error_message": self.error_message,
81
+ "completion": self.completion,
82
+ "input_tokens": self.input_tokens,
83
+ "output_tokens": self.output_tokens,
84
+ "finish_reason": self.finish_reason,
85
+ "cost": self.cost,
86
+ }
87
+
88
+ @classmethod
89
+ def from_dict(cls, data: dict):
90
+ return cls(
91
+ id=data.get("id", random.randint(0, 1_000_000_000)),
92
+ model_internal=data["model_internal"],
93
+ model_external=data["model_external"],
94
+ region=data["region"],
95
+ prompt=Conversation.from_log(data["prompt"]),
96
+ sampling_params=SamplingParams(**data["sampling_params"]),
97
+ status_code=data["status_code"],
98
+ is_error=data["is_error"],
99
+ error_message=data["error_message"],
100
+ input_tokens=data["input_tokens"],
101
+ output_tokens=data["output_tokens"],
102
+ completion=data["completion"],
103
+ finish_reason=data["finish_reason"],
104
+ cost=data["cost"],
105
+ )
106
+
107
+ def write_to_file(self, filename):
108
+ """
109
+ Writes the APIResponse as a line to a file.
110
+ If file exists, appends to it.
111
+ """
112
+ with open(filename, "a") as f:
113
+ f.write(json.dumps(self.to_dict()) + "\n")
114
+
115
+
116
+ class APIRequestBase(ABC):
117
+ """
118
+ Class for handling API requests. All model/endpoint-specific logic should be
119
+ handled by overriding __init__ and implementing the handle_response method.
120
+ For call_api to work, the __init__ must handle setting:
121
+ - url
122
+ - request_header
123
+ - request_json
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ task_id: int,
129
+ # should always be 'role', 'content' keys.
130
+ # internal logic should handle translating to specific API format
131
+ model_name: str, # must correspond to registry
132
+ prompt: Conversation,
133
+ attempts_left: int,
134
+ status_tracker: StatusTracker,
135
+ retry_queue: asyncio.Queue,
136
+ # needed in order to retry with a different model and not throw the output away
137
+ results_arr: list["APIRequestBase"],
138
+ request_timeout: int = 30,
139
+ sampling_params: SamplingParams = SamplingParams(),
140
+ logprobs: bool = False,
141
+ top_logprobs: Optional[int] = None,
142
+ pbar: Optional[tqdm] = None,
143
+ callback: Optional[Callable] = None,
144
+ debug: bool = False,
145
+ all_model_names: list[str] | None = None,
146
+ all_sampling_params: list[SamplingParams] | None = None,
147
+ ):
148
+ if all_model_names is None:
149
+ raise ValueError("all_model_names must be provided.")
150
+ self.task_id = task_id
151
+ self.model_name = model_name
152
+ self.system_prompt = None
153
+ self.prompt = prompt
154
+ self.attempts_left = attempts_left
155
+ self.status_tracker = status_tracker
156
+ self.retry_queue = retry_queue
157
+ self.request_timeout = request_timeout
158
+ self.sampling_params = sampling_params
159
+ self.logprobs = logprobs # len(completion) logprobs
160
+ self.top_logprobs = top_logprobs
161
+ self.pbar = pbar
162
+ self.callback = callback
163
+ self.num_tokens = prompt.count_tokens(sampling_params.max_new_tokens)
164
+ self.results_arr = results_arr
165
+ self.debug = debug
166
+ self.all_model_names = all_model_names
167
+ self.all_sampling_params = all_sampling_params
168
+ self.result = [] # list of APIResponse objects from each attempt
169
+
170
+ # these should be set in the __init__ of the subclass
171
+ self.url = None
172
+ self.request_header = None
173
+ self.request_json = None
174
+ self.region = None
175
+
176
+ def increment_pbar(self):
177
+ if self.pbar is not None:
178
+ self.pbar.update(1)
179
+
180
+ def call_callback(self):
181
+ if self.callback is not None:
182
+ # the APIResponse in self.result includes all the information
183
+ self.callback(self.result[-1], self.status_tracker)
184
+
185
+ def handle_success(self, data):
186
+ self.call_callback()
187
+ self.increment_pbar()
188
+ self.status_tracker.num_tasks_in_progress -= 1
189
+ self.status_tracker.num_tasks_succeeded += 1
190
+
191
+ def handle_error(self, create_new_request=False, give_up_if_no_other_models=False):
192
+ """
193
+ If create_new_request is True, will create a new API request (so that it
194
+ has a chance of being sent to a different model). If false, will retry
195
+ the same request.
196
+ """
197
+ last_result: APIResponse = self.result[-1]
198
+ error_to_print = f"Error task {self.task_id}. "
199
+ error_to_print += (
200
+ f"Model: {last_result.model_internal} Code: {last_result.status_code}, "
201
+ )
202
+ if self.region is not None:
203
+ error_to_print += f"Region: {self.region}, "
204
+ error_to_print += f"Message: {last_result.error_message}."
205
+ print(error_to_print)
206
+ if self.attempts_left > 0:
207
+ self.attempts_left -= 1
208
+ if not create_new_request:
209
+ self.retry_queue.put_nowait(self)
210
+ return
211
+ else:
212
+ # make sure we have another model to send it to besides the current one
213
+ if self.all_model_names is None or len(self.all_model_names) < 2:
214
+ if give_up_if_no_other_models:
215
+ print(
216
+ f"No other models to try for task {self.task_id}. Giving up."
217
+ )
218
+ self.status_tracker.num_tasks_in_progress -= 1
219
+ self.status_tracker.num_tasks_failed += 1
220
+ else:
221
+ print(
222
+ f"No other models to try for task {self.task_id}. Retrying with same model."
223
+ )
224
+ self.retry_queue.put_nowait(self)
225
+ else:
226
+ # two things to change: model_name and sampling_params
227
+ new_model_name = self.model_name
228
+ new_model_idx = 0
229
+ while new_model_name == self.model_name:
230
+ new_model_idx = random.randint(0, len(self.all_model_names) - 1)
231
+ new_model_name = self.all_model_names[new_model_idx]
232
+
233
+ if isinstance(self.all_sampling_params, list):
234
+ new_sampling_params = self.all_sampling_params[new_model_idx]
235
+ elif isinstance(self.all_sampling_params, SamplingParams):
236
+ new_sampling_params = self.all_sampling_params
237
+ elif self.all_sampling_params is None:
238
+ new_sampling_params = self.sampling_params
239
+ else:
240
+ new_sampling_params = self.sampling_params
241
+
242
+ print("Creating new request with model", new_model_name)
243
+ new_request = create_api_request(
244
+ task_id=self.task_id,
245
+ model_name=new_model_name,
246
+ prompt=self.prompt,
247
+ attempts_left=self.attempts_left,
248
+ status_tracker=self.status_tracker,
249
+ retry_queue=self.retry_queue,
250
+ results_arr=self.results_arr,
251
+ request_timeout=self.request_timeout,
252
+ sampling_params=new_sampling_params,
253
+ logprobs=self.logprobs,
254
+ top_logprobs=self.top_logprobs,
255
+ pbar=self.pbar,
256
+ callback=self.callback,
257
+ all_model_names=self.all_model_names,
258
+ all_sampling_params=self.all_sampling_params,
259
+ )
260
+ # PROBLEM: new request is never put into results array, so we can't get the result.
261
+ self.retry_queue.put_nowait(new_request)
262
+ # SOLUTION: just need to make sure it's deduplicated by task_id later.
263
+ self.results_arr.append(new_request)
264
+ else:
265
+ print(f"Task {self.task_id} out of tries.")
266
+ self.status_tracker.num_tasks_in_progress -= 1
267
+ self.status_tracker.num_tasks_failed += 1
268
+
269
+ async def call_api(self):
270
+ try:
271
+ self.status_tracker.total_requests += 1
272
+ timeout = aiohttp.ClientTimeout(total=self.request_timeout)
273
+ async with aiohttp.ClientSession(timeout=timeout) as session:
274
+ assert self.url is not None, "URL is not set"
275
+ async with session.post(
276
+ url=self.url,
277
+ headers=self.request_header,
278
+ json=self.request_json,
279
+ ) as http_response:
280
+ response: APIResponse = await self.handle_response(http_response)
281
+
282
+ self.result.append(response)
283
+ if response.is_error:
284
+ self.handle_error(
285
+ create_new_request=response.retry_with_different_model or False,
286
+ give_up_if_no_other_models=response.give_up_if_no_other_models
287
+ or False,
288
+ )
289
+ else:
290
+ self.handle_success(response)
291
+
292
+ except asyncio.TimeoutError:
293
+ self.result.append(
294
+ APIResponse(
295
+ id=self.task_id,
296
+ model_internal=self.model_name,
297
+ prompt=self.prompt,
298
+ sampling_params=self.sampling_params,
299
+ status_code=None,
300
+ is_error=True,
301
+ error_message="Request timed out (terminated by client).",
302
+ completion=None,
303
+ input_tokens=None,
304
+ output_tokens=None,
305
+ )
306
+ )
307
+ self.handle_error(create_new_request=False)
308
+
309
+ except Exception as e:
310
+ raise_if_modal_exception(e)
311
+ # print(f"Unexpected error {type(e).__name__}: {str(e) or 'No message.'}")
312
+ self.result.append(
313
+ APIResponse(
314
+ id=self.task_id,
315
+ model_internal=self.model_name,
316
+ prompt=self.prompt,
317
+ sampling_params=self.sampling_params,
318
+ status_code=None,
319
+ is_error=True,
320
+ error_message=f"Unexpected {type(e).__name__}: {str(e) or 'No message.'}",
321
+ completion=None,
322
+ input_tokens=None,
323
+ output_tokens=None,
324
+ )
325
+ )
326
+ # maybe consider making True?
327
+ self.handle_error(create_new_request=False)
328
+
329
+ @abstractmethod
330
+ async def handle_response(self, http_response: ClientResponse) -> APIResponse:
331
+ raise NotImplementedError
332
+
333
+
334
+ def create_api_request(
335
+ task_id: int,
336
+ model_name: str,
337
+ prompt: Conversation,
338
+ attempts_left: int,
339
+ status_tracker: StatusTracker,
340
+ retry_queue: asyncio.Queue,
341
+ results_arr: list["APIRequestBase"],
342
+ request_timeout: int = 30,
343
+ sampling_params: SamplingParams = SamplingParams(),
344
+ logprobs: bool = False,
345
+ top_logprobs: Optional[int] = None,
346
+ pbar: Optional[tqdm] = None,
347
+ callback: Optional[Callable] = None,
348
+ all_model_names: list[str] | None = None,
349
+ all_sampling_params: list[SamplingParams] | None = None,
350
+ ) -> APIRequestBase:
351
+ from .common import CLASSES # circular import so made it lazy, does this work?
352
+
353
+ model_obj = APIModel.from_registry(model_name)
354
+ request_class = CLASSES.get(model_obj.api_spec, None)
355
+ if request_class is None:
356
+ raise ValueError(f"Unsupported API spec: {model_obj.api_spec}")
357
+ kwargs = (
358
+ {} if not logprobs else {"logprobs": logprobs, "top_logprobs": top_logprobs}
359
+ )
360
+ return request_class(
361
+ task_id=task_id,
362
+ model_name=model_name,
363
+ prompt=prompt,
364
+ attempts_left=attempts_left,
365
+ status_tracker=status_tracker,
366
+ retry_queue=retry_queue,
367
+ results_arr=results_arr,
368
+ request_timeout=request_timeout,
369
+ sampling_params=sampling_params,
370
+ pbar=pbar,
371
+ callback=callback,
372
+ all_model_names=all_model_names,
373
+ all_sampling_params=all_sampling_params,
374
+ **kwargs,
375
+ )