edsl 0.1.29.dev6__py3-none-any.whl → 0.1.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. edsl/Base.py +6 -3
  2. edsl/__init__.py +23 -23
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +43 -40
  5. edsl/agents/AgentList.py +23 -22
  6. edsl/agents/Invigilator.py +19 -2
  7. edsl/agents/descriptors.py +2 -1
  8. edsl/base/Base.py +289 -0
  9. edsl/config.py +2 -1
  10. edsl/conversation/car_buying.py +1 -1
  11. edsl/coop/utils.py +28 -1
  12. edsl/data/Cache.py +41 -18
  13. edsl/data/CacheEntry.py +6 -7
  14. edsl/data/SQLiteDict.py +11 -3
  15. edsl/data_transfer_models.py +4 -0
  16. edsl/jobs/Answers.py +15 -1
  17. edsl/jobs/Jobs.py +86 -33
  18. edsl/jobs/buckets/ModelBuckets.py +14 -2
  19. edsl/jobs/buckets/TokenBucket.py +32 -5
  20. edsl/jobs/interviews/Interview.py +99 -79
  21. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +18 -24
  22. edsl/jobs/runners/JobsRunnerAsyncio.py +16 -16
  23. edsl/jobs/tasks/QuestionTaskCreator.py +10 -6
  24. edsl/jobs/tasks/TaskHistory.py +4 -3
  25. edsl/language_models/LanguageModel.py +17 -17
  26. edsl/language_models/ModelList.py +1 -1
  27. edsl/language_models/repair.py +8 -7
  28. edsl/notebooks/Notebook.py +16 -10
  29. edsl/questions/QuestionBase.py +6 -2
  30. edsl/questions/QuestionBudget.py +5 -6
  31. edsl/questions/QuestionCheckBox.py +7 -3
  32. edsl/questions/QuestionExtract.py +5 -3
  33. edsl/questions/QuestionFreeText.py +7 -5
  34. edsl/questions/QuestionFunctional.py +34 -5
  35. edsl/questions/QuestionList.py +3 -4
  36. edsl/questions/QuestionMultipleChoice.py +68 -12
  37. edsl/questions/QuestionNumerical.py +4 -3
  38. edsl/questions/QuestionRank.py +5 -3
  39. edsl/questions/__init__.py +4 -3
  40. edsl/questions/descriptors.py +46 -4
  41. edsl/results/DatasetExportMixin.py +570 -0
  42. edsl/results/Result.py +66 -70
  43. edsl/results/Results.py +160 -68
  44. edsl/results/ResultsDBMixin.py +7 -3
  45. edsl/results/ResultsExportMixin.py +22 -537
  46. edsl/results/ResultsGGMixin.py +3 -3
  47. edsl/results/ResultsToolsMixin.py +1 -4
  48. edsl/scenarios/FileStore.py +299 -0
  49. edsl/scenarios/Scenario.py +16 -24
  50. edsl/scenarios/ScenarioList.py +25 -14
  51. edsl/scenarios/ScenarioListExportMixin.py +32 -0
  52. edsl/scenarios/ScenarioListPdfMixin.py +2 -1
  53. edsl/scenarios/__init__.py +1 -0
  54. edsl/study/Study.py +5 -7
  55. edsl/surveys/MemoryPlan.py +11 -4
  56. edsl/surveys/Survey.py +52 -15
  57. edsl/surveys/SurveyExportMixin.py +4 -2
  58. edsl/surveys/SurveyFlowVisualizationMixin.py +6 -4
  59. edsl/utilities/__init__.py +21 -21
  60. edsl/utilities/interface.py +66 -45
  61. edsl/utilities/utilities.py +11 -13
  62. {edsl-0.1.29.dev6.dist-info → edsl-0.1.30.dist-info}/METADATA +1 -1
  63. {edsl-0.1.29.dev6.dist-info → edsl-0.1.30.dist-info}/RECORD +65 -61
  64. {edsl-0.1.29.dev6.dist-info → edsl-0.1.30.dist-info}/WHEEL +1 -1
  65. {edsl-0.1.29.dev6.dist-info → edsl-0.1.30.dist-info}/LICENSE +0 -0
