edsl 0.1.29.dev6__py3-none-any.whl → 0.1.30.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +6 -3
- edsl/__init__.py +23 -23
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +35 -34
- edsl/agents/AgentList.py +16 -5
- edsl/agents/Invigilator.py +19 -1
- edsl/agents/descriptors.py +2 -1
- edsl/base/Base.py +289 -0
- edsl/config.py +2 -1
- edsl/coop/utils.py +28 -1
- edsl/data/Cache.py +19 -5
- edsl/data/SQLiteDict.py +11 -3
- edsl/jobs/Answers.py +15 -1
- edsl/jobs/Jobs.py +69 -31
- edsl/jobs/buckets/ModelBuckets.py +4 -2
- edsl/jobs/buckets/TokenBucket.py +1 -2
- edsl/jobs/interviews/Interview.py +0 -6
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +9 -5
- edsl/jobs/runners/JobsRunnerAsyncio.py +12 -16
- edsl/jobs/tasks/TaskHistory.py +4 -3
- edsl/language_models/LanguageModel.py +5 -11
- edsl/language_models/ModelList.py +1 -1
- edsl/language_models/repair.py +8 -7
- edsl/notebooks/Notebook.py +9 -3
- edsl/questions/QuestionBase.py +6 -2
- edsl/questions/QuestionBudget.py +5 -6
- edsl/questions/QuestionCheckBox.py +7 -3
- edsl/questions/QuestionExtract.py +5 -3
- edsl/questions/QuestionFreeText.py +3 -3
- edsl/questions/QuestionFunctional.py +0 -3
- edsl/questions/QuestionList.py +3 -4
- edsl/questions/QuestionMultipleChoice.py +12 -5
- edsl/questions/QuestionNumerical.py +4 -3
- edsl/questions/QuestionRank.py +5 -3
- edsl/questions/__init__.py +4 -3
- edsl/questions/descriptors.py +4 -2
- edsl/results/DatasetExportMixin.py +491 -0
- edsl/results/Result.py +13 -65
- edsl/results/Results.py +91 -39
- edsl/results/ResultsDBMixin.py +7 -3
- edsl/results/ResultsExportMixin.py +22 -537
- edsl/results/ResultsGGMixin.py +3 -3
- edsl/results/ResultsToolsMixin.py +1 -4
- edsl/scenarios/FileStore.py +140 -0
- edsl/scenarios/Scenario.py +5 -6
- edsl/scenarios/ScenarioList.py +17 -8
- edsl/scenarios/ScenarioListExportMixin.py +32 -0
- edsl/scenarios/ScenarioListPdfMixin.py +2 -1
- edsl/scenarios/__init__.py +1 -0
- edsl/surveys/MemoryPlan.py +11 -4
- edsl/surveys/Survey.py +9 -4
- edsl/surveys/SurveyExportMixin.py +4 -2
- edsl/surveys/SurveyFlowVisualizationMixin.py +6 -4
- edsl/utilities/__init__.py +21 -21
- edsl/utilities/interface.py +66 -45
- edsl/utilities/utilities.py +11 -13
- {edsl-0.1.29.dev6.dist-info → edsl-0.1.30.dev1.dist-info}/METADATA +1 -1
- {edsl-0.1.29.dev6.dist-info → edsl-0.1.30.dev1.dist-info}/RECORD +60 -56
- {edsl-0.1.29.dev6.dist-info → edsl-0.1.30.dev1.dist-info}/LICENSE +0 -0
- {edsl-0.1.29.dev6.dist-info → edsl-0.1.30.dev1.dist-info}/WHEEL +0 -0
edsl/config.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1
1
|
"""This module provides a Config class that loads environment variables from a .env file and sets them as class attributes."""
|
2
2
|
|
3
3
|
import os
|
4
|
-
from dotenv import load_dotenv, find_dotenv
|
5
4
|
from edsl.exceptions import (
|
6
5
|
InvalidEnvironmentVariableError,
|
7
6
|
MissingEnvironmentVariableError,
|
8
7
|
)
|
8
|
+
from dotenv import load_dotenv, find_dotenv
|
9
9
|
|
10
10
|
# valid values for EDSL_RUN_MODE
|
11
11
|
EDSL_RUN_MODES = ["development", "development-testrun", "production"]
|
@@ -96,6 +96,7 @@ class Config:
|
|
96
96
|
Loads the .env
|
97
97
|
- Overrides existing env vars unless EDSL_RUN_MODE=="development-testrun"
|
98
98
|
"""
|
99
|
+
|
99
100
|
override = True
|
100
101
|
if self.EDSL_RUN_MODE == "development-testrun":
|
101
102
|
override = False
|
edsl/coop/utils.py
CHANGED
@@ -10,7 +10,7 @@ from edsl import (
|
|
10
10
|
Study,
|
11
11
|
)
|
12
12
|
from edsl.questions import QuestionBase
|
13
|
-
from typing import Literal, Type, Union
|
13
|
+
from typing import Literal, Optional, Type, Union
|
14
14
|
|
15
15
|
EDSLObject = Union[
|
16
16
|
Agent,
|
@@ -94,3 +94,30 @@ class ObjectRegistry:
|
|
94
94
|
if EDSL_object is None:
|
95
95
|
raise ValueError(f"EDSL class not found for {object_type=}")
|
96
96
|
return EDSL_object
|
97
|
+
|
98
|
+
@classmethod
|
99
|
+
def get_registry(
|
100
|
+
cls,
|
101
|
+
subclass_registry: Optional[dict] = None,
|
102
|
+
exclude_classes: Optional[list] = None,
|
103
|
+
) -> dict:
|
104
|
+
"""
|
105
|
+
Return the registry of objects.
|
106
|
+
|
107
|
+
Exclude objects that are already registered in subclass_registry.
|
108
|
+
This allows the user to isolate Coop-only objects.
|
109
|
+
|
110
|
+
Also exclude objects if their class name is in the exclude_classes list.
|
111
|
+
"""
|
112
|
+
|
113
|
+
if subclass_registry is None:
|
114
|
+
subclass_registry = {}
|
115
|
+
if exclude_classes is None:
|
116
|
+
exclude_classes = []
|
117
|
+
|
118
|
+
return {
|
119
|
+
class_name: o["edsl_class"]
|
120
|
+
for o in cls.objects
|
121
|
+
if (class_name := o["edsl_class"].__name__) not in subclass_registry
|
122
|
+
and class_name not in exclude_classes
|
123
|
+
}
|
edsl/data/Cache.py
CHANGED
@@ -7,13 +7,13 @@ import json
|
|
7
7
|
import os
|
8
8
|
import warnings
|
9
9
|
from typing import Optional, Union
|
10
|
-
|
10
|
+
import time
|
11
11
|
from edsl.config import CONFIG
|
12
12
|
from edsl.data.CacheEntry import CacheEntry
|
13
|
-
|
13
|
+
|
14
|
+
# from edsl.data.SQLiteDict import SQLiteDict
|
14
15
|
from edsl.Base import Base
|
15
16
|
from edsl.utilities.utilities import dict_hash
|
16
|
-
|
17
17
|
from edsl.utilities.decorators import (
|
18
18
|
add_edsl_version,
|
19
19
|
remove_edsl_version,
|
@@ -38,7 +38,7 @@ class Cache(Base):
|
|
38
38
|
self,
|
39
39
|
*,
|
40
40
|
filename: Optional[str] = None,
|
41
|
-
data: Optional[Union[SQLiteDict, dict]] = None,
|
41
|
+
data: Optional[Union["SQLiteDict", dict]] = None,
|
42
42
|
immediate_write: bool = True,
|
43
43
|
method=None,
|
44
44
|
):
|
@@ -104,6 +104,8 @@ class Cache(Base):
|
|
104
104
|
|
105
105
|
def _perform_checks(self):
|
106
106
|
"""Perform checks on the cache."""
|
107
|
+
from edsl.data.CacheEntry import CacheEntry
|
108
|
+
|
107
109
|
if any(not isinstance(value, CacheEntry) for value in self.data.values()):
|
108
110
|
raise Exception("Not all values are CacheEntry instances")
|
109
111
|
if self.method is not None:
|
@@ -138,6 +140,8 @@ class Cache(Base):
|
|
138
140
|
|
139
141
|
|
140
142
|
"""
|
143
|
+
from edsl.data.CacheEntry import CacheEntry
|
144
|
+
|
141
145
|
key = CacheEntry.gen_key(
|
142
146
|
model=model,
|
143
147
|
parameters=parameters,
|
@@ -171,6 +175,7 @@ class Cache(Base):
|
|
171
175
|
* If `immediate_write` is True , the key-value pair is added to `self.data`
|
172
176
|
* If `immediate_write` is False, the key-value pair is added to `self.new_entries_to_write_later`
|
173
177
|
"""
|
178
|
+
|
174
179
|
entry = CacheEntry(
|
175
180
|
model=model,
|
176
181
|
parameters=parameters,
|
@@ -188,13 +193,14 @@ class Cache(Base):
|
|
188
193
|
return key
|
189
194
|
|
190
195
|
def add_from_dict(
|
191
|
-
self, new_data: dict[str, CacheEntry], write_now: Optional[bool] = True
|
196
|
+
self, new_data: dict[str, "CacheEntry"], write_now: Optional[bool] = True
|
192
197
|
) -> None:
|
193
198
|
"""
|
194
199
|
Add entries to the cache from a dictionary.
|
195
200
|
|
196
201
|
:param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
|
197
202
|
"""
|
203
|
+
|
198
204
|
for key, value in new_data.items():
|
199
205
|
if key in self.data:
|
200
206
|
if value != self.data[key]:
|
@@ -231,6 +237,8 @@ class Cache(Base):
|
|
231
237
|
|
232
238
|
:param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
|
233
239
|
"""
|
240
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
241
|
+
|
234
242
|
db = SQLiteDict(db_path)
|
235
243
|
new_data = {}
|
236
244
|
for key, value in db.items():
|
@@ -242,6 +250,8 @@ class Cache(Base):
|
|
242
250
|
"""
|
243
251
|
Construct a Cache from a SQLite database.
|
244
252
|
"""
|
253
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
254
|
+
|
245
255
|
return cls(data=SQLiteDict(db_path))
|
246
256
|
|
247
257
|
@classmethod
|
@@ -268,6 +278,8 @@ class Cache(Base):
|
|
268
278
|
* If `db_path` is provided, the cache will be stored in an SQLite database.
|
269
279
|
"""
|
270
280
|
# if a file doesn't exist at jsonfile, throw an error
|
281
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
282
|
+
|
271
283
|
if not os.path.exists(jsonlfile):
|
272
284
|
raise FileNotFoundError(f"File {jsonlfile} not found")
|
273
285
|
|
@@ -286,6 +298,8 @@ class Cache(Base):
|
|
286
298
|
"""
|
287
299
|
## TODO: Check to make sure not over-writing (?)
|
288
300
|
## Should be added to SQLiteDict constructor (?)
|
301
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
302
|
+
|
289
303
|
new_data = SQLiteDict(db_path)
|
290
304
|
for key, value in self.data.items():
|
291
305
|
new_data[key] = value
|
edsl/data/SQLiteDict.py
CHANGED
@@ -1,9 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import json
|
3
|
-
from sqlalchemy import create_engine
|
4
|
-
from sqlalchemy.exc import SQLAlchemyError
|
5
|
-
from sqlalchemy.orm import sessionmaker
|
6
3
|
from typing import Any, Generator, Optional, Union
|
4
|
+
|
7
5
|
from edsl.config import CONFIG
|
8
6
|
from edsl.data.CacheEntry import CacheEntry
|
9
7
|
from edsl.data.orm import Base, Data
|
@@ -25,10 +23,16 @@ class SQLiteDict:
|
|
25
23
|
>>> import os; os.unlink(temp_db_path) # Clean up the temp file after the test
|
26
24
|
|
27
25
|
"""
|
26
|
+
from sqlalchemy.exc import SQLAlchemyError
|
27
|
+
from sqlalchemy.orm import sessionmaker
|
28
|
+
from sqlalchemy import create_engine
|
29
|
+
|
28
30
|
self.db_path = db_path or CONFIG.get("EDSL_DATABASE_PATH")
|
29
31
|
if not self.db_path.startswith("sqlite:///"):
|
30
32
|
self.db_path = f"sqlite:///{self.db_path}"
|
31
33
|
try:
|
34
|
+
from edsl.data.orm import Base, Data
|
35
|
+
|
32
36
|
self.engine = create_engine(self.db_path, echo=False, future=True)
|
33
37
|
Base.metadata.create_all(self.engine)
|
34
38
|
self.Session = sessionmaker(bind=self.engine)
|
@@ -55,6 +59,8 @@ class SQLiteDict:
|
|
55
59
|
if not isinstance(value, CacheEntry):
|
56
60
|
raise ValueError(f"Value must be a CacheEntry object (got {type(value)}).")
|
57
61
|
with self.Session() as db:
|
62
|
+
from edsl.data.orm import Base, Data
|
63
|
+
|
58
64
|
db.merge(Data(key=key, value=json.dumps(value.to_dict())))
|
59
65
|
db.commit()
|
60
66
|
|
@@ -69,6 +75,8 @@ class SQLiteDict:
|
|
69
75
|
True
|
70
76
|
"""
|
71
77
|
with self.Session() as db:
|
78
|
+
from edsl.data.orm import Base, Data
|
79
|
+
|
72
80
|
value = db.query(Data).filter_by(key=key).first()
|
73
81
|
if not value:
|
74
82
|
raise KeyError(f"Key '{key}' not found.")
|
edsl/jobs/Answers.py
CHANGED
@@ -8,7 +8,15 @@ class Answers(UserDict):
|
|
8
8
|
"""Helper class to hold the answers to a survey."""
|
9
9
|
|
10
10
|
def add_answer(self, response, question) -> None:
|
11
|
-
"""Add a response to the answers dictionary.
|
11
|
+
"""Add a response to the answers dictionary.
|
12
|
+
|
13
|
+
>>> from edsl import QuestionFreeText
|
14
|
+
>>> q = QuestionFreeText.example()
|
15
|
+
>>> answers = Answers()
|
16
|
+
>>> answers.add_answer({"answer": "yes"}, q)
|
17
|
+
>>> answers[q.question_name]
|
18
|
+
'yes'
|
19
|
+
"""
|
12
20
|
answer = response.get("answer")
|
13
21
|
comment = response.pop("comment", None)
|
14
22
|
# record the answer
|
@@ -42,3 +50,9 @@ class Answers(UserDict):
|
|
42
50
|
table.add_row(attr_name, repr(attr_value))
|
43
51
|
|
44
52
|
return table
|
53
|
+
|
54
|
+
|
55
|
+
if __name__ == "__main__":
|
56
|
+
import doctest
|
57
|
+
|
58
|
+
doctest.testmod()
|
edsl/jobs/Jobs.py
CHANGED
@@ -1,30 +1,15 @@
|
|
1
1
|
# """The Jobs class is a collection of agents, scenarios and models and one survey."""
|
2
2
|
from __future__ import annotations
|
3
|
-
import os
|
4
3
|
import warnings
|
5
4
|
from itertools import product
|
6
5
|
from typing import Optional, Union, Sequence, Generator
|
7
|
-
|
8
|
-
from edsl.agents import Agent
|
9
|
-
from edsl.agents.AgentList import AgentList
|
6
|
+
|
10
7
|
from edsl.Base import Base
|
11
|
-
from edsl.data.Cache import Cache
|
12
|
-
from edsl.data.CacheHandler import CacheHandler
|
13
|
-
from edsl.results.Dataset import Dataset
|
14
8
|
|
15
|
-
from edsl.exceptions.jobs import MissingRemoteInferenceError
|
16
9
|
from edsl.exceptions import MissingAPIKeyError
|
17
10
|
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
18
11
|
from edsl.jobs.interviews.Interview import Interview
|
19
|
-
from edsl.language_models import LanguageModel
|
20
|
-
from edsl.results import Results
|
21
|
-
from edsl.scenarios import Scenario
|
22
|
-
from edsl import ScenarioList
|
23
|
-
from edsl.surveys import Survey
|
24
12
|
from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
|
25
|
-
|
26
|
-
from edsl.language_models.ModelList import ModelList
|
27
|
-
|
28
13
|
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
29
14
|
|
30
15
|
|
@@ -37,10 +22,10 @@ class Jobs(Base):
|
|
37
22
|
|
38
23
|
def __init__(
|
39
24
|
self,
|
40
|
-
survey: Survey,
|
41
|
-
agents: Optional[list[Agent]] = None,
|
42
|
-
models: Optional[list[LanguageModel]] = None,
|
43
|
-
scenarios: Optional[list[Scenario]] = None,
|
25
|
+
survey: "Survey",
|
26
|
+
agents: Optional[list["Agent"]] = None,
|
27
|
+
models: Optional[list["LanguageModel"]] = None,
|
28
|
+
scenarios: Optional[list["Scenario"]] = None,
|
44
29
|
):
|
45
30
|
"""Initialize a Jobs instance.
|
46
31
|
|
@@ -50,8 +35,8 @@ class Jobs(Base):
|
|
50
35
|
:param scenarios: a list of scenarios
|
51
36
|
"""
|
52
37
|
self.survey = survey
|
53
|
-
self.agents: AgentList = agents
|
54
|
-
self.scenarios: ScenarioList = scenarios
|
38
|
+
self.agents: "AgentList" = agents
|
39
|
+
self.scenarios: "ScenarioList" = scenarios
|
55
40
|
self.models = models
|
56
41
|
|
57
42
|
self.__bucket_collection = None
|
@@ -62,6 +47,8 @@ class Jobs(Base):
|
|
62
47
|
|
63
48
|
@models.setter
|
64
49
|
def models(self, value):
|
50
|
+
from edsl import ModelList
|
51
|
+
|
65
52
|
if value:
|
66
53
|
if not isinstance(value, ModelList):
|
67
54
|
self._models = ModelList(value)
|
@@ -76,6 +63,8 @@ class Jobs(Base):
|
|
76
63
|
|
77
64
|
@agents.setter
|
78
65
|
def agents(self, value):
|
66
|
+
from edsl import AgentList
|
67
|
+
|
79
68
|
if value:
|
80
69
|
if not isinstance(value, AgentList):
|
81
70
|
self._agents = AgentList(value)
|
@@ -90,6 +79,8 @@ class Jobs(Base):
|
|
90
79
|
|
91
80
|
@scenarios.setter
|
92
81
|
def scenarios(self, value):
|
82
|
+
from edsl import ScenarioList
|
83
|
+
|
93
84
|
if value:
|
94
85
|
if not isinstance(value, ScenarioList):
|
95
86
|
self._scenarios = ScenarioList(value)
|
@@ -101,10 +92,10 @@ class Jobs(Base):
|
|
101
92
|
def by(
|
102
93
|
self,
|
103
94
|
*args: Union[
|
104
|
-
Agent,
|
105
|
-
Scenario,
|
106
|
-
LanguageModel,
|
107
|
-
Sequence[Union[Agent, Scenario, LanguageModel]],
|
95
|
+
"Agent",
|
96
|
+
"Scenario",
|
97
|
+
"LanguageModel",
|
98
|
+
Sequence[Union["Agent", "Scenario", "LanguageModel"]],
|
108
99
|
],
|
109
100
|
) -> Jobs:
|
110
101
|
"""
|
@@ -144,7 +135,7 @@ class Jobs(Base):
|
|
144
135
|
setattr(self, objects_key, new_objects) # update the job
|
145
136
|
return self
|
146
137
|
|
147
|
-
def prompts(self) -> Dataset:
|
138
|
+
def prompts(self) -> "Dataset":
|
148
139
|
"""Return a Dataset of prompts that will be used.
|
149
140
|
|
150
141
|
|
@@ -160,6 +151,7 @@ class Jobs(Base):
|
|
160
151
|
user_prompts = []
|
161
152
|
system_prompts = []
|
162
153
|
scenario_indices = []
|
154
|
+
from edsl.results.Dataset import Dataset
|
163
155
|
|
164
156
|
for interview_index, interview in enumerate(interviews):
|
165
157
|
invigilators = list(interview._build_invigilators(debug=False))
|
@@ -182,7 +174,10 @@ class Jobs(Base):
|
|
182
174
|
|
183
175
|
@staticmethod
|
184
176
|
def _get_container_class(object):
|
185
|
-
from edsl import AgentList
|
177
|
+
from edsl.agents.AgentList import AgentList
|
178
|
+
from edsl.agents.Agent import Agent
|
179
|
+
from edsl.scenarios.Scenario import Scenario
|
180
|
+
from edsl.scenarios.ScenarioList import ScenarioList
|
186
181
|
|
187
182
|
if isinstance(object, Agent):
|
188
183
|
return AgentList
|
@@ -218,6 +213,10 @@ class Jobs(Base):
|
|
218
213
|
def _get_current_objects_of_this_type(
|
219
214
|
self, object: Union[Agent, Scenario, LanguageModel]
|
220
215
|
) -> tuple[list, str]:
|
216
|
+
from edsl.agents.Agent import Agent
|
217
|
+
from edsl.scenarios.Scenario import Scenario
|
218
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
219
|
+
|
221
220
|
"""Return the current objects of the same type as the first argument.
|
222
221
|
|
223
222
|
>>> from edsl.jobs import Jobs
|
@@ -246,6 +245,9 @@ class Jobs(Base):
|
|
246
245
|
@staticmethod
|
247
246
|
def _get_empty_container_object(object):
|
248
247
|
from edsl import AgentList
|
248
|
+
from edsl import Agent
|
249
|
+
from edsl import Scenario
|
250
|
+
from edsl import ScenarioList
|
249
251
|
|
250
252
|
if isinstance(object, Agent):
|
251
253
|
return AgentList([])
|
@@ -310,6 +312,10 @@ class Jobs(Base):
|
|
310
312
|
with us filling in defaults.
|
311
313
|
"""
|
312
314
|
# if no agents, models, or scenarios are set, set them to defaults
|
315
|
+
from edsl.agents.Agent import Agent
|
316
|
+
from edsl.language_models.registry import Model
|
317
|
+
from edsl.scenarios.Scenario import Scenario
|
318
|
+
|
313
319
|
self.agents = self.agents or [Agent()]
|
314
320
|
self.models = self.models or [Model()]
|
315
321
|
self.scenarios = self.scenarios or [Scenario()]
|
@@ -325,6 +331,7 @@ class Jobs(Base):
|
|
325
331
|
These buckets are used to track API calls and token usage.
|
326
332
|
|
327
333
|
>>> from edsl.jobs import Jobs
|
334
|
+
>>> from edsl import Model
|
328
335
|
>>> j = Jobs.example().by(Model(temperature = 1), Model(temperature = 0.5))
|
329
336
|
>>> bc = j.create_bucket_collection()
|
330
337
|
>>> bc
|
@@ -368,6 +375,8 @@ class Jobs(Base):
|
|
368
375
|
"""Check if the parameters in the survey and scenarios are consistent.
|
369
376
|
|
370
377
|
>>> from edsl import QuestionFreeText
|
378
|
+
>>> from edsl import Survey
|
379
|
+
>>> from edsl import Scenario
|
371
380
|
>>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
|
372
381
|
>>> j = Jobs(survey = Survey(questions=[q]))
|
373
382
|
>>> with warnings.catch_warnings(record=True) as w:
|
@@ -464,8 +473,25 @@ class Jobs(Base):
|
|
464
473
|
status="queued",
|
465
474
|
)
|
466
475
|
self._output("Job sent!")
|
467
|
-
|
468
|
-
|
476
|
+
# Create mock results object to store job data
|
477
|
+
results = Results(
|
478
|
+
survey=Survey(),
|
479
|
+
data=[
|
480
|
+
Result(
|
481
|
+
agent=Agent.example(),
|
482
|
+
scenario=Scenario.example(),
|
483
|
+
model=Model(),
|
484
|
+
iteration=1,
|
485
|
+
answer={"info": "Remote job details"},
|
486
|
+
)
|
487
|
+
],
|
488
|
+
)
|
489
|
+
results.add_columns_from_dict([remote_job_data])
|
490
|
+
if self.verbose:
|
491
|
+
results.select(["info", "uuid", "status", "version"]).print(
|
492
|
+
format="rich"
|
493
|
+
)
|
494
|
+
return results
|
469
495
|
else:
|
470
496
|
if check_api_keys:
|
471
497
|
for model in self.models + [Model()]:
|
@@ -477,8 +503,12 @@ class Jobs(Base):
|
|
477
503
|
|
478
504
|
# handle cache
|
479
505
|
if cache is None:
|
506
|
+
from edsl.data.CacheHandler import CacheHandler
|
507
|
+
|
480
508
|
cache = CacheHandler().get_cache()
|
481
509
|
if cache is False:
|
510
|
+
from edsl.data.Cache import Cache
|
511
|
+
|
482
512
|
cache = Cache()
|
483
513
|
|
484
514
|
if not remote_cache:
|
@@ -630,6 +660,11 @@ class Jobs(Base):
|
|
630
660
|
@remove_edsl_version
|
631
661
|
def from_dict(cls, data: dict) -> Jobs:
|
632
662
|
"""Creates a Jobs instance from a dictionary."""
|
663
|
+
from edsl import Survey
|
664
|
+
from edsl.agents.Agent import Agent
|
665
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
666
|
+
from edsl.scenarios.Scenario import Scenario
|
667
|
+
|
633
668
|
return cls(
|
634
669
|
survey=Survey.from_dict(data["survey"]),
|
635
670
|
agents=[Agent.from_dict(agent) for agent in data["agents"]],
|
@@ -656,7 +691,8 @@ class Jobs(Base):
|
|
656
691
|
"""
|
657
692
|
import random
|
658
693
|
from edsl.questions import QuestionMultipleChoice
|
659
|
-
from edsl import Agent
|
694
|
+
from edsl.agents.Agent import Agent
|
695
|
+
from edsl.scenarios.Scenario import Scenario
|
660
696
|
|
661
697
|
# (status, question, period)
|
662
698
|
agent_answers = {
|
@@ -695,6 +731,8 @@ class Jobs(Base):
|
|
695
731
|
question_options=["Good", "Great", "OK", "Terrible"],
|
696
732
|
question_name="how_feeling_yesterday",
|
697
733
|
)
|
734
|
+
from edsl import Survey, ScenarioList
|
735
|
+
|
698
736
|
base_survey = Survey(questions=[q1, q2])
|
699
737
|
|
700
738
|
scenario_list = ScenarioList(
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from edsl.jobs.buckets.TokenBucket import TokenBucket
|
1
|
+
# from edsl.jobs.buckets.TokenBucket import TokenBucket
|
2
2
|
|
3
3
|
|
4
4
|
class ModelBuckets:
|
@@ -8,7 +8,7 @@ class ModelBuckets:
|
|
8
8
|
A request is one call to the service. The number of tokens required for a request depends on parameters.
|
9
9
|
"""
|
10
10
|
|
11
|
-
def __init__(self, requests_bucket: TokenBucket, tokens_bucket: TokenBucket):
|
11
|
+
def __init__(self, requests_bucket: "TokenBucket", tokens_bucket: "TokenBucket"):
|
12
12
|
"""Initialize the model buckets.
|
13
13
|
|
14
14
|
The requests bucket captures requests per unit of time.
|
@@ -28,6 +28,8 @@ class ModelBuckets:
|
|
28
28
|
@classmethod
|
29
29
|
def infinity_bucket(cls, model_name: str = "not_specified") -> "ModelBuckets":
|
30
30
|
"""Create a bucket with infinite capacity and refill rate."""
|
31
|
+
from edsl.jobs.buckets.TokenBucket import TokenBucket
|
32
|
+
|
31
33
|
return cls(
|
32
34
|
requests_bucket=TokenBucket(
|
33
35
|
bucket_name=model_name,
|
edsl/jobs/buckets/TokenBucket.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1
1
|
from typing import Union, List, Any
|
2
2
|
import asyncio
|
3
3
|
import time
|
4
|
-
from collections import UserDict
|
5
|
-
from matplotlib import pyplot as plt
|
6
4
|
|
7
5
|
|
8
6
|
class TokenBucket:
|
@@ -114,6 +112,7 @@ class TokenBucket:
|
|
114
112
|
times, tokens = zip(*self.get_log())
|
115
113
|
start_time = times[0]
|
116
114
|
times = [t - start_time for t in times] # Normalize time to start from 0
|
115
|
+
from matplotlib import pyplot as plt
|
117
116
|
|
118
117
|
plt.figure(figsize=(10, 6))
|
119
118
|
plt.plot(times, tokens, label="Tokens Available")
|
@@ -6,15 +6,9 @@ import asyncio
|
|
6
6
|
import time
|
7
7
|
from typing import Any, Type, List, Generator, Optional
|
8
8
|
|
9
|
-
from edsl.agents import Agent
|
10
|
-
from edsl.language_models import LanguageModel
|
11
|
-
from edsl.scenarios import Scenario
|
12
|
-
from edsl.surveys import Survey
|
13
|
-
|
14
9
|
from edsl.jobs.Answers import Answers
|
15
10
|
from edsl.surveys.base import EndOfSurvey
|
16
11
|
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
17
|
-
|
18
12
|
from edsl.jobs.tasks.TaskCreators import TaskCreators
|
19
13
|
|
20
14
|
from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
|
@@ -5,17 +5,19 @@ import asyncio
|
|
5
5
|
import time
|
6
6
|
import traceback
|
7
7
|
from typing import Generator, Union
|
8
|
+
|
8
9
|
from edsl import CONFIG
|
9
10
|
from edsl.exceptions import InterviewTimeoutError
|
10
|
-
|
11
|
-
from edsl.questions.QuestionBase import QuestionBase
|
11
|
+
|
12
|
+
# from edsl.questions.QuestionBase import QuestionBase
|
12
13
|
from edsl.surveys.base import EndOfSurvey
|
13
14
|
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
14
15
|
from edsl.jobs.interviews.interview_exception_tracking import InterviewExceptionEntry
|
15
16
|
from edsl.jobs.interviews.retry_management import retry_strategy
|
16
17
|
from edsl.jobs.tasks.task_status_enum import TaskStatus
|
17
18
|
from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
|
18
|
-
|
19
|
+
|
20
|
+
# from edsl.agents.InvigilatorBase import InvigilatorBase
|
19
21
|
|
20
22
|
TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
21
23
|
|
@@ -150,15 +152,17 @@ class InterviewTaskBuildingMixin:
|
|
150
152
|
async def _answer_question_and_record_task(
|
151
153
|
self,
|
152
154
|
*,
|
153
|
-
question: QuestionBase,
|
155
|
+
question: "QuestionBase",
|
154
156
|
debug: bool,
|
155
157
|
task=None,
|
156
|
-
) -> AgentResponseDict:
|
158
|
+
) -> "AgentResponseDict":
|
157
159
|
"""Answer a question and records the task.
|
158
160
|
|
159
161
|
This in turn calls the the passed-in agent's async_answer_question method, which returns a response dictionary.
|
160
162
|
Note that is updates answers dictionary with the response.
|
161
163
|
"""
|
164
|
+
from edsl.data_transfer_models import AgentResponseDict
|
165
|
+
|
162
166
|
try:
|
163
167
|
invigilator = self._get_invigilator(question, debug=debug)
|
164
168
|
|
@@ -1,29 +1,17 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import time
|
3
3
|
import asyncio
|
4
|
-
import
|
4
|
+
import time
|
5
5
|
from contextlib import contextmanager
|
6
6
|
|
7
7
|
from typing import Coroutine, List, AsyncGenerator, Optional, Union
|
8
8
|
|
9
|
-
from rich.live import Live
|
10
|
-
from rich.console import Console
|
11
|
-
|
12
9
|
from edsl import shared_globals
|
13
|
-
from edsl.results import Results, Result
|
14
|
-
|
15
10
|
from edsl.jobs.interviews.Interview import Interview
|
16
|
-
from edsl.utilities.decorators import jupyter_nb_handler
|
17
|
-
|
18
|
-
# from edsl.jobs.Jobs import Jobs
|
19
11
|
from edsl.jobs.runners.JobsRunnerStatusMixin import JobsRunnerStatusMixin
|
20
|
-
from edsl.language_models import LanguageModel
|
21
|
-
from edsl.data.Cache import Cache
|
22
|
-
|
23
12
|
from edsl.jobs.tasks.TaskHistory import TaskHistory
|
24
13
|
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
25
|
-
|
26
|
-
import time
|
14
|
+
from edsl.utilities.decorators import jupyter_nb_handler
|
27
15
|
|
28
16
|
|
29
17
|
class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
@@ -42,13 +30,13 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
42
30
|
|
43
31
|
async def run_async_generator(
|
44
32
|
self,
|
45
|
-
cache: Cache,
|
33
|
+
cache: "Cache",
|
46
34
|
n: int = 1,
|
47
35
|
debug: bool = False,
|
48
36
|
stop_on_exception: bool = False,
|
49
37
|
sidecar_model: "LanguageModel" = None,
|
50
38
|
total_interviews: Optional[List["Interview"]] = None,
|
51
|
-
) -> AsyncGenerator[Result, None]:
|
39
|
+
) -> AsyncGenerator["Result", None]:
|
52
40
|
"""Creates the tasks, runs them asynchronously, and returns the results as a Results object.
|
53
41
|
|
54
42
|
Completed tasks are yielded as they are completed.
|
@@ -169,6 +157,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
169
157
|
question_name + "_raw_model_response"
|
170
158
|
] = result["raw_model_response"]
|
171
159
|
|
160
|
+
from edsl.results.Result import Result
|
161
|
+
|
172
162
|
result = Result(
|
173
163
|
agent=interview.agent,
|
174
164
|
scenario=interview.scenario,
|
@@ -197,6 +187,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
197
187
|
print_exceptions: bool = True,
|
198
188
|
) -> "Coroutine":
|
199
189
|
"""Runs a collection of interviews, handling both async and sync contexts."""
|
190
|
+
from rich.console import Console
|
191
|
+
|
200
192
|
console = Console()
|
201
193
|
self.results = []
|
202
194
|
self.start_time = time.monotonic()
|
@@ -204,6 +196,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
204
196
|
self.cache = cache
|
205
197
|
self.sidecar_model = sidecar_model
|
206
198
|
|
199
|
+
from edsl.results.Results import Results
|
200
|
+
|
207
201
|
if not progress_bar:
|
208
202
|
# print("Running without progress bar")
|
209
203
|
with cache as c:
|
@@ -225,6 +219,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
225
219
|
results = Results(survey=self.jobs.survey, data=self.results)
|
226
220
|
else:
|
227
221
|
# print("Running with progress bar")
|
222
|
+
from rich.live import Live
|
223
|
+
from rich.console import Console
|
228
224
|
|
229
225
|
def generate_table():
|
230
226
|
return self.status_table(self.results, self.elapsed_time)
|
edsl/jobs/tasks/TaskHistory.py
CHANGED
@@ -1,8 +1,5 @@
|
|
1
1
|
from edsl.jobs.tasks.task_status_enum import TaskStatus
|
2
|
-
from matplotlib import pyplot as plt
|
3
2
|
from typing import List, Optional
|
4
|
-
|
5
|
-
import matplotlib.pyplot as plt
|
6
3
|
from io import BytesIO
|
7
4
|
import base64
|
8
5
|
|
@@ -75,6 +72,8 @@ class TaskHistory:
|
|
75
72
|
|
76
73
|
def plot_completion_times(self):
|
77
74
|
"""Plot the completion times for each task."""
|
75
|
+
import matplotlib.pyplot as plt
|
76
|
+
|
78
77
|
updates = self.get_updates()
|
79
78
|
|
80
79
|
elapsed = [update.max_time - update.min_time for update in updates]
|
@@ -126,6 +125,8 @@ class TaskHistory:
|
|
126
125
|
rows = int(len(TaskStatus) ** 0.5) + 1
|
127
126
|
cols = (len(TaskStatus) + rows - 1) // rows # Ensure all plots fit
|
128
127
|
|
128
|
+
import matplotlib.pyplot as plt
|
129
|
+
|
129
130
|
fig, axes = plt.subplots(rows, cols, figsize=(15, 10))
|
130
131
|
axes = axes.flatten() # Flatten in case of a single row/column
|
131
132
|
|