edsl 0.1.27.dev2__py3-none-any.whl → 0.1.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. edsl/Base.py +107 -30
  2. edsl/BaseDiff.py +260 -0
  3. edsl/__init__.py +25 -21
  4. edsl/__version__.py +1 -1
  5. edsl/agents/Agent.py +103 -46
  6. edsl/agents/AgentList.py +97 -13
  7. edsl/agents/Invigilator.py +23 -10
  8. edsl/agents/InvigilatorBase.py +19 -14
  9. edsl/agents/PromptConstructionMixin.py +342 -100
  10. edsl/agents/descriptors.py +5 -2
  11. edsl/base/Base.py +289 -0
  12. edsl/config.py +2 -1
  13. edsl/conjure/AgentConstructionMixin.py +152 -0
  14. edsl/conjure/Conjure.py +56 -0
  15. edsl/conjure/InputData.py +659 -0
  16. edsl/conjure/InputDataCSV.py +48 -0
  17. edsl/conjure/InputDataMixinQuestionStats.py +182 -0
  18. edsl/conjure/InputDataPyRead.py +91 -0
  19. edsl/conjure/InputDataSPSS.py +8 -0
  20. edsl/conjure/InputDataStata.py +8 -0
  21. edsl/conjure/QuestionOptionMixin.py +76 -0
  22. edsl/conjure/QuestionTypeMixin.py +23 -0
  23. edsl/conjure/RawQuestion.py +65 -0
  24. edsl/conjure/SurveyResponses.py +7 -0
  25. edsl/conjure/__init__.py +9 -4
  26. edsl/conjure/examples/placeholder.txt +0 -0
  27. edsl/conjure/naming_utilities.py +263 -0
  28. edsl/conjure/utilities.py +165 -28
  29. edsl/conversation/Conversation.py +238 -0
  30. edsl/conversation/car_buying.py +58 -0
  31. edsl/conversation/mug_negotiation.py +81 -0
  32. edsl/conversation/next_speaker_utilities.py +93 -0
  33. edsl/coop/coop.py +337 -121
  34. edsl/coop/utils.py +56 -70
  35. edsl/data/Cache.py +74 -22
  36. edsl/data/CacheHandler.py +10 -9
  37. edsl/data/SQLiteDict.py +11 -3
  38. edsl/inference_services/AnthropicService.py +1 -0
  39. edsl/inference_services/DeepInfraService.py +20 -13
  40. edsl/inference_services/GoogleService.py +7 -1
  41. edsl/inference_services/InferenceServicesCollection.py +33 -7
  42. edsl/inference_services/OpenAIService.py +17 -10
  43. edsl/inference_services/models_available_cache.py +69 -0
  44. edsl/inference_services/rate_limits_cache.py +25 -0
  45. edsl/inference_services/write_available.py +10 -0
  46. edsl/jobs/Answers.py +15 -1
  47. edsl/jobs/Jobs.py +322 -73
  48. edsl/jobs/buckets/BucketCollection.py +9 -3
  49. edsl/jobs/buckets/ModelBuckets.py +4 -2
  50. edsl/jobs/buckets/TokenBucket.py +1 -2
  51. edsl/jobs/interviews/Interview.py +7 -10
  52. edsl/jobs/interviews/InterviewStatusMixin.py +3 -3
  53. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +39 -20
  54. edsl/jobs/interviews/retry_management.py +4 -4
  55. edsl/jobs/runners/JobsRunnerAsyncio.py +103 -65
  56. edsl/jobs/runners/JobsRunnerStatusData.py +3 -3
  57. edsl/jobs/tasks/QuestionTaskCreator.py +4 -2
  58. edsl/jobs/tasks/TaskHistory.py +4 -3
  59. edsl/language_models/LanguageModel.py +42 -55
  60. edsl/language_models/ModelList.py +96 -0
  61. edsl/language_models/registry.py +14 -0
  62. edsl/language_models/repair.py +97 -25
  63. edsl/notebooks/Notebook.py +157 -32
  64. edsl/prompts/Prompt.py +31 -19
  65. edsl/questions/QuestionBase.py +145 -23
  66. edsl/questions/QuestionBudget.py +5 -6
  67. edsl/questions/QuestionCheckBox.py +7 -3
  68. edsl/questions/QuestionExtract.py +5 -3
  69. edsl/questions/QuestionFreeText.py +3 -3
  70. edsl/questions/QuestionFunctional.py +0 -3
  71. edsl/questions/QuestionList.py +3 -4
  72. edsl/questions/QuestionMultipleChoice.py +16 -8
  73. edsl/questions/QuestionNumerical.py +4 -3
  74. edsl/questions/QuestionRank.py +5 -3
  75. edsl/questions/__init__.py +4 -3
  76. edsl/questions/descriptors.py +9 -4
  77. edsl/questions/question_registry.py +27 -31
  78. edsl/questions/settings.py +1 -1
  79. edsl/results/Dataset.py +31 -0
  80. edsl/results/DatasetExportMixin.py +493 -0
  81. edsl/results/Result.py +42 -82
  82. edsl/results/Results.py +178 -66
  83. edsl/results/ResultsDBMixin.py +10 -9
  84. edsl/results/ResultsExportMixin.py +23 -507
  85. edsl/results/ResultsGGMixin.py +3 -3
  86. edsl/results/ResultsToolsMixin.py +9 -9
  87. edsl/scenarios/FileStore.py +140 -0
  88. edsl/scenarios/Scenario.py +59 -6
  89. edsl/scenarios/ScenarioList.py +138 -52
  90. edsl/scenarios/ScenarioListExportMixin.py +32 -0
  91. edsl/scenarios/ScenarioListPdfMixin.py +2 -1
  92. edsl/scenarios/__init__.py +1 -0
  93. edsl/study/ObjectEntry.py +173 -0
  94. edsl/study/ProofOfWork.py +113 -0
  95. edsl/study/SnapShot.py +73 -0
  96. edsl/study/Study.py +498 -0
  97. edsl/study/__init__.py +4 -0
  98. edsl/surveys/MemoryPlan.py +11 -4
  99. edsl/surveys/Survey.py +124 -37
  100. edsl/surveys/SurveyExportMixin.py +25 -5
  101. edsl/surveys/SurveyFlowVisualizationMixin.py +6 -4
  102. edsl/tools/plotting.py +4 -2
  103. edsl/utilities/__init__.py +21 -20
  104. edsl/utilities/gcp_bucket/__init__.py +0 -0
  105. edsl/utilities/gcp_bucket/cloud_storage.py +96 -0
  106. edsl/utilities/gcp_bucket/simple_example.py +9 -0
  107. edsl/utilities/interface.py +90 -73
  108. edsl/utilities/repair_functions.py +28 -0
  109. edsl/utilities/utilities.py +59 -6
  110. {edsl-0.1.27.dev2.dist-info → edsl-0.1.29.dist-info}/METADATA +42 -15
  111. edsl-0.1.29.dist-info/RECORD +203 -0
  112. edsl/conjure/RawResponseColumn.py +0 -327
  113. edsl/conjure/SurveyBuilder.py +0 -308
  114. edsl/conjure/SurveyBuilderCSV.py +0 -78
  115. edsl/conjure/SurveyBuilderSPSS.py +0 -118
  116. edsl/data/RemoteDict.py +0 -103
  117. edsl-0.1.27.dev2.dist-info/RECORD +0 -172
  118. {edsl-0.1.27.dev2.dist-info → edsl-0.1.29.dist-info}/LICENSE +0 -0
  119. {edsl-0.1.27.dev2.dist-info → edsl-0.1.29.dist-info}/WHEEL +0 -0