edsl/base/Base.py ADDED
@@ -0,0 +1,289 @@
1
+ """Base class for all classes in the package. It provides rich printing and persistence of objects."""
2
+
3
+ from abc import ABC, abstractmethod, ABCMeta
4
+ import gzip
5
+ import io
6
+ import json
7
+ from typing import Any, Optional, Union
8
+ from uuid import UUID
9
+ from IPython.display import display
10
+ from rich.console import Console
11
+
12
+
13
+ class RichPrintingMixin:
14
+ """Mixin for rich printing and persistence of objects."""
15
+
16
+ def _for_console(self):
17
+ """Return a string representation of the object for console printing."""
18
+ with io.StringIO() as buf:
19
+ console = Console(file=buf, record=True)
20
+ table = self.rich_print()
21
+ console.print(table)
22
+ return console.export_text()
23
+
24
+ def __str__(self):
25
+ """Return a string representation of the object for console printing."""
26
+ return self._for_console()
27
+
28
+ def print(self):
29
+ """Print the object to the console."""
30
+ from edsl.utilities.utilities import is_notebook
31
+
32
+ if is_notebook():
33
+ display(self.rich_print())
34
+ else:
35
+ from rich.console import Console
36
+
37
+ console = Console()
38
+ console.print(self.rich_print())
39
+
40
+
41
+ class PersistenceMixin:
42
+ """Mixin for saving and loading objects to and from files."""
43
+
44
+ def push(
45
+ self,
46
+ description: Optional[str] = None,
47
+ visibility: Optional[str] = "unlisted",
48
+ ):
49
+ """Post the object to coop."""
50
+ from edsl.coop import Coop
51
+
52
+ c = Coop()
53
+ return c.create(self, description, visibility)
54
+
55
+ @classmethod
56
+ def pull(cls, id_or_url: Union[str, UUID], exec_profile=None):
57
+ """Pull the object from coop."""
58
+ from edsl.coop import Coop
59
+
60
+ if id_or_url.startswith("http"):
61
+ uuid_value = id_or_url.split("/")[-1]
62
+ else:
63
+ uuid_value = id_or_url
64
+
65
+ c = Coop()
66
+
67
+ return c._get_base(cls, uuid_value, exec_profile=exec_profile)
68
+
69
+ @classmethod
70
+ def delete(cls, id_or_url: Union[str, UUID]):
71
+ """Delete the object from coop."""
72
+ from edsl.coop import Coop
73
+
74
+ c = Coop()
75
+ return c._delete_base(cls, id_or_url)
76
+
77
+ @classmethod
78
+ def patch(
79
+ cls,
80
+ id_or_url: Union[str, UUID],
81
+ description: Optional[str] = None,
82
+ value: Optional[Any] = None,
83
+ visibility: Optional[str] = None,
84
+ ):
85
+ """
86
+ Patch an uploaded objects attributes.
87
+ - `description` changes the description of the object on Coop
88
+ - `value` changes the value of the object on Coop. **has to be an EDSL object**
89
+ - `visibility` changes the visibility of the object on Coop
90
+ """
91
+ from edsl.coop import Coop
92
+
93
+ c = Coop()
94
+ return c._patch_base(cls, id_or_url, description, value, visibility)
95
+
96
+ @classmethod
97
+ def search(cls, query):
98
+ """Search for objects on coop."""
99
+ from edsl.coop import Coop
100
+
101
+ c = Coop()
102
+ return c.search(cls, query)
103
+
104
+ def save(self, filename, compress=True):
105
+ """Save the object to a file as zippped JSON.
106
+
107
+ >>> obj.save("obj.json.gz")
108
+
109
+ """
110
+ if filename.endswith("json.gz"):
111
+ import warnings
112
+
113
+ warnings.warn(
114
+ "Do not apply the file extensions. The filename should not end with 'json.gz'."
115
+ )
116
+ filename = filename[:-7]
117
+ if filename.endswith("json"):
118
+ filename = filename[:-4]
119
+ warnings.warn(
120
+ "Do not apply the file extensions. The filename should not end with 'json'."
121
+ )
122
+
123
+ if compress:
124
+ with gzip.open(filename + ".json.gz", "wb") as f:
125
+ f.write(json.dumps(self.to_dict()).encode("utf-8"))
126
+ else:
127
+ with open(filename + ".json", "w") as f:
128
+ f.write(json.dumps(self.to_dict()))
129
+
130
+ @staticmethod
131
+ def open_compressed_file(filename):
132
+ with gzip.open(filename, "rb") as f:
133
+ file_contents = f.read()
134
+ file_contents_decoded = file_contents.decode("utf-8")
135
+ d = json.loads(file_contents_decoded)
136
+ return d
137
+
138
+ @staticmethod
139
+ def open_regular_file(filename):
140
+ with open(filename, "r") as f:
141
+ d = json.loads(f.read())
142
+ return d
143
+
144
+ @classmethod
145
+ def load(cls, filename):
146
+ """Load the object from a file.
147
+
148
+ >>> obj = cls.load("obj.json.gz")
149
+
150
+ """
151
+
152
+ if filename.endswith("json.gz"):
153
+ d = cls.open_compressed_file(filename)
154
+ elif filename.endswith("json"):
155
+ d = cls.open_regular_file(filename)
156
+ else:
157
+ try:
158
+ d = cls.open_compressed_file(filename)
159
+ except:
160
+ d = cls.open_regular_file(filename)
161
+ finally:
162
+ raise ValueError("File must be a json or json.gz file")
163
+
164
+ return cls.from_dict(d)
165
+
166
+
167
+ class RegisterSubclassesMeta(ABCMeta):
168
+ """Metaclass for registering subclasses."""
169
+
170
+ _registry = {}
171
+
172
+ def __init__(cls, name, bases, nmspc):
173
+ """Register the class in the registry upon creation."""
174
+ super(RegisterSubclassesMeta, cls).__init__(name, bases, nmspc)
175
+ if cls.__name__ != "Base":
176
+ RegisterSubclassesMeta._registry[cls.__name__] = cls
177
+
178
+ @staticmethod
179
+ def get_registry():
180
+ """Return the registry of subclasses."""
181
+ return dict(RegisterSubclassesMeta._registry)
182
+
183
+
184
+ class DiffMethodsMixin:
185
+ def __sub__(self, other):
186
+ """Return the difference between two objects."""
187
+ from edsl.BaseDiff import BaseDiff
188
+
189
+ return BaseDiff(self, other)
190
+
191
+
192
+ class Base(
193
+ RichPrintingMixin,
194
+ PersistenceMixin,
195
+ DiffMethodsMixin,
196
+ ABC,
197
+ metaclass=RegisterSubclassesMeta,
198
+ ):
199
+ """Base class for all classes in the package."""
200
+
201
+ # def __getitem__(self, key):
202
+ # return getattr(self, key)
203
+
204
+ # @abstractmethod
205
+ # def _repr_html_(self) -> str:
206
+ # raise NotImplementedError("This method is not implemented yet.")
207
+
208
+ # @abstractmethod
209
+ # def _repr_(self) -> str:
210
+ # raise NotImplementedError("This method is not implemented yet.")
211
+
212
+ def keys(self):
213
+ """Return the keys of the object."""
214
+ _keys = list(self.to_dict().keys())
215
+ if "edsl_version" in _keys:
216
+ _keys.remove("edsl_version")
217
+ if "edsl_class_name" in _keys:
218
+ _keys.remove("edsl_class_name")
219
+ return _keys
220
+
221
+ def values(self):
222
+ """Return the values of the object."""
223
+ data = self.to_dict()
224
+ keys = self.keys()
225
+ return {data[key] for key in keys}
226
+
227
+ def _repr_html_(self):
228
+ from edsl.utilities.utilities import data_to_html
229
+
230
+ return data_to_html(self.to_dict())
231
+
232
+ # def html(self):
233
+ # html_string = self._repr_html_()
234
+ # import tempfile
235
+ # import webbrowser
236
+
237
+ # with tempfile.NamedTemporaryFile("w", delete=False, suffix=".html") as f:
238
+ # # print("Writing HTML to", f.name)
239
+ # f.write(html_string)
240
+ # webbrowser.open(f.name)
241
+
242
+ def __eq__(self, other):
243
+ """Return whether two objects are equal."""
244
+ import inspect
245
+
246
+ if not isinstance(other, self.__class__):
247
+ return False
248
+ if "sort" in inspect.signature(self._to_dict).parameters:
249
+ return self._to_dict(sort=True) == other._to_dict(sort=True)
250
+ else:
251
+ return self._to_dict() == other._to_dict()
252
+
253
+ @abstractmethod
254
+ def example():
255
+ """This method should be implemented by subclasses."""
256
+ raise NotImplementedError("This method is not implemented yet.")
257
+
258
+ @abstractmethod
259
+ def rich_print():
260
+ """This method should be implemented by subclasses."""
261
+ raise NotImplementedError("This method is not implemented yet.")
262
+
263
+ @abstractmethod
264
+ def to_dict():
265
+ """This method should be implemented by subclasses."""
266
+ raise NotImplementedError("This method is not implemented yet.")
267
+
268
+ @abstractmethod
269
+ def from_dict():
270
+ """This method should be implemented by subclasses."""
271
+ raise NotImplementedError("This method is not implemented yet.")
272
+
273
+ @abstractmethod
274
+ def code():
275
+ """This method should be implemented by subclasses."""
276
+ raise NotImplementedError("This method is not implemented yet.")
277
+
278
+ def show_methods(self, show_docstrings=True):
279
+ """Show the methods of the object."""
280
+ public_methods_with_docstrings = [
281
+ (method, getattr(self, method).__doc__)
282
+ for method in dir(self)
283
+ if callable(getattr(self, method)) and not method.startswith("_")
284
+ ]
285
+ if show_docstrings:
286
+ for method, documentation in public_methods_with_docstrings:
287
+ print(f"{method}: {documentation}")
288
+ else:
289
+ return [x[0] for x in public_methods_with_docstrings]
edsl/config.py CHANGED
@@ -1,11 +1,11 @@
1
1
  """This module provides a Config class that loads environment variables from a .env file and sets them as class attributes."""
