lm-deluge 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

@@ -0,0 +1,145 @@
1
+ import asyncio
2
+ import warnings
3
+ from aiohttp import ClientResponse
4
+ import json
5
+ import os
6
+ import time
7
+ from tqdm.auto import tqdm
8
+ from typing import Optional, Callable
9
+
10
+ from .base import APIRequestBase, APIResponse
11
+ from ..prompt import Conversation
12
+ from ..tracker import StatusTracker
13
+ from ..sampling_params import SamplingParams
14
+ from ..models import APIModel
15
+
16
+
17
+ class OpenAIRequest(APIRequestBase):
18
+ def __init__(
19
+ self,
20
+ task_id: int,
21
+ # should always be 'role', 'content' keys.
22
+ # internal logic should handle translating to specific API format
23
+ model_name: str, # must correspond to registry
24
+ prompt: Conversation,
25
+ attempts_left: int,
26
+ status_tracker: StatusTracker,
27
+ retry_queue: asyncio.Queue,
28
+ results_arr: list,
29
+ request_timeout: int = 30,
30
+ sampling_params: SamplingParams = SamplingParams(),
31
+ logprobs: bool = False,
32
+ top_logprobs: Optional[int] = None,
33
+ pbar: Optional[tqdm] = None,
34
+ callback: Optional[Callable] = None,
35
+ debug: bool = False,
36
+ all_model_names: list[str] | None = None,
37
+ all_sampling_params: list[SamplingParams] | None = None,
38
+ ):
39
+ super().__init__(
40
+ task_id=task_id,
41
+ model_name=model_name,
42
+ prompt=prompt,
43
+ attempts_left=attempts_left,
44
+ status_tracker=status_tracker,
45
+ retry_queue=retry_queue,
46
+ results_arr=results_arr,
47
+ request_timeout=request_timeout,
48
+ sampling_params=sampling_params,
49
+ logprobs=logprobs,
50
+ top_logprobs=top_logprobs,
51
+ pbar=pbar,
52
+ callback=callback,
53
+ debug=debug,
54
+ all_model_names=all_model_names,
55
+ all_sampling_params=all_sampling_params,
56
+ )
57
+ self.model = APIModel.from_registry(model_name)
58
+ self.url = f"{self.model.api_base}/chat/completions"
59
+ self.request_header = {
60
+ "Authorization": f"Bearer {os.getenv(self.model.api_key_env_var)}"
61
+ }
62
+ self.request_json = {
63
+ "model": self.model.name,
64
+ "messages": prompt.to_openai(),
65
+ "temperature": sampling_params.temperature,
66
+ "top_p": sampling_params.top_p,
67
+ "max_completion_tokens": sampling_params.max_new_tokens,
68
+ }
69
+ if self.model.reasoning_model:
70
+ self.request_json["temperature"] = 1.0
71
+ self.request_json["top_p"] = 1.0
72
+ self.request_json["reasoning_effort"] = sampling_params.reasoning_effort
73
+ else:
74
+ if sampling_params.reasoning_effort:
75
+ warnings.warn(
76
+ f"Ignoring reasoning_effort param for non-reasoning model: {model_name}"
77
+ )
78
+ if logprobs:
79
+ self.request_json["logprobs"] = True
80
+ if top_logprobs is not None:
81
+ self.request_json["top_logprobs"] = top_logprobs
82
+ if sampling_params.json_mode and self.model.supports_json:
83
+ self.request_json["response_format"] = {"type": "json_object"}
84
+
85
+ async def handle_response(self, http_response: ClientResponse) -> APIResponse:
86
+ is_error = False
87
+ error_message = None
88
+ completion = None
89
+ input_tokens = None
90
+ output_tokens = None
91
+ logprobs = None
92
+ status_code = http_response.status
93
+ mimetype = http_response.headers.get("Content-Type", None)
94
+ data = None
95
+ if status_code >= 200 and status_code < 300:
96
+ try:
97
+ data = await http_response.json()
98
+ except Exception:
99
+ is_error = True
100
+ error_message = (
101
+ f"Error calling .json() on response w/ status {status_code}"
102
+ )
103
+ if not is_error:
104
+ assert data is not None, "data is None"
105
+ try:
106
+ completion = data["choices"][0]["message"]["content"]
107
+ input_tokens = data["usage"]["prompt_tokens"]
108
+ output_tokens = data["usage"]["completion_tokens"]
109
+ if self.logprobs and "logprobs" in data["choices"][0]:
110
+ logprobs = data["choices"][0]["logprobs"]["content"]
111
+ except Exception:
112
+ is_error = True
113
+ error_message = f"Error getting 'choices' and 'usage' from {self.model.name} response."
114
+ elif mimetype and "json" in mimetype.lower():
115
+ is_error = True # expected status is 200, otherwise it's an error
116
+ data = await http_response.json()
117
+ error_message = json.dumps(data)
118
+ else:
119
+ is_error = True
120
+ text = await http_response.text()
121
+ error_message = text
122
+
123
+ # handle special kinds of errors
124
+ if is_error and error_message is not None:
125
+ if "rate limit" in error_message.lower() or status_code == 429:
126
+ error_message += " (Rate limit error, triggering cooldown.)"
127
+ self.status_tracker.time_of_last_rate_limit_error = time.time()
128
+ self.status_tracker.num_rate_limit_errors += 1
129
+ if "context length" in error_message:
130
+ error_message += " (Context length exceeded, set retries to 0.)"
131
+ self.attempts_left = 0
132
+
133
+ return APIResponse(
134
+ id=self.task_id,
135
+ status_code=status_code,
136
+ is_error=is_error,
137
+ error_message=error_message,
138
+ prompt=self.prompt,
139
+ logprobs=logprobs,
140
+ completion=completion,
141
+ model_internal=self.model_name,
142
+ sampling_params=self.sampling_params,
143
+ input_tokens=input_tokens,
144
+ output_tokens=output_tokens,
145
+ )
@@ -0,0 +1,365 @@
1
+ # consider: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/call-gemini-using-openai-library#call-chat-completions-api
2
+ import asyncio
3
+ from aiohttp import ClientResponse
4
+ import json
5
+ import os
6
+ import time
7
+ from tqdm import tqdm
8
+ from typing import Optional, Callable
9
+
10
+ from lm_deluge.prompt import Conversation
11
+ from .base import APIRequestBase, APIResponse
12
+ from ..tracker import StatusTracker
13
+ from ..sampling_params import SamplingParams
14
+ from ..models import APIModel
15
+
16
+ from google.oauth2 import service_account
17
+ from google.auth.transport.requests import Request
18
+
19
+
20
+ def get_access_token(service_account_file: str):
21
+ """
22
+ Get access token from environment variables if another process/coroutine
23
+ has already got them, otherwise get from service account file.
24
+ """
25
+ LAST_REFRESHED = os.getenv("VERTEX_TOKEN_LAST_REFRESHED", None)
26
+ LAST_REFRESHED = int(LAST_REFRESHED) if LAST_REFRESHED is not None else 0
27
+ VERTEX_API_TOKEN = os.getenv("VERTEX_API_TOKEN", None)
28
+
29
+ if VERTEX_API_TOKEN is not None and time.time() - LAST_REFRESHED < 60 * 50:
30
+ return VERTEX_API_TOKEN
31
+ else:
32
+ credentials = service_account.Credentials.from_service_account_file(
33
+ service_account_file,
34
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
35
+ )
36
+ credentials.refresh(Request())
37
+ token = credentials.token
38
+ os.environ["VERTEX_API_TOKEN"] = token
39
+ os.environ["VERTEX_TOKEN_LAST_REFRESHED"] = str(int(time.time()))
40
+
41
+ return token
42
+
43
+
44
+ class VertexAnthropicRequest(APIRequestBase):
45
+ """
46
+ For Claude on Vertex, you'll also have to set the PROJECT_ID environment variable.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ task_id: int,
52
+ model_name: str, # must correspond to registry
53
+ prompt: Conversation,
54
+ attempts_left: int,
55
+ status_tracker: StatusTracker,
56
+ retry_queue: asyncio.Queue,
57
+ results_arr: list,
58
+ request_timeout: int = 30,
59
+ sampling_params: SamplingParams = SamplingParams(),
60
+ pbar: Optional[tqdm] = None,
61
+ callback: Optional[Callable] = None,
62
+ debug: bool = False,
63
+ ):
64
+ super().__init__(
65
+ task_id=task_id,
66
+ model_name=model_name,
67
+ prompt=prompt,
68
+ attempts_left=attempts_left,
69
+ status_tracker=status_tracker,
70
+ retry_queue=retry_queue,
71
+ results_arr=results_arr,
72
+ request_timeout=request_timeout,
73
+ sampling_params=sampling_params,
74
+ pbar=pbar,
75
+ callback=callback,
76
+ debug=debug,
77
+ )
78
+ creds = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
79
+ if not creds:
80
+ raise RuntimeError(
81
+ "GOOGLE_APPLICATION_CREDENTIALS not provided in environment"
82
+ )
83
+ token = get_access_token(creds)
84
+
85
+ self.model = APIModel.from_registry(model_name)
86
+ project_id = os.getenv("PROJECT_ID")
87
+ region = self.model.sample_region()
88
+
89
+ endpoint = f"https://{region}-aiplatform.googleapis.com"
90
+ self.url = f"{endpoint}/v1/projects/{project_id}/locations/{region}/publishers/anthropic/models/{self.model.name}:generateContent"
91
+ self.request_header = {
92
+ "Authorization": f"Bearer {token}",
93
+ "Content-Type": "application/json",
94
+ }
95
+ self.system_message, messages = prompt.to_anthropic()
96
+
97
+ self.request_json = {
98
+ "anthropic_version": "vertex-2023-10-16",
99
+ "messages": messages,
100
+ "temperature": self.sampling_params.temperature,
101
+ "top_p": self.sampling_params.top_p,
102
+ "max_tokens": self.sampling_params.max_new_tokens,
103
+ }
104
+ if self.system_message is not None:
105
+ self.request_json["system"] = self.system_message
106
+
107
+ async def handle_response(self, http_response: ClientResponse) -> APIResponse:
108
+ is_error = False
109
+ error_message = None
110
+ completion = None
111
+ input_tokens = None
112
+ output_tokens = None
113
+ status_code = http_response.status
114
+ mimetype = http_response.headers.get("Content-Type", None)
115
+ if status_code >= 200 and status_code < 300:
116
+ try:
117
+ data = await http_response.json()
118
+ completion = data["content"][0]["text"]
119
+ input_tokens = data["usage"]["input_tokens"]
120
+ output_tokens = data["usage"]["output_tokens"]
121
+ except Exception as e:
122
+ is_error = True
123
+ error_message = (
124
+ f"Error calling .json() on response w/ status {status_code}: {e}"
125
+ )
126
+ elif "json" in (mimetype or "").lower():
127
+ is_error = True # expected status is 200, otherwise it's an error
128
+ data = await http_response.json()
129
+ error_message = json.dumps(data)
130
+
131
+ else:
132
+ is_error = True
133
+ text = await http_response.text()
134
+ error_message = text
135
+
136
+ # handle special kinds of errors. TODO: make sure these are correct for anthropic
137
+ if is_error and error_message is not None:
138
+ if (
139
+ "rate limit" in error_message.lower()
140
+ or "overloaded" in error_message.lower()
141
+ or status_code == 429
142
+ ):
143
+ error_message += " (Rate limit error, triggering cooldown.)"
144
+ self.status_tracker.time_of_last_rate_limit_error = time.time()
145
+ self.status_tracker.num_rate_limit_errors += 1
146
+ if "context length" in error_message:
147
+ error_message += " (Context length exceeded, set retries to 0.)"
148
+ self.attempts_left = 0
149
+
150
+ return APIResponse(
151
+ id=self.task_id,
152
+ status_code=status_code,
153
+ is_error=is_error,
154
+ error_message=error_message,
155
+ prompt=self.prompt,
156
+ completion=completion,
157
+ model_internal=self.model_name,
158
+ sampling_params=self.sampling_params,
159
+ input_tokens=input_tokens,
160
+ output_tokens=output_tokens,
161
+ )
162
+
163
+
164
+ SAFETY_SETTING_CATEGORIES = [
165
+ "HARM_CATEGORY_DANGEROUS_CONTENT",
166
+ "HARM_CATEGORY_HARASSMENT",
167
+ "HARM_CATEGORY_HATE_SPEECH",
168
+ "HARM_CATEGORY_SEXUALLY_EXPLICIT",
169
+ ]
170
+
171
+
172
+ class GeminiRequest(APIRequestBase):
173
+ """
174
+ For Gemini, you'll also have to set the PROJECT_ID environment variable.
175
+ """
176
+
177
+ def __init__(
178
+ self,
179
+ task_id: int,
180
+ model_name: str, # must correspond to registry
181
+ prompt: Conversation,
182
+ attempts_left: int,
183
+ status_tracker: StatusTracker,
184
+ retry_queue: asyncio.Queue,
185
+ results_arr: list,
186
+ request_timeout: int = 30,
187
+ sampling_params: SamplingParams = SamplingParams(),
188
+ pbar: Optional[tqdm] = None,
189
+ callback: Optional[Callable] = None,
190
+ debug: bool = False,
191
+ all_model_names: list[str] | None = None,
192
+ all_sampling_params: list[SamplingParams] | None = None,
193
+ ):
194
+ super().__init__(
195
+ task_id=task_id,
196
+ model_name=model_name,
197
+ prompt=prompt,
198
+ attempts_left=attempts_left,
199
+ status_tracker=status_tracker,
200
+ retry_queue=retry_queue,
201
+ results_arr=results_arr,
202
+ request_timeout=request_timeout,
203
+ sampling_params=sampling_params,
204
+ pbar=pbar,
205
+ callback=callback,
206
+ debug=debug,
207
+ all_model_names=all_model_names,
208
+ all_sampling_params=all_sampling_params,
209
+ )
210
+ self.model = APIModel.from_registry(model_name)
211
+ credentials_file = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
212
+ if not credentials_file:
213
+ raise RuntimeError(
214
+ "no credentials file found. ensure you provide a google credentials file and point to it with GOOGLE_APPLICATION_CREDENTIALS environment variable."
215
+ )
216
+ token = get_access_token(credentials_file)
217
+ self.project_id = os.getenv("PROJECT_ID")
218
+ # sample weighted by region counts
219
+ self.region = self.model.sample_region()
220
+ assert self.region is not None, "unable to sample region"
221
+ self.url = f"https://{self.region}-aiplatform.googleapis.com/v1/projects/{self.project_id}/locations/{self.region}/publishers/google/models/{self.model.name}:generateContent"
222
+
223
+ self.request_header = {
224
+ "Authorization": f"Bearer {token}",
225
+ "Content-Type": "application/json",
226
+ }
227
+ self.system_message, contents = prompt.to_gemini()
228
+ self.request_json = {
229
+ "contents": contents,
230
+ "generationConfig": {
231
+ "stopSequences": [],
232
+ "temperature": sampling_params.temperature,
233
+ "maxOutputTokens": sampling_params.max_new_tokens,
234
+ "topP": sampling_params.top_p,
235
+ "topK": None,
236
+ },
237
+ "safetySettings": [
238
+ {"category": category, "threshold": "BLOCK_NONE"}
239
+ for category in SAFETY_SETTING_CATEGORIES
240
+ ],
241
+ }
242
+ if sampling_params.json_mode and self.model.supports_json:
243
+ self.request_json["generationConfig"]["responseMimeType"] = (
244
+ "application/json"
245
+ )
246
+
247
+ if self.system_message is not None:
248
+ self.request_json["systemInstruction"] = (
249
+ {"role": "SYSTEM", "parts": [{"text": self.system_message}]},
250
+ )
251
+
252
+ async def handle_response(self, http_response: ClientResponse) -> APIResponse:
253
+ is_error = False
254
+ error_message = None
255
+ completion = None
256
+ input_tokens = None
257
+ output_tokens = None
258
+ finish_reason = None
259
+ data = None
260
+ retry_with_different_model = False
261
+ give_up_if_no_other_models = False
262
+ status_code = http_response.status
263
+ mimetype = http_response.headers.get("Content-Type", None)
264
+ if status_code >= 200 and status_code < 300:
265
+ try:
266
+ data = await http_response.json()
267
+ if "candidates" not in data:
268
+ is_error = True
269
+ if "promptFeedback" in data:
270
+ error_message = "Prompt rejected. Feedback: " + str(
271
+ data["promptFeedback"]
272
+ )
273
+ else:
274
+ error_message = "No candidates in response."
275
+ retry_with_different_model = True
276
+ give_up_if_no_other_models = True
277
+ else:
278
+ candidate = data["candidates"][0]
279
+ finish_reason = candidate["finishReason"]
280
+ if "content" in candidate:
281
+ parts = candidate["content"]["parts"]
282
+ completion = " ".join([part["text"] for part in parts])
283
+ usage = data["usageMetadata"]
284
+ input_tokens = usage["promptTokenCount"]
285
+ output_tokens = usage["candidatesTokenCount"]
286
+ elif finish_reason == "RECITATION":
287
+ is_error = True
288
+ citations = candidate.get("citationMetadata", {}).get(
289
+ "citations", []
290
+ )
291
+ urls = ",".join(
292
+ [citation.get("uri", "") for citation in citations]
293
+ )
294
+ error_message = "Finish reason RECITATION. URLS: " + urls
295
+ retry_with_different_model = True
296
+ elif finish_reason == "OTHER":
297
+ is_error = True
298
+ error_message = "Finish reason OTHER."
299
+ retry_with_different_model = True
300
+ elif finish_reason == "SAFETY":
301
+ is_error = True
302
+ error_message = "Finish reason SAFETY."
303
+ retry_with_different_model = True
304
+ else:
305
+ print("Actual structure of response:")
306
+ print(data)
307
+ is_error = True
308
+ error_message = "No content in response."
309
+ except Exception as e:
310
+ is_error = True
311
+ error_message = f"Error calling .json() on response w/ status {status_code}: {e.__class__} {e}"
312
+ if isinstance(e, KeyError):
313
+ print("Actual structure of response:")
314
+ print(data)
315
+ elif "json" in (mimetype or "").lower():
316
+ is_error = True
317
+ data = await http_response.json()
318
+ error_message = json.dumps(data)
319
+ else:
320
+ is_error = True
321
+ text = await http_response.text()
322
+ error_message = text
323
+
324
+ old_region = self.region
325
+ if is_error and error_message is not None:
326
+ if (
327
+ "rate limit" in error_message.lower()
328
+ or "temporarily out of capacity" in error_message.lower()
329
+ or "exceeded" in error_message.lower()
330
+ or
331
+ # 429 code
332
+ status_code == 429
333
+ ):
334
+ error_message += " (Rate limit error, triggering cooldown & retrying with different model.)"
335
+ self.status_tracker.time_of_last_rate_limit_error = time.time()
336
+ self.status_tracker.num_rate_limit_errors += 1
337
+ retry_with_different_model = (
338
+ True # if possible, retry with a different model
339
+ )
340
+ if is_error:
341
+ # change the region in case error is due to region unavailability
342
+ self.region = self.model.sample_region()
343
+ assert self.region is not None, "Unable to sample region"
344
+ self.url = f"https://{self.region}-aiplatform.googleapis.com/v1/projects/{self.project_id}/locations/{self.region}/publishers/google/models/{self.model.name}:generateContent"
345
+
346
+ return APIResponse(
347
+ id=self.task_id,
348
+ status_code=status_code,
349
+ is_error=is_error,
350
+ error_message=error_message,
351
+ prompt=self.prompt,
352
+ completion=completion,
353
+ model_internal=self.model_name,
354
+ sampling_params=self.sampling_params,
355
+ input_tokens=input_tokens,
356
+ output_tokens=output_tokens,
357
+ region=old_region,
358
+ finish_reason=finish_reason,
359
+ retry_with_different_model=retry_with_different_model,
360
+ give_up_if_no_other_models=give_up_if_no_other_models,
361
+ )
362
+
363
+
364
+ # class LlamaEndpointRequest(APIRequestBase):
365
+ # raise NotImplementedError("Llama endpoints are not implemented and never will be because Vertex AI sucks ass.")
lm_deluge/cache.py ADDED
@@ -0,0 +1,144 @@
1
+ import tempfile
2
+ import json
3
+ import sqlite3
4
+ from typing import Any
5
+ from .prompt import Conversation
6
+ from .api_requests.base import APIResponse
7
+
8
+ try:
9
+ import plyvel # type: ignore
10
+ except ImportError:
11
+ plyvel = None
12
+ print("Warning: plyvel not installed, cannot use LevelDB.")
13
+
14
+
15
+ def encode_api_response(response: APIResponse) -> bytes:
16
+ """
17
+ Encode an API response as a string.
18
+ """
19
+ return json.dumps(response.to_dict()).encode()
20
+
21
+
22
+ def decode_api_response(data: bytes) -> APIResponse:
23
+ """
24
+ Decode an API response from a string.
25
+ """
26
+ return APIResponse.from_dict(json.loads(data.decode()))
27
+
28
+
29
+ class DistributedDictCache:
30
+ """
31
+ Use distributed dictionary (e.g. Modal Dict) as a cache.
32
+ Pass in the dictionary object to use. Cache must implement
33
+ 'get' and 'put' methods.
34
+ """
35
+
36
+ def __init__(self, cache: Any, cache_key: str = "default"):
37
+ self.cache = cache
38
+ self.cache_key = cache_key # for namespacing
39
+
40
+ def get(self, prompt: Conversation) -> APIResponse | None:
41
+ """
42
+ Get an API response from the cache.
43
+ """
44
+ data = self.cache.get(f"{self.cache_key}:{prompt.fingerprint}")
45
+ if data is not None:
46
+ return decode_api_response(data)
47
+ return None
48
+
49
+ def put(self, prompt: Conversation, response: APIResponse) -> None:
50
+ """
51
+ Put an API response into the cache.
52
+ """
53
+ key = f"{self.cache_key}:{prompt.fingerprint}"
54
+ self.cache.put(key, encode_api_response(response))
55
+
56
+
57
+ class LevelDBCache:
58
+ """
59
+ Store API responses based on their input messages.
60
+ """
61
+
62
+ def __init__(self, path: str | None = None, cache_key: str = "default"):
63
+ if path is None:
64
+ self.temp_file = tempfile.TemporaryFile(suffix=".db")
65
+ path = self.temp_file.name
66
+ print(f"Using temporary cache at {path}")
67
+ else:
68
+ self.temp_file = None
69
+ self.path = path
70
+ if plyvel is not None:
71
+ self.db = plyvel.DB(path, create_if_missing=True)
72
+ else:
73
+ raise ImportError("plyvel not installed, cannot use LevelDBCache.")
74
+ self.cache_key = cache_key # for namespacing
75
+
76
+ def get(self, prompt: Conversation) -> APIResponse | None:
77
+ """
78
+ Get an API response from the cache.
79
+ """
80
+ key = f"{self.cache_key}:{prompt.fingerprint}"
81
+ data = self.db.get(key.encode())
82
+ if data is not None:
83
+ return decode_api_response(data)
84
+ return None
85
+
86
+ def put(self, prompt: Conversation, response: APIResponse):
87
+ """
88
+ Put an API response into the cache.
89
+ """
90
+ key = f"{self.cache_key}:{prompt.fingerprint}"
91
+ self.db.put(key.encode(), encode_api_response(response))
92
+
93
+ def close(self):
94
+ """
95
+ Close the cache.
96
+ """
97
+ self.db.close()
98
+ if self.temp_file is not None:
99
+ self.temp_file.close()
100
+
101
+
102
+ class SqliteCache:
103
+ """
104
+ Same interface as LevelDBCache, but uses SQLite as KV store instead.
105
+ Good to use on systems where LevelDB installation is problematic.
106
+ """
107
+
108
+ def __init__(self, path: str, cache_key: str = "default"):
109
+ self.path = path
110
+ self.cache_key = cache_key # for namespacing
111
+ self.conn = sqlite3.connect(path)
112
+ self.cursor = self.conn.cursor()
113
+ self.cursor.execute(
114
+ "CREATE TABLE IF NOT EXISTS cache (key TEXT PRIMARY KEY, value BLOB)"
115
+ )
116
+ self.conn.commit()
117
+
118
+ def get(self, prompt: Conversation) -> APIResponse | None:
119
+ """
120
+ Get an API response from the cache.
121
+ """
122
+ key = f"{self.cache_key}:{prompt.fingerprint}"
123
+ self.cursor.execute("SELECT value FROM cache WHERE key=?", (key,))
124
+ data = self.cursor.fetchone()
125
+ if data is not None and len(data) > 0:
126
+ return decode_api_response(data[0])
127
+ return None
128
+
129
+ def put(self, prompt: Conversation, response: APIResponse):
130
+ """
131
+ Put an API response into the cache.
132
+ """
133
+ key = f"{self.cache_key}:{prompt.fingerprint}"
134
+ self.cursor.execute(
135
+ "INSERT OR REPLACE INTO cache (key, value) VALUES (?, ?)",
136
+ (key, encode_api_response(response)),
137
+ )
138
+ self.conn.commit()
139
+
140
+ def close(self):
141
+ """
142
+ Close the cache.
143
+ """
144
+ self.conn.close()