edsl/coop/utils.py CHANGED
@@ -1,45 +1,49 @@
1
- from edsl import Agent, AgentList, Cache, Jobs, Results, Scenario, ScenarioList, Survey
2
- from edsl.notebooks import Notebook
1
+ from edsl import (
2
+ Agent,
3
+ AgentList,
4
+ Cache,
5
+ Notebook,
6
+ Results,
7
+ Scenario,
8
+ ScenarioList,
9
+ Survey,
10
+ Study,
11
+ )
3
12
  from edsl.questions import QuestionBase
4
- from typing import Literal, Type, Union
13
+ from typing import Literal, Optional, Type, Union
5
14
 
6
15
  EDSLObject = Union[
7
16
  Agent,
8
17
  AgentList,
9
18
  Cache,
10
- Jobs,
11
19
  Notebook,
12
20
  Type[QuestionBase],
13
21
  Results,
14
22
  Scenario,
15
23
  ScenarioList,
16
24
  Survey,
25
+ Study,
17
26
  ]
18
27
 
19
28
  ObjectType = Literal[
20
29
  "agent",
21
30
  "agent_list",
22
31
  "cache",
23
- "job",
24
- "question",
25
32
  "notebook",
33
+ "question",
26
34
  "results",
27
35
  "scenario",
28
36
  "scenario_list",
29
37
  "survey",
38
+ "study",
30
39
  ]