2
2
 
3
3
  import os
4
- from dotenv import load_dotenv, find_dotenv
5
4
  from edsl.exceptions import (
6
5
  InvalidEnvironmentVariableError,
7
6
  MissingEnvironmentVariableError,
8
7
  )
8
+ from dotenv import load_dotenv, find_dotenv
9
9
 
10
10
  # valid values for EDSL_RUN_MODE
11
11
  EDSL_RUN_MODES = ["development", "development-testrun", "production"]
@@ -96,6 +96,7 @@ class Config:
96
96
  Loads the .env
97
97
  - Overrides existing env vars unless EDSL_RUN_MODE=="development-testrun"
98
98
  """
99
+
99
100
  override = True
100
101
  if self.EDSL_RUN_MODE == "development-testrun":
101
102
  override = False
@@ -30,7 +30,7 @@ c1 = Conversation(agent_list=AgentList([a1, a3, a2]), max_turns=5, verbose=True)
30
30
  c2 = Conversation(agent_list=AgentList([a1, a2]), max_turns=5, verbose=True)
31
31
 
32
32
  c = Cache.load("car_talk.json.gz")
33
- breakpoint()
33
+ # breakpoint()
34
34
  combo = ConversationList([c1, c2], cache=c)
35
35
  combo.run()
36
36
  results = combo.to_results()
edsl/coop/utils.py CHANGED
@@ -10,7 +10,7 @@ from edsl import (
10
10
  Study,
11
11
  )
12
12
  from edsl.questions import QuestionBase
13
- from typing import Literal, Type, Union
13
+ from typing import Literal, Optional, Type, Union
14
14
 
15
15
  EDSLObject = Union[
16
16
  Agent,
@@ -94,3 +94,30 @@ class ObjectRegistry:
94
94
  if EDSL_object is None:
95
95
  raise ValueError(f"EDSL class not found for {object_type=}")
96
96
  return EDSL_object
97
+
98
+ @classmethod
99
+ def get_registry(
100
+ cls,
101
+ subclass_registry: Optional[dict] = None,
102
+ exclude_classes: Optional[list] = None,
103
+ ) -> dict:
104
+ """
105
+ Return the registry of objects.
106
+
107
+ Exclude objects that are already registered in subclass_registry.
108
+ This allows the user to isolate Coop-only objects.
109
+
110
+ Also exclude objects if their class name is in the exclude_classes list.
111
+ """
112
+
113
+ if subclass_registry is None:
114
+ subclass_registry = {}
115
+ if exclude_classes is None:
116
+ exclude_classes = []
117
+
118
+ return {
119
+ class_name: o["edsl_class"]
120
+ for o in cls.objects
121
+ if (class_name := o["edsl_class"].__name__) not in subclass_registry
122
+ and class_name not in exclude_classes
123
+ }
edsl/data/Cache.py CHANGED
@@ -7,17 +7,10 @@ import json
7
7
  import os
8
8
  import warnings
9
9
  from typing import Optional, Union
10
-
11
- from edsl.config import CONFIG
12
- from edsl.data.CacheEntry import CacheEntry
13
- from edsl.data.SQLiteDict import SQLiteDict
14
10
  from edsl.Base import Base
11
+ from edsl.data.CacheEntry import CacheEntry
15
12
  from edsl.utilities.utilities import dict_hash
16
-
17
- from edsl.utilities.decorators import (
18
- add_edsl_version,
19
- remove_edsl_version,
20
- )
13
+ from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
21
14
 
22
15
 
23
16
  class Cache(Base):
@@ -38,9 +31,10 @@ class Cache(Base):
38
31
  self,
39
32
  *,
40
33
  filename: Optional[str] = None,
41
- data: Optional[Union[SQLiteDict, dict]] = None,
34
+ data: Optional[Union["SQLiteDict", dict]] = None,
42
35
  immediate_write: bool = True,
43
36
  method=None,
37
+ verbose=False,
44
38
  ):
45
39
  """
