edsl 0.1.53__py3-none-any.whl → 0.1.55__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__init__.py +8 -1
- edsl/__init__original.py +134 -0
- edsl/__version__.py +1 -1
- edsl/agents/agent.py +29 -0
- edsl/agents/agent_list.py +36 -1
- edsl/base/base_class.py +281 -151
- edsl/buckets/__init__.py +8 -3
- edsl/buckets/bucket_collection.py +9 -3
- edsl/buckets/model_buckets.py +4 -2
- edsl/buckets/token_bucket.py +2 -2
- edsl/buckets/token_bucket_client.py +5 -3
- edsl/caching/cache.py +131 -62
- edsl/caching/cache_entry.py +70 -58
- edsl/caching/sql_dict.py +17 -0
- edsl/cli.py +99 -0
- edsl/config/config_class.py +16 -0
- edsl/conversation/__init__.py +31 -0
- edsl/coop/coop.py +276 -242
- edsl/coop/coop_jobs_objects.py +59 -0
- edsl/coop/coop_objects.py +29 -0
- edsl/coop/coop_regular_objects.py +26 -0
- edsl/coop/utils.py +24 -19
- edsl/dataset/dataset.py +338 -101
- edsl/db_list/sqlite_list.py +349 -0
- edsl/inference_services/__init__.py +40 -5
- edsl/inference_services/exceptions.py +11 -0
- edsl/inference_services/services/anthropic_service.py +5 -2
- edsl/inference_services/services/aws_bedrock.py +6 -2
- edsl/inference_services/services/azure_ai.py +6 -2
- edsl/inference_services/services/google_service.py +3 -2
- edsl/inference_services/services/mistral_ai_service.py +6 -2
- edsl/inference_services/services/open_ai_service.py +6 -2
- edsl/inference_services/services/perplexity_service.py +6 -2
- edsl/inference_services/services/test_service.py +105 -7
- edsl/interviews/answering_function.py +167 -59
- edsl/interviews/interview.py +124 -72
- edsl/interviews/interview_task_manager.py +10 -0
- edsl/invigilators/invigilators.py +10 -1
- edsl/jobs/async_interview_runner.py +146 -104
- edsl/jobs/data_structures.py +6 -4
- edsl/jobs/decorators.py +61 -0
- edsl/jobs/fetch_invigilator.py +61 -18
- edsl/jobs/html_table_job_logger.py +14 -2
- edsl/jobs/jobs.py +180 -104
- edsl/jobs/jobs_component_constructor.py +2 -2
- edsl/jobs/jobs_interview_constructor.py +2 -0
- edsl/jobs/jobs_pricing_estimation.py +127 -46
- edsl/jobs/jobs_remote_inference_logger.py +4 -0
- edsl/jobs/jobs_runner_status.py +30 -25
- edsl/jobs/progress_bar_manager.py +79 -0
- edsl/jobs/remote_inference.py +35 -1
- edsl/key_management/key_lookup_builder.py +6 -1
- edsl/language_models/language_model.py +102 -12
- edsl/language_models/model.py +10 -3
- edsl/language_models/price_manager.py +45 -75
- edsl/language_models/registry.py +5 -0
- edsl/language_models/utilities.py +2 -1
- edsl/notebooks/notebook.py +77 -10
- edsl/questions/VALIDATION_README.md +134 -0
- edsl/questions/__init__.py +24 -1
- edsl/questions/exceptions.py +21 -0
- edsl/questions/question_check_box.py +171 -149
- edsl/questions/question_dict.py +243 -51
- edsl/questions/question_multiple_choice_with_other.py +624 -0
- edsl/questions/question_registry.py +2 -1
- edsl/questions/templates/multiple_choice_with_other/__init__.py +0 -0
- edsl/questions/templates/multiple_choice_with_other/answering_instructions.jinja +15 -0
- edsl/questions/templates/multiple_choice_with_other/question_presentation.jinja +17 -0
- edsl/questions/validation_analysis.py +185 -0
- edsl/questions/validation_cli.py +131 -0
- edsl/questions/validation_html_report.py +404 -0
- edsl/questions/validation_logger.py +136 -0
- edsl/results/result.py +63 -16
- edsl/results/results.py +702 -171
- edsl/scenarios/construct_download_link.py +16 -3
- edsl/scenarios/directory_scanner.py +226 -226
- edsl/scenarios/file_methods.py +5 -0
- edsl/scenarios/file_store.py +117 -6
- edsl/scenarios/handlers/__init__.py +5 -1
- edsl/scenarios/handlers/mp4_file_store.py +104 -0
- edsl/scenarios/handlers/webm_file_store.py +104 -0
- edsl/scenarios/scenario.py +120 -101
- edsl/scenarios/scenario_list.py +800 -727
- edsl/scenarios/scenario_list_gc_test.py +146 -0
- edsl/scenarios/scenario_list_memory_test.py +214 -0
- edsl/scenarios/scenario_list_source_refactor.md +35 -0
- edsl/scenarios/scenario_selector.py +5 -4
- edsl/scenarios/scenario_source.py +1990 -0
- edsl/scenarios/tests/test_scenario_list_sources.py +52 -0
- edsl/surveys/survey.py +22 -0
- edsl/tasks/__init__.py +4 -2
- edsl/tasks/task_history.py +198 -36
- edsl/tests/scenarios/test_ScenarioSource.py +51 -0
- edsl/tests/scenarios/test_scenario_list_sources.py +51 -0
- edsl/utilities/__init__.py +2 -1
- edsl/utilities/decorators.py +121 -0
- edsl/utilities/memory_debugger.py +1010 -0
- {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/METADATA +52 -76
- {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/RECORD +102 -78
- edsl/jobs/jobs_runner_asyncio.py +0 -281
- edsl/language_models/unused/fake_openai_service.py +0 -60
- {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/LICENSE +0 -0
- {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/WHEEL +0 -0
- {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,349 @@
|
|
1
|
+
import sqlite3
|
2
|
+
import tempfile
|
3
|
+
import os
|
4
|
+
import json
|
5
|
+
from typing import Any, Callable, Iterable, Iterator, List, Optional
|
6
|
+
from abc import ABC, abstractmethod
|
7
|
+
from collections.abc import MutableSequence
|
8
|
+
|
9
|
+
|
10
|
+
class SQLiteList(MutableSequence, ABC):
|
11
|
+
"""
|
12
|
+
An abstract base class for a MutableSequence that stores its data in a temporary SQLite file.
|
13
|
+
The file is removed when close() is called.
|
14
|
+
Subclasses must implement serialize and deserialize methods.
|
15
|
+
"""
|
16
|
+
|
17
|
+
_TABLE_NAME = "list_data" # Class constant instead of instance parameter
|
18
|
+
|
19
|
+
@abstractmethod
|
20
|
+
def serialize(self, value: Any) -> str:
|
21
|
+
"""Convert a value to a string for storage in SQLite."""
|
22
|
+
pass
|
23
|
+
|
24
|
+
@abstractmethod
|
25
|
+
def deserialize(self, value: str) -> Any:
|
26
|
+
"""Convert a stored string back to its original value."""
|
27
|
+
pass
|
28
|
+
|
29
|
+
def __init__(self, data=None):
|
30
|
+
# Create a temporary file for our SQLite database
|
31
|
+
tmpfile = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
32
|
+
self.db_path = tmpfile.name
|
33
|
+
# Close the file handle immediately; SQLite only needs the path
|
34
|
+
tmpfile.close()
|
35
|
+
|
36
|
+
self.conn = sqlite3.connect(self.db_path)
|
37
|
+
self._create_table_if_not_exists()
|
38
|
+
|
39
|
+
# Initialize with data if provided
|
40
|
+
if data is not None:
|
41
|
+
self._batch_insert(data)
|
42
|
+
|
43
|
+
def _create_table_if_not_exists(self):
|
44
|
+
query = f"CREATE TABLE IF NOT EXISTS {self._TABLE_NAME} (idx INTEGER UNIQUE, value BLOB)"
|
45
|
+
with self.conn:
|
46
|
+
self.conn.execute(query)
|
47
|
+
# Create an index for faster lookups
|
48
|
+
self.conn.execute(f"CREATE INDEX IF NOT EXISTS idx_index ON {self._TABLE_NAME} (idx)")
|
49
|
+
|
50
|
+
def _batch_insert(self, data: Iterable) -> None:
|
51
|
+
"""
|
52
|
+
Insert items one at a time to minimize memory usage.
|
53
|
+
|
54
|
+
Args:
|
55
|
+
data: Iterable containing items to insert
|
56
|
+
"""
|
57
|
+
with self.conn:
|
58
|
+
# Use a single transaction for better performance
|
59
|
+
for idx, item in enumerate(data):
|
60
|
+
# Serialize and insert one item at a time to minimize memory usage
|
61
|
+
serialized = self.serialize(item)
|
62
|
+
self.conn.execute(
|
63
|
+
f"INSERT INTO {self._TABLE_NAME} (idx, value) VALUES (?, ?)",
|
64
|
+
(idx, serialized)
|
65
|
+
)
|
66
|
+
# Clear reference to allow garbage collection
|
67
|
+
del serialized
|
68
|
+
|
69
|
+
def __len__(self):
|
70
|
+
cursor = self.conn.execute(f"SELECT COUNT(*) FROM {self._TABLE_NAME}")
|
71
|
+
(count,) = cursor.fetchone()
|
72
|
+
return count
|
73
|
+
|
74
|
+
def __getitem__(self, index):
|
75
|
+
if isinstance(index, slice):
|
76
|
+
# Handle slice object
|
77
|
+
start, stop, step = index.indices(len(self))
|
78
|
+
if step == 1: # Simple range
|
79
|
+
cursor = self.conn.execute(
|
80
|
+
f"SELECT value FROM {self._TABLE_NAME} WHERE idx >= ? AND idx < ? ORDER BY idx",
|
81
|
+
(start, stop)
|
82
|
+
)
|
83
|
+
return [self.deserialize(row[0]) for row in cursor]
|
84
|
+
else: # Need to handle step
|
85
|
+
indices = range(start, stop, step)
|
86
|
+
return [self[i] for i in indices]
|
87
|
+
|
88
|
+
# Handle integer index
|
89
|
+
if index < 0:
|
90
|
+
index = len(self) + index
|
91
|
+
if not 0 <= index < len(self):
|
92
|
+
raise IndexError("list index out of range")
|
93
|
+
|
94
|
+
cursor = self.conn.execute(
|
95
|
+
f"SELECT value FROM {self._TABLE_NAME} WHERE idx=?", (index,)
|
96
|
+
)
|
97
|
+
row = cursor.fetchone()
|
98
|
+
if row is None:
|
99
|
+
raise IndexError("list index out of range")
|
100
|
+
return self.deserialize(row[0])
|
101
|
+
|
102
|
+
def __setitem__(self, index, value):
|
103
|
+
if index < 0:
|
104
|
+
index = len(self) + index
|
105
|
+
if not 0 <= index < len(self):
|
106
|
+
raise IndexError("list assignment index out of range")
|
107
|
+
|
108
|
+
serialized = self.serialize(value)
|
109
|
+
with self.conn:
|
110
|
+
self.conn.execute(
|
111
|
+
f"UPDATE {self._TABLE_NAME} SET value=? WHERE idx=?",
|
112
|
+
(serialized, index),
|
113
|
+
)
|
114
|
+
|
115
|
+
def __delitem__(self, index):
|
116
|
+
if index < 0:
|
117
|
+
index = len(self) + index
|
118
|
+
if not 0 <= index < len(self):
|
119
|
+
raise IndexError("list assignment index out of range")
|
120
|
+
|
121
|
+
with self.conn:
|
122
|
+
self.conn.execute(f"DELETE FROM {self._TABLE_NAME} WHERE idx=?", (index,))
|
123
|
+
self.conn.execute(
|
124
|
+
f"UPDATE {self._TABLE_NAME} SET idx = idx - 1 WHERE idx > ?", (index,)
|
125
|
+
)
|
126
|
+
|
127
|
+
def insert(self, index, value):
|
128
|
+
"""
|
129
|
+
Inserts a value at the given index by shifting everything
|
130
|
+
at or after `index` up by one in descending order.
|
131
|
+
"""
|
132
|
+
length = len(self)
|
133
|
+
if index < 0:
|
134
|
+
index = 0
|
135
|
+
if index > length:
|
136
|
+
index = length
|
137
|
+
|
138
|
+
serialized = self.serialize(value)
|
139
|
+
with self.conn:
|
140
|
+
# Shift every idx >= `index` up by 1, in descending order
|
141
|
+
for i in reversed(range(index, length)):
|
142
|
+
self.conn.execute(
|
143
|
+
f"UPDATE {self._TABLE_NAME} SET idx = ? WHERE idx = ?",
|
144
|
+
(i + 1, i),
|
145
|
+
)
|
146
|
+
|
147
|
+
# Now insert the new item
|
148
|
+
self.conn.execute(
|
149
|
+
f"INSERT INTO {self._TABLE_NAME} (idx, value) VALUES (?, ?)",
|
150
|
+
(index, serialized),
|
151
|
+
)
|
152
|
+
|
153
|
+
def append(self, value):
|
154
|
+
"""Append a value to the end of the list."""
|
155
|
+
index = len(self)
|
156
|
+
serialized = self.serialize(value)
|
157
|
+
with self.conn:
|
158
|
+
self.conn.execute(
|
159
|
+
f"INSERT INTO {self._TABLE_NAME} (idx, value) VALUES (?, ?)",
|
160
|
+
(index, serialized),
|
161
|
+
)
|
162
|
+
|
163
|
+
def extend(self, values: Iterable) -> None:
|
164
|
+
"""
|
165
|
+
Extend the list by appending all items in the given iterable.
|
166
|
+
|
167
|
+
Processes one item at a time to minimize memory usage.
|
168
|
+
|
169
|
+
Args:
|
170
|
+
values: Iterable of values to append
|
171
|
+
"""
|
172
|
+
start_idx = len(self)
|
173
|
+
with self.conn:
|
174
|
+
# Use a single transaction for efficiency
|
175
|
+
for i, item in enumerate(values):
|
176
|
+
# Serialize and insert one at a time to minimize memory usage
|
177
|
+
serialized = self.serialize(item)
|
178
|
+
self.conn.execute(
|
179
|
+
f"INSERT INTO {self._TABLE_NAME} (idx, value) VALUES (?, ?)",
|
180
|
+
(start_idx + i, serialized)
|
181
|
+
)
|
182
|
+
# Clear reference to allow garbage collection
|
183
|
+
del serialized
|
184
|
+
|
185
|
+
def close(self):
|
186
|
+
"""
|
187
|
+
Close the database connection and remove the temporary file.
|
188
|
+
"""
|
189
|
+
self.conn.close()
|
190
|
+
if os.path.exists(self.db_path):
|
191
|
+
os.remove(self.db_path)
|
192
|
+
|
193
|
+
def __repr__(self):
|
194
|
+
num_items = len(self)
|
195
|
+
preview_count = min(num_items, 10)
|
196
|
+
items = [self[i] for i in range(preview_count)]
|
197
|
+
if preview_count < num_items:
|
198
|
+
return f"{items}... (total {num_items} items)"
|
199
|
+
else:
|
200
|
+
return str(items)
|
201
|
+
|
202
|
+
def __add__(self, other):
|
203
|
+
"""
|
204
|
+
Concatenates two SQLiteLists and returns a new SQLiteList containing all elements.
|
205
|
+
Use memory-efficient copy operation.
|
206
|
+
"""
|
207
|
+
if not isinstance(other, SQLiteList):
|
208
|
+
raise TypeError(
|
209
|
+
f"unsupported operand type(s) for +: '{type(self).__name__}' and '{type(other).__name__}'"
|
210
|
+
)
|
211
|
+
|
212
|
+
# Create a new instance of the same class
|
213
|
+
result = type(self)()
|
214
|
+
|
215
|
+
# Use stream to copy all items from self
|
216
|
+
result.extend(self.stream())
|
217
|
+
|
218
|
+
# Use stream to copy all items from other
|
219
|
+
result.extend(other.stream())
|
220
|
+
|
221
|
+
return result
|
222
|
+
|
223
|
+
def stream(self) -> Iterator[Any]:
|
224
|
+
"""Stream items from the database without loading everything into memory."""
|
225
|
+
cursor = self.conn.execute(f"SELECT value FROM {self._TABLE_NAME} ORDER BY idx")
|
226
|
+
for row in cursor:
|
227
|
+
yield self.deserialize(row[0])
|
228
|
+
|
229
|
+
def stream_batched(self, batch_size: int = 1000) -> Iterator[List[Any]]:
|
230
|
+
"""
|
231
|
+
Stream items in batches to reduce memory usage and improve performance.
|
232
|
+
|
233
|
+
Args:
|
234
|
+
batch_size: Number of items to yield in each batch
|
235
|
+
|
236
|
+
Yields:
|
237
|
+
Lists of deserialized items, with at most batch_size items per list
|
238
|
+
"""
|
239
|
+
cursor = self.conn.execute(f"SELECT value FROM {self._TABLE_NAME} ORDER BY idx")
|
240
|
+
batch = []
|
241
|
+
|
242
|
+
for row in cursor:
|
243
|
+
batch.append(self.deserialize(row[0]))
|
244
|
+
if len(batch) >= batch_size:
|
245
|
+
yield batch
|
246
|
+
batch = []
|
247
|
+
|
248
|
+
if batch: # Don't forget the last batch if it's not full
|
249
|
+
yield batch
|
250
|
+
|
251
|
+
def __iter__(self):
|
252
|
+
"""Iterate over items using streaming."""
|
253
|
+
return self.stream()
|
254
|
+
|
255
|
+
def equals(self, other):
|
256
|
+
"""Memory-efficient comparison of two SQLiteLists."""
|
257
|
+
if len(self) != len(other):
|
258
|
+
return False
|
259
|
+
|
260
|
+
# Compare in batches to reduce memory usage
|
261
|
+
batch_size = 1000
|
262
|
+
self_batches = self.stream_batched(batch_size)
|
263
|
+
other_batches = other.stream_batched(batch_size) if hasattr(other, 'stream_batched') else None
|
264
|
+
|
265
|
+
if other_batches:
|
266
|
+
# Both objects support batched streaming
|
267
|
+
for self_batch, other_batch in zip(self_batches, other_batches):
|
268
|
+
if len(self_batch) != len(other_batch):
|
269
|
+
return False
|
270
|
+
for i in range(len(self_batch)):
|
271
|
+
if self_batch[i] != other_batch[i]:
|
272
|
+
return False
|
273
|
+
return True
|
274
|
+
else:
|
275
|
+
# Fall back to item-by-item comparison
|
276
|
+
for i in range(len(self)):
|
277
|
+
if self[i] != other[i]:
|
278
|
+
return False
|
279
|
+
return True
|
280
|
+
|
281
|
+
def __eq__(self, other):
|
282
|
+
"""Use memory-efficient comparison by default."""
|
283
|
+
return self.equals(other)
|
284
|
+
|
285
|
+
def copy_from(self, source_db_path: str) -> None:
|
286
|
+
"""Copy data from another SQLite database file.
|
287
|
+
|
288
|
+
Args:
|
289
|
+
source_db_path: Path to the source SQLite database file
|
290
|
+
|
291
|
+
Raises:
|
292
|
+
sqlite3.Error: If there's an error accessing the source database
|
293
|
+
"""
|
294
|
+
import sqlite3
|
295
|
+
import time
|
296
|
+
import shutil
|
297
|
+
import tempfile
|
298
|
+
|
299
|
+
# Make a temporary copy of the source database to avoid locking issues
|
300
|
+
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as temp_file:
|
301
|
+
temp_db_path = temp_file.name
|
302
|
+
|
303
|
+
try:
|
304
|
+
# Copy the source database to a temporary file
|
305
|
+
shutil.copy2(source_db_path, temp_db_path)
|
306
|
+
|
307
|
+
# Connect to the copied database
|
308
|
+
source_conn = sqlite3.connect(temp_db_path)
|
309
|
+
source_cursor = source_conn.cursor()
|
310
|
+
|
311
|
+
try:
|
312
|
+
# Check if the table exists in the source database
|
313
|
+
source_cursor.execute(f"SELECT name FROM sqlite_master WHERE type='table' AND name='{self._TABLE_NAME}'")
|
314
|
+
if not source_cursor.fetchone():
|
315
|
+
return # Table doesn't exist in source, nothing to copy
|
316
|
+
|
317
|
+
# Get data from source database
|
318
|
+
source_cursor.execute(f"SELECT idx, value FROM {self._TABLE_NAME} ORDER BY idx")
|
319
|
+
rows = source_cursor.fetchall()
|
320
|
+
|
321
|
+
# Empty the current database
|
322
|
+
with self.conn:
|
323
|
+
self.conn.execute(f"DELETE FROM {self._TABLE_NAME}")
|
324
|
+
|
325
|
+
# Insert data into the destination database
|
326
|
+
with self.conn:
|
327
|
+
self.conn.executemany(
|
328
|
+
f"INSERT INTO {self._TABLE_NAME} (idx, value) VALUES (?, ?)",
|
329
|
+
rows
|
330
|
+
)
|
331
|
+
finally:
|
332
|
+
source_cursor.close()
|
333
|
+
source_conn.close()
|
334
|
+
finally:
|
335
|
+
# Clean up the temporary file
|
336
|
+
import os
|
337
|
+
if os.path.exists(temp_db_path):
|
338
|
+
try:
|
339
|
+
os.unlink(temp_db_path)
|
340
|
+
except:
|
341
|
+
pass
|
342
|
+
|
343
|
+
def __del__(self):
|
344
|
+
"""Clean up the temporary file when the object is deleted."""
|
345
|
+
try:
|
346
|
+
self.conn.close()
|
347
|
+
os.unlink(self.db_path)
|
348
|
+
except:
|
349
|
+
pass
|
@@ -1,13 +1,48 @@
|
|
1
|
-
|
2
|
-
from .
|
1
|
+
# Import only the exceptions first to avoid circular imports
|
2
|
+
from .exceptions import InferenceServiceError
|
3
|
+
|
4
|
+
# Import the classes directly here for external use
|
5
|
+
# While avoiding circular imports with other modules
|
3
6
|
from .inference_service_abc import InferenceServiceABC
|
4
7
|
from .data_structures import AvailableModels
|
5
|
-
from .
|
8
|
+
from .inference_services_collection import InferenceServicesCollection
|
6
9
|
|
10
|
+
# Define __all__ without default - we'll add it later
|
7
11
|
__all__ = [
|
8
12
|
"InferenceServicesCollection",
|
9
|
-
"default",
|
10
13
|
"InferenceServiceABC",
|
11
14
|
"AvailableModels",
|
12
15
|
"InferenceServiceError"
|
13
|
-
]
|
16
|
+
]
|
17
|
+
|
18
|
+
# Better approach: import the default instance immediately in a try/except
|
19
|
+
# If it fails, we'll define the default variable later to make imports work
|
20
|
+
try:
|
21
|
+
from .registry import default
|
22
|
+
except ImportError:
|
23
|
+
# This allows all imports to still work, while deferring the actual
|
24
|
+
# import of registry.default to when it's accessed
|
25
|
+
from types import SimpleNamespace
|
26
|
+
|
27
|
+
# Global placeholder for default registry instance
|
28
|
+
_default_instance = None
|
29
|
+
|
30
|
+
def _get_default():
|
31
|
+
"""Load the default registry instance the first time it's needed"""
|
32
|
+
global _default_instance
|
33
|
+
if _default_instance is None:
|
34
|
+
from .registry import default as registry_default
|
35
|
+
_default_instance = registry_default
|
36
|
+
return _default_instance
|
37
|
+
|
38
|
+
# Define a simple proxy class for default
|
39
|
+
class DefaultProxy(SimpleNamespace):
|
40
|
+
def __getattr__(self, name):
|
41
|
+
# Forward all attribute access to the real default instance
|
42
|
+
return getattr(_get_default(), name)
|
43
|
+
|
44
|
+
# Create the proxy instance that stands in for default
|
45
|
+
default = DefaultProxy()
|
46
|
+
|
47
|
+
# Add default to __all__ now that we've defined it
|
48
|
+
__all__.append("default")
|
@@ -36,6 +36,17 @@ class InferenceServiceError(BaseException):
|
|
36
36
|
relevant_doc = "https://docs.expectedparrot.com/en/latest/remote_inference.html"
|
37
37
|
|
38
38
|
|
39
|
+
|
40
|
+
|
41
|
+
class InferenceServiceIntendedError(InferenceServiceError):
|
42
|
+
"""
|
43
|
+
Test error - this is an error thrown on purpose to test the error handling in the framework.
|
44
|
+
|
45
|
+
"""
|
46
|
+
relevant_doc = "https://docs.expectedparrot.com/en/latest/language_models.html#model-parameters"
|
47
|
+
|
48
|
+
|
49
|
+
|
39
50
|
class InferenceServiceValueError(InferenceServiceError):
|
40
51
|
"""
|
41
52
|
Exception raised when invalid values are provided to inference services.
|
@@ -3,9 +3,10 @@ from typing import Any, Optional, List, TYPE_CHECKING
|
|
3
3
|
from anthropic import AsyncAnthropic
|
4
4
|
|
5
5
|
from ..inference_service_abc import InferenceServiceABC
|
6
|
-
from ...language_models import LanguageModel
|
7
6
|
|
7
|
+
# Use TYPE_CHECKING to avoid circular imports at runtime
|
8
8
|
if TYPE_CHECKING:
|
9
|
+
from ...language_models import LanguageModel
|
9
10
|
from ....scenarios.file_store import FileStore as Files
|
10
11
|
|
11
12
|
|
@@ -40,10 +41,12 @@ class AnthropicService(InferenceServiceABC):
|
|
40
41
|
@classmethod
|
41
42
|
def create_model(
|
42
43
|
cls, model_name: str = "claude-3-opus-20240229", model_class_name=None
|
43
|
-
) -> LanguageModel:
|
44
|
+
) -> 'LanguageModel':
|
44
45
|
if model_class_name is None:
|
45
46
|
model_class_name = cls.to_class_name(model_name)
|
46
47
|
|
48
|
+
# Import LanguageModel only when actually creating a model
|
49
|
+
from ...language_models import LanguageModel
|
47
50
|
class LLM(LanguageModel):
|
48
51
|
"""
|
49
52
|
Child class of LanguageModel for interacting with OpenAI models
|
@@ -3,7 +3,9 @@ from typing import Any, List, Optional, TYPE_CHECKING
|
|
3
3
|
import boto3
|
4
4
|
from botocore.exceptions import ClientError
|
5
5
|
from ..inference_service_abc import InferenceServiceABC
|
6
|
-
|
6
|
+
# Use TYPE_CHECKING to avoid circular imports at runtime
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
from ...language_models import LanguageModel
|
7
9
|
|
8
10
|
if TYPE_CHECKING:
|
9
11
|
from ....scenarios.file_store import FileStore
|
@@ -49,10 +51,12 @@ class AwsBedrockService(InferenceServiceABC):
|
|
49
51
|
@classmethod
|
50
52
|
def create_model(
|
51
53
|
cls, model_name: str = "amazon.titan-tg1-large", model_class_name=None
|
52
|
-
) -> LanguageModel:
|
54
|
+
) -> 'LanguageModel':
|
53
55
|
if model_class_name is None:
|
54
56
|
model_class_name = cls.to_class_name(model_name)
|
55
57
|
|
58
|
+
# Import LanguageModel only when actually creating a model
|
59
|
+
from ...language_models import LanguageModel
|
56
60
|
class LLM(LanguageModel):
|
57
61
|
"""
|
58
62
|
Child class of LanguageModel for interacting with AWS Bedrock models.
|
@@ -2,7 +2,9 @@ import os
|
|
2
2
|
from typing import Any, Optional, List, TYPE_CHECKING
|
3
3
|
from openai import AsyncAzureOpenAI
|
4
4
|
from ..inference_service_abc import InferenceServiceABC
|
5
|
-
|
5
|
+
# Use TYPE_CHECKING to avoid circular imports at runtime
|
6
|
+
if TYPE_CHECKING:
|
7
|
+
from ...language_models import LanguageModel
|
6
8
|
|
7
9
|
if TYPE_CHECKING:
|
8
10
|
from ....scenarios.file_store import FileStore
|
@@ -98,10 +100,12 @@ class AzureAIService(InferenceServiceABC):
|
|
98
100
|
@classmethod
|
99
101
|
def create_model(
|
100
102
|
cls, model_name: str = "azureai", model_class_name=None
|
101
|
-
) -> LanguageModel:
|
103
|
+
) -> 'LanguageModel':
|
102
104
|
if model_class_name is None:
|
103
105
|
model_class_name = cls.to_class_name(model_name)
|
104
106
|
|
107
|
+
# Import LanguageModel only when actually creating a model
|
108
|
+
from ...language_models import LanguageModel
|
105
109
|
class LLM(LanguageModel):
|
106
110
|
"""
|
107
111
|
Child class of LanguageModel for interacting with Azure OpenAI models.
|
@@ -7,9 +7,9 @@ from google.api_core.exceptions import InvalidArgument
|
|
7
7
|
|
8
8
|
# from ...exceptions.general import MissingAPIKeyError
|
9
9
|
from ..inference_service_abc import InferenceServiceABC
|
10
|
-
|
11
|
-
|
10
|
+
# Use TYPE_CHECKING to avoid circular imports at runtime
|
12
11
|
if TYPE_CHECKING:
|
12
|
+
from ...language_models import LanguageModel
|
13
13
|
from ....scenarios.file_store import FileStore as Files
|
14
14
|
#from ...coop import Coop
|
15
15
|
|
@@ -65,6 +65,7 @@ class GoogleService(InferenceServiceABC):
|
|
65
65
|
if model_class_name is None:
|
66
66
|
model_class_name = cls.to_class_name(model_name)
|
67
67
|
|
68
|
+
# Import LanguageModel only when actually creating a model
|
68
69
|
from ...language_models import LanguageModel
|
69
70
|
|
70
71
|
class LLM(LanguageModel):
|
@@ -4,7 +4,9 @@ from mistralai import Mistral
|
|
4
4
|
|
5
5
|
|
6
6
|
from ..inference_service_abc import InferenceServiceABC
|
7
|
-
|
7
|
+
# Use TYPE_CHECKING to avoid circular imports at runtime
|
8
|
+
if TYPE_CHECKING:
|
9
|
+
from ...language_models import LanguageModel
|
8
10
|
|
9
11
|
if TYPE_CHECKING:
|
10
12
|
from ....scenarios.file_store import FileStore
|
@@ -64,10 +66,12 @@ class MistralAIService(InferenceServiceABC):
|
|
64
66
|
@classmethod
|
65
67
|
def create_model(
|
66
68
|
cls, model_name: str = "mistral", model_class_name=None
|
67
|
-
) -> LanguageModel:
|
69
|
+
) -> 'LanguageModel':
|
68
70
|
if model_class_name is None:
|
69
71
|
model_class_name = cls.to_class_name(model_name)
|
70
72
|
|
73
|
+
# Import LanguageModel only when actually creating a model
|
74
|
+
from ...language_models import LanguageModel
|
71
75
|
class LLM(LanguageModel):
|
72
76
|
"""
|
73
77
|
Child class of LanguageModel for interacting with Mistral models.
|
@@ -5,7 +5,9 @@ import os
|
|
5
5
|
import openai
|
6
6
|
|
7
7
|
from ..inference_service_abc import InferenceServiceABC
|
8
|
-
|
8
|
+
# Use TYPE_CHECKING to avoid circular imports at runtime
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from ...language_models import LanguageModel
|
9
11
|
from ..rate_limits_cache import rate_limits
|
10
12
|
|
11
13
|
if TYPE_CHECKING:
|
@@ -110,10 +112,12 @@ class OpenAIService(InferenceServiceABC):
|
|
110
112
|
return cls._models_list_cache
|
111
113
|
|
112
114
|
@classmethod
|
113
|
-
def create_model(cls, model_name, model_class_name=None) -> LanguageModel:
|
115
|
+
def create_model(cls, model_name, model_class_name=None) -> 'LanguageModel':
|
114
116
|
if model_class_name is None:
|
115
117
|
model_class_name = cls.to_class_name(model_name)
|
116
118
|
|
119
|
+
# Import LanguageModel only when actually creating a model
|
120
|
+
from ...language_models import LanguageModel
|
117
121
|
class LLM(LanguageModel):
|
118
122
|
"""
|
119
123
|
Child class of LanguageModel for interacting with OpenAI models
|
@@ -1,7 +1,9 @@
|
|
1
1
|
from typing import Any, List, Optional, TYPE_CHECKING
|
2
2
|
from ..rate_limits_cache import rate_limits
|
3
3
|
|
4
|
-
|
4
|
+
# Use TYPE_CHECKING to avoid circular imports at runtime
|
5
|
+
if TYPE_CHECKING:
|
6
|
+
from ...language_models import LanguageModel
|
5
7
|
|
6
8
|
from .open_ai_service import OpenAIService
|
7
9
|
|
@@ -40,10 +42,12 @@ class PerplexityService(OpenAIService):
|
|
40
42
|
@classmethod
|
41
43
|
def create_model(
|
42
44
|
cls, model_name="llama-3.1-sonar-large-128k-online", model_class_name=None
|
43
|
-
) -> LanguageModel:
|
45
|
+
) -> 'LanguageModel':
|
44
46
|
if model_class_name is None:
|
45
47
|
model_class_name = cls.to_class_name(model_name)
|
46
48
|
|
49
|
+
# Import LanguageModel only when actually creating a model
|
50
|
+
from ...language_models import LanguageModel
|
47
51
|
class LLM(LanguageModel):
|
48
52
|
"""
|
49
53
|
Child class of LanguageModel for interacting with Perplexity models
|