31
40
 
32
- ObjectPage = Literal[
33
- "agents",
34
- "agentlists",
35
- "caches",
36
- "jobs",
37
- "notebooks",
38
- "questions",
39
- "results",
40
- "scenarios",
41
- "scenariolists",
42
- "surveys",
41
+
42
+ RemoteJobStatus = Literal[
43
+ "queued",
44
+ "running",
45
+ "completed",
46
+ "failed",
43
47
  ]
44
48
 
45
49
  VisibilityType = Literal[
@@ -55,62 +59,21 @@ class ObjectRegistry:
55
59
  """
56
60
 
57
61
  objects = [
58
- {
59
- "object_type": "agent",
60
- "edsl_class": Agent,
61
- "object_page": "agents",
62
- },
63
- {
64
- "object_type": "agent_list",
65
- "edsl_class": AgentList,
66
- "object_page": "agentlists",
67
- },
68
- {
69
- "object_type": "cache",
70
- "edsl_class": Cache,
71
- "object_page": "caches",
72
- },
73
- {
74
- "object_type": "job",
75
- "edsl_class": Jobs,
76
- "object_page": "jobs",
77
- },
78
- {
79
- "object_type": "question",
80
- "edsl_class": QuestionBase,
81
- "object_page": "questions",
82
- },
83
- {
84
- "object_type": "notebook",
85
- "edsl_class": Notebook,
86
- "object_page": "notebooks",
87
- },
88
- {
89
- "object_type": "results",
90
- "edsl_class": Results,
91
- "object_page": "results",
92
- },
93
- {
94
- "object_type": "scenario",
95
- "edsl_class": Scenario,
96
- "object_page": "scenarios",
97
- },
98
- {
99
- "object_type": "scenario_list",
100
- "edsl_class": ScenarioList,
101
- "object_page": "scenariolists",
102
- },
103
- {
104
- "object_type": "survey",
105
- "edsl_class": Survey,
106
- "object_page": "surveys",
107
- },
62
+ {"object_type": "agent", "edsl_class": Agent},
63
+ {"object_type": "agent_list", "edsl_class": AgentList},
64
+ {"object_type": "cache", "edsl_class": Cache},
65
+ {"object_type": "question", "edsl_class": QuestionBase},
66
+ {"object_type": "notebook", "edsl_class": Notebook},
67
+ {"object_type": "results", "edsl_class": Results},
68
+ {"object_type": "scenario", "edsl_class": Scenario},
69
+ {"object_type": "scenario_list", "edsl_class": ScenarioList},
70
+ {"object_type": "survey", "edsl_class": Survey},
71
+ {"object_type": "study", "edsl_class": Study},
108
72
  ]
109
73
  object_type_to_edsl_class = {o["object_type"]: o["edsl_class"] for o in objects}
110
74
  edsl_class_to_object_type = {
111
75
  o["edsl_class"].__name__: o["object_type"] for o in objects
112
76
  }
113
- object_type_to_object_page = {o["object_type"]: o["object_page"] for o in objects}
114
77
 
115
78
  @classmethod
116
79
  def get_object_type_by_edsl_class(cls, edsl_object: EDSLObject) -> ObjectType:
@@ -133,5 +96,28 @@ class ObjectRegistry:
133
96
  return EDSL_object
134
97
 
135
98
  @classmethod
136
- def get_object_page_by_object_type(cls, object_type: ObjectType) -> ObjectPage:
137
- return cls.object_type_to_object_page.get(object_type)
99
+ def get_registry(
100
+ cls,
101
+ subclass_registry: Optional[dict] = None,
102
+ exclude_classes: Optional[list] = None,
103
+ ) -> dict:
104
+ """
105
+ Return the registry of objects.
106
+
107
+ Exclude objects that are already registered in subclass_registry.
108
+ This allows the user to isolate Coop-only objects.
109
+
110
+ Also exclude objects if their class name is in the exclude_classes list.
111
+ """
112
+
113
+ if subclass_registry is None:
114
+ subclass_registry = {}
115
+ if exclude_classes is None:
116
+ exclude_classes = []
117
+
118
+ return {
119
+ class_name: o["edsl_class"]
120
+ for o in cls.objects
121
+ if (class_name := o["edsl_class"].__name__) not in subclass_registry
122
+ and class_name not in exclude_classes
123
+ }
edsl/data/Cache.py CHANGED
@@ -7,12 +7,13 @@ import json
7
7
  import os
8
8
  import warnings
9
9
  from typing import Optional, Union
10
-
10
+ import time
11
11
  from edsl.config import CONFIG
12
12
  from edsl.data.CacheEntry import CacheEntry
13
- from edsl.data.SQLiteDict import SQLiteDict
14
- from edsl.Base import Base
15
13
 
14
+ # from edsl.data.SQLiteDict import SQLiteDict
15
+ from edsl.Base import Base
16
+ from edsl.utilities.utilities import dict_hash
16
17
  from edsl.utilities.decorators import (
17
18
  add_edsl_version,
18
19
  remove_edsl_version,
@@ -24,7 +25,6 @@ class Cache(Base):
24
25
  A class that represents a cache of responses from a language model.
25
26
 
26
27
  :param data: The data to initialize the cache with.
27
- :param remote: Whether to sync the Cache with the server.
28
28
  :param immediate_write: Whether to write to the cache immediately after storing a new entry.
29
29
 
30
30
  Deprecated:
@@ -37,24 +37,51 @@ class Cache(Base):
37
37
  def __init__(
38
38
  self,
39
39
  *,
40
- data: Optional[Union[SQLiteDict, dict]] = None,
41
- remote: bool = False,
40
+ filename: Optional[str] = None,
41
+ data: Optional[Union["SQLiteDict", dict]] = None,
42
42
  immediate_write: bool = True,
43
43
  method=None,
44
44
  ):
45
45
  """
46
46
  Create two dictionaries to store the cache data.
47
47
 
48
+ :param filename: The name of the file to read/write the cache from/to.
49
+ :param data: The data to initialize the cache with.
50
+ :param immediate_write: Whether to write to the cache immediately after storing a new entry.
51
+ :param method: The method of storage to use for the cache.
52
+
48
53
  """
49
- self.data = data or {}
54
+
50
55
  # self.data_at_init = data or {}
51
56
  self.fetched_data = {}
52
- self.remote = remote
53
57
  self.immediate_write = immediate_write
54
58
  self.method = method
55
59
  self.new_entries = {}
56
60
  self.new_entries_to_write_later = {}
57
61
  self.coop = None
62
+
63
+ self.filename = filename
64
+ if filename and data:
65
+ raise ValueError("Cannot provide both filename and data")
66
+ if filename is None and data is None:
67
+ data = {}
68
+ if data is not None:
69
+ self.data = data
70
+ if filename is not None:
71
+ self.data = {}
72
+ if filename.endswith(".jsonl"):
73
+ if os.path.exists(filename):
74
+ self.add_from_jsonl(filename)
75
+ else:
76
+ print(
77
+ f"File {filename} not found, but will write to this location."
78
+ )
79
+ elif filename.endswith(".db"):
80
+ if os.path.exists(filename):
81
+ self.add_from_sqlite(filename)
82
+ else:
83
+ raise ValueError("Invalid file extension. Must be .jsonl or .db")
84
+
58
85
  self._perform_checks()
59
86
 
60
87
  def rich_print(sefl):
@@ -77,14 +104,12 @@ class Cache(Base):
77
104
 
78
105
  def _perform_checks(self):
79
106
  """Perform checks on the cache."""
107
+ from edsl.data.CacheEntry import CacheEntry
108
+
80
109
  if any(not isinstance(value, CacheEntry) for value in self.data.values()):
81
110
  raise Exception("Not all values are CacheEntry instances")
82
111
  if self.method is not None:
83
112
  warnings.warn("Argument `method` is deprecated", DeprecationWarning)
84
- if self.remote:
85
- from edsl.coop import Coop
86
-
87
- self.coop = Coop()
88
113
 
89
114
  ####################
90
115
  # READ/WRITE
@@ -115,6 +140,8 @@ class Cache(Base):
115
140
 
116
141
 
117
142
  """
143
+ from edsl.data.CacheEntry import CacheEntry
144
+
118
145
  key = CacheEntry.gen_key(
119
146
  model=model,
120
147
  parameters=parameters,
@@ -148,6 +175,7 @@ class Cache(Base):
148
175
  * If `immediate_write` is True , the key-value pair is added to `self.data`
149
176
  * If `immediate_write` is False, the key-value pair is added to `self.new_entries_to_write_later`
150
177
  """
178
+
151
179
  entry = CacheEntry(
152
180
  model=model,
153
181
  parameters=parameters,
@@ -165,13 +193,14 @@ class Cache(Base):
165
193
  return key
166
194
 
167
195
  def add_from_dict(
168
- self, new_data: dict[str, CacheEntry], write_now: Optional[bool] = True
196
+ self, new_data: dict[str, "CacheEntry"], write_now: Optional[bool] = True
169
197
  ) -> None:
170
198
  """
171
199
  Add entries to the cache from a dictionary.
172
200
 
173
201
  :param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
174
202
  """
203
+
175
204
  for key, value in new_data.items():
176
205
  if key in self.data:
177
206
  if value != self.data[key]:
@@ -208,6 +237,8 @@ class Cache(Base):
208
237
 
209
238
  :param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
210
239
  """
240
+ from edsl.data.SQLiteDict import SQLiteDict
241
+
211
242
  db = SQLiteDict(db_path)
212
243
  new_data = {}
213
244
  for key, value in db.items():
@@ -219,6 +250,8 @@ class Cache(Base):
219
250
  """
220
251
  Construct a Cache from a SQLite database.
221
252
  """
253
+ from edsl.data.SQLiteDict import SQLiteDict
254
+
222
255
  return cls(data=SQLiteDict(db_path))
223
256
 
224
257
  @classmethod
@@ -245,6 +278,8 @@ class Cache(Base):
245
278
  * If `db_path` is provided, the cache will be stored in an SQLite database.
246
279
  """
247
280
  # if a file doesn't exist at jsonfile, throw an error
281
+ from edsl.data.SQLiteDict import SQLiteDict
282
+
248
283
  if not os.path.exists(jsonlfile):
249
284
  raise FileNotFoundError(f"File {jsonlfile} not found")
250
285
 
@@ -263,10 +298,25 @@ class Cache(Base):
263
298
  """
264
299
  ## TODO: Check to make sure not over-writing (?)
265
300
  ## Should be added to SQLiteDict constructor (?)
301
+ from edsl.data.SQLiteDict import SQLiteDict
302
+
266
303
  new_data = SQLiteDict(db_path)
267
304
  for key, value in self.data.items():
268
305
  new_data[key] = value
269
306
 
307
+ def write(self, filename: Optional[str] = None) -> None:
308
+ """
309
+ Write the cache to a file at the specified location.
310
+ """
311
+ if filename is None:
312
+ filename = self.filename
313
+ if filename.endswith(".jsonl"):
314
+ self.write_jsonl(filename)
315
+ elif filename.endswith(".db"):
316
+ self.write_sqlite_db(filename)
317
+ else:
318
+ raise ValueError("Invalid file extension. Must be .jsonl or .db")
319
+
270
320
  def write_jsonl(self, filename: str) -> None:
271
321
  """
272
322
  Write the cache to a JSONL file.
@@ -295,11 +345,6 @@ class Cache(Base):
295
345
  """
296
346
  Run when a context is entered.
297
347
  """
298
- if self.remote:
299
- print("Syncing local and remote caches")
300
- exclude_keys = list(self.data.keys())
301
- cache_entries = self.coop.get_cache_entries(exclude_keys)
302
- self.add_from_dict({c.key: c for c in cache_entries}, write_now=True)
303
348
  return self
304
349
 
305
350
  def __exit__(self, exc_type, exc_value, traceback):
@@ -308,16 +353,21 @@ class Cache(Base):
308
353
  """
309
354
  for key, entry in self.new_entries_to_write_later.items():
310
355
  self.data[key] = entry
311
- if self.remote:
312
- _ = self.coop.create_cache_entries(cache_dict=self.new_entries)
313
356
 
314
357
  ####################
315
358
  # DUNDER / USEFUL
316
359
  ####################
360
+ def __hash__(self):
361
+ """Return the hash of the Cache."""
362
+ return dict_hash(self._to_dict())
363
+
364
+ def _to_dict(self) -> dict:
365
+ return {k: v.to_dict() for k, v in self.data.items()}
366
+
317
367
  @add_edsl_version
318
368
  def to_dict(self) -> dict:
319
369
  """Return the Cache as a dictionary."""
320
- return {k: v.to_dict() for k, v in self.data.items()}
370
+ return self._to_dict()
321
371
 
322
372
  def _repr_html_(self):
323
373
  from edsl.utilities.utilities import data_to_html
@@ -359,7 +409,9 @@ class Cache(Base):
359
409
  """
360
410
  Return a string representation of the Cache object.
361
411
  """
362
- return f"Cache(data = {repr(self.data)}, immediate_write={self.immediate_write}, remote={self.remote})"
412
+ return (
413
+ f"Cache(data = {repr(self.data)}, immediate_write={self.immediate_write})"
414
+ )
363
415
 
364
416
  ####################
365
417
  # EXAMPLES
edsl/data/CacheHandler.py CHANGED
@@ -9,22 +9,22 @@ from edsl.data.Cache import Cache
9
9
  from edsl.data.CacheEntry import CacheEntry
10
10
  from edsl.data.SQLiteDict import SQLiteDict
11
11
 
12
+ from edsl.config import CONFIG
13
+
12
14
 
13
15
  def set_session_cache(cache: Cache) -> None:
14
16
  """
15
17
  Set the session cache.
16
18
  """
17
- print("All calls to 'run' will now use this cache by default.")
18
- global _CACHE
19
- _CACHE = cache
19
+ CONFIG.EDSL_SESSION_CACHE = cache
20
20
 
21
21
 
22
22
  def unset_session_cache() -> None:
23
23
  """
24
24
  Unset the session cache.
25
25
  """
26
- global _CACHE
27
- _CACHE = None
26
+ if hasattr(CONFIG, "EDSL_SESSION_CACHE"):
27
+ del CONFIG.EDSL_SESSION_CACHE
28
28
 
29
29
 
30
30
  class CacheHandler:
@@ -49,7 +49,9 @@ class CacheHandler:
49
49
  dir_path = os.path.dirname(path)
50
50
  if dir_path and not os.path.exists(dir_path):
51
51
  os.makedirs(dir_path)
52
- print(f"Created cache directory: {dir_path}")
52
+ import warnings
53
+
54
+ warnings.warn(f"Created cache directory: {dir_path}")
53
55
 
54
56
  def gen_cache(self) -> Cache:
55
57
  """
@@ -58,9 +60,8 @@ class CacheHandler:
58
60
  if self.test:
59
61
  return Cache(data={})
60
62
 
61
- if "_CACHE" in globals() and _CACHE is not None:
62
- # print("Using globally-set cache.")
63
- return _CACHE
63
+ if hasattr(CONFIG, "EDSL_SESSION_CACHE"):
64
+ return CONFIG.EDSL_SESSION_CACHE
64
65
 
65
66
  cache = Cache(data=SQLiteDict(self.CACHE_PATH))
66
67
  return cache
edsl/data/SQLiteDict.py CHANGED
@@ -1,9 +1,7 @@
1
1
  from __future__ import annotations
2
2
  import json
3
- from sqlalchemy import create_engine
4
- from sqlalchemy.exc import SQLAlchemyError
5
- from sqlalchemy.orm import sessionmaker
6
3
  from typing import Any, Generator, Optional, Union
4
+
7
5
  from edsl.config import CONFIG
8
6
  from edsl.data.CacheEntry import CacheEntry
9
7
  from edsl.data.orm import Base, Data
@@ -25,10 +23,16 @@ class SQLiteDict:
25
23
  >>> import os; os.unlink(temp_db_path) # Clean up the temp file after the test
26
24
 
27
25
  """
26
+ from sqlalchemy.exc import SQLAlchemyError
27
+ from sqlalchemy.orm import sessionmaker
28
+ from sqlalchemy import create_engine
29
+
28
30
  self.db_path = db_path or CONFIG.get("EDSL_DATABASE_PATH")
29
31
  if not self.db_path.startswith("sqlite:///"):
30
32
  self.db_path = f"sqlite:///{self.db_path}"
31
33
  try:
34
+ from edsl.data.orm import Base, Data
35
+
32
36
  self.engine = create_engine(self.db_path, echo=False, future=True)
33
37
  Base.metadata.create_all(self.engine)
34
38
  self.Session = sessionmaker(bind=self.engine)
@@ -55,6 +59,8 @@ class SQLiteDict:
55
59
  if not isinstance(value, CacheEntry):
56
60
  raise ValueError(f"Value must be a CacheEntry object (got {type(value)}).")
57
61
  with self.Session() as db:
62
+ from edsl.data.orm import Base, Data
63
+
58
64
  db.merge(Data(key=key, value=json.dumps(value.to_dict())))
59
65
  db.commit()
60
66
 
@@ -69,6 +75,8 @@ class SQLiteDict:
69
75
  True
70
76
  """
71
77
  with self.Session() as db:
78
+ from edsl.data.orm import Base, Data
79
+
72
80
  value = db.query(Data).filter_by(key=key).first()
73
81
  if not value:
74
82
  raise KeyError(f"Key '{key}' not found.")
@@ -16,6 +16,7 @@ class AnthropicService(InferenceServiceABC):
16
16
  def available(cls):
17
17
  # TODO - replace with an API call
18
18
  return [
19
+ "claude-3-5-sonnet-20240620",
19
20
  "claude-3-opus-20240229",
20
21
  "claude-3-sonnet-20240229",
21
22
  "claude-3-haiku-20240307",
@@ -1,7 +1,7 @@
1
1
  import aiohttp
2
2
  import json
3
3
  import requests
4
- from typing import Any
4
+ from typing import Any, List
5
5
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
6
6
  from edsl.language_models import LanguageModel
7
7
 
@@ -12,6 +12,8 @@ class DeepInfraService(InferenceServiceABC):
12
12
  _inference_service_ = "deep_infra"
13
13
  _env_key_name_ = "DEEP_INFRA_API_KEY"
14
14
 
15
+ _models_list_cache: List[str] = []
16
+
15
17
  @classmethod
16
18
  def available(cls):
17
19
  text_models = cls.full_details_available()
@@ -19,20 +21,25 @@ class DeepInfraService(InferenceServiceABC):
19
21
 
20
22
  @classmethod
21
23
  def full_details_available(cls, verbose=False):
22
- url = "https://api.deepinfra.com/models/list"
23
- response = requests.get(url)
24
- if response.status_code == 200:
25
- text_generation_models = [
26
- r for r in response.json() if r["type"] == "text-generation"
27
- ]
28
- from rich import print_json
29
- import json
24
+ if not cls._models_list_cache:
25
+ url = "https://api.deepinfra.com/models/list"
26
+ response = requests.get(url)
27
+ if response.status_code == 200:
28
+ text_generation_models = [
29
+ r for r in response.json() if r["type"] == "text-generation"
30
+ ]
31
+ cls._models_list_cache = text_generation_models
32
+
33
+ from rich import print_json
34
+ import json
30
35
 
31
- if verbose:
32
- print_json(json.dumps(text_generation_models))
33
- return text_generation_models
36
+ if verbose:
37
+ print_json(json.dumps(text_generation_models))
38
+ return text_generation_models
39
+ else:
40
+ return f"Failed to fetch data: Status code {response.status_code}"
34
41
  else:
35
- return f"Failed to fetch data: Status code {response.status_code}"
42
+ return cls._models_list_cache
36
43
 
37
44
  @classmethod
38
45
  def create_model(cls, model_name: str, model_class_name=None) -> LanguageModel:
@@ -60,7 +60,13 @@ class GoogleService(InferenceServiceABC):
60
60
 
61
61
  def parse_response(self, raw_response: dict[str, Any]) -> str:
62
62
  data = raw_response
63
- return data["candidates"][0]["content"]["parts"][0]["text"]
63
+ try:
64
+ return data["candidates"][0]["content"]["parts"][0]["text"]
65
+ except KeyError as e:
66
+ print(
67
+ f"The data return was {data}, which was missing the key 'candidates'"
68
+ )
69
+ raise e
64
70
 
65
71
  LLM.__name__ = model_name
66
72
 
@@ -1,21 +1,47 @@
1
1
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
2
+ import warnings
2
3
 
3
4
 
4
5
  class InferenceServicesCollection:
6
+ added_models = {}
7
+
5
8
  def __init__(self, services: list[InferenceServiceABC] = None):
6
9
  self.services = services or []
7
10
 
11
+ @classmethod
12
+ def add_model(cls, service_name, model_name):
13
+ if service_name not in cls.added_models:
14
+ cls.added_models[service_name] = []
15
+ cls.added_models[service_name].append(model_name)
16
+
17
+ @staticmethod
18
+ def _get_service_available(service) -> list[str]:
19
+ from_api = True
20
+ try:
21
+ service_models = service.available()
22
+ except Exception as e:
23
+ warnings.warn(
24
+ f"Error getting models for {service._inference_service_}. Relying on cache.",
25
+ UserWarning,
26
+ )
27
+ from edsl.inference_services.models_available_cache import models_available
28
+
29
+ service_models = models_available.get(service._inference_service_, [])
30
+ # cache results
31
+ service._models_list_cache = service_models
32
+ from_api = False
33
+ return service_models # , from_api
34
+
8
35
  def available(self):
9
36
  total_models = []
10
37
  for service in self.services:
11
- try:
12
- service_models = service.available()
13
- except Exception as e:
14
- print(f"Error getting models for {service._inference_service_}: {e}")
15
- service_models = []
16
- continue
38
+ service_models = self._get_service_available(service)
17
39
  for model in service_models:
18
40
  total_models.append([model, service._inference_service_, -1])
41
+
42
+ for model in self.added_models.get(service._inference_service_, []):
43
+ total_models.append([model, service._inference_service_, -1])
44
+
19
45
  sorted_models = sorted(total_models)
20
46
  for i, model in enumerate(sorted_models):
21
47
  model[2] = i
@@ -27,7 +53,7 @@ class InferenceServicesCollection:
27
53
 
28
54
  def create_model_factory(self, model_name: str, service_name=None, index=None):
29
55
  for service in self.services:
30
- if model_name in service.available():
56
+ if model_name in self._get_service_available(service):
31
57
  if service_name is None or service_name == service._inference_service_:
32
58
  return service.create_model(model_name)
33
59
 
@@ -4,6 +4,7 @@ from openai import AsyncOpenAI
4
4
 
5
5
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
6
6
  from edsl.language_models import LanguageModel
7
+ from edsl.inference_services.rate_limits_cache import rate_limits
7
8
 
8
9
 
9
10
  class OpenAIService(InferenceServiceABC):
@@ -43,15 +44,16 @@ class OpenAIService(InferenceServiceABC):
43
44
  if m.id not in cls.model_exclude_list
44
45
  ]
45
46
  except Exception as e:
46
- print(
47
- f"""Error retrieving models: {e}.
48
- See instructions about storing your API keys: https://docs.expectedparrot.com/en/latest/api_keys.html"""
49
- )
50
- cls._models_list_cache = [
51
- "gpt-3.5-turbo",
52
- "gpt-4-1106-preview",
53
- "gpt-4",
54
- ] # Fallback list
47
+ raise
48
+ # print(
49
+ # f"""Error retrieving models: {e}.
50
+ # See instructions about storing your API keys: https://docs.expectedparrot.com/en/latest/api_keys.html"""
51
+ # )
52
+ # cls._models_list_cache = [
53
+ # "gpt-3.5-turbo",
54
+ # "gpt-4-1106-preview",
55
+ # "gpt-4",
56
+ # ] # Fallback list
55
57
  return cls._models_list_cache
56
58
 
57
59
  @classmethod
@@ -98,7 +100,12 @@ class OpenAIService(InferenceServiceABC):
98
100
 
99
101
  def get_rate_limits(self) -> dict[str, Any]:
100
102
  try:
101
- headers = self.get_headers()
103
+ if "openai" in rate_limits:
104
+ headers = rate_limits["openai"]
105
+
106
+ else:
107
+ headers = self.get_headers()
108
+
102
109
  except Exception as e:
103
110
  return {
104
111
  "rpm": 10_000,