46
40
  Create two dictionaries to store the cache data.
@@ -59,6 +53,7 @@ class Cache(Base):
59
53
  self.new_entries = {}
60
54
  self.new_entries_to_write_later = {}
61
55
  self.coop = None
56
+ self.verbose = verbose
62
57
 
63
58
  self.filename = filename
64
59
  if filename and data:
@@ -104,6 +99,8 @@ class Cache(Base):
104
99
 
105
100
  def _perform_checks(self):
106
101
  """Perform checks on the cache."""
102
+ from edsl.data.CacheEntry import CacheEntry
103
+
107
104
  if any(not isinstance(value, CacheEntry) for value in self.data.values()):
108
105
  raise Exception("Not all values are CacheEntry instances")
109
106
  if self.method is not None:
@@ -120,7 +117,7 @@ class Cache(Base):
120
117
  system_prompt: str,
121
118
  user_prompt: str,
122
119
  iteration: int,
123
- ) -> Union[None, str]:
120
+ ) -> tuple(Union[None, str], str):
124
121
  """
125
122
  Fetch a value (LLM output) from the cache.
126
123
 
@@ -133,11 +130,13 @@ class Cache(Base):
133
130
  Return None if the response is not found.
134
131
 
135
132
  >>> c = Cache()
136
- >>> c.fetch(model="gpt-3", parameters="default", system_prompt="Hello", user_prompt="Hi", iteration=1) is None
133
+ >>> c.fetch(model="gpt-3", parameters="default", system_prompt="Hello", user_prompt="Hi", iteration=1)[0] is None
137
134
  True
138
135
 
139
136
 
140
137
  """
