edsl 0.1.29.dev6__py3-none-any.whl → 0.1.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +6 -3
- edsl/__init__.py +23 -23
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +43 -40
- edsl/agents/AgentList.py +23 -22
- edsl/agents/Invigilator.py +19 -2
- edsl/agents/descriptors.py +2 -1
- edsl/base/Base.py +289 -0
- edsl/config.py +2 -1
- edsl/conversation/car_buying.py +1 -1
- edsl/coop/utils.py +28 -1
- edsl/data/Cache.py +41 -18
- edsl/data/CacheEntry.py +6 -7
- edsl/data/SQLiteDict.py +11 -3
- edsl/data_transfer_models.py +4 -0
- edsl/jobs/Answers.py +15 -1
- edsl/jobs/Jobs.py +86 -33
- edsl/jobs/buckets/ModelBuckets.py +14 -2
- edsl/jobs/buckets/TokenBucket.py +32 -5
- edsl/jobs/interviews/Interview.py +99 -79
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +18 -24
- edsl/jobs/runners/JobsRunnerAsyncio.py +16 -16
- edsl/jobs/tasks/QuestionTaskCreator.py +10 -6
- edsl/jobs/tasks/TaskHistory.py +4 -3
- edsl/language_models/LanguageModel.py +17 -17
- edsl/language_models/ModelList.py +1 -1
- edsl/language_models/repair.py +8 -7
- edsl/notebooks/Notebook.py +16 -10
- edsl/questions/QuestionBase.py +6 -2
- edsl/questions/QuestionBudget.py +5 -6
- edsl/questions/QuestionCheckBox.py +7 -3
- edsl/questions/QuestionExtract.py +5 -3
- edsl/questions/QuestionFreeText.py +7 -5
- edsl/questions/QuestionFunctional.py +34 -5
- edsl/questions/QuestionList.py +3 -4
- edsl/questions/QuestionMultipleChoice.py +68 -12
- edsl/questions/QuestionNumerical.py +4 -3
- edsl/questions/QuestionRank.py +5 -3
- edsl/questions/__init__.py +4 -3
- edsl/questions/descriptors.py +46 -4
- edsl/results/DatasetExportMixin.py +570 -0
- edsl/results/Result.py +66 -70
- edsl/results/Results.py +160 -68
- edsl/results/ResultsDBMixin.py +7 -3
- edsl/results/ResultsExportMixin.py +22 -537
- edsl/results/ResultsGGMixin.py +3 -3
- edsl/results/ResultsToolsMixin.py +1 -4
- edsl/scenarios/FileStore.py +299 -0
- edsl/scenarios/Scenario.py +16 -24
- edsl/scenarios/ScenarioList.py +25 -14
- edsl/scenarios/ScenarioListExportMixin.py +32 -0
- edsl/scenarios/ScenarioListPdfMixin.py +2 -1
- edsl/scenarios/__init__.py +1 -0
- edsl/study/Study.py +5 -7
- edsl/surveys/MemoryPlan.py +11 -4
- edsl/surveys/Survey.py +52 -15
- edsl/surveys/SurveyExportMixin.py +4 -2
- edsl/surveys/SurveyFlowVisualizationMixin.py +6 -4
- edsl/utilities/__init__.py +21 -21
- edsl/utilities/interface.py +66 -45
- edsl/utilities/utilities.py +11 -13
- {edsl-0.1.29.dev6.dist-info → edsl-0.1.30.dist-info}/METADATA +1 -1
- {edsl-0.1.29.dev6.dist-info → edsl-0.1.30.dist-info}/RECORD +65 -61
- {edsl-0.1.29.dev6.dist-info → edsl-0.1.30.dist-info}/WHEEL +1 -1
- {edsl-0.1.29.dev6.dist-info → edsl-0.1.30.dist-info}/LICENSE +0 -0
edsl/base/Base.py
ADDED
@@ -0,0 +1,289 @@
|
|
1
|
+
"""Base class for all classes in the package. It provides rich printing and persistence of objects."""
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod, ABCMeta
|
4
|
+
import gzip
|
5
|
+
import io
|
6
|
+
import json
|
7
|
+
from typing import Any, Optional, Union
|
8
|
+
from uuid import UUID
|
9
|
+
from IPython.display import display
|
10
|
+
from rich.console import Console
|
11
|
+
|
12
|
+
|
13
|
+
class RichPrintingMixin:
|
14
|
+
"""Mixin for rich printing and persistence of objects."""
|
15
|
+
|
16
|
+
def _for_console(self):
|
17
|
+
"""Return a string representation of the object for console printing."""
|
18
|
+
with io.StringIO() as buf:
|
19
|
+
console = Console(file=buf, record=True)
|
20
|
+
table = self.rich_print()
|
21
|
+
console.print(table)
|
22
|
+
return console.export_text()
|
23
|
+
|
24
|
+
def __str__(self):
|
25
|
+
"""Return a string representation of the object for console printing."""
|
26
|
+
return self._for_console()
|
27
|
+
|
28
|
+
def print(self):
|
29
|
+
"""Print the object to the console."""
|
30
|
+
from edsl.utilities.utilities import is_notebook
|
31
|
+
|
32
|
+
if is_notebook():
|
33
|
+
display(self.rich_print())
|
34
|
+
else:
|
35
|
+
from rich.console import Console
|
36
|
+
|
37
|
+
console = Console()
|
38
|
+
console.print(self.rich_print())
|
39
|
+
|
40
|
+
|
41
|
+
class PersistenceMixin:
|
42
|
+
"""Mixin for saving and loading objects to and from files."""
|
43
|
+
|
44
|
+
def push(
|
45
|
+
self,
|
46
|
+
description: Optional[str] = None,
|
47
|
+
visibility: Optional[str] = "unlisted",
|
48
|
+
):
|
49
|
+
"""Post the object to coop."""
|
50
|
+
from edsl.coop import Coop
|
51
|
+
|
52
|
+
c = Coop()
|
53
|
+
return c.create(self, description, visibility)
|
54
|
+
|
55
|
+
@classmethod
|
56
|
+
def pull(cls, id_or_url: Union[str, UUID], exec_profile=None):
|
57
|
+
"""Pull the object from coop."""
|
58
|
+
from edsl.coop import Coop
|
59
|
+
|
60
|
+
if id_or_url.startswith("http"):
|
61
|
+
uuid_value = id_or_url.split("/")[-1]
|
62
|
+
else:
|
63
|
+
uuid_value = id_or_url
|
64
|
+
|
65
|
+
c = Coop()
|
66
|
+
|
67
|
+
return c._get_base(cls, uuid_value, exec_profile=exec_profile)
|
68
|
+
|
69
|
+
@classmethod
|
70
|
+
def delete(cls, id_or_url: Union[str, UUID]):
|
71
|
+
"""Delete the object from coop."""
|
72
|
+
from edsl.coop import Coop
|
73
|
+
|
74
|
+
c = Coop()
|
75
|
+
return c._delete_base(cls, id_or_url)
|
76
|
+
|
77
|
+
@classmethod
|
78
|
+
def patch(
|
79
|
+
cls,
|
80
|
+
id_or_url: Union[str, UUID],
|
81
|
+
description: Optional[str] = None,
|
82
|
+
value: Optional[Any] = None,
|
83
|
+
visibility: Optional[str] = None,
|
84
|
+
):
|
85
|
+
"""
|
86
|
+
Patch an uploaded objects attributes.
|
87
|
+
- `description` changes the description of the object on Coop
|
88
|
+
- `value` changes the value of the object on Coop. **has to be an EDSL object**
|
89
|
+
- `visibility` changes the visibility of the object on Coop
|
90
|
+
"""
|
91
|
+
from edsl.coop import Coop
|
92
|
+
|
93
|
+
c = Coop()
|
94
|
+
return c._patch_base(cls, id_or_url, description, value, visibility)
|
95
|
+
|
96
|
+
@classmethod
|
97
|
+
def search(cls, query):
|
98
|
+
"""Search for objects on coop."""
|
99
|
+
from edsl.coop import Coop
|
100
|
+
|
101
|
+
c = Coop()
|
102
|
+
return c.search(cls, query)
|
103
|
+
|
104
|
+
def save(self, filename, compress=True):
|
105
|
+
"""Save the object to a file as zippped JSON.
|
106
|
+
|
107
|
+
>>> obj.save("obj.json.gz")
|
108
|
+
|
109
|
+
"""
|
110
|
+
if filename.endswith("json.gz"):
|
111
|
+
import warnings
|
112
|
+
|
113
|
+
warnings.warn(
|
114
|
+
"Do not apply the file extensions. The filename should not end with 'json.gz'."
|
115
|
+
)
|
116
|
+
filename = filename[:-7]
|
117
|
+
if filename.endswith("json"):
|
118
|
+
filename = filename[:-4]
|
119
|
+
warnings.warn(
|
120
|
+
"Do not apply the file extensions. The filename should not end with 'json'."
|
121
|
+
)
|
122
|
+
|
123
|
+
if compress:
|
124
|
+
with gzip.open(filename + ".json.gz", "wb") as f:
|
125
|
+
f.write(json.dumps(self.to_dict()).encode("utf-8"))
|
126
|
+
else:
|
127
|
+
with open(filename + ".json", "w") as f:
|
128
|
+
f.write(json.dumps(self.to_dict()))
|
129
|
+
|
130
|
+
@staticmethod
|
131
|
+
def open_compressed_file(filename):
|
132
|
+
with gzip.open(filename, "rb") as f:
|
133
|
+
file_contents = f.read()
|
134
|
+
file_contents_decoded = file_contents.decode("utf-8")
|
135
|
+
d = json.loads(file_contents_decoded)
|
136
|
+
return d
|
137
|
+
|
138
|
+
@staticmethod
|
139
|
+
def open_regular_file(filename):
|
140
|
+
with open(filename, "r") as f:
|
141
|
+
d = json.loads(f.read())
|
142
|
+
return d
|
143
|
+
|
144
|
+
@classmethod
|
145
|
+
def load(cls, filename):
|
146
|
+
"""Load the object from a file.
|
147
|
+
|
148
|
+
>>> obj = cls.load("obj.json.gz")
|
149
|
+
|
150
|
+
"""
|
151
|
+
|
152
|
+
if filename.endswith("json.gz"):
|
153
|
+
d = cls.open_compressed_file(filename)
|
154
|
+
elif filename.endswith("json"):
|
155
|
+
d = cls.open_regular_file(filename)
|
156
|
+
else:
|
157
|
+
try:
|
158
|
+
d = cls.open_compressed_file(filename)
|
159
|
+
except:
|
160
|
+
d = cls.open_regular_file(filename)
|
161
|
+
finally:
|
162
|
+
raise ValueError("File must be a json or json.gz file")
|
163
|
+
|
164
|
+
return cls.from_dict(d)
|
165
|
+
|
166
|
+
|
167
|
+
class RegisterSubclassesMeta(ABCMeta):
|
168
|
+
"""Metaclass for registering subclasses."""
|
169
|
+
|
170
|
+
_registry = {}
|
171
|
+
|
172
|
+
def __init__(cls, name, bases, nmspc):
|
173
|
+
"""Register the class in the registry upon creation."""
|
174
|
+
super(RegisterSubclassesMeta, cls).__init__(name, bases, nmspc)
|
175
|
+
if cls.__name__ != "Base":
|
176
|
+
RegisterSubclassesMeta._registry[cls.__name__] = cls
|
177
|
+
|
178
|
+
@staticmethod
|
179
|
+
def get_registry():
|
180
|
+
"""Return the registry of subclasses."""
|
181
|
+
return dict(RegisterSubclassesMeta._registry)
|
182
|
+
|
183
|
+
|
184
|
+
class DiffMethodsMixin:
|
185
|
+
def __sub__(self, other):
|
186
|
+
"""Return the difference between two objects."""
|
187
|
+
from edsl.BaseDiff import BaseDiff
|
188
|
+
|
189
|
+
return BaseDiff(self, other)
|
190
|
+
|
191
|
+
|
192
|
+
class Base(
|
193
|
+
RichPrintingMixin,
|
194
|
+
PersistenceMixin,
|
195
|
+
DiffMethodsMixin,
|
196
|
+
ABC,
|
197
|
+
metaclass=RegisterSubclassesMeta,
|
198
|
+
):
|
199
|
+
"""Base class for all classes in the package."""
|
200
|
+
|
201
|
+
# def __getitem__(self, key):
|
202
|
+
# return getattr(self, key)
|
203
|
+
|
204
|
+
# @abstractmethod
|
205
|
+
# def _repr_html_(self) -> str:
|
206
|
+
# raise NotImplementedError("This method is not implemented yet.")
|
207
|
+
|
208
|
+
# @abstractmethod
|
209
|
+
# def _repr_(self) -> str:
|
210
|
+
# raise NotImplementedError("This method is not implemented yet.")
|
211
|
+
|
212
|
+
def keys(self):
|
213
|
+
"""Return the keys of the object."""
|
214
|
+
_keys = list(self.to_dict().keys())
|
215
|
+
if "edsl_version" in _keys:
|
216
|
+
_keys.remove("edsl_version")
|
217
|
+
if "edsl_class_name" in _keys:
|
218
|
+
_keys.remove("edsl_class_name")
|
219
|
+
return _keys
|
220
|
+
|
221
|
+
def values(self):
|
222
|
+
"""Return the values of the object."""
|
223
|
+
data = self.to_dict()
|
224
|
+
keys = self.keys()
|
225
|
+
return {data[key] for key in keys}
|
226
|
+
|
227
|
+
def _repr_html_(self):
|
228
|
+
from edsl.utilities.utilities import data_to_html
|
229
|
+
|
230
|
+
return data_to_html(self.to_dict())
|
231
|
+
|
232
|
+
# def html(self):
|
233
|
+
# html_string = self._repr_html_()
|
234
|
+
# import tempfile
|
235
|
+
# import webbrowser
|
236
|
+
|
237
|
+
# with tempfile.NamedTemporaryFile("w", delete=False, suffix=".html") as f:
|
238
|
+
# # print("Writing HTML to", f.name)
|
239
|
+
# f.write(html_string)
|
240
|
+
# webbrowser.open(f.name)
|
241
|
+
|
242
|
+
def __eq__(self, other):
|
243
|
+
"""Return whether two objects are equal."""
|
244
|
+
import inspect
|
245
|
+
|
246
|
+
if not isinstance(other, self.__class__):
|
247
|
+
return False
|
248
|
+
if "sort" in inspect.signature(self._to_dict).parameters:
|
249
|
+
return self._to_dict(sort=True) == other._to_dict(sort=True)
|
250
|
+
else:
|
251
|
+
return self._to_dict() == other._to_dict()
|
252
|
+
|
253
|
+
@abstractmethod
|
254
|
+
def example():
|
255
|
+
"""This method should be implemented by subclasses."""
|
256
|
+
raise NotImplementedError("This method is not implemented yet.")
|
257
|
+
|
258
|
+
@abstractmethod
|
259
|
+
def rich_print():
|
260
|
+
"""This method should be implemented by subclasses."""
|
261
|
+
raise NotImplementedError("This method is not implemented yet.")
|
262
|
+
|
263
|
+
@abstractmethod
|
264
|
+
def to_dict():
|
265
|
+
"""This method should be implemented by subclasses."""
|
266
|
+
raise NotImplementedError("This method is not implemented yet.")
|
267
|
+
|
268
|
+
@abstractmethod
|
269
|
+
def from_dict():
|
270
|
+
"""This method should be implemented by subclasses."""
|
271
|
+
raise NotImplementedError("This method is not implemented yet.")
|
272
|
+
|
273
|
+
@abstractmethod
|
274
|
+
def code():
|
275
|
+
"""This method should be implemented by subclasses."""
|
276
|
+
raise NotImplementedError("This method is not implemented yet.")
|
277
|
+
|
278
|
+
def show_methods(self, show_docstrings=True):
|
279
|
+
"""Show the methods of the object."""
|
280
|
+
public_methods_with_docstrings = [
|
281
|
+
(method, getattr(self, method).__doc__)
|
282
|
+
for method in dir(self)
|
283
|
+
if callable(getattr(self, method)) and not method.startswith("_")
|
284
|
+
]
|
285
|
+
if show_docstrings:
|
286
|
+
for method, documentation in public_methods_with_docstrings:
|
287
|
+
print(f"{method}: {documentation}")
|
288
|
+
else:
|
289
|
+
return [x[0] for x in public_methods_with_docstrings]
|
edsl/config.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1
1
|
"""This module provides a Config class that loads environment variables from a .env file and sets them as class attributes."""
|
2
2
|
|
3
3
|
import os
|
4
|
-
from dotenv import load_dotenv, find_dotenv
|
5
4
|
from edsl.exceptions import (
|
6
5
|
InvalidEnvironmentVariableError,
|
7
6
|
MissingEnvironmentVariableError,
|
8
7
|
)
|
8
|
+
from dotenv import load_dotenv, find_dotenv
|
9
9
|
|
10
10
|
# valid values for EDSL_RUN_MODE
|
11
11
|
EDSL_RUN_MODES = ["development", "development-testrun", "production"]
|
@@ -96,6 +96,7 @@ class Config:
|
|
96
96
|
Loads the .env
|
97
97
|
- Overrides existing env vars unless EDSL_RUN_MODE=="development-testrun"
|
98
98
|
"""
|
99
|
+
|
99
100
|
override = True
|
100
101
|
if self.EDSL_RUN_MODE == "development-testrun":
|
101
102
|
override = False
|
edsl/conversation/car_buying.py
CHANGED
@@ -30,7 +30,7 @@ c1 = Conversation(agent_list=AgentList([a1, a3, a2]), max_turns=5, verbose=True)
|
|
30
30
|
c2 = Conversation(agent_list=AgentList([a1, a2]), max_turns=5, verbose=True)
|
31
31
|
|
32
32
|
c = Cache.load("car_talk.json.gz")
|
33
|
-
breakpoint()
|
33
|
+
# breakpoint()
|
34
34
|
combo = ConversationList([c1, c2], cache=c)
|
35
35
|
combo.run()
|
36
36
|
results = combo.to_results()
|
edsl/coop/utils.py
CHANGED
@@ -10,7 +10,7 @@ from edsl import (
|
|
10
10
|
Study,
|
11
11
|
)
|
12
12
|
from edsl.questions import QuestionBase
|
13
|
-
from typing import Literal, Type, Union
|
13
|
+
from typing import Literal, Optional, Type, Union
|
14
14
|
|
15
15
|
EDSLObject = Union[
|
16
16
|
Agent,
|
@@ -94,3 +94,30 @@ class ObjectRegistry:
|
|
94
94
|
if EDSL_object is None:
|
95
95
|
raise ValueError(f"EDSL class not found for {object_type=}")
|
96
96
|
return EDSL_object
|
97
|
+
|
98
|
+
@classmethod
|
99
|
+
def get_registry(
|
100
|
+
cls,
|
101
|
+
subclass_registry: Optional[dict] = None,
|
102
|
+
exclude_classes: Optional[list] = None,
|
103
|
+
) -> dict:
|
104
|
+
"""
|
105
|
+
Return the registry of objects.
|
106
|
+
|
107
|
+
Exclude objects that are already registered in subclass_registry.
|
108
|
+
This allows the user to isolate Coop-only objects.
|
109
|
+
|
110
|
+
Also exclude objects if their class name is in the exclude_classes list.
|
111
|
+
"""
|
112
|
+
|
113
|
+
if subclass_registry is None:
|
114
|
+
subclass_registry = {}
|
115
|
+
if exclude_classes is None:
|
116
|
+
exclude_classes = []
|
117
|
+
|
118
|
+
return {
|
119
|
+
class_name: o["edsl_class"]
|
120
|
+
for o in cls.objects
|
121
|
+
if (class_name := o["edsl_class"].__name__) not in subclass_registry
|
122
|
+
and class_name not in exclude_classes
|
123
|
+
}
|
edsl/data/Cache.py
CHANGED
@@ -7,17 +7,10 @@ import json
|
|
7
7
|
import os
|
8
8
|
import warnings
|
9
9
|
from typing import Optional, Union
|
10
|
-
|
11
|
-
from edsl.config import CONFIG
|
12
|
-
from edsl.data.CacheEntry import CacheEntry
|
13
|
-
from edsl.data.SQLiteDict import SQLiteDict
|
14
10
|
from edsl.Base import Base
|
11
|
+
from edsl.data.CacheEntry import CacheEntry
|
15
12
|
from edsl.utilities.utilities import dict_hash
|
16
|
-
|
17
|
-
from edsl.utilities.decorators import (
|
18
|
-
add_edsl_version,
|
19
|
-
remove_edsl_version,
|
20
|
-
)
|
13
|
+
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
21
14
|
|
22
15
|
|
23
16
|
class Cache(Base):
|
@@ -38,9 +31,10 @@ class Cache(Base):
|
|
38
31
|
self,
|
39
32
|
*,
|
40
33
|
filename: Optional[str] = None,
|
41
|
-
data: Optional[Union[SQLiteDict, dict]] = None,
|
34
|
+
data: Optional[Union["SQLiteDict", dict]] = None,
|
42
35
|
immediate_write: bool = True,
|
43
36
|
method=None,
|
37
|
+
verbose=False,
|
44
38
|
):
|
45
39
|
"""
|
46
40
|
Create two dictionaries to store the cache data.
|
@@ -59,6 +53,7 @@ class Cache(Base):
|
|
59
53
|
self.new_entries = {}
|
60
54
|
self.new_entries_to_write_later = {}
|
61
55
|
self.coop = None
|
56
|
+
self.verbose = verbose
|
62
57
|
|
63
58
|
self.filename = filename
|
64
59
|
if filename and data:
|
@@ -104,6 +99,8 @@ class Cache(Base):
|
|
104
99
|
|
105
100
|
def _perform_checks(self):
|
106
101
|
"""Perform checks on the cache."""
|
102
|
+
from edsl.data.CacheEntry import CacheEntry
|
103
|
+
|
107
104
|
if any(not isinstance(value, CacheEntry) for value in self.data.values()):
|
108
105
|
raise Exception("Not all values are CacheEntry instances")
|
109
106
|
if self.method is not None:
|
@@ -120,7 +117,7 @@ class Cache(Base):
|
|
120
117
|
system_prompt: str,
|
121
118
|
user_prompt: str,
|
122
119
|
iteration: int,
|
123
|
-
) -> Union[None, str]:
|
120
|
+
) -> tuple(Union[None, str], str):
|
124
121
|
"""
|
125
122
|
Fetch a value (LLM output) from the cache.
|
126
123
|
|
@@ -133,11 +130,13 @@ class Cache(Base):
|
|
133
130
|
Return None if the response is not found.
|
134
131
|
|
135
132
|
>>> c = Cache()
|
136
|
-
>>> c.fetch(model="gpt-3", parameters="default", system_prompt="Hello", user_prompt="Hi", iteration=1) is None
|
133
|
+
>>> c.fetch(model="gpt-3", parameters="default", system_prompt="Hello", user_prompt="Hi", iteration=1)[0] is None
|
137
134
|
True
|
138
135
|
|
139
136
|
|
140
137
|
"""
|
138
|
+
from edsl.data.CacheEntry import CacheEntry
|
139
|
+
|
141
140
|
key = CacheEntry.gen_key(
|
142
141
|
model=model,
|
143
142
|
parameters=parameters,
|
@@ -147,8 +146,13 @@ class Cache(Base):
|
|
147
146
|
)
|
148
147
|
entry = self.data.get(key, None)
|
149
148
|
if entry is not None:
|
149
|
+
if self.verbose:
|
150
|
+
print(f"Cache hit for key: {key}")
|
150
151
|
self.fetched_data[key] = entry
|
151
|
-
|
152
|
+
else:
|
153
|
+
if self.verbose:
|
154
|
+
print(f"Cache miss for key: {key}")
|
155
|
+
return None if entry is None else entry.output, key
|
152
156
|
|
153
157
|
def store(
|
154
158
|
self,
|
@@ -171,6 +175,7 @@ class Cache(Base):
|
|
171
175
|
* If `immediate_write` is True , the key-value pair is added to `self.data`
|
172
176
|
* If `immediate_write` is False, the key-value pair is added to `self.new_entries_to_write_later`
|
173
177
|
"""
|
178
|
+
|
174
179
|
entry = CacheEntry(
|
175
180
|
model=model,
|
176
181
|
parameters=parameters,
|
@@ -188,13 +193,14 @@ class Cache(Base):
|
|
188
193
|
return key
|
189
194
|
|
190
195
|
def add_from_dict(
|
191
|
-
self, new_data: dict[str, CacheEntry], write_now: Optional[bool] = True
|
196
|
+
self, new_data: dict[str, "CacheEntry"], write_now: Optional[bool] = True
|
192
197
|
) -> None:
|
193
198
|
"""
|
194
199
|
Add entries to the cache from a dictionary.
|
195
200
|
|
196
201
|
:param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
|
197
202
|
"""
|
203
|
+
|
198
204
|
for key, value in new_data.items():
|
199
205
|
if key in self.data:
|
200
206
|
if value != self.data[key]:
|
@@ -231,6 +237,8 @@ class Cache(Base):
|
|
231
237
|
|
232
238
|
:param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
|
233
239
|
"""
|
240
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
241
|
+
|
234
242
|
db = SQLiteDict(db_path)
|
235
243
|
new_data = {}
|
236
244
|
for key, value in db.items():
|
@@ -242,6 +250,8 @@ class Cache(Base):
|
|
242
250
|
"""
|
243
251
|
Construct a Cache from a SQLite database.
|
244
252
|
"""
|
253
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
254
|
+
|
245
255
|
return cls(data=SQLiteDict(db_path))
|
246
256
|
|
247
257
|
@classmethod
|
@@ -268,6 +278,8 @@ class Cache(Base):
|
|
268
278
|
* If `db_path` is provided, the cache will be stored in an SQLite database.
|
269
279
|
"""
|
270
280
|
# if a file doesn't exist at jsonfile, throw an error
|
281
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
282
|
+
|
271
283
|
if not os.path.exists(jsonlfile):
|
272
284
|
raise FileNotFoundError(f"File {jsonlfile} not found")
|
273
285
|
|
@@ -286,6 +298,8 @@ class Cache(Base):
|
|
286
298
|
"""
|
287
299
|
## TODO: Check to make sure not over-writing (?)
|
288
300
|
## Should be added to SQLiteDict constructor (?)
|
301
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
302
|
+
|
289
303
|
new_data = SQLiteDict(db_path)
|
290
304
|
for key, value in self.data.items():
|
291
305
|
new_data[key] = value
|
@@ -340,6 +354,9 @@ class Cache(Base):
|
|
340
354
|
for key, entry in self.new_entries_to_write_later.items():
|
341
355
|
self.data[key] = entry
|
342
356
|
|
357
|
+
if self.filename:
|
358
|
+
self.write(self.filename)
|
359
|
+
|
343
360
|
####################
|
344
361
|
# DUNDER / USEFUL
|
345
362
|
####################
|
@@ -456,12 +473,18 @@ class Cache(Base):
|
|
456
473
|
webbrowser.open("file://" + filepath)
|
457
474
|
|
458
475
|
@classmethod
|
459
|
-
def example(cls) -> Cache:
|
476
|
+
def example(cls, randomize: bool = False) -> Cache:
|
460
477
|
"""
|
461
|
-
|
462
|
-
|
478
|
+
Returns an example Cache instance.
|
479
|
+
|
480
|
+
:param randomize: If True, uses CacheEntry's randomize method.
|
463
481
|
"""
|
464
|
-
return cls(
|
482
|
+
return cls(
|
483
|
+
data={
|
484
|
+
CacheEntry.example(randomize).key: CacheEntry.example(),
|
485
|
+
CacheEntry.example(randomize).key: CacheEntry.example(),
|
486
|
+
}
|
487
|
+
)
|
465
488
|
|
466
489
|
|
467
490
|
if __name__ == "__main__":
|
edsl/data/CacheEntry.py
CHANGED
@@ -2,11 +2,8 @@ from __future__ import annotations
|
|
2
2
|
import json
|
3
3
|
import datetime
|
4
4
|
import hashlib
|
5
|
-
import random
|
6
5
|
from typing import Optional
|
7
|
-
|
8
|
-
|
9
|
-
# TODO: Timestamp should probably be float?
|
6
|
+
from uuid import uuid4
|
10
7
|
|
11
8
|
|
12
9
|
class CacheEntry:
|
@@ -151,10 +148,12 @@ class CacheEntry:
|
|
151
148
|
@classmethod
|
152
149
|
def example(cls, randomize: bool = False) -> CacheEntry:
|
153
150
|
"""
|
154
|
-
Returns
|
151
|
+
Returns an example CacheEntry instance.
|
152
|
+
|
153
|
+
:param randomize: If True, adds a random string to the system prompt.
|
155
154
|
"""
|
156
|
-
# if random, create a
|
157
|
-
addition = "" if not randomize else str(
|
155
|
+
# if random, create a uuid
|
156
|
+
addition = "" if not randomize else str(uuid4())
|
158
157
|
return CacheEntry(
|
159
158
|
model="gpt-3.5-turbo",
|
160
159
|
parameters={"temperature": 0.5},
|
edsl/data/SQLiteDict.py
CHANGED
@@ -1,9 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import json
|
3
|
-
from sqlalchemy import create_engine
|
4
|
-
from sqlalchemy.exc import SQLAlchemyError
|
5
|
-
from sqlalchemy.orm import sessionmaker
|
6
3
|
from typing import Any, Generator, Optional, Union
|
4
|
+
|
7
5
|
from edsl.config import CONFIG
|
8
6
|
from edsl.data.CacheEntry import CacheEntry
|
9
7
|
from edsl.data.orm import Base, Data
|
@@ -25,10 +23,16 @@ class SQLiteDict:
|
|
25
23
|
>>> import os; os.unlink(temp_db_path) # Clean up the temp file after the test
|
26
24
|
|
27
25
|
"""
|
26
|
+
from sqlalchemy.exc import SQLAlchemyError
|
27
|
+
from sqlalchemy.orm import sessionmaker
|
28
|
+
from sqlalchemy import create_engine
|
29
|
+
|
28
30
|
self.db_path = db_path or CONFIG.get("EDSL_DATABASE_PATH")
|
29
31
|
if not self.db_path.startswith("sqlite:///"):
|
30
32
|
self.db_path = f"sqlite:///{self.db_path}"
|
31
33
|
try:
|
34
|
+
from edsl.data.orm import Base, Data
|
35
|
+
|
32
36
|
self.engine = create_engine(self.db_path, echo=False, future=True)
|
33
37
|
Base.metadata.create_all(self.engine)
|
34
38
|
self.Session = sessionmaker(bind=self.engine)
|
@@ -55,6 +59,8 @@ class SQLiteDict:
|
|
55
59
|
if not isinstance(value, CacheEntry):
|
56
60
|
raise ValueError(f"Value must be a CacheEntry object (got {type(value)}).")
|
57
61
|
with self.Session() as db:
|
62
|
+
from edsl.data.orm import Base, Data
|
63
|
+
|
58
64
|
db.merge(Data(key=key, value=json.dumps(value.to_dict())))
|
59
65
|
db.commit()
|
60
66
|
|
@@ -69,6 +75,8 @@ class SQLiteDict:
|
|
69
75
|
True
|
70
76
|
"""
|
71
77
|
with self.Session() as db:
|
78
|
+
from edsl.data.orm import Base, Data
|
79
|
+
|
72
80
|
value = db.query(Data).filter_by(key=key).first()
|
73
81
|
if not value:
|
74
82
|
raise KeyError(f"Key '{key}' not found.")
|
edsl/data_transfer_models.py
CHANGED
@@ -17,6 +17,8 @@ class AgentResponseDict(UserDict):
|
|
17
17
|
cached_response=None,
|
18
18
|
raw_model_response=None,
|
19
19
|
simple_model_raw_response=None,
|
20
|
+
cache_used=None,
|
21
|
+
cache_key=None,
|
20
22
|
):
|
21
23
|
"""Initialize the AgentResponseDict object."""
|
22
24
|
usage = usage or {"prompt_tokens": 0, "completion_tokens": 0}
|
@@ -30,5 +32,7 @@ class AgentResponseDict(UserDict):
|
|
30
32
|
"cached_response": cached_response,
|
31
33
|
"raw_model_response": raw_model_response,
|
32
34
|
"simple_model_raw_response": simple_model_raw_response,
|
35
|
+
"cache_used": cache_used,
|
36
|
+
"cache_key": cache_key,
|
33
37
|
}
|
34
38
|
)
|
edsl/jobs/Answers.py
CHANGED
@@ -8,7 +8,15 @@ class Answers(UserDict):
|
|
8
8
|
"""Helper class to hold the answers to a survey."""
|
9
9
|
|
10
10
|
def add_answer(self, response, question) -> None:
|
11
|
-
"""Add a response to the answers dictionary.
|
11
|
+
"""Add a response to the answers dictionary.
|
12
|
+
|
13
|
+
>>> from edsl import QuestionFreeText
|
14
|
+
>>> q = QuestionFreeText.example()
|
15
|
+
>>> answers = Answers()
|
16
|
+
>>> answers.add_answer({"answer": "yes"}, q)
|
17
|
+
>>> answers[q.question_name]
|
18
|
+
'yes'
|
19
|
+
"""
|
12
20
|
answer = response.get("answer")
|
13
21
|
comment = response.pop("comment", None)
|
14
22
|
# record the answer
|
@@ -42,3 +50,9 @@ class Answers(UserDict):
|
|
42
50
|
table.add_row(attr_name, repr(attr_value))
|
43
51
|
|
44
52
|
return table
|
53
|
+
|
54
|
+
|
55
|
+
if __name__ == "__main__":
|
56
|
+
import doctest
|
57
|
+
|
58
|
+
doctest.testmod()
|