edsl 0.1.29__py3-none-any.whl → 0.1.29.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +18 -18
- edsl/__init__.py +24 -24
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +41 -77
- edsl/agents/AgentList.py +6 -35
- edsl/agents/Invigilator.py +1 -19
- edsl/agents/InvigilatorBase.py +10 -15
- edsl/agents/PromptConstructionMixin.py +100 -342
- edsl/agents/descriptors.py +1 -2
- edsl/config.py +1 -2
- edsl/conjure/InputData.py +8 -39
- edsl/coop/coop.py +150 -187
- edsl/coop/utils.py +75 -43
- edsl/data/Cache.py +5 -19
- edsl/data/SQLiteDict.py +3 -11
- edsl/jobs/Answers.py +1 -15
- edsl/jobs/Jobs.py +46 -90
- edsl/jobs/buckets/ModelBuckets.py +2 -4
- edsl/jobs/buckets/TokenBucket.py +2 -1
- edsl/jobs/interviews/Interview.py +9 -3
- edsl/jobs/interviews/InterviewStatusMixin.py +3 -3
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +10 -15
- edsl/jobs/runners/JobsRunnerAsyncio.py +25 -21
- edsl/jobs/tasks/TaskHistory.py +3 -4
- edsl/language_models/LanguageModel.py +11 -5
- edsl/language_models/ModelList.py +3 -3
- edsl/language_models/repair.py +7 -8
- edsl/notebooks/Notebook.py +3 -40
- edsl/prompts/Prompt.py +19 -31
- edsl/questions/QuestionBase.py +13 -38
- edsl/questions/QuestionBudget.py +6 -5
- edsl/questions/QuestionCheckBox.py +3 -7
- edsl/questions/QuestionExtract.py +3 -5
- edsl/questions/QuestionFreeText.py +3 -3
- edsl/questions/QuestionFunctional.py +3 -0
- edsl/questions/QuestionList.py +4 -3
- edsl/questions/QuestionMultipleChoice.py +8 -16
- edsl/questions/QuestionNumerical.py +3 -4
- edsl/questions/QuestionRank.py +3 -5
- edsl/questions/__init__.py +3 -4
- edsl/questions/descriptors.py +2 -4
- edsl/questions/question_registry.py +31 -20
- edsl/questions/settings.py +1 -1
- edsl/results/Dataset.py +0 -31
- edsl/results/Result.py +74 -22
- edsl/results/Results.py +47 -97
- edsl/results/ResultsDBMixin.py +3 -7
- edsl/results/ResultsExportMixin.py +537 -22
- edsl/results/ResultsGGMixin.py +3 -3
- edsl/results/ResultsToolsMixin.py +5 -5
- edsl/scenarios/Scenario.py +6 -5
- edsl/scenarios/ScenarioList.py +11 -34
- edsl/scenarios/ScenarioListPdfMixin.py +1 -2
- edsl/scenarios/__init__.py +0 -1
- edsl/study/ObjectEntry.py +13 -89
- edsl/study/ProofOfWork.py +2 -5
- edsl/study/SnapShot.py +8 -4
- edsl/study/Study.py +14 -21
- edsl/study/__init__.py +0 -2
- edsl/surveys/MemoryPlan.py +4 -11
- edsl/surveys/Survey.py +7 -46
- edsl/surveys/SurveyExportMixin.py +2 -4
- edsl/surveys/SurveyFlowVisualizationMixin.py +4 -6
- edsl/tools/plotting.py +2 -4
- edsl/utilities/__init__.py +21 -21
- edsl/utilities/interface.py +45 -66
- edsl/utilities/utilities.py +13 -11
- {edsl-0.1.29.dist-info → edsl-0.1.29.dev1.dist-info}/METADATA +10 -11
- {edsl-0.1.29.dist-info → edsl-0.1.29.dev1.dist-info}/RECORD +72 -75
- edsl-0.1.29.dev1.dist-info/entry_points.txt +3 -0
- edsl/base/Base.py +0 -289
- edsl/results/DatasetExportMixin.py +0 -493
- edsl/scenarios/FileStore.py +0 -140
- edsl/scenarios/ScenarioListExportMixin.py +0 -32
- {edsl-0.1.29.dist-info → edsl-0.1.29.dev1.dist-info}/LICENSE +0 -0
- {edsl-0.1.29.dist-info → edsl-0.1.29.dev1.dist-info}/WHEEL +0 -0
edsl/data/Cache.py
CHANGED
@@ -7,13 +7,13 @@ import json
|
|
7
7
|
import os
|
8
8
|
import warnings
|
9
9
|
from typing import Optional, Union
|
10
|
-
|
10
|
+
|
11
11
|
from edsl.config import CONFIG
|
12
12
|
from edsl.data.CacheEntry import CacheEntry
|
13
|
-
|
14
|
-
# from edsl.data.SQLiteDict import SQLiteDict
|
13
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
15
14
|
from edsl.Base import Base
|
16
15
|
from edsl.utilities.utilities import dict_hash
|
16
|
+
|
17
17
|
from edsl.utilities.decorators import (
|
18
18
|
add_edsl_version,
|
19
19
|
remove_edsl_version,
|
@@ -38,7 +38,7 @@ class Cache(Base):
|
|
38
38
|
self,
|
39
39
|
*,
|
40
40
|
filename: Optional[str] = None,
|
41
|
-
data: Optional[Union[
|
41
|
+
data: Optional[Union[SQLiteDict, dict]] = None,
|
42
42
|
immediate_write: bool = True,
|
43
43
|
method=None,
|
44
44
|
):
|
@@ -104,8 +104,6 @@ class Cache(Base):
|
|
104
104
|
|
105
105
|
def _perform_checks(self):
|
106
106
|
"""Perform checks on the cache."""
|
107
|
-
from edsl.data.CacheEntry import CacheEntry
|
108
|
-
|
109
107
|
if any(not isinstance(value, CacheEntry) for value in self.data.values()):
|
110
108
|
raise Exception("Not all values are CacheEntry instances")
|
111
109
|
if self.method is not None:
|
@@ -140,8 +138,6 @@ class Cache(Base):
|
|
140
138
|
|
141
139
|
|
142
140
|
"""
|
143
|
-
from edsl.data.CacheEntry import CacheEntry
|
144
|
-
|
145
141
|
key = CacheEntry.gen_key(
|
146
142
|
model=model,
|
147
143
|
parameters=parameters,
|
@@ -175,7 +171,6 @@ class Cache(Base):
|
|
175
171
|
* If `immediate_write` is True , the key-value pair is added to `self.data`
|
176
172
|
* If `immediate_write` is False, the key-value pair is added to `self.new_entries_to_write_later`
|
177
173
|
"""
|
178
|
-
|
179
174
|
entry = CacheEntry(
|
180
175
|
model=model,
|
181
176
|
parameters=parameters,
|
@@ -193,14 +188,13 @@ class Cache(Base):
|
|
193
188
|
return key
|
194
189
|
|
195
190
|
def add_from_dict(
|
196
|
-
self, new_data: dict[str,
|
191
|
+
self, new_data: dict[str, CacheEntry], write_now: Optional[bool] = True
|
197
192
|
) -> None:
|
198
193
|
"""
|
199
194
|
Add entries to the cache from a dictionary.
|
200
195
|
|
201
196
|
:param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
|
202
197
|
"""
|
203
|
-
|
204
198
|
for key, value in new_data.items():
|
205
199
|
if key in self.data:
|
206
200
|
if value != self.data[key]:
|
@@ -237,8 +231,6 @@ class Cache(Base):
|
|
237
231
|
|
238
232
|
:param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
|
239
233
|
"""
|
240
|
-
from edsl.data.SQLiteDict import SQLiteDict
|
241
|
-
|
242
234
|
db = SQLiteDict(db_path)
|
243
235
|
new_data = {}
|
244
236
|
for key, value in db.items():
|
@@ -250,8 +242,6 @@ class Cache(Base):
|
|
250
242
|
"""
|
251
243
|
Construct a Cache from a SQLite database.
|
252
244
|
"""
|
253
|
-
from edsl.data.SQLiteDict import SQLiteDict
|
254
|
-
|
255
245
|
return cls(data=SQLiteDict(db_path))
|
256
246
|
|
257
247
|
@classmethod
|
@@ -278,8 +268,6 @@ class Cache(Base):
|
|
278
268
|
* If `db_path` is provided, the cache will be stored in an SQLite database.
|
279
269
|
"""
|
280
270
|
# if a file doesn't exist at jsonfile, throw an error
|
281
|
-
from edsl.data.SQLiteDict import SQLiteDict
|
282
|
-
|
283
271
|
if not os.path.exists(jsonlfile):
|
284
272
|
raise FileNotFoundError(f"File {jsonlfile} not found")
|
285
273
|
|
@@ -298,8 +286,6 @@ class Cache(Base):
|
|
298
286
|
"""
|
299
287
|
## TODO: Check to make sure not over-writing (?)
|
300
288
|
## Should be added to SQLiteDict constructor (?)
|
301
|
-
from edsl.data.SQLiteDict import SQLiteDict
|
302
|
-
|
303
289
|
new_data = SQLiteDict(db_path)
|
304
290
|
for key, value in self.data.items():
|
305
291
|
new_data[key] = value
|
edsl/data/SQLiteDict.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import json
|
3
|
+
from sqlalchemy import create_engine
|
4
|
+
from sqlalchemy.exc import SQLAlchemyError
|
5
|
+
from sqlalchemy.orm import sessionmaker
|
3
6
|
from typing import Any, Generator, Optional, Union
|
4
|
-
|
5
7
|
from edsl.config import CONFIG
|
6
8
|
from edsl.data.CacheEntry import CacheEntry
|
7
9
|
from edsl.data.orm import Base, Data
|
@@ -23,16 +25,10 @@ class SQLiteDict:
|
|
23
25
|
>>> import os; os.unlink(temp_db_path) # Clean up the temp file after the test
|
24
26
|
|
25
27
|
"""
|
26
|
-
from sqlalchemy.exc import SQLAlchemyError
|
27
|
-
from sqlalchemy.orm import sessionmaker
|
28
|
-
from sqlalchemy import create_engine
|
29
|
-
|
30
28
|
self.db_path = db_path or CONFIG.get("EDSL_DATABASE_PATH")
|
31
29
|
if not self.db_path.startswith("sqlite:///"):
|
32
30
|
self.db_path = f"sqlite:///{self.db_path}"
|
33
31
|
try:
|
34
|
-
from edsl.data.orm import Base, Data
|
35
|
-
|
36
32
|
self.engine = create_engine(self.db_path, echo=False, future=True)
|
37
33
|
Base.metadata.create_all(self.engine)
|
38
34
|
self.Session = sessionmaker(bind=self.engine)
|
@@ -59,8 +55,6 @@ class SQLiteDict:
|
|
59
55
|
if not isinstance(value, CacheEntry):
|
60
56
|
raise ValueError(f"Value must be a CacheEntry object (got {type(value)}).")
|
61
57
|
with self.Session() as db:
|
62
|
-
from edsl.data.orm import Base, Data
|
63
|
-
|
64
58
|
db.merge(Data(key=key, value=json.dumps(value.to_dict())))
|
65
59
|
db.commit()
|
66
60
|
|
@@ -75,8 +69,6 @@ class SQLiteDict:
|
|
75
69
|
True
|
76
70
|
"""
|
77
71
|
with self.Session() as db:
|
78
|
-
from edsl.data.orm import Base, Data
|
79
|
-
|
80
72
|
value = db.query(Data).filter_by(key=key).first()
|
81
73
|
if not value:
|
82
74
|
raise KeyError(f"Key '{key}' not found.")
|
edsl/jobs/Answers.py
CHANGED
@@ -8,15 +8,7 @@ class Answers(UserDict):
|
|
8
8
|
"""Helper class to hold the answers to a survey."""
|
9
9
|
|
10
10
|
def add_answer(self, response, question) -> None:
|
11
|
-
"""Add a response to the answers dictionary.
|
12
|
-
|
13
|
-
>>> from edsl import QuestionFreeText
|
14
|
-
>>> q = QuestionFreeText.example()
|
15
|
-
>>> answers = Answers()
|
16
|
-
>>> answers.add_answer({"answer": "yes"}, q)
|
17
|
-
>>> answers[q.question_name]
|
18
|
-
'yes'
|
19
|
-
"""
|
11
|
+
"""Add a response to the answers dictionary."""
|
20
12
|
answer = response.get("answer")
|
21
13
|
comment = response.pop("comment", None)
|
22
14
|
# record the answer
|
@@ -50,9 +42,3 @@ class Answers(UserDict):
|
|
50
42
|
table.add_row(attr_name, repr(attr_value))
|
51
43
|
|
52
44
|
return table
|
53
|
-
|
54
|
-
|
55
|
-
if __name__ == "__main__":
|
56
|
-
import doctest
|
57
|
-
|
58
|
-
doctest.testmod()
|
edsl/jobs/Jobs.py
CHANGED
@@ -1,15 +1,30 @@
|
|
1
1
|
# """The Jobs class is a collection of agents, scenarios and models and one survey."""
|
2
2
|
from __future__ import annotations
|
3
|
+
import os
|
3
4
|
import warnings
|
4
5
|
from itertools import product
|
5
6
|
from typing import Optional, Union, Sequence, Generator
|
6
|
-
|
7
|
+
from edsl import Model
|
8
|
+
from edsl.agents import Agent
|
9
|
+
from edsl.agents.AgentList import AgentList
|
7
10
|
from edsl.Base import Base
|
11
|
+
from edsl.data.Cache import Cache
|
12
|
+
from edsl.data.CacheHandler import CacheHandler
|
13
|
+
from edsl.results.Dataset import Dataset
|
8
14
|
|
15
|
+
from edsl.exceptions.jobs import MissingRemoteInferenceError
|
9
16
|
from edsl.exceptions import MissingAPIKeyError
|
10
17
|
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
11
18
|
from edsl.jobs.interviews.Interview import Interview
|
19
|
+
from edsl.language_models import LanguageModel
|
20
|
+
from edsl.results import Results
|
21
|
+
from edsl.scenarios import Scenario
|
22
|
+
from edsl import ScenarioList
|
23
|
+
from edsl.surveys import Survey
|
12
24
|
from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
|
25
|
+
|
26
|
+
from edsl.language_models.ModelList import ModelList
|
27
|
+
|
13
28
|
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
14
29
|
|
15
30
|
|
@@ -22,10 +37,10 @@ class Jobs(Base):
|
|
22
37
|
|
23
38
|
def __init__(
|
24
39
|
self,
|
25
|
-
survey:
|
26
|
-
agents: Optional[list[
|
27
|
-
models: Optional[list[
|
28
|
-
scenarios: Optional[list[
|
40
|
+
survey: Survey,
|
41
|
+
agents: Optional[list[Agent]] = None,
|
42
|
+
models: Optional[list[LanguageModel]] = None,
|
43
|
+
scenarios: Optional[list[Scenario]] = None,
|
29
44
|
):
|
30
45
|
"""Initialize a Jobs instance.
|
31
46
|
|
@@ -35,8 +50,8 @@ class Jobs(Base):
|
|
35
50
|
:param scenarios: a list of scenarios
|
36
51
|
"""
|
37
52
|
self.survey = survey
|
38
|
-
self.agents:
|
39
|
-
self.scenarios:
|
53
|
+
self.agents: AgentList = agents
|
54
|
+
self.scenarios: ScenarioList = scenarios
|
40
55
|
self.models = models
|
41
56
|
|
42
57
|
self.__bucket_collection = None
|
@@ -47,8 +62,6 @@ class Jobs(Base):
|
|
47
62
|
|
48
63
|
@models.setter
|
49
64
|
def models(self, value):
|
50
|
-
from edsl import ModelList
|
51
|
-
|
52
65
|
if value:
|
53
66
|
if not isinstance(value, ModelList):
|
54
67
|
self._models = ModelList(value)
|
@@ -63,8 +76,6 @@ class Jobs(Base):
|
|
63
76
|
|
64
77
|
@agents.setter
|
65
78
|
def agents(self, value):
|
66
|
-
from edsl import AgentList
|
67
|
-
|
68
79
|
if value:
|
69
80
|
if not isinstance(value, AgentList):
|
70
81
|
self._agents = AgentList(value)
|
@@ -79,8 +90,6 @@ class Jobs(Base):
|
|
79
90
|
|
80
91
|
@scenarios.setter
|
81
92
|
def scenarios(self, value):
|
82
|
-
from edsl import ScenarioList
|
83
|
-
|
84
93
|
if value:
|
85
94
|
if not isinstance(value, ScenarioList):
|
86
95
|
self._scenarios = ScenarioList(value)
|
@@ -92,10 +101,10 @@ class Jobs(Base):
|
|
92
101
|
def by(
|
93
102
|
self,
|
94
103
|
*args: Union[
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
Sequence[Union[
|
104
|
+
Agent,
|
105
|
+
Scenario,
|
106
|
+
LanguageModel,
|
107
|
+
Sequence[Union[Agent, Scenario, LanguageModel]],
|
99
108
|
],
|
100
109
|
) -> Jobs:
|
101
110
|
"""
|
@@ -135,7 +144,7 @@ class Jobs(Base):
|
|
135
144
|
setattr(self, objects_key, new_objects) # update the job
|
136
145
|
return self
|
137
146
|
|
138
|
-
def prompts(self) ->
|
147
|
+
def prompts(self) -> Dataset:
|
139
148
|
"""Return a Dataset of prompts that will be used.
|
140
149
|
|
141
150
|
|
@@ -151,7 +160,6 @@ class Jobs(Base):
|
|
151
160
|
user_prompts = []
|
152
161
|
system_prompts = []
|
153
162
|
scenario_indices = []
|
154
|
-
from edsl.results.Dataset import Dataset
|
155
163
|
|
156
164
|
for interview_index, interview in enumerate(interviews):
|
157
165
|
invigilators = list(interview._build_invigilators(debug=False))
|
@@ -174,10 +182,7 @@ class Jobs(Base):
|
|
174
182
|
|
175
183
|
@staticmethod
|
176
184
|
def _get_container_class(object):
|
177
|
-
from edsl
|
178
|
-
from edsl.agents.Agent import Agent
|
179
|
-
from edsl.scenarios.Scenario import Scenario
|
180
|
-
from edsl.scenarios.ScenarioList import ScenarioList
|
185
|
+
from edsl import AgentList
|
181
186
|
|
182
187
|
if isinstance(object, Agent):
|
183
188
|
return AgentList
|
@@ -213,10 +218,6 @@ class Jobs(Base):
|
|
213
218
|
def _get_current_objects_of_this_type(
|
214
219
|
self, object: Union[Agent, Scenario, LanguageModel]
|
215
220
|
) -> tuple[list, str]:
|
216
|
-
from edsl.agents.Agent import Agent
|
217
|
-
from edsl.scenarios.Scenario import Scenario
|
218
|
-
from edsl.language_models.LanguageModel import LanguageModel
|
219
|
-
|
220
221
|
"""Return the current objects of the same type as the first argument.
|
221
222
|
|
222
223
|
>>> from edsl.jobs import Jobs
|
@@ -245,9 +246,6 @@ class Jobs(Base):
|
|
245
246
|
@staticmethod
|
246
247
|
def _get_empty_container_object(object):
|
247
248
|
from edsl import AgentList
|
248
|
-
from edsl import Agent
|
249
|
-
from edsl import Scenario
|
250
|
-
from edsl import ScenarioList
|
251
249
|
|
252
250
|
if isinstance(object, Agent):
|
253
251
|
return AgentList([])
|
@@ -312,12 +310,12 @@ class Jobs(Base):
|
|
312
310
|
with us filling in defaults.
|
313
311
|
"""
|
314
312
|
# if no agents, models, or scenarios are set, set them to defaults
|
315
|
-
from edsl.agents.Agent import Agent
|
316
|
-
from edsl.language_models.registry import Model
|
317
|
-
from edsl.scenarios.Scenario import Scenario
|
318
|
-
|
319
313
|
self.agents = self.agents or [Agent()]
|
320
314
|
self.models = self.models or [Model()]
|
315
|
+
# if remote, set all the models to remote
|
316
|
+
if hasattr(self, "remote") and self.remote:
|
317
|
+
for model in self.models:
|
318
|
+
model.remote = True
|
321
319
|
self.scenarios = self.scenarios or [Scenario()]
|
322
320
|
for agent, scenario, model in product(self.agents, self.scenarios, self.models):
|
323
321
|
yield Interview(
|
@@ -331,7 +329,6 @@ class Jobs(Base):
|
|
331
329
|
These buckets are used to track API calls and token usage.
|
332
330
|
|
333
331
|
>>> from edsl.jobs import Jobs
|
334
|
-
>>> from edsl import Model
|
335
332
|
>>> j = Jobs.example().by(Model(temperature = 1), Model(temperature = 0.5))
|
336
333
|
>>> bc = j.create_bucket_collection()
|
337
334
|
>>> bc
|
@@ -371,16 +368,14 @@ class Jobs(Base):
|
|
371
368
|
if self.verbose:
|
372
369
|
print(message)
|
373
370
|
|
374
|
-
def _check_parameters(self, strict=False, warn=
|
371
|
+
def _check_parameters(self, strict=False, warn = True) -> None:
|
375
372
|
"""Check if the parameters in the survey and scenarios are consistent.
|
376
373
|
|
377
374
|
>>> from edsl import QuestionFreeText
|
378
|
-
>>> from edsl import Survey
|
379
|
-
>>> from edsl import Scenario
|
380
375
|
>>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
|
381
376
|
>>> j = Jobs(survey = Survey(questions=[q]))
|
382
377
|
>>> with warnings.catch_warnings(record=True) as w:
|
383
|
-
... j._check_parameters(
|
378
|
+
... j._check_parameters()
|
384
379
|
... assert len(w) == 1
|
385
380
|
... assert issubclass(w[-1].category, UserWarning)
|
386
381
|
... assert "The following parameters are in the survey but not in the scenarios" in str(w[-1].message)
|
@@ -418,13 +413,15 @@ class Jobs(Base):
|
|
418
413
|
progress_bar: bool = False,
|
419
414
|
stop_on_exception: bool = False,
|
420
415
|
cache: Union[Cache, bool] = None,
|
416
|
+
remote: bool = (
|
417
|
+
False if os.getenv("DEFAULT_RUN_MODE", "local") == "local" else True
|
418
|
+
),
|
421
419
|
check_api_keys: bool = False,
|
422
420
|
sidecar_model: Optional[LanguageModel] = None,
|
423
421
|
batch_mode: Optional[bool] = None,
|
424
422
|
verbose: bool = False,
|
425
423
|
print_exceptions=True,
|
426
424
|
remote_cache_description: Optional[str] = None,
|
427
|
-
remote_inference_description: Optional[str] = None,
|
428
425
|
) -> Results:
|
429
426
|
"""
|
430
427
|
Runs the Job: conducts Interviews and returns their results.
|
@@ -434,11 +431,11 @@ class Jobs(Base):
|
|
434
431
|
:param progress_bar: shows a progress bar
|
435
432
|
:param stop_on_exception: stops the job if an exception is raised
|
436
433
|
:param cache: a cache object to store results
|
434
|
+
:param remote: run the job remotely
|
437
435
|
:param check_api_keys: check if the API keys are valid
|
438
436
|
:param batch_mode: run the job in batch mode i.e., no expecation of interaction with the user
|
439
437
|
:param verbose: prints messages
|
440
438
|
:param remote_cache_description: specifies a description for this group of entries in the remote cache
|
441
|
-
:param remote_inference_description: specifies a description for the remote inference job
|
442
439
|
"""
|
443
440
|
from edsl.coop.coop import Coop
|
444
441
|
|
@@ -449,50 +446,21 @@ class Jobs(Base):
|
|
449
446
|
"Batch mode is deprecated. Please update your code to not include 'batch_mode' in the 'run' method."
|
450
447
|
)
|
451
448
|
|
449
|
+
self.remote = remote
|
452
450
|
self.verbose = verbose
|
453
451
|
|
454
452
|
try:
|
455
453
|
coop = Coop()
|
456
|
-
|
457
|
-
remote_cache = user_edsl_settings["remote_caching"]
|
458
|
-
remote_inference = user_edsl_settings["remote_inference"]
|
454
|
+
remote_cache = coop.edsl_settings["remote_caching"]
|
459
455
|
except Exception:
|
460
456
|
remote_cache = False
|
461
|
-
remote_inference = False
|
462
457
|
|
463
|
-
if
|
464
|
-
|
465
|
-
if
|
466
|
-
|
467
|
-
"Remote caching activated. The remote cache will be used for this job."
|
468
|
-
)
|
458
|
+
if self.remote:
|
459
|
+
## TODO: This should be a coop check
|
460
|
+
if os.getenv("EXPECTED_PARROT_API_KEY", None) is None:
|
461
|
+
raise MissingRemoteInferenceError()
|
469
462
|
|
470
|
-
|
471
|
-
self,
|
472
|
-
description=remote_inference_description,
|
473
|
-
status="queued",
|
474
|
-
)
|
475
|
-
self._output("Job sent!")
|
476
|
-
# Create mock results object to store job data
|
477
|
-
results = Results(
|
478
|
-
survey=Survey(),
|
479
|
-
data=[
|
480
|
-
Result(
|
481
|
-
agent=Agent.example(),
|
482
|
-
scenario=Scenario.example(),
|
483
|
-
model=Model(),
|
484
|
-
iteration=1,
|
485
|
-
answer={"info": "Remote job details"},
|
486
|
-
)
|
487
|
-
],
|
488
|
-
)
|
489
|
-
results.add_columns_from_dict([remote_job_data])
|
490
|
-
if self.verbose:
|
491
|
-
results.select(["info", "uuid", "status", "version"]).print(
|
492
|
-
format="rich"
|
493
|
-
)
|
494
|
-
return results
|
495
|
-
else:
|
463
|
+
if not self.remote:
|
496
464
|
if check_api_keys:
|
497
465
|
for model in self.models + [Model()]:
|
498
466
|
if not model.has_valid_api_key():
|
@@ -503,12 +471,8 @@ class Jobs(Base):
|
|
503
471
|
|
504
472
|
# handle cache
|
505
473
|
if cache is None:
|
506
|
-
from edsl.data.CacheHandler import CacheHandler
|
507
|
-
|
508
474
|
cache = CacheHandler().get_cache()
|
509
475
|
if cache is False:
|
510
|
-
from edsl.data.Cache import Cache
|
511
|
-
|
512
476
|
cache = Cache()
|
513
477
|
|
514
478
|
if not remote_cache:
|
@@ -660,11 +624,6 @@ class Jobs(Base):
|
|
660
624
|
@remove_edsl_version
|
661
625
|
def from_dict(cls, data: dict) -> Jobs:
|
662
626
|
"""Creates a Jobs instance from a dictionary."""
|
663
|
-
from edsl import Survey
|
664
|
-
from edsl.agents.Agent import Agent
|
665
|
-
from edsl.language_models.LanguageModel import LanguageModel
|
666
|
-
from edsl.scenarios.Scenario import Scenario
|
667
|
-
|
668
627
|
return cls(
|
669
628
|
survey=Survey.from_dict(data["survey"]),
|
670
629
|
agents=[Agent.from_dict(agent) for agent in data["agents"]],
|
@@ -691,8 +650,7 @@ class Jobs(Base):
|
|
691
650
|
"""
|
692
651
|
import random
|
693
652
|
from edsl.questions import QuestionMultipleChoice
|
694
|
-
from edsl
|
695
|
-
from edsl.scenarios.Scenario import Scenario
|
653
|
+
from edsl import Agent
|
696
654
|
|
697
655
|
# (status, question, period)
|
698
656
|
agent_answers = {
|
@@ -731,8 +689,6 @@ class Jobs(Base):
|
|
731
689
|
question_options=["Good", "Great", "OK", "Terrible"],
|
732
690
|
question_name="how_feeling_yesterday",
|
733
691
|
)
|
734
|
-
from edsl import Survey, ScenarioList
|
735
|
-
|
736
692
|
base_survey = Survey(questions=[q1, q2])
|
737
693
|
|
738
694
|
scenario_list = ScenarioList(
|
@@ -1,4 +1,4 @@
|
|
1
|
-
|
1
|
+
from edsl.jobs.buckets.TokenBucket import TokenBucket
|
2
2
|
|
3
3
|
|
4
4
|
class ModelBuckets:
|
@@ -8,7 +8,7 @@ class ModelBuckets:
|
|
8
8
|
A request is one call to the service. The number of tokens required for a request depends on parameters.
|
9
9
|
"""
|
10
10
|
|
11
|
-
def __init__(self, requests_bucket:
|
11
|
+
def __init__(self, requests_bucket: TokenBucket, tokens_bucket: TokenBucket):
|
12
12
|
"""Initialize the model buckets.
|
13
13
|
|
14
14
|
The requests bucket captures requests per unit of time.
|
@@ -28,8 +28,6 @@ class ModelBuckets:
|
|
28
28
|
@classmethod
|
29
29
|
def infinity_bucket(cls, model_name: str = "not_specified") -> "ModelBuckets":
|
30
30
|
"""Create a bucket with infinite capacity and refill rate."""
|
31
|
-
from edsl.jobs.buckets.TokenBucket import TokenBucket
|
32
|
-
|
33
31
|
return cls(
|
34
32
|
requests_bucket=TokenBucket(
|
35
33
|
bucket_name=model_name,
|
edsl/jobs/buckets/TokenBucket.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
from typing import Union, List, Any
|
2
2
|
import asyncio
|
3
3
|
import time
|
4
|
+
from collections import UserDict
|
5
|
+
from matplotlib import pyplot as plt
|
4
6
|
|
5
7
|
|
6
8
|
class TokenBucket:
|
@@ -112,7 +114,6 @@ class TokenBucket:
|
|
112
114
|
times, tokens = zip(*self.get_log())
|
113
115
|
start_time = times[0]
|
114
116
|
times = [t - start_time for t in times] # Normalize time to start from 0
|
115
|
-
from matplotlib import pyplot as plt
|
116
117
|
|
117
118
|
plt.figure(figsize=(10, 6))
|
118
119
|
plt.plot(times, tokens, label="Tokens Available")
|
@@ -6,9 +6,15 @@ import asyncio
|
|
6
6
|
import time
|
7
7
|
from typing import Any, Type, List, Generator, Optional
|
8
8
|
|
9
|
+
from edsl.agents import Agent
|
10
|
+
from edsl.language_models import LanguageModel
|
11
|
+
from edsl.scenarios import Scenario
|
12
|
+
from edsl.surveys import Survey
|
13
|
+
|
9
14
|
from edsl.jobs.Answers import Answers
|
10
15
|
from edsl.surveys.base import EndOfSurvey
|
11
16
|
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
17
|
+
|
12
18
|
from edsl.jobs.tasks.TaskCreators import TaskCreators
|
13
19
|
|
14
20
|
from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
|
@@ -54,9 +60,9 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
54
60
|
self.debug = debug
|
55
61
|
self.iteration = iteration
|
56
62
|
self.cache = cache
|
57
|
-
self.answers: dict[
|
58
|
-
|
59
|
-
) # will get filled in as interview progresses
|
63
|
+
self.answers: dict[
|
64
|
+
str, str
|
65
|
+
] = Answers() # will get filled in as interview progresses
|
60
66
|
self.sidecar_model = sidecar_model
|
61
67
|
|
62
68
|
# Trackers
|
@@ -17,9 +17,9 @@ class InterviewStatusMixin:
|
|
17
17
|
The keys are the question names; the values are the lists of status log changes for each task.
|
18
18
|
"""
|
19
19
|
for task_creator in self.task_creators.values():
|
20
|
-
self._task_status_log_dict[
|
21
|
-
task_creator.
|
22
|
-
|
20
|
+
self._task_status_log_dict[
|
21
|
+
task_creator.question.question_name
|
22
|
+
] = task_creator.status_log
|
23
23
|
return self._task_status_log_dict
|
24
24
|
|
25
25
|
@property
|
@@ -5,19 +5,17 @@ import asyncio
|
|
5
5
|
import time
|
6
6
|
import traceback
|
7
7
|
from typing import Generator, Union
|
8
|
-
|
9
8
|
from edsl import CONFIG
|
10
9
|
from edsl.exceptions import InterviewTimeoutError
|
11
|
-
|
12
|
-
|
10
|
+
from edsl.data_transfer_models import AgentResponseDict
|
11
|
+
from edsl.questions.QuestionBase import QuestionBase
|
13
12
|
from edsl.surveys.base import EndOfSurvey
|
14
13
|
from edsl.jobs.buckets.ModelBuckets import ModelBuckets
|
15
14
|
from edsl.jobs.interviews.interview_exception_tracking import InterviewExceptionEntry
|
16
15
|
from edsl.jobs.interviews.retry_management import retry_strategy
|
17
16
|
from edsl.jobs.tasks.task_status_enum import TaskStatus
|
18
17
|
from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
|
19
|
-
|
20
|
-
# from edsl.agents.InvigilatorBase import InvigilatorBase
|
18
|
+
from edsl.agents.InvigilatorBase import InvigilatorBase
|
21
19
|
|
22
20
|
TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
23
21
|
|
@@ -46,7 +44,6 @@ class InterviewTaskBuildingMixin:
|
|
46
44
|
scenario=self.scenario,
|
47
45
|
model=self.model,
|
48
46
|
debug=debug,
|
49
|
-
survey=self.survey,
|
50
47
|
memory_plan=self.survey.memory_plan,
|
51
48
|
current_answers=self.answers,
|
52
49
|
iteration=self.iteration,
|
@@ -152,17 +149,15 @@ class InterviewTaskBuildingMixin:
|
|
152
149
|
async def _answer_question_and_record_task(
|
153
150
|
self,
|
154
151
|
*,
|
155
|
-
question:
|
152
|
+
question: QuestionBase,
|
156
153
|
debug: bool,
|
157
154
|
task=None,
|
158
|
-
) ->
|
155
|
+
) -> AgentResponseDict:
|
159
156
|
"""Answer a question and records the task.
|
160
157
|
|
161
158
|
This in turn calls the the passed-in agent's async_answer_question method, which returns a response dictionary.
|
162
159
|
Note that is updates answers dictionary with the response.
|
163
160
|
"""
|
164
|
-
from edsl.data_transfer_models import AgentResponseDict
|
165
|
-
|
166
161
|
try:
|
167
162
|
invigilator = self._get_invigilator(question, debug=debug)
|
168
163
|
|
@@ -258,11 +253,11 @@ class InterviewTaskBuildingMixin:
|
|
258
253
|
"""
|
259
254
|
current_question_index: int = self.to_index[current_question.question_name]
|
260
255
|
|
261
|
-
next_question: Union[
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
256
|
+
next_question: Union[
|
257
|
+
int, EndOfSurvey
|
258
|
+
] = self.survey.rule_collection.next_question(
|
259
|
+
q_now=current_question_index,
|
260
|
+
answers=self.answers | self.scenario | self.agent["traits"],
|
266
261
|
)
|
267
262
|
|
268
263
|
next_question_index = next_question.next_q
|