138
+ from edsl.data.CacheEntry import CacheEntry
139
+
141
140
  key = CacheEntry.gen_key(
142
141
  model=model,
143
142
  parameters=parameters,
@@ -147,8 +146,13 @@ class Cache(Base):
147
146
  )
148
147
  entry = self.data.get(key, None)
149
148
  if entry is not None:
149
+ if self.verbose:
150
+ print(f"Cache hit for key: {key}")
150
151
  self.fetched_data[key] = entry
151
- return None if entry is None else entry.output
152
+ else:
153
+ if self.verbose:
154
+ print(f"Cache miss for key: {key}")
155
+ return None if entry is None else entry.output, key
152
156
 
153
157
  def store(
154
158
  self,
@@ -171,6 +175,7 @@ class Cache(Base):
171
175
  * If `immediate_write` is True , the key-value pair is added to `self.data`
172
176
  * If `immediate_write` is False, the key-value pair is added to `self.new_entries_to_write_later`
173
177
  """
178
+
174
179
  entry = CacheEntry(
175
180
  model=model,
176
181
  parameters=parameters,
@@ -188,13 +193,14 @@ class Cache(Base):
188
193
  return key
189
194
 
190
195
  def add_from_dict(
191
- self, new_data: dict[str, CacheEntry], write_now: Optional[bool] = True
196
+ self, new_data: dict[str, "CacheEntry"], write_now: Optional[bool] = True
192
197
  ) -> None:
193
198
  """
194
199
  Add entries to the cache from a dictionary.
195
200
 
196
201
  :param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
197
202
  """
203
+
198
204
  for key, value in new_data.items():
199
205
  if key in self.data:
200
206
  if value != self.data[key]:
@@ -231,6 +237,8 @@ class Cache(Base):
231
237
 
232
238
  :param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
233
239
  """
240
+ from edsl.data.SQLiteDict import SQLiteDict
241
+
234
242
  db = SQLiteDict(db_path)
235
243
  new_data = {}
236
244
  for key, value in db.items():
@@ -242,6 +250,8 @@ class Cache(Base):
242
250
  """
243
251
  Construct a Cache from a SQLite database.
244
252
  """
253
+ from edsl.data.SQLiteDict import SQLiteDict
254
+
245
255
  return cls(data=SQLiteDict(db_path))
246
256
 
247
257
  @classmethod
@@ -268,6 +278,8 @@ class Cache(Base):
268
278
  * If `db_path` is provided, the cache will be stored in an SQLite database.
269
279
  """
270
280
  # if a file doesn't exist at jsonfile, throw an error
281
+ from edsl.data.SQLiteDict import SQLiteDict
282
+
271
283
  if not os.path.exists(jsonlfile):
272
284
  raise FileNotFoundError(f"File {jsonlfile} not found")
273
285
 
@@ -286,6 +298,8 @@ class Cache(Base):
286
298
  """
287
299
  ## TODO: Check to make sure not over-writing (?)
288
300
  ## Should be added to SQLiteDict constructor (?)
301
+ from edsl.data.SQLiteDict import SQLiteDict
302
+
289
303
  new_data = SQLiteDict(db_path)
290
304
  for key, value in self.data.items():
291
305
  new_data[key] = value
@@ -340,6 +354,9 @@ class Cache(Base):
340
354
  for key, entry in self.new_entries_to_write_later.items():
341
355
  self.data[key] = entry
342
356
 
357
+ if self.filename:
358
+ self.write(self.filename)
359
+
343
360
  ####################
344
361
  # DUNDER / USEFUL
345
362
  ####################
@@ -456,12 +473,18 @@ class Cache(Base):
456
473
  webbrowser.open("file://" + filepath)
457
474
 
458
475
  @classmethod
459
- def example(cls) -> Cache:
476
+ def example(cls, randomize: bool = False) -> Cache:
460
477
  """
461
- Return an example Cache.
462
- The example Cache has one entry.
478
+ Returns an example Cache instance.
479
+
480
+ :param randomize: If True, uses CacheEntry's randomize method.
463
481
  """
464
- return cls(data={CacheEntry.example().key: CacheEntry.example()})
482
+ return cls(
483
+ data={
484
+ CacheEntry.example(randomize).key: CacheEntry.example(),
485
+ CacheEntry.example(randomize).key: CacheEntry.example(),
486
+ }
487
+ )
465
488
 
466
489
 
467
490
  if __name__ == "__main__":
edsl/data/CacheEntry.py CHANGED
@@ -2,11 +2,8 @@ from __future__ import annotations
2
2
  import json
3
3
  import datetime
4
4
  import hashlib
5
- import random
6
5
  from typing import Optional
7
-
8
-
9
- # TODO: Timestamp should probably be float?
6
+ from uuid import uuid4
10
7
 
11
8
 
12
9
  class CacheEntry:
@@ -151,10 +148,12 @@ class CacheEntry:
151
148
  @classmethod
152
149
  def example(cls, randomize: bool = False) -> CacheEntry:
153
150
  """
154
- Returns a CacheEntry example.
151
+ Returns an example CacheEntry instance.
152
+
153
+ :param randomize: If True, adds a random string to the system prompt.
155
154
  """
156
- # if random, create a random number for 0-100
157
- addition = "" if not randomize else str(random.randint(0, 1000))
155
+ # if random, create a uuid
156
+ addition = "" if not randomize else str(uuid4())
158
157
  return CacheEntry(
159
158
  model="gpt-3.5-turbo",
160
159
  parameters={"temperature": 0.5},
edsl/data/SQLiteDict.py CHANGED
@@ -1,9 +1,7 @@
1
1
  from __future__ import annotations
2
2
  import json
3
- from sqlalchemy import create_engine
4
- from sqlalchemy.exc import SQLAlchemyError
5
- from sqlalchemy.orm import sessionmaker
6
3
  from typing import Any, Generator, Optional, Union
4
+
7
5
  from edsl.config import CONFIG
8
6
  from edsl.data.CacheEntry import CacheEntry
9
7
  from edsl.data.orm import Base, Data
@@ -25,10 +23,16 @@ class SQLiteDict:
25
23
  >>> import os; os.unlink(temp_db_path) # Clean up the temp file after the test
26
24
 
27
25
  """
26
+ from sqlalchemy.exc import SQLAlchemyError
27
+ from sqlalchemy.orm import sessionmaker
28
+ from sqlalchemy import create_engine
29
+
28
30
  self.db_path = db_path or CONFIG.get("EDSL_DATABASE_PATH")
29
31
  if not self.db_path.startswith("sqlite:///"):
30
32
  self.db_path = f"sqlite:///{self.db_path}"
31
33
  try:
34
+ from edsl.data.orm import Base, Data
35
+
32
36
  self.engine = create_engine(self.db_path, echo=False, future=True)
33
37
  Base.metadata.create_all(self.engine)
34
38
  self.Session = sessionmaker(bind=self.engine)
@@ -55,6 +59,8 @@ class SQLiteDict:
55
59
  if not isinstance(value, CacheEntry):
56
60
  raise ValueError(f"Value must be a CacheEntry object (got {type(value)}).")
57
61
  with self.Session() as db:
62
+ from edsl.data.orm import Base, Data
63
+
58
64
  db.merge(Data(key=key, value=json.dumps(value.to_dict())))
59
65
  db.commit()
60
66
 
@@ -69,6 +75,8 @@ class SQLiteDict:
69
75
  True
70
76
  """
71
77
  with self.Session() as db:
78
+ from edsl.data.orm import Base, Data
79
+
72
80
  value = db.query(Data).filter_by(key=key).first()
73
81
  if not value:
74
82
  raise KeyError(f"Key '{key}' not found.")
@@ -17,6 +17,8 @@ class AgentResponseDict(UserDict):
17
17
  cached_response=None,
18
18
  raw_model_response=None,
19
19
  simple_model_raw_response=None,
20
+ cache_used=None,
21
+ cache_key=None,
20
22
  ):
21
23
  """Initialize the AgentResponseDict object."""
22
24
  usage = usage or {"prompt_tokens": 0, "completion_tokens": 0}
@@ -30,5 +32,7 @@ class AgentResponseDict(UserDict):
30
32
  "cached_response": cached_response,
31
33
  "raw_model_response": raw_model_response,
32
34
  "simple_model_raw_response": simple_model_raw_response,
35
+ "cache_used": cache_used,
36
+ "cache_key": cache_key,
33
37
  }
34
38
  )
edsl/jobs/Answers.py CHANGED
@@ -8,7 +8,15 @@ class Answers(UserDict):
8
8
  """Helper class to hold the answers to a survey."""
9
9
 
10
10
  def add_answer(self, response, question) -> None:
11
- """Add a response to the answers dictionary."""
11
+ """Add a response to the answers dictionary.
12
+
13
+ >>> from edsl import QuestionFreeText
14
+ >>> q = QuestionFreeText.example()
15
+ >>> answers = Answers()
16
+ >>> answers.add_answer({"answer": "yes"}, q)
17
+ >>> answers[q.question_name]
18
+ 'yes'
19
+ """
12
20
  answer = response.get("answer")
13
21
  comment = response.pop("comment", None)
14
22
  # record the answer
@@ -42,3 +50,9 @@ class Answers(UserDict):
42
50
  table.add_row(attr_name, repr(attr_value))
43
51
 
44
52
  return table
53
+
54
+
55
+ if __name__ == "__main__":
56
+ import doctest
57
+
58
+ doctest.testmod()