edsl 0.1.28__py3-none-any.whl → 0.1.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +18 -18
- edsl/__init__.py +24 -24
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +77 -41
- edsl/agents/AgentList.py +35 -6
- edsl/agents/Invigilator.py +19 -1
- edsl/agents/InvigilatorBase.py +15 -10
- edsl/agents/PromptConstructionMixin.py +342 -100
- edsl/agents/descriptors.py +2 -1
- edsl/base/Base.py +289 -0
- edsl/config.py +2 -1
- edsl/conjure/InputData.py +39 -8
- edsl/coop/coop.py +188 -151
- edsl/coop/utils.py +43 -75
- edsl/data/Cache.py +19 -5
- edsl/data/SQLiteDict.py +11 -3
- edsl/jobs/Answers.py +15 -1
- edsl/jobs/Jobs.py +92 -47
- edsl/jobs/buckets/ModelBuckets.py +4 -2
- edsl/jobs/buckets/TokenBucket.py +1 -2
- edsl/jobs/interviews/Interview.py +3 -9
- edsl/jobs/interviews/InterviewStatusMixin.py +3 -3
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +15 -10
- edsl/jobs/runners/JobsRunnerAsyncio.py +21 -25
- edsl/jobs/tasks/TaskHistory.py +4 -3
- edsl/language_models/LanguageModel.py +5 -11
- edsl/language_models/ModelList.py +3 -3
- edsl/language_models/repair.py +8 -7
- edsl/notebooks/Notebook.py +40 -3
- edsl/prompts/Prompt.py +31 -19
- edsl/questions/QuestionBase.py +38 -13
- edsl/questions/QuestionBudget.py +5 -6
- edsl/questions/QuestionCheckBox.py +7 -3
- edsl/questions/QuestionExtract.py +5 -3
- edsl/questions/QuestionFreeText.py +3 -3
- edsl/questions/QuestionFunctional.py +0 -3
- edsl/questions/QuestionList.py +3 -4
- edsl/questions/QuestionMultipleChoice.py +16 -8
- edsl/questions/QuestionNumerical.py +4 -3
- edsl/questions/QuestionRank.py +5 -3
- edsl/questions/__init__.py +4 -3
- edsl/questions/descriptors.py +4 -2
- edsl/questions/question_registry.py +20 -31
- edsl/questions/settings.py +1 -1
- edsl/results/Dataset.py +31 -0
- edsl/results/DatasetExportMixin.py +493 -0
- edsl/results/Result.py +22 -74
- edsl/results/Results.py +105 -67
- edsl/results/ResultsDBMixin.py +7 -3
- edsl/results/ResultsExportMixin.py +22 -537
- edsl/results/ResultsGGMixin.py +3 -3
- edsl/results/ResultsToolsMixin.py +5 -5
- edsl/scenarios/FileStore.py +140 -0
- edsl/scenarios/Scenario.py +5 -6
- edsl/scenarios/ScenarioList.py +44 -15
- edsl/scenarios/ScenarioListExportMixin.py +32 -0
- edsl/scenarios/ScenarioListPdfMixin.py +2 -1
- edsl/scenarios/__init__.py +1 -0
- edsl/study/ObjectEntry.py +89 -13
- edsl/study/ProofOfWork.py +5 -2
- edsl/study/SnapShot.py +4 -8
- edsl/study/Study.py +21 -14
- edsl/study/__init__.py +2 -0
- edsl/surveys/MemoryPlan.py +11 -4
- edsl/surveys/Survey.py +46 -7
- edsl/surveys/SurveyExportMixin.py +4 -2
- edsl/surveys/SurveyFlowVisualizationMixin.py +6 -4
- edsl/tools/plotting.py +4 -2
- edsl/utilities/__init__.py +21 -21
- edsl/utilities/interface.py +66 -45
- edsl/utilities/utilities.py +11 -13
- {edsl-0.1.28.dist-info → edsl-0.1.29.dist-info}/METADATA +11 -10
- {edsl-0.1.28.dist-info → edsl-0.1.29.dist-info}/RECORD +75 -72
- edsl-0.1.28.dist-info/entry_points.txt +0 -3
- {edsl-0.1.28.dist-info → edsl-0.1.29.dist-info}/LICENSE +0 -0
- {edsl-0.1.28.dist-info → edsl-0.1.29.dist-info}/WHEEL +0 -0
edsl/coop/utils.py
CHANGED
@@ -2,7 +2,6 @@ from edsl import (
|
|
2
2
|
Agent,
|
3
3
|
AgentList,
|
4
4
|
Cache,
|
5
|
-
Jobs,
|
6
5
|
Notebook,
|
7
6
|
Results,
|
8
7
|
Scenario,
|
@@ -11,13 +10,12 @@ from edsl import (
|
|
11
10
|
Study,
|
12
11
|
)
|
13
12
|
from edsl.questions import QuestionBase
|
14
|
-
from typing import Literal, Type, Union
|
13
|
+
from typing import Literal, Optional, Type, Union
|
15
14
|
|
16
15
|
EDSLObject = Union[
|
17
16
|
Agent,
|
18
17
|
AgentList,
|
19
18
|
Cache,
|
20
|
-
Jobs,
|
21
19
|
Notebook,
|
22
20
|
Type[QuestionBase],
|
23
21
|
Results,
|
@@ -31,9 +29,8 @@ ObjectType = Literal[
|
|
31
29
|
"agent",
|
32
30
|
"agent_list",
|
33
31
|
"cache",
|
34
|
-
"job",
|
35
|
-
"question",
|
36
32
|
"notebook",
|
33
|
+
"question",
|
37
34
|
"results",
|
38
35
|
"scenario",
|
39
36
|
"scenario_list",
|
@@ -41,18 +38,12 @@ ObjectType = Literal[
|
|
41
38
|
"study",
|
42
39
|
]
|
43
40
|
|
44
|
-
|
45
|
-
|
46
|
-
"
|
47
|
-
"
|
48
|
-
"
|
49
|
-
"
|
50
|
-
"questions",
|
51
|
-
"results",
|
52
|
-
"scenarios",
|
53
|
-
"scenariolists",
|
54
|
-
"surveys",
|
55
|
-
"studies",
|
41
|
+
|
42
|
+
RemoteJobStatus = Literal[
|
43
|
+
"queued",
|
44
|
+
"running",
|
45
|
+
"completed",
|
46
|
+
"failed",
|
56
47
|
]
|
57
48
|
|
58
49
|
VisibilityType = Literal[
|
@@ -68,67 +59,21 @@ class ObjectRegistry:
|
|
68
59
|
"""
|
69
60
|
|
70
61
|
objects = [
|
71
|
-
{
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
},
|
76
|
-
{
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
},
|
81
|
-
{
|
82
|
-
"object_type": "cache",
|
83
|
-
"edsl_class": Cache,
|
84
|
-
"object_page": "caches",
|
85
|
-
},
|
86
|
-
{
|
87
|
-
"object_type": "job",
|
88
|
-
"edsl_class": Jobs,
|
89
|
-
"object_page": "jobs",
|
90
|
-
},
|
91
|
-
{
|
92
|
-
"object_type": "question",
|
93
|
-
"edsl_class": QuestionBase,
|
94
|
-
"object_page": "questions",
|
95
|
-
},
|
96
|
-
{
|
97
|
-
"object_type": "notebook",
|
98
|
-
"edsl_class": Notebook,
|
99
|
-
"object_page": "notebooks",
|
100
|
-
},
|
101
|
-
{
|
102
|
-
"object_type": "results",
|
103
|
-
"edsl_class": Results,
|
104
|
-
"object_page": "results",
|
105
|
-
},
|
106
|
-
{
|
107
|
-
"object_type": "scenario",
|
108
|
-
"edsl_class": Scenario,
|
109
|
-
"object_page": "scenarios",
|
110
|
-
},
|
111
|
-
{
|
112
|
-
"object_type": "scenario_list",
|
113
|
-
"edsl_class": ScenarioList,
|
114
|
-
"object_page": "scenariolists",
|
115
|
-
},
|
116
|
-
{
|
117
|
-
"object_type": "survey",
|
118
|
-
"edsl_class": Survey,
|
119
|
-
"object_page": "surveys",
|
120
|
-
},
|
121
|
-
{
|
122
|
-
"object_type": "study",
|
123
|
-
"edsl_class": Study,
|
124
|
-
"object_page": "studies",
|
125
|
-
},
|
62
|
+
{"object_type": "agent", "edsl_class": Agent},
|
63
|
+
{"object_type": "agent_list", "edsl_class": AgentList},
|
64
|
+
{"object_type": "cache", "edsl_class": Cache},
|
65
|
+
{"object_type": "question", "edsl_class": QuestionBase},
|
66
|
+
{"object_type": "notebook", "edsl_class": Notebook},
|
67
|
+
{"object_type": "results", "edsl_class": Results},
|
68
|
+
{"object_type": "scenario", "edsl_class": Scenario},
|
69
|
+
{"object_type": "scenario_list", "edsl_class": ScenarioList},
|
70
|
+
{"object_type": "survey", "edsl_class": Survey},
|
71
|
+
{"object_type": "study", "edsl_class": Study},
|
126
72
|
]
|
127
73
|
object_type_to_edsl_class = {o["object_type"]: o["edsl_class"] for o in objects}
|
128
74
|
edsl_class_to_object_type = {
|
129
75
|
o["edsl_class"].__name__: o["object_type"] for o in objects
|
130
76
|
}
|
131
|
-
object_type_to_object_page = {o["object_type"]: o["object_page"] for o in objects}
|
132
77
|
|
133
78
|
@classmethod
|
134
79
|
def get_object_type_by_edsl_class(cls, edsl_object: EDSLObject) -> ObjectType:
|
@@ -151,5 +96,28 @@ class ObjectRegistry:
|
|
151
96
|
return EDSL_object
|
152
97
|
|
153
98
|
@classmethod
|
154
|
-
def
|
155
|
-
|
99
|
+
def get_registry(
|
100
|
+
cls,
|
101
|
+
subclass_registry: Optional[dict] = None,
|
102
|
+
exclude_classes: Optional[list] = None,
|
103
|
+
) -> dict:
|
104
|
+
"""
|
105
|
+
Return the registry of objects.
|
106
|
+
|
107
|
+
Exclude objects that are already registered in subclass_registry.
|
108
|
+
This allows the user to isolate Coop-only objects.
|
109
|
+
|
110
|
+
Also exclude objects if their class name is in the exclude_classes list.
|
111
|
+
"""
|
112
|
+
|
113
|
+
if subclass_registry is None:
|
114
|
+
subclass_registry = {}
|
115
|
+
if exclude_classes is None:
|
116
|
+
exclude_classes = []
|
117
|
+
|
118
|
+
return {
|
119
|
+
class_name: o["edsl_class"]
|
120
|
+
for o in cls.objects
|
121
|
+
if (class_name := o["edsl_class"].__name__) not in subclass_registry
|
122
|
+
and class_name not in exclude_classes
|
123
|
+
}
|
edsl/data/Cache.py
CHANGED
@@ -7,13 +7,13 @@ import json
|
|
7
7
|
import os
|
8
8
|
import warnings
|
9
9
|
from typing import Optional, Union
|
10
|
-
|
10
|
+
import time
|
11
11
|
from edsl.config import CONFIG
|
12
12
|
from edsl.data.CacheEntry import CacheEntry
|
13
|
-
|
13
|
+
|
14
|
+
# from edsl.data.SQLiteDict import SQLiteDict
|
14
15
|
from edsl.Base import Base
|
15
16
|
from edsl.utilities.utilities import dict_hash
|
16
|
-
|
17
17
|
from edsl.utilities.decorators import (
|
18
18
|
add_edsl_version,
|
19
19
|
remove_edsl_version,
|
@@ -38,7 +38,7 @@ class Cache(Base):
|
|
38
38
|
self,
|
39
39
|
*,
|
40
40
|
filename: Optional[str] = None,
|
41
|
-
data: Optional[Union[SQLiteDict, dict]] = None,
|
41
|
+
data: Optional[Union["SQLiteDict", dict]] = None,
|
42
42
|
immediate_write: bool = True,
|
43
43
|
method=None,
|
44
44
|
):
|
@@ -104,6 +104,8 @@ class Cache(Base):
|
|
104
104
|
|
105
105
|
def _perform_checks(self):
|
106
106
|
"""Perform checks on the cache."""
|
107
|
+
from edsl.data.CacheEntry import CacheEntry
|
108
|
+
|
107
109
|
if any(not isinstance(value, CacheEntry) for value in self.data.values()):
|
108
110
|
raise Exception("Not all values are CacheEntry instances")
|
109
111
|
if self.method is not None:
|
@@ -138,6 +140,8 @@ class Cache(Base):
|
|
138
140
|
|
139
141
|
|
140
142
|
"""
|
143
|
+
from edsl.data.CacheEntry import CacheEntry
|
144
|
+
|
141
145
|
key = CacheEntry.gen_key(
|
142
146
|
model=model,
|
143
147
|
parameters=parameters,
|
@@ -171,6 +175,7 @@ class Cache(Base):
|
|
171
175
|
* If `immediate_write` is True , the key-value pair is added to `self.data`
|
172
176
|
* If `immediate_write` is False, the key-value pair is added to `self.new_entries_to_write_later`
|
173
177
|
"""
|
178
|
+
|
174
179
|
entry = CacheEntry(
|
175
180
|
model=model,
|
176
181
|
parameters=parameters,
|
@@ -188,13 +193,14 @@ class Cache(Base):
|
|
188
193
|
return key
|
189
194
|
|
190
195
|
def add_from_dict(
|
191
|
-
self, new_data: dict[str, CacheEntry], write_now: Optional[bool] = True
|
196
|
+
self, new_data: dict[str, "CacheEntry"], write_now: Optional[bool] = True
|
192
197
|
) -> None:
|
193
198
|
"""
|
194
199
|
Add entries to the cache from a dictionary.
|
195
200
|
|
196
201
|
:param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
|
197
202
|
"""
|
203
|
+
|
198
204
|
for key, value in new_data.items():
|
199
205
|
if key in self.data:
|
200
206
|
if value != self.data[key]:
|
@@ -231,6 +237,8 @@ class Cache(Base):
|
|
231
237
|
|
232
238
|
:param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
|
233
239
|
"""
|
240
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
241
|
+
|
234
242
|
db = SQLiteDict(db_path)
|
235
243
|
new_data = {}
|
236
244
|
for key, value in db.items():
|
@@ -242,6 +250,8 @@ class Cache(Base):
|
|
242
250
|
"""
|
243
251
|
Construct a Cache from a SQLite database.
|
244
252
|
"""
|
253
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
254
|
+
|
245
255
|
return cls(data=SQLiteDict(db_path))
|
246
256
|
|
247
257
|
@classmethod
|
@@ -268,6 +278,8 @@ class Cache(Base):
|
|
268
278
|
* If `db_path` is provided, the cache will be stored in an SQLite database.
|
269
279
|
"""
|
270
280
|
# if a file doesn't exist at jsonfile, throw an error
|
281
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
282
|
+
|
271
283
|
if not os.path.exists(jsonlfile):
|
272
284
|
raise FileNotFoundError(f"File {jsonlfile} not found")
|
273
285
|
|
@@ -286,6 +298,8 @@ class Cache(Base):
|
|
286
298
|
"""
|
287
299
|
## TODO: Check to make sure not over-writing (?)
|
288
300
|
## Should be added to SQLiteDict constructor (?)
|
301
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
302
|
+
|
289
303
|
new_data = SQLiteDict(db_path)
|
290
304
|
for key, value in self.data.items():
|
291
305
|
new_data[key] = value
|
edsl/data/SQLiteDict.py
CHANGED
@@ -1,9 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import json
|
3
|
-
from sqlalchemy import create_engine
|
4
|
-
from sqlalchemy.exc import SQLAlchemyError
|
5
|
-
from sqlalchemy.orm import sessionmaker
|
6
3
|
from typing import Any, Generator, Optional, Union
|
4
|
+
|
7
5
|
from edsl.config import CONFIG
|
8
6
|
from edsl.data.CacheEntry import CacheEntry
|
9
7
|
from edsl.data.orm import Base, Data
|
@@ -25,10 +23,16 @@ class SQLiteDict:
|
|
25
23
|
>>> import os; os.unlink(temp_db_path) # Clean up the temp file after the test
|
26
24
|
|
27
25
|
"""
|
26
|
+
from sqlalchemy.exc import SQLAlchemyError
|
27
|
+
from sqlalchemy.orm import sessionmaker
|
28
|
+
from sqlalchemy import create_engine
|
29
|
+
|
28
30
|
self.db_path = db_path or CONFIG.get("EDSL_DATABASE_PATH")
|
29
31
|
if not self.db_path.startswith("sqlite:///"):
|
30
32
|
self.db_path = f"sqlite:///{self.db_path}"
|
31
33
|
try:
|
34
|
+
from edsl.data.orm import Base, Data
|
35
|
+
|
32
36
|
self.engine = create_engine(self.db_path, echo=False, future=True)
|
33
37
|
Base.metadata.create_all(self.engine)
|
34
38
|
self.Session = sessionmaker(bind=self.engine)
|
@@ -55,6 +59,8 @@ class SQLiteDict:
|
|
55
59
|
if not isinstance(value, CacheEntry):
|
56
60
|
raise ValueError(f"Value must be a CacheEntry object (got {type(value)}).")
|
57
61
|
with self.Session() as db:
|
62
|
+
from edsl.data.orm import Base, Data
|
63
|
+
|
58
64
|
db.merge(Data(key=key, value=json.dumps(value.to_dict())))
|
59
65
|
db.commit()
|
60
66
|
|
@@ -69,6 +75,8 @@ class SQLiteDict:
|
|
69
75
|
True
|
70
76
|
"""
|
71
77
|
with self.Session() as db:
|
78
|
+
from edsl.data.orm import Base, Data
|
79
|
+
|
72
80
|
value = db.query(Data).filter_by(key=key).first()
|
73
81
|
if not value:
|
74
82
|
raise KeyError(f"Key '{key}' not found.")
|
edsl/jobs/Answers.py
CHANGED
@@ -8,7 +8,15 @@ class Answers(UserDict):
|
|
8
8
|
"""Helper class to hold the answers to a survey."""
|
9
9
|
|
10
10
|
def add_answer(self, response, question) -> None:
|
11
|
-
"""Add a response to the answers dictionary.
|
11
|
+
"""Add a response to the answers dictionary.
|
12
|
+
|
13
|
+
>>> from edsl import QuestionFreeText
|
14
|
+
>>> q = QuestionFreeText.example()
|
15
|
+
>>> answers = Answers()
|
16
|
+
>>> answers.add_answer({"answer": "yes"}, q)
|
17
|
+
>>> answers[q.question_name]
|
18
|
+
'yes'
|
19
|
+
"""
|
12
20
|
answer = response.get("answer")
|
13
21
|
comment = response.pop("comment", None)
|
14
22
|
# record the answer
|
@@ -42,3 +50,9 @@ class Answers(UserDict):
|
|
42
50
|
table.add_row(attr_name, repr(attr_value))
|
43
51
|
|
44
52
|
return table
|
53
|
+
|
54
|
+
|
55
|
+
if __name__ == "__main__":
|
56
|
+
import doctest
|
57
|
+
|
58
|
+
doctest.testmod()
|
edsl/jobs/Jobs.py
CHANGED
@@ -1,30 +1,15 @@
|
|
1
1
|
# """The Jobs class is a collection of agents, scenarios and models and one survey."""
|
2
2
|
from __future__ import annotations
|
3
|
-
import os
|
4
3
|
import warnings
|
5
4
|
from itertools import product
|
6
5
|
from typing import Optional, Union, Sequence, Generator
|
7
|
-
|
8
|
-
from edsl.agents import Agent
|
9
|
-
from edsl.agents.AgentList import AgentList
|
6
|
+
|
10
7
|
from edsl.Base import Base
|
11
|
-
from edsl.data.Cache import Cache
|
12
|
-
from edsl.data.CacheHandler import CacheHandler
|
13
|
-
from edsl.results.Dataset import Dataset
|
14
8
|
|
15
|
-
from edsl.exceptions.jobs import MissingRemoteInferenceError
|
16
9
|
from edsl.exceptions import MissingAPIKeyError
|
17
10
|
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
18
11
|
from edsl.jobs.interviews.Interview import Interview
|
19
|
-
from edsl.language_models import LanguageModel
|
20
|
-
from edsl.results import Results
|
21
|
-
from edsl.scenarios import Scenario
|
22
|
-
from edsl import ScenarioList
|
23
|
-
from edsl.surveys import Survey
|
24
12
|
from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
|
25
|
-
|
26
|
-
from edsl.language_models.ModelList import ModelList
|
27
|
-
|
28
13
|
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
29
14
|
|
30
15
|
|
@@ -37,10 +22,10 @@ class Jobs(Base):
|
|
37
22
|
|
38
23
|
def __init__(
|
39
24
|
self,
|
40
|
-
survey: Survey,
|
41
|
-
agents: Optional[list[Agent]] = None,
|
42
|
-
models: Optional[list[LanguageModel]] = None,
|
43
|
-
scenarios: Optional[list[Scenario]] = None,
|
25
|
+
survey: "Survey",
|
26
|
+
agents: Optional[list["Agent"]] = None,
|
27
|
+
models: Optional[list["LanguageModel"]] = None,
|
28
|
+
scenarios: Optional[list["Scenario"]] = None,
|
44
29
|
):
|
45
30
|
"""Initialize a Jobs instance.
|
46
31
|
|
@@ -50,8 +35,8 @@ class Jobs(Base):
|
|
50
35
|
:param scenarios: a list of scenarios
|
51
36
|
"""
|
52
37
|
self.survey = survey
|
53
|
-
self.agents: AgentList = agents
|
54
|
-
self.scenarios: ScenarioList = scenarios
|
38
|
+
self.agents: "AgentList" = agents
|
39
|
+
self.scenarios: "ScenarioList" = scenarios
|
55
40
|
self.models = models
|
56
41
|
|
57
42
|
self.__bucket_collection = None
|
@@ -62,6 +47,8 @@ class Jobs(Base):
|
|
62
47
|
|
63
48
|
@models.setter
|
64
49
|
def models(self, value):
|
50
|
+
from edsl import ModelList
|
51
|
+
|
65
52
|
if value:
|
66
53
|
if not isinstance(value, ModelList):
|
67
54
|
self._models = ModelList(value)
|
@@ -76,6 +63,8 @@ class Jobs(Base):
|
|
76
63
|
|
77
64
|
@agents.setter
|
78
65
|
def agents(self, value):
|
66
|
+
from edsl import AgentList
|
67
|
+
|
79
68
|
if value:
|
80
69
|
if not isinstance(value, AgentList):
|
81
70
|
self._agents = AgentList(value)
|
@@ -90,6 +79,8 @@ class Jobs(Base):
|
|
90
79
|
|
91
80
|
@scenarios.setter
|
92
81
|
def scenarios(self, value):
|
82
|
+
from edsl import ScenarioList
|
83
|
+
|
93
84
|
if value:
|
94
85
|
if not isinstance(value, ScenarioList):
|
95
86
|
self._scenarios = ScenarioList(value)
|
@@ -101,10 +92,10 @@ class Jobs(Base):
|
|
101
92
|
def by(
|
102
93
|
self,
|
103
94
|
*args: Union[
|
104
|
-
Agent,
|
105
|
-
Scenario,
|
106
|
-
LanguageModel,
|
107
|
-
Sequence[Union[Agent, Scenario, LanguageModel]],
|
95
|
+
"Agent",
|
96
|
+
"Scenario",
|
97
|
+
"LanguageModel",
|
98
|
+
Sequence[Union["Agent", "Scenario", "LanguageModel"]],
|
108
99
|
],
|
109
100
|
) -> Jobs:
|
110
101
|
"""
|
@@ -144,7 +135,7 @@ class Jobs(Base):
|
|
144
135
|
setattr(self, objects_key, new_objects) # update the job
|
145
136
|
return self
|
146
137
|
|
147
|
-
def prompts(self) -> Dataset:
|
138
|
+
def prompts(self) -> "Dataset":
|
148
139
|
"""Return a Dataset of prompts that will be used.
|
149
140
|
|
150
141
|
|
@@ -160,6 +151,7 @@ class Jobs(Base):
|
|
160
151
|
user_prompts = []
|
161
152
|
system_prompts = []
|
162
153
|
scenario_indices = []
|
154
|
+
from edsl.results.Dataset import Dataset
|
163
155
|
|
164
156
|
for interview_index, interview in enumerate(interviews):
|
165
157
|
invigilators = list(interview._build_invigilators(debug=False))
|
@@ -182,7 +174,10 @@ class Jobs(Base):
|
|
182
174
|
|
183
175
|
@staticmethod
|
184
176
|
def _get_container_class(object):
|
185
|
-
from edsl import AgentList
|
177
|
+
from edsl.agents.AgentList import AgentList
|
178
|
+
from edsl.agents.Agent import Agent
|
179
|
+
from edsl.scenarios.Scenario import Scenario
|
180
|
+
from edsl.scenarios.ScenarioList import ScenarioList
|
186
181
|
|
187
182
|
if isinstance(object, Agent):
|
188
183
|
return AgentList
|
@@ -218,6 +213,10 @@ class Jobs(Base):
|
|
218
213
|
def _get_current_objects_of_this_type(
|
219
214
|
self, object: Union[Agent, Scenario, LanguageModel]
|
220
215
|
) -> tuple[list, str]:
|
216
|
+
from edsl.agents.Agent import Agent
|
217
|
+
from edsl.scenarios.Scenario import Scenario
|
218
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
219
|
+
|
221
220
|
"""Return the current objects of the same type as the first argument.
|
222
221
|
|
223
222
|
>>> from edsl.jobs import Jobs
|
@@ -246,6 +245,9 @@ class Jobs(Base):
|
|
246
245
|
@staticmethod
|
247
246
|
def _get_empty_container_object(object):
|
248
247
|
from edsl import AgentList
|
248
|
+
from edsl import Agent
|
249
|
+
from edsl import Scenario
|
250
|
+
from edsl import ScenarioList
|
249
251
|
|
250
252
|
if isinstance(object, Agent):
|
251
253
|
return AgentList([])
|
@@ -310,12 +312,12 @@ class Jobs(Base):
|
|
310
312
|
with us filling in defaults.
|
311
313
|
"""
|
312
314
|
# if no agents, models, or scenarios are set, set them to defaults
|
315
|
+
from edsl.agents.Agent import Agent
|
316
|
+
from edsl.language_models.registry import Model
|
317
|
+
from edsl.scenarios.Scenario import Scenario
|
318
|
+
|
313
319
|
self.agents = self.agents or [Agent()]
|
314
320
|
self.models = self.models or [Model()]
|
315
|
-
# if remote, set all the models to remote
|
316
|
-
if hasattr(self, "remote") and self.remote:
|
317
|
-
for model in self.models:
|
318
|
-
model.remote = True
|
319
321
|
self.scenarios = self.scenarios or [Scenario()]
|
320
322
|
for agent, scenario, model in product(self.agents, self.scenarios, self.models):
|
321
323
|
yield Interview(
|
@@ -329,6 +331,7 @@ class Jobs(Base):
|
|
329
331
|
These buckets are used to track API calls and token usage.
|
330
332
|
|
331
333
|
>>> from edsl.jobs import Jobs
|
334
|
+
>>> from edsl import Model
|
332
335
|
>>> j = Jobs.example().by(Model(temperature = 1), Model(temperature = 0.5))
|
333
336
|
>>> bc = j.create_bucket_collection()
|
334
337
|
>>> bc
|
@@ -368,14 +371,16 @@ class Jobs(Base):
|
|
368
371
|
if self.verbose:
|
369
372
|
print(message)
|
370
373
|
|
371
|
-
def _check_parameters(self, strict=False) -> None:
|
374
|
+
def _check_parameters(self, strict=False, warn=False) -> None:
|
372
375
|
"""Check if the parameters in the survey and scenarios are consistent.
|
373
376
|
|
374
377
|
>>> from edsl import QuestionFreeText
|
378
|
+
>>> from edsl import Survey
|
379
|
+
>>> from edsl import Scenario
|
375
380
|
>>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
|
376
381
|
>>> j = Jobs(survey = Survey(questions=[q]))
|
377
382
|
>>> with warnings.catch_warnings(record=True) as w:
|
378
|
-
... j._check_parameters()
|
383
|
+
... j._check_parameters(warn = True)
|
379
384
|
... assert len(w) == 1
|
380
385
|
... assert issubclass(w[-1].category, UserWarning)
|
381
386
|
... assert "The following parameters are in the survey but not in the scenarios" in str(w[-1].message)
|
@@ -403,7 +408,8 @@ class Jobs(Base):
|
|
403
408
|
if strict:
|
404
409
|
raise ValueError(message)
|
405
410
|
else:
|
406
|
-
|
411
|
+
if warn:
|
412
|
+
warnings.warn(message)
|
407
413
|
|
408
414
|
def run(
|
409
415
|
self,
|
@@ -412,15 +418,13 @@ class Jobs(Base):
|
|
412
418
|
progress_bar: bool = False,
|
413
419
|
stop_on_exception: bool = False,
|
414
420
|
cache: Union[Cache, bool] = None,
|
415
|
-
remote: bool = (
|
416
|
-
False if os.getenv("DEFAULT_RUN_MODE", "local") == "local" else True
|
417
|
-
),
|
418
421
|
check_api_keys: bool = False,
|
419
422
|
sidecar_model: Optional[LanguageModel] = None,
|
420
423
|
batch_mode: Optional[bool] = None,
|
421
424
|
verbose: bool = False,
|
422
425
|
print_exceptions=True,
|
423
426
|
remote_cache_description: Optional[str] = None,
|
427
|
+
remote_inference_description: Optional[str] = None,
|
424
428
|
) -> Results:
|
425
429
|
"""
|
426
430
|
Runs the Job: conducts Interviews and returns their results.
|
@@ -430,11 +434,11 @@ class Jobs(Base):
|
|
430
434
|
:param progress_bar: shows a progress bar
|
431
435
|
:param stop_on_exception: stops the job if an exception is raised
|
432
436
|
:param cache: a cache object to store results
|
433
|
-
:param remote: run the job remotely
|
434
437
|
:param check_api_keys: check if the API keys are valid
|
435
438
|
:param batch_mode: run the job in batch mode i.e., no expecation of interaction with the user
|
436
439
|
:param verbose: prints messages
|
437
440
|
:param remote_cache_description: specifies a description for this group of entries in the remote cache
|
441
|
+
:param remote_inference_description: specifies a description for the remote inference job
|
438
442
|
"""
|
439
443
|
from edsl.coop.coop import Coop
|
440
444
|
|
@@ -445,21 +449,50 @@ class Jobs(Base):
|
|
445
449
|
"Batch mode is deprecated. Please update your code to not include 'batch_mode' in the 'run' method."
|
446
450
|
)
|
447
451
|
|
448
|
-
self.remote = remote
|
449
452
|
self.verbose = verbose
|
450
453
|
|
451
454
|
try:
|
452
455
|
coop = Coop()
|
453
|
-
|
456
|
+
user_edsl_settings = coop.edsl_settings
|
457
|
+
remote_cache = user_edsl_settings["remote_caching"]
|
458
|
+
remote_inference = user_edsl_settings["remote_inference"]
|
454
459
|
except Exception:
|
455
460
|
remote_cache = False
|
461
|
+
remote_inference = False
|
456
462
|
|
457
|
-
if
|
458
|
-
|
459
|
-
if
|
460
|
-
|
463
|
+
if remote_inference:
|
464
|
+
self._output("Remote inference activated. Sending job to server...")
|
465
|
+
if remote_cache:
|
466
|
+
self._output(
|
467
|
+
"Remote caching activated. The remote cache will be used for this job."
|
468
|
+
)
|
461
469
|
|
462
|
-
|
470
|
+
remote_job_data = coop.remote_inference_create(
|
471
|
+
self,
|
472
|
+
description=remote_inference_description,
|
473
|
+
status="queued",
|
474
|
+
)
|
475
|
+
self._output("Job sent!")
|
476
|
+
# Create mock results object to store job data
|
477
|
+
results = Results(
|
478
|
+
survey=Survey(),
|
479
|
+
data=[
|
480
|
+
Result(
|
481
|
+
agent=Agent.example(),
|
482
|
+
scenario=Scenario.example(),
|
483
|
+
model=Model(),
|
484
|
+
iteration=1,
|
485
|
+
answer={"info": "Remote job details"},
|
486
|
+
)
|
487
|
+
],
|
488
|
+
)
|
489
|
+
results.add_columns_from_dict([remote_job_data])
|
490
|
+
if self.verbose:
|
491
|
+
results.select(["info", "uuid", "status", "version"]).print(
|
492
|
+
format="rich"
|
493
|
+
)
|
494
|
+
return results
|
495
|
+
else:
|
463
496
|
if check_api_keys:
|
464
497
|
for model in self.models + [Model()]:
|
465
498
|
if not model.has_valid_api_key():
|
@@ -470,8 +503,12 @@ class Jobs(Base):
|
|
470
503
|
|
471
504
|
# handle cache
|
472
505
|
if cache is None:
|
506
|
+
from edsl.data.CacheHandler import CacheHandler
|
507
|
+
|
473
508
|
cache = CacheHandler().get_cache()
|
474
509
|
if cache is False:
|
510
|
+
from edsl.data.Cache import Cache
|
511
|
+
|
475
512
|
cache = Cache()
|
476
513
|
|
477
514
|
if not remote_cache:
|
@@ -623,6 +660,11 @@ class Jobs(Base):
|
|
623
660
|
@remove_edsl_version
|
624
661
|
def from_dict(cls, data: dict) -> Jobs:
|
625
662
|
"""Creates a Jobs instance from a dictionary."""
|
663
|
+
from edsl import Survey
|
664
|
+
from edsl.agents.Agent import Agent
|
665
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
666
|
+
from edsl.scenarios.Scenario import Scenario
|
667
|
+
|
626
668
|
return cls(
|
627
669
|
survey=Survey.from_dict(data["survey"]),
|
628
670
|
agents=[Agent.from_dict(agent) for agent in data["agents"]],
|
@@ -649,7 +691,8 @@ class Jobs(Base):
|
|
649
691
|
"""
|
650
692
|
import random
|
651
693
|
from edsl.questions import QuestionMultipleChoice
|
652
|
-
from edsl import Agent
|
694
|
+
from edsl.agents.Agent import Agent
|
695
|
+
from edsl.scenarios.Scenario import Scenario
|
653
696
|
|
654
697
|
# (status, question, period)
|
655
698
|
agent_answers = {
|
@@ -688,6 +731,8 @@ class Jobs(Base):
|
|
688
731
|
question_options=["Good", "Great", "OK", "Terrible"],
|
689
732
|
question_name="how_feeling_yesterday",
|
690
733
|
)
|
734
|
+
from edsl import Survey, ScenarioList
|
735
|
+
|
691
736
|
base_survey = Survey(questions=[q1, q2])
|
692
737
|
|
693
738
|
scenario_list = ScenarioList(
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from edsl.jobs.buckets.TokenBucket import TokenBucket
|
1
|
+
# from edsl.jobs.buckets.TokenBucket import TokenBucket
|
2
2
|
|
3
3
|
|
4
4
|
class ModelBuckets:
|
@@ -8,7 +8,7 @@ class ModelBuckets:
|
|
8
8
|
A request is one call to the service. The number of tokens required for a request depends on parameters.
|
9
9
|
"""
|
10
10
|
|
11
|
-
def __init__(self, requests_bucket: TokenBucket, tokens_bucket: TokenBucket):
|
11
|
+
def __init__(self, requests_bucket: "TokenBucket", tokens_bucket: "TokenBucket"):
|
12
12
|
"""Initialize the model buckets.
|
13
13
|
|
14
14
|
The requests bucket captures requests per unit of time.
|
@@ -28,6 +28,8 @@ class ModelBuckets:
|
|
28
28
|
@classmethod
|
29
29
|
def infinity_bucket(cls, model_name: str = "not_specified") -> "ModelBuckets":
|
30
30
|
"""Create a bucket with infinite capacity and refill rate."""
|
31
|
+
from edsl.jobs.buckets.TokenBucket import TokenBucket
|
32
|
+
|
31
33
|
return cls(
|
32
34
|
requests_bucket=TokenBucket(
|
33
35
|
bucket_name=model_name,
|