edsl 0.1.27.dev2__py3-none-any.whl → 0.1.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +107 -30
- edsl/BaseDiff.py +260 -0
- edsl/__init__.py +25 -21
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +103 -46
- edsl/agents/AgentList.py +97 -13
- edsl/agents/Invigilator.py +23 -10
- edsl/agents/InvigilatorBase.py +19 -14
- edsl/agents/PromptConstructionMixin.py +342 -100
- edsl/agents/descriptors.py +5 -2
- edsl/base/Base.py +289 -0
- edsl/config.py +2 -1
- edsl/conjure/AgentConstructionMixin.py +152 -0
- edsl/conjure/Conjure.py +56 -0
- edsl/conjure/InputData.py +659 -0
- edsl/conjure/InputDataCSV.py +48 -0
- edsl/conjure/InputDataMixinQuestionStats.py +182 -0
- edsl/conjure/InputDataPyRead.py +91 -0
- edsl/conjure/InputDataSPSS.py +8 -0
- edsl/conjure/InputDataStata.py +8 -0
- edsl/conjure/QuestionOptionMixin.py +76 -0
- edsl/conjure/QuestionTypeMixin.py +23 -0
- edsl/conjure/RawQuestion.py +65 -0
- edsl/conjure/SurveyResponses.py +7 -0
- edsl/conjure/__init__.py +9 -4
- edsl/conjure/examples/placeholder.txt +0 -0
- edsl/conjure/naming_utilities.py +263 -0
- edsl/conjure/utilities.py +165 -28
- edsl/conversation/Conversation.py +238 -0
- edsl/conversation/car_buying.py +58 -0
- edsl/conversation/mug_negotiation.py +81 -0
- edsl/conversation/next_speaker_utilities.py +93 -0
- edsl/coop/coop.py +337 -121
- edsl/coop/utils.py +56 -70
- edsl/data/Cache.py +74 -22
- edsl/data/CacheHandler.py +10 -9
- edsl/data/SQLiteDict.py +11 -3
- edsl/inference_services/AnthropicService.py +1 -0
- edsl/inference_services/DeepInfraService.py +20 -13
- edsl/inference_services/GoogleService.py +7 -1
- edsl/inference_services/InferenceServicesCollection.py +33 -7
- edsl/inference_services/OpenAIService.py +17 -10
- edsl/inference_services/models_available_cache.py +69 -0
- edsl/inference_services/rate_limits_cache.py +25 -0
- edsl/inference_services/write_available.py +10 -0
- edsl/jobs/Answers.py +15 -1
- edsl/jobs/Jobs.py +322 -73
- edsl/jobs/buckets/BucketCollection.py +9 -3
- edsl/jobs/buckets/ModelBuckets.py +4 -2
- edsl/jobs/buckets/TokenBucket.py +1 -2
- edsl/jobs/interviews/Interview.py +7 -10
- edsl/jobs/interviews/InterviewStatusMixin.py +3 -3
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +39 -20
- edsl/jobs/interviews/retry_management.py +4 -4
- edsl/jobs/runners/JobsRunnerAsyncio.py +103 -65
- edsl/jobs/runners/JobsRunnerStatusData.py +3 -3
- edsl/jobs/tasks/QuestionTaskCreator.py +4 -2
- edsl/jobs/tasks/TaskHistory.py +4 -3
- edsl/language_models/LanguageModel.py +42 -55
- edsl/language_models/ModelList.py +96 -0
- edsl/language_models/registry.py +14 -0
- edsl/language_models/repair.py +97 -25
- edsl/notebooks/Notebook.py +157 -32
- edsl/prompts/Prompt.py +31 -19
- edsl/questions/QuestionBase.py +145 -23
- edsl/questions/QuestionBudget.py +5 -6
- edsl/questions/QuestionCheckBox.py +7 -3
- edsl/questions/QuestionExtract.py +5 -3
- edsl/questions/QuestionFreeText.py +3 -3
- edsl/questions/QuestionFunctional.py +0 -3
- edsl/questions/QuestionList.py +3 -4
- edsl/questions/QuestionMultipleChoice.py +16 -8
- edsl/questions/QuestionNumerical.py +4 -3
- edsl/questions/QuestionRank.py +5 -3
- edsl/questions/__init__.py +4 -3
- edsl/questions/descriptors.py +9 -4
- edsl/questions/question_registry.py +27 -31
- edsl/questions/settings.py +1 -1
- edsl/results/Dataset.py +31 -0
- edsl/results/DatasetExportMixin.py +493 -0
- edsl/results/Result.py +42 -82
- edsl/results/Results.py +178 -66
- edsl/results/ResultsDBMixin.py +10 -9
- edsl/results/ResultsExportMixin.py +23 -507
- edsl/results/ResultsGGMixin.py +3 -3
- edsl/results/ResultsToolsMixin.py +9 -9
- edsl/scenarios/FileStore.py +140 -0
- edsl/scenarios/Scenario.py +59 -6
- edsl/scenarios/ScenarioList.py +138 -52
- edsl/scenarios/ScenarioListExportMixin.py +32 -0
- edsl/scenarios/ScenarioListPdfMixin.py +2 -1
- edsl/scenarios/__init__.py +1 -0
- edsl/study/ObjectEntry.py +173 -0
- edsl/study/ProofOfWork.py +113 -0
- edsl/study/SnapShot.py +73 -0
- edsl/study/Study.py +498 -0
- edsl/study/__init__.py +4 -0
- edsl/surveys/MemoryPlan.py +11 -4
- edsl/surveys/Survey.py +124 -37
- edsl/surveys/SurveyExportMixin.py +25 -5
- edsl/surveys/SurveyFlowVisualizationMixin.py +6 -4
- edsl/tools/plotting.py +4 -2
- edsl/utilities/__init__.py +21 -20
- edsl/utilities/gcp_bucket/__init__.py +0 -0
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -0
- edsl/utilities/gcp_bucket/simple_example.py +9 -0
- edsl/utilities/interface.py +90 -73
- edsl/utilities/repair_functions.py +28 -0
- edsl/utilities/utilities.py +59 -6
- {edsl-0.1.27.dev2.dist-info → edsl-0.1.29.dist-info}/METADATA +42 -15
- edsl-0.1.29.dist-info/RECORD +203 -0
- edsl/conjure/RawResponseColumn.py +0 -327
- edsl/conjure/SurveyBuilder.py +0 -308
- edsl/conjure/SurveyBuilderCSV.py +0 -78
- edsl/conjure/SurveyBuilderSPSS.py +0 -118
- edsl/data/RemoteDict.py +0 -103
- edsl-0.1.27.dev2.dist-info/RECORD +0 -172
- {edsl-0.1.27.dev2.dist-info → edsl-0.1.29.dist-info}/LICENSE +0 -0
- {edsl-0.1.27.dev2.dist-info → edsl-0.1.29.dist-info}/WHEEL +0 -0
edsl/coop/utils.py
CHANGED
@@ -1,45 +1,49 @@
|
|
1
|
-
from edsl import
|
2
|
-
|
1
|
+
from edsl import (
|
2
|
+
Agent,
|
3
|
+
AgentList,
|
4
|
+
Cache,
|
5
|
+
Notebook,
|
6
|
+
Results,
|
7
|
+
Scenario,
|
8
|
+
ScenarioList,
|
9
|
+
Survey,
|
10
|
+
Study,
|
11
|
+
)
|
3
12
|
from edsl.questions import QuestionBase
|
4
|
-
from typing import Literal, Type, Union
|
13
|
+
from typing import Literal, Optional, Type, Union
|
5
14
|
|
6
15
|
EDSLObject = Union[
|
7
16
|
Agent,
|
8
17
|
AgentList,
|
9
18
|
Cache,
|
10
|
-
Jobs,
|
11
19
|
Notebook,
|
12
20
|
Type[QuestionBase],
|
13
21
|
Results,
|
14
22
|
Scenario,
|
15
23
|
ScenarioList,
|
16
24
|
Survey,
|
25
|
+
Study,
|
17
26
|
]
|
18
27
|
|
19
28
|
ObjectType = Literal[
|
20
29
|
"agent",
|
21
30
|
"agent_list",
|
22
31
|
"cache",
|
23
|
-
"job",
|
24
|
-
"question",
|
25
32
|
"notebook",
|
33
|
+
"question",
|
26
34
|
"results",
|
27
35
|
"scenario",
|
28
36
|
"scenario_list",
|
29
37
|
"survey",
|
38
|
+
"study",
|
30
39
|
]
|
31
40
|
|
32
|
-
|
33
|
-
|
34
|
-
"
|
35
|
-
"
|
36
|
-
"
|
37
|
-
"
|
38
|
-
"questions",
|
39
|
-
"results",
|
40
|
-
"scenarios",
|
41
|
-
"scenariolists",
|
42
|
-
"surveys",
|
41
|
+
|
42
|
+
RemoteJobStatus = Literal[
|
43
|
+
"queued",
|
44
|
+
"running",
|
45
|
+
"completed",
|
46
|
+
"failed",
|
43
47
|
]
|
44
48
|
|
45
49
|
VisibilityType = Literal[
|
@@ -55,62 +59,21 @@ class ObjectRegistry:
|
|
55
59
|
"""
|
56
60
|
|
57
61
|
objects = [
|
58
|
-
{
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
},
|
63
|
-
{
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
},
|
68
|
-
{
|
69
|
-
"object_type": "cache",
|
70
|
-
"edsl_class": Cache,
|
71
|
-
"object_page": "caches",
|
72
|
-
},
|
73
|
-
{
|
74
|
-
"object_type": "job",
|
75
|
-
"edsl_class": Jobs,
|
76
|
-
"object_page": "jobs",
|
77
|
-
},
|
78
|
-
{
|
79
|
-
"object_type": "question",
|
80
|
-
"edsl_class": QuestionBase,
|
81
|
-
"object_page": "questions",
|
82
|
-
},
|
83
|
-
{
|
84
|
-
"object_type": "notebook",
|
85
|
-
"edsl_class": Notebook,
|
86
|
-
"object_page": "notebooks",
|
87
|
-
},
|
88
|
-
{
|
89
|
-
"object_type": "results",
|
90
|
-
"edsl_class": Results,
|
91
|
-
"object_page": "results",
|
92
|
-
},
|
93
|
-
{
|
94
|
-
"object_type": "scenario",
|
95
|
-
"edsl_class": Scenario,
|
96
|
-
"object_page": "scenarios",
|
97
|
-
},
|
98
|
-
{
|
99
|
-
"object_type": "scenario_list",
|
100
|
-
"edsl_class": ScenarioList,
|
101
|
-
"object_page": "scenariolists",
|
102
|
-
},
|
103
|
-
{
|
104
|
-
"object_type": "survey",
|
105
|
-
"edsl_class": Survey,
|
106
|
-
"object_page": "surveys",
|
107
|
-
},
|
62
|
+
{"object_type": "agent", "edsl_class": Agent},
|
63
|
+
{"object_type": "agent_list", "edsl_class": AgentList},
|
64
|
+
{"object_type": "cache", "edsl_class": Cache},
|
65
|
+
{"object_type": "question", "edsl_class": QuestionBase},
|
66
|
+
{"object_type": "notebook", "edsl_class": Notebook},
|
67
|
+
{"object_type": "results", "edsl_class": Results},
|
68
|
+
{"object_type": "scenario", "edsl_class": Scenario},
|
69
|
+
{"object_type": "scenario_list", "edsl_class": ScenarioList},
|
70
|
+
{"object_type": "survey", "edsl_class": Survey},
|
71
|
+
{"object_type": "study", "edsl_class": Study},
|
108
72
|
]
|
109
73
|
object_type_to_edsl_class = {o["object_type"]: o["edsl_class"] for o in objects}
|
110
74
|
edsl_class_to_object_type = {
|
111
75
|
o["edsl_class"].__name__: o["object_type"] for o in objects
|
112
76
|
}
|
113
|
-
object_type_to_object_page = {o["object_type"]: o["object_page"] for o in objects}
|
114
77
|
|
115
78
|
@classmethod
|
116
79
|
def get_object_type_by_edsl_class(cls, edsl_object: EDSLObject) -> ObjectType:
|
@@ -133,5 +96,28 @@ class ObjectRegistry:
|
|
133
96
|
return EDSL_object
|
134
97
|
|
135
98
|
@classmethod
|
136
|
-
def
|
137
|
-
|
99
|
+
def get_registry(
|
100
|
+
cls,
|
101
|
+
subclass_registry: Optional[dict] = None,
|
102
|
+
exclude_classes: Optional[list] = None,
|
103
|
+
) -> dict:
|
104
|
+
"""
|
105
|
+
Return the registry of objects.
|
106
|
+
|
107
|
+
Exclude objects that are already registered in subclass_registry.
|
108
|
+
This allows the user to isolate Coop-only objects.
|
109
|
+
|
110
|
+
Also exclude objects if their class name is in the exclude_classes list.
|
111
|
+
"""
|
112
|
+
|
113
|
+
if subclass_registry is None:
|
114
|
+
subclass_registry = {}
|
115
|
+
if exclude_classes is None:
|
116
|
+
exclude_classes = []
|
117
|
+
|
118
|
+
return {
|
119
|
+
class_name: o["edsl_class"]
|
120
|
+
for o in cls.objects
|
121
|
+
if (class_name := o["edsl_class"].__name__) not in subclass_registry
|
122
|
+
and class_name not in exclude_classes
|
123
|
+
}
|
edsl/data/Cache.py
CHANGED
@@ -7,12 +7,13 @@ import json
|
|
7
7
|
import os
|
8
8
|
import warnings
|
9
9
|
from typing import Optional, Union
|
10
|
-
|
10
|
+
import time
|
11
11
|
from edsl.config import CONFIG
|
12
12
|
from edsl.data.CacheEntry import CacheEntry
|
13
|
-
from edsl.data.SQLiteDict import SQLiteDict
|
14
|
-
from edsl.Base import Base
|
15
13
|
|
14
|
+
# from edsl.data.SQLiteDict import SQLiteDict
|
15
|
+
from edsl.Base import Base
|
16
|
+
from edsl.utilities.utilities import dict_hash
|
16
17
|
from edsl.utilities.decorators import (
|
17
18
|
add_edsl_version,
|
18
19
|
remove_edsl_version,
|
@@ -24,7 +25,6 @@ class Cache(Base):
|
|
24
25
|
A class that represents a cache of responses from a language model.
|
25
26
|
|
26
27
|
:param data: The data to initialize the cache with.
|
27
|
-
:param remote: Whether to sync the Cache with the server.
|
28
28
|
:param immediate_write: Whether to write to the cache immediately after storing a new entry.
|
29
29
|
|
30
30
|
Deprecated:
|
@@ -37,24 +37,51 @@ class Cache(Base):
|
|
37
37
|
def __init__(
|
38
38
|
self,
|
39
39
|
*,
|
40
|
-
|
41
|
-
|
40
|
+
filename: Optional[str] = None,
|
41
|
+
data: Optional[Union["SQLiteDict", dict]] = None,
|
42
42
|
immediate_write: bool = True,
|
43
43
|
method=None,
|
44
44
|
):
|
45
45
|
"""
|
46
46
|
Create two dictionaries to store the cache data.
|
47
47
|
|
48
|
+
:param filename: The name of the file to read/write the cache from/to.
|
49
|
+
:param data: The data to initialize the cache with.
|
50
|
+
:param immediate_write: Whether to write to the cache immediately after storing a new entry.
|
51
|
+
:param method: The method of storage to use for the cache.
|
52
|
+
|
48
53
|
"""
|
49
|
-
|
54
|
+
|
50
55
|
# self.data_at_init = data or {}
|
51
56
|
self.fetched_data = {}
|
52
|
-
self.remote = remote
|
53
57
|
self.immediate_write = immediate_write
|
54
58
|
self.method = method
|
55
59
|
self.new_entries = {}
|
56
60
|
self.new_entries_to_write_later = {}
|
57
61
|
self.coop = None
|
62
|
+
|
63
|
+
self.filename = filename
|
64
|
+
if filename and data:
|
65
|
+
raise ValueError("Cannot provide both filename and data")
|
66
|
+
if filename is None and data is None:
|
67
|
+
data = {}
|
68
|
+
if data is not None:
|
69
|
+
self.data = data
|
70
|
+
if filename is not None:
|
71
|
+
self.data = {}
|
72
|
+
if filename.endswith(".jsonl"):
|
73
|
+
if os.path.exists(filename):
|
74
|
+
self.add_from_jsonl(filename)
|
75
|
+
else:
|
76
|
+
print(
|
77
|
+
f"File {filename} not found, but will write to this location."
|
78
|
+
)
|
79
|
+
elif filename.endswith(".db"):
|
80
|
+
if os.path.exists(filename):
|
81
|
+
self.add_from_sqlite(filename)
|
82
|
+
else:
|
83
|
+
raise ValueError("Invalid file extension. Must be .jsonl or .db")
|
84
|
+
|
58
85
|
self._perform_checks()
|
59
86
|
|
60
87
|
def rich_print(sefl):
|
@@ -77,14 +104,12 @@ class Cache(Base):
|
|
77
104
|
|
78
105
|
def _perform_checks(self):
|
79
106
|
"""Perform checks on the cache."""
|
107
|
+
from edsl.data.CacheEntry import CacheEntry
|
108
|
+
|
80
109
|
if any(not isinstance(value, CacheEntry) for value in self.data.values()):
|
81
110
|
raise Exception("Not all values are CacheEntry instances")
|
82
111
|
if self.method is not None:
|
83
112
|
warnings.warn("Argument `method` is deprecated", DeprecationWarning)
|
84
|
-
if self.remote:
|
85
|
-
from edsl.coop import Coop
|
86
|
-
|
87
|
-
self.coop = Coop()
|
88
113
|
|
89
114
|
####################
|
90
115
|
# READ/WRITE
|
@@ -115,6 +140,8 @@ class Cache(Base):
|
|
115
140
|
|
116
141
|
|
117
142
|
"""
|
143
|
+
from edsl.data.CacheEntry import CacheEntry
|
144
|
+
|
118
145
|
key = CacheEntry.gen_key(
|
119
146
|
model=model,
|
120
147
|
parameters=parameters,
|
@@ -148,6 +175,7 @@ class Cache(Base):
|
|
148
175
|
* If `immediate_write` is True , the key-value pair is added to `self.data`
|
149
176
|
* If `immediate_write` is False, the key-value pair is added to `self.new_entries_to_write_later`
|
150
177
|
"""
|
178
|
+
|
151
179
|
entry = CacheEntry(
|
152
180
|
model=model,
|
153
181
|
parameters=parameters,
|
@@ -165,13 +193,14 @@ class Cache(Base):
|
|
165
193
|
return key
|
166
194
|
|
167
195
|
def add_from_dict(
|
168
|
-
self, new_data: dict[str, CacheEntry], write_now: Optional[bool] = True
|
196
|
+
self, new_data: dict[str, "CacheEntry"], write_now: Optional[bool] = True
|
169
197
|
) -> None:
|
170
198
|
"""
|
171
199
|
Add entries to the cache from a dictionary.
|
172
200
|
|
173
201
|
:param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
|
174
202
|
"""
|
203
|
+
|
175
204
|
for key, value in new_data.items():
|
176
205
|
if key in self.data:
|
177
206
|
if value != self.data[key]:
|
@@ -208,6 +237,8 @@ class Cache(Base):
|
|
208
237
|
|
209
238
|
:param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
|
210
239
|
"""
|
240
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
241
|
+
|
211
242
|
db = SQLiteDict(db_path)
|
212
243
|
new_data = {}
|
213
244
|
for key, value in db.items():
|
@@ -219,6 +250,8 @@ class Cache(Base):
|
|
219
250
|
"""
|
220
251
|
Construct a Cache from a SQLite database.
|
221
252
|
"""
|
253
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
254
|
+
|
222
255
|
return cls(data=SQLiteDict(db_path))
|
223
256
|
|
224
257
|
@classmethod
|
@@ -245,6 +278,8 @@ class Cache(Base):
|
|
245
278
|
* If `db_path` is provided, the cache will be stored in an SQLite database.
|
246
279
|
"""
|
247
280
|
# if a file doesn't exist at jsonfile, throw an error
|
281
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
282
|
+
|
248
283
|
if not os.path.exists(jsonlfile):
|
249
284
|
raise FileNotFoundError(f"File {jsonlfile} not found")
|
250
285
|
|
@@ -263,10 +298,25 @@ class Cache(Base):
|
|
263
298
|
"""
|
264
299
|
## TODO: Check to make sure not over-writing (?)
|
265
300
|
## Should be added to SQLiteDict constructor (?)
|
301
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
302
|
+
|
266
303
|
new_data = SQLiteDict(db_path)
|
267
304
|
for key, value in self.data.items():
|
268
305
|
new_data[key] = value
|
269
306
|
|
307
|
+
def write(self, filename: Optional[str] = None) -> None:
|
308
|
+
"""
|
309
|
+
Write the cache to a file at the specified location.
|
310
|
+
"""
|
311
|
+
if filename is None:
|
312
|
+
filename = self.filename
|
313
|
+
if filename.endswith(".jsonl"):
|
314
|
+
self.write_jsonl(filename)
|
315
|
+
elif filename.endswith(".db"):
|
316
|
+
self.write_sqlite_db(filename)
|
317
|
+
else:
|
318
|
+
raise ValueError("Invalid file extension. Must be .jsonl or .db")
|
319
|
+
|
270
320
|
def write_jsonl(self, filename: str) -> None:
|
271
321
|
"""
|
272
322
|
Write the cache to a JSONL file.
|
@@ -295,11 +345,6 @@ class Cache(Base):
|
|
295
345
|
"""
|
296
346
|
Run when a context is entered.
|
297
347
|
"""
|
298
|
-
if self.remote:
|
299
|
-
print("Syncing local and remote caches")
|
300
|
-
exclude_keys = list(self.data.keys())
|
301
|
-
cache_entries = self.coop.get_cache_entries(exclude_keys)
|
302
|
-
self.add_from_dict({c.key: c for c in cache_entries}, write_now=True)
|
303
348
|
return self
|
304
349
|
|
305
350
|
def __exit__(self, exc_type, exc_value, traceback):
|
@@ -308,16 +353,21 @@ class Cache(Base):
|
|
308
353
|
"""
|
309
354
|
for key, entry in self.new_entries_to_write_later.items():
|
310
355
|
self.data[key] = entry
|
311
|
-
if self.remote:
|
312
|
-
_ = self.coop.create_cache_entries(cache_dict=self.new_entries)
|
313
356
|
|
314
357
|
####################
|
315
358
|
# DUNDER / USEFUL
|
316
359
|
####################
|
360
|
+
def __hash__(self):
|
361
|
+
"""Return the hash of the Cache."""
|
362
|
+
return dict_hash(self._to_dict())
|
363
|
+
|
364
|
+
def _to_dict(self) -> dict:
|
365
|
+
return {k: v.to_dict() for k, v in self.data.items()}
|
366
|
+
|
317
367
|
@add_edsl_version
|
318
368
|
def to_dict(self) -> dict:
|
319
369
|
"""Return the Cache as a dictionary."""
|
320
|
-
return
|
370
|
+
return self._to_dict()
|
321
371
|
|
322
372
|
def _repr_html_(self):
|
323
373
|
from edsl.utilities.utilities import data_to_html
|
@@ -359,7 +409,9 @@ class Cache(Base):
|
|
359
409
|
"""
|
360
410
|
Return a string representation of the Cache object.
|
361
411
|
"""
|
362
|
-
return
|
412
|
+
return (
|
413
|
+
f"Cache(data = {repr(self.data)}, immediate_write={self.immediate_write})"
|
414
|
+
)
|
363
415
|
|
364
416
|
####################
|
365
417
|
# EXAMPLES
|
edsl/data/CacheHandler.py
CHANGED
@@ -9,22 +9,22 @@ from edsl.data.Cache import Cache
|
|
9
9
|
from edsl.data.CacheEntry import CacheEntry
|
10
10
|
from edsl.data.SQLiteDict import SQLiteDict
|
11
11
|
|
12
|
+
from edsl.config import CONFIG
|
13
|
+
|
12
14
|
|
13
15
|
def set_session_cache(cache: Cache) -> None:
|
14
16
|
"""
|
15
17
|
Set the session cache.
|
16
18
|
"""
|
17
|
-
|
18
|
-
global _CACHE
|
19
|
-
_CACHE = cache
|
19
|
+
CONFIG.EDSL_SESSION_CACHE = cache
|
20
20
|
|
21
21
|
|
22
22
|
def unset_session_cache() -> None:
|
23
23
|
"""
|
24
24
|
Unset the session cache.
|
25
25
|
"""
|
26
|
-
|
27
|
-
|
26
|
+
if hasattr(CONFIG, "EDSL_SESSION_CACHE"):
|
27
|
+
del CONFIG.EDSL_SESSION_CACHE
|
28
28
|
|
29
29
|
|
30
30
|
class CacheHandler:
|
@@ -49,7 +49,9 @@ class CacheHandler:
|
|
49
49
|
dir_path = os.path.dirname(path)
|
50
50
|
if dir_path and not os.path.exists(dir_path):
|
51
51
|
os.makedirs(dir_path)
|
52
|
-
|
52
|
+
import warnings
|
53
|
+
|
54
|
+
warnings.warn(f"Created cache directory: {dir_path}")
|
53
55
|
|
54
56
|
def gen_cache(self) -> Cache:
|
55
57
|
"""
|
@@ -58,9 +60,8 @@ class CacheHandler:
|
|
58
60
|
if self.test:
|
59
61
|
return Cache(data={})
|
60
62
|
|
61
|
-
if "
|
62
|
-
|
63
|
-
return _CACHE
|
63
|
+
if hasattr(CONFIG, "EDSL_SESSION_CACHE"):
|
64
|
+
return CONFIG.EDSL_SESSION_CACHE
|
64
65
|
|
65
66
|
cache = Cache(data=SQLiteDict(self.CACHE_PATH))
|
66
67
|
return cache
|
edsl/data/SQLiteDict.py
CHANGED
@@ -1,9 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import json
|
3
|
-
from sqlalchemy import create_engine
|
4
|
-
from sqlalchemy.exc import SQLAlchemyError
|
5
|
-
from sqlalchemy.orm import sessionmaker
|
6
3
|
from typing import Any, Generator, Optional, Union
|
4
|
+
|
7
5
|
from edsl.config import CONFIG
|
8
6
|
from edsl.data.CacheEntry import CacheEntry
|
9
7
|
from edsl.data.orm import Base, Data
|
@@ -25,10 +23,16 @@ class SQLiteDict:
|
|
25
23
|
>>> import os; os.unlink(temp_db_path) # Clean up the temp file after the test
|
26
24
|
|
27
25
|
"""
|
26
|
+
from sqlalchemy.exc import SQLAlchemyError
|
27
|
+
from sqlalchemy.orm import sessionmaker
|
28
|
+
from sqlalchemy import create_engine
|
29
|
+
|
28
30
|
self.db_path = db_path or CONFIG.get("EDSL_DATABASE_PATH")
|
29
31
|
if not self.db_path.startswith("sqlite:///"):
|
30
32
|
self.db_path = f"sqlite:///{self.db_path}"
|
31
33
|
try:
|
34
|
+
from edsl.data.orm import Base, Data
|
35
|
+
|
32
36
|
self.engine = create_engine(self.db_path, echo=False, future=True)
|
33
37
|
Base.metadata.create_all(self.engine)
|
34
38
|
self.Session = sessionmaker(bind=self.engine)
|
@@ -55,6 +59,8 @@ class SQLiteDict:
|
|
55
59
|
if not isinstance(value, CacheEntry):
|
56
60
|
raise ValueError(f"Value must be a CacheEntry object (got {type(value)}).")
|
57
61
|
with self.Session() as db:
|
62
|
+
from edsl.data.orm import Base, Data
|
63
|
+
|
58
64
|
db.merge(Data(key=key, value=json.dumps(value.to_dict())))
|
59
65
|
db.commit()
|
60
66
|
|
@@ -69,6 +75,8 @@ class SQLiteDict:
|
|
69
75
|
True
|
70
76
|
"""
|
71
77
|
with self.Session() as db:
|
78
|
+
from edsl.data.orm import Base, Data
|
79
|
+
|
72
80
|
value = db.query(Data).filter_by(key=key).first()
|
73
81
|
if not value:
|
74
82
|
raise KeyError(f"Key '{key}' not found.")
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import aiohttp
|
2
2
|
import json
|
3
3
|
import requests
|
4
|
-
from typing import Any
|
4
|
+
from typing import Any, List
|
5
5
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
6
6
|
from edsl.language_models import LanguageModel
|
7
7
|
|
@@ -12,6 +12,8 @@ class DeepInfraService(InferenceServiceABC):
|
|
12
12
|
_inference_service_ = "deep_infra"
|
13
13
|
_env_key_name_ = "DEEP_INFRA_API_KEY"
|
14
14
|
|
15
|
+
_models_list_cache: List[str] = []
|
16
|
+
|
15
17
|
@classmethod
|
16
18
|
def available(cls):
|
17
19
|
text_models = cls.full_details_available()
|
@@ -19,20 +21,25 @@ class DeepInfraService(InferenceServiceABC):
|
|
19
21
|
|
20
22
|
@classmethod
|
21
23
|
def full_details_available(cls, verbose=False):
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
24
|
+
if not cls._models_list_cache:
|
25
|
+
url = "https://api.deepinfra.com/models/list"
|
26
|
+
response = requests.get(url)
|
27
|
+
if response.status_code == 200:
|
28
|
+
text_generation_models = [
|
29
|
+
r for r in response.json() if r["type"] == "text-generation"
|
30
|
+
]
|
31
|
+
cls._models_list_cache = text_generation_models
|
32
|
+
|
33
|
+
from rich import print_json
|
34
|
+
import json
|
30
35
|
|
31
|
-
|
32
|
-
|
33
|
-
|
36
|
+
if verbose:
|
37
|
+
print_json(json.dumps(text_generation_models))
|
38
|
+
return text_generation_models
|
39
|
+
else:
|
40
|
+
return f"Failed to fetch data: Status code {response.status_code}"
|
34
41
|
else:
|
35
|
-
return
|
42
|
+
return cls._models_list_cache
|
36
43
|
|
37
44
|
@classmethod
|
38
45
|
def create_model(cls, model_name: str, model_class_name=None) -> LanguageModel:
|
@@ -60,7 +60,13 @@ class GoogleService(InferenceServiceABC):
|
|
60
60
|
|
61
61
|
def parse_response(self, raw_response: dict[str, Any]) -> str:
|
62
62
|
data = raw_response
|
63
|
-
|
63
|
+
try:
|
64
|
+
return data["candidates"][0]["content"]["parts"][0]["text"]
|
65
|
+
except KeyError as e:
|
66
|
+
print(
|
67
|
+
f"The data return was {data}, which was missing the key 'candidates'"
|
68
|
+
)
|
69
|
+
raise e
|
64
70
|
|
65
71
|
LLM.__name__ = model_name
|
66
72
|
|
@@ -1,21 +1,47 @@
|
|
1
1
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
2
|
+
import warnings
|
2
3
|
|
3
4
|
|
4
5
|
class InferenceServicesCollection:
|
6
|
+
added_models = {}
|
7
|
+
|
5
8
|
def __init__(self, services: list[InferenceServiceABC] = None):
|
6
9
|
self.services = services or []
|
7
10
|
|
11
|
+
@classmethod
|
12
|
+
def add_model(cls, service_name, model_name):
|
13
|
+
if service_name not in cls.added_models:
|
14
|
+
cls.added_models[service_name] = []
|
15
|
+
cls.added_models[service_name].append(model_name)
|
16
|
+
|
17
|
+
@staticmethod
|
18
|
+
def _get_service_available(service) -> list[str]:
|
19
|
+
from_api = True
|
20
|
+
try:
|
21
|
+
service_models = service.available()
|
22
|
+
except Exception as e:
|
23
|
+
warnings.warn(
|
24
|
+
f"Error getting models for {service._inference_service_}. Relying on cache.",
|
25
|
+
UserWarning,
|
26
|
+
)
|
27
|
+
from edsl.inference_services.models_available_cache import models_available
|
28
|
+
|
29
|
+
service_models = models_available.get(service._inference_service_, [])
|
30
|
+
# cache results
|
31
|
+
service._models_list_cache = service_models
|
32
|
+
from_api = False
|
33
|
+
return service_models # , from_api
|
34
|
+
|
8
35
|
def available(self):
|
9
36
|
total_models = []
|
10
37
|
for service in self.services:
|
11
|
-
|
12
|
-
service_models = service.available()
|
13
|
-
except Exception as e:
|
14
|
-
print(f"Error getting models for {service._inference_service_}: {e}")
|
15
|
-
service_models = []
|
16
|
-
continue
|
38
|
+
service_models = self._get_service_available(service)
|
17
39
|
for model in service_models:
|
18
40
|
total_models.append([model, service._inference_service_, -1])
|
41
|
+
|
42
|
+
for model in self.added_models.get(service._inference_service_, []):
|
43
|
+
total_models.append([model, service._inference_service_, -1])
|
44
|
+
|
19
45
|
sorted_models = sorted(total_models)
|
20
46
|
for i, model in enumerate(sorted_models):
|
21
47
|
model[2] = i
|
@@ -27,7 +53,7 @@ class InferenceServicesCollection:
|
|
27
53
|
|
28
54
|
def create_model_factory(self, model_name: str, service_name=None, index=None):
|
29
55
|
for service in self.services:
|
30
|
-
if model_name in
|
56
|
+
if model_name in self._get_service_available(service):
|
31
57
|
if service_name is None or service_name == service._inference_service_:
|
32
58
|
return service.create_model(model_name)
|
33
59
|
|
@@ -4,6 +4,7 @@ from openai import AsyncOpenAI
|
|
4
4
|
|
5
5
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
6
6
|
from edsl.language_models import LanguageModel
|
7
|
+
from edsl.inference_services.rate_limits_cache import rate_limits
|
7
8
|
|
8
9
|
|
9
10
|
class OpenAIService(InferenceServiceABC):
|
@@ -43,15 +44,16 @@ class OpenAIService(InferenceServiceABC):
|
|
43
44
|
if m.id not in cls.model_exclude_list
|
44
45
|
]
|
45
46
|
except Exception as e:
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
47
|
+
raise
|
48
|
+
# print(
|
49
|
+
# f"""Error retrieving models: {e}.
|
50
|
+
# See instructions about storing your API keys: https://docs.expectedparrot.com/en/latest/api_keys.html"""
|
51
|
+
# )
|
52
|
+
# cls._models_list_cache = [
|
53
|
+
# "gpt-3.5-turbo",
|
54
|
+
# "gpt-4-1106-preview",
|
55
|
+
# "gpt-4",
|
56
|
+
# ] # Fallback list
|
55
57
|
return cls._models_list_cache
|
56
58
|
|
57
59
|
@classmethod
|
@@ -98,7 +100,12 @@ class OpenAIService(InferenceServiceABC):
|
|
98
100
|
|
99
101
|
def get_rate_limits(self) -> dict[str, Any]:
|
100
102
|
try:
|
101
|
-
|
103
|
+
if "openai" in rate_limits:
|
104
|
+
headers = rate_limits["openai"]
|
105
|
+
|
106
|
+
else:
|
107
|
+
headers = self.get_headers()
|
108
|
+
|
102
109
|
except Exception as e:
|
103
110
|
return {
|
104
111
|
"rpm": 10_000,
|