edsl 0.1.29.dev3__py3-none-any.whl → 0.1.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +18 -18
- edsl/__init__.py +23 -23
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +79 -41
- edsl/agents/AgentList.py +26 -26
- edsl/agents/Invigilator.py +19 -2
- edsl/agents/InvigilatorBase.py +15 -10
- edsl/agents/PromptConstructionMixin.py +342 -100
- edsl/agents/descriptors.py +2 -1
- edsl/base/Base.py +289 -0
- edsl/config.py +2 -1
- edsl/conjure/InputData.py +39 -8
- edsl/conversation/car_buying.py +1 -1
- edsl/coop/coop.py +187 -150
- edsl/coop/utils.py +43 -75
- edsl/data/Cache.py +41 -18
- edsl/data/CacheEntry.py +6 -7
- edsl/data/SQLiteDict.py +11 -3
- edsl/data_transfer_models.py +4 -0
- edsl/jobs/Answers.py +15 -1
- edsl/jobs/Jobs.py +108 -49
- edsl/jobs/buckets/ModelBuckets.py +14 -2
- edsl/jobs/buckets/TokenBucket.py +32 -5
- edsl/jobs/interviews/Interview.py +99 -79
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +19 -24
- edsl/jobs/runners/JobsRunnerAsyncio.py +16 -16
- edsl/jobs/tasks/QuestionTaskCreator.py +10 -6
- edsl/jobs/tasks/TaskHistory.py +4 -3
- edsl/language_models/LanguageModel.py +17 -17
- edsl/language_models/ModelList.py +1 -1
- edsl/language_models/repair.py +8 -7
- edsl/notebooks/Notebook.py +47 -10
- edsl/prompts/Prompt.py +31 -19
- edsl/questions/QuestionBase.py +38 -13
- edsl/questions/QuestionBudget.py +5 -6
- edsl/questions/QuestionCheckBox.py +7 -3
- edsl/questions/QuestionExtract.py +5 -3
- edsl/questions/QuestionFreeText.py +7 -5
- edsl/questions/QuestionFunctional.py +34 -5
- edsl/questions/QuestionList.py +3 -4
- edsl/questions/QuestionMultipleChoice.py +68 -12
- edsl/questions/QuestionNumerical.py +4 -3
- edsl/questions/QuestionRank.py +5 -3
- edsl/questions/__init__.py +4 -3
- edsl/questions/descriptors.py +46 -4
- edsl/questions/question_registry.py +20 -31
- edsl/questions/settings.py +1 -1
- edsl/results/Dataset.py +31 -0
- edsl/results/DatasetExportMixin.py +570 -0
- edsl/results/Result.py +66 -70
- edsl/results/Results.py +160 -68
- edsl/results/ResultsDBMixin.py +7 -3
- edsl/results/ResultsExportMixin.py +22 -537
- edsl/results/ResultsGGMixin.py +3 -3
- edsl/results/ResultsToolsMixin.py +5 -5
- edsl/scenarios/FileStore.py +299 -0
- edsl/scenarios/Scenario.py +16 -24
- edsl/scenarios/ScenarioList.py +42 -17
- edsl/scenarios/ScenarioListExportMixin.py +32 -0
- edsl/scenarios/ScenarioListPdfMixin.py +2 -1
- edsl/scenarios/__init__.py +1 -0
- edsl/study/Study.py +8 -16
- edsl/surveys/MemoryPlan.py +11 -4
- edsl/surveys/Survey.py +88 -17
- edsl/surveys/SurveyExportMixin.py +4 -2
- edsl/surveys/SurveyFlowVisualizationMixin.py +6 -4
- edsl/tools/plotting.py +4 -2
- edsl/utilities/__init__.py +21 -21
- edsl/utilities/interface.py +66 -45
- edsl/utilities/utilities.py +11 -13
- {edsl-0.1.29.dev3.dist-info → edsl-0.1.30.dist-info}/METADATA +11 -10
- {edsl-0.1.29.dev3.dist-info → edsl-0.1.30.dist-info}/RECORD +74 -71
- {edsl-0.1.29.dev3.dist-info → edsl-0.1.30.dist-info}/WHEEL +1 -1
- edsl-0.1.29.dev3.dist-info/entry_points.txt +0 -3
- {edsl-0.1.29.dev3.dist-info → edsl-0.1.30.dist-info}/LICENSE +0 -0
edsl/data/Cache.py
CHANGED
@@ -7,17 +7,10 @@ import json
|
|
7
7
|
import os
|
8
8
|
import warnings
|
9
9
|
from typing import Optional, Union
|
10
|
-
|
11
|
-
from edsl.config import CONFIG
|
12
|
-
from edsl.data.CacheEntry import CacheEntry
|
13
|
-
from edsl.data.SQLiteDict import SQLiteDict
|
14
10
|
from edsl.Base import Base
|
11
|
+
from edsl.data.CacheEntry import CacheEntry
|
15
12
|
from edsl.utilities.utilities import dict_hash
|
16
|
-
|
17
|
-
from edsl.utilities.decorators import (
|
18
|
-
add_edsl_version,
|
19
|
-
remove_edsl_version,
|
20
|
-
)
|
13
|
+
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
21
14
|
|
22
15
|
|
23
16
|
class Cache(Base):
|
@@ -38,9 +31,10 @@ class Cache(Base):
|
|
38
31
|
self,
|
39
32
|
*,
|
40
33
|
filename: Optional[str] = None,
|
41
|
-
data: Optional[Union[SQLiteDict, dict]] = None,
|
34
|
+
data: Optional[Union["SQLiteDict", dict]] = None,
|
42
35
|
immediate_write: bool = True,
|
43
36
|
method=None,
|
37
|
+
verbose=False,
|
44
38
|
):
|
45
39
|
"""
|
46
40
|
Create two dictionaries to store the cache data.
|
@@ -59,6 +53,7 @@ class Cache(Base):
|
|
59
53
|
self.new_entries = {}
|
60
54
|
self.new_entries_to_write_later = {}
|
61
55
|
self.coop = None
|
56
|
+
self.verbose = verbose
|
62
57
|
|
63
58
|
self.filename = filename
|
64
59
|
if filename and data:
|
@@ -104,6 +99,8 @@ class Cache(Base):
|
|
104
99
|
|
105
100
|
def _perform_checks(self):
|
106
101
|
"""Perform checks on the cache."""
|
102
|
+
from edsl.data.CacheEntry import CacheEntry
|
103
|
+
|
107
104
|
if any(not isinstance(value, CacheEntry) for value in self.data.values()):
|
108
105
|
raise Exception("Not all values are CacheEntry instances")
|
109
106
|
if self.method is not None:
|
@@ -120,7 +117,7 @@ class Cache(Base):
|
|
120
117
|
system_prompt: str,
|
121
118
|
user_prompt: str,
|
122
119
|
iteration: int,
|
123
|
-
) -> Union[None, str]:
|
120
|
+
) -> tuple(Union[None, str], str):
|
124
121
|
"""
|
125
122
|
Fetch a value (LLM output) from the cache.
|
126
123
|
|
@@ -133,11 +130,13 @@ class Cache(Base):
|
|
133
130
|
Return None if the response is not found.
|
134
131
|
|
135
132
|
>>> c = Cache()
|
136
|
-
>>> c.fetch(model="gpt-3", parameters="default", system_prompt="Hello", user_prompt="Hi", iteration=1) is None
|
133
|
+
>>> c.fetch(model="gpt-3", parameters="default", system_prompt="Hello", user_prompt="Hi", iteration=1)[0] is None
|
137
134
|
True
|
138
135
|
|
139
136
|
|
140
137
|
"""
|
138
|
+
from edsl.data.CacheEntry import CacheEntry
|
139
|
+
|
141
140
|
key = CacheEntry.gen_key(
|
142
141
|
model=model,
|
143
142
|
parameters=parameters,
|
@@ -147,8 +146,13 @@ class Cache(Base):
|
|
147
146
|
)
|
148
147
|
entry = self.data.get(key, None)
|
149
148
|
if entry is not None:
|
149
|
+
if self.verbose:
|
150
|
+
print(f"Cache hit for key: {key}")
|
150
151
|
self.fetched_data[key] = entry
|
151
|
-
|
152
|
+
else:
|
153
|
+
if self.verbose:
|
154
|
+
print(f"Cache miss for key: {key}")
|
155
|
+
return None if entry is None else entry.output, key
|
152
156
|
|
153
157
|
def store(
|
154
158
|
self,
|
@@ -171,6 +175,7 @@ class Cache(Base):
|
|
171
175
|
* If `immediate_write` is True , the key-value pair is added to `self.data`
|
172
176
|
* If `immediate_write` is False, the key-value pair is added to `self.new_entries_to_write_later`
|
173
177
|
"""
|
178
|
+
|
174
179
|
entry = CacheEntry(
|
175
180
|
model=model,
|
176
181
|
parameters=parameters,
|
@@ -188,13 +193,14 @@ class Cache(Base):
|
|
188
193
|
return key
|
189
194
|
|
190
195
|
def add_from_dict(
|
191
|
-
self, new_data: dict[str, CacheEntry], write_now: Optional[bool] = True
|
196
|
+
self, new_data: dict[str, "CacheEntry"], write_now: Optional[bool] = True
|
192
197
|
) -> None:
|
193
198
|
"""
|
194
199
|
Add entries to the cache from a dictionary.
|
195
200
|
|
196
201
|
:param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
|
197
202
|
"""
|
203
|
+
|
198
204
|
for key, value in new_data.items():
|
199
205
|
if key in self.data:
|
200
206
|
if value != self.data[key]:
|
@@ -231,6 +237,8 @@ class Cache(Base):
|
|
231
237
|
|
232
238
|
:param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
|
233
239
|
"""
|
240
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
241
|
+
|
234
242
|
db = SQLiteDict(db_path)
|
235
243
|
new_data = {}
|
236
244
|
for key, value in db.items():
|
@@ -242,6 +250,8 @@ class Cache(Base):
|
|
242
250
|
"""
|
243
251
|
Construct a Cache from a SQLite database.
|
244
252
|
"""
|
253
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
254
|
+
|
245
255
|
return cls(data=SQLiteDict(db_path))
|
246
256
|
|
247
257
|
@classmethod
|
@@ -268,6 +278,8 @@ class Cache(Base):
|
|
268
278
|
* If `db_path` is provided, the cache will be stored in an SQLite database.
|
269
279
|
"""
|
270
280
|
# if a file doesn't exist at jsonfile, throw an error
|
281
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
282
|
+
|
271
283
|
if not os.path.exists(jsonlfile):
|
272
284
|
raise FileNotFoundError(f"File {jsonlfile} not found")
|
273
285
|
|
@@ -286,6 +298,8 @@ class Cache(Base):
|
|
286
298
|
"""
|
287
299
|
## TODO: Check to make sure not over-writing (?)
|
288
300
|
## Should be added to SQLiteDict constructor (?)
|
301
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
302
|
+
|
289
303
|
new_data = SQLiteDict(db_path)
|
290
304
|
for key, value in self.data.items():
|
291
305
|
new_data[key] = value
|
@@ -340,6 +354,9 @@ class Cache(Base):
|
|
340
354
|
for key, entry in self.new_entries_to_write_later.items():
|
341
355
|
self.data[key] = entry
|
342
356
|
|
357
|
+
if self.filename:
|
358
|
+
self.write(self.filename)
|
359
|
+
|
343
360
|
####################
|
344
361
|
# DUNDER / USEFUL
|
345
362
|
####################
|
@@ -456,12 +473,18 @@ class Cache(Base):
|
|
456
473
|
webbrowser.open("file://" + filepath)
|
457
474
|
|
458
475
|
@classmethod
|
459
|
-
def example(cls) -> Cache:
|
476
|
+
def example(cls, randomize: bool = False) -> Cache:
|
460
477
|
"""
|
461
|
-
|
462
|
-
|
478
|
+
Returns an example Cache instance.
|
479
|
+
|
480
|
+
:param randomize: If True, uses CacheEntry's randomize method.
|
463
481
|
"""
|
464
|
-
return cls(
|
482
|
+
return cls(
|
483
|
+
data={
|
484
|
+
CacheEntry.example(randomize).key: CacheEntry.example(),
|
485
|
+
CacheEntry.example(randomize).key: CacheEntry.example(),
|
486
|
+
}
|
487
|
+
)
|
465
488
|
|
466
489
|
|
467
490
|
if __name__ == "__main__":
|
edsl/data/CacheEntry.py
CHANGED
@@ -2,11 +2,8 @@ from __future__ import annotations
|
|
2
2
|
import json
|
3
3
|
import datetime
|
4
4
|
import hashlib
|
5
|
-
import random
|
6
5
|
from typing import Optional
|
7
|
-
|
8
|
-
|
9
|
-
# TODO: Timestamp should probably be float?
|
6
|
+
from uuid import uuid4
|
10
7
|
|
11
8
|
|
12
9
|
class CacheEntry:
|
@@ -151,10 +148,12 @@ class CacheEntry:
|
|
151
148
|
@classmethod
|
152
149
|
def example(cls, randomize: bool = False) -> CacheEntry:
|
153
150
|
"""
|
154
|
-
Returns
|
151
|
+
Returns an example CacheEntry instance.
|
152
|
+
|
153
|
+
:param randomize: If True, adds a random string to the system prompt.
|
155
154
|
"""
|
156
|
-
# if random, create a
|
157
|
-
addition = "" if not randomize else str(
|
155
|
+
# if random, create a uuid
|
156
|
+
addition = "" if not randomize else str(uuid4())
|
158
157
|
return CacheEntry(
|
159
158
|
model="gpt-3.5-turbo",
|
160
159
|
parameters={"temperature": 0.5},
|
edsl/data/SQLiteDict.py
CHANGED
@@ -1,9 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import json
|
3
|
-
from sqlalchemy import create_engine
|
4
|
-
from sqlalchemy.exc import SQLAlchemyError
|
5
|
-
from sqlalchemy.orm import sessionmaker
|
6
3
|
from typing import Any, Generator, Optional, Union
|
4
|
+
|
7
5
|
from edsl.config import CONFIG
|
8
6
|
from edsl.data.CacheEntry import CacheEntry
|
9
7
|
from edsl.data.orm import Base, Data
|
@@ -25,10 +23,16 @@ class SQLiteDict:
|
|
25
23
|
>>> import os; os.unlink(temp_db_path) # Clean up the temp file after the test
|
26
24
|
|
27
25
|
"""
|
26
|
+
from sqlalchemy.exc import SQLAlchemyError
|
27
|
+
from sqlalchemy.orm import sessionmaker
|
28
|
+
from sqlalchemy import create_engine
|
29
|
+
|
28
30
|
self.db_path = db_path or CONFIG.get("EDSL_DATABASE_PATH")
|
29
31
|
if not self.db_path.startswith("sqlite:///"):
|
30
32
|
self.db_path = f"sqlite:///{self.db_path}"
|
31
33
|
try:
|
34
|
+
from edsl.data.orm import Base, Data
|
35
|
+
|
32
36
|
self.engine = create_engine(self.db_path, echo=False, future=True)
|
33
37
|
Base.metadata.create_all(self.engine)
|
34
38
|
self.Session = sessionmaker(bind=self.engine)
|
@@ -55,6 +59,8 @@ class SQLiteDict:
|
|
55
59
|
if not isinstance(value, CacheEntry):
|
56
60
|
raise ValueError(f"Value must be a CacheEntry object (got {type(value)}).")
|
57
61
|
with self.Session() as db:
|
62
|
+
from edsl.data.orm import Base, Data
|
63
|
+
|
58
64
|
db.merge(Data(key=key, value=json.dumps(value.to_dict())))
|
59
65
|
db.commit()
|
60
66
|
|
@@ -69,6 +75,8 @@ class SQLiteDict:
|
|
69
75
|
True
|
70
76
|
"""
|
71
77
|
with self.Session() as db:
|
78
|
+
from edsl.data.orm import Base, Data
|
79
|
+
|
72
80
|
value = db.query(Data).filter_by(key=key).first()
|
73
81
|
if not value:
|
74
82
|
raise KeyError(f"Key '{key}' not found.")
|
edsl/data_transfer_models.py
CHANGED
@@ -17,6 +17,8 @@ class AgentResponseDict(UserDict):
|
|
17
17
|
cached_response=None,
|
18
18
|
raw_model_response=None,
|
19
19
|
simple_model_raw_response=None,
|
20
|
+
cache_used=None,
|
21
|
+
cache_key=None,
|
20
22
|
):
|
21
23
|
"""Initialize the AgentResponseDict object."""
|
22
24
|
usage = usage or {"prompt_tokens": 0, "completion_tokens": 0}
|
@@ -30,5 +32,7 @@ class AgentResponseDict(UserDict):
|
|
30
32
|
"cached_response": cached_response,
|
31
33
|
"raw_model_response": raw_model_response,
|
32
34
|
"simple_model_raw_response": simple_model_raw_response,
|
35
|
+
"cache_used": cache_used,
|
36
|
+
"cache_key": cache_key,
|
33
37
|
}
|
34
38
|
)
|
edsl/jobs/Answers.py
CHANGED
@@ -8,7 +8,15 @@ class Answers(UserDict):
|
|
8
8
|
"""Helper class to hold the answers to a survey."""
|
9
9
|
|
10
10
|
def add_answer(self, response, question) -> None:
|
11
|
-
"""Add a response to the answers dictionary.
|
11
|
+
"""Add a response to the answers dictionary.
|
12
|
+
|
13
|
+
>>> from edsl import QuestionFreeText
|
14
|
+
>>> q = QuestionFreeText.example()
|
15
|
+
>>> answers = Answers()
|
16
|
+
>>> answers.add_answer({"answer": "yes"}, q)
|
17
|
+
>>> answers[q.question_name]
|
18
|
+
'yes'
|
19
|
+
"""
|
12
20
|
answer = response.get("answer")
|
13
21
|
comment = response.pop("comment", None)
|
14
22
|
# record the answer
|
@@ -42,3 +50,9 @@ class Answers(UserDict):
|
|
42
50
|
table.add_row(attr_name, repr(attr_value))
|
43
51
|
|
44
52
|
return table
|
53
|
+
|
54
|
+
|
55
|
+
if __name__ == "__main__":
|
56
|
+
import doctest
|
57
|
+
|
58
|
+
doctest.testmod()
|
edsl/jobs/Jobs.py
CHANGED
@@ -1,30 +1,15 @@
|
|
1
1
|
# """The Jobs class is a collection of agents, scenarios and models and one survey."""
|
2
2
|
from __future__ import annotations
|
3
|
-
import os
|
4
3
|
import warnings
|
5
4
|
from itertools import product
|
6
5
|
from typing import Optional, Union, Sequence, Generator
|
7
|
-
|
8
|
-
from edsl.agents import Agent
|
9
|
-
from edsl.agents.AgentList import AgentList
|
6
|
+
|
10
7
|
from edsl.Base import Base
|
11
|
-
from edsl.data.Cache import Cache
|
12
|
-
from edsl.data.CacheHandler import CacheHandler
|
13
|
-
from edsl.results.Dataset import Dataset
|
14
8
|
|
15
|
-
from edsl.exceptions.jobs import MissingRemoteInferenceError
|
16
9
|
from edsl.exceptions import MissingAPIKeyError
|
17
10
|
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
18
11
|
from edsl.jobs.interviews.Interview import Interview
|
19
|
-
from edsl.language_models import LanguageModel
|
20
|
-
from edsl.results import Results
|
21
|
-
from edsl.scenarios import Scenario
|
22
|
-
from edsl import ScenarioList
|
23
|
-
from edsl.surveys import Survey
|
24
12
|
from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
|
25
|
-
|
26
|
-
from edsl.language_models.ModelList import ModelList
|
27
|
-
|
28
13
|
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
29
14
|
|
30
15
|
|
@@ -37,10 +22,10 @@ class Jobs(Base):
|
|
37
22
|
|
38
23
|
def __init__(
|
39
24
|
self,
|
40
|
-
survey: Survey,
|
41
|
-
agents: Optional[list[Agent]] = None,
|
42
|
-
models: Optional[list[LanguageModel]] = None,
|
43
|
-
scenarios: Optional[list[Scenario]] = None,
|
25
|
+
survey: "Survey",
|
26
|
+
agents: Optional[list["Agent"]] = None,
|
27
|
+
models: Optional[list["LanguageModel"]] = None,
|
28
|
+
scenarios: Optional[list["Scenario"]] = None,
|
44
29
|
):
|
45
30
|
"""Initialize a Jobs instance.
|
46
31
|
|
@@ -50,8 +35,8 @@ class Jobs(Base):
|
|
50
35
|
:param scenarios: a list of scenarios
|
51
36
|
"""
|
52
37
|
self.survey = survey
|
53
|
-
self.agents: AgentList = agents
|
54
|
-
self.scenarios: ScenarioList = scenarios
|
38
|
+
self.agents: "AgentList" = agents
|
39
|
+
self.scenarios: "ScenarioList" = scenarios
|
55
40
|
self.models = models
|
56
41
|
|
57
42
|
self.__bucket_collection = None
|
@@ -62,6 +47,8 @@ class Jobs(Base):
|
|
62
47
|
|
63
48
|
@models.setter
|
64
49
|
def models(self, value):
|
50
|
+
from edsl import ModelList
|
51
|
+
|
65
52
|
if value:
|
66
53
|
if not isinstance(value, ModelList):
|
67
54
|
self._models = ModelList(value)
|
@@ -76,6 +63,8 @@ class Jobs(Base):
|
|
76
63
|
|
77
64
|
@agents.setter
|
78
65
|
def agents(self, value):
|
66
|
+
from edsl import AgentList
|
67
|
+
|
79
68
|
if value:
|
80
69
|
if not isinstance(value, AgentList):
|
81
70
|
self._agents = AgentList(value)
|
@@ -90,6 +79,8 @@ class Jobs(Base):
|
|
90
79
|
|
91
80
|
@scenarios.setter
|
92
81
|
def scenarios(self, value):
|
82
|
+
from edsl import ScenarioList
|
83
|
+
|
93
84
|
if value:
|
94
85
|
if not isinstance(value, ScenarioList):
|
95
86
|
self._scenarios = ScenarioList(value)
|
@@ -101,10 +92,10 @@ class Jobs(Base):
|
|
101
92
|
def by(
|
102
93
|
self,
|
103
94
|
*args: Union[
|
104
|
-
Agent,
|
105
|
-
Scenario,
|
106
|
-
LanguageModel,
|
107
|
-
Sequence[Union[Agent, Scenario, LanguageModel]],
|
95
|
+
"Agent",
|
96
|
+
"Scenario",
|
97
|
+
"LanguageModel",
|
98
|
+
Sequence[Union["Agent", "Scenario", "LanguageModel"]],
|
108
99
|
],
|
109
100
|
) -> Jobs:
|
110
101
|
"""
|
@@ -144,7 +135,7 @@ class Jobs(Base):
|
|
144
135
|
setattr(self, objects_key, new_objects) # update the job
|
145
136
|
return self
|
146
137
|
|
147
|
-
def prompts(self) -> Dataset:
|
138
|
+
def prompts(self) -> "Dataset":
|
148
139
|
"""Return a Dataset of prompts that will be used.
|
149
140
|
|
150
141
|
|
@@ -160,6 +151,7 @@ class Jobs(Base):
|
|
160
151
|
user_prompts = []
|
161
152
|
system_prompts = []
|
162
153
|
scenario_indices = []
|
154
|
+
from edsl.results.Dataset import Dataset
|
163
155
|
|
164
156
|
for interview_index, interview in enumerate(interviews):
|
165
157
|
invigilators = list(interview._build_invigilators(debug=False))
|
@@ -182,7 +174,10 @@ class Jobs(Base):
|
|
182
174
|
|
183
175
|
@staticmethod
|
184
176
|
def _get_container_class(object):
|
185
|
-
from edsl import AgentList
|
177
|
+
from edsl.agents.AgentList import AgentList
|
178
|
+
from edsl.agents.Agent import Agent
|
179
|
+
from edsl.scenarios.Scenario import Scenario
|
180
|
+
from edsl.scenarios.ScenarioList import ScenarioList
|
186
181
|
|
187
182
|
if isinstance(object, Agent):
|
188
183
|
return AgentList
|
@@ -218,6 +213,10 @@ class Jobs(Base):
|
|
218
213
|
def _get_current_objects_of_this_type(
|
219
214
|
self, object: Union[Agent, Scenario, LanguageModel]
|
220
215
|
) -> tuple[list, str]:
|
216
|
+
from edsl.agents.Agent import Agent
|
217
|
+
from edsl.scenarios.Scenario import Scenario
|
218
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
219
|
+
|
221
220
|
"""Return the current objects of the same type as the first argument.
|
222
221
|
|
223
222
|
>>> from edsl.jobs import Jobs
|
@@ -246,6 +245,9 @@ class Jobs(Base):
|
|
246
245
|
@staticmethod
|
247
246
|
def _get_empty_container_object(object):
|
248
247
|
from edsl import AgentList
|
248
|
+
from edsl import Agent
|
249
|
+
from edsl import Scenario
|
250
|
+
from edsl import ScenarioList
|
249
251
|
|
250
252
|
if isinstance(object, Agent):
|
251
253
|
return AgentList([])
|
@@ -310,12 +312,12 @@ class Jobs(Base):
|
|
310
312
|
with us filling in defaults.
|
311
313
|
"""
|
312
314
|
# if no agents, models, or scenarios are set, set them to defaults
|
315
|
+
from edsl.agents.Agent import Agent
|
316
|
+
from edsl.language_models.registry import Model
|
317
|
+
from edsl.scenarios.Scenario import Scenario
|
318
|
+
|
313
319
|
self.agents = self.agents or [Agent()]
|
314
320
|
self.models = self.models or [Model()]
|
315
|
-
# if remote, set all the models to remote
|
316
|
-
if hasattr(self, "remote") and self.remote:
|
317
|
-
for model in self.models:
|
318
|
-
model.remote = True
|
319
321
|
self.scenarios = self.scenarios or [Scenario()]
|
320
322
|
for agent, scenario, model in product(self.agents, self.scenarios, self.models):
|
321
323
|
yield Interview(
|
@@ -329,6 +331,7 @@ class Jobs(Base):
|
|
329
331
|
These buckets are used to track API calls and token usage.
|
330
332
|
|
331
333
|
>>> from edsl.jobs import Jobs
|
334
|
+
>>> from edsl import Model
|
332
335
|
>>> j = Jobs.example().by(Model(temperature = 1), Model(temperature = 0.5))
|
333
336
|
>>> bc = j.create_bucket_collection()
|
334
337
|
>>> bc
|
@@ -368,14 +371,16 @@ class Jobs(Base):
|
|
368
371
|
if self.verbose:
|
369
372
|
print(message)
|
370
373
|
|
371
|
-
def _check_parameters(self, strict=False, warn
|
374
|
+
def _check_parameters(self, strict=False, warn=False) -> None:
|
372
375
|
"""Check if the parameters in the survey and scenarios are consistent.
|
373
376
|
|
374
377
|
>>> from edsl import QuestionFreeText
|
378
|
+
>>> from edsl import Survey
|
379
|
+
>>> from edsl import Scenario
|
375
380
|
>>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
|
376
381
|
>>> j = Jobs(survey = Survey(questions=[q]))
|
377
382
|
>>> with warnings.catch_warnings(record=True) as w:
|
378
|
-
... j._check_parameters()
|
383
|
+
... j._check_parameters(warn = True)
|
379
384
|
... assert len(w) == 1
|
380
385
|
... assert issubclass(w[-1].category, UserWarning)
|
381
386
|
... assert "The following parameters are in the survey but not in the scenarios" in str(w[-1].message)
|
@@ -413,15 +418,13 @@ class Jobs(Base):
|
|
413
418
|
progress_bar: bool = False,
|
414
419
|
stop_on_exception: bool = False,
|
415
420
|
cache: Union[Cache, bool] = None,
|
416
|
-
remote: bool = (
|
417
|
-
False if os.getenv("DEFAULT_RUN_MODE", "local") == "local" else True
|
418
|
-
),
|
419
421
|
check_api_keys: bool = False,
|
420
422
|
sidecar_model: Optional[LanguageModel] = None,
|
421
423
|
batch_mode: Optional[bool] = None,
|
422
424
|
verbose: bool = False,
|
423
425
|
print_exceptions=True,
|
424
426
|
remote_cache_description: Optional[str] = None,
|
427
|
+
remote_inference_description: Optional[str] = None,
|
425
428
|
) -> Results:
|
426
429
|
"""
|
427
430
|
Runs the Job: conducts Interviews and returns their results.
|
@@ -431,11 +434,11 @@ class Jobs(Base):
|
|
431
434
|
:param progress_bar: shows a progress bar
|
432
435
|
:param stop_on_exception: stops the job if an exception is raised
|
433
436
|
:param cache: a cache object to store results
|
434
|
-
:param remote: run the job remotely
|
435
437
|
:param check_api_keys: check if the API keys are valid
|
436
438
|
:param batch_mode: run the job in batch mode i.e., no expecation of interaction with the user
|
437
439
|
:param verbose: prints messages
|
438
440
|
:param remote_cache_description: specifies a description for this group of entries in the remote cache
|
441
|
+
:param remote_inference_description: specifies a description for the remote inference job
|
439
442
|
"""
|
440
443
|
from edsl.coop.coop import Coop
|
441
444
|
|
@@ -446,21 +449,57 @@ class Jobs(Base):
|
|
446
449
|
"Batch mode is deprecated. Please update your code to not include 'batch_mode' in the 'run' method."
|
447
450
|
)
|
448
451
|
|
449
|
-
self.remote = remote
|
450
452
|
self.verbose = verbose
|
451
453
|
|
452
454
|
try:
|
453
455
|
coop = Coop()
|
454
|
-
|
456
|
+
user_edsl_settings = coop.edsl_settings
|
457
|
+
remote_cache = user_edsl_settings["remote_caching"]
|
458
|
+
remote_inference = user_edsl_settings["remote_inference"]
|
455
459
|
except Exception:
|
456
460
|
remote_cache = False
|
461
|
+
remote_inference = False
|
462
|
+
|
463
|
+
if remote_inference:
|
464
|
+
from edsl.agents.Agent import Agent
|
465
|
+
from edsl.language_models.registry import Model
|
466
|
+
from edsl.results.Result import Result
|
467
|
+
from edsl.results.Results import Results
|
468
|
+
from edsl.scenarios.Scenario import Scenario
|
469
|
+
from edsl.surveys.Survey import Survey
|
470
|
+
|
471
|
+
self._output("Remote inference activated. Sending job to server...")
|
472
|
+
if remote_cache:
|
473
|
+
self._output(
|
474
|
+
"Remote caching activated. The remote cache will be used for this job."
|
475
|
+
)
|
457
476
|
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
477
|
+
remote_job_data = coop.remote_inference_create(
|
478
|
+
self,
|
479
|
+
description=remote_inference_description,
|
480
|
+
status="queued",
|
481
|
+
)
|
482
|
+
self._output("Job sent!")
|
483
|
+
# Create mock results object to store job data
|
484
|
+
results = Results(
|
485
|
+
survey=Survey(),
|
486
|
+
data=[
|
487
|
+
Result(
|
488
|
+
agent=Agent.example(),
|
489
|
+
scenario=Scenario.example(),
|
490
|
+
model=Model(),
|
491
|
+
iteration=1,
|
492
|
+
answer={"info": "Remote job details"},
|
493
|
+
)
|
494
|
+
],
|
495
|
+
)
|
496
|
+
results.add_columns_from_dict([remote_job_data])
|
497
|
+
if self.verbose:
|
498
|
+
results.select(["info", "uuid", "status", "version"]).print(
|
499
|
+
format="rich"
|
500
|
+
)
|
501
|
+
return results
|
502
|
+
else:
|
464
503
|
if check_api_keys:
|
465
504
|
for model in self.models + [Model()]:
|
466
505
|
if not model.has_valid_api_key():
|
@@ -471,8 +510,12 @@ class Jobs(Base):
|
|
471
510
|
|
472
511
|
# handle cache
|
473
512
|
if cache is None:
|
513
|
+
from edsl.data.CacheHandler import CacheHandler
|
514
|
+
|
474
515
|
cache = CacheHandler().get_cache()
|
475
516
|
if cache is False:
|
517
|
+
from edsl.data.Cache import Cache
|
518
|
+
|
476
519
|
cache = Cache()
|
477
520
|
|
478
521
|
if not remote_cache:
|
@@ -624,6 +667,11 @@ class Jobs(Base):
|
|
624
667
|
@remove_edsl_version
|
625
668
|
def from_dict(cls, data: dict) -> Jobs:
|
626
669
|
"""Creates a Jobs instance from a dictionary."""
|
670
|
+
from edsl import Survey
|
671
|
+
from edsl.agents.Agent import Agent
|
672
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
673
|
+
from edsl.scenarios.Scenario import Scenario
|
674
|
+
|
627
675
|
return cls(
|
628
676
|
survey=Survey.from_dict(data["survey"]),
|
629
677
|
agents=[Agent.from_dict(agent) for agent in data["agents"]],
|
@@ -639,7 +687,9 @@ class Jobs(Base):
|
|
639
687
|
# Example methods
|
640
688
|
#######################
|
641
689
|
@classmethod
|
642
|
-
def example(
|
690
|
+
def example(
|
691
|
+
cls, throw_exception_probability: int = 0, randomize: bool = False
|
692
|
+
) -> Jobs:
|
643
693
|
"""Return an example Jobs instance.
|
644
694
|
|
645
695
|
:param throw_exception_probability: the probability that an exception will be thrown when answering a question. This is useful for testing error handling.
|
@@ -649,8 +699,12 @@ class Jobs(Base):
|
|
649
699
|
|
650
700
|
"""
|
651
701
|
import random
|
702
|
+
from uuid import uuid4
|
652
703
|
from edsl.questions import QuestionMultipleChoice
|
653
|
-
from edsl import Agent
|
704
|
+
from edsl.agents.Agent import Agent
|
705
|
+
from edsl.scenarios.Scenario import Scenario
|
706
|
+
|
707
|
+
addition = "" if not randomize else str(uuid4())
|
654
708
|
|
655
709
|
# (status, question, period)
|
656
710
|
agent_answers = {
|
@@ -689,10 +743,15 @@ class Jobs(Base):
|
|
689
743
|
question_options=["Good", "Great", "OK", "Terrible"],
|
690
744
|
question_name="how_feeling_yesterday",
|
691
745
|
)
|
746
|
+
from edsl import Survey, ScenarioList
|
747
|
+
|
692
748
|
base_survey = Survey(questions=[q1, q2])
|
693
749
|
|
694
750
|
scenario_list = ScenarioList(
|
695
|
-
[
|
751
|
+
[
|
752
|
+
Scenario({"period": f"morning{addition}"}),
|
753
|
+
Scenario({"period": "afternoon"}),
|
754
|
+
]
|
696
755
|
)
|
697
756
|
job = base_survey.by(scenario_list).by(joy_agent, sad_agent)
|
698
757
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from edsl.jobs.buckets.TokenBucket import TokenBucket
|
1
|
+
# from edsl.jobs.buckets.TokenBucket import TokenBucket
|
2
2
|
|
3
3
|
|
4
4
|
class ModelBuckets:
|
@@ -8,7 +8,7 @@ class ModelBuckets:
|
|
8
8
|
A request is one call to the service. The number of tokens required for a request depends on parameters.
|
9
9
|
"""
|
10
10
|
|
11
|
-
def __init__(self, requests_bucket: TokenBucket, tokens_bucket: TokenBucket):
|
11
|
+
def __init__(self, requests_bucket: "TokenBucket", tokens_bucket: "TokenBucket"):
|
12
12
|
"""Initialize the model buckets.
|
13
13
|
|
14
14
|
The requests bucket captures requests per unit of time.
|
@@ -25,9 +25,21 @@ class ModelBuckets:
|
|
25
25
|
tokens_bucket=self.tokens_bucket + other.tokens_bucket,
|
26
26
|
)
|
27
27
|
|
28
|
+
def turbo_mode_on(self):
|
29
|
+
"""Set the refill rate to infinity for both buckets."""
|
30
|
+
self.requests_bucket.turbo_mode_on()
|
31
|
+
self.tokens_bucket.turbo_mode_on()
|
32
|
+
|
33
|
+
def turbo_mode_off(self):
|
34
|
+
"""Restore the refill rate to its original value for both buckets."""
|
35
|
+
self.requests_bucket.turbo_mode_off()
|
36
|
+
self.tokens_bucket.turbo_mode_off()
|
37
|
+
|
28
38
|
@classmethod
|
29
39
|
def infinity_bucket(cls, model_name: str = "not_specified") -> "ModelBuckets":
|
30
40
|
"""Create a bucket with infinite capacity and refill rate."""
|
41
|
+
from edsl.jobs.buckets.TokenBucket import TokenBucket
|
42
|
+
|
31
43
|
return cls(
|
32
44
|
requests_bucket=TokenBucket(
|
33
45
|
bucket_name=model_name,
|