edsl 0.1.54__py3-none-any.whl → 0.1.56__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. edsl/__init__.py +8 -1
  2. edsl/__init__original.py +134 -0
  3. edsl/__version__.py +1 -1
  4. edsl/agents/agent.py +29 -0
  5. edsl/agents/agent_list.py +36 -1
  6. edsl/base/base_class.py +281 -151
  7. edsl/base/data_transfer_models.py +15 -4
  8. edsl/buckets/__init__.py +8 -3
  9. edsl/buckets/bucket_collection.py +9 -3
  10. edsl/buckets/model_buckets.py +4 -2
  11. edsl/buckets/token_bucket.py +2 -2
  12. edsl/buckets/token_bucket_client.py +5 -3
  13. edsl/caching/cache.py +131 -62
  14. edsl/caching/cache_entry.py +70 -58
  15. edsl/caching/sql_dict.py +17 -0
  16. edsl/cli.py +99 -0
  17. edsl/config/config_class.py +16 -0
  18. edsl/conversation/__init__.py +31 -0
  19. edsl/coop/coop.py +276 -242
  20. edsl/coop/coop_jobs_objects.py +59 -0
  21. edsl/coop/coop_objects.py +29 -0
  22. edsl/coop/coop_regular_objects.py +26 -0
  23. edsl/coop/utils.py +24 -19
  24. edsl/dataset/dataset.py +338 -101
  25. edsl/dataset/dataset_operations_mixin.py +216 -180
  26. edsl/db_list/sqlite_list.py +349 -0
  27. edsl/inference_services/__init__.py +40 -5
  28. edsl/inference_services/exceptions.py +11 -0
  29. edsl/inference_services/services/anthropic_service.py +5 -2
  30. edsl/inference_services/services/aws_bedrock.py +6 -2
  31. edsl/inference_services/services/azure_ai.py +6 -2
  32. edsl/inference_services/services/google_service.py +7 -3
  33. edsl/inference_services/services/mistral_ai_service.py +6 -2
  34. edsl/inference_services/services/open_ai_service.py +6 -2
  35. edsl/inference_services/services/perplexity_service.py +6 -2
  36. edsl/inference_services/services/test_service.py +94 -5
  37. edsl/interviews/answering_function.py +167 -59
  38. edsl/interviews/interview.py +124 -72
  39. edsl/interviews/interview_task_manager.py +10 -0
  40. edsl/interviews/request_token_estimator.py +8 -0
  41. edsl/invigilators/invigilators.py +35 -13
  42. edsl/jobs/async_interview_runner.py +146 -104
  43. edsl/jobs/data_structures.py +6 -4
  44. edsl/jobs/decorators.py +61 -0
  45. edsl/jobs/fetch_invigilator.py +61 -18
  46. edsl/jobs/html_table_job_logger.py +14 -2
  47. edsl/jobs/jobs.py +180 -104
  48. edsl/jobs/jobs_component_constructor.py +2 -2
  49. edsl/jobs/jobs_interview_constructor.py +2 -0
  50. edsl/jobs/jobs_pricing_estimation.py +154 -113
  51. edsl/jobs/jobs_remote_inference_logger.py +4 -0
  52. edsl/jobs/jobs_runner_status.py +30 -25
  53. edsl/jobs/progress_bar_manager.py +79 -0
  54. edsl/jobs/remote_inference.py +35 -1
  55. edsl/key_management/key_lookup_builder.py +6 -1
  56. edsl/language_models/language_model.py +110 -12
  57. edsl/language_models/model.py +10 -3
  58. edsl/language_models/price_manager.py +176 -71
  59. edsl/language_models/registry.py +5 -0
  60. edsl/notebooks/notebook.py +77 -10
  61. edsl/questions/VALIDATION_README.md +134 -0
  62. edsl/questions/__init__.py +24 -1
  63. edsl/questions/exceptions.py +21 -0
  64. edsl/questions/question_dict.py +201 -16
  65. edsl/questions/question_multiple_choice_with_other.py +624 -0
  66. edsl/questions/question_registry.py +2 -1
  67. edsl/questions/templates/multiple_choice_with_other/__init__.py +0 -0
  68. edsl/questions/templates/multiple_choice_with_other/answering_instructions.jinja +15 -0
  69. edsl/questions/templates/multiple_choice_with_other/question_presentation.jinja +17 -0
  70. edsl/questions/validation_analysis.py +185 -0
  71. edsl/questions/validation_cli.py +131 -0
  72. edsl/questions/validation_html_report.py +404 -0
  73. edsl/questions/validation_logger.py +136 -0
  74. edsl/results/result.py +115 -46
  75. edsl/results/results.py +702 -171
  76. edsl/scenarios/construct_download_link.py +16 -3
  77. edsl/scenarios/directory_scanner.py +226 -226
  78. edsl/scenarios/file_methods.py +5 -0
  79. edsl/scenarios/file_store.py +150 -9
  80. edsl/scenarios/handlers/__init__.py +5 -1
  81. edsl/scenarios/handlers/mp4_file_store.py +104 -0
  82. edsl/scenarios/handlers/webm_file_store.py +104 -0
  83. edsl/scenarios/scenario.py +120 -101
  84. edsl/scenarios/scenario_list.py +800 -727
  85. edsl/scenarios/scenario_list_gc_test.py +146 -0
  86. edsl/scenarios/scenario_list_memory_test.py +214 -0
  87. edsl/scenarios/scenario_list_source_refactor.md +35 -0
  88. edsl/scenarios/scenario_selector.py +5 -4
  89. edsl/scenarios/scenario_source.py +1990 -0
  90. edsl/scenarios/tests/test_scenario_list_sources.py +52 -0
  91. edsl/surveys/survey.py +22 -0
  92. edsl/tasks/__init__.py +4 -2
  93. edsl/tasks/task_history.py +198 -36
  94. edsl/tests/scenarios/test_ScenarioSource.py +51 -0
  95. edsl/tests/scenarios/test_scenario_list_sources.py +51 -0
  96. edsl/utilities/__init__.py +2 -1
  97. edsl/utilities/decorators.py +121 -0
  98. edsl/utilities/memory_debugger.py +1010 -0
  99. {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/METADATA +51 -76
  100. {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/RECORD +103 -79
  101. edsl/jobs/jobs_runner_asyncio.py +0 -281
  102. edsl/language_models/unused/fake_openai_service.py +0 -60
  103. {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/LICENSE +0 -0
  104. {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/WHEEL +0 -0
  105. {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,349 @@
1
+ import sqlite3
2
+ import tempfile
3
+ import os
4
+ import json
5
+ from typing import Any, Callable, Iterable, Iterator, List, Optional
6
+ from abc import ABC, abstractmethod
7
+ from collections.abc import MutableSequence
8
+
9
+
10
+ class SQLiteList(MutableSequence, ABC):
11
+ """
12
+ An abstract base class for a MutableSequence that stores its data in a temporary SQLite file.
13
+ The file is removed when close() is called.
14
+ Subclasses must implement serialize and deserialize methods.
15
+ """
16
+
17
+ _TABLE_NAME = "list_data" # Class constant instead of instance parameter
18
+
19
+ @abstractmethod
20
+ def serialize(self, value: Any) -> str:
21
+ """Convert a value to a string for storage in SQLite."""
22
+ pass
23
+
24
+ @abstractmethod
25
+ def deserialize(self, value: str) -> Any:
26
+ """Convert a stored string back to its original value."""
27
+ pass
28
+
29
+ def __init__(self, data=None):
30
+ # Create a temporary file for our SQLite database
31
+ tmpfile = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
32
+ self.db_path = tmpfile.name
33
+ # Close the file handle immediately; SQLite only needs the path
34
+ tmpfile.close()
35
+
36
+ self.conn = sqlite3.connect(self.db_path)
37
+ self._create_table_if_not_exists()
38
+
39
+ # Initialize with data if provided
40
+ if data is not None:
41
+ self._batch_insert(data)
42
+
43
+ def _create_table_if_not_exists(self):
44
+ query = f"CREATE TABLE IF NOT EXISTS {self._TABLE_NAME} (idx INTEGER UNIQUE, value BLOB)"
45
+ with self.conn:
46
+ self.conn.execute(query)
47
+ # Create an index for faster lookups
48
+ self.conn.execute(f"CREATE INDEX IF NOT EXISTS idx_index ON {self._TABLE_NAME} (idx)")
49
+
50
+ def _batch_insert(self, data: Iterable) -> None:
51
+ """
52
+ Insert items one at a time to minimize memory usage.
53
+
54
+ Args:
55
+ data: Iterable containing items to insert
56
+ """
57
+ with self.conn:
58
+ # Use a single transaction for better performance
59
+ for idx, item in enumerate(data):
60
+ # Serialize and insert one item at a time to minimize memory usage
61
+ serialized = self.serialize(item)
62
+ self.conn.execute(
63
+ f"INSERT INTO {self._TABLE_NAME} (idx, value) VALUES (?, ?)",
64
+ (idx, serialized)
65
+ )
66
+ # Clear reference to allow garbage collection
67
+ del serialized
68
+
69
+ def __len__(self):
70
+ cursor = self.conn.execute(f"SELECT COUNT(*) FROM {self._TABLE_NAME}")
71
+ (count,) = cursor.fetchone()
72
+ return count
73
+
74
+ def __getitem__(self, index):
75
+ if isinstance(index, slice):
76
+ # Handle slice object
77
+ start, stop, step = index.indices(len(self))
78
+ if step == 1: # Simple range
79
+ cursor = self.conn.execute(
80
+ f"SELECT value FROM {self._TABLE_NAME} WHERE idx >= ? AND idx < ? ORDER BY idx",
81
+ (start, stop)
82
+ )
83
+ return [self.deserialize(row[0]) for row in cursor]
84
+ else: # Need to handle step
85
+ indices = range(start, stop, step)
86
+ return [self[i] for i in indices]
87
+
88
+ # Handle integer index
89
+ if index < 0:
90
+ index = len(self) + index
91
+ if not 0 <= index < len(self):
92
+ raise IndexError("list index out of range")
93
+
94
+ cursor = self.conn.execute(
95
+ f"SELECT value FROM {self._TABLE_NAME} WHERE idx=?", (index,)
96
+ )
97
+ row = cursor.fetchone()
98
+ if row is None:
99
+ raise IndexError("list index out of range")
100
+ return self.deserialize(row[0])
101
+
102
+ def __setitem__(self, index, value):
103
+ if index < 0:
104
+ index = len(self) + index
105
+ if not 0 <= index < len(self):
106
+ raise IndexError("list assignment index out of range")
107
+
108
+ serialized = self.serialize(value)
109
+ with self.conn:
110
+ self.conn.execute(
111
+ f"UPDATE {self._TABLE_NAME} SET value=? WHERE idx=?",
112
+ (serialized, index),
113
+ )
114
+
115
+ def __delitem__(self, index):
116
+ if index < 0:
117
+ index = len(self) + index
118
+ if not 0 <= index < len(self):
119
+ raise IndexError("list assignment index out of range")
120
+
121
+ with self.conn:
122
+ self.conn.execute(f"DELETE FROM {self._TABLE_NAME} WHERE idx=?", (index,))
123
+ self.conn.execute(
124
+ f"UPDATE {self._TABLE_NAME} SET idx = idx - 1 WHERE idx > ?", (index,)
125
+ )
126
+
127
+ def insert(self, index, value):
128
+ """
129
+ Inserts a value at the given index by shifting everything
130
+ at or after `index` up by one in descending order.
131
+ """
132
+ length = len(self)
133
+ if index < 0:
134
+ index = 0
135
+ if index > length:
136
+ index = length
137
+
138
+ serialized = self.serialize(value)
139
+ with self.conn:
140
+ # Shift every idx >= `index` up by 1, in descending order
141
+ for i in reversed(range(index, length)):
142
+ self.conn.execute(
143
+ f"UPDATE {self._TABLE_NAME} SET idx = ? WHERE idx = ?",
144
+ (i + 1, i),
145
+ )
146
+
147
+ # Now insert the new item
148
+ self.conn.execute(
149
+ f"INSERT INTO {self._TABLE_NAME} (idx, value) VALUES (?, ?)",
150
+ (index, serialized),
151
+ )
152
+
153
+ def append(self, value):
154
+ """Append a value to the end of the list."""
155
+ index = len(self)
156
+ serialized = self.serialize(value)
157
+ with self.conn:
158
+ self.conn.execute(
159
+ f"INSERT INTO {self._TABLE_NAME} (idx, value) VALUES (?, ?)",
160
+ (index, serialized),
161
+ )
162
+
163
+ def extend(self, values: Iterable) -> None:
164
+ """
165
+ Extend the list by appending all items in the given iterable.
166
+
167
+ Processes one item at a time to minimize memory usage.
168
+
169
+ Args:
170
+ values: Iterable of values to append
171
+ """
172
+ start_idx = len(self)
173
+ with self.conn:
174
+ # Use a single transaction for efficiency
175
+ for i, item in enumerate(values):
176
+ # Serialize and insert one at a time to minimize memory usage
177
+ serialized = self.serialize(item)
178
+ self.conn.execute(
179
+ f"INSERT INTO {self._TABLE_NAME} (idx, value) VALUES (?, ?)",
180
+ (start_idx + i, serialized)
181
+ )
182
+ # Clear reference to allow garbage collection
183
+ del serialized
184
+
185
+ def close(self):
186
+ """
187
+ Close the database connection and remove the temporary file.
188
+ """
189
+ self.conn.close()
190
+ if os.path.exists(self.db_path):
191
+ os.remove(self.db_path)
192
+
193
+ def __repr__(self):
194
+ num_items = len(self)
195
+ preview_count = min(num_items, 10)
196
+ items = [self[i] for i in range(preview_count)]
197
+ if preview_count < num_items:
198
+ return f"{items}... (total {num_items} items)"
199
+ else:
200
+ return str(items)
201
+
202
+ def __add__(self, other):
203
+ """
204
+ Concatenates two SQLiteLists and returns a new SQLiteList containing all elements.
205
+ Use memory-efficient copy operation.
206
+ """
207
+ if not isinstance(other, SQLiteList):
208
+ raise TypeError(
209
+ f"unsupported operand type(s) for +: '{type(self).__name__}' and '{type(other).__name__}'"
210
+ )
211
+
212
+ # Create a new instance of the same class
213
+ result = type(self)()
214
+
215
+ # Use stream to copy all items from self
216
+ result.extend(self.stream())
217
+
218
+ # Use stream to copy all items from other
219
+ result.extend(other.stream())
220
+
221
+ return result
222
+
223
+ def stream(self) -> Iterator[Any]:
224
+ """Stream items from the database without loading everything into memory."""
225
+ cursor = self.conn.execute(f"SELECT value FROM {self._TABLE_NAME} ORDER BY idx")
226
+ for row in cursor:
227
+ yield self.deserialize(row[0])
228
+
229
+ def stream_batched(self, batch_size: int = 1000) -> Iterator[List[Any]]:
230
+ """
231
+ Stream items in batches to reduce memory usage and improve performance.
232
+
233
+ Args:
234
+ batch_size: Number of items to yield in each batch
235
+
236
+ Yields:
237
+ Lists of deserialized items, with at most batch_size items per list
238
+ """
239
+ cursor = self.conn.execute(f"SELECT value FROM {self._TABLE_NAME} ORDER BY idx")
240
+ batch = []
241
+
242
+ for row in cursor:
243
+ batch.append(self.deserialize(row[0]))
244
+ if len(batch) >= batch_size:
245
+ yield batch
246
+ batch = []
247
+
248
+ if batch: # Don't forget the last batch if it's not full
249
+ yield batch
250
+
251
+ def __iter__(self):
252
+ """Iterate over items using streaming."""
253
+ return self.stream()
254
+
255
+ def equals(self, other):
256
+ """Memory-efficient comparison of two SQLiteLists."""
257
+ if len(self) != len(other):
258
+ return False
259
+
260
+ # Compare in batches to reduce memory usage
261
+ batch_size = 1000
262
+ self_batches = self.stream_batched(batch_size)
263
+ other_batches = other.stream_batched(batch_size) if hasattr(other, 'stream_batched') else None
264
+
265
+ if other_batches:
266
+ # Both objects support batched streaming
267
+ for self_batch, other_batch in zip(self_batches, other_batches):
268
+ if len(self_batch) != len(other_batch):
269
+ return False
270
+ for i in range(len(self_batch)):
271
+ if self_batch[i] != other_batch[i]:
272
+ return False
273
+ return True
274
+ else:
275
+ # Fall back to item-by-item comparison
276
+ for i in range(len(self)):
277
+ if self[i] != other[i]:
278
+ return False
279
+ return True
280
+
281
+ def __eq__(self, other):
282
+ """Use memory-efficient comparison by default."""
283
+ return self.equals(other)
284
+
285
+ def copy_from(self, source_db_path: str) -> None:
286
+ """Copy data from another SQLite database file.
287
+
288
+ Args:
289
+ source_db_path: Path to the source SQLite database file
290
+
291
+ Raises:
292
+ sqlite3.Error: If there's an error accessing the source database
293
+ """
294
+ import sqlite3
295
+ import time
296
+ import shutil
297
+ import tempfile
298
+
299
+ # Make a temporary copy of the source database to avoid locking issues
300
+ with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as temp_file:
301
+ temp_db_path = temp_file.name
302
+
303
+ try:
304
+ # Copy the source database to a temporary file
305
+ shutil.copy2(source_db_path, temp_db_path)
306
+
307
+ # Connect to the copied database
308
+ source_conn = sqlite3.connect(temp_db_path)
309
+ source_cursor = source_conn.cursor()
310
+
311
+ try:
312
+ # Check if the table exists in the source database
313
+ source_cursor.execute(f"SELECT name FROM sqlite_master WHERE type='table' AND name='{self._TABLE_NAME}'")
314
+ if not source_cursor.fetchone():
315
+ return # Table doesn't exist in source, nothing to copy
316
+
317
+ # Get data from source database
318
+ source_cursor.execute(f"SELECT idx, value FROM {self._TABLE_NAME} ORDER BY idx")
319
+ rows = source_cursor.fetchall()
320
+
321
+ # Empty the current database
322
+ with self.conn:
323
+ self.conn.execute(f"DELETE FROM {self._TABLE_NAME}")
324
+
325
+ # Insert data into the destination database
326
+ with self.conn:
327
+ self.conn.executemany(
328
+ f"INSERT INTO {self._TABLE_NAME} (idx, value) VALUES (?, ?)",
329
+ rows
330
+ )
331
+ finally:
332
+ source_cursor.close()
333
+ source_conn.close()
334
+ finally:
335
+ # Clean up the temporary file
336
+ import os
337
+ if os.path.exists(temp_db_path):
338
+ try:
339
+ os.unlink(temp_db_path)
340
+ except:
341
+ pass
342
+
343
+ def __del__(self):
344
+ """Clean up the temporary file when the object is deleted."""
345
+ try:
346
+ self.conn.close()
347
+ os.unlink(self.db_path)
348
+ except:
349
+ pass
@@ -1,13 +1,48 @@
1
- from .inference_services_collection import InferenceServicesCollection
2
- from .registry import default
1
+ # Import only the exceptions first to avoid circular imports
2
+ from .exceptions import InferenceServiceError
3
+
4
+ # Import the classes directly here for external use
5
+ # While avoiding circular imports with other modules
3
6
  from .inference_service_abc import InferenceServiceABC
4
7
  from .data_structures import AvailableModels
5
- from .exceptions import InferenceServiceError
8
+ from .inference_services_collection import InferenceServicesCollection
6
9
 
10
+ # Define __all__ without default - we'll add it later
7
11
  __all__ = [
8
12
  "InferenceServicesCollection",
9
- "default",
10
13
  "InferenceServiceABC",
11
14
  "AvailableModels",
12
15
  "InferenceServiceError"
13
- ]
16
+ ]
17
+
18
+ # Better approach: import the default instance immediately in a try/except
19
+ # If it fails, we'll define the default variable later to make imports work
20
+ try:
21
+ from .registry import default
22
+ except ImportError:
23
+ # This allows all imports to still work, while deferring the actual
24
+ # import of registry.default to when it's accessed
25
+ from types import SimpleNamespace
26
+
27
+ # Global placeholder for default registry instance
28
+ _default_instance = None
29
+
30
+ def _get_default():
31
+ """Load the default registry instance the first time it's needed"""
32
+ global _default_instance
33
+ if _default_instance is None:
34
+ from .registry import default as registry_default
35
+ _default_instance = registry_default
36
+ return _default_instance
37
+
38
+ # Define a simple proxy class for default
39
+ class DefaultProxy(SimpleNamespace):
40
+ def __getattr__(self, name):
41
+ # Forward all attribute access to the real default instance
42
+ return getattr(_get_default(), name)
43
+
44
+ # Create the proxy instance that stands in for default
45
+ default = DefaultProxy()
46
+
47
+ # Add default to __all__ now that we've defined it
48
+ __all__.append("default")
@@ -36,6 +36,17 @@ class InferenceServiceError(BaseException):
36
36
  relevant_doc = "https://docs.expectedparrot.com/en/latest/remote_inference.html"
37
37
 
38
38
 
39
+
40
+
41
+ class InferenceServiceIntendedError(InferenceServiceError):
42
+ """
43
+ Test error - this is an error thrown on purpose to test the error handling in the framework.
44
+
45
+ """
46
+ relevant_doc = "https://docs.expectedparrot.com/en/latest/language_models.html#model-parameters"
47
+
48
+
49
+
39
50
  class InferenceServiceValueError(InferenceServiceError):
40
51
  """
41
52
  Exception raised when invalid values are provided to inference services.
@@ -3,9 +3,10 @@ from typing import Any, Optional, List, TYPE_CHECKING
3
3
  from anthropic import AsyncAnthropic
4
4
 
5
5
  from ..inference_service_abc import InferenceServiceABC
6
- from ...language_models import LanguageModel
7
6
 
7
+ # Use TYPE_CHECKING to avoid circular imports at runtime
8
8
  if TYPE_CHECKING:
9
+ from ...language_models import LanguageModel
9
10
  from ....scenarios.file_store import FileStore as Files
10
11
 
11
12
 
@@ -40,10 +41,12 @@ class AnthropicService(InferenceServiceABC):
40
41
  @classmethod
41
42
  def create_model(
42
43
  cls, model_name: str = "claude-3-opus-20240229", model_class_name=None
43
- ) -> LanguageModel:
44
+ ) -> 'LanguageModel':
44
45
  if model_class_name is None:
45
46
  model_class_name = cls.to_class_name(model_name)
46
47
 
48
+ # Import LanguageModel only when actually creating a model
49
+ from ...language_models import LanguageModel
47
50
  class LLM(LanguageModel):
48
51
  """
49
52
  Child class of LanguageModel for interacting with OpenAI models
@@ -3,7 +3,9 @@ from typing import Any, List, Optional, TYPE_CHECKING
3
3
  import boto3
4
4
  from botocore.exceptions import ClientError
5
5
  from ..inference_service_abc import InferenceServiceABC
6
- from ...language_models import LanguageModel
6
+ # Use TYPE_CHECKING to avoid circular imports at runtime
7
+ if TYPE_CHECKING:
8
+ from ...language_models import LanguageModel
7
9
 
8
10
  if TYPE_CHECKING:
9
11
  from ....scenarios.file_store import FileStore
@@ -49,10 +51,12 @@ class AwsBedrockService(InferenceServiceABC):
49
51
  @classmethod
50
52
  def create_model(
51
53
  cls, model_name: str = "amazon.titan-tg1-large", model_class_name=None
52
- ) -> LanguageModel:
54
+ ) -> 'LanguageModel':
53
55
  if model_class_name is None:
54
56
  model_class_name = cls.to_class_name(model_name)
55
57
 
58
+ # Import LanguageModel only when actually creating a model
59
+ from ...language_models import LanguageModel
56
60
  class LLM(LanguageModel):
57
61
  """
58
62
  Child class of LanguageModel for interacting with AWS Bedrock models.
@@ -2,7 +2,9 @@ import os
2
2
  from typing import Any, Optional, List, TYPE_CHECKING
3
3
  from openai import AsyncAzureOpenAI
4
4
  from ..inference_service_abc import InferenceServiceABC
5
- from ...language_models import LanguageModel
5
+ # Use TYPE_CHECKING to avoid circular imports at runtime
6
+ if TYPE_CHECKING:
7
+ from ...language_models import LanguageModel
6
8
 
7
9
  if TYPE_CHECKING:
8
10
  from ....scenarios.file_store import FileStore
@@ -98,10 +100,12 @@ class AzureAIService(InferenceServiceABC):
98
100
  @classmethod
99
101
  def create_model(
100
102
  cls, model_name: str = "azureai", model_class_name=None
101
- ) -> LanguageModel:
103
+ ) -> 'LanguageModel':
102
104
  if model_class_name is None:
103
105
  model_class_name = cls.to_class_name(model_name)
104
106
 
107
+ # Import LanguageModel only when actually creating a model
108
+ from ...language_models import LanguageModel
105
109
  class LLM(LanguageModel):
106
110
  """
107
111
  Child class of LanguageModel for interacting with Azure OpenAI models.
@@ -7,11 +7,13 @@ from google.api_core.exceptions import InvalidArgument
7
7
 
8
8
  # from ...exceptions.general import MissingAPIKeyError
9
9
  from ..inference_service_abc import InferenceServiceABC
10
- from ...language_models import LanguageModel
11
10
 
11
+ # Use TYPE_CHECKING to avoid circular imports at runtime
12
12
  if TYPE_CHECKING:
13
+ from ...language_models import LanguageModel
13
14
  from ....scenarios.file_store import FileStore as Files
14
- #from ...coop import Coop
15
+ # from ...coop import Coop
16
+ import asyncio
15
17
 
16
18
  safety_settings = [
17
19
  {
@@ -61,10 +63,11 @@ class GoogleService(InferenceServiceABC):
61
63
  @classmethod
62
64
  def create_model(
63
65
  cls, model_name: str = "gemini-pro", model_class_name=None
64
- ) -> 'LanguageModel':
66
+ ) -> "LanguageModel":
65
67
  if model_class_name is None:
66
68
  model_class_name = cls.to_class_name(model_name)
67
69
 
70
+ # Import LanguageModel only when actually creating a model
68
71
  from ...language_models import LanguageModel
69
72
 
70
73
  class LLM(LanguageModel):
@@ -137,6 +140,7 @@ class GoogleService(InferenceServiceABC):
137
140
  gen_ai_file = google.generativeai.types.file_types.File(
138
141
  file.external_locations["google"]
139
142
  )
143
+
140
144
  combined_prompt.append(gen_ai_file)
141
145
 
142
146
  try:
@@ -4,7 +4,9 @@ from mistralai import Mistral
4
4
 
5
5
 
6
6
  from ..inference_service_abc import InferenceServiceABC
7
- from ...language_models import LanguageModel
7
+ # Use TYPE_CHECKING to avoid circular imports at runtime
8
+ if TYPE_CHECKING:
9
+ from ...language_models import LanguageModel
8
10
 
9
11
  if TYPE_CHECKING:
10
12
  from ....scenarios.file_store import FileStore
@@ -64,10 +66,12 @@ class MistralAIService(InferenceServiceABC):
64
66
  @classmethod
65
67
  def create_model(
66
68
  cls, model_name: str = "mistral", model_class_name=None
67
- ) -> LanguageModel:
69
+ ) -> 'LanguageModel':
68
70
  if model_class_name is None:
69
71
  model_class_name = cls.to_class_name(model_name)
70
72
 
73
+ # Import LanguageModel only when actually creating a model
74
+ from ...language_models import LanguageModel
71
75
  class LLM(LanguageModel):
72
76
  """
73
77
  Child class of LanguageModel for interacting with Mistral models.
@@ -5,7 +5,9 @@ import os
5
5
  import openai
6
6
 
7
7
  from ..inference_service_abc import InferenceServiceABC
8
- from ...language_models import LanguageModel
8
+ # Use TYPE_CHECKING to avoid circular imports at runtime
9
+ if TYPE_CHECKING:
10
+ from ...language_models import LanguageModel
9
11
  from ..rate_limits_cache import rate_limits
10
12
 
11
13
  if TYPE_CHECKING:
@@ -110,10 +112,12 @@ class OpenAIService(InferenceServiceABC):
110
112
  return cls._models_list_cache
111
113
 
112
114
  @classmethod
113
- def create_model(cls, model_name, model_class_name=None) -> LanguageModel:
115
+ def create_model(cls, model_name, model_class_name=None) -> 'LanguageModel':
114
116
  if model_class_name is None:
115
117
  model_class_name = cls.to_class_name(model_name)
116
118
 
119
+ # Import LanguageModel only when actually creating a model
120
+ from ...language_models import LanguageModel
117
121
  class LLM(LanguageModel):
118
122
  """
119
123
  Child class of LanguageModel for interacting with OpenAI models
@@ -1,7 +1,9 @@
1
1
  from typing import Any, List, Optional, TYPE_CHECKING
2
2
  from ..rate_limits_cache import rate_limits
3
3
 
4
- from ...language_models import LanguageModel
4
+ # Use TYPE_CHECKING to avoid circular imports at runtime
5
+ if TYPE_CHECKING:
6
+ from ...language_models import LanguageModel
5
7
 
6
8
  from .open_ai_service import OpenAIService
7
9
 
@@ -40,10 +42,12 @@ class PerplexityService(OpenAIService):
40
42
  @classmethod
41
43
  def create_model(
42
44
  cls, model_name="llama-3.1-sonar-large-128k-online", model_class_name=None
43
- ) -> LanguageModel:
45
+ ) -> 'LanguageModel':
44
46
  if model_class_name is None:
45
47
  model_class_name = cls.to_class_name(model_name)
46
48
 
49
+ # Import LanguageModel only when actually creating a model
50
+ from ...language_models import LanguageModel
47
51
  class LLM(LanguageModel):
48
52
  """
49
53
  Child class of LanguageModel for interacting with Perplexity models